diff --git a/.gitignore b/.gitignore index b5306b8b79c37166e5496cf17a3e39b86b9a6314..5afe375f46f07b3b557ae23f75740b337517d3bd 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ Podfile.lock /tensorflow/contrib/lite/examples/ios/simple/data/*.tflite xcuserdata/** /api_init_files_list.txt +/estimator_api_init_files_list.txt # Android .gradle diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index db4b1581ae671b1e676e215c9a80dfaab832fa21..f598999f351c10f8bd01dfbd3ad8897f19d570e8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -107,7 +107,7 @@ diff /tmp/my_cc_file.cc #### Python coding style Changes to TensorFlow Python code should conform to -[Google Python Style Guide](https://google.github.io/styleguide/pyguide.html) +[Google Python Style Guide](https://github.com/google/styleguide/blob/gh-pages/pyguide.md) Use `pylint` to check your Python changes. To install `pylint` and retrieve TensorFlow's custom style definition: diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index 2f3df7cda9cec29ed0c2266629022f0a22b37df9..52faed9297cfcaf8c93bb9c79686c9258a53c560 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -15,9 +15,10 @@ If you open a GitHub issue, here is our policy: ### System information - **Have I written custom code (as opposed to using a stock example script provided in TensorFlow)**: - **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**: +- **Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device**: - **TensorFlow installed from (source or binary)**: - **TensorFlow version (use command below)**: -- **Python version**: +- **Python version**: - **Bazel version (if compiling from source)**: - **GCC/Compiler version (if compiling from source)**: - **CUDA/cuDNN version**: diff --git a/README.md b/README.md index 05fcb23f7edd657f2ea495d848fadc226e56b524..1aaecae6de718a8ef32b831e235ab4179583cd22 100644 --- a/README.md +++ b/README.md @@ -82,12 +82,12 @@ The TensorFlow project strives to abide by generally accepted best practices in | Build Type | Status | Artifacts | | --- | --- | --- | | **Linux CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | -| **Linux GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-cc.png) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | -| **Linux XLA** | TBA | TBA | +| **Linux GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.png) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | +| **Linux XLA** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.png) | TBA | | **MacOS** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | -| **Windows CPU** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [pypi](https://pypi.org/project/tf-nightly/) | -| **Windows GPU** | [![Status](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/badge/icon)](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | -| **Android** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) [build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/) | +| **Windows CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.png) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Windows GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.png) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | +| **Android** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.png) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) | ### Community Supported Builds @@ -97,7 +97,8 @@ The TensorFlow project strives to abide by generally accepted best practices in | **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA | | **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | | **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA | -| **Linux CPU with Intel® MKL-DNN®** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | TBA | +| **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | +| **Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6| ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)|[1.9.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp27-cp27mu-linux_x86_64.whl)
[1.9.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp35-cp35m-linux_x86_64.whl)
[1.9.0 py3.6](https://storage.cloud.google.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp36-cp36m-linux_x86_64.whl) | ## For more information diff --git a/RELEASE.md b/RELEASE.md index 7e6325af14d007a39d272817e2c4d476da9ce119..078aafd3746e5ce5c16af15de80d99c1a9e8c567 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,68 @@ +# Release 1.10.0 + +## Major Features And Improvements + +* The `tf.lite` runtime now supports `complex64`. +* Initial Bigtable integration for `tf.data`. +* Improved local run behavior in `tf.estimator.train_and_evaluate` which does not reload checkpoints for evaluation. +* `RunConfig` now sets device_filters to restrict how workers and PS can communicate. This can speed up training and ensure clean shutdowns in some situations. But if you have jobs that require communication between workers, you will have to set custom session_options in your `RunConfig`. +* Moved Distributions and Bijectors from `tf.contrib.distributions` to [Tensorflow Probability (TFP)](https://github.com/tensorflow/probability). `tf.contrib.distributions` is now deprecated and will be removed by the end of 2018. +* Adding new endpoints for existing tensorflow symbols. These endpoints are going to be the preferred endpoints going forward and may replace some of the existing endpoints in the future. See below for the complete list. New symbols have been added to the following modules: [`tf.debugging`](https://www.tensorflow.org/versions/master/api_docs/python/tf/debugging), [`tf.dtypes`](https://www.tensorflow.org/versions/master/api_docs/python/tf/dtypes), [`tf.image`](https://www.tensorflow.org/versions/master/api_docs/python/tf/image), [`tf.io`](https://www.tensorflow.org/versions/master/api_docs/python/tf/io), [`tf.linalg`](https://www.tensorflow.org/versions/master/api_docs/python/tf/linalg), [`tf.manip`](https://www.tensorflow.org/versions/master/api_docs/python/tf/manip), [`tf.math`](https://www.tensorflow.org/versions/master/api_docs/python/tf/math), [`tf.quantization`](https://www.tensorflow.org/versions/master/api_docs/python/tf/quantization), [`tf.strings`](https://www.tensorflow.org/versions/master/api_docs/python/tf/strings) + +## Breaking Changes + +* Prebuilt binaries are now (as of TensorFlow 1.10) built against NCCL 2.2 and no longer include NCCL in the binary install. TensorFlow usage with multiple GPUs and NCCL requires upgrade to [NCCL 2.2](https://developer.nvidia.com/nccl). See updated install guides: [Installing TensorFlow on Ubuntu](https://www.tensorflow.org/install/install_linux#tensorflow_gpu_support) and [Install TensorFlow from Sources](https://www.tensorflow.org/install/install_sources#optional_install_tensorflow_for_gpu_prerequisites). +* Starting from TensorFlow 1.11, Windows builds will use Bazel. Therefore, we will drop official support for cmake. + +## Bug Fixes and Other Changes + +* `tf.data`: + * `tf.contrib.data.group_by_reducer()` is now available via the public API. + * `tf.contrib.data.choose_from_datasets()` is now available via the public API. + * Adding `drop_remainder` argument to `tf.data.Dataset.batch()` and `tf.data.Dataset.padded_batch()`, deprecating tf.contrib.data.batch_and_drop_remainder()` and `tf.contrib.data.padded_batch_and_drop_remainder()`. +* `tf.estimator`: + * `Estimator`s now use custom savers included in `EstimatorSpec` scaffolds for saving SavedModels during export. + * `EstimatorSpec` will now add a default prediction output for export if no `export_output` is provided, eliminating the need to explicitly include a `PredictOutput` object in the `model_fn` for simple use-cases. + * Support sparse_combiner in canned Linear Estimators. + * Added batch normalization to `DNNClassifier`, `DNNRegressor`, and `DNNEstimator`. + * Adding ranking support for boosted trees. + * Adding center bias option for boosted trees. +* Add `synchronization` and `aggregation` args to get_variable(). These args will be used for distributed variables. +* Add `synchronization` and `aggregation` args to the layer `add_weight()` API. These args will be used for distributed variables. +* `tf.losses.*` do not add to the global collection when executing eagerly (to avoid leaking memory). +* Support different summary and checkpoint directories in `tf.train.MonitoredTrainingSession()`. +* Added IndRNN, IndyGRU, and IndyLSTM cells to `tf.contrib.rnn`. +* Add safe static factory functions for SparseTensor and convert all CHECKs to DCHECKs. Using the constructor directly is unsafe and deprecated. +* Make the Bigtable client connection pool configurable & increase the default # of connections for performance. +* Added derivative of `tf.random_gamma` with respect to the alpha parameter. +* Added derivative of `tf.igamma(a, x)` and `tf.igammac(a, x)` with respect to a. +* Modified Bessel functions of order zero and one. +* Add FillTriangular Bijector to create triangular matrices. +* Added support for Type III DCT, and `tf.spectral.idct(type=2|3)`. +* Correctly handle CuDNN RNN weight loaded when nest in `TimeDistributed`. +* Adding per-element weight support for `WALSComputePartialLhsAndRhsOp`. +* ZerosLike and OnesLike ops treated as constants by Graph Transform Tool. +* Gamma distribution and the derived distributions (Beta, Dirichlet, Student's t, inverse Gamma) now fully reparameterized. +* Java: Experimental wrapper classes to make graph generation easier. Thanks @karllessard and @kbsriram +* Build & link in secure gRPC components (switch from the insecure grpc dependency to secure grpc dependency). +* Adding new endpoints for existing tensorflow symbols. These endpoints are going to be the preferred endpoints going forward and may replace some of the existing endpoints in the future. List of new endpoints: + * New endpoints in `tf.image` namespace: `tf.image.extract_image_patches` + * New endpoints in `tf.debugging` namespace: `tf.debugging.check_numerics`, `tf.debugging.is_finite`, `tf.debugging.is_inf`, `tf.debugging.is_nan`. + * New endpoints in `tf.dtypes` namespace: `tf.dtypes.as_string`. + * New endpoints in `tf.io` namespace: `tf.io.decode_base64`, `tf.io.decode_compressed`, `tf.io.decode_json_example`, `tf.io.decode_raw`, `tf.io.encode_base64`, `tf.io.matching_files`, `tf.io.parse_tensor`, `tf.io.read_file, `tf.io.write_file`. + * New endpoints in tf.linalg namespace: `tf.linalg.cross`, `tf.linalg.tensor_diag` (corresponds to `tf.diag`), `tf.linalg.tensor_diag_part` (corresponds to `tf.diag_part`). + * New endpoints in tf.manip namespace: `tf.manip.batch_to_space_nd`, `tf.manip.gather_nd`, `tf.manip.reshape`, `tf.manip.reverse`, `tf.manip.scatter_nd`, `tf.manip.space_to_batch_nd`, `tf.manip.tile` + * New endpoints in tf.math namespace: `tf.math.acos`, `tf.math.acosh`, `tf.math.add`, `tf.math.asin`, `tf.math.asinh`, `tf.math.atan`, `tf.math.atan2`, `tf.math.atanh`, `tf.math.betainc`, `tf.math.ceil`, `tf.math.cos`, `tf.math.cosh`, `tf.math.digamma`, `tf.math.equal`, `tf.math.erfc`, `tf.math.exp`, `tf.math.expm1`, `tf.math.floor`, `tf.math.greater`, `tf.math.greater_equal`, `tf.math.igamma`, `tf.math.igammac`, `tf.math.invert_permutation`, `tf.math.less`, `tf.math.less_equal`, `tf.math.lgamma`, `tf.math.log`, `tf.math.log1p`, `tf.math.logical_and`, `tf.math.logical_not`, `tf.math.logical_or`, `tf.math.maximum`, `tf.math.minimum`, `tf.math.not_equal`, `tf.math.polygamma`, `tf.math.reciprocal`, `tf.math.rint`, `tf.math.rsqrt`, `tf.math.segment_max`, `tf.math.segment_mean`, `tf.math.segment_min`, `tf.math.segment_prod`, `tf.math.segment_sum`, `tf.math.sin`, `tf.math.sinh`, `tf.math.softplus`, `tf.math.softsign`, `tf.math.squared_difference`, `tf.math.tan`, `tf.math.unsorted_segment_max`, `tf.math.unsorted_segment_min`, `tf.math.unsorted_segment_prod`, `tf.math.unsorted_segment_sum`, `tf.math.zeta`. + * New endpoints in `tf.quantization` namespace: `tf.quantization.dequantize`, `tf.quantization.fake_quant_with_min_max_args`, `tf.quantization.fake_quant_with_min_max_args_gradient`, `tf.quantization.fake_quant_with_min_max_vars`, `tf.quantization.fake_quant_with_min_max_vars_gradient`, `tf.quantization.fake_quant_with_min_max_vars_per_channel`, `tf.quantization.fake_quant_with_min_max_vars_per_channel_gradient`. + * New endpoints in tf.strings namespace: `tf.strings.join` (corresponds to `tf.string_join`), `tf.strings.regex_replace`, `tf.strings.to_number` (corresponds to `tf.string_to_number`), `tf.strings.strip` (corresponds to `tf.string_strip`), `tf.strings.substr`, `tf.strings.to_hash_bucket` (corresponds to `tf.string_to_hash_bucket`), `tf.strings.to_hash_bucket_fast` (corresponds to `tf.string_to_hash_bucket_fast`), `tf.strings.to_hash_bucket_strong` (corresponds to `tf.string_to_hash_bucket_strong`). + + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, Andrei Nigmatulin, Andrew Ginns, BjøRn Moholt, Brett Koonce, Chengzhi Chen, Chinmay Das, Christian Ertler, Christoph Boeddeker, Clayne Robison, Courtial Florian, ctiijima, Dan Douthit, Dan J, Dan Ringwalt, EFanZh, Emanuele Ballarin, eqy, Evgeniy Zheltonozhskiy, Freedom" Koan-Sin Tan, FréDéRic Branchaud-Charron, G K, gracehoney, Guillaume Klein, Guozhong Zhuang, Hsien-Yang Li, hsm207, ImSheridan, Jayaram Bobba, Jiandong Ruan, Jie, Joel Shor, Jonas Rauber, Jongmin Baek, jsawruk, Karan Kaw, Karl Lessard, karl@kubx.ca, Kb Sriram, KinmanLam, leiiwang, Li, Yiqiang, Loo Rong Jie, Mahmoud Abuzaina, Mahmoud Aslan, ManHyuk, Martin Patz, Martin Zeitler, mktozk, Mohammad Ashraf Bhuiyan, mrTsjolder, Naman Bhalla, Nick Felt, Nicolas Lopez, Niranjan Hasabnis, Nishidha Panpaliya, Nitish, nrstott, Nutti, Parag Jain, PeterLee, Philipp Jund, Rach L, Rafal Wojdyla, Roland Zimmermann, Sergei Lebedev, SneakyFish5, Soila Kavulya, Sriram Veturi, Steven Schmatz, Taehoon Lee, Tang, Wenyi, Taras Sereda, Ted Chang, Tim Zaman, Tristan Rice, tucan, vchigrin, Vikram Tiwari, Vincent, WeberXie, William D. Irons, Yan Facai (颜发才), Yong Tang, Yu Yi, Yuxin Wu, Zé ViníCius + # Release 1.9.0 ## Major Features And Improvements @@ -6,7 +71,7 @@ * Update `tf.keras` to the Keras 2.1.6 API. * Added [`tf.keras.layers.CuDNNGRU`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNGRU) and [`tf.keras.layers.CuDNNLSTM`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNLSTM) layers. [Try it](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb?linkId=53292082). * Adding support of core [feature columns](https://www.tensorflow.org/get_started/feature_columns) and [losses](https://www.tensorflow.org/api_docs/python/tf/losses) to [gradient boosted trees estimators](https://github.com/tensorflow/models/tree/master/official/boosted_trees). -* The [python interface](https://tensorflow-dot-devsite.googleplex.com/versions/r1.9/api_docs/python/tf/contrib/lite) +* The [python interface](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/lite) for the [TFLite Optimizing Converter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/README.md) has been expanded, and the command line interface (AKA: `toco`, `tflite_convert`) is once again included in the standard `pip` installation. @@ -21,7 +86,7 @@ * The [distributions.Bijector](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/distributions/bijectors/Bijector) API supports broadcasting for Bijectors with new API changes. -## Breaking Chances +## Breaking Changes * If you're opening empty variable scopes; replace `variable_scope('', ...)` by `variable_scope(tf.get_variable_scope(), ...)`. * Headers used for building custom ops have been moved from site-packages/external into site-packages/tensorflow/include/external. @@ -34,18 +99,22 @@ * Using `tf.layers` in a subclassed `tf.keras.Model` class. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details * `tf.data`: - * The `DatasetBase::DebugString()` method is now `const`. - * Added the `tf.contrib.data.sample_from_datasets()` API for randomly sampling from multiple datasets. + * `Dataset.from_generator()` now accepts an `args` list, in order to create nested generators. + * `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed. + * `tf.contrib.data.sample_from_datasets()` and `tf.contrib.data.choose_from_datasets()` make it easier to sample or deterministically choose elements from multiple datasets. + * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings, and two infrequently used arguments removed. + * (C++) `DatasetBase::DebugString()` is now `const`. + * (C++) `DatasetBase::MakeIterator()` has been renamed to `DatasetBase::MakeIteratorInternal()`. + * (C++) `IteratorBase::Initialize()` method was added to support raising errors during iterator construction. * Eager Execution: + * Added the ability to pause recording operations for gradient computation via `tf.GradientTape.stop_recording`. + * Updated documentation, introductory notebooks. * `tf.keras`: * Move Keras code out of _impl folder and remove API files. * `tf.keras.Model.save_weights` now saves in TensorFlow format by default. * Enable dataset iterators to be passed to `tf.keras.Model` training/eval methods. -* Accelerated Linear Algebra (XLA): -* TensorFlow Debugger (tfdbg): fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB). +* TensorFlow Debugger (tfdbg) CLI: fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB). * `tf.contrib`: - * Add `tf.contrib.data.choose_from_datasets()`. - * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings. Two arguments were removed from `make_csv_dataset`. * `tf.contrib.framework.zero_initializer` supports ResourceVariable. * Adding "constrained_optimization" to tensorflow/contrib. * Other: @@ -55,7 +124,6 @@ * More consistent GcsFileSystem behavior for certain reads past EOF. * Update benchmark for tf.scan to match ranges across eager and graph modes. * Fixed bug in `tf.reduce_prod gradient` for complex dtypes. - * Add optional `args` argument to `Dataset.from_generator()`. * Allow the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), which would previously raise an error. This will correspond to an attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only be accessed indirectly (e.g. through getattr and setattr). To set this up the user will first need to explicitly add the variable to the hparam object (e.g. "hparams.add_hparam(name='a.b', value=0.0)"). * Benchmark for tf.scan in graph and eager modes. * Added complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D. @@ -65,7 +133,6 @@ * LinearOperator[1D,2D,3D]Circulant added to `tensorflow.linalg`. * Conv3D, Conv3DBackpropInput, Conv3DBackpropFilter now supports arbitrary. * Added `tf.train.Checkpoint` for reading/writing object-based checkpoints. - * `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed. * Added LinearOperatorKronecker, a dense-free implementation of the Kronecker Product. * Allow LinearOperator to broadcast. * SavedModelBuilder will now deduplicate asset names that point to files with the same basename and the same contents. Note that this may result in new asset files included in SavedModels in cases where assets with the same name but different contents were previously overwriting each other. diff --git a/WORKSPACE b/WORKSPACE index fd7570a80ae2ee0087f7d2fd771fcce5b9690028..17961829a605c2d1f2d2ba86a7c30c47618c139b 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -18,7 +18,7 @@ closure_repositories() # files, in case the parsing of those build files depends on the bazel # version we require here. load("//tensorflow:version_check.bzl", "check_bazel_version_at_least") -check_bazel_version_at_least("0.10.0") +check_bazel_version_at_least("0.15.0") load("//tensorflow:workspace.bzl", "tf_workspace") diff --git a/configure.py b/configure.py index ad585fa52e571d62d11864531476e46b2f15f297..f97bf8a66836a6647ba6aca625cb1526e11b39af 100644 --- a/configure.py +++ b/configure.py @@ -35,8 +35,8 @@ except ImportError: _DEFAULT_CUDA_VERSION = '9.0' _DEFAULT_CUDNN_VERSION = '7' -_DEFAULT_NCCL_VERSION = '1.3' -_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' +_DEFAULT_NCCL_VERSION = '2.2' +_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' @@ -680,7 +680,7 @@ def create_android_sdk_rule(environ_cp): if is_windows() or is_cygwin(): default_sdk_path = cygpath('%s/Android/Sdk' % environ_cp['APPDATA']) elif is_macos(): - default_sdk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + default_sdk_path = '%s/library/Android/Sdk' % environ_cp['HOME'] else: default_sdk_path = '%s/Android/Sdk' % environ_cp['HOME'] @@ -835,6 +835,8 @@ def set_tf_cuda_version(environ_cp): '[Default is %s]: ') % (tf_cuda_version, default_cuda_path) cuda_toolkit_path = get_from_env_or_user_or_default( environ_cp, 'CUDA_TOOLKIT_PATH', ask_cuda_path, default_cuda_path) + if is_windows() or is_cygwin(): + cuda_toolkit_path = cygpath(cuda_toolkit_path) if is_windows(): cuda_rt_lib_path = 'lib/x64/cudart.lib' @@ -880,7 +882,7 @@ def set_tf_cudnn_version(environ_cp): default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH') ask_cudnn_path = (r'Please specify the location where cuDNN %s library is ' 'installed. Refer to README.md for more details. [Default' - ' is %s]:') % (tf_cudnn_version, default_cudnn_path) + ' is %s]: ') % (tf_cudnn_version, default_cudnn_path) cudnn_install_path = get_from_env_or_user_or_default( environ_cp, 'CUDNN_INSTALL_PATH', ask_cudnn_path, default_cudnn_path) @@ -1095,8 +1097,10 @@ def set_tf_nccl_install_path(environ_cp): raise ValueError('Currently NCCL is only supported on Linux platforms.') ask_nccl_version = ( - 'Please specify the NCCL version you want to use. ' - '[Leave empty to default to NCCL %s]: ') % _DEFAULT_NCCL_VERSION + 'Please specify the NCCL version you want to use. If NCCL %s is not ' + 'installed, then you can use version 1.3 that can be fetched ' + 'automatically but it may have worse performance with multiple GPUs. ' + '[Default is %s]: ') % (_DEFAULT_NCCL_VERSION, _DEFAULT_NCCL_VERSION) for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): tf_nccl_version = get_from_env_or_user_or_default( @@ -1197,7 +1201,7 @@ def set_tf_cuda_compute_capabilities(environ_cp): 'https://developer.nvidia.com/cuda-gpus.\nPlease' ' note that each additional compute ' 'capability significantly increases your ' - 'build time and binary size. [Default is: %s]' % + 'build time and binary size. [Default is: %s]: ' % default_cuda_compute_capabilities) tf_cuda_compute_capabilities = get_from_env_or_user_or_default( environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES', @@ -1232,28 +1236,13 @@ def set_tf_cuda_compute_capabilities(environ_cp): def set_other_cuda_vars(environ_cp): """Set other CUDA related variables.""" - if is_windows(): - # The following three variables are needed for MSVC toolchain configuration - # in Bazel - environ_cp['CUDA_PATH'] = environ_cp.get('CUDA_TOOLKIT_PATH') - environ_cp['CUDA_COMPUTE_CAPABILITIES'] = environ_cp.get( - 'TF_CUDA_COMPUTE_CAPABILITIES') - environ_cp['NO_WHOLE_ARCHIVE_OPTION'] = 1 - write_action_env_to_bazelrc('CUDA_PATH', environ_cp.get('CUDA_PATH')) - write_action_env_to_bazelrc('CUDA_COMPUTE_CAPABILITIE', - environ_cp.get('CUDA_COMPUTE_CAPABILITIE')) - write_action_env_to_bazelrc('NO_WHOLE_ARCHIVE_OPTION', - environ_cp.get('NO_WHOLE_ARCHIVE_OPTION')) - write_to_bazelrc('build --config=win-cuda') - write_to_bazelrc('test --config=win-cuda') + # If CUDA is enabled, always use GPU during build and test. + if environ_cp.get('TF_CUDA_CLANG') == '1': + write_to_bazelrc('build --config=cuda_clang') + write_to_bazelrc('test --config=cuda_clang') else: - # If CUDA is enabled, always use GPU during build and test. - if environ_cp.get('TF_CUDA_CLANG') == '1': - write_to_bazelrc('build --config=cuda_clang') - write_to_bazelrc('test --config=cuda_clang') - else: - write_to_bazelrc('build --config=cuda') - write_to_bazelrc('test --config=cuda') + write_to_bazelrc('build --config=cuda') + write_to_bazelrc('test --config=cuda') def set_host_cxx_compiler(environ_cp): @@ -1413,14 +1402,36 @@ def set_build_strip_flag(): write_to_bazelrc('build --strip=always') -def set_windows_build_flags(): - if is_windows(): - # The non-monolithic build is not supported yet - write_to_bazelrc('build --config monolithic') - # Suppress warning messages - write_to_bazelrc('build --copt=-w --host_copt=-w') - # Output more verbose information when something goes wrong - write_to_bazelrc('build --verbose_failures') +def set_windows_build_flags(environ_cp): + """Set Windows specific build options.""" + # The non-monolithic build is not supported yet + write_to_bazelrc('build --config monolithic') + # Suppress warning messages + write_to_bazelrc('build --copt=-w --host_copt=-w') + # Output more verbose information when something goes wrong + write_to_bazelrc('build --verbose_failures') + # The host and target platforms are the same in Windows build. So we don't + # have to distinct them. This avoids building the same targets twice. + write_to_bazelrc('build --distinct_host_configuration=false') + # Enable short object file path to avoid long path issue on Windows. + # TODO(pcloudy): Remove this flag when upgrading Bazel to 0.16.0 + # Short object file path will be enabled by default. + write_to_bazelrc('build --experimental_shortened_obj_file_path=true') + + if get_var( + environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline', + True, + ('Would you like to override eigen strong inline for some C++ ' + 'compilation to reduce the compilation time?'), + 'Eigen strong inline overridden.', + 'Not overriding eigen strong inline, ' + 'some compilations could take more than 20 mins.'): + # Due to a known MSVC compiler issue + # https://github.com/tensorflow/tensorflow/issues/10521 + # Overriding eigen strong inline speeds up the compiling of + # conv_grad_ops_3d.cc and conv_ops_3d.cc by 20 minutes, + # but this also hurts the performance. Let users decide what they want. + write_to_bazelrc('build --define=override_eigen_strong_inline=true') def config_info_line(name, help_text): @@ -1440,14 +1451,14 @@ def main(): # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.10.0') + check_bazel_version('0.15.0') reset_tf_configure_bazelrc(args.workspace) cleanup_makefile() setup_python(environ_cp) if is_windows(): - environ_cp['TF_NEED_S3'] = '0' + environ_cp['TF_NEED_AWS'] = '0' environ_cp['TF_NEED_GCP'] = '0' environ_cp['TF_NEED_HDFS'] = '0' environ_cp['TF_NEED_JEMALLOC'] = '0' @@ -1460,19 +1471,31 @@ def main(): # TODO(ibiryukov): Investigate using clang as a cpu or cuda compiler on # Windows. environ_cp['TF_DOWNLOAD_CLANG'] = '0' + environ_cp['TF_ENABLE_XLA'] = '0' + environ_cp['TF_NEED_GDR'] = '0' + environ_cp['TF_NEED_VERBS'] = '0' + environ_cp['TF_NEED_MPI'] = '0' + environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0' if is_macos(): environ_cp['TF_NEED_JEMALLOC'] = '0' environ_cp['TF_NEED_TENSORRT'] = '0' + # The numpy package on ppc64le uses OpenBLAS which has multi-threading + # issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at + # runtime to allow the Tensorflow testcases which compare numpy + # results to Tensorflow results to succeed. + if is_ppc64le(): + write_action_env_to_bazelrc("OMP_NUM_THREADS", 1) + set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc', 'with_jemalloc', True) set_build_var(environ_cp, 'TF_NEED_GCP', 'Google Cloud Platform', 'with_gcp_support', True, 'gcp') set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System', 'with_hdfs_support', True, 'hdfs') - set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System', - 'with_s3_support', True, 's3') + set_build_var(environ_cp, 'TF_NEED_AWS', 'Amazon AWS Platform', + 'with_aws_support', True, 'aws') set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform', 'with_kafka_support', True, 'kafka') set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', @@ -1536,7 +1559,8 @@ def main(): set_grpc_build_flags() set_cc_opt_flags(environ_cp) set_build_strip_flag() - set_windows_build_flags() + if is_windows(): + set_windows_build_flags(environ_cp) if get_var( environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', @@ -1548,11 +1572,15 @@ def main(): create_android_ndk_rule(environ_cp) create_android_sdk_rule(environ_cp) - print('Preconfigured Bazel build configs. You can use any of the below by ' - 'adding "--config=<>" to your build command. See tools/bazel.rc for ' - 'more details.') - config_info_line('mkl', 'Build with MKL support.') - config_info_line('monolithic', 'Config for mostly static monolithic build.') + # On Windows, we don't have MKL support and the build is always monolithic. + # So no need to print the following message. + # TODO(pcloudy): remove the following if check when they make sense on Windows + if not is_windows(): + print('Preconfigured Bazel build configs. You can use any of the below by ' + 'adding "--config=<>" to your build command. See tools/bazel.rc for ' + 'more details.') + config_info_line('mkl', 'Build with MKL support.') + config_info_line('monolithic', 'Config for mostly static monolithic build.') if __name__ == '__main__': main() diff --git a/tensorflow/BUILD b/tensorflow/BUILD index f362900387e506e935d4ede9aa781a83948fe0da..f8cd6820244aa05724ce0980419eb7b77962ff91 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -20,10 +20,18 @@ load( "tf_additional_binary_deps", ) load( - "//tensorflow/tools/api/generator:api_gen.bzl", + "//tensorflow/python/tools/api/generator:api_gen.bzl", "gen_api_init_files", # @unused ) +# Config setting used when building for products +# which requires restricted licenses to be avoided. +config_setting( + name = "no_lgpl_deps", + values = {"define": "__TENSORFLOW_NO_LGPL_DEPS__=1"}, + visibility = ["//visibility:public"], +) + # Config setting for determining if we are building for Android. config_setting( name = "android", @@ -216,8 +224,8 @@ config_setting( ) config_setting( - name = "with_s3_support", - define_values = {"with_s3_support": "true"}, + name = "with_aws_support", + define_values = {"with_aws_support": "true"}, visibility = ["//visibility:public"], ) @@ -244,8 +252,8 @@ config_setting( ) config_setting( - name = "with_s3_support_windows_override", - define_values = {"with_s3_support": "true"}, + name = "with_aws_support_windows_override", + define_values = {"with_aws_support": "true"}, values = {"cpu": "x64_windows"}, visibility = ["//visibility:public"], ) @@ -279,8 +287,8 @@ config_setting( ) config_setting( - name = "with_s3_support_android_override", - define_values = {"with_s3_support": "true"}, + name = "with_aws_support_android_override", + define_values = {"with_aws_support": "true"}, values = {"crosstool_top": "//external:android/crosstool"}, visibility = ["//visibility:public"], ) @@ -300,8 +308,8 @@ config_setting( ) config_setting( - name = "with_s3_support_ios_override", - define_values = {"with_s3_support": "true"}, + name = "with_aws_support_ios_override", + define_values = {"with_aws_support": "true"}, values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, visibility = ["//visibility:public"], ) @@ -373,6 +381,14 @@ config_setting( }, ) +# Setting to use when loading kernels dynamically +config_setting( + name = "dynamic_loaded_kernels", + define_values = { + "dynamic_loaded_kernels": "true", + }, +) + config_setting( name = "using_cuda_nvcc", define_values = { @@ -400,14 +416,6 @@ config_setting( visibility = ["//visibility:public"], ) -# TODO(laigd): consider removing this option and make TensorRT enabled -# automatically when CUDA is enabled. -config_setting( - name = "with_tensorrt_support", - values = {"define": "with_tensorrt_support=true"}, - visibility = ["//visibility:public"], -) - package_group( name = "internal", packages = [ @@ -433,11 +441,6 @@ filegroup( ), ) -filegroup( - name = "docs_src", - data = glob(["docs_src/**/*.md"]), -) - cc_library( name = "grpc", deps = select({ @@ -581,6 +584,7 @@ exports_files( gen_api_init_files( name = "tensorflow_python_api_gen", srcs = ["api_template.__init__.py"], + api_version = 1, root_init_template = "api_template.__init__.py", ) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index a8ad8e4b9477eecd6a7cabdada19439ac48dc5be..19ccb6e71d2f3021c1ce5c8905d8a72059c1cfcb 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eval_const_tensor.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -51,6 +52,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" @@ -327,6 +329,7 @@ TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) { } void TF_DeleteBuffer(TF_Buffer* buffer) { + if (buffer == nullptr) return; if (buffer->data_deallocator != nullptr) { (*buffer->data_deallocator)(const_cast(buffer->data), buffer->length); @@ -356,6 +359,7 @@ void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) { void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) { status->status = Status::OK(); + if (s == nullptr) return; delete s->session; delete s; } @@ -906,6 +910,7 @@ TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; } void TF_DeleteLibraryHandle(TF_Library* lib_handle) { + if (lib_handle == nullptr) return; tensorflow::port::Free(const_cast(lib_handle->op_list.data)); delete lib_handle; } @@ -963,6 +968,7 @@ TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr); TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(), nullptr); TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1); +TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0); #undef TF_DEVICELIST_METHOD @@ -1852,6 +1858,7 @@ TF_Graph::TF_Graph() TF_Graph* TF_NewGraph() { return new TF_Graph; } void TF_DeleteGraph(TF_Graph* g) { + if (g == nullptr) return; g->mu.lock(); g->delete_requested = true; const bool del = g->sessions.empty(); @@ -2068,7 +2075,8 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( TF_Graph* graph, const TF_Buffer* graph_def, const TF_ImportGraphDefOptions* options, TF_Status* status) { GraphDef def; - if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, graph_def->length)) { + if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, + graph_def->length)) { status->status = InvalidArgument("Invalid GraphDef"); return nullptr; } @@ -2098,7 +2106,8 @@ void TF_GraphImportGraphDefWithReturnOutputs( return; } GraphDef def; - if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, graph_def->length)) { + if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, + graph_def->length)) { status->status = InvalidArgument("Invalid GraphDef"); return; } @@ -2381,6 +2390,12 @@ void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); } void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, TF_Output* dx, TF_Status* status, TF_Output* dy) { + TF_AddGradientsWithPrefix(g, nullptr, y, ny, x, nx, dx, status, dy); +} + +void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y, + int ny, TF_Output* x, int nx, TF_Output* dx, + TF_Status* status, TF_Output* dy) { #ifdef __ANDROID__ status->status = tensorflow::errors::Unimplemented( "Adding gradients is not supported in Android. File a bug at " @@ -2397,9 +2412,29 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, const int first_new_node_id = g->graph.num_node_ids(); + string prefix_cmp; + const char* child_scope_name; + if (prefix == nullptr) { + child_scope_name = "gradients"; + } else { + prefix_cmp = string(prefix) + "/"; + // The operation should fail if the provided name prefix has already been + // used in this graph + for (const auto& pair : g->name_map) { + const string& name = pair.first; + if (name.compare(prefix) == 0 || + tensorflow::str_util::StartsWith(name, prefix_cmp)) { + status->status = InvalidArgument( + "prefix [", prefix, + "] conflicts with existing node in the graph named [", name, "]"); + return; + } + } + child_scope_name = prefix; + } tensorflow::Scope scope = NewInternalScope(&g->graph, &status->status, &g->refiner) - .NewSubScope("gradients"); + .NewSubScope(child_scope_name); if (dx != nullptr) { std::vector dx_arg = OutputsFromTFOutputs(dx, ny); @@ -2414,6 +2449,18 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) { Node* n = g->graph.FindNodeId(i); if (n == nullptr) continue; + + // Adding the gradients to the graph can alter the prefix to prevent + // name collisions only if this prefix has not been provided explicitly + // by the user. If it was provided, assert that it remained intact. + if (prefix != nullptr && + !tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) { + status->status = tensorflow::errors::Internal( + "BUG: The gradients prefix have been unexpectedly altered when " + "adding the nodes to the graph. This is a bug. Please file an " + "issue at https://github.com/tensorflow/tensorflow/issues."); + return; + } // We have a convoluted scheme here: Using the C++ graph construction API // to add potentially many nodes to the graph without running the checks // (such as uniqueness of the names of nodes) we run with other functions @@ -2525,6 +2572,7 @@ void TF_CloseSession(TF_Session* s, TF_Status* status) { void TF_DeleteSession(TF_Session* s, TF_Status* status) { status->status = Status::OK(); + if (s == nullptr) return; TF_Graph* const graph = s->graph; if (graph != nullptr) { graph->mu.lock(); @@ -2723,7 +2771,34 @@ TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name, TF_Buffer* ret = TF_NewBuffer(); status->status = MessageToBuffer(*api_def, ret); + if (!status->status.ok()) { + TF_DeleteBuffer(ret); + return nullptr; + } return ret; #endif // __ANDROID__ } + +TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) { + tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels(); + TF_Buffer* ret = TF_NewBuffer(); + status->status = MessageToBuffer(kernel_list, ret); + if (!status->status.ok()) { + TF_DeleteBuffer(ret); + return nullptr; + } + return ret; +} + +TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) { + tensorflow::KernelList kernel_list = + tensorflow::GetRegisteredKernelsForOp(name); + TF_Buffer* ret = TF_NewBuffer(); + status->status = MessageToBuffer(kernel_list, ret); + if (!status->status.ok()) { + TF_DeleteBuffer(ret); + return nullptr; + } + return ret; +} } // end extern "C" diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 1eb75ef11ff337dfcb2e016e09804fc04662fcda..850f6ecd637d768bca99720e0add07680829e17a 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -44,6 +44,7 @@ limitations under the License. // * size_t is used to represent byte sizes of objects that are // materialized in the address space of the calling process. // * int is used as an index into arrays. +// * Deletion functions are safe to call on nullptr. // // Questions left to address: // * Might at some point need a way for callers to provide their own Env. @@ -1130,6 +1131,7 @@ TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params); // Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, // i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... +// // `dx` are used as initial gradients (which represent the symbolic partial // derivatives of some loss function `L` w.r.t. `y`). // `dx` must be nullptr or have size `ny`. @@ -1138,6 +1140,12 @@ TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params); // The partial derivatives are returned in `dy`. `dy` should be allocated to // size `nx`. // +// Gradient nodes are automatically named under the "gradients/" prefix. To +// guarantee name uniqueness, subsequent calls to the same graph will +// append an incremental tag to the prefix: "gradients_1/", "gradients_2/", ... +// See TF_AddGradientsWithPrefix, which provides a means to specify a custom +// name prefix for operations added to a graph to compute the gradients. +// // WARNING: This function does not yet support all the gradients that python // supports. See // https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md @@ -1146,6 +1154,33 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, TF_Output* dx, TF_Status* status, TF_Output* dy); +// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, +// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... +// This is a variant of TF_AddGradients that allows to caller to pass a custom +// name prefix to the operations added to a graph to compute the gradients. +// +// `dx` are used as initial gradients (which represent the symbolic partial +// derivatives of some loss function `L` w.r.t. `y`). +// `dx` must be nullptr or have size `ny`. +// If `dx` is nullptr, the implementation will use dx of `OnesLike` for all +// shapes in `y`. +// The partial derivatives are returned in `dy`. `dy` should be allocated to +// size `nx`. +// `prefix` names the scope into which all gradients operations are being added. +// `prefix` must be unique within the provided graph otherwise this operation +// will fail. If `prefix` is nullptr, the default prefixing behaviour takes +// place, see TF_AddGradients for more details. +// +// WARNING: This function does not yet support all the gradients that python +// supports. See +// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md +// for instructions on how to add C++ more gradients. +TF_CAPI_EXPORT void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, + TF_Output* y, int ny, + TF_Output* x, int nx, + TF_Output* dx, TF_Status* status, + TF_Output* dy); + // Create a TF_Function from a TF_Graph // // Params: @@ -1235,6 +1270,11 @@ TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction( int noutputs, const TF_Output* outputs, const char* const* output_names, const TF_FunctionOptions* opts, const char* description, TF_Status* status); +// Returns the name of the graph function. +// The return value points to memory that is only usable until the next +// mutation to *func. +TF_CAPI_EXPORT extern const char* TF_FunctionName(TF_Function* func); + // Write out a serialized representation of `func` (as a FunctionDef protocol // message) to `output_func_def` (allocated by TF_NewBuffer()). // `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer() @@ -1521,6 +1561,13 @@ TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list, TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes( const TF_DeviceList* list, int index, TF_Status* status); +// Retrieve the incarnation number of a given device. +// +// If index is out of bounds, an error code will be set in the status object, +// and 0 will be returned. +TF_CAPI_EXPORT extern uint64_t TF_DeviceListIncarnation( + const TF_DeviceList* list, int index, TF_Status* status); + // -------------------------------------------------------------------------- // Load plugins containing custom ops and kernels @@ -1603,6 +1650,18 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, size_t name_len, TF_Status* status); +// -------------------------------------------------------------------------- +// Kernel definition information. + +// Returns a serialized KernelList protocol buffer containing KernelDefs for all +// registered kernels. +TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status); + +// Returns a serialized KernelList protocol buffer containing KernelDefs for all +// kernels registered for the operation named `name`. +TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( + const char* name, TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 95b04f9058afdfaadbc24f0238860279fcd3e800..69b3ffe2a1f620e346405607ecf742fb863aa644 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -57,6 +57,45 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { } } +TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, + unsigned char gpu_memory_allow_growth) { + tensorflow::ConfigProto config; + auto* optimizer_options = + config.mutable_graph_options()->mutable_optimizer_options(); + if (enable_xla_compilation) { + optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1); + + // These XLA flags are needed to trigger XLA properly from C (more generally + // non-Python) clients. If this API is called again with `enable` set to + // false, it is safe to keep these flag values as is. + tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = + tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); + flags->tf_xla_cpu_global_jit = true; + flags->tf_xla_min_cluster_size = 1; + } else { + optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF); + } + + auto* gpu_options = config.mutable_gpu_options(); + gpu_options->set_allow_growth(gpu_memory_allow_growth); + + TF_Buffer* ret = TF_NewBuffer(); + TF_CHECK_OK(MessageToBuffer(config, ret)); + return ret; +} + +TF_Buffer* TF_CreateRunOptions(unsigned char enable_full_trace) { + tensorflow::RunOptions options; + if (enable_full_trace) { + options.set_trace_level(tensorflow::RunOptions::FULL_TRACE); + } else { + options.set_trace_level(tensorflow::RunOptions::NO_TRACE); + } + TF_Buffer* ret = TF_NewBuffer(); + TF_CHECK_OK(MessageToBuffer(options, ret)); + return ret; +} + const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) { tensorflow::mutex_lock c(graph->mu); const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString(); diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 20bdace40f1272ded06e710034053a7610326e7f..6617c5a572e90e78369f73d714f39942f213040f 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -55,11 +55,27 @@ extern "C" { // set XLA flag values to prepare for XLA compilation. Otherwise set // global_jit_level to OFF. // -// This API is syntax sugar over TF_SetConfig(), and is used by clients that -// cannot read/write the tensorflow.ConfigProto proto. +// This and the next API are syntax sugar over TF_SetConfig(), and is used by +// clients that cannot read/write the tensorflow.ConfigProto proto. +// TODO: Migrate to TF_CreateConfig() below. TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable); +// Create a serialized tensorflow.ConfigProto proto, where: +// +// a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if +// `enable_xla_compilation` is non-zero, and OFF otherwise. +// b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`. +TF_CAPI_EXPORT extern TF_Buffer* TF_CreateConfig( + unsigned char enable_xla_compilation, + unsigned char gpu_memory_allow_growth); + +// Create a serialized tensorflow.RunOptions proto, where RunOptions.trace_level +// is set to FULL_TRACE if `enable_full_trace` is non-zero, and NO_TRACE +// otherwise. +TF_CAPI_EXPORT extern TF_Buffer* TF_CreateRunOptions( + unsigned char enable_full_trace); + // Returns the graph content in a human-readable format, with length set in // `len`. The format is subject to change in the future. // The returned string is heap-allocated, and caller should call free() on it. diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 384e6c8cb97022264c5327da5ca5861057608fbe..a2c5a42c11361779de61b515e0f08dcc45e609b9 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -536,6 +536,10 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, return tf_function; } +const char* TF_FunctionName(TF_Function* func) { + return func->fdef.signature().name().c_str(); +} + void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func, const TF_Function* grad, TF_Status* status) { if (func == nullptr) { diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 610274696f5940c063e68f2310cfd9cc1e0bd964..73fe73769bc1219ce865149d67d333c53371ccc5 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -193,6 +193,7 @@ class CApiFunctionTest : public ::testing::Test { ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); ASSERT_NE(func_, nullptr); + ASSERT_EQ(std::string(func_name_), std::string(TF_FunctionName(func_))); TF_GraphCopyFunction(host_graph_, func_, nullptr, s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); } @@ -1516,7 +1517,8 @@ void DefineStatefulFunction(const char* name, TF_Function** func) { TF_Output inputs[] = {}; TF_Output outputs[] = {{random, 0}}; - *func = TF_GraphToFunction(func_graph.get(), name, /*append_hash=*/false, -1, + *func = TF_GraphToFunction(func_graph.get(), name, + /*append_hash_to_fn_name=*/false, -1, /*opers=*/nullptr, 0, inputs, 1, outputs, /*output_names=*/nullptr, /*opts=*/nullptr, "", s.get()); @@ -1617,5 +1619,66 @@ TEST_F(CApiFunctionTest, GetFunctionsFromGraph) { TF_DeleteFunction(func1); } +// This test only works when the TF build includes XLA compiler. One way to set +// this up is via bazel build option "--define with_xla_support=true". +// +// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to +// something like TENSORFLOW_CAPI_USE_XLA. +#ifdef TENSORFLOW_EAGER_USE_XLA +TEST_F(CApiFunctionTest, StatelessIf_XLA) { + TF_Function* func; + const std::string funcName = "BranchFunc"; + DefineFunction(funcName.c_str(), &func); + TF_GraphCopyFunction(host_graph_, func, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* feed = Placeholder(host_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* true_cond = ScalarConst(true, host_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_OperationDescription* desc = + TF_NewOperation(host_graph_, "StatelessIf", "IfNode"); + TF_AddInput(desc, {true_cond, 0}); + TF_Output inputs[] = {{feed, 0}}; + TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs)); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_SetAttrType(desc, "Tcond", TF_BOOL); + TF_DataType inputType = TF_INT32; + TF_SetAttrTypeList(desc, "Tin", &inputType, 1); + TF_SetAttrTypeList(desc, "Tout", &inputType, 1); + TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size()); + TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size()); + TF_SetDevice(desc, "/device:XLA_CPU:0"); + auto op = TF_FinishOperation(desc, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + ASSERT_NE(op, nullptr); + + // Create a session for this graph. + CSession csession(host_graph_, s_, /*use_XLA*/ true); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Run the graph. + csession.SetInputs({{feed, Int32Tensor(17)}}); + csession.SetOutputs({op}); + csession.Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Tensor* out = csession.output_tensor(0); + ASSERT_TRUE(out != nullptr); + EXPECT_EQ(TF_INT32, TF_TensorType(out)); + EXPECT_EQ(0, TF_NumDims(out)); // scalar + ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); + int32* output_contents = static_cast(TF_TensorData(out)); + EXPECT_EQ(-17, *output_contents); + + // Clean up + csession.CloseAndDelete(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_DeleteFunction(func); +} +#endif // TENSORFLOW_EAGER_USE_XLA + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index bc04b53fbb7fa9ba46228ae5a4ec8ee96df5f3dc..aa2a537f03be31ae45ff3d6f7815b449d661cf9c 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -29,9 +29,11 @@ limitations under the License. #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/graph.pb_text.h" +#include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def.pb_text.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -1424,6 +1426,29 @@ TEST(CAPI, SavedModelNullArgsAreValid) { TF_DeleteStatus(s); } +TEST(CAPI, DeletingNullPointerIsSafe) { + TF_Status* status = TF_NewStatus(); + + TF_DeleteStatus(nullptr); + TF_DeleteBuffer(nullptr); + TF_DeleteTensor(nullptr); + TF_DeleteSessionOptions(nullptr); + TF_DeleteGraph(nullptr); + TF_DeleteImportGraphDefOptions(nullptr); + TF_DeleteImportGraphDefResults(nullptr); + TF_DeleteFunction(nullptr); + TF_DeleteSession(nullptr, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeletePRunHandle(nullptr); + TF_DeleteDeprecatedSession(nullptr, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteDeviceList(nullptr); + TF_DeleteLibraryHandle(nullptr); + TF_DeleteApiDefMap(nullptr); + + TF_DeleteStatus(status); +} + REGISTER_OP("TestOpWithNoGradient") .Input("x: T") .Output("y: T") @@ -1458,8 +1483,8 @@ class CApiGradientsTest : public ::testing::Test { BuildSuccessGraph(inputs, outputs); BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs); - AddGradients(grad_inputs_provided, inputs, 2, outputs, 1, grad_outputs); - + AddGradients(grad_inputs_provided, nullptr, inputs, 2, outputs, 1, + grad_outputs); EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); // Compare that the graphs match. @@ -1480,7 +1505,8 @@ class CApiGradientsTest : public ::testing::Test { BuildErrorGraph(inputs, outputs); - AddGradients(grad_inputs_provided, inputs, 1, outputs, 1, grad_outputs); + AddGradients(grad_inputs_provided, nullptr, inputs, 1, outputs, 1, + grad_outputs); string expected_msg = "No gradient defined for op: TestOpWithNoGradient. Please see " @@ -1524,19 +1550,20 @@ class CApiGradientsTest : public ::testing::Test { EXPECT_EQ(*a_data, *b_data); } - void AddGradients(bool grad_inputs_provided, TF_Output* inputs, int ninputs, - TF_Output* outputs, int noutputs, TF_Output* grad_outputs) { + void AddGradients(bool grad_inputs_provided, const char* prefix, + TF_Output* inputs, int ninputs, TF_Output* outputs, + int noutputs, TF_Output* grad_outputs) { if (grad_inputs_provided) { TF_Output grad_inputs[1]; const float grad_inputs_val[] = {1.0, 1.0, 1.0, 1.0}; TF_Operation* grad_inputs_op = FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs"); grad_inputs[0] = TF_Output{grad_inputs_op, 0}; - TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, grad_inputs, - s_, grad_outputs); + TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs, + ninputs, grad_inputs, s_, grad_outputs); } else { - TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, nullptr, s_, - grad_outputs); + TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs, + ninputs, nullptr, s_, grad_outputs); } } @@ -1681,6 +1708,20 @@ class CApiGradientsTest : public ::testing::Test { return op; } + void BuildGraphAndAddGradientsWithPrefixes(const char* prefix1, + const char* prefix2 = nullptr) { + TF_Output inputs[2]; + TF_Output outputs[1]; + TF_Output grad_outputs[2]; + + BuildSuccessGraph(inputs, outputs); + + AddGradients(false, prefix1, inputs, 2, outputs, 1, grad_outputs); + if (prefix2 != nullptr) { + AddGradients(false, prefix2, inputs, 2, outputs, 1, grad_outputs); + } + } + TF_Status* s_; TF_Graph* graph_; TF_Graph* expected_graph_; @@ -1700,6 +1741,56 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) { TestGradientsError(false); } +TEST_F(CApiGradientsTest, GradientsPrefix_PrefixIsOk) { + BuildGraphAndAddGradientsWithPrefixes("gradients"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithDistinctPrefixes) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients_1"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInSameScope) { + BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope/gradients_1"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInDifferentScopes) { + BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope_1/gradients"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsSubScopeOf1st) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/sub"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_PrefixMatchesExistingNodeName) { + BuildGraphAndAddGradientsWithPrefixes("Const_0"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithIdenticalPrefixes) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsMatchingNodeOf1st) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/MatMul"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_1stGradientsMatchingNodeOf2nd) { + BuildGraphAndAddGradientsWithPrefixes("gradients/MatMul", "gradients"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsParentScopeOf1st) { + BuildGraphAndAddGradientsWithPrefixes("gradients/sub", "gradients"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + void ScalarFloatFromTensor(const TF_Tensor* t, float* f) { ASSERT_TRUE(t != nullptr); ASSERT_EQ(TF_FLOAT, TF_TensorType(t)); @@ -2312,6 +2403,57 @@ TEST(TestApiDef, TestCreateApiDefWithOverwrites) { TF_DeleteLibraryHandle(lib); } +class DummyKernel : public tensorflow::OpKernel { + public: + explicit DummyKernel(tensorflow::OpKernelConstruction* context) + : OpKernel(context) {} + void Compute(tensorflow::OpKernelContext* context) override {} +}; + +// Test we can query kernels +REGISTER_OP("TestOpWithSingleKernel") + .Input("a: float") + .Input("b: float") + .Output("o: float"); +REGISTER_KERNEL_BUILDER( + Name("TestOpWithSingleKernel").Device(tensorflow::DEVICE_CPU), DummyKernel); + +TEST(TestKernel, TestGetAllRegisteredKernels) { + TF_Status* status = TF_NewStatus(); + TF_Buffer* kernel_list_buf = TF_GetAllRegisteredKernels(status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + KernelList kernel_list; + kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length); + ASSERT_GT(kernel_list.kernel_size(), 0); + TF_DeleteBuffer(kernel_list_buf); + TF_DeleteStatus(status); +} + +TEST(TestKernel, TestGetRegisteredKernelsForOp) { + TF_Status* status = TF_NewStatus(); + TF_Buffer* kernel_list_buf = + TF_GetRegisteredKernelsForOp("TestOpWithSingleKernel", status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + KernelList kernel_list; + kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length); + ASSERT_EQ(kernel_list.kernel_size(), 1); + EXPECT_EQ(kernel_list.kernel(0).op(), "TestOpWithSingleKernel"); + EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU"); + TF_DeleteBuffer(kernel_list_buf); + TF_DeleteStatus(status); +} + +TEST(TestKernel, TestGetRegisteredKernelsForOpNoKernels) { + TF_Status* status = TF_NewStatus(); + TF_Buffer* kernel_list_buf = TF_GetRegisteredKernelsForOp("Unknown", status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + KernelList kernel_list; + kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length); + ASSERT_EQ(kernel_list.kernel_size(), 0); + TF_DeleteBuffer(kernel_list_buf); + TF_DeleteStatus(status); +} + #undef EXPECT_TF_META } // namespace diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index 24eb6c069b21349fce288db3e79fbf14e824ad11..f15d9ee20adb31a0b76e2cd0d1e67f17a9deff05 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -26,6 +26,10 @@ limitations under the License. using tensorflow::GraphDef; using tensorflow::NodeDef; +static void BoolDeallocator(void* data, size_t, void* arg) { + delete[] static_cast(data); +} + static void Int32Deallocator(void* data, size_t, void* arg) { delete[] static_cast(data); } @@ -38,6 +42,14 @@ static void FloatDeallocator(void* data, size_t, void* arg) { delete[] static_cast(data); } +TF_Tensor* BoolTensor(bool v) { + const int num_bytes = sizeof(bool); + bool* values = new bool[1]; + values[0] = v; + return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator, + nullptr); +} + TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { int64_t num_values = 1; for (int i = 0; i < num_dims; ++i) { @@ -131,6 +143,12 @@ TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, return op; } +TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s, + const char* name) { + unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor); + return Const(tensor.get(), graph, s, name); +} + TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, const char* name) { unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index 38313d647ca93d4779bb1325f8ed7bde4b743879..7eeb1ee5e17ad7e5644f8bc8a18ca967b108475d 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -31,6 +31,8 @@ using ::tensorflow::string; typedef std::unique_ptr unique_tensor_ptr; +TF_Tensor* BoolTensor(int32_t v); + // Create a tensor with values of type TF_INT8 provided by `values`. TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values); @@ -55,6 +57,9 @@ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name = "const"); +TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar"); + TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, const char* name = "scalar"); diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 82ca2be2cff885967dd798a1cb84b164a9df399e..dfb1c9a37644c726e1eabab775593596d5b556b9 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -110,7 +110,7 @@ tensorflow::Status GetAllRemoteDevices( tensorflow::Status CreateRemoteContexts( const std::vector& remote_workers, int64 rendezvous_id, - const tensorflow::ServerDef& server_def, + int keep_alive_secs, const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, tensorflow::gtl::FlatMap* remote_contexts) { for (int i = 0; i < remote_workers.size(); i++) { @@ -129,6 +129,7 @@ tensorflow::Status CreateRemoteContexts( request.mutable_server_def()->set_job_name(parsed_name.job); request.mutable_server_def()->set_task_index(parsed_name.task); request.set_async(async); + request.set_keep_alive_secs(keep_alive_secs); auto* eager_client = remote_eager_workers->GetClient(remote_worker); if (eager_client == nullptr) { return tensorflow::errors::Internal( @@ -150,8 +151,9 @@ tensorflow::Status CreateRemoteContexts( return tensorflow::Status::OK(); } -tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, - TFE_Context** ctx) { +tensorflow::Status UpdateTFE_ContextWithServerDef( + int keep_alive_secs, const tensorflow::ServerDef& server_def, + TFE_Context* ctx) { // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the // server object (which currently CHECK-fails) and we miss the error, instead, // we log the error, and then return to allow the user to see the error @@ -165,12 +167,12 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, } \ } while (0); - string worker_name = tensorflow::strings::StrCat( - "/job:", opts->server_def.job_name(), - "/replica:0/task:", opts->server_def.task_index()); + string worker_name = + tensorflow::strings::StrCat("/job:", server_def.job_name(), + "/replica:0/task:", server_def.task_index()); std::unique_ptr server; - LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(opts->server_def, &server)); + LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &server)); tensorflow::GrpcServer* grpc_server = dynamic_cast(server.get()); @@ -202,15 +204,15 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, // Initialize remote eager workers. tensorflow::gtl::FlatMap remote_contexts; LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( - remote_workers, rendezvous_id, opts->server_def, - remote_eager_workers.get(), opts->async, &remote_contexts)); + remote_workers, rendezvous_id, keep_alive_secs, server_def, + remote_eager_workers.get(), ctx->context.Async(), &remote_contexts)); tensorflow::RemoteRendezvous* r = grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id); TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession( - session_name, opts->server_def, true)); + session_name, server_def, true)); std::shared_ptr worker_session; TF_RETURN_IF_ERROR( @@ -221,10 +223,11 @@ tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); auto* device_mgr = grpc_server->worker_env()->device_mgr; - *ctx = new TFE_Context(opts->session_options.options, opts->policy, - opts->async, device_mgr, r, std::move(server), - std::move(remote_eager_workers), - std::move(remote_device_mgr), remote_contexts); + + ctx->context.InitializeRemote(std::move(server), + std::move(remote_eager_workers), + std::move(remote_device_mgr), remote_contexts, + r, device_mgr, keep_alive_secs); return tensorflow::Status::OK(); #undef LOG_AND_RETURN_IF_ERROR @@ -249,15 +252,6 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( options->policy = policy; } -TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( - TFE_ContextOptions* options, const void* proto, size_t proto_len, - TF_Status* status) { - if (!options->server_def.ParseFromArray(proto, proto_len)) { - status->status = tensorflow::errors::InvalidArgument( - "Invalid tensorflow.ServerDef protocol buffer"); - } -} - TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, unsigned char async, TF_Status* status) { @@ -267,12 +261,6 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { - if (!opts->server_def.job_name().empty()) { - TFE_Context* ctx = nullptr; - status->status = NewRemoteAwareTFE_Context(opts, &ctx); - return ctx; - } - std::vector devices; status->status = tensorflow::DeviceFactory::AddDevices( opts->session_options.options, "/job:localhost/replica:0/task:0", @@ -288,7 +276,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { opts->async, std::move(device_mgr), r); } -void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { delete ctx; } +void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; } TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* list = new TF_DeviceList; @@ -301,6 +289,22 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); } +// Set server_def on the context, possibly updating it. +TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, + int keep_alive_secs, + const void* proto, + size_t proto_len, + TF_Status* status) { + tensorflow::ServerDef server_def; + if (!server_def.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Invalid tensorflow.ServerDef protocol buffer"); + return; + } + status->status = + UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx); +} + void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { ctx->context.SetThreadLocalDevicePlacementPolicy( @@ -336,7 +340,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { } void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { - DCHECK(h); + if (h == nullptr) return; if (h->handle) { h->handle->Unref(); } @@ -348,6 +352,11 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { } int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return -1; + } int result; status->status = h->handle->NumDims(&result); return result; @@ -355,12 +364,22 @@ int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return -1; + } tensorflow::int64 result; status->status = h->handle->Dim(dim_index, &result); return result; } const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } tensorflow::Device* d = nullptr; status->status = h->handle->OpDevice(&d); return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" @@ -368,6 +387,11 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { } TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } // TODO(agarwal): move this implementation inside TFE_TensorHandle. tensorflow::Device* d = nullptr; tensorflow::Device* op_device = nullptr; @@ -664,17 +688,17 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) { const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( TFE_TensorHandle* h, TF_Status* status) { - tensorflow::Device* d = nullptr; - tensorflow::Device* op_device = nullptr; - const tensorflow::Tensor* t = nullptr; - status->status = h->handle->TensorAndDevice(&t, &d, &op_device); - if (!status->status.ok()) return nullptr; - if (d != nullptr) { + if (!h->handle->OnHostCPU()) { status->status = tensorflow::errors::FailedPrecondition( "TFE_TensorHandle is placed in device (not host) memory. Cannot return " "a tensorflow::Tensor"); return nullptr; } + tensorflow::Device* d = nullptr; + tensorflow::Device* op_device = nullptr; + const tensorflow::Tensor* t = nullptr; + status->status = h->handle->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) return nullptr; return t; } @@ -700,6 +724,10 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, } } // namespace +void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); } + +void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); } + namespace tensorflow { void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index fdbd5374b2afe815c3a81b453930eb8f1fa351d3..a0ebc6fa0a22ed61be91c2974352c2988fb4cd92 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -81,16 +81,6 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); -// A tensorflow.ServerDef specifies remote workers (in addition to the current -// workers name). Operations created on this context can then be executed on -// any of these remote workers by setting an appropriate device. -// -// If the following is set, all servers identified by the -// ServerDef must be up when the context is created. -TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( - TFE_ContextOptions* options, const void* proto, size_t proto_len, - TF_Status* status); - // Destroy an options object. TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*); @@ -102,8 +92,7 @@ typedef struct TFE_Context TFE_Context; TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext( const TFE_ContextOptions* opts, TF_Status* status); -TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, - TF_Status* status); +TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx); TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status); @@ -128,6 +117,18 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*, unsigned char async, TF_Status* status); +// A tensorflow.ServerDef specifies remote workers (in addition to the current +// workers name). Operations created on this context can then be executed on +// any of these remote workers by setting an appropriate device. +// +// If the following is set, all servers identified by the +// ServerDef must be up when the context is created. +TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, + int keep_alive_secs, + const void* proto, + size_t proto_len, + TF_Status* status); + // Causes the calling thread to block till all ops dispatched in async mode // have been executed. Note that "execution" here refers to kernel execution / // scheduling of copies, etc. Similar to sync execution, it doesn't guarantee @@ -380,6 +381,16 @@ TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status); +// Some TF ops need a step container to be set to limit the lifetime of some +// resources (mostly TensorArray and Stack, used in while loop gradients in +// graph mode). Calling this on a context tells it to start a step. +TF_CAPI_EXPORT extern void TFE_ContextStartStep(TFE_Context* ctx); + +// Ends a step. When there is no active step (that is, every started step has +// been ended) step containers will be cleared. Note: it is not safe to call +// TFE_ContextEndStep while ops which rely on the step container may be running. +TF_CAPI_EXPORT extern void TFE_ContextEndStep(TFE_Context* ctx); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 4c5077023d5bb3b83808bf3908e7110dd026e3ad..a5c0681e2e4eddae08954d9d0178ca96a3f8f29a 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -59,7 +59,6 @@ struct TFE_ContextOptions { // true if async execution is enabled. bool async = false; TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT}; - tensorflow::ServerDef server_def; }; struct TFE_Context { @@ -73,23 +72,6 @@ struct TFE_Context { default_policy), async, std::move(device_mgr), rendezvous) {} - explicit TFE_Context( - const tensorflow::SessionOptions& opts, - TFE_ContextDevicePlacementPolicy default_policy, bool async, - tensorflow::DeviceMgr* local_device_mgr, - tensorflow::Rendezvous* rendezvous, - std::unique_ptr server, - std::unique_ptr remote_eager_workers, - std::unique_ptr remote_device_mgr, - const tensorflow::gtl::FlatMap& - remote_contexts) - : context(opts, - static_cast( - default_policy), - async, local_device_mgr, rendezvous, std::move(server), - std::move(remote_eager_workers), std::move(remote_device_mgr), - remote_contexts) {} - tensorflow::EagerContext context; }; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 3504a8b5e78480732d3454097c1b2197ac2b2e17..71d5f3613c89762633113b4e1dfb82b8199a1cd1 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -49,7 +49,7 @@ void BM_InitOp(int iters) { } tensorflow::testing::StopTiming(); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -80,7 +80,7 @@ void BM_Execute(int iters, int async) { tensorflow::testing::StopTiming(); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -95,7 +95,7 @@ TEST(CAPI, Context) { TF_DeviceList* devices = TFE_ContextListDevices(ctx, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); const int num_devices = TF_DeviceListCount(devices); @@ -108,14 +108,14 @@ TEST(CAPI, Context) { TF_DeleteStatus(status); } -tensorflow::ServerDef GetServerDef(int num_tasks) { +tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) { tensorflow::ServerDef server_def; server_def.set_protocol("grpc"); - server_def.set_job_name("localhost"); + server_def.set_job_name(job_name); server_def.set_task_index(0); tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); tensorflow::JobDef* job_def = cluster_def->add_job(); - job_def->set_name("localhost"); + job_def->set_name(job_name); for (int i = 0; i < num_tasks; i++) { int port = tensorflow::testing::PickUnusedPortOrDie(); job_def->mutable_tasks()->insert( @@ -124,6 +124,10 @@ tensorflow::ServerDef GetServerDef(int num_tasks) { return server_def; } +tensorflow::ServerDef GetServerDef(int num_tasks) { + return GetServerDef("localhost", num_tasks); +} + void TestRemoteExecute(bool async) { tensorflow::ServerDef server_def = GetServerDef(2); @@ -140,9 +144,6 @@ void TestRemoteExecute(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), - status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_EXPLICIT); @@ -150,6 +151,9 @@ void TestRemoteExecute(bool async) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); const char remote_device_name[] = @@ -195,8 +199,8 @@ void TestRemoteExecute(bool async) { TFE_DeleteOp(matmul); TFE_ContextAsyncWait(ctx, status); - TFE_DeleteContext(ctx, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); @@ -229,15 +233,15 @@ void TestRemoteExecuteSilentCopies(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), - status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); TFE_Context* ctx = TFE_NewContext(opts, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; @@ -281,7 +285,7 @@ void TestRemoteExecuteSilentCopies(bool async) { TFE_DeleteOp(matmul); TFE_ContextAsyncWait(ctx, status); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); @@ -296,6 +300,147 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) { TestRemoteExecuteSilentCopies(true); } +void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, + const std::vector& expected_values) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + std::unique_ptr actual_values(new float[expected_values.size()]); + EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t)); + memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + + for (int i = 0; i < expected_values.size(); i++) { + EXPECT_EQ(expected_values[i], actual_values[i]) + << "Mismatch in expected values at (zero-based) index " << i; + } +} + +void CheckRemoteMatMulExecutesOK(TFE_Context* ctx, + const char* remote_device_name, + const char* local_device_name) { + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + + TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + auto* retval_task0 = + TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22}); + + TFE_DeleteTensorHandle(retval_task0); + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); +} + +void TestRemoteExecuteChangeServerDef(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + const char local_device_name[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); + + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); + + // Update the server def with a new set of names (worker instead of + // localhost). + tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2); + serialized = updated_server_def.SerializeAsString(); + + updated_server_def.set_task_index(1); + tensorflow::Status s = tensorflow::GrpcServer::Create( + updated_server_def, tensorflow::Env::Default(), &worker_server); + ASSERT_TRUE(s.ok()) << s.error_message(); + ASSERT_TRUE(worker_server->Start().ok()); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Create a new tensor_handle. + TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(); + + // Check that copying it to the old remote device (named localhost) fails. + TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status); + EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Copying and executing on the new remote device works. + const char new_remote_device_name[] = + "/job:worker/replica:0/task:1/device:CPU:0"; + const char new_local_device_name[] = + "/job:worker/replica:0/task:0/device:CPU:0"; + + auto* h0_task1_new = TFE_TensorHandleCopyToDevice( + h0_task0_new, ctx, new_remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_DeleteTensorHandle(h0_task0_new); + TFE_DeleteTensorHandle(h0_task1_new); + + CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name, + new_local_device_name); + + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + + TFE_DeleteContext(ctx); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteChangeServerDef) { + TestRemoteExecuteChangeServerDef(false); +} +TEST(CAPI, RemoteExecuteChangeServerDefAsync) { + TestRemoteExecuteChangeServerDef(true); +} + TEST(CAPI, TensorHandle) { TFE_TensorHandle* h = TestMatrixTensorHandle(); EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); @@ -380,8 +525,7 @@ void TensorHandleCopyBetweenDevices(bool async) { TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); - TFE_DeleteContext(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleCopyBetweenDevices) { @@ -418,7 +562,7 @@ void TensorHandleCopyBetweenDevicesError(bool async) { TFE_DeleteTensorHandle(hcopy); TFE_DeleteTensorHandle(hcpu); if (hdevice != nullptr) TFE_DeleteTensorHandle(hdevice); - TFE_DeleteContext(ctx, status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleCopyBetweenDevicesError) { @@ -451,7 +595,7 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) { TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); - TFE_DeleteContext(ctx, status.get()); + TFE_DeleteContext(ctx); return; } const string gpu_1_name(TF_DeviceListName(devices, 1, status.get())); @@ -484,8 +628,7 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) { TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); - TFE_DeleteContext(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { @@ -533,8 +676,7 @@ void TensorHandleSilentCopy(bool async) { TFE_DeleteTensorHandle(hcpu); TFE_ContextAsyncWait(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TFE_DeleteContext(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false); } @@ -580,8 +722,7 @@ void TensorHandleSilentCopyLocal(bool async) { TFE_DeleteTensorHandle(hcpu); TFE_ContextAsyncWait(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TFE_DeleteContext(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); } TEST(CAPI, TensorHandleSilentCopyLocalAsync) { @@ -614,11 +755,47 @@ void SetAndGetOpDevices(bool async) { TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } +TEST(CAPI, TensorHandleNullptr) { + TFE_TensorHandle* h = nullptr; + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + TF_Tensor* t = TFE_TensorHandleResolve(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(t, nullptr); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + + const char* device_name = TFE_TensorHandleDeviceName(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(device_name, nullptr); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + + int num_dims = TFE_TensorHandleNumDims(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(num_dims, -1); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + + int dim = TFE_TensorHandleDim(h, 0, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(dim, -1); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); +} + void Execute_MatMul_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -640,7 +817,7 @@ void Execute_MatMul_CPU(bool async) { TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteTensorHandle(retvals[0]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); @@ -712,7 +889,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) { TFE_DeleteTensorHandle(m1); TFE_DeleteTensorHandle(m2); TFE_DeleteTensorHandle(retvals[0]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } TEST(CAPI, Execute_MatMul_CPU_Runtime_Error) { @@ -743,7 +920,7 @@ void Execute_MatMul_CPU_Type_Error(bool async) { if (retvals[0] != nullptr) { TFE_DeleteTensorHandle(retvals[0]); } - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } @@ -781,7 +958,7 @@ TEST(CAPI, Execute_Min_CPU) { TF_DeleteTensor(t); EXPECT_EQ(1, output[0]); EXPECT_EQ(3, output[1]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -823,7 +1000,7 @@ void Execute_MatMul_XLA_CPU(bool async) { EXPECT_EQ(10, product[1]); EXPECT_EQ(15, product[2]); EXPECT_EQ(22, product[3]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); } @@ -862,7 +1039,7 @@ void Execute_Min_XLA_CPU(bool async) { TF_DeleteTensor(t); EXPECT_EQ(1, output[0]); EXPECT_EQ(3, output[1]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); TF_DeleteStatus(status); } TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); } @@ -898,7 +1075,7 @@ void ExecuteWithTracing(bool async) { TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); TFE_DeleteTensorHandle(retvals[0]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); @@ -974,7 +1151,7 @@ TEST(CAPI, Function_ident_CPU) { TF_DeleteTensor(r); TFE_DeleteTensorHandle(result[0]); } - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); } @@ -1044,7 +1221,7 @@ TEST(CAPI, Function_ident_XLA_CPU) { TF_DeleteTensor(r); TFE_DeleteTensorHandle(result[0]); } - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); } @@ -1120,7 +1297,7 @@ void FunctionDefAndExecute(bool async) { EXPECT_EQ(10, product[1]); EXPECT_EQ(15, product[2]); EXPECT_EQ(22, product[3]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -1161,7 +1338,7 @@ void BM_ExecuteFunction(int iters, int async) { tensorflow::testing::StopTiming(); TFE_DeleteTensorHandle(m); TFE_DeleteTensorHandle(retval[0]); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -1249,7 +1426,7 @@ TEST(CAPI, Variables) { TFE_DeleteTensorHandle(var_handle); TFE_DeleteTensorHandle(value_handle); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } @@ -1288,7 +1465,7 @@ void BM_ReadVariable(int iters) { TFE_DeleteOp(op); TFE_DeleteTensorHandle(var_handle); - TFE_DeleteContext(ctx, status); + TFE_DeleteContext(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 734e712daa39c03f0177eb199b1acb1b19e5d845..1adb0458c35193117b5fa5cfe9ceffbaaf699af7 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -520,7 +520,12 @@ Status GradientTape::ComputeGradient( } } else { any_gradient_nonzero = true; - auto new_gradients = vspace.AggregateGradients(grad_it->second); + Gradient* new_gradients = nullptr; + if (grad_it->second.size() == 1) { + new_gradients = grad_it->second.at(0); + } else { + new_gradients = vspace.AggregateGradients(grad_it->second); + } if (sources_set.find(grad_it->first) == sources_set.end()) { gradients.erase(grad_it); } else { diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index e18fdf6c57bd3f432d8cb73536fb816df90b3963..8486b585c8587e18e8eea18a893fac0a40ff4a27 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -155,7 +155,7 @@ void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, tensorflow::shape_inference::ShapeHandle shape; status->status = ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); - if (status->status.ok()) return; + if (!status->status.ok()) return; shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype()); } ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index a98f0b00b2c70055f697ed4f15cb14708384b62f..588a45ea43f90c4d9b3d04fea305d2c562ae1d72 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -121,6 +121,7 @@ cc_library( deps = [ ":array_grad", ":data_flow_grad", + ":image_grad", ":math_grad", ":nn_grad", ], @@ -331,6 +332,36 @@ tf_cc_test( ], ) +cc_library( + name = "image_grad", + srcs = ["gradients/image_grad.cc"], + deps = [ + ":cc_ops", + ":cc_ops_internal", + ":grad_op_registry", + ":gradients", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "gradients_image_grad_test", + srcs = ["gradients/image_grad_test.cc"], + deps = [ + ":cc_ops", + ":client_session", + ":grad_op_registry", + ":grad_testutil", + ":gradient_checker", + ":image_grad", + ":testutil", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_library( name = "math_grad", srcs = ["gradients/math_grad.cc"], diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc index ba056a8f3a84910aebf5079573cb64c19f41469d..0e61089a5950ee894ad5489317757cff8a85e966 100644 --- a/tensorflow/cc/client/client_session.cc +++ b/tensorflow/cc/client/client_session.cc @@ -127,4 +127,22 @@ Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs, target_node_names, outputs, run_metadata); } +Status ClientSession::MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) { + TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph()); + return impl()->session_->MakeCallable(callable_options, out_handle); +} + +Status ClientSession::RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata) { + return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors, + run_metadata); +} + +Status ClientSession::ReleaseCallable(CallableHandle handle) { + return impl()->session_->ReleaseCallable(handle); +} + } // end namespace tensorflow diff --git a/tensorflow/cc/client/client_session.h b/tensorflow/cc/client/client_session.h index 5fb4109f7d15d5997f745acd913e60a02855fd73..7dd653eec4ec729b652cb779d06e820bfb437b3c 100644 --- a/tensorflow/cc/client/client_session.h +++ b/tensorflow/cc/client/client_session.h @@ -87,7 +87,33 @@ class ClientSession { const std::vector& run_outputs, std::vector* outputs, RunMetadata* run_metadata) const; - // TODO(keveman): Add support for partial run. + /// \brief A handle to a subgraph, created with + /// `ClientSession::MakeCallable()`. + typedef int64 CallableHandle; + + /// \brief Creates a `handle` for invoking the subgraph defined by + /// `callable_options`. + /// NOTE: This API is still experimental and may change. + Status MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle); + + /// \brief Invokes the subgraph named by `handle` with the given options and + /// input tensors. + /// + /// The order of tensors in `feed_tensors` must match the order of names in + /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will + /// match the order of names in `CallableOptions::fetch()` when this subgraph + /// was created. + /// NOTE: This API is still experimental and may change. + Status RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata); + + /// \brief Releases resources associated with the given `handle` in this + /// session. + /// NOTE: This API is still experimental and may change. + Status ReleaseCallable(CallableHandle handle); private: class Impl; diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc index ea5cf5a1f12be316cc6e0d0a02cd3caf4d177400..559ffea7e817526e7f1396cd0e8187d01364f23b 100644 --- a/tensorflow/cc/client/client_session_test.cc +++ b/tensorflow/cc/client/client_session_test.cc @@ -95,5 +95,26 @@ TEST(ClientSessionTest, MultiThreaded) { test::ExpectTensorEqual(outputs[0], test::AsTensor({-1, 2}, {2})); } +TEST(ClientSessionTest, Callable) { + Scope root = Scope::NewRootScope(); + auto a = Placeholder(root, DT_INT32); + auto b = Placeholder(root, DT_INT32); + auto c = Add(root, a, b); + ClientSession session(root); + std::vector outputs; + + CallableOptions options; + options.add_feed(a.node()->name()); + options.add_feed(b.node()->name()); + options.add_fetch(c.node()->name()); + ClientSession::CallableHandle callable; + TF_CHECK_OK(session.MakeCallable(options, &callable)); + TF_EXPECT_OK(session.RunCallable( + callable, {test::AsTensor({1}, {}), test::AsTensor({41}, {})}, + &outputs, nullptr)); + test::ExpectTensorEqual(outputs[0], test::AsTensor({42}, {})); + TF_EXPECT_OK(session.ReleaseCallable(callable)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc index de2645cb440bda1f35e764af9197ca97bb760c08..e9f9c59e3aa0e8a9dc5d5e658540e9da73adaca5 100644 --- a/tensorflow/cc/framework/gradient_checker.cc +++ b/tensorflow/cc/framework/gradient_checker.cc @@ -247,7 +247,7 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs, auto y_pos_flat = y_pos[y_idx].flat(); auto y_neg_flat = y_neg[y_idx].flat(); const int64 y_size = y_shapes[y_idx].num_elements(); - const Y_T scale = Y_T{2 * delta}; + const Y_T scale = 2 * delta; auto jacobian = (*jacobian_ts)[x_idx * y_num + y_idx].matrix(); for (int c = 0; c < y_size; ++c) { SetJacobian(&jacobian, r * x_stride + unit_dimension, @@ -351,7 +351,14 @@ Status ComputeGradientErrorInternal(const Scope& scope, const OutputList& xs, auto jac_n = jacobian_ns[i].matrix(); for (int r = 0; r < jacobian_ts[i].dim_size(0); ++r) { for (int c = 0; c < jacobian_ts[i].dim_size(1); ++c) { - *max_error = std::max(*max_error, std::fabs(jac_t(r, c) - jac_n(r, c))); + auto cur_error = std::fabs(jac_t(r, c) - jac_n(r, c)); + // Treat any NaN as max_error and immediately return. + // (Note that std::max may ignore NaN arguments.) + if (std::isnan(cur_error)) { + *max_error = cur_error; + return Status::OK(); + } + *max_error = std::max(*max_error, cur_error); } } } @@ -409,6 +416,7 @@ Status ComputeGradientError(const Scope& scope, const Output& x, const Output& y, const TensorShape& y_shape, JAC_T* max_error); INSTANTIATE_GRAD_ERR_TYPE(float, float, float); +INSTANTIATE_GRAD_ERR_TYPE(double, float, double); INSTANTIATE_GRAD_ERR_TYPE(double, double, double); INSTANTIATE_GRAD_ERR_TYPE(complex64, float, float); INSTANTIATE_GRAD_ERR_TYPE(float, complex64, float); diff --git a/tensorflow/cc/framework/gradient_checker_test.cc b/tensorflow/cc/framework/gradient_checker_test.cc index d4f0a7f5ab3716be41e22c02a21aca028f76fb88..8dd762c282eff287bddd49ea6f38b2b8060949b0 100644 --- a/tensorflow/cc/framework/gradient_checker_test.cc +++ b/tensorflow/cc/framework/gradient_checker_test.cc @@ -28,12 +28,14 @@ namespace { using ops::Complex; using ops::Const; +using ops::Div; using ops::MatMul; using ops::Placeholder; using ops::Real; using ops::Split; using ops::Square; using ops::Stack; +using ops::Sub; using ops::Unstack; TEST(GradientCheckerTest, BasicFloat) { @@ -104,6 +106,20 @@ TEST(GradientCheckerTest, Complex64ToFloat) { EXPECT_LT(max_error, 1e-4); } +// When calculating gradients that are undefined, test we get NaN +// as the computed error rather than 0. +TEST(GradientCheckerTest, BasicNan) { + Scope scope = Scope::NewRootScope(); + TensorShape shape({2, 4, 3}); + auto x = Placeholder(scope, DT_FLOAT, Placeholder::Shape(shape)); + // y = x/(x-x) should always return NaN + auto y = Div(scope, x, Sub(scope, x, x)); + float max_error; + TF_ASSERT_OK((ComputeGradientError( + scope, {x}, {shape}, {y}, {shape}, &max_error))); + EXPECT_TRUE(std::isnan(max_error)); +} + TEST(GradientCheckerTest, MatMulGrad) { Scope scope = Scope::NewRootScope(); diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index b353accddcb6db9a07c112de03ead2f02c4ee6a6..e9173227aadbf86eab666e6c17bacacb92888572 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -120,6 +120,24 @@ Status SplitGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Split", SplitGrad); +Status FillGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // y = fill(fill_shape, x) + // No gradient returned for the fill_shape argument. + grad_outputs->push_back(NoGradient()); + // The gradient for x (which must be a scalar) is just the sum of + // all the gradients from the shape it fills. + // We use ReduceSum to implement this, which needs an argument providing + // the indices of all the dimensions of the incoming gradient. + // grad(x) = reduce_sum(grad(y), [0..rank(grad(y))]) + auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]), + Const(scope, 1)); + grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims)); + return scope.status(); +} +REGISTER_GRADIENT_OP("Fill", FillGrad); + Status DiagGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index d09275b6487b4212aa35a0476002f2bb587fa210..f41de3dc2098df55fbbb616557f264a4e70db6b6 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -108,6 +108,14 @@ TEST_F(ArrayGradTest, SplitGrad) { RunTest({x}, {x_shape}, y.output, {y_shape, y_shape}); } +TEST_F(ArrayGradTest, FillGrad) { + TensorShape x_shape({}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + TensorShape y_shape({2, 5, 3}); + auto y = Fill(scope_, {2, 5, 3}, x); + RunTest(x, x_shape, y, y_shape); +} + TEST_F(ArrayGradTest, DiagGrad) { TensorShape x_shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); diff --git a/tensorflow/cc/gradients/image_grad.cc b/tensorflow/cc/gradients/image_grad.cc new file mode 100644 index 0000000000000000000000000000000000000000..882709e1e2817431a32c453fe0f35f2b2e6c69b0 --- /dev/null +++ b/tensorflow/cc/gradients/image_grad.cc @@ -0,0 +1,74 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/gradients.h" +#include "tensorflow/cc/ops/image_ops_internal.h" +#include "tensorflow/cc/ops/standard_ops.h" + +namespace tensorflow { +namespace ops { +namespace { + +Status ResizeNearestNeighborGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool align_corners; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners)); + // The internal gradient implementation needs the shape of the input image. + // x_shape = shape(x)[1:3] + // = slice(shape(x), {1}, {3 - 1}) + auto x_shape = Slice(scope, Shape(scope, op.input(0)), {1}, {2}); + grad_outputs->push_back(internal::ResizeNearestNeighborGrad( + scope, grad_inputs[0], x_shape, + internal::ResizeNearestNeighborGrad::AlignCorners(align_corners))); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("ResizeNearestNeighbor", ResizeNearestNeighborGradHelper); + +Status ResizeBilinearGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool align_corners; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners)); + grad_outputs->push_back(internal::ResizeBilinearGrad( + scope, grad_inputs[0], op.input(0), + internal::ResizeBilinearGrad::AlignCorners(align_corners))); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("ResizeBilinear", ResizeBilinearGradHelper); + +Status ResizeBicubicGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool align_corners; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "align_corners", &align_corners)); + grad_outputs->push_back(internal::ResizeBicubicGrad( + scope, grad_inputs[0], op.input(0), + internal::ResizeBicubicGrad::AlignCorners(align_corners))); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("ResizeBicubic", ResizeBicubicGradHelper); + +} // anonymous namespace +} // namespace ops +} // namespace tensorflow diff --git a/tensorflow/cc/gradients/image_grad_test.cc b/tensorflow/cc/gradients/image_grad_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2e55c7561b030c50bd67bd53fd0d55710085c5d2 --- /dev/null +++ b/tensorflow/cc/gradients/image_grad_test.cc @@ -0,0 +1,157 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/gradient_checker.h" +#include "tensorflow/cc/framework/testutil.h" +#include "tensorflow/cc/gradients/grad_testutil.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +using ops::Const; +using ops::ResizeBicubic; +using ops::ResizeBilinear; +using ops::ResizeNearestNeighbor; + +class ImageGradTest : public ::testing::Test { + protected: + ImageGradTest() : scope_(Scope::NewRootScope()) {} + + enum OpType { RESIZE_NEAREST, RESIZE_BILINEAR, RESIZE_BICUBIC }; + + template + Tensor MakeData(const TensorShape& data_shape) { + DataType data_type = DataTypeToEnum::v(); + Tensor data(data_type, data_shape); + auto data_flat = data.flat(); + for (int i = 0; i < data_flat.size(); ++i) { + data_flat(i) = T(i); + } + return data; + } + + template + void MakeOp(const OpType op_type, const Tensor& x_data, const Input& y_shape, + const bool align_corners, Output* x, Output* y) { + *x = Const(scope_, x_data); + switch (op_type) { + case RESIZE_NEAREST: + *y = ResizeNearestNeighbor( + scope_, *x, y_shape, + ResizeNearestNeighbor::AlignCorners(align_corners)); + return; + case RESIZE_BILINEAR: + *y = ResizeBilinear(scope_, *x, y_shape, + ResizeBilinear::AlignCorners(align_corners)); + return; + case RESIZE_BICUBIC: + *y = ResizeBicubic(scope_, *x, y_shape, + ResizeBicubic::AlignCorners(align_corners)); + return; + } + assert(false); + } + + template + void TestResizedShapeForType(const OpType op_type, const bool align_corners) { + TensorShape x_shape({1, 2, 2, 1}); + Tensor x_data = MakeData(x_shape); + Output x, y; + MakeOp(op_type, x_data, {4, 6}, align_corners, &x, &y); + + ClientSession session(scope_); + std::vector outputs; + TF_ASSERT_OK(session.Run({y}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + EXPECT_EQ(outputs[0].shape(), TensorShape({1, 4, 6, 1})); + } + + void TestResizedShape(OpType op_type) { + for (const bool align_corners : {true, false}) { + TestResizedShapeForType(op_type, align_corners); + TestResizedShapeForType(op_type, align_corners); + TestResizedShapeForType(op_type, align_corners); + } + } + + template + void TestResizeToSmallerAndAlign(const OpType op_type, + const bool align_corners) { + TensorShape x_shape({1, 4, 6, 1}); + Tensor x_data = MakeData(x_shape); + Output x, y; + MakeOp(op_type, x_data, {2, 3}, align_corners, &x, &y); + JAC_T max_error; + TF_ASSERT_OK((ComputeGradientError( + scope_, x, x_data, y, {1, 2, 3, 1}, &max_error))); + EXPECT_LT(max_error, 1e-3); + } + + template + void TestResizeToLargerAndAlign(const OpType op_type, + const bool align_corners) { + TensorShape x_shape({1, 2, 3, 1}); + Tensor x_data = MakeData(x_shape); + Output x, y; + MakeOp(op_type, x_data, {4, 6}, align_corners, &x, &y); + JAC_T max_error; + TF_ASSERT_OK((ComputeGradientError( + scope_, x, x_data, y, {1, 4, 6, 1}, &max_error))); + EXPECT_LT(max_error, 1e-3); + } + + template + void TestResize(OpType op_type) { + for (const bool align_corners : {true, false}) { + TestResizeToSmallerAndAlign(op_type, align_corners); + TestResizeToLargerAndAlign(op_type, align_corners); + } + } + + Scope scope_; +}; + +TEST_F(ImageGradTest, TestNearestNeighbor) { + TestResizedShape(RESIZE_NEAREST); + TestResize(RESIZE_NEAREST); + TestResize(RESIZE_NEAREST); +} + +TEST_F(ImageGradTest, TestBilinear) { + TestResizedShape(RESIZE_BILINEAR); + TestResize(RESIZE_BILINEAR); + // Note that Y_T is always float for this op. We choose + // double for the jacobian to capture the higher precision + // between X_T and Y_T. + TestResize(RESIZE_BILINEAR); +} + +TEST_F(ImageGradTest, TestBicubic) { + TestResizedShape(RESIZE_BICUBIC); + TestResize(RESIZE_BICUBIC); + // Note that Y_T is always float for this op. We choose + // double for the jacobian to capture the higher precision + // between X_T and Y_T. + TestResize(RESIZE_BICUBIC); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index fd7b6fe6625f27bda92e2f56f60908658cdecd7e..1c9bdff5e1295135abe60c282d565c39071fd78a 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -475,11 +475,7 @@ TEST_F(CWiseUnaryGradTest, Tan_Complex) { auto x_fn = [this](const int i) { return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); }; - // TODO(kbsriram) - // Enable when tan kernel supports complex inputs - if (false) { - TestCWiseGrad(TAN, x_fn); - } + TestCWiseGrad(TAN, x_fn); } TEST_F(CWiseUnaryGradTest, Atan) { diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index c73482d5f4d13ade0dc0412941251d1651371b6e..588e96cb196189780037f66266484962ba0385e4 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -47,6 +47,72 @@ Status SoftmaxGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad); +bool IsZero(const Scope& scope, const Output& grad) { + string op_type_name = grad.op().node()->type_string(); + if (op_type_name == "ZerosLike" || op_type_name == "Zeros") { + return true; + } + // The Operation we were provided is not named something obvious so + // we need to actually look at its contents. + // The original python code did this by calling a utility function called + // tensor_util.constant_value. + // There is no C++ equivalent to tensor_util.constant_value so we do nothing + // for the moment. + return false; +} + +// Multiply after broadcasting vec to match dimensions of mat. +// Args: +// vec: A 1-D tensor of dimension [D0] +// mat: A 2-D tensor of dimesnion [D0, D1] +// +// Returns: +// A tensor of dimension [D0, D1], the result fo vec * mat. +Output BroadcastMul(const Scope& scope, const Output& vec, const Output& mat) { + auto reshaped = ExpandDims(scope, vec, -1); + return Multiply(scope, reshaped, mat); +} + +Status SoftmaxCrossEntropyWithLogitsGrad(const Scope& scope, + const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // Softmax gradient with cross entropy logits function. + // We multiply the backprop for cost with the gradients - op.output[1]. + // There is no gradient for labels. + + // The outputs of the network are at input index 0. + auto logits = op.input(0); + // The "truth" labels are at index 1. + auto softmax_grad = op.output(1); + + // The loss is the output at index 0, and backprop is the output at index 1. + auto grad_loss = grad_inputs[0]; + auto grad_grad = grad_inputs[1]; + + auto grad = BroadcastMul(scope, grad_loss, softmax_grad); + if (!IsZero(scope, grad_grad)) { + std::vector axis; + auto logits_softmax = Softmax(scope, logits); + + auto grad_grad_expand = ExpandDims(scope, grad_grad, 1); + auto logits_softmax_expand = ExpandDims(scope, logits_softmax, 2); + auto matmul_result = + BatchMatMul(scope, grad_grad_expand, logits_softmax_expand); + axis.push_back(1); + auto squeeze_result = Squeeze(scope, matmul_result, Squeeze::Axis(axis)); + auto subtraction_result = Subtract(scope, grad_grad, squeeze_result); + auto multiply_result = Multiply(scope, subtraction_result, logits_softmax); + grad = Add(scope, grad, multiply_result); + } + auto minus_log_softmax = Multiply(scope, LogSoftmax(scope, logits), -1.0f); + grad_outputs->push_back(grad); + grad_outputs->push_back(BroadcastMul(scope, grad_loss, minus_log_softmax)); + return scope.status(); +} +REGISTER_GRADIENT_OP("SoftmaxCrossEntropyWithLogits", + SoftmaxCrossEntropyWithLogitsGrad); + Status LogSoftmaxGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { @@ -195,9 +261,9 @@ Status MaxPool3DGradHelper(const Scope& scope, const Operation& op, TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); MaxPool3DGrad::Attrs grad_attrs; - auto dx = MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0], - ksize, strides, padding, - grad_attrs.DataFormat(data_format)); + auto dx = + MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0], ksize, + strides, padding, grad_attrs.DataFormat(data_format)); grad_outputs->push_back(dx); return scope.status(); } @@ -216,10 +282,9 @@ Status AvgPoolGradHelper(const Scope& scope, const Operation& op, TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); internal::AvgPoolGrad::Attrs grad_attrs; - auto dx = - internal::AvgPoolGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], - ksize, strides, padding, - grad_attrs.DataFormat(data_format)); + auto dx = internal::AvgPoolGrad(scope, Shape(scope, op.input(0)), + grad_inputs[0], ksize, strides, padding, + grad_attrs.DataFormat(data_format)); grad_outputs->push_back(dx); return scope.status(); } @@ -238,9 +303,9 @@ Status AvgPool3DGradHelper(const Scope& scope, const Operation& op, TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); AvgPool3DGrad::Attrs grad_attrs; - auto dx = AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], - ksize, strides, padding, - grad_attrs.DataFormat(data_format)); + auto dx = + AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], ksize, + strides, padding, grad_attrs.DataFormat(data_format)); grad_outputs->push_back(dx); return scope.status(); } diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index b4d457a9d14eb79232cda9412fa0050f6a9968cc..aa72cf7ba2a958f54d50b59f0edaefb27edf0e86 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -25,6 +25,8 @@ limitations under the License. namespace tensorflow { namespace { +using ops::AvgPool; +using ops::AvgPool3D; using ops::BiasAdd; using ops::Conv2D; using ops::Elu; @@ -33,11 +35,9 @@ using ops::FractionalMaxPool; using ops::L2Loss; using ops::LogSoftmax; using ops::LRN; -using ops::AvgPool; -using ops::AvgPool3D; using ops::MaxPool; -using ops::MaxPoolV2; using ops::MaxPool3D; +using ops::MaxPoolV2; using ops::Placeholder; using ops::Relu; using ops::Relu6; @@ -111,6 +111,20 @@ TEST_F(NNGradTest, SoftmaxGrad) { RunTest(x, shape, y, shape); } +TEST_F(NNGradTest, SoftmaxCrossEntropyWithLogitsGrad) { + TensorShape logits_shape({5, 3}); + TensorShape loss_shape({5}); + + auto logits = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(logits_shape)); + auto labels = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(logits_shape)); + auto y = + tensorflow::ops::SoftmaxCrossEntropyWithLogits(scope_, logits, labels); + // Note the reversal of the backprop and loss orders. Issue #18734 has been + // opened for this. + RunTest({logits, labels}, {logits_shape, logits_shape}, {y.backprop, y.loss}, + {logits_shape, loss_shape}); +} + TEST_F(NNGradTest, LogSoftmaxGrad) { TensorShape shape({5, 3}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); @@ -253,7 +267,7 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) { RunTest(x, x_shape, y, y_shape); } -TEST_F(NNGradTest, LRN){ +TEST_F(NNGradTest, LRN) { TensorShape x_shape({1, 1, 2, 1}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); auto y = LRN(scope_, x); diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 06a3be18e08f611d3ecf9804908d791d15fdab13..3d3895c8fa82c3c0e2974228e9cad767d0e00df4 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -33,6 +33,46 @@ cc_library( hdrs = ["tag_constants.h"], ) +cc_library( + name = "reader", + srcs = ["reader.cc"], + hdrs = ["reader.h"], + deps = [ + ":constants", + ] + if_not_mobile([ + # TODO(b/111634734): :lib and :protos_all contain dependencies that + # cannot be built on mobile platforms. Instead, include the appropriate + # tf_lib depending on the build platform. + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ]) + if_mobile([ + # Mobile-friendly SavedModel proto. See go/portable-proto for more info. + "//tensorflow/core:saved_model_portable_proto", + ]) + if_android([ + "//tensorflow/core:android_tensorflow_lib", + ]) + if_ios([ + "//tensorflow/core:ios_tensorflow_lib", + ]), +) + +tf_cc_test( + name = "reader_test", + srcs = ["reader_test.cc"], + data = [ + ":saved_model_half_plus_two", + ], + linkstatic = 1, + deps = [ + ":constants", + ":reader", + ":tag_constants", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_library( name = "loader", hdrs = ["loader.h"], @@ -54,6 +94,7 @@ cc_library( hdrs = ["loader.h"], deps = [ ":constants", + ":reader", ] + if_not_mobile([ "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index faa1e378d07ea94ad08ee084d18bf6a113f054af..3830416159158cca8bfb8422c2959b49fa42406d 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -18,8 +18,10 @@ limitations under the License. #include #include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf_internal.h" @@ -43,56 +45,6 @@ auto* load_latency = monitoring::Counter<1>::New( constexpr char kLoadAttemptFail[] = "fail"; constexpr char kLoadAttemptSuccess[] = "success"; -Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) { - const string saved_model_pb_path = - io::JoinPath(export_dir, kSavedModelFilenamePb); - if (Env::Default()->FileExists(saved_model_pb_path).ok()) { - return ReadBinaryProto(Env::Default(), saved_model_pb_path, - saved_model_proto); - } - const string saved_model_pbtxt_path = - io::JoinPath(export_dir, kSavedModelFilenamePbTxt); - if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) { - return ReadTextProto(Env::Default(), saved_model_pbtxt_path, - saved_model_proto); - } - return Status(error::Code::NOT_FOUND, - "Could not find SavedModel .pb or .pbtxt at supplied export " - "directory path: " + - export_dir); -} - -string GetTagsAsString(const std::unordered_set& tags) { - string tags_as_string = "{ "; - for (const string& tag : tags) { - tags_as_string = strings::StrCat(tags_as_string, tag, " "); - } - tags_as_string = strings::StrCat(tags_as_string, "}"); - return tags_as_string; -} - -Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, - const std::unordered_set& tags, - MetaGraphDef* meta_graph_def_to_load) { - for (const MetaGraphDef& meta_graph_def : saved_model_proto.meta_graphs()) { - // Get tags from the meta_graph_def. - std::unordered_set graph_tags; - for (const string& tag : meta_graph_def.meta_info_def().tags()) { - graph_tags.insert(tag); - } - // Match with the set of tags provided. - if (graph_tags == tags) { - *meta_graph_def_to_load = meta_graph_def; - return Status::OK(); - } - } - return Status(error::Code::NOT_FOUND, - "Could not find meta graph def matching supplied tags: " + - GetTagsAsString(tags) + - ". To inspect available tag-sets in the SavedModel, please " - "use the SavedModel CLI: `saved_model_cli`"); -} - Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, const SessionOptions& session_options, std::unique_ptr* session) { @@ -122,6 +74,54 @@ void AddAssetsTensorsToInputs(const StringPiece export_dir, } } +// Like Session::Run(), but uses the Make/Run/ReleaseCallable() API to avoid +// leaving behind non-GC'ed state. +// +// Detailed motivation behind this approach, from ashankar@: +// +// Each call to Session::Run() that identifies a new subgraph (based on feeds +// and fetches) creates some datastructures that live as long as the session +// (the partitioned graph, associated executors etc.). +// +// A pathological case of this would be if say the initialization op +// (main_op/legacy_init_op) involves the use of a large constant. Then we +// allocate memory for that large constant that will just stick around till the +// session dies. With this Callable mechanism, that memory will be released +// right after ReleaseCallable returns. +// +// However, the resource manager state remains. +Status RunOnce(const RunOptions& run_options, + const std::vector>& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs, RunMetadata* run_metadata, + Session* session) { + CallableOptions callable_options; + std::vector feed_tensors; + *callable_options.mutable_run_options() = run_options; + for (const auto& input : inputs) { + const string& name = input.first; + const Tensor& tensor = input.second; + callable_options.add_feed(name); + feed_tensors.push_back(tensor); + } + for (const string& output_tensor_name : output_tensor_names) { + callable_options.add_fetch(output_tensor_name); + } + for (const string& target_node_name : target_node_names) { + callable_options.add_target(target_node_name); + } + + Session::CallableHandle callable_handle; + TF_RETURN_IF_ERROR(session->MakeCallable(callable_options, &callable_handle)); + const Status run_status = session->RunCallable(callable_handle, feed_tensors, + outputs, run_metadata); + // Be sure to call ReleaseCallable() regardless of the outcome of + // RunCallable(). + session->ReleaseCallable(callable_handle).IgnoreError(); + return run_status; +} + bool HasMainOp(const MetaGraphDef& meta_graph_def) { const auto& collection_def_map = meta_graph_def.collection_def(); if (collection_def_map.find(kSavedModelMainOpKey) != @@ -134,10 +134,11 @@ bool HasMainOp(const MetaGraphDef& meta_graph_def) { Status RunMainOp(const RunOptions& run_options, const string& export_dir, const MetaGraphDef& meta_graph_def, const std::vector& asset_file_defs, - Session* session) { - LOG(INFO) << "Running MainOp on SavedModel bundle."; + Session* session, const string& main_op_key) { + LOG(INFO) << "Running MainOp with key " << main_op_key + << " on SavedModel bundle."; const auto& collection_def_map = meta_graph_def.collection_def(); - const auto main_op_it = collection_def_map.find(kSavedModelMainOpKey); + const auto main_op_it = collection_def_map.find(main_op_key); if (main_op_it != collection_def_map.end()) { if (main_op_it->second.node_list().value_size() != 1) { return errors::FailedPrecondition( @@ -147,8 +148,8 @@ Status RunMainOp(const RunOptions& run_options, const string& export_dir, AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; const StringPiece main_op_name = main_op_it->second.node_list().value(0); - return session->Run(run_options, inputs, {}, {main_op_name.ToString()}, - nullptr /* outputs */, &run_metadata); + return RunOnce(run_options, inputs, {}, {main_op_name.ToString()}, + nullptr /* outputs */, &run_metadata, session); } return Status::OK(); } @@ -169,7 +170,8 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, variables_directory, MetaFilename(kSavedModelVariablesFilename)); if (!Env::Default()->FileExists(variables_index_path).ok()) { LOG(INFO) << "The specified SavedModel has no variables; no checkpoints " - "were restored."; + "were restored. File does not exist: " + << variables_index_path; return Status::OK(); } const string variables_path = @@ -185,32 +187,8 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; - return session->Run(run_options, inputs, {}, {restore_op_name.ToString()}, - nullptr /* outputs */, &run_metadata); -} - -Status RunLegacyInitOp(const RunOptions& run_options, const string& export_dir, - const MetaGraphDef& meta_graph_def, - const std::vector& asset_file_defs, - Session* session) { - LOG(INFO) << "Running LegacyInitOp on SavedModel bundle."; - const auto& collection_def_map = meta_graph_def.collection_def(); - const auto init_op_it = collection_def_map.find(kSavedModelLegacyInitOpKey); - if (init_op_it != collection_def_map.end()) { - if (init_op_it->second.node_list().value_size() != 1) { - return errors::FailedPrecondition(strings::StrCat( - "Expected exactly one serving init op in : ", export_dir)); - } - std::vector> inputs; - AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); - RunMetadata run_metadata; - const StringPiece legacy_init_op_name = - init_op_it->second.node_list().value(0); - return session->Run(run_options, inputs, {}, - {legacy_init_op_name.ToString()}, nullptr /* outputs */, - &run_metadata); - } - return Status::OK(); + return RunOnce(run_options, inputs, {}, {restore_op_name.ToString()}, + nullptr /* outputs */, &run_metadata, session); } Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, @@ -235,18 +213,8 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, const string& export_dir, const std::unordered_set& tags, SavedModelBundle* const bundle) { - if (!MaybeSavedModelDirectory(export_dir)) { - return Status(error::Code::NOT_FOUND, - "SavedModel not found in export directory: " + export_dir); - } - LOG(INFO) << "Loading SavedModel with tags: " << GetTagsAsString(tags) - << "; from: " << export_dir; - - SavedModel saved_model_proto; - TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto)); - - TF_RETURN_IF_ERROR( - FindMetaGraphDefToLoad(saved_model_proto, tags, &bundle->meta_graph_def)); + TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags, + &bundle->meta_graph_def)); TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession( bundle->meta_graph_def, session_options, &bundle->session)); @@ -262,11 +230,11 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, if (HasMainOp(bundle->meta_graph_def)) { TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir, bundle->meta_graph_def, asset_file_defs, - bundle->session.get())); + bundle->session.get(), kSavedModelMainOpKey)); } else { - TF_RETURN_IF_ERROR(RunLegacyInitOp(run_options, export_dir, - bundle->meta_graph_def, asset_file_defs, - bundle->session.get())); + TF_RETURN_IF_ERROR(RunMainOp( + run_options, export_dir, bundle->meta_graph_def, asset_file_defs, + bundle->session.get(), kSavedModelLegacyInitOpKey)); } return Status::OK(); } @@ -288,8 +256,8 @@ Status LoadSavedModel(const SessionOptions& session_options, return end_microseconds - start_microseconds; }(); auto log_and_count = [&](const string& status_str) { - LOG(INFO) << "SavedModel load for tags " << GetTagsAsString(tags) - << "; Status: " << status_str << ". Took " + LOG(INFO) << "SavedModel load for tags { " << str_util::Join(tags, " ") + << " }; Status: " << status_str << ". Took " << load_latency_microsecs << " microseconds."; load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1); }; diff --git a/tensorflow/cc/saved_model/reader.cc b/tensorflow/cc/saved_model/reader.cc new file mode 100644 index 0000000000000000000000000000000000000000..2146c8a19745fa9ea2484c4bb4a2104a38d85144 --- /dev/null +++ b/tensorflow/cc/saved_model/reader.cc @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/saved_model/reader.h" + +#include + +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/protobuf/saved_model.pb.h" + +namespace tensorflow { +namespace { + +Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) { + LOG(INFO) << "Reading SavedModel from: " << export_dir; + + const string saved_model_pb_path = + io::JoinPath(export_dir, kSavedModelFilenamePb); + if (Env::Default()->FileExists(saved_model_pb_path).ok()) { + return ReadBinaryProto(Env::Default(), saved_model_pb_path, + saved_model_proto); + } + const string saved_model_pbtxt_path = + io::JoinPath(export_dir, kSavedModelFilenamePbTxt); + if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) { + return ReadTextProto(Env::Default(), saved_model_pbtxt_path, + saved_model_proto); + } + return Status(error::Code::NOT_FOUND, + "Could not find SavedModel .pb or .pbtxt at supplied export " + "directory path: " + + export_dir); +} + +Status FindMetaGraphDef(const SavedModel& saved_model_proto, + const std::unordered_set& tags, + MetaGraphDef* meta_graph_def) { + LOG(INFO) << "Reading meta graph with tags { " << str_util::Join(tags, " ") + << " }"; + for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) { + // Get tags from the graph_def. + std::unordered_set graph_tags; + for (const string& tag : graph_def.meta_info_def().tags()) { + graph_tags.insert(tag); + } + // Match with the set of tags provided. + if (graph_tags == tags) { + *meta_graph_def = graph_def; + return Status::OK(); + } + } + return Status( + error::Code::NOT_FOUND, + strings::StrCat( + "Could not find meta graph def matching supplied tags: { ", + str_util::Join(tags, " "), + " }. To inspect available tag-sets in the SavedModel, please " + "use the SavedModel CLI: `saved_model_cli`")); +} + +} // namespace + +Status ReadMetaGraphDefFromSavedModel(const string& export_dir, + const std::unordered_set& tags, + MetaGraphDef* const meta_graph_def) { + SavedModel saved_model_proto; + TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto)); + TF_RETURN_IF_ERROR(FindMetaGraphDef(saved_model_proto, tags, meta_graph_def)); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/cc/saved_model/reader.h b/tensorflow/cc/saved_model/reader.h new file mode 100644 index 0000000000000000000000000000000000000000..5815108df2a1883b6618e801f30c1915cde8c895 --- /dev/null +++ b/tensorflow/cc/saved_model/reader.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/// Functions to read the SavedModel proto, or parts of it. + +#ifndef TENSORFLOW_CC_SAVED_MODEL_READER_H_ +#define TENSORFLOW_CC_SAVED_MODEL_READER_H_ + +#include +#include + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { + +// Reads the SavedModel proto from saved_model.pb(txt) in the given directory, +// finds the MetaGraphDef that matches the given set of tags and writes it to +// the `meta_graph_def` parameter. Returns a failure status when the SavedModel +// file does not exist or no MetaGraphDef matches the tags. +Status ReadMetaGraphDefFromSavedModel(const string& export_dir, + const std::unordered_set& tags, + MetaGraphDef* const meta_graph_def); + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_READER_H_ diff --git a/tensorflow/cc/saved_model/reader_test.cc b/tensorflow/cc/saved_model/reader_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..620e9c2eece886c9600a8c93cede3b132ccbccaa --- /dev/null +++ b/tensorflow/cc/saved_model/reader_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/saved_model/reader.h" + +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/tag_constants.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +constexpr char kTestDataPbTxt[] = + "cc/saved_model/testdata/half_plus_two_pbtxt/00000123"; +constexpr char kTestDataSharded[] = + "cc/saved_model/testdata/half_plus_two/00000123"; + +class ReaderTest : public ::testing::Test { + protected: + ReaderTest() {} + + void CheckMetaGraphDef(const MetaGraphDef& meta_graph_def) { + const auto& tags = meta_graph_def.meta_info_def().tags(); + EXPECT_TRUE(std::find(tags.begin(), tags.end(), kSavedModelTagServe) != + tags.end()); + EXPECT_NE(meta_graph_def.meta_info_def().tensorflow_version(), ""); + EXPECT_EQ( + meta_graph_def.signature_def().at("serving_default").method_name(), + "tensorflow/serving/predict"); + } +}; + +TEST_F(ReaderTest, TagMatch) { + MetaGraphDef meta_graph_def; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe}, + &meta_graph_def)); + CheckMetaGraphDef(meta_graph_def); +} + +TEST_F(ReaderTest, NoTagMatch) { + MetaGraphDef meta_graph_def; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"}, + &meta_graph_def); + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(str_util::StrContains( + st.error_message(), + "Could not find meta graph def matching supplied tags: { missing-tag }")) + << st.error_message(); +} + +TEST_F(ReaderTest, NoTagMatchMultiple) { + MetaGraphDef meta_graph_def; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + Status st = ReadMetaGraphDefFromSavedModel( + export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def); + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(str_util::StrContains( + st.error_message(), + "Could not find meta graph def matching supplied tags: ")) + << st.error_message(); +} + +TEST_F(ReaderTest, PbtxtFormat) { + MetaGraphDef meta_graph_def; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt); + TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe}, + &meta_graph_def)); + CheckMetaGraphDef(meta_graph_def); +} + +TEST_F(ReaderTest, InvalidExportPath) { + MetaGraphDef meta_graph_def; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path"); + Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe}, + &meta_graph_def); + EXPECT_FALSE(st.ok()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 2119c8ec47f941a76e81346ae5d20da78eae11a3..d2f803bd18b38ad5c1a8b5afd70531db117826ea 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -8,28 +8,6 @@ load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -# Optional runtime utilities for use by code generated by tfcompile. -cc_library( - name = "runtime", - srcs = ["runtime.cc"], - hdrs = ["runtime.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework_lite", - ], -) - -tf_cc_test( - name = "runtime_test", - srcs = ["runtime_test.cc"], - deps = [ - ":runtime", - "//tensorflow/core:framework", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - # Don't depend on this directly; this is only used for the benchmark test # generated by tf_library. cc_library( @@ -53,9 +31,9 @@ cc_library( ], deps = [ ":embedded_protocol_buffers", - ":runtime", # needed by codegen to print aligned_buffer_bytes "//tensorflow/compiler/tf2xla", "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/tf2xla:tf2xla_proto", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -68,6 +46,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:compile_only_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/core:core_cpu_internal", @@ -237,7 +216,6 @@ test_suite( tests = [ ":benchmark_test", ":codegen_test", - ":runtime_test", ":test_graph_tfadd_test", ":test_graph_tfunknownop2_test", ":test_graph_tfunknownop3_test", diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 28070d60dbbe6dd8f930b8e6509cedcf09f94e11..8dbe1e11b7c392cca29fc8792d3cf9f1bf44f1fb 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" -#include "tensorflow/compiler/aot/runtime.h" +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/compiler/tf2xla/str_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -303,10 +303,10 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, const std::vector iarg(arg_sizes.begin(), arg_sizes.end()); const std::vector itemp(temp_sizes.begin(), temp_sizes.end()); const size_t arg_bytes_aligned = - runtime::aligned_buffer_bytes(iarg.data(), iarg.size()); + cpu_function_runtime::AlignedBufferBytes(iarg.data(), iarg.size()); const size_t arg_bytes_total = total_buffer_bytes(iarg.data(), iarg.size()); const size_t temp_bytes_aligned = - runtime::aligned_buffer_bytes(itemp.data(), itemp.size()); + cpu_function_runtime::AlignedBufferBytes(itemp.data(), itemp.size()); const size_t temp_bytes_total = total_buffer_bytes(itemp.data(), itemp.size()); diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index bbc35da2ef6d14ff0d3570ef2d5cf6743456c674..2b5f97b34cd928d32eb220536342c715d91d45bb 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/compile_only_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 5c57fee326ca743dcb8aaae354d261ed4d7f44be..326f73b975aec3a7a6bc7cdc9a92f540ad545ad6 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -16,339 +16,365 @@ tf_library( ) """ -load("//tensorflow:tensorflow.bzl", - "if_android", "tf_cc_test", "tf_copts") - -def tf_library(name, graph, config, - freeze_checkpoint=None, freeze_saver=None, - cpp_class=None, gen_test=True, gen_benchmark=True, - visibility=None, testonly=None, - tfcompile_flags=None, - tfcompile_tool="//tensorflow/compiler/aot:tfcompile", - include_standard_runtime_deps=True, - enable_xla_hlo_profiling=False, deps=None, tags=None): - """Runs tfcompile to compile a TensorFlow graph into executable code. - - Given an invocation of tf_library(name="foo", ...), generates the following - build targets: - foo: A cc_library containing the generated header and computation. - foo_test: A cc_test with simple tests and benchmarks. Only created if - gen_test=True. - foo_benchmark: A cc_binary that runs a minimal-dependency benchmark, useful - for mobile devices or other platforms that can't compile the - full test libraries. Only created if gen_benchmark=True. - - Args: - name: The name of the build rule. - graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it - is expected to be in the human-readable proto text format, otherwise it is - expected to be in the proto binary format. - config: File containing tensorflow.tf2xla.Config proto. If the file ends - in '.pbtxt' it is expected to be in the human-readable proto text format, - otherwise it is expected to be in the proto binary format. - freeze_checkpoint: If provided, run freeze_graph with this checkpoint to - convert variables into constants. - freeze_saver: If provided, run freeze_graph with this saver, in SaverDef - binary form, to convert variables into constants. - cpp_class: The name of the generated C++ class, wrapping the generated - function. The syntax of this flag is - [[::],...]. This mirrors the C++ syntax - for referring to a class, where multiple namespaces may precede the class - name, separated by double-colons. The class will be generated in the - given namespace(s), or if no namespaces are given, within the global - namespace. - gen_test: If True, also generate a cc_test rule that builds a simple - test and benchmark. - gen_benchmark: If True, also generate a binary with a simple benchmark. - Unlike the output of gen_test, this benchmark can be run on android. - visibility: Bazel build visibility. - testonly: Bazel testonly attribute. - tfcompile_flags: Extra flags to pass to tfcompile to control compilation. - tfcompile_tool: The tfcompile binary. A non-default can be passed to - use a tfcompile built with extra dependencies. - include_standard_runtime_deps: If True, the standard list of kernel/runtime - deps is added to deps. If False, deps must contain the full set of deps - needed by the generated library. - enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program, - and emit metadata that lets us pretty-print the gathered profile counters. - deps: a list of deps to include on the build rules for the generated - library, added to the standard deps if standard_runtime_deps is True. - tags: tags to apply to subsidiary build rules. - - The output header is called .h. - """ - if not cpp_class: - fail("cpp_class must be specified") - - tfcompile_graph = graph - if freeze_checkpoint or freeze_saver: - if not freeze_checkpoint: - fail("freeze_checkpoint must be specified when freeze_saver is specified") +load( + "//tensorflow:tensorflow.bzl", + "if_android", + "tf_cc_test", + "tf_copts", +) - freeze_name = "freeze_" + name - freeze_file = freeze_name + ".pb" +def tf_library( + name, + graph, + config, + freeze_checkpoint = None, + freeze_saver = None, + cpp_class = None, + gen_test = True, + gen_benchmark = True, + visibility = None, + testonly = None, + tfcompile_flags = None, + tfcompile_tool = "//tensorflow/compiler/aot:tfcompile", + include_standard_runtime_deps = True, + enable_xla_hlo_profiling = False, + deps = None, + tags = None): + """Runs tfcompile to compile a TensorFlow graph into executable code. - # First run tfcompile to generate the list of out_nodes. - out_nodes_file = "out_nodes_" + freeze_name - native.genrule( - name=("gen_" + out_nodes_file), - srcs=[config], - outs=[out_nodes_file], - cmd=("$(location " + tfcompile_tool + ")" + - " --config=$(location " + config + ")" + - " --dump_fetch_nodes > $@"), - tools=[tfcompile_tool], - # Run tfcompile on the build host, rather than forge, since it's - # typically way faster on the local machine. - local=1, - tags=tags, - ) + Given an invocation of tf_library(name="foo", ...), generates the following + build targets: + foo: A cc_library containing the generated header and + computation. + foo_test: A cc_test with simple tests and benchmarks. Only created if + gen_test=True. + foo_benchmark: A cc_binary that runs a minimal-dependency benchmark, + useful for mobile devices or other platforms that can't + compile the full test libraries. Only created if + gen_benchmark=True. + The output header is called .h. - # Now run freeze_graph to convert variables into constants. - freeze_args = (" --input_graph=$(location " + graph + ")" + - " --checkpoint_version=1" + - " --input_binary=" + str(not graph.endswith(".pbtxt")) + - " --input_checkpoint=$(location " + freeze_checkpoint + ")" + - " --output_graph=$(location " + freeze_file + ")" + - " --output_node_names=$$(<$(location " + out_nodes_file + - "))") - freeze_saver_srcs = [] - if freeze_saver: - freeze_args += " --input_saver=$(location " + freeze_saver + ")" - freeze_saver_srcs += [freeze_saver] - native.genrule( - name=freeze_name, - srcs=[ - graph, - freeze_checkpoint, - out_nodes_file, - ] + freeze_saver_srcs, - outs=[freeze_file], - cmd=("$(location //tensorflow/python/tools:freeze_graph)" + - freeze_args), - tools=["//tensorflow/python/tools:freeze_graph"], - tags=tags, - ) - tfcompile_graph = freeze_file + Args: + name: The name of the build rule. + graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' + it is expected to be in the human-readable proto text format, otherwise + it is expected to be in the proto binary format. + config: File containing tensorflow.tf2xla.Config proto. If the file ends + in '.pbtxt' it is expected to be in the human-readable proto text + format, otherwise it is expected to be in the proto binary format. + freeze_checkpoint: If provided, run freeze_graph with this checkpoint to + convert variables into constants. + freeze_saver: If provided, run freeze_graph with this saver, in SaverDef + binary form, to convert variables into constants. + cpp_class: The name of the generated C++ class, wrapping the generated + function. The syntax of this flag is + [[::],...]. This mirrors the C++ syntax + for referring to a class, where multiple namespaces may precede the + class name, separated by double-colons. The class will be generated in + the given namespace(s), or if no namespaces are given, within the global + namespace. + gen_test: If True, also generate a cc_test rule that builds a simple + test and benchmark. + gen_benchmark: If True, also generate a binary with a simple benchmark. + Unlike the output of gen_test, this benchmark can be run on android. + visibility: Bazel build visibility. + testonly: Bazel testonly attribute. + tfcompile_flags: Extra flags to pass to tfcompile to control compilation. + tfcompile_tool: The tfcompile binary. A non-default can be passed to + use a tfcompile built with extra dependencies. + include_standard_runtime_deps: If True, the standard list of + kernel/runtime deps is added to deps. If False, deps must contain the + full set of deps needed by the generated library. + enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated + program, and emit metadata that lets us pretty-print the gathered + profile counters. + deps: a list of deps to include on the build rules for the generated + library, added to the standard deps if standard_runtime_deps is True. + tags: tags to apply to subsidiary build rules. + """ + if not cpp_class: + fail("cpp_class must be specified") - # Rule that runs tfcompile to produce the header and object file. - header_file = name + ".h" - metadata_object_file = name + "_tfcompile_metadata.o" - function_object_file = name + "_tfcompile_function.o" - ep = ("__" + native.package_name() + "__" + name).replace("/", "_") - if type(tfcompile_flags) == type(""): - flags = tfcompile_flags - else: - flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])]) - if enable_xla_hlo_profiling: - profiling_flag = "--xla_hlo_profile" - else: - profiling_flag = "" - native.genrule( - name=("gen_" + name), - srcs=[ - tfcompile_graph, - config, - ], - outs=[ - header_file, - metadata_object_file, - function_object_file, - ], - cmd=("$(location " + tfcompile_tool + ")" + - " --graph=$(location " + tfcompile_graph + ")" + - " --config=$(location " + config + ")" + - " --entry_point=" + ep + - " --cpp_class=" + cpp_class + - " --target_triple=" + target_llvm_triple() + - " --out_header=$(@D)/" + header_file + - " --out_metadata_object=$(@D)/" + metadata_object_file + - " --out_function_object=$(@D)/" + function_object_file + - " " + flags + " " + profiling_flag), - tools=[tfcompile_tool], - visibility=visibility, - testonly=testonly, - # Run tfcompile on the build host since it's typically faster on the local - # machine. - # - # Note that setting the local=1 attribute on a *test target* causes the - # test infrastructure to skip that test. However this is a genrule, not a - # test target, and runs with --genrule_strategy=forced_forge, meaning the - # local=1 attribute is ignored, and the genrule is still run. - # - # https://www.bazel.io/versions/master/docs/be/general.html#genrule - local=1, - tags=tags, - ) + tfcompile_graph = graph + if freeze_checkpoint or freeze_saver: + if not freeze_checkpoint: + fail("freeze_checkpoint must be specified when freeze_saver is " + + "specified") - # Rule that runs tfcompile to produce the SessionModule proto, useful for - # debugging. TODO(b/64813587): Once the SessionModule proto is - # deterministic, move this into the main rule above. - session_module_pb = name + "_session_module.pb" - native.genrule( - name=(name + "_session_module"), - srcs=[ - tfcompile_graph, - config, - ], - outs=[ - session_module_pb, - ], - cmd=("$(location " + tfcompile_tool + ")" + - " --graph=$(location " + tfcompile_graph + ")" + - " --config=$(location " + config + ")" + - " --entry_point=" + ep + - " --cpp_class=" + cpp_class + - " --target_triple=" + target_llvm_triple() + - " --out_session_module=$(@D)/" + session_module_pb + - " " + flags), - tools=[tfcompile_tool], - visibility=visibility, - testonly=testonly, - local=1, - tags=tags, - ) + freeze_name = "freeze_" + name + freeze_file = freeze_name + ".pb" - # The cc_library rule packaging up the header and object file, and needed - # kernel implementations. - need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1) - native.cc_library( - name=name, - srcs=[function_object_file, metadata_object_file], - hdrs=[header_file], - visibility=visibility, - testonly=testonly, - deps = [ - # These deps are required by all tf_library targets even if - # include_standard_runtime_deps is False. Without them, the - # generated code will fail to compile. - "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", - "//tensorflow/core:framework_lite", - ] + (need_xla_data_proto and [ - # If we're generating the program shape, we must depend on the proto. - "//tensorflow/compiler/xla:xla_data_proto", - ] or []) + (enable_xla_hlo_profiling and [ - "//tensorflow/compiler/xla/service:hlo_profile_printer_data" - ] or []) + (include_standard_runtime_deps and [ - # TODO(cwhipkey): only depend on kernel code that the model actually needed. - "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", - "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", - "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", - "//tensorflow/compiler/xla/service/cpu:runtime_matmul", - "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", - "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", - "//third_party/eigen3", - ] or []) + (deps or []), - tags=tags, - ) + # First run tfcompile to generate the list of out_nodes. + out_nodes_file = "out_nodes_" + freeze_name + native.genrule( + name = ("gen_" + out_nodes_file), + srcs = [config], + outs = [out_nodes_file], + cmd = ("$(location " + tfcompile_tool + ")" + + " --config=$(location " + config + ")" + + " --dump_fetch_nodes > $@"), + tools = [tfcompile_tool], + # Run tfcompile on the build host, rather than forge, since it's + # typically way faster on the local machine. + local = 1, + tags = tags, + ) - # Variables used for gen_test and gen_benchmark. - no_ns_name = "" - cpp_class_split = cpp_class.rsplit("::", maxsplit=2) - if len(cpp_class_split) == 1: - no_ns_name = cpp_class_split[0] - else: - no_ns_name = cpp_class_split[1] - sed_replace = ( - "-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " + - "-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " + - "-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" ") + # Now run freeze_graph to convert variables into constants. + freeze_args = ( + " --input_graph=$(location " + graph + ")" + + " --checkpoint_version=1" + + " --input_binary=" + str(not graph.endswith(".pbtxt")) + + " --input_checkpoint=$(location " + freeze_checkpoint + ")" + + " --output_graph=$(location " + freeze_file + ")" + + " --output_node_names=$$(<$(location " + out_nodes_file + + "))" + ) + freeze_saver_srcs = [] + if freeze_saver: + freeze_args += " --input_saver=$(location " + freeze_saver + ")" + freeze_saver_srcs += [freeze_saver] + native.genrule( + name = freeze_name, + srcs = [ + graph, + freeze_checkpoint, + out_nodes_file, + ] + freeze_saver_srcs, + outs = [freeze_file], + cmd = ("$(location " + + "//tensorflow/python/tools:freeze_graph)" + + freeze_args), + tools = ["//tensorflow/python/tools:freeze_graph"], + tags = tags, + ) + tfcompile_graph = freeze_file - if gen_test: - test_name = name + "_test" - test_file = test_name + ".cc" - # Rule to rewrite test.cc to produce the test_file. + # Rule that runs tfcompile to produce the header and object file. + header_file = name + ".h" + metadata_object_file = name + "_tfcompile_metadata.o" + function_object_file = name + "_tfcompile_function.o" + ep = ("__" + native.package_name() + "__" + name).replace("/", "_") + if type(tfcompile_flags) == type(""): + flags = tfcompile_flags + else: + flags = " ".join([ + "'" + arg.replace("'", "'\\''") + "'" + for arg in (tfcompile_flags or []) + ]) + if enable_xla_hlo_profiling: + profiling_flag = "--xla_hlo_profile" + else: + profiling_flag = "" native.genrule( - name=("gen_" + test_name), - testonly=1, - srcs=[ - "//tensorflow/compiler/aot:test.cc", + name = ("gen_" + name), + srcs = [ + tfcompile_graph, + config, + ], + outs = [ header_file, + metadata_object_file, + function_object_file, ], - outs=[test_file], - cmd=("sed " + sed_replace + - " $(location //tensorflow/compiler/aot:test.cc) " + - "> $(OUTS)"), - tags=tags, - ) - - # The cc_test rule for the generated code. To ensure that this works - # reliably across build configurations, we must use tf_cc_test instead of - # native.cc_test. This is related to how we build - # //tensorflow/core:lib -- see the note in tensorflow/core/BUILD - # for more details. - tf_cc_test( - name=test_name, - srcs=[test_file], - deps=[ - ":" + name, - "//tensorflow/compiler/aot:runtime", - "//tensorflow/compiler/aot:tf_library_test_main", - "//tensorflow/compiler/xla:executable_run_options", - "//third_party/eigen3", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], - tags=tags, + cmd = ("$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_header=$(@D)/" + header_file + + " --out_metadata_object=$(@D)/" + metadata_object_file + + " --out_function_object=$(@D)/" + function_object_file + + " " + flags + " " + profiling_flag), + tools = [tfcompile_tool], + visibility = visibility, + testonly = testonly, + # Run tfcompile on the build host since it's typically faster on the + # local machine. + # + # Note that setting the local=1 attribute on a *test target* causes the + # test infrastructure to skip that test. However this is a genrule, not + # a test target, and runs with --genrule_strategy=forced_forge, meaning + # the local=1 attribute is ignored, and the genrule is still run. + # + # https://www.bazel.io/versions/master/docs/be/general.html#genrule + local = 1, + tags = tags, ) - if gen_benchmark: - benchmark_name = name + "_benchmark" - benchmark_file = benchmark_name + ".cc" - benchmark_main = ("//tensorflow/compiler/aot:" + - "benchmark_main.template") - - # Rule to rewrite benchmark.cc to produce the benchmark_file. + # Rule that runs tfcompile to produce the SessionModule proto, useful for + # debugging. TODO(b/64813587): Once the SessionModule proto is + # deterministic, move this into the main rule above. + session_module_pb = name + "_session_module.pb" native.genrule( - name=("gen_" + benchmark_name), - srcs=[ - benchmark_main, - header_file, + name = (name + "_session_module"), + srcs = [ + tfcompile_graph, + config, ], + outs = [ + session_module_pb, + ], + cmd = ("$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_session_module=$(@D)/" + session_module_pb + + " " + flags), + tools = [tfcompile_tool], + visibility = visibility, testonly = testonly, - outs=[benchmark_file], - cmd=("sed " + sed_replace + - " $(location " + benchmark_main + ") " + - "> $(OUTS)"), - tags=tags, + local = 1, + tags = tags, ) - # The cc_benchmark rule for the generated code. This does not need the - # tf_cc_binary since we (by deliberate design) do not depend on - # //tensorflow/core:lib. - # - # Note: to get smaller size on android for comparison, compile with: - # --copt=-fvisibility=hidden - # --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN - # --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN - native.cc_binary( - name=benchmark_name, - srcs=[benchmark_file], + # The cc_library rule packaging up the header and object file, and needed + # kernel implementations. + need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1) + native.cc_library( + name = name, + srcs = [function_object_file, metadata_object_file], + hdrs = [header_file], + visibility = visibility, testonly = testonly, - copts = tf_copts(), - linkopts = if_android(["-pie", "-s"]), - deps=[ - ":" + name, - "//tensorflow/compiler/aot:benchmark", - "//tensorflow/compiler/aot:runtime", - "//tensorflow/compiler/xla:executable_run_options", + deps = [ + # These deps are required by all tf_library targets even if + # include_standard_runtime_deps is False. Without them, the + # generated code will fail to compile. + "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", + "//tensorflow/core:framework_lite", + ] + (need_xla_data_proto and [ + # If we're generating the program shape, we must depend on the + # proto. + "//tensorflow/compiler/xla:xla_data_proto", + ] or []) + (enable_xla_hlo_profiling and [ + "//tensorflow/compiler/xla/service:hlo_profile_printer_data", + ] or []) + (include_standard_runtime_deps and [ + # TODO(cwhipkey): only depend on kernel code that the model actually + # needed. + "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", + "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", + "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_matmul", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//third_party/eigen3", - ] + if_android([ - "//tensorflow/compiler/aot:benchmark_extra_android", - ]), - tags=tags, + ] or []) + (deps or []), + tags = tags, + ) + + # Variables used for gen_test and gen_benchmark. + cpp_class_split = cpp_class.rsplit("::", maxsplit = 2) + if len(cpp_class_split) == 1: + no_ns_name = cpp_class_split[0] + else: + no_ns_name = cpp_class_split[1] + sed_replace = ( + "-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " + + "-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " + + "-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" " ) + if gen_test: + test_name = name + "_test" + test_file = test_name + ".cc" + + # Rule to rewrite test.cc to produce the test_file. + native.genrule( + name = ("gen_" + test_name), + testonly = 1, + srcs = [ + "//tensorflow/compiler/aot:test.cc", + header_file, + ], + outs = [test_file], + cmd = ( + "sed " + sed_replace + + " $(location //tensorflow/compiler/aot:test.cc) " + + "> $(OUTS)" + ), + tags = tags, + ) + + # The cc_test rule for the generated code. To ensure that this works + # reliably across build configurations, we must use tf_cc_test instead + # of native.cc_test. This is related to how we build + # //tensorflow/core:lib -- see the note in + # tensorflow/core/BUILD for more details. + tf_cc_test( + name = test_name, + srcs = [test_file], + deps = [ + ":" + name, + "//tensorflow/compiler/aot:tf_library_test_main", + "//tensorflow/compiler/xla:executable_run_options", + "//third_party/eigen3", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], + tags = tags, + ) + + if gen_benchmark: + benchmark_name = name + "_benchmark" + benchmark_file = benchmark_name + ".cc" + benchmark_main = ("//tensorflow/compiler/aot:" + + "benchmark_main.template") + + # Rule to rewrite benchmark.cc to produce the benchmark_file. + native.genrule( + name = ("gen_" + benchmark_name), + srcs = [ + benchmark_main, + header_file, + ], + testonly = testonly, + outs = [benchmark_file], + cmd = ("sed " + sed_replace + + " $(location " + benchmark_main + ") " + + "> $(OUTS)"), + tags = tags, + ) + + # The cc_benchmark rule for the generated code. This does not need the + # tf_cc_binary since we (by deliberate design) do not depend on + # //tensorflow/core:lib. + # + # Note: to get smaller size on android for comparison, compile with: + # --copt=-fvisibility=hidden + # --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN + # --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN + native.cc_binary( + name = benchmark_name, + srcs = [benchmark_file], + testonly = testonly, + copts = tf_copts(), + linkopts = if_android(["-pie", "-s"]), + deps = [ + ":" + name, + "//tensorflow/compiler/aot:benchmark", + "//tensorflow/compiler/xla:executable_run_options", + "//third_party/eigen3", + ] + if_android([ + "//tensorflow/compiler/aot:benchmark_extra_android", + ]), + tags = tags, + ) + def target_llvm_triple(): - """Returns the target LLVM triple to be used for compiling the target.""" - # TODO(toddw): Add target_triple for other targets. For details see: - # http://llvm.org/docs/doxygen/html/Triple_8h_source.html - return select({ - "//tensorflow:android_armeabi": "armv5-none-android", - "//tensorflow:android_arm": "armv7-none-android", - "//tensorflow:android_arm64": "aarch64-none-android", - "//tensorflow:android_x86": "i686-none-android", - "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", - "//tensorflow:darwin": "x86_64-none-darwin", - "//conditions:default": "x86_64-pc-linux", - }) + """Returns the target LLVM triple to be used for compiling the target.""" + + # TODO(toddw): Add target_triple for other targets. For details see: + # http://llvm.org/docs/doxygen/html/Triple_8h_source.html + return select({ + "//tensorflow:android_armeabi": "armv5-none-android", + "//tensorflow:android_arm": "armv7-none-android", + "//tensorflow:android_arm64": "aarch64-none-android", + "//tensorflow:android_x86": "i686-none-android", + "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", + "//tensorflow:darwin": "x86_64-none-darwin", + "//conditions:default": "x86_64-pc-linux", + }) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index d976f8296c6dc26bd13833f67874849aba91ad65..15f9ba217f2c2762de36a1e1c0fc7227449bb730 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -166,6 +166,7 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -176,13 +177,19 @@ cc_library( "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:fifo_queue", + "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:no_op", + "//tensorflow/core/kernels:queue_op", "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:shape_ops", "//tensorflow/core/kernels:variable_ops", + "//tensorflow/core/kernels/data:generator_dataset_op", + "//tensorflow/core/kernels/data:iterator_ops", + "//tensorflow/core/kernels/data:prefetch_dataset_op", ], ) @@ -302,11 +309,14 @@ cc_library( name = "compilation_passes", srcs = [ "build_xla_launch_ops_pass.cc", + "deadness_analysis.cc", + "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", "mark_for_compilation_pass.cc", ], hdrs = [ "build_xla_launch_ops_pass.h", + "deadness_analysis.h", "encapsulate_subgraphs_pass.h", "mark_for_compilation_pass.h", ], @@ -323,6 +333,7 @@ cc_library( "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -371,6 +382,35 @@ tf_cc_test( ], ) +tf_cc_test( + name = "deadness_analysis_test", + size = "small", + srcs = [ + "deadness_analysis_internal.h", + "deadness_analysis_test.cc", + ], + deps = [ + ":common", + ":compilation_passes", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cc_test( name = "compilation_passes_test", size = "small", @@ -385,6 +425,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", + "//tensorflow/cc:sendrecv_ops", "//tensorflow/compiler/jit/kernels:xla_launch_op", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", @@ -456,6 +497,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":common", + ":compilation_passes", ":union_find", ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc new file mode 100644 index 0000000000000000000000000000000000000000..8aff87e5e620fefd30eeb902209c9bc17540f468 --- /dev/null +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -0,0 +1,592 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/deadness_analysis.h" +#include "tensorflow/compiler/jit/deadness_analysis_internal.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/hash/hash.h" + +// ALGORITHM OVERVIEW +// +// We map every output produced by each node in the TensorFlow graph (including +// control dependence) into an instance of the Predicate class. Instances of +// Predicate denote logical formulas and mapping a node `n` to a predicate +// `pred` implies that `n` is executed whenver `pred` is true. Then we can +// deduce mismatching liveness in the inputs to node by comparing the predicate +// those inputs are mapped to. +// +// Loops are handled pessimistically -- we map Merge nodes with backedges to +// uninterpreted symbols (the same kind we use to represent Switch and _Recv). +// Predicate equality has to hold over all possible assignments to these +// uninterpreted symbols. + +namespace tensorflow { + +namespace { + +// Represents a logical predicate, used as described in the algorithm overview +// above. +class Predicate { + public: + enum class Kind { kAnd, kOr, kNot, kSymbol }; + + virtual string ToString() const = 0; + int64 hash() const { return hash_; } + + virtual Kind kind() const = 0; + virtual ~Predicate() {} + + protected: + explicit Predicate(int64 hash) : hash_(hash) {} + + private: + const int64 hash_; + + TF_DISALLOW_COPY_AND_ASSIGN(Predicate); +}; + +int64 HashPredicateSequence(Predicate::Kind kind, + gtl::ArraySlice preds) { + int64 hash = ::tensorflow::hash()(kind); + for (Predicate* pred : preds) { + hash = Hash64Combine(hash, pred->hash()); + } + return hash; +} + +// Represents a logical conjunction of a set of predicates. +class AndPredicate : public Predicate { + public: + explicit AndPredicate(std::vector operands) + : Predicate(HashPredicateSequence(Kind::kAnd, operands)), + operands_(std::move(operands)) {} + + string ToString() const override { + if (operands().empty()) { + return "#true"; + } + + std::vector operands_str; + std::transform(operands().begin(), operands().end(), + std::back_inserter(operands_str), + [](Predicate* pred) { return pred->ToString(); }); + + return strings::StrCat("(", str_util::Join(operands_str, " & "), ")"); + } + + Kind kind() const override { return Kind::kAnd; } + + const gtl::ArraySlice operands() const { return operands_; } + + private: + std::vector operands_; +}; + +// Represents a logical disjunction of a set of predicates. +class OrPredicate : public Predicate { + public: + explicit OrPredicate(std::vector operands) + : Predicate(HashPredicateSequence(Kind::kOr, operands)), + operands_(std::move(operands)) {} + + string ToString() const override { + if (operands().empty()) { + return "#false"; + } + + std::vector operands_str; + std::transform(operands().begin(), operands().end(), + std::back_inserter(operands_str), + [](Predicate* pred) { return pred->ToString(); }); + + return strings::StrCat("(", str_util::Join(operands_str, " | "), ")"); + } + + Kind kind() const override { return Kind::kOr; } + const gtl::ArraySlice operands() const { return operands_; } + + private: + std::vector operands_; +}; + +// Represents a logical negation of a set of predicates. +class NotPredicate : public Predicate { + public: + explicit NotPredicate(Predicate* operand) + : Predicate(HashPredicateSequence(Kind::kNot, {operand})), + operand_(operand) {} + + string ToString() const override { + return strings::StrCat("~", operand()->ToString()); + } + + Kind kind() const override { return Kind::kNot; } + Predicate* operand() const { return operand_; } + + private: + Predicate* operand_; +}; + +// Represents an uninterpreted symbol in a logical predicate. +// +// Two predicates are equivalent iff they are equivalent for all assignments to +// the symbols contained in them. +class SymbolPredicate : public Predicate { + public: + explicit SymbolPredicate(TensorId tensor_id, bool must_be_true) + : Predicate(Hash(tensor_id, must_be_true)), + tensor_id_(std::move(tensor_id)), + must_be_true_(must_be_true) {} + + string ToString() const override { + return must_be_true() ? strings::StrCat("*", tensor_id_.ToString()) + : tensor_id_.ToString(); + } + + Kind kind() const override { return Kind::kSymbol; } + + // If `must_be_true()` is true this SymbolPredicate represents the proposition + // "tensor_id() is live and evaluates to true". + // + // If `must_be_true()` is false then this SymbolPredicate represents the + // proposition "tensor_id() is live (and may evalutate to any value)" + TensorId tensor_id() const { return tensor_id_; } + bool must_be_true() const { return must_be_true_; } + + private: + TensorId tensor_id_; + bool must_be_true_; + + static int64 Hash(const TensorId tensor_id, bool must_be_true) { + return Hash64Combine( + ::tensorflow::hash()(must_be_true), + Hash64Combine(::tensorflow::hash()(Kind::kSymbol), + TensorId::Hasher{}(tensor_id))); + } +}; + +// Creates and owns Predicate instances. Simplifies predicates as it creates +// them. +class PredicateFactory { + public: + Predicate* MakeAndPredicate(gtl::ArraySlice operands) { + return MakeAndOrImpl(operands, /*is_and=*/true); + } + + Predicate* MakeOrPredicate(gtl::ArraySlice operands) { + return MakeAndOrImpl(operands, /*is_and=*/false); + } + + Predicate* MakeNotPredicate(Predicate* pred) { + SignatureForNot signature = pred; + auto it = interned_not_instances_.find(signature); + if (it == interned_not_instances_.end()) { + std::unique_ptr new_pred = Make(pred); + Predicate* new_pred_ptr = new_pred.get(); + interned_not_instances_.emplace(signature, std::move(new_pred)); + return new_pred_ptr; + } else { + return it->second.get(); + } + } + + Predicate* MakeSymbolPredicate(TensorId tensor_id, bool must_be_true) { + SignatureForSymbol signature = {tensor_id, must_be_true}; + auto it = interned_symbol_instances_.find(signature); + if (it == interned_symbol_instances_.end()) { + std::unique_ptr new_pred = + Make(tensor_id, must_be_true); + Predicate* new_pred_ptr = new_pred.get(); + interned_symbol_instances_.emplace(std::move(signature), + std::move(new_pred)); + return new_pred_ptr; + } else { + return it->second.get(); + } + } + + Predicate* MakeTrue() { return MakeAndPredicate({}); } + Predicate* MakeFalse() { return MakeOrPredicate({}); } + + private: + template + std::unique_ptr Make(Args&&... args) { + return std::unique_ptr( + new PredicateT(std::forward(args)...)); + } + + Predicate* MakeAndOrImpl(gtl::ArraySlice operands, bool is_and); + + // Predicate instances are interned, meaning that there is only a single + // instance of a Predicate object with a given content. This makes checking + // for structural equality super-cheap -- we can just compare pointers. + // + // We intern predicates by maintaining a map from the content of a Predicate + // to the only instance of said predicate we allow to exist in the + // interned_and_or_instances_, interned_not_instances_ and + // interned_symbol_instances_ fields. These maps also double up as storage + // for the owning pointers to predicate instances. + + using SignatureForAndOr = + std::pair>; + using SignatureForNot = Predicate*; + using SignatureForSymbol = std::pair; + + struct HashSignatureForAndOr { + size_t operator()(const SignatureForAndOr& signature) const { + size_t hash = ::tensorflow::hash()(signature.first); + for (Predicate* p : signature.second) { + hash = Hash64Combine(hash, ::tensorflow::hash()(p)); + } + return hash; + } + }; + + struct HashSignatureForSymbol { + size_t operator()(const SignatureForSymbol& signature) const { + return Hash64Combine(SafeTensorId::Hasher()(signature.first), + ::tensorflow::hash()(signature.second)); + } + }; + + gtl::FlatMap, + HashSignatureForAndOr> + interned_and_or_instances_; + gtl::FlatMap> + interned_not_instances_; + gtl::FlatMap, + HashSignatureForSymbol> + interned_symbol_instances_; +}; + +// Common code to create AndPredicate or OrPredicate instances. +Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice operands, + bool is_and) { + Predicate::Kind pred_kind = + is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr; + gtl::FlatSet simplified_ops_set; + std::vector simplified_ops; + for (Predicate* op : operands) { + // Simplify A&A => A and A|A => A. + if (!simplified_ops_set.insert(op).second) { + continue; + } + + if (op->kind() == pred_kind) { + // "Inline" the operands of an inner And/Or into the parent And/Or. + gtl::ArraySlice operands = + is_and ? dynamic_cast(op)->operands() + : dynamic_cast(op)->operands(); + for (Predicate* subop : operands) { + if (simplified_ops_set.insert(subop).second) { + simplified_ops.push_back(subop); + } + } + } else { + simplified_ops.push_back(op); + } + } + + if (simplified_ops.size() == 1) { + return simplified_ops[0]; + } + + // Simplify "A&~A=>False" and "A|~A=>True". + gtl::FlatSet negated_ops; + for (Predicate* op : simplified_ops) { + if (op->kind() == Predicate::Kind::kNot) { + negated_ops.insert(dynamic_cast(*op).operand()); + } + } + + for (Predicate* op : simplified_ops) { + if (negated_ops.count(op)) { + return is_and ? MakeFalse() : MakeTrue(); + } + } + + std::stable_sort( + simplified_ops.begin(), simplified_ops.end(), + [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); }); + + auto it = interned_and_or_instances_.find({pred_kind, simplified_ops}); + if (it == interned_and_or_instances_.end()) { + simplified_ops.shrink_to_fit(); + // NB! Because we'll use a non-owning reference to simplified_ops in the + // key for interned_and_or_instances_ we need to be careful to std::move() + // it all the way through. + gtl::ArraySlice operands_slice = simplified_ops; + std::unique_ptr new_pred = + is_and ? Make(std::move(simplified_ops)) + : Make(std::move(simplified_ops)); + + Predicate* new_pred_ptr = new_pred.get(); + CHECK(interned_and_or_instances_ + .emplace(SignatureForAndOr(pred_kind, operands_slice), + std::move(new_pred)) + .second); + return new_pred_ptr; + } else { + return it->second.get(); + } +} + +class DeadnessAnalysisImpl : public DeadnessAnalysis { + public: + explicit DeadnessAnalysisImpl(const Graph* graph) + : graph_(*graph), vlog_(VLOG_IS_ON(2)) {} + + Status Populate(); + bool HasInputsWithMismatchingDeadness(const Node& node) override; + void Print() const override; + gtl::FlatMap PredicateMapAsString() const; + + private: + enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly }; + + std::vector GetIncomingPreds(Node* n, EdgeKind edge_kind); + void SetPred(Node* n, int output_idx, Predicate* pred) { + CHECK( + predicate_map_.insert({TensorId(n->name(), output_idx), pred}).second); + } + void SetPred(Node* n, gtl::ArraySlice output_idxs, Predicate* pred) { + for (int output_idx : output_idxs) { + SetPred(n, output_idx, pred); + } + } + + Status HandleSwitch(Node* n); + Status HandleMerge(Node* n); + Status HandleRecv(Node* n); + Status HandleGeneric(Node* n); + + const Graph& graph_; + gtl::FlatMap predicate_map_; + PredicateFactory predicate_factory_; + bool vlog_; +}; + +TensorId InputEdgeToTensorId(const Edge* e) { + return TensorId(e->src()->name(), e->src_output()); +} + +std::vector DeadnessAnalysisImpl::GetIncomingPreds( + Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind) { + std::vector incoming_preds; + for (const Edge* in_edge : n->in_edges()) { + bool should_process = + edge_kind == EdgeKind::kDataAndControl || + (in_edge->IsControlEdge() && edge_kind == EdgeKind::kControlOnly) || + (!in_edge->IsControlEdge() && edge_kind == EdgeKind::kDataOnly); + + if (should_process) { + auto it = predicate_map_.find(InputEdgeToTensorId(in_edge)); + CHECK(it != predicate_map_.end()); + incoming_preds.push_back(it->second); + } + } + return incoming_preds; +} + +Status DeadnessAnalysisImpl::HandleSwitch(Node* n) { + std::vector input_preds = + GetIncomingPreds(n, EdgeKind::kDataAndControl); + const Edge* pred_edge; + TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge)); + Predicate* true_switch = predicate_factory_.MakeSymbolPredicate( + TensorId(pred_edge->src()->name(), pred_edge->src_output()), + /*must_be_true=*/true); + Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch); + + // Output 0 is alive iff all inputs are alive and the condition is false. + input_preds.push_back(false_switch); + SetPred(n, 0, predicate_factory_.MakeAndPredicate(input_preds)); + input_preds.pop_back(); + + // Output 1 is alive iff all inputs are alive and the condition is true. + input_preds.push_back(true_switch); + SetPred(n, 1, predicate_factory_.MakeAndPredicate(input_preds)); + input_preds.pop_back(); + + // Control is alive iff any inputs are alive. + SetPred(n, Graph::kControlSlot, + predicate_factory_.MakeAndPredicate(input_preds)); + + return Status::OK(); +} + +Status DeadnessAnalysisImpl::HandleMerge(Node* n) { + // Merge ignores deadness of its control inputs. A merge that isn't the + // target of a backedge has is alive iff any of its data inputs are. We treat + // the liveness of a merge that is the target of a backedge symbolically. + + bool has_backedge = std::any_of( + n->in_edges().begin(), n->in_edges().end(), [](const Edge* e) { + return !e->IsControlEdge() && e->src()->IsNextIteration(); + }); + + Predicate* input_data_pred = + has_backedge ? predicate_factory_.MakeSymbolPredicate( + TensorId(n->name(), 0), /*must_be_true=*/false) + : predicate_factory_.MakeOrPredicate( + GetIncomingPreds(n, EdgeKind::kDataOnly)); + + SetPred(n, {0, 1, Graph::kControlSlot}, input_data_pred); + return Status::OK(); +} + +Status DeadnessAnalysisImpl::HandleRecv(Node* n) { + // In addition to being alive or dead based on the inputs, a _Recv can also + // acquire a dead signal from a _Send. + std::vector input_preds = + GetIncomingPreds(n, EdgeKind::kDataAndControl); + input_preds.push_back(predicate_factory_.MakeSymbolPredicate( + TensorId(n->name(), 0), /*must_be_true=*/false)); + SetPred(n, {0, Graph::kControlSlot}, + predicate_factory_.MakeAndPredicate(input_preds)); + return Status::OK(); +} + +Status DeadnessAnalysisImpl::HandleGeneric(Node* n) { + // Generally nodes are alive iff all their inputs are alive. + Predicate* pred = predicate_factory_.MakeAndPredicate( + GetIncomingPreds(n, EdgeKind::kDataAndControl)); + for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) { + SetPred(n, output_idx, pred); + } + SetPred(n, Graph::kControlSlot, pred); + return Status::OK(); +} + +Status DeadnessAnalysisImpl::Populate() { + std::vector rpo; + GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/{}, + /*edge_filter=*/[](const Edge& edge) { + return !edge.src()->IsNextIteration(); + }); + + // This an abstract interpretation over the deadness propagation semantics of + // the graph executor. + for (Node* n : rpo) { + if (n->IsSwitch()) { + TF_RETURN_IF_ERROR(HandleSwitch(n)); + } else if (n->IsMerge()) { + TF_RETURN_IF_ERROR(HandleMerge(n)); + } else if (n->IsControlTrigger()) { + SetPred(n, Graph::kControlSlot, predicate_factory_.MakeTrue()); + } else if (n->IsRecv() || n->IsHostRecv()) { + TF_RETURN_IF_ERROR(HandleRecv(n)); + } else { + TF_RETURN_IF_ERROR(HandleGeneric(n)); + } + } + + return Status::OK(); +} + +bool DeadnessAnalysisImpl::HasInputsWithMismatchingDeadness(const Node& node) { + CHECK(!node.IsMerge()); + + if (vlog_) { + VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() << ")"; + } + + Predicate* pred = nullptr; + for (const Edge* edge : node.in_edges()) { + auto it = predicate_map_.find(InputEdgeToTensorId(edge)); + CHECK(it != predicate_map_.end()); + if (vlog_) { + VLOG(2) << " " << InputEdgeToTensorId(edge).ToString() << ": " + << it->second->ToString(); + } + + // Today we just compare the predicates for equality (with some + // canonicalization/simplification happening before) but we could be more + // sophisticated here if need be. Comparing pointers is sufficient because + // we intern Predicate instances by their content. + if (pred != nullptr && pred != it->second) { + if (vlog_) { + VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() + << ") -> true"; + } + return true; + } + pred = it->second; + } + + if (vlog_) { + VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() + << ") -> false"; + } + + return false; +} + +void DeadnessAnalysisImpl::Print() const { + std::vector tensor_ids; + for (const auto& kv_pair : predicate_map_) { + tensor_ids.push_back(kv_pair.first); + } + + std::sort(tensor_ids.begin(), tensor_ids.end()); + + for (TensorId tensor_id : tensor_ids) { + auto it = predicate_map_.find(tensor_id); + CHECK(it != predicate_map_.end()) << tensor_id.ToString(); + VLOG(2) << tensor_id.ToString() << " -> " << it->second->ToString(); + } +} + +} // namespace + +DeadnessAnalysis::~DeadnessAnalysis() {} + +/*static*/ Status DeadnessAnalysis::Run( + const Graph& graph, std::unique_ptr* result) { + std::unique_ptr analysis( + new DeadnessAnalysisImpl(&graph)); + TF_RETURN_IF_ERROR(analysis->Populate()); + + if (VLOG_IS_ON(2)) { + analysis->Print(); + } + + *result = std::move(analysis); + return Status::OK(); +} + +gtl::FlatMap +DeadnessAnalysisImpl::PredicateMapAsString() const { + gtl::FlatMap result; + std::vector tensor_ids; + for (const auto& kv_pair : predicate_map_) { + CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second); + } + return result; +} + +namespace deadness_analysis_internal { +Status ComputePredicates(const Graph& graph, + PredicateMapTy* out_predicate_map) { + DeadnessAnalysisImpl impl(&graph); + TF_RETURN_IF_ERROR(impl.Populate()); + *out_predicate_map = impl.PredicateMapAsString(); + return Status::OK(); +} +} // namespace deadness_analysis_internal + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis.h b/tensorflow/compiler/jit/deadness_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..6e7ab411619ba08060aa4925e91dce06299d1d23 --- /dev/null +++ b/tensorflow/compiler/jit/deadness_analysis.h @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_ + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// This analyzes a TensorFlow graph to identify nodes which may have partially +// dead inputs (i.e. these nodes may have some dead inputs and some alive +// inputs). +// +// For example, the ADD node in the following graph +// +// V0 PRED0 V1 PRED1 +// | | | | +// v v v v +// SWITCH SWITCH +// | | +// +---+ + ---+ +// | | +// v v +// ADD +// +// can have its inputs independently dead or alive based on the runtime values +// of PRED0 and PRED1. +// +// It is tempting to call this a liveness analysis but I avoided that because +// "liveness" already has other connotations. +class DeadnessAnalysis { + public: + // Returns true if `node` may have some live inputs and some dead inputs. + // + // This is a conservatively correct routine -- if it returns false then `node` + // is guaranteed to not have inputs with mismatching liveness, but not the + // converse. + // + // REQUIRES: node is not a Merge operation. + virtual bool HasInputsWithMismatchingDeadness(const Node& node) = 0; + + // Prints out the internal state of this instance. For debugging purposes + // only. + virtual void Print() const = 0; + virtual ~DeadnessAnalysis(); + + // Run the deadness analysis over `graph` and returns an error or a populated + // instance of DeadnessAnalysis in `result`. + static Status Run(const Graph& graph, + std::unique_ptr* result); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_ diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..cdef4051108fdc5d063ab592676c7644989155bf --- /dev/null +++ b/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ +#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ + +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace tensorflow { +namespace deadness_analysis_internal { + +// Returns a map describing the predicate each Tensor was mapped to. For +// testing purposes only. +using PredicateMapTy = gtl::FlatMap; +Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map); +} // namespace deadness_analysis_internal +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6881095b51758d2e0b06c60021bc8c2860ac566e --- /dev/null +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -0,0 +1,467 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/deadness_analysis.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/deadness_analysis_internal.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +Status AnalyzeDeadness(Graph* graph, + std::unique_ptr* result) { + FixupSourceAndSinkEdges(graph); + return DeadnessAnalysis::Run(*graph, result); +} + +ops::Switch CreateSwitch(const Scope& root, const string& prefix) { + Output value = ops::Placeholder(root.WithOpName(prefix + "/value"), DT_FLOAT); + Output predicate = + ops::Placeholder(root.WithOpName(prefix + "/pred"), DT_BOOL); + return ops::Switch(root.WithOpName(prefix + "/switch"), value, predicate); +} + +Output CreateInductionVariable(const Scope& root, const string& prefix, + const string& frame_name, int32 init) { + Output initial_value = ops::Const(root.WithOpName(prefix + "/init"), init); + Output enter_initial_value = ops::internal::Enter( + root.WithOpName(prefix + "/enter"), initial_value, frame_name); + + ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_initial_value}); + Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1); + Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10); + Output loop_cond_expr = + ops::Less(root.WithOpName(prefix + "/less"), iv.output, final_value); + Output loop_cond = + ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr); + ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); + ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); + Output iv_next = + ops::Add(root.WithOpName(prefix + "/ivnext"), iv.output, increment_by); + Output next_iteration = + ops::NextIteration(root.WithOpName(prefix + "next_iteration"), iv_next); + + root.graph()->AddEdge(next_iteration.node(), 0, iv.output.node(), 1); + root.graph()->AddControlEdge(iv.output.node(), increment_by.node()); + root.graph()->AddControlEdge(iv.output.node(), final_value.node()); + + return iv.output; +} + +TEST(DeadnessAnalysisTest, BasicPositive) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw = CreateSwitch(root, "0"); + Output add = + ops::Add(root.WithOpName("add"), sw.output_true, sw.output_false); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, BasicNegative) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("b"), DT_FLOAT); + Output add = ops::Add(root.WithOpName("add"), a, b); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, AndIsCommutative) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + + Output a0 = + ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false); + Output a1 = + ops::Add(root.WithOpName("a1"), sw_1.output_false, sw_0.output_false); + + Output b0 = + ops::Add(root.WithOpName("b0"), sw_0.output_false, sw_1.output_true); + Output b1 = + ops::Add(root.WithOpName("b1"), sw_1.output_true, sw_0.output_false); + + Output live0 = ops::Add(root.WithOpName("live0"), a0, a1); + Output live1 = ops::Add(root.WithOpName("live1"), b0, b1); + + Output halfdead0 = ops::Add(root.WithOpName("halfdead0"), a0, b0); + Output halfdead1 = ops::Add(root.WithOpName("halfdead1"), a1, b1); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node())); + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node())); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node())); + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node())); +} + +TEST(DeadnessAnalysisTest, AndIsAssociative) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + ops::Switch sw_2 = CreateSwitch(root, "2"); + + Output a0 = + ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false); + Output a1 = ops::Add(root.WithOpName("a1"), a0, sw_2.output_false); + + Output b0 = + ops::Add(root.WithOpName("b0"), sw_1.output_false, sw_2.output_false); + Output b1 = ops::Add(root.WithOpName("b1"), sw_0.output_false, b0); + + Output add = ops::Add(root.WithOpName("add"), a1, b1); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, OrIsCommutative) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + + ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false}); + ops::Merge m1(root.WithOpName("m1"), {sw_1.output_false, sw_0.output_false}); + ops::Merge m2(root.WithOpName("m2"), {sw_0.output_false, sw_1.output_true}); + ops::Merge m3(root.WithOpName("m3"), {sw_1.output_true, sw_0.output_false}); + + Output live0 = ops::Add(root.WithOpName("live0"), m0.output, m1.output); + Output live1 = ops::Add(root.WithOpName("live1"), m2.output, m3.output); + + Output halfdead0 = + ops::Add(root.WithOpName("halfdead0"), m0.output, m2.output); + Output halfdead1 = + ops::Add(root.WithOpName("halfdead1"), m1.output, m3.output); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node())); + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node())); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node())); + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node())); +} + +TEST(DeadnessAnalysisTest, OrIsAssociative) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + ops::Switch sw_2 = CreateSwitch(root, "2"); + + ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false}); + ops::Merge m1(root.WithOpName("m1"), {m0.output, sw_2.output_false}); + ops::Merge m2(root.WithOpName("m2"), {sw_1.output_false, sw_2.output_false}); + ops::Merge m3(root.WithOpName("m3"), {sw_0.output_false, m2.output}); + + Output add = ops::Add(root.WithOpName("add"), m1.output, m3.output); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, AndOfOr) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + ops::Switch sw_2 = CreateSwitch(root, "2"); + ops::Switch sw_3 = CreateSwitch(root, "3"); + + ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false}); + ops::Merge m1(root.WithOpName("m1"), {sw_2.output_false, sw_3.output_false}); + + Output add0 = ops::Add(root.WithOpName("add0"), m0.output, m1.output); + Output add1 = ops::Add(root.WithOpName("add1"), m0.output, m1.output); + + Output add2 = ops::Add(root.WithOpName("add2"), add0, add1); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node())); +} + +TEST(DeadnessAnalysisTest, OrOfAnd) { + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + ops::Switch sw_2 = CreateSwitch(root, "2"); + ops::Switch sw_3 = CreateSwitch(root, "3"); + + Output add0 = + ops::Add(root.WithOpName("add0"), sw_0.output_false, sw_1.output_false); + Output add1 = + ops::Add(root.WithOpName("add1"), sw_2.output_false, sw_3.output_false); + + ops::Merge m0(root.WithOpName("m0"), {add0, add1}); + ops::Merge m1(root.WithOpName("m1"), {add0, add1}); + + Output add2 = ops::Add(root.WithOpName("add2"), m0.output, m1.output); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node())); +} + +TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) { + // This demonstrates one of the weaknesses in the current approach -- since we + // only do some basic simplifications we can't see that "(A|B)&C" == + // "(A&C)|(B&C)". + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "0"); + ops::Switch sw_1 = CreateSwitch(root, "1"); + ops::Switch sw_2 = CreateSwitch(root, "2"); + + ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false}); + Output add0 = ops::Add(root.WithOpName("add0"), m0.output, sw_2.output_false); + + Output add1 = + ops::Add(root.WithOpName("add1"), sw_0.output_false, sw_2.output_false); + Output add2 = + ops::Add(root.WithOpName("add2"), sw_1.output_false, sw_2.output_false); + ops::Merge m1(root.WithOpName("m1"), {add1, add2}); + + Output add3 = ops::Add(root.WithOpName("add3"), add0, m1.output); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add2.node())); +} + +TEST(DeadnessAnalysisTest, Ternary) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output predicate = ops::Placeholder(root.WithOpName("predicate"), DT_BOOL); + Output true_value = ops::Placeholder(root.WithOpName("true_value"), DT_FLOAT); + Output false_value = + ops::Placeholder(root.WithOpName("false_value"), DT_FLOAT); + + ops::Switch predicated_true(root.WithOpName("predicated_true"), true_value, + predicate); + + ops::Switch predicated_false(root.WithOpName("predicated_false"), true_value, + predicate); + ops::Merge merge(root.WithOpName("ternary"), {predicated_true.output_true, + predicated_false.output_false}); + Output addend = ops::Placeholder(root.WithOpName("addend"), DT_FLOAT); + Output add = ops::Add(root.WithOpName("add"), merge.output, addend); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, Recv) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_FLOAT, "tensor_a", + "sender", 0, "receiver"); + Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_FLOAT, "tensor_b", + "sender", 0, "receiver"); + Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, HostRecv) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output recv_a = ops::_HostRecv(root.WithOpName("recv_a"), DT_FLOAT, + "tensor_a", "sender", 0, "receiver"); + Output recv_b = ops::_HostRecv(root.WithOpName("recv_b"), DT_FLOAT, + "tensor_b", "sender", 0, "receiver"); + Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, Loop) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0); + Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0); + Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1); + Output add0 = ops::Add(root.WithOpName("add0"), iv0, iv1); + Output add1 = ops::Add(root.WithOpName("add1"), iv1, iv2); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + // NB! iv0 and iv1 are equivalent and a smarter deadness analysis would have + // noticed that. Today we are pessimistic here because we assign an + // uninterpreted symbol to merges with backedges. + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node())); + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node())); +} + +TEST(DeadnessAnalysisTest, ControlInputs) { + Scope root = Scope::NewRootScope().ExitOnError(); + ops::Switch sw = CreateSwitch(root, "0"); + + Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false); + Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true); + + Output const0 = ops::Const(root.WithOpName("const0"), 1); + Output const1 = ops::Const(root.WithOpName("const1"), 2); + + Output add = ops::Add(root.WithOpName("add"), const0, const1); + + root.graph()->AddControlEdge(id0.node(), const0.node()); + root.graph()->AddControlEdge(id1.node(), const1.node()); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, ControlTrigger) { + Scope root = Scope::NewRootScope().ExitOnError(); + ops::Switch sw = CreateSwitch(root, "0"); + + Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false); + Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true); + + ops::ControlTrigger ctrl_trigger0(root.WithOpName("ctrl_trigger0")); + ops::ControlTrigger ctrl_trigger1(root.WithOpName("ctrl_trigger1")); + + Output const0 = ops::Const(root.WithOpName("const0"), 1); + Output const1 = ops::Const(root.WithOpName("const1"), 2); + + Output add = ops::Add(root.WithOpName("add"), const0, const1); + + root.graph()->AddControlEdge(id0.node(), ctrl_trigger0.operation.node()); + root.graph()->AddControlEdge(ctrl_trigger0.operation.node(), const0.node()); + + root.graph()->AddControlEdge(id1.node(), ctrl_trigger1.operation.node()); + root.graph()->AddControlEdge(ctrl_trigger1.operation.node(), const1.node()); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, ControlInputsToMerge) { + Scope root = Scope::NewRootScope().ExitOnError(); + ops::Switch sw = CreateSwitch(root, "0"); + + Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false); + Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true); + + Output constant = ops::Const(root.WithOpName("constant"), 5); + ops::Merge m0(root.WithOpName("m0"), {constant}); + ops::Merge m1(root.WithOpName("m0"), {constant}); + Output add = ops::Add(root.WithOpName("add"), m0.output, m1.output); + + root.graph()->AddControlEdge(id0.node(), m0.output.node()); + root.graph()->AddControlEdge(id1.node(), m1.output.node()); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node())); +} + +TEST(DeadnessAnalysisTest, RecvVsSwitch) { + // Demonstrates why we need the must_be_true bit on SymbolP. + Scope root = Scope::NewRootScope().ExitOnError(); + + Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender", + 0, "receiver"); + Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL); + ops::Switch sw(root.WithOpName("switch"), value, recv); + Output logical_and = + ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*logical_and.node())); +} + +TEST(DeadnessAnalysisTest, RecvVsSwitchText) { + // Demonstrates why we need the must_be_true bit on SymbolP. + Scope root = Scope::NewRootScope().ExitOnError(); + + Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender", + 0, "receiver"); + Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL); + ops::Switch sw(root.WithOpName("switch"), value, recv); + Output logical_and = + ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true); + + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + deadness_analysis_internal::PredicateMapTy predicate_map; + TF_ASSERT_OK(deadness_analysis_internal::ComputePredicates(*root.graph(), + &predicate_map)); + + TensorId logical_and_output_0 = {logical_and.node()->name(), + Graph::kControlSlot}; + EXPECT_EQ(predicate_map[logical_and_output_0], "(recv:0 & *recv:0)"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index e786d41887f1d539fe1ae122275d1c14c77309e8..fdd71c6a588ad96301f543651c8531e6f9c3ca05 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -60,9 +60,9 @@ const char* const kXlaHostTransferSequencerAttr = namespace { -bool AreAllParentsConst(const Node& n, - const gtl::FlatSet& runtime_const_nodes) { - if (n.type_string() == "GuaranteeConst" || n.type_string() == "Const") { +bool AreAllParentsGuaranteedConst( + const Node& n, const gtl::FlatSet& runtime_const_nodes) { + if (n.type_string() == "GuaranteeConst") { // If the current node is itself a cast-to-const, no need // to look at the incoming edges. return true; @@ -93,7 +93,8 @@ void MarkGuaranteedConstants( ReverseDFSFrom(graph, srcs, /*enter=*/nullptr, /*leave=*/[&guaranteed_const_nodes](const Node* n) { // TODO(vinuraja): Doesn't work in the presence of loops. - if (AreAllParentsConst(*n, guaranteed_const_nodes)) { + if (AreAllParentsGuaranteedConst(*n, + guaranteed_const_nodes)) { guaranteed_const_nodes.insert(n); } }); @@ -137,7 +138,7 @@ class Encapsulator { // Find subgraphs marked with 'group_attribute', and build a new // subgraph, one for each value of 'group_attribute'. - Status SplitIntoSubgraphs(); + Status SplitIntoSubgraphs(FunctionLibraryDefinition* library); // Build a FunctionDef for each subgraph, and add it 'library'. The values of // the 'group_attribute' annotations become the function names. @@ -1136,7 +1137,10 @@ Status Encapsulator::Subgraph::AddShapeInferenceInfo( GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef)); host_compute->AddAttr("shape_inference_graph", inference_graph_name); host_compute->AddAttr("shapes", std::vector()); - TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + // TODO(sibyl-Aix6ihai): Understand why there are multiple calls to Encapsulator. + if (library->Find(inference_graph_name) == nullptr) { + TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + } } return Status::OK(); } @@ -1474,7 +1478,7 @@ Status Encapsulator::CopySubgraphEdges( return Status::OK(); } -Status Encapsulator::SplitIntoSubgraphs() { +Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) { Status s; // Map from input graph nodes to subgraph nodes. @@ -1509,6 +1513,15 @@ Status Encapsulator::SplitIntoSubgraphs() { TF_RETURN_IF_ERROR(BuildControlFlowInfo(subgraph.GetGraph(), &dummy)); } + if (VLOG_IS_ON(1)) { + // Dump subgraphs. + for (auto& entry : subgraphs_) { + dump_graph::DumpGraphToFile( + strings::StrCat("encapsulate_subgraphs_subgraph_", entry.first), + *entry.second.GetGraph(), library); + } + } + return s; } @@ -1932,6 +1945,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( // continue. TensorShapeProto proto; context->ShapeHandleToProto(shape, &proto); + VLOG(2) << "Node " << src_node->name() + << " has known shape: " << proto.DebugString(); if (dummy_node_images.find(src_node) == dummy_node_images.end()) { dummy_node_images[src_node] = AddDummyShapedNode(src_node, src_port, control_flow_info, @@ -1949,6 +1964,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( if (VLOG_IS_ON(2)) { TensorShapeProto proto; context->ShapeHandleToProto(shape, &proto); + VLOG(2) << "Node " << src_node->name() + << " has unknown shape: " << proto.DebugString(); } stack.push_back({src_node, false}); } @@ -2191,6 +2208,23 @@ Status Encapsulator::FindClusterDependencies() { } } } + if (VLOG_IS_ON(2)) { + // Print debug information. + VLOG(2) << "node_ancestors_map:"; + for (const auto& node_iter : node_ancestors_map) { + VLOG(2) << "\t" << node_iter.first->name() << ": subgraph = '" + << node_iter.second.subgraph + << "', outside_compilation_cluster = '" + << node_iter.second.outside_compilation_cluster + << "', ancestor_clusters: " + << (node_iter.second.ancestor_clusters.empty() ? "(empty)" : ""); + for (const auto& cluster_iter : node_iter.second.ancestor_clusters) { + VLOG(2) << "\t\tsubgraph = '" << cluster_iter.subgraph + << "', outside_compilation_cluster = '" + << cluster_iter.outside_compilation_cluster << "'"; + } + } + } return Status::OK(); } @@ -2398,7 +2432,7 @@ Status EncapsulateSubgraphsInFunctions( std::move(outside_compilation_attribute), &graph_in); TF_RETURN_IF_ERROR(encapsulator.FindClusterDependencies()); - TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs()); + TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs(library)); TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs( rewrite_subgraph_fn, reuse_existing_functions, library)); @@ -2447,7 +2481,7 @@ Status EncapsulateSubgraphsPass::Run( const GraphOptimizationPassOptions& options) { VLOG(1) << "EncapsulateSubgraphsPass::Run"; if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph, + dump_graph::DumpGraphToFile("encapsulate_subgraphs_before", **options.graph, options.flib_def); } @@ -2530,7 +2564,7 @@ Status EncapsulateSubgraphsPass::Run( "EncapsulateSubgraphsPass failed"); if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out, + dump_graph::DumpGraphToFile("encapsulate_subgraphs_after", *graph_out, options.flib_def); } diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 4eb389e0c653f2d32c17f448687f865a44a11b96..c0543a00792235c5dd090e81930d8c219dc7f1a3 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -742,10 +742,13 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { Scope root = Scope::NewRootScope().ExitOnError().WithDevice( "/job:localhost/replica:0/task:0/cpu:0"); auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); - auto const_x2 = ops::Const(root.WithOpName("const_x2"), 10.0f); + auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT); + auto const_guarantee_x2 = + ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2); auto const_guarantee_x1 = ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1); - auto add1 = ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_x2); + auto add1 = + ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_guarantee_x2); add1.node()->AddAttr("_encapsulate", "encapsulate1"); Graph graph_before(OpRegistry::Global()); diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 251a07304eaeb21f1313d7a6ef6af668f99d8551..b313d48011b561eaab618692df49d1558c34a77c 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -51,7 +51,11 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, if (device_type_ == DeviceType(DEVICE_CPU)) { platform_id_ = se::host::kHostPlatformId; } else if (device_type_ == DeviceType(DEVICE_GPU)) { - platform_id_ = se::cuda::kCudaPlatformId; + platform_id_ = ctx->device() + ->tensorflow_gpu_device_info() + ->stream->parent() + ->platform() + ->id(); } else { platform_id_ = nullptr; } @@ -115,6 +119,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { const XlaDevice::Metadata* metadata = nullptr; Status s = XlaDevice::GetMetadata(ctx, &metadata); bool allocate_xla_tensors = s.ok(); + bool use_multiple_streams = s.ok() && metadata->UseMultipleStreams(); // Get the platform_id_ for XLA_* devices. if (platform_id_ == nullptr) { @@ -148,6 +153,10 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { XlaCompiler::Options options; options.client = client; + if (ctx->op_device_context() != nullptr) { + options.device_ordinal = + ctx->op_device_context()->stream()->parent()->device_ordinal(); + } options.device_type = cache->device_type(); options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); @@ -180,8 +189,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { VLOG(1) << "Executing XLA Computation..."; - XlaComputationLaunchContext launch_context(client, xla_allocator, - allocate_xla_tensors); + XlaComputationLaunchContext launch_context( + client, xla_allocator, allocate_xla_tensors, use_multiple_streams); launch_context.PopulateInputs(ctx, kernel, variables); // Execute the computation. diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 8c3882116dd4f048ea3e32c037bf4139c67a3eb9..45d422943c23f59823e6bfbcb355d4b58a6a225e 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/memory_types.h" @@ -460,20 +462,22 @@ Status MarkForCompilationPass::Run( VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; + VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit; const FunctionLibraryDefinition* fld = options.flib_def; - auto is_compilable = [global_jit_level, cpu_global_jit, fusion_only, fld]( - const Node* node, const DeviceType& device_type) { + std::unique_ptr deadness; + { + XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1); + TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(**options.graph, &deadness)); + } + + auto is_compilable = [&](const Node* node, const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { return false; } - // Don't compile control trigger nodes. We won't preserve their deadness - // semantics correctly, so it's safest not to compile them. - if (node->IsControlTrigger()) return false; - // If this device requires a JIT, we must say yes. if (registration->requires_compilation) return true; @@ -485,6 +489,14 @@ Status MarkForCompilationPass::Run( status = fld->GetAttr(*node, kXlaCompileAttr, &compile); if (status.ok()) return compile; + // If inputs to `node` can have conflicting deadness (i.e. some are alive + // and some are dead) then don't compile it. XLA cannot represent the + // deadness semantics of these nodes correctly and auto-clustering these + // nodes can cause deadness to propagate to nodes that should be live. + if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) { + return false; + } + // Check for fusable ops only if requested. if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) { return false; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 772c92d369e67f431b5d030d1d5cdc5ae2700d39..2c5f4fb774fcab082c0d0d316cdc6757cacc1e96 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -680,5 +681,37 @@ TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) { EXPECT_EQ(clusters, expected_clusters); } +TEST(XlaCompilationTest, ClusterControlTrigger) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_BOOL, "tensor_a", + "sender", 0, "receiver"); + Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_BOOL, "tensor_b", + "sender", 0, "receiver"); + Output const_a = ops::Const(root.WithOpName("const_a"), 42); + + ops::ControlTrigger ctrl_trigger_a(root.WithOpName("ctrl_trigger_a")); + ops::ControlTrigger ctrl_trigger_b(root.WithOpName("ctrl_trigger_b")); + root.graph()->AddControlEdge(recv_a.node(), ctrl_trigger_a.operation.node()); + root.graph()->AddControlEdge(recv_b.node(), ctrl_trigger_a.operation.node()); + root.graph()->AddControlEdge(ctrl_trigger_b.operation.node(), const_a.node()); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + ASSERT_FALSE(clusters.empty()); + string cluster_name = clusters.begin()->second; + + // ctrl_trigger_a has inputs with mismatching deadness so it won't be + // clustered. ctrl_trigger_b is okay to cluster. + std::unordered_map expected_clusters( + {{"const_a", cluster_name}, {"ctrl_trigger_b", cluster_name}}); + EXPECT_EQ(clusters, expected_clusters); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 54a41a4daa790401c797277e7eaab531dd34ac80..7140d47a9421ec73d0144e855b490f89569e6ae9 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -209,7 +209,9 @@ Status XlaCompilationCache::BuildExecutable( argument_layouts[i] = &result.xla_input_shapes[i]; } xla::ExecutableBuildOptions build_options; - build_options.set_device_ordinal(client_->default_device_ordinal()); + build_options.set_device_ordinal(options.device_ordinal != -1 + ? options.device_ordinal + : client_->default_device_ordinal()); build_options.set_result_layout(result.xla_output_shape); build_options.set_device_allocator(options.device_allocator); @@ -256,6 +258,7 @@ Status XlaCompilationCache::CompileImpl( xla::LocalExecutable** executable, const XlaCompiler::CompileOptions* compile_options, bool compile_single_op) { + CHECK_NE(executable, nullptr); VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { @@ -293,7 +296,7 @@ Status XlaCompilationCache::CompileImpl( // protect the contents of the cache entry. Entry* entry; { - mutex_lock lock(mu_); + mutex_lock lock(compile_cache_mu_); // Find or create a cache entry. std::unique_ptr& e = cache_[signature]; if (!e) { @@ -309,6 +312,8 @@ Status XlaCompilationCache::CompileImpl( if (!entry->compiled) { VLOG(1) << "Compilation cache miss for signature: " << SignatureDebugString(signature); + tensorflow::Env* env = tensorflow::Env::Default(); + const uint64 compile_start_us = env->NowMicros(); // Do the actual JIT compilation without holding the lock (it can take // a long time.) std::vector args; @@ -327,18 +332,35 @@ Status XlaCompilationCache::CompileImpl( compile_options ? *compile_options : XlaCompiler::CompileOptions(), function, args, &entry->compilation_result); } - } - *compilation_result = &entry->compilation_result; - if (entry->compilation_status.ok() && executable) { - if (entry->executable == nullptr) { - entry->compilation_status = BuildExecutable( - options, entry->compilation_result, &entry->executable); + TF_RETURN_IF_ERROR(entry->compilation_status); + CHECK_EQ(entry->executable.get(), nullptr); + entry->compilation_status = + BuildExecutable(options, entry->compilation_result, &entry->executable); + + const uint64 compile_end_us = env->NowMicros(); + const uint64 compile_time_us = compile_end_us - compile_start_us; + { + mutex_lock lock(compile_stats_mu_); + auto it = compile_stats_.emplace(function.name(), CompileStats{}).first; + it->second.compile_count++; + it->second.cumulative_compile_time_us += compile_time_us; + VLOG(1) << "compiled " << function.name() << " " + << it->second.compile_count + << " times, compile time: " << compile_time_us + << " us, cumulative: " << it->second.cumulative_compile_time_us + << " us (" + << tensorflow::strings::HumanReadableElapsedTime(compile_time_us / + 1.0e6) + << " / " + << tensorflow::strings::HumanReadableElapsedTime( + it->second.cumulative_compile_time_us / 1.0e6) + << ")"; } - *executable = entry->executable.get(); } - - Status status = entry->compilation_status; - return status; + TF_RETURN_IF_ERROR(entry->compilation_status); + *compilation_result = &entry->compilation_result; + *executable = entry->executable.get(); + return Status::OK(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index be1043d8c3fc0573922837e541615114a6d7a1a5..fc5f008f4f52c32d97e680784082d0e7bcb7d8eb 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -150,9 +151,22 @@ class XlaCompilationCache : public ResourceBase { std::unique_ptr executable GUARDED_BY(mu); }; - mutex mu_; - std::unordered_map, Signature::Hash> cache_ - GUARDED_BY(mu_); + mutex compile_cache_mu_; + gtl::FlatMap, Signature::Hash> cache_ + GUARDED_BY(compile_cache_mu_); + + struct CompileStats { + // Number of times the cluster has been (re-)compiled. + int64 compile_count = 0; + + // Cumulative time spent compiling the cluster. + int64 cumulative_compile_time_us = 0; + }; + mutex compile_stats_mu_; + + // Maps cluster names to compilation statistics for said cluster. + gtl::FlatMap compile_stats_ + GUARDED_BY(compile_stats_mu_); TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); }; diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index baccea2d6a793df8c5cf8c8941706d41d2c044ca..d288d37bc75380168a31937024dd41bdbe7dce9d 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -53,7 +53,9 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, // Builds an XLA allocator for the device. XlaComputationLaunchContext launch_context( - client, client->backend().memory_allocator(), true); + client, client->backend().memory_allocator(), + /*allocate_xla_tensors=*/true, + /*use_multiple_streams=*/metadata.UseMultipleStreams()); launch_context.PopulateInputs(ctx, result, variables); diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 43648402f65c656b6b4eb2e83e61ce45f1c73669..7e159e3171113b0d53f03bb676ac9c21db7fe77a 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -54,6 +54,7 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, DEVICE_CPU_XLA_JIT, options, name_prefix, registration, /*transfer_as_literal=*/false, + /*use_multiple_streams=*/false, /*shape_representation_fn=*/{}, /*padded_shape_fn=*/{}, &device)); devices->push_back(device.release()); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index ed007d603ea1b3d27dd25f00726261cdd029c20c..4ddeaebd3e42e96d46857a278451d8c97e49a725 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -130,7 +130,7 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { const string& jit_device_name, const SessionOptions& options, const string& name_prefix, const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, + bool transfer_as_literal, bool use_multiple_streams, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, const PaddedShapeFn& padded_shape_fn, std::unique_ptr* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" @@ -151,22 +151,24 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), strings::StrCat("device: ", device_name, " device")); - device->reset(new XlaDevice( - options, attrs, device_ordinal, DeviceType(jit_device_name), - platform.ValueOrDie(), transfer_as_literal, shape_representation_fn, - padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn)); + device->reset( + new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name), + platform.ValueOrDie(), transfer_as_literal, + use_multiple_streams, shape_representation_fn, + padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn)); return Status::OK(); } XlaDevice::Metadata::Metadata( int device_ordinal, se::Platform* platform, const DeviceType& device_type, XlaCompiler::ShapeRepresentationFn shape_representation_fn, - PaddedShapeFn padded_shape_fn) + PaddedShapeFn padded_shape_fn, bool use_multiple_streams) : device_ordinal_(device_ordinal), device_type_(device_type), platform_(platform), shape_representation_fn_(std::move(shape_representation_fn)), - padded_shape_fn_(std::move(padded_shape_fn)) {} + padded_shape_fn_(std::move(padded_shape_fn)), + use_multiple_streams_(use_multiple_streams) {} int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; } @@ -200,24 +202,27 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { XlaDevice::XlaDevice( const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, - se::Platform* platform, bool transfer_as_literal, + se::Platform* platform, bool transfer_as_literal, bool use_multiple_streams, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, const PaddedShapeFn& padded_shape_fn) : LocalDevice(options, attrs), xla_metadata_(device_ordinal, platform, jit_device_name, - shape_representation_fn, padded_shape_fn), + shape_representation_fn, padded_shape_fn, + use_multiple_streams), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), - xla_allocator_(nullptr), platform_(platform), + use_multiple_streams_(use_multiple_streams), transfer_as_literal_(transfer_as_literal), shape_representation_fn_(shape_representation_fn) { - VLOG(1) << "Created XLA device " << jit_device_name; + VLOG(1) << "Created XLA device " << jit_device_name << " " << this; } XlaDevice::~XlaDevice() { - if (gpu_device_info_ != nullptr) { - gpu_device_info_->default_context->Unref(); + VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this; + mutex_lock lock(mu_); + if (device_context_) { + device_context_->Unref(); } } @@ -233,6 +238,11 @@ xla::LocalClient* XlaDevice::client() const { } Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) { + mutex_lock lock(mu_); + return GetAllocatorLocked(attr); +} + +Allocator* XlaDevice::GetAllocatorLocked(AllocatorAttributes attr) { if (attr.on_host()) { return cpu_allocator(); } @@ -245,52 +255,105 @@ Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) { return xla_allocator_; } -xla::StatusOr XlaDevice::GetStream() { - if (!stream_) { - xla::Backend* backend = client()->mutable_backend(); - TF_ASSIGN_OR_RETURN(stream_, backend->BorrowStream(device_ordinal_)); +Status XlaDevice::EnsureDeviceContextOk() { + mutex_lock lock(mu_); + return GetDeviceContextLocked().status(); +} + +Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend, + const string& name, + xla::StreamPool::Ptr* stream, + bool* stream_was_changed) { + if (!(*stream) || !(*stream)->ok()) { + TF_ASSIGN_OR_RETURN(*stream, backend->BorrowStream(device_ordinal_)); + VLOG(1) << "XlaDevice " << this << " new " << name << " " + << (*stream)->DebugStreamPointers(); + *stream_was_changed = true; } - return stream_.get(); + return Status::OK(); } -Status XlaDevice::CreateAndSetGpuDeviceInfo() { - if (gpu_device_info_ == nullptr) { - TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - // Call GetAllocator for the side-effect of ensuring the allocator - // is created. - GetAllocator({}); - // XlaDevice owns both gpu_device_info_ and - // gpu_device_info_->default_context. - gpu_device_info_ = MakeUnique(); - gpu_device_info_->stream = stream; - gpu_device_info_->default_context = new XlaDeviceContext( - stream, client(), transfer_as_literal_, shape_representation_fn_); - set_tensorflow_gpu_device_info(gpu_device_info_.get()); +xla::StatusOr XlaDevice::GetDeviceContextLocked() { + xla::Backend* backend = client()->mutable_backend(); + + // Ensure all our streams are valid, borrowing new streams if necessary. + bool need_new_device_context = !device_context_; + TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_, + &need_new_device_context)); + + se::Stream* host_to_device_stream = stream_.get(); + se::Stream* device_to_host_stream = stream_.get(); + if (use_multiple_streams_) { + TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream", + &host_to_device_stream_, + &need_new_device_context)); + TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream", + &device_to_host_stream_, + &need_new_device_context)); + host_to_device_stream = host_to_device_stream_.get(); + device_to_host_stream = device_to_host_stream_.get(); } - return Status::OK(); + if (!need_new_device_context) { + return device_context_; + } + + // At this point we know we need a new device context. + // Call GetAllocator for the side-effect of ensuring the allocator is created. + GetAllocatorLocked({}); + if (device_context_) { + device_context_->Unref(); + } + device_context_ = new XlaDeviceContext( + stream_.get(), host_to_device_stream, device_to_host_stream, client(), + transfer_as_literal_, shape_representation_fn_); + VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext " + << device_context_; + + // Create and set a new GpuDeviceInfo, if necessary. + // + // TODO(b/78232898): This isn't thread-safe; there is a race between the call + // to set_tensorflow_gpu_device_info() with ops that call the getter + // tensorflow_gpu_device_info(). This isn't trivially fixed by adding locking + // to those methods; see the bug for details. Our only saving grace at the + // moment is that this race doesn't seem to occur in practice. + if (use_gpu_device_info_) { + auto gpu_device_info = MakeUnique(); + gpu_device_info->stream = stream_.get(); + gpu_device_info->default_context = device_context_; + set_tensorflow_gpu_device_info(gpu_device_info.get()); + gpu_device_info_ = std::move(gpu_device_info); + VLOG(1) << "XlaDevice " << this << " new GpuDeviceInfo " + << gpu_device_info_.get(); + } + + return device_context_; +} + +Status XlaDevice::UseGpuDeviceInfo() { + mutex_lock lock(mu_); + use_gpu_device_info_ = true; + return GetDeviceContextLocked().status(); } Status XlaDevice::FillContextMap(const Graph* graph, DeviceContextMap* device_context_map) { VLOG(1) << "XlaDevice::FillContextMap"; + mutex_lock lock(mu_); + TF_ASSIGN_OR_RETURN(XlaDeviceContext * device_context, + GetDeviceContextLocked()); + device_context_map->resize(graph->num_node_ids()); - TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - // Call GetAllocator for the side-effect of ensuring the allocator is created. - GetAllocator({}); - auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_, - shape_representation_fn_); for (Node* n : graph->nodes()) { VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); - ctx->Ref(); - (*device_context_map)[n->id()] = ctx; + device_context->Ref(); + (*device_context_map)[n->id()] = device_context; } - ctx->Unref(); return Status::OK(); } void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { - VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":" + VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":" << op_kernel->type_string(); // When Xprof profiling is off (which is the default), constructing the // activity is simple enough that its overhead is negligible. @@ -301,7 +364,7 @@ void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) { - VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" + VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" << op_kernel->type_string(); tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(), op_kernel->IsExpensive()); @@ -323,16 +386,17 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, if (alloc_attrs.on_host()) { *tensor = parsed; } else { - Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); + mutex_lock lock(mu_); + TF_ASSIGN_OR_RETURN(XlaDeviceContext * device_context, + GetDeviceContextLocked()); + Allocator* allocator = GetAllocatorLocked(alloc_attrs); + Tensor copy(allocator, parsed.dtype(), parsed.shape()); Notification n; - TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - XlaTransferManager manager(stream, client(), transfer_as_literal_, - shape_representation_fn_); - manager.CopyCPUTensorToDevice(&parsed, this, ©, - [&n, &status](const Status& s) { - status = s; - n.Notify(); - }); + device_context->CopyCPUTensorToDevice(&parsed, this, ©, + [&n, &status](const Status& s) { + status = s; + n.Notify(); + }); n.WaitForNotification(); *tensor = copy; } diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 02e88ee6793e984a7b782790f8011cbcbc5a5026..d8906419b0c406026bb7e10007b2f0a2b4832d01 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -25,10 +25,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ +#include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/allocator.h" @@ -39,6 +41,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace tensorflow { @@ -57,7 +60,7 @@ class XlaDevice : public LocalDevice { Metadata(int device_ordinal, se::Platform* platform, const DeviceType& device_type, XlaCompiler::ShapeRepresentationFn shape_representation_fn, - PaddedShapeFn padded_shape_fn); + PaddedShapeFn padded_shape_fn, bool use_multiple_streams); // The index of the device on this host. int device_ordinal() const; @@ -70,12 +73,15 @@ class XlaDevice : public LocalDevice { } const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; } + bool UseMultipleStreams() const { return use_multiple_streams_; } + private: const int device_ordinal_; const DeviceType device_type_; se::Platform* platform_; // Not owned. XlaCompiler::ShapeRepresentationFn shape_representation_fn_; PaddedShapeFn padded_shape_fn_; + const bool use_multiple_streams_; TF_DISALLOW_COPY_AND_ASSIGN(Metadata); }; @@ -89,6 +95,8 @@ class XlaDevice : public LocalDevice { // 'transfer_as_literal' is true if device<->host transfers must be done using // XLA's TransferLiteral{To,From}Device interface. If false, we can use // ThenMemcpy instead. + // If 'use_multiple_streams' is true, we create separate streams for + // host-to-device and device-to-host communication. // If padded_shape_fn is empty, a default implementation that returns // the on-host shape is used. static Status Create( @@ -96,7 +104,7 @@ class XlaDevice : public LocalDevice { int device_ordinal, const string& jit_device_name, const SessionOptions& options, const string& name_prefix, const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, + bool transfer_as_literal, bool use_multiple_streams, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, const PaddedShapeFn& padded_shape_fn, std::unique_ptr* device); @@ -106,54 +114,90 @@ class XlaDevice : public LocalDevice { XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, se::Platform* platform, bool transfer_as_literal, + bool use_multiple_streams, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, const PaddedShapeFn& padded_shape_fn); ~XlaDevice() override; - Allocator* GetAllocator(AllocatorAttributes attr) override; + Allocator* GetAllocator(AllocatorAttributes attr) override + LOCKS_EXCLUDED(mu_); void Compute(OpKernel* op_kernel, OpKernelContext* context) override; void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; Status Sync() override { return Status::OK(); } Status FillContextMap(const Graph* graph, - DeviceContextMap* device_context_map) override; + DeviceContextMap* device_context_map) override + LOCKS_EXCLUDED(mu_); Status MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, - Tensor* tensor) override; + Tensor* tensor) override LOCKS_EXCLUDED(mu_); - xla::LocalClient* client() const; const Metadata& metadata() { return xla_metadata_; } - xla::StatusOr GetStream(); - // If not already set, create and set GpuDeviceInfo. - // Not thread-safe - Status CreateAndSetGpuDeviceInfo(); + // Ensures the DeviceContext associated with this XlaDevice is created and + // valid (i.e. all streams are ok). If any state is not valid, a new + // DeviceContext will be created. + // + // TODO(b/111859745): The Eager context needs to call this method to recover + // from failures. + Status EnsureDeviceContextOk() LOCKS_EXCLUDED(mu_); + + // Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra + // information for GPU and TPU devices. + Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_); private: + xla::LocalClient* client() const; + Allocator* GetAllocatorLocked(AllocatorAttributes attr) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + Status EnsureStreamOkLocked(xla::Backend* backend, const string& name, + xla::StreamPool::Ptr* stream, + bool* stream_was_changed) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + xla::StatusOr GetDeviceContextLocked() + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + mutex mu_; // The metadata of this XlaDevice. const Metadata xla_metadata_; // Which hardware device in the client's platform this XlaDevice controls. const int device_ordinal_; // The name of the device that is used to compile Ops for this XlaDevice. - DeviceType jit_device_name_; + const DeviceType jit_device_name_; + // The platform for this device. + se::Platform* const platform_; // Not owned. // Memory allocator associated with this device. - Allocator* xla_allocator_; // Not owned. - se::Platform* platform_; // Not owned. + Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned. // Stream associated with this device. Operations enqueued on this // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and // computations enqueued by XLA. - xla::Backend::StreamPtr stream_; + xla::StreamPool::Ptr stream_ GUARDED_BY(mu_); + // If false, only stream_ is valid and all computation and transfers use + // stream_. If true, computation is performed by stream_ and transfers are + // performed by host_to_device/device_to_host_stream. + const bool use_multiple_streams_; + // If use_multiple_streams_, host to device transfers are performed using this + // stream. + xla::StreamPool::Ptr host_to_device_stream_ GUARDED_BY(mu_); + // If use_multiple_streams_, device to host transfers are performed using this + // stream. + xla::StreamPool::Ptr device_to_host_stream_ GUARDED_BY(mu_); // Must we use XLA's transfer manager for correct host<->device transfers? if // false, we can use ThenMemcpy() instead. - bool transfer_as_literal_; - XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + const bool transfer_as_literal_; + const XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + + // The device context accessed by all users of the XlaDevice, set by calls to + // EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is + // also filled in to that struct. XlaDeviceContext is a ref-counted object. + XlaDeviceContext* device_context_ GUARDED_BY(mu_) = nullptr; - // If set, holds default device context (that we must Unref) - // and its stream. - std::unique_ptr gpu_device_info_; + // Holds extra information for GPU and TPU devices, e.g. the device context. + bool use_gpu_device_info_ GUARDED_BY(mu_) = false; + std::unique_ptr gpu_device_info_ GUARDED_BY(mu_); }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index e20f5aa83766ccbdf4c19269cfbb00f9e077c2ef..0100bf51ed2a66f6d110dacd30bcdf9f48a8f64f 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -48,17 +48,24 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } XlaTransferManager::XlaTransferManager( - se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + se::Stream* compute_stream, se::Stream* host_to_device_stream, + se::Stream* device_to_host_stream, xla::LocalClient* client, + bool transfer_as_literal, XlaCompiler::ShapeRepresentationFn shape_representation_fn) - : stream_(stream), + : stream_(compute_stream), + host_to_device_stream_(host_to_device_stream), + device_to_host_stream_(device_to_host_stream), client_(client), transfer_manager_(client->backend().transfer_manager()), transfer_as_literal_(transfer_as_literal), shape_representation_fn_(std::move(shape_representation_fn)) { + CHECK(host_to_device_stream_ != nullptr); + CHECK(device_to_host_stream_ != nullptr); + CHECK(stream_ != nullptr); if (!shape_representation_fn_) { - shape_representation_fn_ = [](const TensorShape& shape, DataType dtype) { - return shape; - }; + shape_representation_fn_ = + [](const TensorShape& shape, + DataType dtype) -> xla::StatusOr { return shape; }; } } @@ -74,43 +81,47 @@ Status XlaTransferManager::TransferLiteralToDevice( auto literal = std::make_shared( static_cast(DMAHelper::base(&host_tensor)), xla_shape); - const xla::ShapedBuffer& shaped_buffer = - XlaTensor::FromTensor(device_tensor)->shaped_buffer(); + XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); + const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer(); VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " " << shaped_buffer.ToString(); + if (UseMultipleStreams()) { + // Initially wait for the compute stream so that memory allocations are + // synchronized. + host_to_device_stream_->ThenWaitFor(stream_); + } TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync( - stream_, *literal, shaped_buffer)); + host_to_device_stream_, *literal, shaped_buffer)); + if (UseMultipleStreams()) { + se::Event event(stream_->parent()); + TF_RET_CHECK(event.Init()) << "Event failed to initialize!"; + host_to_device_stream_->ThenRecordEvent(&event); + xla_tensor->SetDefinedOn(host_to_device_stream_, std::move(event)); + } // Unref the host tensor, and capture the literal shared_ptr too so it goes // out of scope when the lambda completes. - stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); }); + host_to_device_stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); }); + return Status::OK(); } void XlaTransferManager::TransferLiteralFromDevice( Tensor* host_tensor, const Tensor& device_tensor, const StatusCallback& done) const { + xla::MutableBorrowingLiteral literal; + TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(host_tensor, &literal)); + const xla::ShapedBuffer& shaped_buffer = XlaTensor::FromTensor(&device_tensor)->shaped_buffer(); TensorReference ref(device_tensor); transfer_manager_->TransferLiteralFromDevice( - stream_, shaped_buffer, - [=, &shaped_buffer]( - xla::StatusOr > literal_or) { + device_to_host_stream_, shaped_buffer, literal, + [=, &shaped_buffer, &literal](xla::Status status) { ref.Unref(); done([&]() -> Status { - TF_ASSIGN_OR_RETURN(auto literal, std::move(literal_or)); - VLOG(1) << "Transfer from device as literal: " << literal->ToString() + VLOG(1) << "Transfer from device as literal: " << literal.ToString() << " " << shaped_buffer.ToString(); - Tensor tensor; - TF_RETURN_IF_ERROR( - LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); - // Reshape the tensor back to its declared shape. - Status status; - if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { - status = errors::Internal( - "Tensor::CopyFrom failed when copying from XLA device to CPU"); - } return status; }()); }); @@ -120,62 +131,73 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const { - if (cpu_tensor->NumElements() > 0) { - VLOG(2) << "CopyCPUTensorToDevice " - << reinterpret_cast(cpu_tensor->tensor_data().data()) - << " " - << reinterpret_cast( - device_tensor->tensor_data().data()) - << " " << cpu_tensor->NumElements() << " " - << cpu_tensor->shape().DebugString() << " " - << device_tensor->shape().DebugString(); - - void* src_ptr = const_cast(DMAHelper::base(cpu_tensor)); - const int64 total_bytes = cpu_tensor->TotalBytes(); - - XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); - CHECK(xla_tensor); - - TensorShape shape = shape_representation_fn_(device_tensor->shape(), - device_tensor->dtype()); - Status status; - if (!xla_tensor->has_shaped_buffer()) { - status = xla_tensor->AllocateShapedBuffer( - device_tensor->dtype(), shape, client_, - stream_->parent()->device_ordinal()); - if (!status.ok()) { - return done(status); - } - } + if (cpu_tensor->NumElements() == 0) { + VLOG(2) << "CopyCPUTensorToDevice empty tensor"; + done(Status::OK()); + return; + } - if (transfer_as_literal_) { - Tensor reshaped_cpu_tensor; - if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) { - done(errors::Internal( - "Tensor::CopyFrom failed when copying from CPU to XLA device")); - return; - } - status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); - } else { - se::DeviceMemoryBase dev_dst_ptr = - XlaTensor::DeviceMemoryFromTensor(*device_tensor); - stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); - // TODO(hpucha): Make this asynchronous. - Status block_status = stream_->BlockHostUntilDone(); - if (!block_status.ok()) { - status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_, - block_status.error_message().c_str()); - } - } - xla_tensor->set_host_tensor(*cpu_tensor); + VLOG(2) << "CopyCPUTensorToDevice " + << reinterpret_cast(cpu_tensor->tensor_data().data()) + << " " + << reinterpret_cast(device_tensor->tensor_data().data()) + << " " << cpu_tensor->NumElements() << " " + << cpu_tensor->shape().DebugString() << " " + << device_tensor->shape().DebugString(); + + void* src_ptr = const_cast(DMAHelper::base(cpu_tensor)); + const int64 total_bytes = cpu_tensor->TotalBytes(); + + XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); + CHECK(xla_tensor); - done(status); + xla::StatusOr shape_or_status = + shape_representation_fn_(device_tensor->shape(), device_tensor->dtype()); + if (!shape_or_status.ok()) { + done(shape_or_status.status()); return; } + TensorShape shape = shape_or_status.ValueOrDie(); + if (!xla_tensor->has_shaped_buffer()) { + Status s = + xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_, + stream_->parent()->device_ordinal()); + if (!s.ok()) { + done(s); + return; + } + } - VLOG(2) << "CopyCPUTensorToDevice empty tensor"; - done(Status::OK()); + Status status; + if (transfer_as_literal_) { + Tensor reshaped_cpu_tensor; + if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) { + done(errors::Internal( + "Tensor::CopyFrom failed when copying from CPU to XLA device")); + return; + } + status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); + if (status.ok()) { + xla_tensor->set_host_tensor(*cpu_tensor); + host_to_device_stream_->ThenDoHostCallback( + [done]() { done(Status::OK()); }); + return; + } + } else { + se::DeviceMemoryBase dev_dst_ptr = + XlaTensor::DeviceMemoryFromTensor(*device_tensor); + host_to_device_stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); + // TODO(hpucha): Make this asynchronous. + Status block_status = host_to_device_stream_->BlockHostUntilDone(); + if (!block_status.ok()) { + status = xla::InternalError( + "Failed to complete data transfer on stream %p: %s", + host_to_device_stream_, block_status.error_message().c_str()); + } + } + xla_tensor->set_host_tensor(*cpu_tensor); + + done(status); } void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, @@ -183,68 +205,102 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, Device* device, Tensor* cpu_tensor, StatusCallback done) { - if (device_tensor->NumElements() > 0) { - VLOG(2) << "CopyDeviceTensorToCPU " - << reinterpret_cast( - device_tensor->tensor_data().data()) - << " " - << reinterpret_cast(cpu_tensor->tensor_data().data()) - << " " << device_tensor->NumElements() << " " - << cpu_tensor->shape().DebugString() << " " - << device_tensor->shape().DebugString(); - - const int64 total_bytes = cpu_tensor->TotalBytes(); - se::DeviceMemoryBase dev_src_ptr = - XlaTensor::DeviceMemoryFromTensor(*device_tensor); - void* dst_ptr = DMAHelper::base(cpu_tensor); + if (device_tensor->NumElements() == 0) { + VLOG(2) << "CopyDeviceTensorToCPU empty tensor"; + done(Status::OK()); + return; + } + VLOG(2) << "CopyDeviceTensorToCPU " + << reinterpret_cast(device_tensor->tensor_data().data()) + << " " + << reinterpret_cast(cpu_tensor->tensor_data().data()) + << " " << device_tensor->NumElements() << " " + << cpu_tensor->shape().DebugString() << " " + << device_tensor->shape().DebugString(); + + const int64 total_bytes = cpu_tensor->TotalBytes(); + se::DeviceMemoryBase dev_src_ptr = + XlaTensor::DeviceMemoryFromTensor(*device_tensor); + void* dst_ptr = DMAHelper::base(cpu_tensor); + XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); + + if (se::Event* event = + xla_tensor->GetDefinitionEvent(device_to_host_stream_)) { + device_to_host_stream_->ThenWaitFor(event); + xla_tensor->SetDefinedOn(device_to_host_stream_); + } - Status status; - if (transfer_as_literal_) { - TransferLiteralFromDevice(cpu_tensor, *device_tensor, done); - return; - } else { - stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); - // TODO(hpucha): Make this asynchronous. - Status block_status = stream_->BlockHostUntilDone(); - if (!block_status.ok()) { - status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_, - block_status.error_message().c_str()); - } - done(status); - } + Status status; + if (transfer_as_literal_) { + TransferLiteralFromDevice(cpu_tensor, *device_tensor, done); return; + } else { + device_to_host_stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); + // TODO(hpucha): Make this asynchronous. + Status block_status = device_to_host_stream_->BlockHostUntilDone(); + if (!block_status.ok()) { + status = xla::InternalError( + "Failed to complete data transfer on stream %p: %s", stream_, + block_status.error_message().c_str()); + } } - VLOG(2) << "CopyDeviceTensorToCPU empty tensor"; - done(Status::OK()); + done(status); } void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done) { + VLOG(2) << "CopyDeviceTensorToDevice " + << reinterpret_cast(src_tensor.tensor_data().data()) + << " " + << reinterpret_cast(dst_tensor->tensor_data().data()); // Perform memory allocation now, and enqueue the device-to-device transfer. Status status = [&]() -> Status { if (src_tensor.NumElements() == 0) { return Status::OK(); } + // TODO(jmolloy): We co-opt the device_to_host stream for device to device + // transfers; perhaps we should have a dedicated device to device stream? or + // one per device? + auto device_to_device_stream = stream_; XlaTensor* xla_src = XlaTensor::FromTensor(&src_tensor); XlaTensor* xla_dst = XlaTensor::FromTensor(dst_tensor); CHECK(xla_src && xla_dst) << "Missing destination tensor for device-to-device copy"; if (!xla_dst->has_shaped_buffer()) { - TensorShape shape = - shape_representation_fn_(src_tensor.shape(), src_tensor.dtype()); + TF_ASSIGN_OR_RETURN( + TensorShape shape, + shape_representation_fn_(src_tensor.shape(), src_tensor.dtype())); TF_RETURN_IF_ERROR( xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_, stream_->parent()->device_ordinal())); + if (stream_ != device_to_device_stream) { + // Initially wait for the compute stream so that memory allocations are + // synchronized. + device_to_device_stream->ThenWaitFor(stream_); + } } + + if (se::Event* event = + xla_src->GetDefinitionEvent(device_to_device_stream)) { + device_to_device_stream->ThenWaitFor(event); + xla_src->SetDefinedOn(device_to_device_stream); + } + auto from_iter = xla_src->shaped_buffer().buffers().begin(); auto to_iter = xla_dst->shaped_buffer().buffers().begin(); for (auto end_iter = xla_src->shaped_buffer().buffers().end(); from_iter != end_iter; ++from_iter, ++to_iter) { - stream_->ThenMemcpyD2D(&to_iter->second, from_iter->second, - to_iter->second.size()); + device_to_device_stream->ThenMemcpyD2D( + &to_iter->second, from_iter->second, to_iter->second.size()); + } + + if (UseMultipleStreams()) { + se::Event event(stream_->parent()); + CHECK(event.Init()); + device_to_device_stream->ThenRecordEvent(&event); + xla_dst->SetDefinedOn(device_to_device_stream, std::move(event)); } return Status::OK(); }(); @@ -256,9 +312,12 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, } XlaDeviceContext::XlaDeviceContext( - se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + se::Stream* compute_stream, se::Stream* host_to_device_stream, + se::Stream* device_to_host_stream, xla::LocalClient* client, + bool transfer_as_literal, XlaCompiler::ShapeRepresentationFn shape_representation_fn) - : manager_(stream, client, transfer_as_literal, + : manager_(compute_stream, host_to_device_stream, device_to_host_stream, + client, transfer_as_literal, std::move(shape_representation_fn)) {} void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index c5c81d65fe0f4a2774aab9f742454467e052071e..912f8d779e72f44821bc4fb25efa30bd35d01412 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -47,7 +47,9 @@ class XlaDeviceAllocator : public Allocator { class XlaTransferManager { public: explicit XlaTransferManager( - se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + se::Stream* compute_stream, se::Stream* host_to_device_stream, + se::Stream* device_to_host_stream, xla::LocalClient* client, + bool transfer_as_literal, XlaCompiler::ShapeRepresentationFn shape_representation_fn); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, @@ -67,10 +69,17 @@ class XlaTransferManager { void TransferLiteralFromDevice(Tensor* host_tensor, const Tensor& device_tensor, const StatusCallback& done) const; + bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; } - // Stream obtained from a Device, used to transfer tensors between - // CPU and device. + // The main compute stream of the device, used to synchronize the transfer + // streams if they are set. se::Stream* stream_; + // The stream to use for transferring data from host to device. Can be + // idential to stream_, but must not be nullptr. + se::Stream* host_to_device_stream_; + // The stream to use for transferring data from device to host. Can be + // idential to stream_, but must not be nullptr. + se::Stream* device_to_host_stream_; // For the underlying memory allocator and XLA's TransferManager. xla::LocalClient* client_; // Transfer manager, for marshalling data to and from the device. @@ -86,7 +95,9 @@ class XlaTransferManager { class XlaDeviceContext : public DeviceContext { public: explicit XlaDeviceContext( - se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + se::Stream* compute_stream, se::Stream* host_to_device_stream, + se::Stream* device_to_host_stream, xla::LocalClient* client, + bool transfer_as_literal, XlaCompiler::ShapeRepresentationFn shape_representation_fn); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 11e45d2823da2b623bd3cd45f7147686b05fdb2f..da3e329247e825d4a33a53dc310899d6ba6ce9cf 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -23,9 +23,15 @@ limitations under the License. #include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/control_flow_ops.h" +#include "tensorflow/core/kernels/data/generator_dataset_op.h" +#include "tensorflow/core/kernels/data/iterator_ops.h" +#include "tensorflow/core/kernels/data/prefetch_dataset_op.h" +#include "tensorflow/core/kernels/fifo_queue.h" +#include "tensorflow/core/kernels/function_ops.h" #include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/no_op.h" +#include "tensorflow/core/kernels/queue_op.h" #include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/sendrecv_ops.h" #include "tensorflow/core/kernels/shape_ops.h" @@ -75,9 +81,7 @@ class XlaAssignVariableOp : public AsyncOpKernel { ConstantOp); \ REGISTER_KERNEL_BUILDER( \ Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("IdentityN").Device(DEVICE).TypeConstraint("T", TYPES), \ - IdentityNOp); \ + REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE), IdentityNOp); \ REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \ REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \ PlaceholderOp); \ @@ -88,6 +92,9 @@ class XlaAssignVariableOp : public AsyncOpKernel { REGISTER_KERNEL_BUILDER( \ Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \ ReadVariableOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("DestroyResourceOp").Device(DEVICE).HostMemory("resource"), \ + DestroyResourceOp); \ REGISTER_KERNEL_BUILDER(Name("Shape") \ .Device(DEVICE) \ .HostMemory("output") \ @@ -145,7 +152,94 @@ class XlaAssignVariableOp : public AsyncOpKernel { .Device(DEVICE) \ .HostMemory("input") \ .HostMemory("output"), \ - LoopCondOp); + LoopCondOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueEnqueueV2").Device(DEVICE).HostMemory("handle"), EnqueueOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueDequeueV2").Device(DEVICE).HostMemory("handle"), DequeueOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueCloseV2").Device(DEVICE).HostMemory("handle"), QueueCloseOp); \ + REGISTER_KERNEL_BUILDER(Name("QueueSizeV2") \ + .Device(DEVICE) \ + .HostMemory("size") \ + .HostMemory("handle"), \ + QueueSizeOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("QueueIsClosedV2").Device(DEVICE).HostMemory("handle"), \ + QueueIsClosedOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name(kArgOp).Device(DEVICE).HostMemory("output").TypeConstraint("T", \ + TYPES), \ + ArgOp); \ + REGISTER_KERNEL_BUILDER(Name(kArgOp) \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ArgOp); \ + \ + REGISTER_KERNEL_BUILDER(Name(kRetOp) \ + .Device(DEVICE) \ + .TypeConstraint("T", TYPES) \ + .HostMemory("input"), \ + RetvalOp); \ + REGISTER_KERNEL_BUILDER(Name(kRetOp) \ + .Device(DEVICE) \ + .TypeConstraint("T") \ + .HostMemory("input"), \ + RetvalOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("RemoteCall").Device(DEVICE).HostMemory("target"), RemoteCallOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"), \ + GeneratorDatasetOp); \ + REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") \ + .Device(DEVICE) \ + .HostMemory("buffer_size") \ + .HostMemory("input_dataset") \ + .HostMemory("handle"), \ + PrefetchDatasetOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE), \ + IteratorHandleOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("MakeIterator").Device(DEVICE).HostMemory("dataset"), \ + MakeIteratorOp); \ + REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \ + AnonymousIteratorHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ + IteratorGetNextOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \ + .Device(DEVICE) \ + .HostMemory("string_handle"), \ + IteratorToStringHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2") \ + .Device(DEVICE) \ + .HostMemory("string_handle"), \ + IteratorFromStringHandleOp); \ + REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ArgOp); \ + REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kRetOp) \ + .Device(DEVICE) \ + .TypeConstraint("T") \ + .HostMemory("input"), \ + RetvalOp); + +// TODO(phawkins): currently we do not register the QueueEnqueueMany, +// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read +// and write the tensors they access in order to concatenate them into a batch. +// We would need either to call out to an XLA computation to perform the +// concatenation, or we would need to refactor those kernels so the splitting +// or merging is done in a separate operator that can be compiled. } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc index 74257b09a808a39454eace3b1a9bf57a2e071360..4b499b161371ecece14447b29fbf809b6e8857db 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/union_find.h" @@ -146,6 +147,9 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, TF_RETURN_IF_ERROR( ImportGraphDef(options, item.graph, &graph, &shape_refiner)); + std::unique_ptr deadness; + TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(graph, &deadness)); + // Collect nodes that can be fused via XLA, while ignoring those that // explicitly ask for XLA: (*) nodes that are marked to be compiled // explicitly. (*) nodes assigned to XLA device. @@ -185,6 +189,14 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, continue; } + // If inputs to `node` can have conflicting deadness (i.e. some are alive + // and some are dead) then don't compile it. XLA cannot represent the + // deadness semantics of these nodes correctly and auto-clustering these + // nodes can cause deadness to propagate to nodes that should be live. + if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) { + continue; + } + compilation_candidates.insert(node); } diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index c0d86a28c7698c302e28bab972bb2f847cc00ca4..ef4466f0056ea98adc1ae6774105466af0d14293 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -49,6 +49,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix, registration, /*transfer_as_literal=*/false, + /*use_multiple_streams=*/false, /*shape_representation_fn=*/{}, /*padded_shape_fn=*/{}, &device); if (!status.ok()) { @@ -58,7 +59,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, } // TODO(b/78468222): Uncomment after fixing this bug - // status = device->CreateAndSetGpuDeviceInfo(); + // status = device->UseGpuDeviceInfo(); // if (!status.ok()) { // errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT, // " device"); diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 661187f4a873b03b8d013aa74cb6b6315bb4e2eb..45745596749207189c60ee1e3dcf19b6ecb7eb5b 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -52,6 +52,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices( DEVICE_INTERPRETER_XLA_JIT, options, name_prefix, registration, /*transfer_as_literal=*/false, + /*use_multiple_streams=*/false, /*shape_representation_fn=*/{}, /*padded_shape_fn=*/{}, &device)); devices->push_back(device.release()); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 5ceccc769fa2e95d4cf4d2b4ebd8dbf312ebdfd0..6134b8c6946429918a5ca37188cbff13a6cd1c79 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -64,11 +64,13 @@ xla::StatusOr XlaAllocator::Allocate( int device_ordinal, uint64 size, bool retry_on_failure) { AllocationAttributes attrs; attrs.no_retry_on_failure = !retry_on_failure; - void* data = - wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs); - if (data == nullptr) { - return errors::ResourceExhausted("Out of memory while trying to allocate ", - size, " bytes."); + void* data = nullptr; + if (size != 0) { + data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs); + if (data == nullptr) { + return errors::ResourceExhausted( + "Out of memory while trying to allocate ", size, " bytes."); + } } return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size), device_ordinal, this); @@ -115,14 +117,22 @@ using internal::ExtractSubShapedBuffer; XlaComputationLaunchContext::XlaComputationLaunchContext( xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, - bool allocate_xla_tensors) + bool allocate_xla_tensors, bool use_multiple_streams) : client_(client), xla_allocator_(xla_allocator), - allocate_xla_tensors_(allocate_xla_tensors) {} + allocate_xla_tensors_(allocate_xla_tensors), + use_multiple_streams_(use_multiple_streams) { + if (use_multiple_streams_) { + CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must " + "be allocating XLA tensors!"; + } +} void XlaComputationLaunchContext::PopulateInputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, const std::map& variables) { + se::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; // Build ShapedBuffers that point directly to the Tensor buffers. arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1); arg_buffers_.resize(kernel->xla_input_shapes.size()); @@ -140,6 +150,16 @@ void XlaComputationLaunchContext::PopulateInputs( t = &(ctx->input(arg_num)); } + if (use_multiple_streams_) { + CHECK(stream) << "Must have a stream available when using XLA tensors!"; + XlaTensor* xla_tensor = XlaTensor::FromTensor(t); + CHECK(xla_tensor); + if (se::Event* event = xla_tensor->GetDefinitionEvent(stream)) { + stream->ThenWaitFor(event); + xla_tensor->SetDefinedOn(stream); + } + } + const xla::Shape on_device_shape = client_->backend().transfer_manager()->HostShapeToDeviceShape(shape); if (xla::ShapeUtil::IsTuple(on_device_shape)) { @@ -248,6 +268,12 @@ void XlaComputationLaunchContext::PopulateOutputs( if (xla_tensor) { xla_tensor->set_shaped_buffer(ScopedShapedBuffer( ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + if (use_multiple_streams_) { + se::Event event(stream->parent()); + CHECK(event.Init()); + stream->ThenRecordEvent(&event); + xla_tensor->SetDefinedOn(stream, std::move(event)); + } } else { // xla_tensor wasn't valid, which must mean this is a zero-element // tensor. @@ -302,6 +328,12 @@ void XlaComputationLaunchContext::PopulateOutputs( CHECK(xla_tensor); xla_tensor->set_shaped_buffer( ExtractSubShapedBuffer(&output, output_num, xla_allocator_)); + if (use_multiple_streams_) { + se::Event event(stream->parent()); + CHECK(event.Init()); + stream->ThenRecordEvent(&event); + xla_tensor->SetDefinedOn(stream, std::move(event)); + } *variable->tensor() = output_tensor; } else { Tensor output_tensor = XlaTensorBuffer::MakeTensor( diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 4390701ccbd0bc3971413ddcd917c11019990087..1ea3fa4cf29266e8c452385226e56bd0b82622d9 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -76,9 +76,15 @@ class XlaComputationLaunchContext { // Create a new launch context. 'allocate_xla_tensors' is true if allocated // output tensors and variables are always XlaTensors. If false they are // assumed to be "normal" device pointers. + // If 'use_multiple_streams' is true, tensors may be defined and used on + // multiple streams and so se::Events must be defined and waited for. If + // 'use_multiple_streams' is true, 'allocate_xla_tensors' must also be true + // because we track inter-stream dependencies through events inside XlaTensor + // objects. XlaComputationLaunchContext(xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, - bool allocate_xla_tensors); + bool allocate_xla_tensors, + bool use_multiple_streams); // Add all inputs within `ctx` as XLA arguments (returned by arguments()). // `variables` is a map from TensorFlow argument number to resource variable. @@ -99,6 +105,7 @@ class XlaComputationLaunchContext { xla::LocalClient* client_; xla::DeviceMemoryAllocator* xla_allocator_; bool allocate_xla_tensors_; + bool use_multiple_streams_; std::vector> arg_buffers_; std::vector arg_ptrs_; }; @@ -115,7 +122,11 @@ class XlaTensorBuffer : public TensorBuffer { data_ = const_cast(ptr); } - ~XlaTensorBuffer() override { allocator_->DeallocateRaw(data_); } + ~XlaTensorBuffer() override { + if (data_) { + allocator_->DeallocateRaw(data_); + } + } void* data() const override { return data_; } size_t size() const override { return expected_size_; } diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index 3c44c4ae6df7f3e2d60d8933561c0c71888e8c3f..d777dfa5a34fb9615ddcf393ed53be1491cb70af 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -73,6 +73,34 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, return Status::OK(); } +se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) { + mutex_lock lock(mu_); + if (!definition_event_.has_value()) { + return nullptr; + } + + // The set of defined streams is expected to be very small indeed (usually + // 1-2), so a simple linear scan should be fast enough. + if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(), + stream) != streams_defined_on_.end()) { + // stream is in streams_defined_on_; it doesn't need to be waited on. + return nullptr; + } + + return &*definition_event_; +} + +void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) { + mutex_lock lock(mu_); + definition_event_ = std::move(event); + streams_defined_on_ = {stream}; +} + +void XlaTensor::SetDefinedOn(se::Stream* stream) { + mutex_lock lock(mu_); + streams_defined_on_.push_back(stream); +} + // The pointer tag, OR-ed into the XlaTensor's address to distinguish it from // device-side tensors, which are either CPU or GPU memory pointers. This works // because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits. diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index c54001a999998f45c0cdacd752ca4036f0792857..f7e401c731163200c518074f2caa6907efb1f684 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -85,6 +85,24 @@ class XlaTensor { host_tensor_.reset(new Tensor(tensor)); } + // If the tensor's content is not yet defined on 'stream', and there exists an + // se::Event declaring when the tensor's content is defined, return it. + // Otherwise, return nullptr. If this function returns nullptr then the + // tensor's content can be read on 'stream' without additional + // synchronization. + se::Event* GetDefinitionEvent(se::Stream* stream); + + // Assert that the tensor's content is defined on 'stream' by the time 'event' + // triggers. + void SetDefinedOn(se::Stream* stream, se::Event event); + + // Assert that the tensor's content is defined on 'stream'. This version does + // not provide an event, and must be called *after* SetDefinedOn(Stream, + // Event). This call can be read as an assertion that the definition event has + // been waited on by 'stream', so further calls to GetDefinitionEvent(stream) + // do not need to also wait on the event. + void SetDefinedOn(se::Stream* stream); + // Convert from a raw pointer to an XlaTensor, removing the pointer tag. static XlaTensor* FromOpaquePointer(void* ptr); // Convert to a raw pointer from an XlaTensor, adding the pointer tag. @@ -95,6 +113,14 @@ class XlaTensor { std::unique_ptr shaped_buffer_; // An optional host tensor value. std::unique_ptr host_tensor_; + // An optional event that is triggered when the tensor's content has been + // defined. If this event is nullptr, it is assumed that the tensor's content + // is always defined. + gtl::optional definition_event_; + // A list of all streams for which the tensor's content is defined for any + // newly enqueued command. + gtl::InlinedVector streams_defined_on_ GUARDED_BY(mu_); + mutex mu_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index c1f65416b44b7d8e014aba1338fc3ba5dd6fba5b..ae98b3f0f9d5dac66b9716ad84a9f0371511e9b6 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -70,6 +70,19 @@ py_test( ], ) +tf_xla_py_test( + name = "adadelta_test", + size = "medium", + srcs = ["adadelta_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "adagrad_test", size = "small", @@ -84,6 +97,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "adagrad_da_test", + size = "small", + srcs = ["adagrad_da_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "adam_test", size = "small", @@ -98,6 +124,48 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "adamax_test", + size = "small", + srcs = ["adamax_test.py"], + deps = [ + ":xla_test", + "//tensorflow/contrib/opt:opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:training", + ], +) + +tf_xla_py_test( + name = "addsign_test", + size = "small", + srcs = ["addsign_test.py"], + deps = [ + ":xla_test", + "//tensorflow/contrib/opt:opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:training", + ], +) + +tf_xla_py_test( + name = "powersign_test", + size = "small", + srcs = ["powersign_test.py"], + deps = [ + ":xla_test", + "//tensorflow/contrib/opt:opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "argminmax_test", size = "small", @@ -167,7 +235,7 @@ tf_xla_py_test( tf_xla_py_test( name = "cholesky_op_test", - size = "small", + size = "medium", srcs = ["cholesky_op_test.py"], tags = ["optonly"], deps = [ @@ -350,7 +418,7 @@ tf_xla_py_test( tf_xla_py_test( name = "eager_test", - size = "small", + size = "large", srcs = ["eager_test.py"], disabled_backends = [ # TODO(b/78199195) Support XLA CPU devices in eager runtime @@ -371,6 +439,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "fifo_queue_test", + size = "medium", + srcs = ["fifo_queue_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "fft_test", size = "medium", @@ -556,16 +638,60 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "proximal_adagrad_test", + size = "medium", + srcs = ["proximal_adagrad_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:training", + ], +) + +tf_xla_py_test( + name = "proximal_gradient_descent_test", + size = "medium", + srcs = ["proximal_gradient_descent_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:training", + ], +) + +tf_xla_py_test( + name = "qr_op_test", + size = "medium", + srcs = ["qr_op_test.py"], + disabled_backends = [ + # Test is very slow on CPU. + "cpu", + "cpu_ondemand", + ], + shard_count = 5, + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + "@absl_py//absl/testing:parameterized", + ], +) + tf_xla_py_test( name = "random_ops_test", size = "small", srcs = ["random_ops_test.py"], disabled_backends = [ - # TODO(b/110300529): RngNormal doesn't return values with the expected variance - "cpu", "cpu_ondemand", - # TODO(b/31361304): enable RNG ops on GPU when parallelized. - "gpu", ], deps = [ ":xla_test", @@ -871,8 +997,11 @@ tf_xla_py_test( tf_xla_py_test( name = "sort_ops_test", - size = "small", + size = "medium", srcs = ["sort_ops_test.py"], + shard_count = 5, + # Times out in fastbuild mode. + tags = ["optonly"], deps = [ "//tensorflow/compiler/tests:xla_test", "//tensorflow/compiler/tf2xla/python:xla", diff --git a/tensorflow/compiler/tests/adadelta_test.py b/tensorflow/compiler/tests/adadelta_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3e3c09c66e72c4de141b64cea3c4693fabb7b2a2 --- /dev/null +++ b/tensorflow/compiler/tests/adadelta_test.py @@ -0,0 +1,134 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Adadelta Optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adadelta + + +class AdadeltaOptimizerTest(xla_test.XLATestCase): + + def testBasic(self): + num_updates = 4 # number of ADADELTA steps to perform + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + for grad in [0.2, 0.1, 0.01]: + for lr in [1.0, 0.5, 0.1]: + var0_init = [1.0, 2.0] + var1_init = [3.0, 4.0] + var0 = resource_variable_ops.ResourceVariable( + var0_init, dtype=dtype) + var1 = resource_variable_ops.ResourceVariable( + var1_init, dtype=dtype) + + grads = constant_op.constant([grad, grad], dtype=dtype) + + accum = 0.0 + accum_update = 0.0 + + # ADADELTA gradient optimizer + rho = 0.95 + epsilon = 1e-8 + adadelta_opt = adadelta.AdadeltaOptimizer( + learning_rate=lr, rho=rho, epsilon=epsilon) + adadelta_update = adadelta_opt.apply_gradients( + zip([grads, grads], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + opt_vars = adadelta_opt.variables() + self.assertStartsWith(opt_vars[0].name, var0._shared_name) + self.assertStartsWith(opt_vars[1].name, var0._shared_name) + self.assertStartsWith(opt_vars[2].name, var1._shared_name) + self.assertStartsWith(opt_vars[3].name, var1._shared_name) + self.assertEqual(4, len(opt_vars)) + # Assign slots + slot = [None] * 2 + slot_update = [None] * 2 + self.assertEqual(["accum", "accum_update"], + adadelta_opt.get_slot_names()) + slot[0] = adadelta_opt.get_slot(var0, "accum") + self.assertEquals(slot[0].get_shape(), var0.get_shape()) + self.assertFalse(slot[0] in variables.trainable_variables()) + + slot_update[0] = adadelta_opt.get_slot(var0, "accum_update") + self.assertEquals(slot_update[0].get_shape(), var0.get_shape()) + self.assertFalse(slot_update[0] in variables.trainable_variables()) + + slot[1] = adadelta_opt.get_slot(var1, "accum") + self.assertEquals(slot[1].get_shape(), var1.get_shape()) + self.assertFalse(slot[1] in variables.trainable_variables()) + + slot_update[1] = adadelta_opt.get_slot(var1, "accum_update") + self.assertEquals(slot_update[1].get_shape(), var1.get_shape()) + self.assertFalse(slot_update[1] in variables.trainable_variables()) + + # Fetch params to validate initial values + self.assertAllClose(var0_init, self.evaluate(var0)) + self.assertAllClose(var1_init, self.evaluate(var1)) + + update = [None] * num_updates + tot_update = 0 + for step in range(num_updates): + # Run adadelta update for comparison + self.evaluate(adadelta_update) + + # Perform initial update without previous accum values + accum = accum * rho + (grad**2) * (1 - rho) + update[step] = ( + np.sqrt(accum_update + epsilon) * + (1. / np.sqrt(accum + epsilon)) * grad) + accum_update = ( + accum_update * rho + (update[step]**2) * (1.0 - rho)) + tot_update += update[step] * lr + + # Check that the accumulators have been updated + for slot_idx in range(2): + self.assertAllCloseAccordingToType( + np.array([accum, accum], dtype=dtype), + self.evaluate(slot[slot_idx]), + rtol=1e-5) + + self.assertAllCloseAccordingToType( + np.array([accum_update, accum_update], dtype=dtype), + self.evaluate(slot_update[slot_idx]), + rtol=1e-5) + + # Check that the parameters have been updated + self.assertAllCloseAccordingToType( + np.array( + [var0_init[0] - tot_update, var0_init[1] - tot_update], + dtype=dtype), + self.evaluate(var0), + rtol=1e-5) + + self.assertAllCloseAccordingToType( + np.array( + [var1_init[0] - tot_update, var1_init[1] - tot_update], + dtype=dtype), + self.evaluate(var1), + rtol=1e-5) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/adagrad_da_test.py b/tensorflow/compiler/tests/adagrad_da_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dc1625793aa44b96d3b96e175237caf96e7d7e74 --- /dev/null +++ b/tensorflow/compiler/tests/adagrad_da_test.py @@ -0,0 +1,165 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for AdagradDA optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adagrad_da + + +class AdagradDAOptimizerTest(xla_test.XLATestCase): + + def testAdagradDAWithoutRegularizationBasic1(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + global_step = resource_variable_ops.ResourceVariable( + 0, dtype=dtypes.int64) + var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.02], dtype=dtype) + opt = adagrad_da.AdagradDAOptimizer( + 3.0, + global_step, + initial_gradient_squared_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0) + update = opt.apply_gradients( + zip([grads0, grads1], [var0, var1]), global_step=global_step) + variables.global_variables_initializer().run() + + self.assertAllClose([0.0, 0.0], var0.eval()) + self.assertAllClose([0.0, 0.0], var1.eval()) + + # Run a step of AdagradDA + update.run() + + # Let g to be gradient accumulator, gg to be gradient squared + # accumulator, T be the global step, lr is the learning rate, and k the + # initial gradient squared accumulator value. + # w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})} + # For -0.1*3.0*(0.1 - 0)/(0 + sqrt(0.1 + 0.1*0.1)) = -0.904534 + # similarly for others. + self.assertAllCloseAccordingToType( + np.array([-0.904534, -1.603567]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([-0.094821, -0.189358]), var1.eval()) + + def testAdagradDAwithoutRegularizationBasic2(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + global_step = resource_variable_ops.ResourceVariable( + 0, dtype=dtypes.int64) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.02], dtype=dtype) + + opt = adagrad_da.AdagradDAOptimizer( + 3.0, + global_step, + initial_gradient_squared_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0) + update = opt.apply_gradients( + zip([grads0, grads1], [var0, var1]), global_step=global_step) + variables.global_variables_initializer().run() + + self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) + self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + + # Run a step of AdagradDA + update.run() + + self.assertAllCloseAccordingToType( + np.array([-0.904534, -1.603567]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([-0.094821, -0.189358]), var1.eval()) + + def testAdagradDAWithL1(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + global_step = resource_variable_ops.ResourceVariable( + 0, dtype=dtypes.int64) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.02], dtype=dtype) + + opt = adagrad_da.AdagradDAOptimizer( + 3.0, + global_step, + initial_gradient_squared_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=0.0) + update = opt.apply_gradients( + zip([grads0, grads1], [var0, var1]), global_step=global_step) + variables.global_variables_initializer().run() + + self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) + self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + + # Run a step of AdagradDA + update.run() + + self.assertAllCloseAccordingToType( + np.array([-0.895489, -1.59555]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([-0.085339, -0.17989]), var1.eval()) + + def testAdagradDAWithL1_L2(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + global_step = resource_variable_ops.ResourceVariable( + 0, dtype=dtypes.int64) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.02], dtype=dtype) + + opt = adagrad_da.AdagradDAOptimizer( + 3.0, + global_step, + initial_gradient_squared_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=2.0) + update = opt.apply_gradients( + zip([grads0, grads1], [var0, var1]), global_step=global_step) + variables.global_variables_initializer().run() + + self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) + self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + + # Run a step of AdagradDA + update.run() + + self.assertAllCloseAccordingToType( + np.array([-0.046907, -0.093659]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([-0.004275, -0.009023]), var1.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py index 9a93b3216404d8ed21fd6c57757bec1730c119b4..d775850a80e9f83f7b2c9f1cf8997dd50e229635 100644 --- a/tensorflow/compiler/tests/adagrad_test.py +++ b/tensorflow/compiler/tests/adagrad_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -28,7 +28,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import adagrad -class AdagradOptimizerTest(XLATestCase): +class AdagradOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index 3215dc36e5b2d517aa951db1b0d41188185ef93a..0d2e4d029636577adc74784d9a8b3494b94dc67d 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops @@ -48,10 +48,13 @@ def adam_update_numpy(param, return param_t, m_t, v_t -class AdamOptimizerTest(XLATestCase): +class AdamOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: + # TODO: test fails for float16 due to excessive precision requirements. + if dtype == np.float16: + continue with self.test_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) @@ -91,6 +94,9 @@ class AdamOptimizerTest(XLATestCase): def testTensorLearningRate(self): for dtype in self.float_types: + # TODO: test fails for float16 due to excessive precision requirements. + if dtype == np.float16: + continue with self.test_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) @@ -130,6 +136,9 @@ class AdamOptimizerTest(XLATestCase): def testSharing(self): for dtype in self.float_types: + # TODO: test fails for float16 due to excessive precision requirements. + if dtype == np.float16: + continue with self.test_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) diff --git a/tensorflow/compiler/tests/adamax_test.py b/tensorflow/compiler/tests/adamax_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c4fdbc5974319db9243eb2c323746cbaaea795f6 --- /dev/null +++ b/tensorflow/compiler/tests/adamax_test.py @@ -0,0 +1,139 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for AdaMax optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.contrib.opt.python.training import adamax +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def adamax_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + m_t = beta1 * m + (1 - beta1) * g_t + v_t = np.maximum(beta2 * v, np.abs(g_t)) + param_t = param - (alpha / (1 - beta1**t)) * (m_t / (v_t + epsilon)) + return param_t, m_t, v_t + + +class AdaMaxOptimizerTest(xla_test.XLATestCase): + + def testBasic(self): + for i, dtype in enumerate(self.float_types): + with self.test_session(), self.test_scope(): + variable_scope.get_variable_scope().set_use_resource(True) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = adamax.AdaMaxOptimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + opt_variables = opt.variables() + beta1_power = opt._get_beta_accumulators() + self.assertTrue(beta1_power is not None) + self.assertIn(beta1_power, opt_variables) + + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power = opt._get_beta_accumulators() + + # Run 3 steps of AdaMax + for t in range(1, 4): + update.run() + + self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval()) + + var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval(), rtol=1e-2) + self.assertAllCloseAccordingToType(var1_np, var1.eval(), rtol=1e-2) + self.assertEqual("var0_%d/AdaMax:0" % (i,), + opt.get_slot(var=var0, name="m").name) + + def testTensorLearningRate(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + variable_scope.get_variable_scope().set_use_resource(True) + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = adamax.AdaMaxOptimizer(constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power = opt._get_beta_accumulators() + + # Run 3 steps of AdaMax + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + update.run() + + var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/addsign_test.py b/tensorflow/compiler/tests/addsign_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9ec5a964cbb4dd98d2ef2d0b684872292118800f --- /dev/null +++ b/tensorflow/compiler/tests/addsign_test.py @@ -0,0 +1,142 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for AddSign.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.contrib.opt.python.training import addsign +from tensorflow.contrib.opt.python.training import sign_decay +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def py_linear_decay_fn(decay_steps): + def linear_decay(step): + step = min(step, decay_steps) + return float(decay_steps - step) / decay_steps + return linear_decay + + +def addsign_update_numpy(params, + g_t, + m, + lr, + alpha=1.0, + beta=0.9, + py_sign_decay_fn=None, + t=None): + m_t = beta * m + (1 - beta) * g_t + if py_sign_decay_fn is None: + sign_decayed = 1.0 + else: + sign_decayed = py_sign_decay_fn(t-1) + multiplier = alpha + sign_decayed * np.sign(g_t) * np.sign(m_t) + params_t = params - lr * multiplier * g_t + return params_t, m_t + + +class AddSignTest(xla_test.XLATestCase): + + def _testDense(self, + learning_rate=0.1, + sign_decay_fn=None, + py_sign_decay_fn=None, + alpha=1.0, + beta=0.9): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + # Initialize variables for numpy implementation. + m0, m1 = 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + global_step = resource_variable_ops.ResourceVariable(0, trainable=False) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = addsign.AddSignOptimizer( + learning_rate=learning_rate, + alpha=alpha, + beta=beta, + sign_decay_fn=sign_decay_fn, + ) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 7 steps of AddSign + # first 4 steps with positive gradient + # last 3 steps with negative gradient (sign(gm) should be -1) + for t in range(1, 8): + if t < 5: + update.run() + else: + neg_update.run() + + var0_np, m0 = addsign_update_numpy( + var0_np, + grads0_np if t < 5 else -grads0_np, + m0, + learning_rate, + alpha=alpha, + beta=beta, + py_sign_decay_fn=py_sign_decay_fn, + t=t, + ) + var1_np, m1 = addsign_update_numpy( + var1_np, + grads1_np if t < 5 else -grads1_np, + m1, + learning_rate, + alpha=alpha, + beta=beta, + py_sign_decay_fn=py_sign_decay_fn, + t=t, + ) + + # Validate updated params + self.assertAllCloseAccordingToType( + var0_np, var0.eval(), half_rtol=1e-2) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testDense(self): + decay_steps = 10 + sign_decay_fn = sign_decay.get_linear_decay_fn(decay_steps) + py_sign_decay_fn = py_linear_decay_fn(decay_steps) + self._testDense() + self._testDense(learning_rate=0.01, alpha=0.1, beta=0.8) + self._testDense( + sign_decay_fn=sign_decay_fn, py_sign_decay_fn=py_sign_decay_fn) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index afef36d9d202a822d81ace185a9112bb83daae8c..0aafda7fb4d710f154157ee352d6616e5aa8935f 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops @@ -32,7 +32,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest -class BinaryOpsTest(XLATestCase): +class BinaryOpsTest(xla_test.XLATestCase): """Test cases for binary operators.""" def _testBinary(self, op, a, b, expected, equality_test=None): @@ -691,11 +691,13 @@ class BinaryOpsTest(XLATestCase): np.array([[10], [7], [2]], dtype=np.float32), np.float32(7), expected=np.array([[False], [False], [True]], dtype=np.bool)) - self._testBinary( - less_op, - np.array([[10], [7], [2], [-1]], dtype=np.int64), - np.int64(7), - expected=np.array([[False], [False], [True], [True]], dtype=np.bool)) + if np.int64 in self.numeric_types: + self._testBinary( + less_op, + np.array([[10], [7], [2], [-1]], dtype=np.int64), + np.int64(7), + expected=np.array( + [[False], [False], [True], [True]], dtype=np.bool)) for less_equal_op in [math_ops.less_equal, (lambda x, y: x <= y)]: self._testBinary( diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py index fde9759a1c209844caac99d5f303cd3e406e5370..ef4d5f6322b7ae79b051795b5af7e6f7f1e55550 100644 --- a/tensorflow/compiler/tests/bucketize_op_test.py +++ b/tensorflow/compiler/tests/bucketize_op_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.ops import array_ops @@ -26,7 +26,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class BucketizationOpTest(XLATestCase): +class BucketizationOpTest(xla_test.XLATestCase): def testInt(self): with self.test_session() as sess: diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index 035cdea1786d39f3d21bb63be5c8ccffe1608bdf..a4e7f75081dfd07fd4b5c94c33908aab8e7d8aa9 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -22,7 +22,7 @@ import collections import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops @@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest # TODO(srvasude): Merge this with # third_party/tensorflow/python/kernel_tests/random/multinomial_op_test.py. -class CategoricalTest(XLATestCase): +class CategoricalTest(xla_test.XLATestCase): """Test cases for random-number generating operators.""" def output_dtypes(self): diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py index 1a8989d7c2f617525c301f30fd899a01362310bf..ed532db0ee5553a275192e6cc3ebf394075fa0e1 100644 --- a/tensorflow/compiler/tests/cholesky_op_test.py +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -18,12 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import unittest - import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -32,7 +30,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class CholeskyOpTest(XLATestCase): +class CholeskyOpTest(xla_test.XLATestCase): # Cholesky defined for float64, float32, complex64, complex128 # (https://www.tensorflow.org/api_docs/python/tf/cholesky) @@ -103,9 +101,8 @@ class CholeskyOpTest(XLATestCase): with self.assertRaises(ValueError): linalg_ops.cholesky(tensor3) - @unittest.skip("Test is slow") - def testLarge(self): - n = 200 + def testLarge2000x2000(self): + n = 2000 shape = (n, n) data = np.ones(shape).astype(np.float32) / (2.0 * n) + np.diag( np.ones(n).astype(np.float32)) @@ -128,6 +125,5 @@ class CholeskyOpTest(XLATestCase): matrix = np.dot(np.dot(w, np.diag(v)), w.T).astype(dtype) self._verifyCholesky(matrix, atol=1e-4) - if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py index 574f82fc717818334ac5d72ebef2191f1c18e669..e42ebf8f9e01dab13cde15979ffc42b7c0fbc57b 100644 --- a/tensorflow/compiler/tests/clustering_test.py +++ b/tensorflow/compiler/tests/clustering_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" -class ClusteringTest(XLATestCase): +class ClusteringTest(xla_test.XLATestCase): def testAdd(self): val1 = np.array([4, 3, 2, 1], dtype=np.float32) diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index f10973e19f1945515b776cf86349445ed7334629..d9ad4281477e87f79f2ecb52989ae86a5030d0cc 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -30,7 +30,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class ConcatTest(XLATestCase): +class ConcatTest(xla_test.XLATestCase): def testHStack(self): with self.test_session(): @@ -292,7 +292,7 @@ class ConcatTest(XLATestCase): array_ops.concat([scalar, scalar, scalar], dim) -class ConcatOffsetTest(XLATestCase): +class ConcatOffsetTest(xla_test.XLATestCase): def testBasic(self): with self.test_session() as sess: @@ -306,7 +306,7 @@ class ConcatOffsetTest(XLATestCase): self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]]) -class PackTest(XLATestCase): +class PackTest(xla_test.XLATestCase): def testBasic(self): with self.test_session() as sess: diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py index d12e1ff1e8f4564f39642bd0b64fc40d8dca8ef0..f9db103f6d0f9ea0e393a0971593552ec5c14079 100644 --- a/tensorflow/compiler/tests/conv2d_test.py +++ b/tensorflow/compiler/tests/conv2d_test.py @@ -26,23 +26,20 @@ from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import test_utils -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest - DATA_FORMATS = ( ("_data_format_NHWC", "NHWC"), ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), ) -class Conv2DTest(XLATestCase, parameterized.TestCase): +class Conv2DTest(xla_test.XLATestCase, parameterized.TestCase): def _VerifyValues(self, input_sizes=None, @@ -236,7 +233,7 @@ class Conv2DTest(XLATestCase, parameterized.TestCase): expected=np.reshape([108, 128], [1, 1, 1, 2])) -class Conv2DBackpropInputTest(XLATestCase, parameterized.TestCase): +class Conv2DBackpropInputTest(xla_test.XLATestCase, parameterized.TestCase): def _VerifyValues(self, input_sizes=None, @@ -534,7 +531,7 @@ class Conv2DBackpropInputTest(XLATestCase, parameterized.TestCase): expected=[5, 0, 11, 0, 0, 0, 17, 0, 23]) -class Conv2DBackpropFilterTest(XLATestCase, parameterized.TestCase): +class Conv2DBackpropFilterTest(xla_test.XLATestCase, parameterized.TestCase): def _VerifyValues(self, input_sizes=None, diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py index 3bebf46511cbc471d3fbbbe92d28511fcc717387..31ee41f04f27d387415e9fa2c4fa70b33cab7b04 100644 --- a/tensorflow/compiler/tests/conv3d_test.py +++ b/tensorflow/compiler/tests/conv3d_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -33,7 +33,7 @@ from tensorflow.python.platform import googletest # Test cloned from # tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py -class Conv3DBackpropFilterV2GradTest(XLATestCase): +class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase): def testGradient(self): with self.test_session(), self.test_scope(): @@ -66,7 +66,7 @@ class Conv3DBackpropFilterV2GradTest(XLATestCase): # Test cloned from tensorflow/python/kernel_tests/conv3d_transpose_test.py -class Conv3DTransposeTest(XLATestCase): +class Conv3DTransposeTest(xla_test.XLATestCase): def testConv3DTransposeSingleStride(self): with self.test_session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 03d96a2cd8ab22a472a67f092e36224820405fa8..98dc73e189f99b7b811487756659d89dacb97d8a 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -114,7 +114,7 @@ def CheckGradConfigsToTest(): yield i, f, o, s, p -class DepthwiseConv2DTest(XLATestCase): +class DepthwiseConv2DTest(xla_test.XLATestCase): # This is testing that depthwise_conv2d and depthwise_conv2d_native # produce the same results. It also tests that NCHW and NWHC diff --git a/tensorflow/compiler/tests/dynamic_slice_ops_test.py b/tensorflow/compiler/tests/dynamic_slice_ops_test.py index 6a46d2ec3e7aee3a4ecfbf1ab9f622d8eb659e3c..154e36b10e6da409606ae6022aaf53e34c8e37cc 100644 --- a/tensorflow/compiler/tests/dynamic_slice_ops_test.py +++ b/tensorflow/compiler/tests/dynamic_slice_ops_test.py @@ -20,14 +20,14 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class DynamicUpdateSliceOpsTest(XLATestCase): +class DynamicUpdateSliceOpsTest(xla_test.XLATestCase): def _assertOpOutputMatchesExpected(self, op, args, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py index c109c27abe2f145685f83251e1d21ec8ddad563a..edd78153b56bb5bf1c268936fb82a60581389733 100644 --- a/tensorflow/compiler/tests/dynamic_stitch_test.py +++ b/tensorflow/compiler/tests/dynamic_stitch_test.py @@ -20,14 +20,14 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.platform import googletest -class DynamicStitchTest(XLATestCase): +class DynamicStitchTest(xla_test.XLATestCase): def _AssertDynamicStitchResultIs(self, indices, data, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index e438832a23a670596d12cbc67d71a9f561b82193..422f36d43bf38d26f057c18da716d7e281c286af 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -40,7 +40,7 @@ from tensorflow.python.platform import googletest from tensorflow.python.training import adam -class EagerTest(XLATestCase): +class EagerTest(xla_test.XLATestCase): def testBasic(self): with self.test_scope(): @@ -286,7 +286,7 @@ class EagerTest(XLATestCase): [2.0, 2.0]], embedding_matrix.numpy()) -class EagerFunctionTest(XLATestCase): +class EagerFunctionTest(xla_test.XLATestCase): def testBasic(self): with self.test_scope(): @@ -400,10 +400,25 @@ class EagerFunctionTest(XLATestCase): self.assertEqual(75, y.numpy()) self.assertEqual(30, dy.numpy()) + def testGradientTapeInDefun(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + + @function.defun + def f(): + x = constant_op.constant(1.0) + with backprop.GradientTape() as tape: + y = v0 * x + dy = tape.gradient(y, v0) + return dy + + dy = f() + self.assertEqual(1.0, dy.numpy()) + def testSliceInDefun(self): with self.test_scope(): - @function.defun(compiled=True) + @function.defun def f(x, y): return x[0::2, y:, ...] @@ -418,8 +433,24 @@ class EagerFunctionTest(XLATestCase): self.assertAllEqual(np.ones([1, 2, 4]), z.numpy()) self.assertAllEqual((2, 3, 4), dz.shape.as_list()) + def testNestedDefun(self): + self.skipTest('Nested defuns do not work on TPU at the moment') + with self.test_scope(): + + @function.defun + def times_two(x): + return 2 * x + + @function.defun + def two_x_plus_1(x): + return times_two(x) + 1 + + x = constant_op.constant([2, 3, 4]) + y = two_x_plus_1(x) + self.assertAllEqual([5, 7, 9], y.numpy()) + -class ExcessivePaddingTest(XLATestCase): +class ExcessivePaddingTest(xla_test.XLATestCase): """Test that eager execution works with TPU flattened tensors. Tensors that would normally be excessively padded when written @@ -470,6 +501,36 @@ class ExcessivePaddingTest(XLATestCase): self.assertAllEqual(100 * [[36.0]], reduced) +def multiple_tpus(): + devices = context.context().devices() + return len([d for d in devices if 'device:TPU:' in d]) > 1 + + +class MultiDeviceTest(xla_test.XLATestCase): + """Test running TPU computation on more than one core.""" + + def testBasic(self): + if not multiple_tpus(): + self.skipTest('MultiDeviceTest requires multiple TPU devices.') + + # Compute 10 on TPU core 0 + with ops.device('device:TPU:0'): + two = constant_op.constant(2) + five = constant_op.constant(5) + ten = two * five + self.assertAllEqual(10, ten) + + # Compute 6 on TPU core 1 + with ops.device('device:TPU:1'): + two = constant_op.constant(2) + three = constant_op.constant(3) + six = two * three + self.assertAllEqual(6, six) + + # Copy 10 and 6 to CPU and sum them + self.assertAllEqual(16, ten + six) + + if __name__ == '__main__': ops.enable_eager_execution( config=config_pb2.ConfigProto(log_device_placement=True)) diff --git a/tensorflow/compiler/tests/extract_image_patches_op_test.py b/tensorflow/compiler/tests/extract_image_patches_op_test.py index 0361702e7af778176daed941d64e61198090daf2..5529fdbb090315e1d7f47589777d8a538c90db2b 100644 --- a/tensorflow/compiler/tests/extract_image_patches_op_test.py +++ b/tensorflow/compiler/tests/extract_image_patches_op_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ExtractImagePatches(XLATestCase): +class ExtractImagePatches(xla_test.XLATestCase): """Functional tests for ExtractImagePatches op.""" def _VerifyValues(self, image, ksizes, strides, rates, padding, patches): diff --git a/tensorflow/compiler/tests/fake_quant_ops_test.py b/tensorflow/compiler/tests/fake_quant_ops_test.py index dfe9400ef0f55ca011d4e23ba5d735899ca2e054..c48ab178bf53558084fb500b2811c6f0b77a7943 100644 --- a/tensorflow/compiler/tests/fake_quant_ops_test.py +++ b/tensorflow/compiler/tests/fake_quant_ops_test.py @@ -17,14 +17,14 @@ from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.platform import googletest -class FakeQuantWithMinMaxArgsTest(XLATestCase): +class FakeQuantWithMinMaxArgsTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxArgs operation.""" # 8 bits, wide range. @@ -122,7 +122,7 @@ class FakeQuantWithMinMaxArgsTest(XLATestCase): result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03) -class FakeQuantWithMinMaxArgsGradientTest(XLATestCase): +class FakeQuantWithMinMaxArgsGradientTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxArgsGradient operation.""" # 8 bits, wide range. @@ -223,7 +223,7 @@ class FakeQuantWithMinMaxArgsGradientTest(XLATestCase): bfloat16_rtol=0.03) -class FakeQuantWithMinMaxVarsTest(XLATestCase): +class FakeQuantWithMinMaxVarsTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxVars operation.""" # 8 bits, wide range. @@ -328,7 +328,7 @@ class FakeQuantWithMinMaxVarsTest(XLATestCase): result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03) -class FakeQuantWithMinMaxVarsGradientTest(XLATestCase): +class FakeQuantWithMinMaxVarsGradientTest(xla_test.XLATestCase): """Test cases for FakeQuantWithMinMaxVarsGradient operation.""" # 8 bits, wide range. diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py index b2360dd009ba6e705efe51b5d4c408cbf1e13ac9..c64ea249ecb97991952a960a6d16e1bb3be35b17 100644 --- a/tensorflow/compiler/tests/fft_test.py +++ b/tensorflow/compiler/tests/fft_test.py @@ -23,7 +23,7 @@ import itertools import numpy as np import scipy.signal as sps -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.contrib.signal.python.ops import spectral_ops as signal from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -58,7 +58,7 @@ INNER_DIMS_2D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2)) INNER_DIMS_3D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2, POWS_OF_2)) -class FFTTest(XLATestCase): +class FFTTest(xla_test.XLATestCase): def _VerifyFftMethod(self, inner_dims, complex_to_input, input_to_expected, tf_method): diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0f64cc87cde77fbbef6c4e570879e992bc34bafa --- /dev/null +++ b/tensorflow/compiler/tests/fifo_queue_test.py @@ -0,0 +1,201 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.data_flow_ops.FIFOQueue.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.platform import test + + +class FIFOQueueTest(xla_test.XLATestCase): + + def testEnqueue(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + enqueue_op = q.enqueue((10.0,)) + enqueue_op.run() + + def testEnqueueWithShape(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2)) + enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],)) + enqueue_correct_op.run() + with self.assertRaises(ValueError): + q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],)) + self.assertEqual(1, q.size().eval()) + + def testMultipleDequeues(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q.enqueue([1])) + self.evaluate(q.enqueue([2])) + self.evaluate(q.enqueue([3])) + a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()]) + self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) + + def testQueuesDontShare(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q.enqueue(1)) + q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) + self.evaluate(q2.enqueue(2)) + self.assertAllEqual(self.evaluate(q2.dequeue()), 2) + self.assertAllEqual(self.evaluate(q.dequeue()), 1) + + def testEnqueueDictWithoutNames(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + with self.assertRaisesRegexp(ValueError, "must have names"): + q.enqueue({"a": 12.0}) + + def testParallelEnqueue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + # Run one producer thread for each element in elems. + def enqueue(enqueue_op): + sess.run(enqueue_op) + + threads = [ + self.checkedThread(target=enqueue, args=(e,)) for e in enqueue_ops + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # Dequeue every element using a single thread. + results = [] + for _ in xrange(len(elems)): + results.append(dequeued_t.eval()) + self.assertItemsEqual(elems, results) + + def testParallelDequeue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + # Enqueue every element using a single thread. + for enqueue_op in enqueue_ops: + enqueue_op.run() + + # Run one consumer thread for each element in elems. + results = [] + + def dequeue(): + results.append(sess.run(dequeued_t)) + + threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + self.assertItemsEqual(elems, results) + + def testDequeue(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + for enqueue_op in enqueue_ops: + enqueue_op.run() + + for i in xrange(len(elems)): + vals = dequeued_t.eval() + self.assertEqual([elems[i]], vals) + + def testEnqueueAndBlockingDequeue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32) + elems = [10.0, 20.0, 30.0] + enqueue_ops = [q.enqueue((x,)) for x in elems] + dequeued_t = q.dequeue() + + def enqueue(): + # The enqueue_ops should run after the dequeue op has blocked. + # TODO(mrry): Figure out how to do this without sleeping. + time.sleep(0.1) + for enqueue_op in enqueue_ops: + sess.run(enqueue_op) + + results = [] + + def dequeue(): + for _ in xrange(len(elems)): + results.append(sess.run(dequeued_t)) + + enqueue_thread = self.checkedThread(target=enqueue) + dequeue_thread = self.checkedThread(target=dequeue) + enqueue_thread.start() + dequeue_thread.start() + enqueue_thread.join() + dequeue_thread.join() + + for elem, result in zip(elems, results): + self.assertEqual([elem], result) + + def testMultiEnqueueAndDequeue(self): + with self.test_session() as sess, self.test_scope(): + q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32)) + elems = [(5, 10.0), (10, 20.0), (15, 30.0)] + enqueue_ops = [q.enqueue((x, y)) for x, y in elems] + dequeued_t = q.dequeue() + + for enqueue_op in enqueue_ops: + enqueue_op.run() + + for i in xrange(len(elems)): + x_val, y_val = sess.run(dequeued_t) + x, y = elems[i] + self.assertEqual([x], x_val) + self.assertEqual([y], y_val) + + def testQueueSizeEmpty(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + self.assertEqual([0], q.size().eval()) + + def testQueueSizeAfterEnqueueAndDequeue(self): + with self.test_session(), self.test_scope(): + q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) + enqueue_op = q.enqueue((10.0,)) + dequeued_t = q.dequeue() + size = q.size() + self.assertEqual([], size.get_shape()) + + enqueue_op.run() + self.assertEqual(1, size.eval()) + dequeued_t.op.run() + self.assertEqual(0, size.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index 8e6407dffdac3adbcda8cbca2109ef9196defa8c..1da97fd51217a0f28d4b3ba2ccfae3f6b094e65b 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -30,7 +30,7 @@ from tensorflow.python.training import ftrl from tensorflow.python.training import gradient_descent -class FtrlOptimizerTest(XLATestCase): +class FtrlOptimizerTest(xla_test.XLATestCase): def initVariableAndGradient(self, dtype): var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index 8a3f4b0bdc7a61d6cfa2ba7474ce8579e293a5c7..04fba444460e714ce96205361ac02ed492206b04 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class FunctionTest(XLATestCase): +class FunctionTest(xla_test.XLATestCase): def testFunction(self): """Executes a simple TensorFlow function.""" diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 5782e76734eb2bd01e377528fcfddfeeb8e7b6c2..132e42ac7a28d0769b0de12ea0cee6eae752b245 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -22,7 +22,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import test_utils -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker @@ -30,7 +30,7 @@ from tensorflow.python.ops import nn from tensorflow.python.platform import test -class FusedBatchNormTest(XLATestCase, parameterized.TestCase): +class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): def _reference_training(self, x, scale, offset, epsilon, data_format): if data_format != "NHWC": diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py index 9378b1db7245c0da3e8298e7dcd972491616b0cd..23b0aed34fb460f50c241e5a920cb4f6f613b947 100644 --- a/tensorflow/compiler/tests/gather_nd_op_test.py +++ b/tensorflow/compiler/tests/gather_nd_op_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class GatherNdTest(XLATestCase): +class GatherNdTest(xla_test.XLATestCase): def _runGather(self, params, indices): with self.test_session(): diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index 1a8c4519118f69ce51ca9a5eb95a9d706c7766cc..e9c8ef7c91a728b7dfc948fd9b315e6c9102f6a3 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -136,6 +136,20 @@ class GatherTest(xla_test.XLATestCase): self.assertAllEqual( [[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]})) + def testGatherPrecision(self): + with self.test_session() as session, self.test_scope(): + data = np.array([[0, 0, 0, 0], [0, 2 * (1 + np.exp2(-8)), 0, 0], + [0, 0, 0, 0], [0.015789, 0.0985, 0.55789, 0.3842]]) + indices = np.array([1, 2, 3, 1]) + dtype = dtypes.float32 + params_np = self._buildParams(data, dtype) + params = array_ops.placeholder(dtype=dtype) + indices_tf = constant_op.constant(indices) + gather_t = array_ops.gather(params, indices_tf) + gather_val = session.run(gather_t, feed_dict={params: params_np}) + np_val = params_np[indices] + self.assertAllEqual(np_val, gather_val) + class GatherBenchmark(test.Benchmark): """Microbenchmarks for the gather op.""" diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 7cf953ef25ef5daf8a6d4fc9985ed8dbfb2081e5..bf986ade06b11358552ee92df3169f965ce3f534 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -25,7 +25,8 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test +from tensorflow.python.compat import compat from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -41,7 +42,7 @@ def GenerateNumpyRandomRGB(shape): return np.random.randint(0, 256, shape) / 256. -class RGBToHSVTest(XLATestCase): +class RGBToHSVTest(xla_test.XLATestCase): def testBatch(self): # Build an arbitrary RGB image @@ -104,7 +105,7 @@ class RGBToHSVTest(XLATestCase): self.assertAllCloseAccordingToType(hsv_tf, hsv_np) -class AdjustContrastTest(XLATestCase): +class AdjustContrastTest(xla_test.XLATestCase): def _testContrast(self, x_np, y_np, contrast_factor): with self.test_session(): @@ -168,7 +169,7 @@ class AdjustContrastTest(XLATestCase): self.assertAllClose(y_tf, y_np, rtol=1e-5, atol=1e-5) -class AdjustHueTest(XLATestCase): +class AdjustHueTest(xla_test.XLATestCase): def testAdjustNegativeHue(self): x_shape = [2, 2, 3] @@ -303,7 +304,7 @@ class AdjustHueTest(XLATestCase): self._adjustHueTf(x_np, delta_h) -class AdjustSaturationTest(XLATestCase): +class AdjustSaturationTest(xla_test.XLATestCase): def _adjust_saturation(self, image, saturation_factor): image = ops.convert_to_tensor(image, name="image") @@ -403,7 +404,7 @@ class AdjustSaturationTest(XLATestCase): self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5) -class ResizeBilinearTest(XLATestCase): +class ResizeBilinearTest(xla_test.XLATestCase): def _assertForwardOpMatchesExpected(self, image_np, @@ -579,5 +580,140 @@ class ResizeBilinearTest(XLATestCase): large_tolerance=True) +class NonMaxSuppressionTest(xla_test.XLATestCase): + + def testNMS128From1024(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + with compat.forward_compatibility_horizon(2018, 8, 8): + num_boxes = 1024 + boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") + scores_np = np.random.normal(0.5, 0.1, (num_boxes,)).astype("f4") + + max_output_size = 128 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.0, dtype=np.float32) + + with self.test_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + score_threshold: score_threshold_np, + iou_threshold: iou_threshold_np + } + (indices_tf, _) = sess.run(selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + + def testNMS3From6Boxes(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + with compat.forward_compatibility_horizon(2018, 8, 8): + # Three boxes are selected based on IOU. + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.0, dtype=np.float32) + + with self.test_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + score_threshold: score_threshold_np, + iou_threshold: iou_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 3) + self.assertAllClose(indices_tf[:num_valid], [3, 0, 5]) + + def testNMS3Then2WithScoreThresh(self): + # Three boxes are selected based on IOU. + # One is filtered out by score threshold. + + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + with compat.forward_compatibility_horizon(2018, 8, 8): + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.4, dtype=np.float32) + + with self.test_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + iou_threshold: iou_threshold_np, + score_threshold: score_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 2) + self.assertAllClose(indices_tf[:num_valid], [3, 0]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py index 69bd8f7230d4394c45764d02a88fb0ec097c5756..253b45902fba2df64e5234f135b373cd2a0a7e2a 100644 --- a/tensorflow/compiler/tests/lrn_ops_test.py +++ b/tensorflow/compiler/tests/lrn_ops_test.py @@ -22,7 +22,7 @@ import copy import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -36,7 +36,7 @@ CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" # Local response normalization tests. The forward tests are copied from # tensorflow/python/kernel_tests/lrn_op_test.py -class LRNTest(XLATestCase): +class LRNTest(xla_test.XLATestCase): def _LRN(self, input_image, lrn_depth_radius=5, bias=1.0, alpha=1.0, beta=0.5): diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py index 29394f9ea5139b30f88f53de0469b27e37d79195..0d9f99f8a6803ecae5f9233518a1768109161ac0 100644 --- a/tensorflow/compiler/tests/matrix_band_part_test.py +++ b/tensorflow/compiler/tests/matrix_band_part_test.py @@ -19,14 +19,14 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class MatrixBandPartTest(XLATestCase): +class MatrixBandPartTest(xla_test.XLATestCase): def _testMatrixBandPart(self, dtype, shape): with self.test_session(): diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index 5819b2bf2b55b9213a039c0ba82dd0bf1c738b00..2bb8a97bdaf5836a05501ab9754433e29ae34675 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -22,7 +22,7 @@ import itertools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -35,7 +35,7 @@ def MakePlaceholder(x): return array_ops.placeholder(dtypes.as_dtype(x.dtype), shape=x.shape) -class MatrixTriangularSolveOpTest(XLATestCase): +class MatrixTriangularSolveOpTest(xla_test.XLATestCase): # MatrixTriangularSolve defined for float64, float32, complex64, complex128 # (https://www.tensorflow.org/api_docs/python/tf/matrix_triangular_solve) diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py index af9394e7d7dc9cf7dd009420ff9c845aec8785bd..c2592c54cf83d41f0e3bdbc1f4dc9ff276ddb078 100644 --- a/tensorflow/compiler/tests/momentum_test.py +++ b/tensorflow/compiler/tests/momentum_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -30,7 +30,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import momentum as momentum_lib -class MomentumOptimizerTest(XLATestCase): +class MomentumOptimizerTest(xla_test.XLATestCase): def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum): var += accum * lr * momentum diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py index e4843b169b943b63346b783ddc50039030988ca5..da08225e9fc0d5a8ec21ee9961c4758fa38628b4 100644 --- a/tensorflow/compiler/tests/nary_ops_test.py +++ b/tensorflow/compiler/tests/nary_ops_test.py @@ -22,14 +22,14 @@ import unittest import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class NAryOpsTest(XLATestCase): +class NAryOpsTest(xla_test.XLATestCase): def _testNAry(self, op, args, expected, equality_fn=None): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py index 6f588d8ab562cb24f33c4c2987df22264aede027..2f9122645d3c5ccabc8130ac30a3f09cf4bc2de7 100644 --- a/tensorflow/compiler/tests/nullary_ops_test.py +++ b/tensorflow/compiler/tests/nullary_ops_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import googletest -class NullaryOpsTest(XLATestCase): +class NullaryOpsTest(xla_test.XLATestCase): def _testNullary(self, op, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py index 5e6d1313bd0336eba71fcf3658d949bd3342ae11..a75d99189b5b673261c9e48f1c5998ea0c575594 100644 --- a/tensorflow/compiler/tests/placeholder_test.py +++ b/tensorflow/compiler/tests/placeholder_test.py @@ -18,14 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest -class PlaceholderTest(XLATestCase): +class PlaceholderTest(xla_test.XLATestCase): def test_placeholder_with_default_default(self): with self.test_session() as sess, self.test_scope(): diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py index d9285186baa9007e485ab916e573ad0de5e26e56..17f860db61aeda98326a6820771d67ee948b6dda 100644 --- a/tensorflow/compiler/tests/pooling_ops_3d_test.py +++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -41,7 +41,7 @@ def _AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding): padding=padding) -class Pooling3DTest(XLATestCase): +class Pooling3DTest(xla_test.XLATestCase): def _VerifyValues(self, pool_func, input_sizes, window, strides, padding, expected): diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index fe270af3d636c0824621f36360ce9e7d14d8fc91..9fc94752ea660f7fb8b2c792180f01485ad04419 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -69,7 +69,7 @@ def GetTestConfigs(): return test_configs -class PoolingTest(XLATestCase): +class PoolingTest(xla_test.XLATestCase): def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding, data_format, expected): @@ -288,7 +288,7 @@ class PoolingTest(XLATestCase): expected=expected_output) -class PoolGradTest(XLATestCase): +class PoolGradTest(xla_test.XLATestCase): CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" diff --git a/tensorflow/compiler/tests/powersign_test.py b/tensorflow/compiler/tests/powersign_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5fa7706d7294f2cffb7d24a56851be02d759335a --- /dev/null +++ b/tensorflow/compiler/tests/powersign_test.py @@ -0,0 +1,142 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for PowerSign.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.contrib.opt.python.training import powersign +from tensorflow.contrib.opt.python.training import sign_decay +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def py_linear_decay_fn(decay_steps): + def linear_decay(step): + step = min(step, decay_steps) + return float(decay_steps - step) / decay_steps + return linear_decay + + +def powersign_update_numpy(params, + g_t, + m, + lr, + base=math.e, + beta=0.9, + py_sign_decay_fn=None, + t=None): + m_t = beta * m + (1 - beta) * g_t + if py_sign_decay_fn is None: + sign_decayed = 1.0 + else: + sign_decayed = py_sign_decay_fn(t-1) + multiplier = base ** (sign_decayed * np.sign(g_t) * np.sign(m_t)) + params_t = params - lr * multiplier * g_t + return params_t, m_t + + +class PowerSignTest(xla_test.XLATestCase): + + def _testDense(self, + learning_rate=0.1, + sign_decay_fn=None, + py_sign_decay_fn=None, + base=math.e, + beta=0.9): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + # Initialize variables for numpy implementation. + m0, m1 = 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + global_step = resource_variable_ops.ResourceVariable(0, trainable=False) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = powersign.PowerSignOptimizer( + learning_rate=learning_rate, + base=base, + beta=beta, + sign_decay_fn=sign_decay_fn, + ) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]), + global_step=global_step) + neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]), + global_step=global_step) + + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 7 steps of powersign + # first 4 steps with positive gradient + # last 3 steps with negative gradient (sign(gm) should be -1) + for t in range(1, 8): + if t < 5: + update.run() + else: + neg_update.run() + + var0_np, m0 = powersign_update_numpy( + var0_np, + grads0_np if t < 5 else -grads0_np, + m0, + learning_rate, + base=base, + beta=beta, + py_sign_decay_fn=py_sign_decay_fn, + t=t, + ) + var1_np, m1 = powersign_update_numpy( + var1_np, + grads1_np if t < 5 else -grads1_np, + m1, + learning_rate, + base=base, + beta=beta, + py_sign_decay_fn=py_sign_decay_fn, + t=t, + ) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testDense(self): + decay_steps = 10 + sign_decay_fn = sign_decay.get_linear_decay_fn(decay_steps) + py_sign_decay_fn = py_linear_decay_fn(decay_steps) + self._testDense() + self._testDense(learning_rate=0.1, base=10.0, beta=0.8) + self._testDense( + sign_decay_fn=sign_decay_fn, py_sign_decay_fn=py_sign_decay_fn) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/tests/proximal_adagrad_test.py b/tensorflow/compiler/tests/proximal_adagrad_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cde87db63dbfd7c8d823c6fd0e41eee8b23735bb --- /dev/null +++ b/tensorflow/compiler/tests/proximal_adagrad_test.py @@ -0,0 +1,172 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Proximal Adagrad optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adagrad +from tensorflow.python.training import proximal_adagrad + + +class ProximalAdagradOptimizerTest(xla_test.XLATestCase): + + def testResourceProximalAdagradwithoutRegularization(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([0.0, 0.0]) + var1 = resource_variable_ops.ResourceVariable([0.0, 0.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + opt = proximal_adagrad.ProximalAdagradOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([0.0, 0.0], var0.eval()) + self.assertAllClose([0.0, 0.0], var1.eval()) + + # Run 3 steps Proximal Adagrad. + for _ in range(3): + update.run() + + self.assertAllClose(np.array([-2.60260963, -4.29698515]), var0.eval()) + self.assertAllClose(np.array([-0.28432083, -0.56694895]), var1.eval()) + opt_vars = opt.variables() + self.assertStartsWith(opt_vars[0].name, var0._shared_name) + self.assertStartsWith(opt_vars[1].name, var1._shared_name) + self.assertEqual(2, len(opt_vars)) + + def testProximalAdagradwithoutRegularization2(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + opt = proximal_adagrad.ProximalAdagradOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 3 steps Proximal Adagrad. + for _ in range(3): + update.run() + self.assertAllClose(np.array([-1.60261, -2.296985]), var0.eval()) + self.assertAllClose(np.array([3.715679, 2.433051]), var1.eval()) + + def testProximalAdagradWithL1(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + opt = proximal_adagrad.ProximalAdagradOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=0.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 10 steps Proximal Adagrad + for _ in range(10): + update.run() + self.assertAllClose(np.array([-6.663634, -9.190331]), var0.eval()) + self.assertAllClose(np.array([2.959304, 1.029232]), var1.eval()) + + def testProximalAdagradWithL1_L2(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + opt = proximal_adagrad.ProximalAdagradOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=2.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 10 steps Proximal Adagrad. + for _ in range(10): + update.run() + + self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval()) + self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval()) + + def applyOptimizer(self, opt, steps=5): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run ProximalAdagrad for a few steps + for _ in range(steps): + update.run() + + return var0.eval(), var1.eval() + + def testEquivAdagradwithoutRegularization(self): + with self.test_session(), self.test_scope(): + val0, val1 = self.applyOptimizer( + proximal_adagrad.ProximalAdagradOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0)) + + with self.test_session(), self.test_scope(): + val2, val3 = self.applyOptimizer( + adagrad.AdagradOptimizer( + 3.0, initial_accumulator_value=0.1)) + + self.assertAllClose(val0, val2) + self.assertAllClose(val1, val3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/proximal_gradient_descent_test.py b/tensorflow/compiler/tests/proximal_gradient_descent_test.py new file mode 100644 index 0000000000000000000000000000000000000000..11eb76871133eba8fcd24621afb03e16614fb005 --- /dev/null +++ b/tensorflow/compiler/tests/proximal_gradient_descent_test.py @@ -0,0 +1,156 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Proximal Gradient Descent optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import proximal_gradient_descent + + +class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): + + def testResourceProximalGradientDescentwithoutRegularization(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([0.0, 0.0]) + var1 = resource_variable_ops.ResourceVariable([0.0, 0.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + opt = proximal_gradient_descent.ProximalGradientDescentOptimizer( + 3.0, l1_regularization_strength=0.0, l2_regularization_strength=0.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([0.0, 0.0], var0.eval()) + self.assertAllClose([0.0, 0.0], var1.eval()) + + # Run 3 steps Proximal Gradient Descent. + for _ in range(3): + update.run() + + self.assertAllClose(np.array([-0.9, -1.8]), var0.eval()) + self.assertAllClose(np.array([-0.09, -0.18]), var1.eval()) + + def testProximalGradientDescentwithoutRegularization2(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + opt = proximal_gradient_descent.ProximalGradientDescentOptimizer( + 3.0, l1_regularization_strength=0.0, l2_regularization_strength=0.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 3 steps Proximal Gradient Descent + for _ in range(3): + update.run() + + self.assertAllClose(np.array([0.1, 0.2]), var0.eval()) + self.assertAllClose(np.array([3.91, 2.82]), var1.eval()) + + def testProximalGradientDescentWithL1(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + opt = proximal_gradient_descent.ProximalGradientDescentOptimizer( + 3.0, l1_regularization_strength=0.001, l2_regularization_strength=0.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 10 steps proximal gradient descent. + for _ in range(10): + update.run() + + self.assertAllClose(np.array([-1.988, -3.988001]), var0.eval()) + self.assertAllClose(np.array([3.67, 2.37]), var1.eval()) + + def testProximalGradientDescentWithL1_L2(self): + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + opt = proximal_gradient_descent.ProximalGradientDescentOptimizer( + 3.0, l1_regularization_strength=0.001, l2_regularization_strength=2.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([4.0, 3.0], var1.eval()) + + # Run 10 steps Proximal Gradient Descent + for _ in range(10): + update.run() + + self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval()) + self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval()) + + def applyOptimizer(self, opt, steps=5): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0]) + grads0 = constant_op.constant([0.1, 0.2]) + grads1 = constant_op.constant([0.01, 0.02]) + + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run ProximalAdagrad for a few steps + for _ in range(steps): + update.run() + + return var0.eval(), var1.eval() + + def testEquivGradientDescentwithoutRegularization(self): + with self.test_session(), self.test_scope(): + val0, val1 = self.applyOptimizer( + proximal_gradient_descent.ProximalGradientDescentOptimizer( + 3.0, + l1_regularization_strength=0.0, + l2_regularization_strength=0.0)) + + with self.test_session(), self.test_scope(): + val2, val3 = self.applyOptimizer( + gradient_descent.GradientDescentOptimizer(3.0)) + + self.assertAllClose(val0, val2) + self.assertAllClose(val1, val3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1b969ee2b3886fca6ec9951d1621ca5af6a673d8 --- /dev/null +++ b/tensorflow/compiler/tests/qr_op_test.py @@ -0,0 +1,115 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.math_ops.matrix_inverse.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from absl.testing import parameterized +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): + + def AdjustedNorm(self, x): + """Computes the norm of matrices in 'x', adjusted for dimension and type.""" + norm = np.linalg.norm(x, axis=(-2, -1)) + return norm / (max(x.shape[-2:]) * np.finfo(x.dtype).eps) + + def CompareOrthogonal(self, x, y, rank): + # We only compare the first 'rank' orthogonal vectors since the + # remainder form an arbitrary orthonormal basis for the + # (row- or column-) null space, whose exact value depends on + # implementation details. Notice that since we check that the + # matrices of singular vectors are unitary elsewhere, we do + # implicitly test that the trailing vectors of x and y span the + # same space. + x = x[..., 0:rank] + y = y[..., 0:rank] + # Q is only unique up to sign (complex phase factor for complex matrices), + # so we normalize the sign first. + sum_of_ratios = np.sum(np.divide(y, x), -2, keepdims=True) + phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios)) + x *= phases + self.assertTrue(np.all(self.AdjustedNorm(x - y) < 30.0)) + + def CheckApproximation(self, a, q, r): + # Tests that a ~= q*r. + precision = self.AdjustedNorm(a - np.matmul(q, r)) + self.assertTrue(np.all(precision < 10.0)) + + def CheckUnitary(self, x): + # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity. + xx = math_ops.matmul(x, x, adjoint_a=True) + identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0) + precision = self.AdjustedNorm(xx.eval() - identity.eval()) + self.assertTrue(np.all(precision < 5.0)) + + def _test(self, dtype, shape, full_matrices): + np.random.seed(1) + x_np = np.random.uniform( + low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype) + + with self.test_session() as sess: + x_tf = array_ops.placeholder(dtype) + with self.test_scope(): + q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices) + q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np}) + + q_dims = q_tf_val.shape + np_q = np.ndarray(q_dims, dtype) + np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1])) + new_first_dim = np_q_reshape.shape[0] + + x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1])) + for i in range(new_first_dim): + if full_matrices: + np_q_reshape[i, :, :], _ = np.linalg.qr( + x_reshape[i, :, :], mode="complete") + else: + np_q_reshape[i, :, :], _ = np.linalg.qr( + x_reshape[i, :, :], mode="reduced") + np_q = np.reshape(np_q_reshape, q_dims) + self.CompareOrthogonal(np_q, q_tf_val, min(shape[-2:])) + self.CheckApproximation(x_np, q_tf_val, r_tf_val) + self.CheckUnitary(q_tf_val) + + SIZES = [1, 2, 5, 10, 32, 100, 300] + DTYPES = [np.float32] + PARAMS = itertools.product(SIZES, SIZES, DTYPES) + + @parameterized.parameters(*PARAMS) + def testQR(self, rows, cols, dtype): + # TODO(b/111317468): implement full_matrices=False, test other types. + for full_matrices in [True]: + # Only tests the (3, 2) case for small numbers of rows/columns. + for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10): + self._test(dtype, batch_dims + (rows, cols), full_matrices) + + def testLarge2000x2000(self): + self._test(np.float32, (2000, 2000), full_matrices=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 2e71b00ba66dba93c87e565e3a372111de1f362d..cc0e9b2f98dc2cdb0382140d5172ed51d8ab2b53 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -22,7 +22,7 @@ import math import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -31,7 +31,7 @@ from tensorflow.python.ops.distributions import special_math from tensorflow.python.platform import googletest -class RandomOpsTest(XLATestCase): +class RandomOpsTest(xla_test.XLATestCase): """Test cases for random-number generating operators.""" def _random_types(self): @@ -57,7 +57,8 @@ class RandomOpsTest(XLATestCase): def testRandomUniformIsNotConstant(self): def rng(dtype): - return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000) + dtype = dtypes.as_dtype(dtype) + return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=dtype.max) for dtype in self._random_types(): self._testRngIsNotConstant(rng, dtype) @@ -73,6 +74,11 @@ class RandomOpsTest(XLATestCase): def testRandomUniformIsInRange(self): for dtype in self._random_types(): + # TODO (b/112272078): enable bfloat16 for CPU and GPU when the bug is + # fixed. + if (self.device in ["XLA_GPU", "XLA_CPU" + ]) and (dtype in [dtypes.bfloat16, dtypes.half]): + continue with self.test_session() as sess: with self.test_scope(): x = random_ops.random_uniform( @@ -124,26 +130,35 @@ class RandomOpsTest(XLATestCase): # Department of Scientific Computing website. Florida State University. expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma actual_mean = np.mean(y) - self.assertAllClose(actual_mean, expected_mean, atol=2e-4) + atol = 2e-4 + if self.device in ["XLA_GPU", "XLA_CPU"]: + atol = 2.2e-4 + self.assertAllClose(actual_mean, expected_mean, atol=atol) expected_median = mu + probit( (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma actual_median = np.median(y) - self.assertAllClose(actual_median, expected_median, atol=8e-4) + self.assertAllClose(actual_median, expected_median, atol=1e-3) expected_variance = sigma**2 * (1 + ( (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( (normal_pdf(alpha) - normal_pdf(beta)) / z)**2) actual_variance = np.var(y) - self.assertAllClose(actual_variance, expected_variance, rtol=3e-4) + rtol = 1e-3 + if self.device in ["XLA_GPU", "XLA_CPU"]: + rtol = 4e-4 + self.assertAllClose(actual_variance, expected_variance, rtol=rtol) def testShuffle1d(self): + # TODO(b/26783907): this test requires the CPU backend to implement sort. + if self.device in ["XLA_CPU"]: + return with self.test_session() as sess: with self.test_scope(): - x = math_ops.range(20) + x = math_ops.range(1 << 16) shuffle = random_ops.random_shuffle(x) result = sess.run(shuffle) - expected = range(20) + expected = range(1 << 16) # Compare sets to avoid randomness behavior changes but make sure still # have all the values. self.assertAllEqual(set(result), set(expected)) diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 16f293891d56d78885dd515bb7b9899faf0690f7..c0ea242044540b1cef44186880ba3cd92b8849d6 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -62,6 +62,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" @@ -101,6 +102,9 @@ class OpTestBuilder { OpTestBuilder& RandomInput(DataType type); OpTestBuilder& RandomInput(DataType type, std::vector dims); + // As RandomInput but the values are unique. + OpTestBuilder& RandomUniqueInput(DataType type, std::vector dims); + // Sets an attribute. template OpTestBuilder& Attr(StringPiece attr_name, T&& value); @@ -126,6 +130,7 @@ class OpTestBuilder { DataType type = DT_INVALID; bool has_dims = false; + bool needs_unique_values = false; std::vector dims; }; @@ -167,6 +172,18 @@ OpTestBuilder& OpTestBuilder::RandomInput(DataType type, return *this; } +OpTestBuilder& OpTestBuilder::RandomUniqueInput(DataType type, + std::vector dims) { + VLOG(1) << "Adding input: " << type << " " << TensorShape(dims).DebugString(); + InputDescription input; + input.type = type; + input.has_dims = true; + input.needs_unique_values = true; + input.dims = std::move(dims); + inputs_.push_back(input); + return *this; +} + template OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, T&& value) { AddNodeAttr(attr_name, std::forward(value), &node_def_); @@ -289,7 +306,8 @@ class OpTest : public ::testing::Test { // Returns a tensor filled with random but "reasonable" values from the middle // of the type's range. If the shape is omitted, a random shape is used. // TODO(phawkins): generalize this code to a caller-supplied distribution. - Tensor RandomTensor(DataType dtype, gtl::ArraySlice shape); + Tensor RandomTensor(DataType dtype, bool needs_unique_values, + gtl::ArraySlice shape); Tensor RandomTensor(DataType dtype); // Like RandomTensor, but uses values >= 0. @@ -432,49 +450,90 @@ std::vector OpTest::RandomDims(int min_rank, int max_rank, return dims; } -Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice shape) { +Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, + gtl::ArraySlice shape) { Tensor tensor(dtype, TensorShape(shape)); switch (dtype) { case DT_FLOAT: { + gtl::FlatSet already_generated; std::uniform_real_distribution distribution(-1.0f, 1.0f); - test::FillFn(&tensor, [this, &distribution](int i) -> float { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> float { + float generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } case DT_DOUBLE: { + gtl::FlatSet already_generated; std::uniform_real_distribution distribution(-1.0, 1.0); - test::FillFn(&tensor, [this, &distribution](int i) -> double { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> double { + double generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } case DT_COMPLEX64: { + gtl::FlatSet> already_generated; std::uniform_real_distribution distribution(-1.0f, 1.0f); - test::FillFn(&tensor, [this, &distribution](int i) { - return complex64(distribution(generator()), distribution(generator())); + test::FillFn(&tensor, [&](int i) { + complex64 generated; + do { + generated = + complex64(distribution(generator()), distribution(generator())); + } while ( + needs_unique_values && + !already_generated + .insert(std::make_pair(generated.real(), generated.imag())) + .second); + return generated; }); break; } case DT_INT32: { + gtl::FlatSet already_generated; std::uniform_int_distribution distribution(-(1 << 20), 1 << 20); - test::FillFn(&tensor, [this, &distribution](int i) -> int32 { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> int32 { + int32 generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } case DT_INT64: { + gtl::FlatSet already_generated; std::uniform_int_distribution distribution(-(1LL << 40), 1LL << 40); - test::FillFn(&tensor, [this, &distribution](int i) -> int64 { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> int64 { + int64 generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } case DT_BOOL: { + gtl::FlatSet already_generated; std::bernoulli_distribution distribution; - test::FillFn(&tensor, [this, &distribution](int i) -> bool { - return distribution(generator()); + test::FillFn(&tensor, [&](int i) -> bool { + bool generated; + do { + generated = distribution(generator()); + } while (needs_unique_values && + !already_generated.insert(generated).second); + return generated; }); break; } @@ -485,7 +544,7 @@ Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice shape) { } Tensor OpTest::RandomTensor(DataType dtype) { - return RandomTensor(dtype, RandomDims()); + return RandomTensor(dtype, /*needs_unique_values=*/false, RandomDims()); } Tensor OpTest::RandomNonNegativeTensor(DataType dtype, @@ -761,7 +820,8 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( VLOG(1) << "Ignoring oversize dims."; return kInvalid; } - input_tensors.push_back(RandomTensor(input.type, dims)); + input_tensors.push_back( + RandomTensor(input.type, input.needs_unique_values, dims)); } VLOG(1) << "Input: " << input_tensors.back().DebugString(); } @@ -960,7 +1020,7 @@ TEST_F(OpTest, ArgMax) { std::uniform_int_distribution(-num_dims, num_dims)(generator()); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("ArgMax") - .RandomInput(DT_FLOAT, dims) + .RandomUniqueInput(DT_FLOAT, dims) .Input(test::AsScalar(reduce_dim)) .Attr("T", DT_FLOAT) .Attr("Tidx", DT_INT32) @@ -976,7 +1036,7 @@ TEST_F(OpTest, ArgMin) { std::uniform_int_distribution(-num_dims, num_dims)(generator()); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("ArgMin") - .RandomInput(DT_FLOAT, dims) + .RandomUniqueInput(DT_FLOAT, dims) .Input(test::AsScalar(reduce_dim)) .Attr("T", DT_FLOAT) .Attr("Tidx", DT_INT32) diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index 7420724bdbeab63b39542ada59328621febad895..cea2ec816f85e88b11e6e80c91c14fca9015f45c 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -22,7 +22,7 @@ import functools import itertools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.ops import array_ops @@ -30,7 +30,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class ReduceOpsTest(XLATestCase): +class ReduceOpsTest(xla_test.XLATestCase): def _testReduction(self, tf_reduce_fn, @@ -156,7 +156,7 @@ class ReduceOpsTest(XLATestCase): self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA) -class ReduceOpPrecisionTest(XLATestCase): +class ReduceOpPrecisionTest(xla_test.XLATestCase): def _testReduceSum(self, expected_result, diff --git a/tensorflow/compiler/tests/reduce_window_test.py b/tensorflow/compiler/tests/reduce_window_test.py index e78a63465b80644d8810d9fa7433653bc4639fed..c69b6837b0f88ced844faf3713a29a1c14c8790d 100644 --- a/tensorflow/compiler/tests/reduce_window_test.py +++ b/tensorflow/compiler/tests/reduce_window_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import dtypes from tensorflow.python.framework import function @@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class ReduceWindowTest(XLATestCase): +class ReduceWindowTest(xla_test.XLATestCase): """Test cases for xla.reduce_window.""" def _reduce_window(self, operand, init, reducer, **kwargs): diff --git a/tensorflow/compiler/tests/reverse_ops_test.py b/tensorflow/compiler/tests/reverse_ops_test.py index 18fabca28c9817fc8517595fa1694a18399f54b0..d01c676e7c2fe705344f26818350c46c30451c67 100644 --- a/tensorflow/compiler/tests/reverse_ops_test.py +++ b/tensorflow/compiler/tests/reverse_ops_test.py @@ -21,14 +21,14 @@ from __future__ import print_function import itertools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class ReverseOpsTest(XLATestCase): +class ReverseOpsTest(xla_test.XLATestCase): def testReverseOneDim(self): shape = (7, 5, 9, 11) diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py index 1a5d05094e53cfecd9476d7d87f023e8a02d7458..ccfa63001653537c4d1b7140e3d745c126f9034b 100644 --- a/tensorflow/compiler/tests/reverse_sequence_op_test.py +++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ReverseSequenceTest(XLATestCase): +class ReverseSequenceTest(xla_test.XLATestCase): def _testReverseSequence(self, x, diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py index ecdce4f052bbe3eeae8697c02c891105103f4f69..ff8bbac911abe73f946464663984ff1626302882 100644 --- a/tensorflow/compiler/tests/rmsprop_test.py +++ b/tensorflow/compiler/tests/rmsprop_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -28,33 +28,104 @@ from tensorflow.python.platform import test from tensorflow.python.training import rmsprop -class RmspropTest(XLATestCase): +class RmspropTest(xla_test.XLATestCase): + + def _rmsprop_update_numpy(self, + var, + g, + mg, + rms, + mom, + lr, + decay=0.9, + momentum=0.0, + epsilon=1e-10, + centered=False): + rms_t = rms * decay + (1 - decay) * g * g + denom_t = rms_t + epsilon + if centered: + mg_t = mg * decay + (1 - decay) * g + denom_t -= mg_t * mg_t + else: + mg_t = mg + mom_t = momentum * mom + lr * g / np.sqrt(denom_t, dtype=denom_t.dtype) + var_t = var - mom_t + return var_t, mg_t, rms_t, mom_t def testBasic(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - rms_opt = rmsprop.RMSPropOptimizer(3.0) - rms_update = rms_opt.apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - - # Run 3 steps of RMSProp - for _ in range(3): - rms_update.run() - - # Validate updated params - self.assertAllCloseAccordingToType( - np.array([2.91705132e-04, 1.00029182e+00]), var0.eval()) - self.assertAllCloseAccordingToType( - np.array([2.89990854, 3.89990854]), var1.eval()) + for centered in [False, True]: + with self.test_session(), self.test_scope(): + # Initialize variables for numpy implementation. + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + mg0_np = np.array([0.0, 0.0], dtype=dtype) + mg1_np = np.array([0.0, 0.0], dtype=dtype) + rms0_np = np.array([1.0, 1.0], dtype=dtype) + rms1_np = np.array([1.0, 1.0], dtype=dtype) + mom0_np = np.array([0.0, 0.0], dtype=dtype) + mom1_np = np.array([0.0, 0.0], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + learning_rate = 3.0 + rms_opt = rmsprop.RMSPropOptimizer(learning_rate, centered=centered) + rms_update = rms_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + mg0 = rms_opt.get_slot(var0, "mg") + self.assertEqual(mg0 is not None, centered) + mg1 = rms_opt.get_slot(var1, "mg") + self.assertEqual(mg1 is not None, centered) + rms0 = rms_opt.get_slot(var0, "rms") + self.assertTrue(rms0 is not None) + rms1 = rms_opt.get_slot(var1, "rms") + self.assertTrue(rms1 is not None) + mom0 = rms_opt.get_slot(var0, "momentum") + self.assertTrue(mom0 is not None) + mom1 = rms_opt.get_slot(var1, "momentum") + self.assertTrue(mom1 is not None) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 3 steps of RMSProp + for _ in range(3): + rms_update.run() + + var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy( + var0_np, + grads0_np, + mg0_np, + rms0_np, + mom0_np, + learning_rate, + centered=centered) + var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy( + var1_np, + grads1_np, + mg1_np, + rms1_np, + mom1_np, + learning_rate, + centered=centered) + + # Validate updated params + if centered: + self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) + self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) + self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) + self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) + self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) + self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py index 3260e63b23226d736a7ddc0f21a94a8c791e0442..4292352e76ebcef7dbf41df7b857d2604a468117 100644 --- a/tensorflow/compiler/tests/scan_ops_test.py +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops @@ -69,7 +69,7 @@ def handle_options(func, x, axis, exclusive, reverse): return x -class CumsumTest(XLATestCase): +class CumsumTest(xla_test.XLATestCase): valid_dtypes = [np.float32] @@ -147,7 +147,7 @@ class CumsumTest(XLATestCase): math_ops.cumsum(input_tensor, [0]).eval() -class CumprodTest(XLATestCase): +class CumprodTest(xla_test.XLATestCase): valid_dtypes = [np.float32] diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py index 638946e234daf28dc4a34e6c33fc0f78b8e8699b..f606f88545d0b6f0b52cee9b93083a6bd91169bc 100644 --- a/tensorflow/compiler/tests/scatter_nd_op_test.py +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -22,7 +22,7 @@ import functools import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -68,7 +68,7 @@ def _NumpyUpdate(indices, updates, shape): return _NumpyScatterNd(ref, indices, updates, lambda p, u: u) -class ScatterNdTest(XLATestCase): +class ScatterNdTest(xla_test.XLATestCase): def _VariableRankTest(self, np_scatter, diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py index 305ca0c6b78d3ef985deb38816f9388e7983906b..6c4890565d2083a9493abc59bd563c4dd9fdb186 100644 --- a/tensorflow/compiler/tests/slice_ops_test.py +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -18,14 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -class SliceTest(XLATestCase): +class SliceTest(xla_test.XLATestCase): def test1D(self): for dtype in self.numeric_types: @@ -110,7 +110,7 @@ class SliceTest(XLATestCase): self.assertAllEqual([[[1, 1, 1, 1], [6, 5, 4, 3]]], result) -class StridedSliceTest(XLATestCase): +class StridedSliceTest(xla_test.XLATestCase): def test1D(self): for dtype in self.numeric_types: diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 8ae579abda9854079ee491a7254eb4d09183594a..7ff01be3cb4848d6bb85b8ab96b3ee1db6889791 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -64,20 +64,61 @@ class XlaSortOpTest(xla_test.XLATestCase): if self.device in ["XLA_CPU", "XLA_GPU"]: return - # Only bfloat16 is implemented. - bfloat16 = dtypes.bfloat16.as_numpy_dtype - if bfloat16 in self.numeric_types: - for x in [np.arange(20)]: + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) + for dtype in supported_types.intersection(self.numeric_types): + # Use small input size for bfloat16. Otherwise, we'll get duplicate values + # after conversion to bfloat16, so the possible resulting index array is + # no longer unique. + if dtype == dtypes.bfloat16.as_numpy_dtype: + array_size = 20 + k_options = [0, 1, 2, 10, 20] + else: + array_size = 200 * 1000 + k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000] + for x in [np.arange(array_size)]: np.random.shuffle(x) - for k in [0, 1, 2, 10, 20]: + for k in k_options: indices = x.argsort()[::-1][:k] def topk(v, k=k): return nn_ops.top_k(v, k=k, sorted=True) self._assertOpOutputMatchesExpected( - topk, [x.astype(bfloat16)], - expected=[x[indices].astype(bfloat16), indices]) + topk, [x.astype(dtype)], + expected=[x[indices].astype(dtype), indices]) + + def testTopK2D(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) + for dtype in supported_types.intersection(self.numeric_types): + # Use small input size for bfloat16. Otherwise, we'll get duplicate values + # after conversion to bfloat16, so the possible resulting index array is + # no longer unique. + if dtype == dtypes.bfloat16.as_numpy_dtype: + array_size = 10 + k_options = [0, 1, 2, 10] + else: + array_size = 200 * 1000 + k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000] + batch = 16 + for x in [np.arange(batch * array_size)]: + np.random.shuffle(x) + x = np.reshape(x, [batch, array_size]) + for k in k_options: + indices = x.argsort(axis=1)[::, -1:-k - 1:-1] + expected = np.sort(x, axis=1)[::, -1:-k - 1:-1] + + def topk(v, k=k): + return nn_ops.top_k(v, k=k, sorted=True) + + self._assertOpOutputMatchesExpected( + topk, [x.astype(dtype)], + expected=[expected.astype(dtype), indices]) def testTopKZeros(self): """Tests that positive and negative zeros sort correctly.""" @@ -99,7 +140,7 @@ class XlaSortOpTest(xla_test.XLATestCase): {p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)}) self.assertAllEqual( np.array([3., 0., 0., 0.], dtype=bfloat16), results[0]) - self.assertEqual(list([3, 0, 1, 2]), list(results[1])) + self.assertEqual(list([3, 0, 2, 6]), list(results[1])) def testTopKInfinities(self): """Tests that positive and negative infinity sort correctly.""" diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py index f37c34156f96761632247be4bc1b62fca54f666e..c685bc548f9f6f8f7723c6f94dfd45f5420b4a67 100644 --- a/tensorflow/compiler/tests/spacetobatch_op_test.py +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops @@ -68,7 +68,7 @@ def space_to_batch_direct(input_array, block_shape, paddings): return permuted_reshaped_padded.reshape(output_shape) -class SpaceToBatchTest(XLATestCase): +class SpaceToBatchTest(xla_test.XLATestCase): """Tests input-output pairs for the SpaceToBatch and BatchToSpace ops.""" def _testPad(self, inputs, paddings, block_size, outputs): @@ -149,7 +149,7 @@ class SpaceToBatchTest(XLATestCase): self._testOne(x_np, block_size, x_out) -class SpaceToBatchNDTest(XLATestCase): +class SpaceToBatchNDTest(xla_test.XLATestCase): """Tests input-output pairs for the SpaceToBatchND and BatchToSpaceND ops.""" def _testPad(self, inputs, block_shape, paddings, outputs): diff --git a/tensorflow/compiler/tests/stack_ops_test.py b/tensorflow/compiler/tests/stack_ops_test.py index 94342f9567ca71274609e63b0482d55637c98d51..b7dd787feff2b22a9cfb5d43a4ba6ceb6eb0b301 100644 --- a/tensorflow/compiler/tests/stack_ops_test.py +++ b/tensorflow/compiler/tests/stack_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -28,7 +28,7 @@ from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.platform import test -class StackOpTest(XLATestCase): +class StackOpTest(xla_test.XLATestCase): def testStackPushPop(self): with self.test_session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index abce190d831b25b364e393788aeeaf7dd1f2c5e1..d162675ef840131485128414b4a29e3cd89c8761 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -22,7 +22,7 @@ import math import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.contrib import stateless from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -30,7 +30,7 @@ from tensorflow.python.ops.distributions import special_math from tensorflow.python.platform import test -class StatelessRandomOpsTest(XLATestCase): +class StatelessRandomOpsTest(xla_test.XLATestCase): """Test cases for stateless random-number generator operators.""" def _random_types(self): diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index ef047005b60bd156a677050368ef67ae030d6c3a..effa5a59fee7dda543b2c409dfaa27a972a55808 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops @@ -28,7 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class TernaryOpsTest(XLATestCase): +class TernaryOpsTest(xla_test.XLATestCase): def _testTernary(self, op, a, b, c, expected): with self.test_session() as session: diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index a24abd7547a423bd87ae46761d37c29caa67055d..73adb0d243b3b27e6c6ba669b2fd134a5976a2ec 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -23,7 +23,7 @@ import unittest import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import bitwise_ops @@ -44,7 +44,7 @@ def nhwc_to_format(x, data_format): raise ValueError("Unknown format {}".format(data_format)) -class UnaryOpsTest(XLATestCase): +class UnaryOpsTest(xla_test.XLATestCase): """Test cases for unary operators.""" def _assertOpOutputMatchesExpected(self, @@ -361,6 +361,12 @@ class UnaryOpsTest(XLATestCase): np.array([[-0.05, 6.05, 5]], dtype=dtype), expected=np.array([[0, 6, 5]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + nn_ops.softmax, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array([0.032058604, 0.087144323, 0.23688284, 0.64391428], + dtype=dtype)) + self._assertOpOutputMatchesExpected( nn_ops.softmax, np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), @@ -369,6 +375,14 @@ class UnaryOpsTest(XLATestCase): [0.032058604, 0.087144323, 0.23688284, 0.64391428]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + nn_ops.softmax, + np.array([[[1, 1], [1, 1]], [[1, 2], [3, 4]]], dtype=dtype), + expected=np.array( + [[[0.5, 0.5], [0.5, 0.5]], + [[0.26894142, 0.73105858], [0.26894142, 0.73105858]]], + dtype=dtype)) + self._assertOpOutputMatchesExpected( nn_ops.softsign, np.array([[-2, -1, 0, 1, 2]], dtype=dtype), @@ -382,6 +396,62 @@ class UnaryOpsTest(XLATestCase): expected=np.array( [[True, False, True], [False, True, True]], dtype=np.bool)) + self._assertOpOutputMatchesExpected( + math_ops.lgamma, + np.array( + [[1, 2, 3], [4, 5, 6], [1 / 2, 3 / 2, 5 / 2], + [-3 / 2, -7 / 2, -11 / 2]], + dtype=dtype), + expected=np.array( + [ + [0, 0, np.log(2.0)], + [np.log(6.0), np.log(24.0), + np.log(120)], + [ + np.log(np.pi) / 2, + np.log(np.pi) / 2 - np.log(2), + np.log(np.pi) / 2 - np.log(4) + np.log(3) + ], + [ + np.log(np.pi) / 2 - np.log(3) + np.log(4), + np.log(np.pi) / 2 - np.log(105) + np.log(16), + np.log(np.pi) / 2 - np.log(10395) + np.log(64), + ], + ], + dtype=dtype)) + + self._assertOpOutputMatchesExpected( + math_ops.digamma, + np.array( + [[1.0, 0.5, 1 / 3.0], [0.25, 1 / 6.0, 0.125], [2.0, 3.0, 4.0], + [6.0, 8.0, 9.0]], + dtype=dtype), + expected=np.array( + [ + [ + -np.euler_gamma, -2 * np.log(2) - np.euler_gamma, + -np.pi / 2 / np.sqrt(3) - 3 * np.log(3) / 2 - + np.euler_gamma + ], + [ + -np.pi / 2 - 3 * np.log(2) - np.euler_gamma, + -np.pi * np.sqrt(3) / 2 - 2 * np.log(2) - + 3 * np.log(3) / 2 - np.euler_gamma, + -np.pi / 2 - 4 * np.log(2) - + (np.pi + np.log(2 + np.sqrt(2)) - np.log(2 - np.sqrt(2))) + / np.sqrt(2) - np.euler_gamma + ], + [ + 1 - np.euler_gamma, 1.5 - np.euler_gamma, + 11 / 6.0 - np.euler_gamma + ], + [ + 137 / 60.0 - np.euler_gamma, 363 / 140.0 - np.euler_gamma, + 761 / 280.0 - np.euler_gamma + ], + ], + dtype=dtype)) + def quantize_and_dequantize_v2(x): return array_ops.quantize_and_dequantize_v2( x, -127, 127, signed_input=True, num_bits=8) diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index bd616f2a20cabfe1e85d325f592565171a1297c2..dd2c252d383bca9c59033ac07e442b487e4975a6 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -37,7 +37,7 @@ from tensorflow.python.platform import googletest from tensorflow.python.training.gradient_descent import GradientDescentOptimizer -class VariableOpsTest(XLATestCase): +class VariableOpsTest(xla_test.XLATestCase): """Test cases for resource variable operators.""" def testOneWriteOneOutput(self): @@ -435,7 +435,7 @@ class StridedSliceAssignChecker(object): self.test.assertAllEqual(val, valnp) -class SliceAssignTest(XLATestCase): +class SliceAssignTest(xla_test.XLATestCase): def testSliceAssign(self): for dtype in self.numeric_types: diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py index f79eb27435cc954cebde4357c1d946a320f4ed75..b637cf31cfc303ebe84ce8307ef4ad8b0b5cd720 100644 --- a/tensorflow/compiler/tests/while_test.py +++ b/tensorflow/compiler/tests/while_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -29,7 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class WhileTest(XLATestCase): +class WhileTest(xla_test.XLATestCase): def testSingletonLoopHandrolled(self): # Define a function for the loop body diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index f0b010fa67f2ffb3f81fd14d4d89585f716b4890..85084bb1240cf05f6eabfbea772df113cabe613c 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -20,14 +20,16 @@ from __future__ import print_function import numpy as np -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.platform import test -class XlaDeviceTest(XLATestCase): +class XlaDeviceTest(xla_test.XLATestCase): def testCopies(self): """Tests that copies onto and off XLA devices work.""" @@ -47,6 +49,34 @@ class XlaDeviceTest(XLATestCase): result = sess.run(z, {x: inputs}) self.assertAllCloseAccordingToType(result, inputs + inputs) + def testCopiesOfUnsupportedTypesFailGracefully(self): + """Tests that copies of unsupported types don't crash.""" + test_types = set([ + np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.int32, + np.int64, np.float16, np.float32, np.float16, + dtypes.bfloat16.as_numpy_dtype + ]) + shape = (10, 10) + for unsupported_dtype in test_types - self.all_types: + with self.test_session() as sess: + with ops.device("CPU"): + x = array_ops.placeholder(unsupported_dtype, shape) + with self.test_scope(): + y, = array_ops.identity_n([x]) + with ops.device("CPU"): + z = array_ops.identity(y) + + inputs = np.random.randint(-100, 100, shape) + inputs = inputs.astype(unsupported_dtype) + # Execution should either succeed or raise an InvalidArgumentError, + # but not crash. Even "unsupported types" may succeed here since some + # backends (e.g., the CPU backend) are happy to handle buffers of + # unsupported types, even if they cannot compute with them. + try: + sess.run(z, {x: inputs}) + except errors.InvalidArgumentError: + pass + def testControlTrigger(self): with self.test_session() as sess: with self.test_scope(): diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index a7b9cc6c811098e7af95d10ec739b26508720548..61759fd2764205fab7fce11c4003e84be1be813a 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -81,7 +81,7 @@ cc_library( "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -91,6 +91,18 @@ cc_library( ], ) +cc_library( + name = "cpu_function_runtime", + srcs = ["cpu_function_runtime.cc"], + hdrs = ["cpu_function_runtime.h"], + deps = [ + # Keep dependencies to a minimum here; this library is used in every AOT + # binary produced by tfcompile. + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:framework_lite", + ], +) + cc_library( name = "xla_compiled_cpu_function", srcs = ["xla_compiled_cpu_function.cc"], @@ -99,12 +111,23 @@ cc_library( deps = [ # Keep dependencies to a minimum here; this library is used in every AOT # binary produced by tfcompile. - "//tensorflow/compiler/aot:runtime", + ":cpu_function_runtime", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", ], ) +tf_cc_test( + name = "cpu_function_runtime_test", + srcs = ["cpu_function_runtime_test.cc"], + deps = [ + ":cpu_function_runtime", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "xla_jit_compiled_cpu_function", srcs = ["xla_jit_compiled_cpu_function.cc"], @@ -119,6 +142,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service/cpu:cpu_executable", "//tensorflow/core:lib", @@ -162,15 +186,19 @@ cc_library( ":sharding_util", ":tf2xla_util", "//tensorflow/compiler/tf2xla/lib:util", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -198,7 +226,7 @@ cc_library( ], visibility = [":friends"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:core_cpu_internal", @@ -281,10 +309,12 @@ tf_cc_test( deps = [ ":tf2xla", ":tf2xla_proto", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -323,7 +353,7 @@ tf_cc_test( "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:client_library", @@ -360,6 +390,7 @@ tf_cc_test( ], deps = [ ":common", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/core:framework", "//tensorflow/core:test", diff --git a/tensorflow/compiler/aot/runtime.cc b/tensorflow/compiler/tf2xla/cpu_function_runtime.cc similarity index 83% rename from tensorflow/compiler/aot/runtime.cc rename to tensorflow/compiler/tf2xla/cpu_function_runtime.cc index 5e74079fc158379b8977ada6412141e39142c3d3..2ffad2af8cfe621f0cbbdd8a9484ef2dfdf1b129 100644 --- a/tensorflow/compiler/aot/runtime.cc +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,22 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/aot/runtime.h" - -#include +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/core/platform/dynamic_annotations.h" namespace tensorflow { -namespace tfcompile { -namespace runtime { - namespace { - // Inline memory allocation routines here, because depending on '//base' brings // in libraries which use c++ streams, which adds considerable code size on // android. -inline void* aligned_malloc(size_t size, int minimum_alignment) { +void* aligned_malloc(size_t size, int minimum_alignment) { #if defined(__ANDROID__) || defined(OS_ANDROID) || defined(OS_CYGWIN) return memalign(minimum_alignment, size); #elif defined(_WIN32) @@ -47,7 +41,7 @@ inline void* aligned_malloc(size_t size, int minimum_alignment) { #endif } -inline void aligned_free(void* aligned_memory) { +void aligned_free(void* aligned_memory) { #if defined(_WIN32) _aligned_free(aligned_memory); #else @@ -58,13 +52,13 @@ inline void aligned_free(void* aligned_memory) { size_t align_to(size_t n, size_t align) { return (((n - 1) / align) + 1) * align; } - } // namespace -size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n) { +namespace cpu_function_runtime { +size_t AlignedBufferBytes(const intptr_t* sizes, size_t n) { size_t total = 0; for (size_t i = 0; i < n; ++i) { - if (sizes[i] != -1) { + if (sizes[i] > 0) { total += align_to(sizes[i], kAlign); } } @@ -73,7 +67,7 @@ size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n) { void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, bool annotate_initialized) { - const size_t total = aligned_buffer_bytes(sizes, n); + const size_t total = AlignedBufferBytes(sizes, n); void* contiguous = nullptr; if (total > 0) { contiguous = aligned_malloc(total, kAlign); @@ -85,7 +79,9 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, } uintptr_t pos = reinterpret_cast(contiguous); for (size_t i = 0; i < n; ++i) { - if (sizes[i] == -1) { + if (sizes[i] < 0) { + // bufs[i] is either a constant, an entry parameter or a thread local + // allocation. bufs[i] = nullptr; } else { bufs[i] = reinterpret_cast(pos); @@ -100,7 +96,5 @@ void FreeContiguous(void* contiguous) { aligned_free(contiguous); } } - -} // namespace runtime -} // namespace tfcompile +} // namespace cpu_function_runtime } // namespace tensorflow diff --git a/tensorflow/compiler/aot/runtime.h b/tensorflow/compiler/tf2xla/cpu_function_runtime.h similarity index 70% rename from tensorflow/compiler/aot/runtime.h rename to tensorflow/compiler/tf2xla/cpu_function_runtime.h index d1a669ceb17b9fd71d26e978035283f8824b0376..c7b4559c65731d1c4f4ea41e8be173ba89fe359c 100644 --- a/tensorflow/compiler/aot/runtime.h +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,25 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file contains utilities to make it easier to invoke functions generated -// by tfcompile. Usage of these utilities is optional. - -#ifndef TENSORFLOW_COMPILER_AOT_RUNTIME_H_ -#define TENSORFLOW_COMPILER_AOT_RUNTIME_H_ +#ifndef TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_ +#define TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_ #include "tensorflow/core/platform/types.h" namespace tensorflow { -namespace tfcompile { -namespace runtime { +namespace cpu_function_runtime { // Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. -static constexpr size_t kAlign = 64; +constexpr size_t kAlign = 64; -// aligned_buffer_bytes returns the sum of each size in `sizes`, skipping -1 -// values. There are `n` entries in `sizes`. Each buffer is aligned to kAlign -// byte boundaries. -size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n); +// AlignedBufferBytes returns the sum of each size in `sizes`, skipping -1 +// values. There are `n` entries in `sizes`. Each buffer is aligned to +// kAlign byte boundaries. +size_t AlignedBufferBytes(const intptr_t* sizes, size_t n); // MallocContiguousBuffers allocates buffers for use by the entry point // generated by tfcompile. `sizes` is an array of byte sizes for each buffer, @@ -41,8 +37,8 @@ size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n); // temporary buffers. // // A single contiguous block of memory is allocated, and portions of it are -// parceled out into `bufs`, which must have space for `n` entries. Returns the -// head of the allocated contiguous block, which should be passed to +// parceled out into `bufs`, which must have space for `n` entries. Returns +// the head of the allocated contiguous block, which should be passed to // FreeContiguous when the buffers are no longer in use. void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, bool annotate_initialized); @@ -50,9 +46,7 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs, // FreeContiguous frees the contiguous block of memory allocated by // MallocContiguousBuffers. void FreeContiguous(void* contiguous); - -} // namespace runtime -} // namespace tfcompile +} // namespace cpu_function_runtime } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_AOT_RUNTIME_H_ +#endif // TENSORFLOW_COMPILER_TF2XLA_CPU_FUNCTION_RUNTIME_H_ diff --git a/tensorflow/compiler/aot/runtime_test.cc b/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc similarity index 71% rename from tensorflow/compiler/aot/runtime_test.cc rename to tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc index 06ec623eb2dce5f8dc7156fb7e7b9ad57d90c8ee..f4f27a156261ea6872777cef76ecaf7dd7eebe0d 100644 --- a/tensorflow/compiler/aot/runtime_test.cc +++ b/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc @@ -13,39 +13,37 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/aot/runtime.h" +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { -namespace tfcompile { -namespace runtime { namespace { -TEST(Runtime, AlignmentValue) { +TEST(XlaCompiledCpuFunctionTest, AlignmentValue) { // We've chosen 64 byte alignment for the tfcompile runtime to mimic the // regular tensorflow allocator, which was chosen to play nicely with Eigen. // The tfcompile runtime also has a requirement that comes from the xla // generated code, on the relation: buffer_size >= 16 ? 2 * sizeof(void*) : 8 // So any value that we choose must abide by that constraint as well. - EXPECT_EQ(kAlign, Allocator::kAllocatorAlignment); + EXPECT_EQ(cpu_function_runtime::kAlign, Allocator::kAllocatorAlignment); } -TEST(Runtime, AlignedBufferBytes) { - EXPECT_EQ(aligned_buffer_bytes(nullptr, 0), 0); +TEST(XlaCompiledCpuFunctionTest, AlignedBufferBytes) { + EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(nullptr, 0), 0); static constexpr intptr_t sizesA[1] = {-1}; - EXPECT_EQ(aligned_buffer_bytes(sizesA, 1), 0); + EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesA, 1), 0); static constexpr intptr_t sizesB[1] = {3}; - EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 64); + EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesB, 1), 64); static constexpr intptr_t sizesC[1] = {32}; - EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 64); + EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesC, 1), 64); static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; - EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 320); + EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesD, 7), 320); } void* add_ptr(void* base, uintptr_t delta) { @@ -56,48 +54,49 @@ void* add_ptr(void* base, uintptr_t delta) { // expected nullptrs, and write to each byte of allocated memory. We rely on // the leak checker to tell us if there's an inconsistency between malloc and // free. We also check the contiguous property. -TEST(Runtime, MallocFreeContiguousBuffers) { +TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) { // Test empty sizes. - void* base = MallocContiguousBuffers(nullptr, 0, nullptr, false); + void* base = + cpu_function_runtime::MallocContiguousBuffers(nullptr, 0, nullptr, false); EXPECT_EQ(base, nullptr); - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); // Test non-empty sizes with 0 sum. static constexpr intptr_t sizesA[1] = {-1}; void* bufA[1]; - base = MallocContiguousBuffers(sizesA, 1, bufA, false); + base = cpu_function_runtime::MallocContiguousBuffers(sizesA, 1, bufA, false); EXPECT_EQ(base, nullptr); EXPECT_EQ(bufA[0], nullptr); - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); // Test non-empty sizes with non-0 sum. static constexpr intptr_t sizesB[1] = {3}; void* bufB[1]; - base = MallocContiguousBuffers(sizesB, 1, bufB, false); + base = cpu_function_runtime::MallocContiguousBuffers(sizesB, 1, bufB, false); EXPECT_NE(base, nullptr); EXPECT_EQ(bufB[0], add_ptr(base, 0)); char* bufB0_bytes = static_cast(bufB[0]); bufB0_bytes[0] = 'A'; bufB0_bytes[1] = 'B'; bufB0_bytes[2] = 'C'; - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); // Test non-empty sizes with non-0 sum, and annotate_initialized. static constexpr intptr_t sizesC[1] = {3}; void* bufC[1]; - base = MallocContiguousBuffers(sizesC, 1, bufC, true); + base = cpu_function_runtime::MallocContiguousBuffers(sizesC, 1, bufC, true); EXPECT_NE(base, nullptr); EXPECT_EQ(bufC[0], add_ptr(base, 0)); char* bufC0_bytes = static_cast(bufC[0]); bufC0_bytes[0] = 'A'; bufC0_bytes[1] = 'B'; bufC0_bytes[2] = 'C'; - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); // Test mixed sizes. static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; void* bufD[7]; - base = MallocContiguousBuffers(sizesD, 7, bufD, false); + base = cpu_function_runtime::MallocContiguousBuffers(sizesD, 7, bufD, false); EXPECT_NE(base, nullptr); EXPECT_EQ(bufD[0], add_ptr(base, 0)); EXPECT_EQ(bufD[1], nullptr); @@ -115,10 +114,8 @@ TEST(Runtime, MallocFreeContiguousBuffers) { } } } - FreeContiguous(base); + cpu_function_runtime::FreeContiguous(base); } } // namespace -} // namespace runtime -} // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index 03603ee9baefd1d20d220faf63c9c1c427ebdf31..24616c01c7e54b2e8662457ca6af23a0bc563e08 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -33,7 +33,7 @@ struct NameCounts { std::unordered_map counts; }; -string MakeUniquePath(string name) { +string MakeUniqueFilename(string name) { static NameCounts& instance = *new NameCounts; // Remove illegal characters from `name`. @@ -50,26 +50,41 @@ string MakeUniquePath(string name) { count = instance.counts[name]++; } - legacy_flags::DumpGraphFlags* flags = legacy_flags::GetDumpGraphFlags(); - string path = strings::StrCat(flags->tf_dump_graph_prefix, "/", name); + string filename = name; if (count > 0) { - strings::StrAppend(&path, "_", count); + strings::StrAppend(&filename, "_", count); } - strings::StrAppend(&path, ".pbtxt"); - return path; + strings::StrAppend(&filename, ".pbtxt"); + return filename; +} + +string WriteTextProtoToUniqueFile( + Env* env, const string& name, const char* proto_type, + const ::tensorflow::protobuf::Message& proto) { + const string& dirname = + legacy_flags::GetDumpGraphFlags()->tf_dump_graph_prefix; + Status status = env->RecursivelyCreateDir(dirname); + if (!status.ok()) { + LOG(WARNING) << "Failed to create " << dirname << " for dumping " + << proto_type << ": " << status; + return "(unavailable)"; + } + string filepath = strings::StrCat(dirname, "/", MakeUniqueFilename(name)); + status = WriteTextProto(Env::Default(), filepath, proto); + if (!status.ok()) { + LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath + << " : " << status; + return "(unavailable)"; + } + LOG(INFO) << "Dumped " << proto_type << " to " << filepath; + return filepath; } } // anonymous namespace string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { - string path = MakeUniquePath(name); - Status status = WriteTextProto(Env::Default(), path, graph_def); - if (!status.ok()) { - VLOG(1) << "Failed to dump GraphDef to file: " << path << " : " << status; - path.clear(); - path = "(unavailable)"; - } - return path; + return WriteTextProtoToUniqueFile(Env::Default(), name, "GraphDef", + graph_def); } string DumpGraphToFile(const string& name, Graph const& graph, @@ -83,15 +98,7 @@ string DumpGraphToFile(const string& name, Graph const& graph, } string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { - string path = MakeUniquePath(name); - Status status = WriteTextProto(Env::Default(), path, fdef); - if (!status.ok()) { - VLOG(1) << "Failed to dump FunctionDef to file: " << path << " : " - << status; - path.clear(); - path = "(unavailable)"; - } - return path; + return WriteTextProtoToUniqueFile(Env::Default(), name, "FunctionDef", fdef); } } // namespace dump_graph diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 6cc95149a16a59fce8486c5d103ad09e3e262765..0904778f97c95628c81054cd4bc2ff32ff440a33 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -177,8 +177,8 @@ Status CheckNoCycleContains(const Node* node, const int num_nodes) { visited[current_node->id()] = true; for (const Edge* out : current_node->out_edges()) { if (out->dst() == node) { - return errors::Internal("Detect a cycle: Node \"", node->name(), "\"(", - node->def().op(), ") feeds into itself."); + return errors::Internal("Detected a cycle: ", FormatNodeForError(*node), + "(", node->def().op(), ") feeds into itself."); } else if (!visited[out->dst()->id()]) { ready.push_back(out->dst()); } @@ -324,7 +324,7 @@ Status AddMissingFunctionDef(const FunctionDef& fdef, if (library->Find(node.op())) { continue; } - // The function refered by 'SymbolicGradient' node is specified in its + // The function referred by 'SymbolicGradient' node is specified in its // attribute 'f'. if (node.op() == FunctionLibraryDefinition::kGradientOp) { const AttrValue* attr = @@ -437,22 +437,24 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, continue; } if (enter_merge != nullptr) { - return errors::Internal( - "Enter node for loop-varying argument ", arg.enter->name(), - " has multiple successors: ", enter_merge->dst()->name(), " and ", - e->dst()->name()); + return errors::Internal("Enter node for loop-varying argument ", + FormatNodeForError(*arg.enter), + " has multiple successors: ", + FormatNodeForError(*enter_merge->dst()), + " and ", FormatNodeForError(*e->dst())); } enter_merge = e; } if (enter_merge == nullptr) { return errors::Internal("Enter node for loop-varying argument ", - arg.enter->name(), " has zero successors"); + FormatNodeForError(*arg.enter), + " has zero successors"); } arg.merge = enter_merge->dst(); if (!IsMerge(arg.merge)) { return errors::InvalidArgument( "Successor of Enter node for loop-varying argument ", - arg.merge->name(), + FormatNodeForError(*arg.merge), " is not a Merge node; got: ", arg.merge->type_string()); } @@ -462,7 +464,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, return errors::InvalidArgument( "Unexpected number of inputs to Merge node for loop-varying " "argument ", - arg.merge->name(), "; expected 2, got ", + FormatNodeForError(*arg.merge), "; expected 2, got ", arg.merge->input_types().size()); } TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(), @@ -470,7 +472,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, if (!IsNextIteration(arg.next_iteration)) { return errors::InvalidArgument( "Expected NextIteration node as input to Merge node; got node ", - arg.next_iteration->name(), " with kind ", + FormatNodeForError(*arg.next_iteration), " with kind ", arg.next_iteration->type_string()); } @@ -481,14 +483,14 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, switches.find(edge->dst()) != switches.end()) { if (arg.switch_node != nullptr) { return errors::InvalidArgument("Duplicate Switch successors to ", - arg.merge->name()); + FormatNodeForError(*arg.merge)); } arg.switch_node = edge->dst(); } } if (arg.switch_node == nullptr) { return errors::InvalidArgument("Missing Switch successor to ", - arg.merge->name()); + FormatNodeForError(*arg.merge)); } // Update the device on the Identity outputs of the switch to match their @@ -516,14 +518,15 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, possible_exit.pop_front(); if (IsExit(edge->dst())) { if (arg.exit != nullptr) { - return errors::InvalidArgument("Duplicate Exit successors to ", - arg.switch_node->name()); + return errors::InvalidArgument( + "Duplicate Exit successors to ", + FormatNodeForError(*arg.switch_node)); } arg.exit = edge->dst(); } else { if (!IsIdentity(edge->dst())) { return errors::Unimplemented("General graph between switch (", - arg.switch_node->name(), + FormatNodeForError(*arg.switch_node), ") and exit node of frame ", frame->name, " not supported yet."); } @@ -1470,7 +1473,7 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, if (!unreachable_nodes.empty()) { return errors::InvalidArgument( "The following nodes are unreachable from the source in the graph: ", - tensorflow::str_util::Join(unreachable_nodes, ", ")); + errors::FormatNodeNamesForError(unreachable_nodes)); } // Builds Frames, indexed by name. diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index aae2f8ee5acd6249f8b6002d94c877f18064f936..ccf249b35d66861888ad5e5e904b5f63b8ac50a1 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -1064,7 +1064,10 @@ TEST(FunctionalizeControlFlow, Cycle) { // less -> XlaIf <--> identity. Status status = FunctionalizeControlFlow(graph.get(), &library); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "Detect a cycle")) + EXPECT_TRUE(str_util::StrContains(status.error_message(), "Detected a cycle")) + << status.error_message(); + EXPECT_TRUE( + str_util::StrContains(status.error_message(), "{{node cond/Less_5_If}}")) << status.error_message(); } diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 4900af6df17f360630abb1e64b7f144ccd4a0289..e4fdf0a6186eb69a2e3413838c91616b992ef2d6 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" @@ -161,9 +161,8 @@ Status GraphCompiler::Compile() { outputs.resize(n->num_outputs()); for (int o = 0; o < n->num_outputs(); ++o) { outputs[o] = op_context.release_output(o); - if (*op_context.is_output_dead() || outputs[o].tensor == nullptr) { + if (outputs[o].tensor == nullptr) { return errors::Internal("Missing xla_context ", o, "-th output from ", - (*op_context.is_output_dead() ? "(dead)" : ""), SummarizeNode(*n)); } } diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 45657bb150cec77f9a03cae0ae6dc8b6f302533d..3bfe74521fb30639cb08495c729cbaf6232dd996 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -58,6 +58,7 @@ tf_kernel_library( "pack_op.cc", "pad_op.cc", "pooling_ops.cc", + "qr_op.cc", "quantize_and_dequantize_op.cc", "random_ops.cc", "reduce_window_op.cc", @@ -107,6 +108,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", "//tensorflow/compiler/tf2xla/lib:cholesky", + "//tensorflow/compiler/tf2xla/lib:qr", "//tensorflow/compiler/tf2xla/lib:random", "//tensorflow/compiler/tf2xla/lib:scatter", "//tensorflow/compiler/tf2xla/lib:triangular_solve", @@ -114,14 +116,22 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/lib:while_loop", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/compiler/xla/client/lib:pooling", + "//tensorflow/compiler/xla/client/lib:prng", + "//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/core:framework", "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", @@ -156,8 +166,9 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -172,8 +183,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -207,10 +218,11 @@ tf_kernel_library( ":index_ops_kernel_argmax_float_2d", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/kernels:argmax_op", diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc index e33532828040123243f839ab1aa655b4bbc72520..41a453da80dec6b6f57a4d222e2c33ef6b786a10 100644 --- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 26fc1620a4f032b3af28de6e3a5af0e965e82341..276d744c096f8996c774964204feaa3762bdb844 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -65,6 +65,6 @@ class XlaArgOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(XlaArgOp); }; -REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes(), XlaArgOp); +REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes().CompilationOnly(), XlaArgOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index b0ba25b9983c3a9af26728ce4b1c263c844327db..4cfe946b2e6146f034867c06e996ffae42b90705 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -28,11 +28,10 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1), + auto result = BatchDot(ctx->Input(0), ctx->Input(1), /*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_, /*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_); - OP_REQUIRES_OK(ctx, result.status()); - ctx->SetOutput(0, result.ValueOrDie()); + ctx->SetOutput(0, result); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index c4af79281d2162b1dbfb0a7881720892f4bc49d2..b3ad0aea84eef601de08909f760699b8700d28f4 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 26130fd9e7fce75c6d2a5a53cfc85842cf762b35..48f2a005ab16651fe29d0f6f9d881f95693da461 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index ee2c920453c3bbaef2c145df743fddf999167c39..ba3b1c9dab79a387c48e8e25e4804917f328f8a0 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/bcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index e9b2c0b16d39cb3b747c0316621fb01de709b12e..41f540506ba41fbe7f91393e7b8e26a89e72ef0a 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index d6d4ae89376b67c14af8ef4f3a608fcc83b6fb59..2c328102e0bd84709707f102272691b6aec9a577 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc index efbdb76eaaf78904fe783a018940b1b096ec39bd..5078f8662bd397eaa51274ec816c130b8ced92cc 100644 --- a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index 62eebf762b3e063da8ec456cc4726d3cc9b77d1d..8cc2479dd555380da7500abe6b2aca380110333b 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index c137d026bda7d9263d6bec85b13d5ce1dc040038..e7fef77edcba0ea5a521956a704225ac4f7fcb22 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -74,16 +74,14 @@ class CategoricalOp : public XlaOpKernel { // See: // https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ // TODO(b/68769470): Switch to using a cumulative sum approach. - auto softmax_entries = - xla::Sub(logits, xla::Log(xla::Neg(xla::Log(uniforms))), - /*broadcast_dimensions=*/{0, 2}); - - TensorShape softmax_shape(uniform_shape_array); - xla::XlaOp argmax; - OP_REQUIRES_OK( - ctx, - XlaHelpers::ArgMax(builder, ctx, softmax_entries, softmax_shape, - input_type(0), output_type(0), /*axis=*/2, &argmax)); + auto softmax_entries = xla::Sub(logits, xla::Log(-xla::Log(uniforms)), + /*broadcast_dimensions=*/{0, 2}); + + xla::PrimitiveType xla_output_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(output_type(0), &xla_output_type)); + xla::XlaOp argmax = + XlaHelpers::ArgMax(softmax_entries, xla_output_type, /*axis=*/2); ctx->SetOutput(0, argmax); } diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index fe6651793dc763d13f4a4b0ac294ec3ecf64af8f..9fcbc86adc0967cbb7fb73da8bdabc58b60953da 100644 --- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -24,12 +24,7 @@ class CholeskyOp : public XlaOpKernel { public: explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - auto result = Cholesky(ctx->builder(), ctx->Input(0)); - if (!result.ok()) { - ctx->SetStatus(result.status()); - return; - } - ctx->SetOutput(0, result.ValueOrDie()); + ctx->SetOutput(0, Cholesky(ctx->Input(0))); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc index 4e6d33304c4ae08a0fd1e0a8373267a527087528..547fe48046e8c934e3bc14d02c8448e107c1a406 100644 --- a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index e3a32a5c0e2f93237c8c7ebeea3668b5d1ab6c23..f4106051043859a6786705009d76b02a64cd3ff1 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index f4360d8c3f6fc4007c31fdcfd7f7634de15c76d4..da8cf3fc6fa694f592280f8c249d317827d9cd09 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 5d41fc708ac3c1ea8bfb72e0e57f9a12bb99407a..5da7972397b32fb4a2f216913e065c04131a3773 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -18,7 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -96,14 +97,9 @@ xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape, // Create a M sized linspace and an M*N sized linspace that will be // broadcasted into perpendicular dimensions and compared. - xla::XlaOp input_feature_iota; - // DT_INT32 Iota will always return status::OK(). - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature, - &input_feature_iota)); - xla::XlaOp expanded_feature_iota; - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, - input_feature * depthwise_multiplier, - &expanded_feature_iota)); + xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature); + xla::XlaOp expanded_feature_iota = + xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier); // Divide the M*N sized linspace by the depthwise_multiplier to create // [0 0 1 1 2 2] in the example in the function comment. diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc index 500a564f3f0489a42dbc9d5b70ae7708a7a43973..db579a5b35d69deb3dca578e31c1b54fada76342 100644 --- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index 9ff3e0222831cb4339943966810eeae451e47a2c..ef1015552d181a183d412f9c269dd5ec608b388f 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/util/bcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index 4f92dbc8740b697322424058530b8477c35d809a..a5b870f8dbf70bcee331992345d63fd5d986bdca 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/bcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index f3149200250935629a6e4bf67bff0c048135ce3e..12b0e38288e8f222ed506a75ec2575f27141c859 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 17bf0c069c658c3d513ed3fd85e14b0dd5ef0d33..ed44ad218b6dc073583ec339da082b6881ad672d 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -18,7 +18,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -26,10 +28,10 @@ namespace tensorflow { namespace { // Create a diagonal / batch diagonal matrix with 'input' on the diagonal. -xla::StatusOr CreateDiagonal( - const xla::XlaOp& input, int64 last_dim_size, - tensorflow::gtl::ArraySlice other_dims, XlaOpKernelContext* ctx, - xla::XlaBuilder* builder) { +xla::XlaOp CreateDiagonal(xla::XlaOp input, int64 last_dim_size, + gtl::ArraySlice other_dims, + xla::PrimitiveType element_type) { + xla::XlaBuilder* builder = input.builder(); // Create two matrices that have the following forms, and compare them: // // [[0, 0, 0, 0] [[0, 1, 2, 3] @@ -39,9 +41,7 @@ xla::StatusOr CreateDiagonal( // // This produces a predicate matrix of the right size, with "true" on the // diagonal. - xla::XlaOp iota; - TF_RETURN_IF_ERROR( - XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota)); + xla::XlaOp iota = xla::Iota(builder, xla::S32, last_dim_size); xla::XlaOp iota_broadcast = xla::Broadcast(iota, {last_dim_size}); xla::XlaOp mask = xla::Eq(iota_broadcast, iota, {0}); @@ -68,12 +68,9 @@ xla::StatusOr CreateDiagonal( xla::XlaOp input_broadcast = xla::Reshape(input, broadcast_dims); broadcast_dims[broadcast_dims.size() - 2] = last_dim_size; - xla::PrimitiveType element_type; - TF_RETURN_IF_ERROR( - DataTypeToPrimitiveType(ctx->input_type(0), &element_type)); auto broadcast_shape = xla::ShapeUtil::MakeShape(element_type, broadcast_dims); - xla::XlaOp zeros = Zeros(builder, broadcast_shape); + xla::XlaOp zeros = xla::Zeros(builder, broadcast_shape); input_broadcast = xla::Add(input_broadcast, zeros); return xla::Select(mask, input_broadcast, zeros); @@ -84,8 +81,6 @@ class DiagOp : public XlaOpKernel { explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); - OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("Diag op must have at an input")); const TensorShape input_shape = ctx->InputShape(0); @@ -108,10 +103,8 @@ class DiagOp : public XlaOpKernel { input = xla::Reshape(input, {size}); // Create an R2 with the R1 diagonal. - auto diag_or_status = - CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder); - OP_REQUIRES_OK(ctx, diag_or_status.status()); - xla::XlaOp diag = diag_or_status.ValueOrDie(); + xla::XlaOp diag = + CreateDiagonal(input, size, /*other_dims=*/{}, ctx->input_xla_type(0)); // Reshapes to the final shape. std::vector new_dims(dims.size() * 2); @@ -130,8 +123,6 @@ class DiagPartOp : public XlaOpKernel { explicit DiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); - const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -157,37 +148,13 @@ class DiagPartOp : public XlaOpKernel { new_dims.push_back(dims[i]); } - xla::XlaOp diag = ctx->Input(0); - - // TODO(b/30878775): use Slice with strides when supported, in place of - // the Pad -> Reshape -> Slice. - - // Picture: - // [[1, 0, 0, 0] pad and reshape to [[1, 0, 0, 0, 0], - // [0, 2, 0, 0] =================> [2, 0, 0, 0, 0], - // [0, 0, 3, 0] [3, 0, 0, 0, 0], - // [0, 0, 0, 4]] [4, 0, 0, 0, 0]] - // and then slice out the first column. - - // Flattens the input to 1D. - int64 size = input_shape.num_elements(); - diag = xla::Reshape(diag, {size}); - - // Adds padding after the last element of 'new_size'. - xla::PaddingConfig config; - auto* dim = config.add_dimensions(); - dim->set_edge_padding_high(new_size); - auto zero = XlaHelpers::Zero(builder, input_type(0)); - diag = xla::Pad(diag, zero, config); - - // Reshapes so the diagonal is now in the first column. - diag = xla::Reshape(diag, {new_size, new_size + 1}); + xla::XlaOp input = ctx->Input(0); - // Slices out the first column and reshapes to the final shape. - diag = xla::Slice(diag, {0, 0}, {new_size, 1}, {1, 1}); - diag = xla::Reshape(diag, new_dims); + xla::XlaOp output = xla::Reshape( + xla::GetMatrixDiagonal(xla::Reshape(input, {new_size, new_size})), + new_dims); - ctx->SetOutput(0, diag); + ctx->SetOutput(0, output); } }; @@ -198,8 +165,6 @@ class MatrixDiagOp : public XlaOpKernel { explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); - OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("MatrixDiag op must have at an input")); const TensorShape input_shape = ctx->InputShape(0); @@ -209,17 +174,15 @@ class MatrixDiagOp : public XlaOpKernel { errors::InvalidArgument("Expected 1 <= dims, got shape ", input_shape.DebugString())); - xla::XlaOp diag = ctx->Input(0); int last_dim = dims.size() - 1; int64 last_dim_size = input_shape.dim_size(last_dim); tensorflow::gtl::ArraySlice other_dims(dims); other_dims.pop_back(); - auto diag_or_status = - CreateDiagonal(diag, last_dim_size, other_dims, ctx, builder); - OP_REQUIRES_OK(ctx, diag_or_status.status()); - diag = diag_or_status.ValueOrDie(); + xla::XlaOp input = ctx->Input(0); + xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims, + ctx->input_xla_type(0)); ctx->SetOutput(0, diag); } }; @@ -231,8 +194,6 @@ class MatrixDiagPartOp : public XlaOpKernel { explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* builder = ctx->builder(); - const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -240,71 +201,8 @@ class MatrixDiagPartOp : public XlaOpKernel { errors::InvalidArgument("Expected 2 <= dims, got shape ", input_shape.DebugString())); - xla::XlaOp diag = ctx->Input(0); - - int last_dim = dims.size() - 1; - int64 last_dim_size = dims[last_dim]; - - // The smaller of the last two dimension sizes. - int64 smaller_dim_size = std::min(dims[last_dim - 1], dims[last_dim]); - - // TODO(b/30878775): use Slice with strides when supported, in place of - // the Pad -> Reshape -> Slice. - - // Picture: for each 2D matrix in the tensor's last two dimensions: - // [[1, 0, 0, 0] pad and reshape to [[1, 0, 0, 0, 0], - // [0, 2, 0, 0] =================> [2, 0, 0, 0, 0], - // [0, 0, 3, 0]] [3, 0, 0, 0, 0], - // and then slice out the first column. - // - // Another example, with tall and narrow input. - // [[1, 0] pad and reshape to [[1, 0, 0], - // [0, 2] =================> [2, 0, 0]] - // [0, 0] - // [0, 0]] - - // Collapses the last two dimensions. - std::vector flattened_dims(dims.begin(), dims.end() - 1); - flattened_dims.back() *= dims.back(); - diag = xla::Reshape(diag, flattened_dims); - - // Slices or pads the last dimension to 'target_size'. - int64 actual_size = flattened_dims.back(); - int64 target_size = smaller_dim_size * (last_dim_size + 1); - if (actual_size < target_size) { - xla::PaddingConfig config = - xla::MakeNoPaddingConfig(flattened_dims.size()); - auto* dim = config.mutable_dimensions(flattened_dims.size() - 1); - dim->set_edge_padding_high(target_size - actual_size); - auto zero = XlaHelpers::Zero(builder, input_type(0)); - diag = xla::Pad(diag, zero, config); - } else if (actual_size > target_size) { - std::vector start(flattened_dims.size(), 0); - std::vector limits(flattened_dims.begin(), flattened_dims.end()); - std::vector strides(flattened_dims.size(), 1); - limits[flattened_dims.size() - 1] = target_size; - diag = xla::Slice(diag, start, limits, strides); - } - - // Reshape so the target values are in the first position of the last - // dimension. - std::vector unflattened_dims(dims.begin(), dims.end()); - dims[last_dim - 1] = smaller_dim_size; - dims[last_dim] = last_dim_size + 1; - diag = xla::Reshape(diag, dims); - - // Slices out the first column and reshapes to the final shape. - std::vector start(dims.size(), 0); - std::vector limits(dims.begin(), dims.end()); - std::vector strides(dims.size(), 1); - limits[last_dim] = 1; - diag = xla::Slice(diag, start, limits, strides); - - // Collapses away the last dimension. - dims.pop_back(); - diag = xla::Reshape(diag, dims); - - ctx->SetOutput(0, diag); + xla::XlaOp input = ctx->Input(0); + ctx->SetOutput(0, xla::GetMatrixDiagonal(input)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index 3b86ea34c9e7d943eb9c7de222e0a2be049ebc68..a3389d5b905bf3ee15744ab4fcee193d312e2ae0 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/compiler/tf2xla/type_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 958231505b50431b9bb267b0a3cc5ed56e3aeb21..cb73053666d4c32bc0a2ef19b174aee1a29f101e 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index 2c76bcee2593b820eafe09af3a52736ed8a92f86..5fdb1d972c55efb876972d3f472b53a1f7cde1c2 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index b2451236de83583e9915967f420d1235acd92600..c68b0bfd7961892294c2931e5c4c44de534a7740 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -17,7 +17,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -111,9 +112,7 @@ class ExtractImagePatchesOp : public XlaOpKernel { // Builds an identity matrix as a broadcast equality of iotas. // iota = np.arange(np.prod(ksize), depth) // filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32) - xla::XlaOp iota; - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, - kernel_size * depth, &iota)); + xla::XlaOp iota = xla::Iota(builder, xla::S32, kernel_size * depth); auto lhs = xla::Reshape(iota, lhs_shape); auto filter = xla::ConvertElementType( diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 2fd1a34741e1c7235397f9a69dd8444b4679fa22..cdba6680dee3fade5bdf0c453ed672b653072b0d 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index b2b00e51e3b00fa93c258af489cf0f4a3e6e764b..80bcef966360ec9a1ca63a02741108ce41b31846 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index 95faa1d058f4c0d3fa802b157c6daba1e1adaf41..54b21a278229024e3e54e9135548be6b69b077e1 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 5f041be5df226ed996b21844c0cf92b6dfac005c..35de96e0aab847fa39ef26d5f3052c392062fd7d 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index d898e43b858bac706d524c7c271f48b1b5fa258f..92346283c31dfe1d638526ac4b26ef762cd7fd14 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/bcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index f5fcf3cacdbff8297bc42fcb0cf79c2bc83a4e11..462e0e439583e6e2a622e44eecd4874ad5da7562 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { @@ -246,6 +246,8 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Done building If"; } +REGISTER_XLA_OP(Name("If").AllowResourceTypes(), XlaIfOp); +REGISTER_XLA_OP(Name("StatelessIf").AllowResourceTypes(), XlaIfOp); REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index cb4caf7bcb4caaa1bf7e0e79e52bb966a8838db3..33a73fe5fdf403e513be085dd7bcea3255277b4a 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -17,7 +17,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace { @@ -311,5 +316,150 @@ class AdjustHueOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp); +class NonMaxSuppressionOp : public XlaOpKernel { + public: + explicit NonMaxSuppressionOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size", + &pad_to_max_output_size_)); + } + + void Compile(XlaOpKernelContext* context) override { + // TODO(b/111646731): Improve scalability of this op, using blocking. + int num_boxes_dim = 0; + int coords_dim = 1; + const TensorShape& boxes_shape = context->InputShape("boxes"); + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(boxes_shape), + errors::InvalidArgument("boxes must be 2-D, currently: ", + boxes_shape.DebugString())); + const int64 num_boxes = boxes_shape.dim_size(num_boxes_dim); + OP_REQUIRES(context, boxes_shape.dim_size(coords_dim) == 4, + errors::InvalidArgument("boxes must have 4 columns", + boxes_shape.DebugString())); + const TensorShape& scores_shape = context->InputShape("scores"); + OP_REQUIRES(context, TensorShapeUtils::IsVector(scores_shape), + errors::InvalidArgument("scores must be 1-D, currently: ", + scores_shape.DebugString())); + OP_REQUIRES( + context, scores_shape.dim_size(0) == num_boxes, + errors::InvalidArgument("scores size must equal number of boxes", + scores_shape.DebugString())); + OP_REQUIRES(context, pad_to_max_output_size_, + errors::InvalidArgument( + "XLA compilation requires pad_to_max_output_size == True")); + + xla::XlaOp boxes = context->Input("boxes"); + xla::XlaOp scores = context->Input("scores"); + int64 output_size; + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &output_size)); + OP_REQUIRES( + context, output_size >= 0, + errors::InvalidArgument("Need output_size >= 0, got ", output_size)); + xla::XlaOp score_thresh = context->Input("score_threshold"); + xla::XlaOp iou_thresh = context->Input("iou_threshold"); + + xla::XlaBuilder* const builder = context->builder(); + + // Choose a more convenient layout. + xla::XlaOp boxes_t = xla::Transpose(boxes, {1, 0}); + coords_dim = 0; + num_boxes_dim = 1; + + // Shapes are henceforth [1, num_boxes]. + xla::XlaOp coord_y0 = xla::SliceInDim(boxes_t, + /*start_index=*/0, + /*limit_index=*/1, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp coord_x0 = xla::SliceInDim(boxes_t, + /*start_index=*/1, + /*limit_index=*/2, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp coord_y1 = xla::SliceInDim(boxes_t, + /*start_index=*/2, + /*limit_index=*/3, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp coord_x1 = xla::SliceInDim(boxes_t, + /*start_index=*/3, + /*limit_index=*/4, + /*stride=*/1, + /*dimno=*/coords_dim); + xla::XlaOp y1 = + xla::Select(xla::Le(coord_y0, coord_y1), coord_y0, coord_y1); + xla::XlaOp y2 = + xla::Select(xla::Le(coord_y0, coord_y1), coord_y1, coord_y0); + xla::XlaOp x1 = + xla::Select(xla::Le(coord_x0, coord_x1), coord_x0, coord_x1); + xla::XlaOp x2 = + xla::Select(xla::Le(coord_x0, coord_x1), coord_x1, coord_x0); + xla::XlaOp area = (y2 - y1) * (x2 - x1); + + // Transpose the 1xN tensors, instead of the NxN tensors. + xla::XlaOp y1_t = xla::Transpose(y1, {1, 0}); + xla::XlaOp y2_t = xla::Transpose(y2, {1, 0}); + xla::XlaOp x1_t = xla::Transpose(x1, {1, 0}); + xla::XlaOp x2_t = xla::Transpose(x2, {1, 0}); + xla::XlaOp area_t = xla::Transpose(area, {1, 0}); + + // Shapes are henceforth [num_boxes, num_boxes]. + xla::XlaOp i_xmin = xla::Max(x1, x1_t); + xla::XlaOp i_ymin = xla::Max(y1, y1_t); + xla::XlaOp i_xmax = xla::Min(x2, x2_t); + xla::XlaOp i_ymax = xla::Min(y2, y2_t); + auto square_zero = xla::ZerosLike(i_xmin); + + xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) * + xla::Max(i_ymax - i_ymin, square_zero); + xla::XlaOp u_area = area + area_t - i_area; + xla::XlaOp iou = i_area / u_area; + + xla::XlaOp iou_thresh_mask = xla::Gt(iou, iou_thresh + square_zero); + xla::XlaOp scores_2d = xla::Reshape(scores, {num_boxes, 1}); + xla::XlaOp score_cmp_mask = + xla::Gt(scores_2d, xla::Transpose(scores_2d, {1, 0})); + xla::XlaOp suppress = xla::And(iou_thresh_mask, score_cmp_mask); + + // Shapes are [num_boxes] after the reduce. + xla::XlaOp included_iou = xla::Not(xla::Reduce( + suppress, + /*init_value=*/xla::ConstantR0(builder, false), + /*computation=*/CreateScalarOrComputation(xla::PRED, builder), + /*dimensions_to_reduce=*/{0})); + xla::XlaOp included_score = + xla::Gt(scores, xla::Broadcast(score_thresh, {num_boxes})); + xla::XlaOp included = xla::And(included_iou, included_score); + xla::XlaOp neg_inf = + xla::Broadcast(xla::MinValue(builder, xla::F32), {num_boxes}); + xla::XlaOp scores_included = xla::Select(included, scores, neg_inf); + + xla::XlaOp ones_included = xla::Select( + included, + xla::Broadcast(xla::ConstantR0(builder, 1), {num_boxes}), + xla::Broadcast(xla::ConstantR0(builder, 0), {num_boxes})); + + // num_valid is scalar. + xla::XlaOp num_valid = xla::Reduce( + ones_included, + /*init_value=*/xla::ConstantR0(builder, 0), + /*computation=*/CreateScalarAddComputation(xla::S32, builder), + /*dimensions_to_reduce=*/{0}); + + xla::XlaOp output_tuple = TopK(scores_included, output_size); + xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1); + + context->SetOutput(0, selected_indices); + context->SetOutput(1, num_valid); + } + + private: + bool pad_to_max_output_size_; +}; + +REGISTER_XLA_OP( + Name("NonMaxSuppressionV4").CompileTimeConstInput("max_output_size"), + NonMaxSuppressionOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index de971ce4acd424cb7194d328db69dc4846ced2db..8d75624e74028ea083c3facc4f9578ec14c50e6d 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -18,7 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/lib/math/math_util.h" @@ -128,10 +129,7 @@ const int64 kMax2DKernelSize = 16; xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, gtl::ArraySlice kernel_size, int64 channels) { - xla::XlaOp channels_iota; - // DT_INT32 Iota will always return status::OK(). - TF_CHECK_OK( - XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); + xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); auto diag = xla::ConvertElementType( xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1, @@ -149,10 +147,7 @@ xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, gtl::ArraySlice kernel_size, int64 channels, int64 dim) { - xla::XlaOp channels_iota; - // DT_INT32 Iota will always return status::OK(). - TF_CHECK_OK( - XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); + xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); auto diag = xla::ConvertElementType( xla::Eq( diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index 36eb4c75454ed82804c40b82e5dbaec2eef0a719..f3964748587c1b31cf8b1b76643ff19a9044bf44 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -60,19 +60,15 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { input_shape.DebugString())); DataType index_type = output_type(0); + xla::PrimitiveType index_xla_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(index_type, &index_xla_type)); - xla::XlaBuilder* b = ctx->builder(); xla::XlaOp input = ctx->Input(0); - xla::XlaOp output; if (is_min_) { - OP_REQUIRES_OK(ctx, - XlaHelpers::ArgMin(b, ctx, input, input_shape, input_type(0), - index_type, axis, &output)); + output = XlaHelpers::ArgMin(input, index_xla_type, axis); } else { - OP_REQUIRES_OK(ctx, - XlaHelpers::ArgMax(b, ctx, input, input_shape, input_type(0), - index_type, axis, &output)); + output = XlaHelpers::ArgMax(input, index_xla_type, axis); } ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index a020ebc729e4c07d1b182cc0585ba0f2bca46403..22a45b2a11e8ecb688f8e773ef4b286eafe68f4f 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel { std::vector args; args.push_back(ctx->Input(0)); args.push_back(xla::ConstantLiteral( - &b, *xla::Literal::CreateR1(input_shape.dim_sizes()))); + &b, *xla::LiteralUtil::CreateR1(input_shape.dim_sizes()))); if (input_shape.dims() > 1) { // Don't bother passing the output shape and dim for the 1d case, since // the shape is always a scalar and the dim is always 0. args.push_back(xla::ConstantLiteral( - &b, *xla::Literal::CreateR1(output_shape.dim_sizes()))); + &b, *xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); args.push_back( - xla::ConstantLiteral(&b, *xla::Literal::CreateR0(dim))); + xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0(dim))); } xla::Shape xla_shape = diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index 9e64711051d31107db1bf6f1966f9ed6f5630c34..f028e361bccd51de0bd69a1d2227c7afaed53455 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc index 2fb072f827906d40dcf410f0312394c4f568a28d..a11bbe918f7f8eb050aaa40d4344f9cc9e9a10a4 100644 --- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index dc934543cb2f94fbe1e8f1f865156eb082d6a127..87ee2d3aede50eb24e65570f106d49030e1d4236 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index 844080b8cf5462da201ce7671e4f9d02fa52c861..6440770c29894c951f010f6c1deb929f4fe79bbf 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -54,10 +54,14 @@ class MatMulOp : public XlaOpKernel { const TensorShape b_shape = ctx->InputShape(1); // Check that the dimensions of the two matrices are valid. - OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a_shape), - errors::InvalidArgument("In[0] is not a matrix")); - OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b_shape), - errors::InvalidArgument("In[1] is not a matrix")); + OP_REQUIRES( + ctx, TensorShapeUtils::IsMatrix(a_shape), + errors::InvalidArgument("In[0] is not a matrix. Instead it has shape ", + a_shape.DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::IsMatrix(b_shape), + errors::InvalidArgument("In[1] is not a matrix. Instead it has shape ", + b_shape.DebugString())); int first_index = transpose_a_ ? 0 : 1; int second_index = transpose_b_ ? 1 : 0; diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc index 9d3575e331cc188340afc5bbb1f4468a9d7fbde7..8dfd7de591c4a3c4768dd60b41e03d294ad49397 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -16,7 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -51,6 +52,7 @@ class MatrixBandPartOp : public XlaOpKernel { xla::XlaOp num_upper = context->Input(2); DataType input_type = context->input_type(0); DataType index_type = context->input_type(1); + xla::PrimitiveType index_xla_type = context->input_xla_type(1); TensorShape batch_shape = input_shape; batch_shape.RemoveLastDims(2); @@ -59,11 +61,8 @@ class MatrixBandPartOp : public XlaOpKernel { // Compute 'offset', which is how many diagonals we are above/below the // diagonal. - xla::XlaOp iota_m; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m)); - - xla::XlaOp iota_n; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n)); + xla::XlaOp iota_m = xla::Iota(builder, index_xla_type, m); + xla::XlaOp iota_n = xla::Iota(builder, index_xla_type, n); auto offset = xla::Sub(xla::Broadcast(iota_n, {m}), iota_m, /*broadcast_dimensions=*/{0}); diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc index 7bf1894ea03988f608080fd051b3d6c528b6443b..c0ca881ff82cee04e0c5e35f9a2d5732fabdd8a6 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -16,7 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { @@ -62,10 +63,8 @@ class MatrixSetDiagOp : public XlaOpKernel { auto zero = XlaHelpers::Zero(builder, context->input_type(0)); // Create an indicator tensor that is true only on the diagonal. - xla::XlaOp iota_m; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m)); - xla::XlaOp iota_n; - OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n)); + xla::XlaOp iota_m = xla::Iota(builder, xla::S32, m); + xla::XlaOp iota_n = xla::Iota(builder, xla::S32, n); auto indicator = xla::Eq(iota_m, xla::Broadcast(iota_n, {m}), /*broadcast_dimensions=*/{0}); indicator = xla::Broadcast(indicator, batch_shape.dim_sizes()); diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index eaed93146460de5a6e8328432302cc75bf36a534..f4def11d08c31513aec5aad15187016a7294c2fd 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -30,13 +30,9 @@ class MatrixTriangularSolveOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { auto result = TriangularSolve( - ctx->builder(), ctx->Input(0), ctx->Input(1), /*left_side=*/true, + ctx->Input(0), ctx->Input(1), /*left_side=*/true, /*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_); - if (!result.ok()) { - ctx->SetStatus(result.status()); - return; - } - ctx->SetOutput(0, result.ValueOrDie()); + ctx->SetOutput(0, result); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index 529959dbd90b05f8860360f70e087ef225150600..eedfc3c9140d7b1ccc1944611de98c1d49fbdaf2 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/mirror_pad_mode.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index 3aed47de2603f3e187ad515d4db3f884da4c6cc8..a9b519d8928cc2807831fd6b4f12e60b7d58ea55 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 89fd610bc63349d008836c3c4e6ec8927c232a54..e5937b56c17d01892928b073da09f38941ea1bbb 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 771dcbab21691ff1f018e4d65815cd5a53c9447a..d4d180aff806f12875f0e43f111ee090f6607ef6 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -20,8 +20,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -62,63 +65,60 @@ class PoolingOp : public XlaOpKernel { Padding padding; OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding)); padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame; + + OP_REQUIRES_OK( + ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_)); } int num_dims() const { return num_spatial_dims_ + 2; } - // Method that builds an initial value to use in reductions. - virtual xla::XlaOp InitValue(xla::XlaBuilder* b) = 0; - - // The reduction operation to apply to each window. - virtual const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) = 0; - - // A post-processing operation to apply on the outputs of the ReduceWindow. - virtual xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, - const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape) = 0; - - void Compile(XlaOpKernelContext* ctx) override { - std::vector ksize = ksize_; - std::vector stride = stride_; - if (ctx->num_inputs() != 1) { - const TensorShape ksize_shape = ctx->InputShape(1); - // Validate input sizes. - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), - errors::InvalidArgument("ksize must be a vector, not shape ", - ksize_shape.DebugString())); - OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(), - errors::InvalidArgument("Sliding window ksize field must " - "specify ", - num_dims(), " dimensions")); - ksize.clear(); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize)); - - const TensorShape stride_shape = ctx->InputShape(2); - // Validate input sizes. - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), - errors::InvalidArgument("stride must be a vector, not shape ", - stride_shape.DebugString())); - OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(), - errors::InvalidArgument("Sliding window stride field must " - "specify ", - num_dims(), " dimensions")); - stride.clear(); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride)); + protected: + xla::StatusOr> GetKernelSize(XlaOpKernelContext* ctx) { + if (ctx->num_inputs() == 1) { + return ksize_; } - const TensorShape input_shape = ctx->InputShape(0); - OP_REQUIRES(ctx, input_shape.dims() == num_dims(), - errors::InvalidArgument("Input to ", type_string(), - " operator must have ", num_dims(), - " dimensions")); + const TensorShape ksize_shape = ctx->InputShape(1); + // Validate input sizes. + if (!TensorShapeUtils::IsVector(ksize_shape)) { + return errors::InvalidArgument("ksize must be a vector, not shape ", + ksize_shape.DebugString()); + } + if (ksize_shape.num_elements() != num_dims()) { + return errors::InvalidArgument( + "Sliding window ksize field must " + "specify ", + num_dims(), " dimensions"); + } + std::vector ksize; + auto status = ctx->ConstantInputAsIntVector(1, &ksize); + if (!status.ok()) { + return status; + } + return ksize; + } - xla::XlaBuilder* const b = ctx->builder(); - auto input = - XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_); - auto reduce = xla::ReduceWindow(input, InitValue(b), *Reduction(ctx), ksize, - stride, padding_); - auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); - ctx->SetOutput(0, - PostProcessOutput(ctx, pooled, input_type(0), input_shape)); + xla::StatusOr> GetStride(XlaOpKernelContext* ctx) { + if (ctx->num_inputs() == 1) { + return stride_; + } + const TensorShape stride_shape = ctx->InputShape(2); + // Validate input sizes. + if (!TensorShapeUtils::IsVector(stride_shape)) { + return errors::InvalidArgument("stride must be a vector, not shape ", + stride_shape.DebugString()); + } + if (stride_shape.num_elements() != num_dims()) { + return errors::InvalidArgument( + "Sliding window stride field must " + "specify ", + num_dims(), " dimensions"); + } + std::vector stride; + auto status = ctx->ConstantInputAsIntVector(2, &stride); + if (!status.ok()) { + return status; + } + return stride; } protected: @@ -128,26 +128,51 @@ class PoolingOp : public XlaOpKernel { xla::Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; DataType reduction_type_; + xla::PrimitiveType xla_reduction_type_; }; +// Converts the tensor data format to the one required by the XLA pooling +// library. +xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format, + int num_spatial_dims) { + int num_dims = num_spatial_dims + 2; + int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format); + int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format); + gtl::InlinedVector spatial_dimensions(num_spatial_dims); + for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) { + spatial_dimensions[spatial_dim] = + GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim); + } + return xla::TensorFormat(/*batch_dimension=*/batch_dimension, + /*feature_dimension=*/feature_dimension, + /*spatial_dimensions=*/spatial_dimensions); +} + class MaxPoolOp : public PoolingOp { public: MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, /*reduction_type=*/ctx->input_type(0)) {} - xla::XlaOp InitValue(xla::XlaBuilder* b) override { - return XlaHelpers::MinValue(b, reduction_type_); - } + void Compile(XlaOpKernelContext* ctx) override { + auto ksize_or_error = GetKernelSize(ctx); + OP_REQUIRES_OK(ctx, ksize_or_error.status()); + std::vector ksize = ksize_or_error.ValueOrDie(); - const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { - return ctx->GetOrCreateMax(reduction_type_); - } + auto stride_or_error = GetStride(ctx); + OP_REQUIRES_OK(ctx, stride_or_error.status()); + std::vector stride = stride_or_error.ValueOrDie(); - xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, - const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape) override { - return output; + const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == num_dims(), + errors::InvalidArgument("Input to ", type_string(), + " operator must have ", num_dims(), + " dimensions")); + + auto pooling = + xla::MaxPool(ctx->Input(0), ksize, stride, padding_, + XlaTensorFormat(data_format_, input_shape.dims() - 2)); + ctx->SetOutput(0, pooling); } }; @@ -174,9 +199,8 @@ class MaxPool3DOp : public MaxPoolOp { }; REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp); -// Common computation shared between AvgPool and AvgPoolGrad. Divide each -// element of an image by the count of elements that contributed to that -// element during pooling. +// Divide each element of an image by the count of elements that contributed to +// that element during pooling. static xla::XlaOp AvgPoolDivideByCount( XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype, const TensorShape& input_shape, xla::Padding padding, @@ -235,20 +259,34 @@ class AvgPoolOp : public PoolingOp { /*reduction_type=*/ XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::XlaOp InitValue(xla::XlaBuilder* b) override { - return XlaHelpers::Zero(b, reduction_type_); - } + void Compile(XlaOpKernelContext* ctx) override { + auto ksize_or_error = GetKernelSize(ctx); + OP_REQUIRES_OK(ctx, ksize_or_error.status()); + std::vector ksize = ksize_or_error.ValueOrDie(); - const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { - return ctx->GetOrCreateAdd(reduction_type_); - } + auto stride_or_error = GetStride(ctx); + OP_REQUIRES_OK(ctx, stride_or_error.status()); + std::vector stride = stride_or_error.ValueOrDie(); + + const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == num_dims(), + errors::InvalidArgument("Input to ", type_string(), + " operator must have ", num_dims(), + " dimensions")); - xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, - const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape) override { - return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_, - ksize_, stride_, num_spatial_dims_, - data_format_); + auto xla_data_format = + XlaTensorFormat(data_format_, input_shape.dims() - 2); + auto spatial_padding = MakeSpatialPadding( + input_shape.dim_sizes(), ksize, stride, padding_, xla_data_format); + + // Convert the input to the reduction type. + auto converted_input = + ConvertElementType(ctx->Input(0), xla_reduction_type_); + auto pooling = + xla::AvgPool(converted_input, ksize, stride, spatial_padding, + xla_data_format, padding_ == xla::Padding::kValid); + // Convert the pooling result back to the input type before returning it. + ctx->SetOutput(0, ConvertElementType(pooling, ctx->input_xla_type(0))); } }; @@ -628,7 +666,7 @@ class MaxPoolGradGradOp : public XlaOpKernel { auto in_hi_bp_hi = xla::Add(in_hi, bp_hi); // Want an unsigned add. auto in_hi_bp_lo = xla::Add(in_hi, bp_lo); // Want an unsigned add. - auto init_value = XlaHelpers::MinValue(b, DT_FLOAT); + auto init_value = xla::MinValue(b, xla::F32); // We will reduce by taking the maximal value up to 16 bits (ignoring the lo // 16 bits of packed-in hi/lo backprop value). auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits"); diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..de9068a640dc03b141b6954eaa1629dd6c8c1f3a --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/qr.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +class QROp : public XlaOpKernel { + public: + explicit QROp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + bool full_matrices; + OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices)); + OP_REQUIRES( + ctx, full_matrices, + errors::Unimplemented("full_matrices=False case of QR decomposition is " + "not implemented in TF/XLA")); + } + void Compile(XlaOpKernelContext* ctx) override { + auto result = QRDecomposition(ctx->Input(0)); + if (!result.ok()) { + ctx->SetStatus(result.status()); + return; + } + ctx->SetOutput(0, result.ValueOrDie().q); + ctx->SetOutput(1, result.ValueOrDie().r); + } +}; + +REGISTER_XLA_OP(Name("Qr").TypeConstraint("T", kFloatTypes), QROp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index 02293796e47063b81a9ff46c8b911461e3a5f5e5..6f4ed496a1774dde68dd9d5fbd37995d615b678c 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -18,7 +18,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -50,8 +52,8 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { } else { const xla::XlaComputation* fmax = ctx->GetOrCreateMax(data_type); const xla::XlaComputation* fmin = ctx->GetOrCreateMin(data_type); - min_range = ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin); - max_range = ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax); + min_range = ReduceAll(input, xla::MaxValue(b, xla_type), *fmin); + max_range = ReduceAll(input, xla::MinValue(b, xla_type), *fmax); } xla::XlaOp num_bits; @@ -93,10 +95,10 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { // while keeping 0 unchanged. xla::XlaOp scale_from_min_side = Select(Gt(min_quantized * min_range, zero), min_quantized / min_range, - XlaHelpers::MaxFiniteValue(b, data_type)); + xla::MaxFiniteValue(b, xla_type)); xla::XlaOp scale_from_max_side = Select(Gt(max_quantized * max_range, zero), max_quantized / max_range, - XlaHelpers::MaxFiniteValue(b, data_type)); + xla::MaxFiniteValue(b, xla_type)); // Note: Avoids changing the side of the range that determines scale. xla::XlaOp cond = Lt(scale_from_min_side, scale_from_max_side); diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 51f2cdc9f44daf81e9b56ec71a7ff81ed6cbb2ad..2da9340625db08b14b78340c471f096baf15689d 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -26,7 +26,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -73,57 +74,121 @@ class RandomShuffleOp : public XlaOpKernel { for (tensorflow::TensorShapeDim dimension : input_shape) { num_elements *= dimension.size; } + if (num_elements <= 1 || n <= 1) { // No shuffling is required, so copy input directly to output ctx->SetOutput(0, input); - } else { - // Generate the random swaps for the indices. - auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n}); - auto swaps = - xla::RngUniform(xla::ConstantR0(builder, 0), - xla::ConstantR0(builder, n), swaps_shape); - - // Generate range(n) as the initial value for the indices to be swapped. - xla::XlaOp indices; - TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, n, &indices)); - - // Swap the indices at i and swaps[i]. - auto swap_body_fn = [&](xla::XlaOp i, - gtl::ArraySlice loop_vars, - xla::XlaBuilder* builder) - -> xla::StatusOr> { - auto swaps = loop_vars[0]; - auto indices = loop_vars[1]; - i = xla::Reshape(i, {1}); - // temp = indices[i] - auto temp = xla::DynamicSlice(indices, i, {1}); - // swap_index = swaps[i] - auto swap_index = xla::DynamicSlice(swaps, i, {1}); - // swap_value = indices[swaps[i]] - auto swap_value = xla::DynamicSlice(indices, swap_index, {1}); - // indices[i] = indices[swaps[i]] - indices = xla::DynamicUpdateSlice(indices, swap_value, i); - // indices[swaps[i]] = temp - indices = xla::DynamicUpdateSlice(indices, temp, swap_index); - return std::vector{swaps, indices}; - }; - // for i in range(n): - auto swap_loop_result = - XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, - "indices_swap_loop", builder) - .ValueOrDie(); - auto swapped_indices = swap_loop_result[1]; - - // Gather the data using the swapped indices as the shuffled order. - auto indices_tensor_shape = TensorShape({n}); - DataType type = ctx->expected_output_dtype(0); - xla::XlaOp gather; - OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices, - indices_tensor_shape, - /*axis=*/0, /*indices_are_nd=*/false, type, - DT_INT32, builder, &gather)); - ctx->SetOutput(0, gather); + return; + } + + if (input_shape.dims() == 1) { + // For R1s, shuffle values by sorting instead of the obvious Fisher-Yates + // algorithm. Fisher-Yates is simple to implement and correct, but not + // easily parallelizable. For a sufficiently parallel architecture, it is + // faster to sort many times, than Fisher-Yates shuffle once. + + // Shuffle values by assigning each value a random key and sorting the + // keys. Keys can collide causing detectable patterns in the shuffled + // output. Collisions translates into more ascending sub-sequences in the + // shuffled output than would be expected by chance. To avoid collisions, + // the number of possible key values must be sufficiently large. + + // How are more than 2^32 keys created? In each loop iteration, the + // algorithm sorts by random keys. Conceptually, the earlier iterations + // are sorting on the lower-order bits of larger keys that are never + // actually assembled. + + // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is + // the number of possible keys and n is the number of values. If d = n^2, + // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit + // as n goes to infinity is zero. + + // This implementation ensures that the key-space is greater than or equal + // to the cube of the number of values. The risk of collisions can be + // further reduced by increasing Exponent at the expense of + // performance. + + // For Exponent = 2, the expected number of collisions per shuffle is + // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is + // about 1/2. + + // For Exponent = 3, the expected number of collisions per shuffle is + // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is + // about 1/3255. + + // For Exponent = 4, the expected number of collisions per shuffle is + // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is + // about 1/132622. + constexpr int Exponent = 3; + const int rounds = static_cast( + std::ceil(Exponent * std::log(num_elements) / std::log(kuint32max))); + + const xla::Shape key_shape = + xla::ShapeUtil::MakeShape(xla::U32, {num_elements}); + xla::XlaOp zero = xla::ConstantR0(builder, 0U); + + // Unfortunately, xla::RngUniform gives values in the half open interval + // rather than the closed interval, so instead of 2^32 possible keys there + // are only 2^32 - 1 (kuint32max). + xla::XlaOp max_value = xla::ConstantR0(builder, kuint32max); + + xla::XlaOp curr = input; + for (int i = 0; i < rounds; ++i) { + xla::XlaOp keys = xla::RngUniform(zero, max_value, key_shape); + xla::XlaOp sorted = xla::Sort(keys, curr); + curr = xla::GetTupleElement(sorted, 1); + } + + ctx->SetOutput(0, curr); + return; } + + // The Fisher-Yates algorithm. + + // Generate the random swaps for the indices. + auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n}); + auto swaps = + xla::RngUniform(xla::ConstantR0(builder, 0), + xla::ConstantR0(builder, n), swaps_shape); + + // Generate range(n) as the initial value for the indices to be swapped. + xla::XlaOp indices = xla::Iota(builder, xla::S32, n); + + // Swap the indices at i and swaps[i]. + auto swap_body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + xla::XlaBuilder* builder) + -> xla::StatusOr> { + auto swaps = loop_vars[0]; + auto indices = loop_vars[1]; + i = xla::Reshape(i, {1}); + // temp = indices[i] + auto temp = xla::DynamicSlice(indices, i, {1}); + // swap_index = swaps[i] + auto swap_index = xla::DynamicSlice(swaps, i, {1}); + // swap_value = indices[swaps[i]] + auto swap_value = xla::DynamicSlice(indices, swap_index, {1}); + // indices[i] = indices[swaps[i]] + indices = xla::DynamicUpdateSlice(indices, swap_value, i); + // indices[swaps[i]] = temp + indices = xla::DynamicUpdateSlice(indices, temp, swap_index); + return std::vector{swaps, indices}; + }; + // for i in range(n): + auto swap_loop_result = + XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, + "indices_swap_loop", builder) + .ValueOrDie(); + auto swapped_indices = swap_loop_result[1]; + + // Gather the data using the swapped indices as the shuffled order. + auto indices_tensor_shape = TensorShape({n}); + DataType type = ctx->expected_output_dtype(0); + xla::XlaOp gather; + OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices, + indices_tensor_shape, + /*axis=*/0, /*indices_are_nd=*/false, type, + DT_INT32, builder, &gather)); + ctx->SetOutput(0, gather); } private: @@ -211,7 +276,7 @@ class TruncatedNormalOp : public XlaOpKernel { xla::XlaOp min_positive = XlaHelpers::FloatLiteral(b, dtype, std::numeric_limits::min()); auto uniform = xla::RngUniform(min_positive, one, xla_shape); - ctx->SetOutput(0, TruncatedNormal(dtype, uniform)); + ctx->SetOutput(0, TruncatedNormal(uniform)); } }; @@ -220,5 +285,5 @@ REGISTER_XLA_OP(Name("TruncatedNormal") .TypeConstraint("dtype", DT_FLOAT), TruncatedNormalOp); -} // anonymous namespace +} // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index 76bd1e62aa1efd85d6ed489b9a6d22a2bacf2a8b..b11a4ce36da9907ce8fe377c075023a4540797fa 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -19,7 +19,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index d3573bac3d7641128fbfc2122336a7c4347836c0..0d260fa8fcaa513d7854c1e9215952404d555c70 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -19,8 +19,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -32,7 +33,7 @@ class SumOp : public XlaReductionOp { : XlaReductionOp(ctx, XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::Zero(builder, reduction_type_); + return xla::Zero(builder, xla_reduction_type_); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { @@ -49,7 +50,7 @@ class ProdOp : public XlaReductionOp { XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::One(builder, reduction_type_); + return xla::One(builder, xla_reduction_type_); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, @@ -67,7 +68,7 @@ class MinOp : public XlaReductionOp { : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::MaxValue(builder, reduction_type_); + return xla::MaxValue(builder, xla_reduction_type_); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, @@ -84,7 +85,7 @@ class MaxOp : public XlaReductionOp { : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::MinValue(builder, reduction_type_); + return xla::MinValue(builder, xla_reduction_type_); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, @@ -102,7 +103,7 @@ class MeanOp : public XlaReductionOp { XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::Zero(builder, reduction_type_); + return xla::Zero(builder, xla_reduction_type_); } void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, const xla::XlaOp& scalar_rhs) override { @@ -114,7 +115,7 @@ class MeanOp : public XlaReductionOp { int64 num_elements_reduced) override { auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0), num_elements_reduced); - return xla::Div(reduce_output, divisor); + return reduce_output / divisor; } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index 2ecfb854a1c8625524d4f1199af3927edd204926..466e79828d111ee7cadcf713703e8f252c63e62c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -19,7 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -64,6 +64,7 @@ class XlaReductionOp : public XlaOpKernel { protected: DataType reduction_type_; + xla::PrimitiveType xla_reduction_type_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 14506d65c4db4cea5bd9fc037536a894aea4330e..b52f0a0ab6290f2019bb58120be5c2364ec15bb6 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -19,8 +19,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -32,6 +33,8 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx, OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt})); OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_)); + OP_REQUIRES_OK( + ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_)); } // Unless BuildFinalizer is overridden the reduction has no diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index a4ba6c748a73f161ea252e2adf4050eb5dda7df5..d35777ccb1271ec6a7c9972c714d06b2415d9c34 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index e0ca8dd8e27914ad60d0b97e8ac5f0b91a4fd9a6..121750a82a8c5cbe940068555ad273b7e0d22dfc 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index db7ea775e23e86bdbd9259e73dfa2412ef10ac6c..64900e4709fd3e16d21096b0cfff8922906cb0d4 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -16,7 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -62,10 +63,20 @@ class RetvalOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); } else { TensorShape shape = ctx->InputShape(0); - TensorShape representation_shape = - tc.is_entry_computation() - ? tc.RepresentationShape(shape, ctx->input_type(0)) - : shape; + ctx->SetStatus(is_constant.status()); + TensorShape representation_shape; + if (tc.is_entry_computation()) { + xla::StatusOr shape_or_status = + tc.RepresentationShape(shape, ctx->input_type(0)); + if (!shape_or_status.ok()) { + ctx->SetStatus(shape_or_status.status()); + return; + } else { + representation_shape = shape_or_status.ValueOrDie(); + } + } else { + representation_shape = shape; + } xla::XlaOp output = input; if (tc.is_entry_computation()) { @@ -93,7 +104,7 @@ class RetvalOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); }; -REGISTER_XLA_OP(Name("_Retval"), RetvalOp); +REGISTER_XLA_OP(Name("_Retval").CompilationOnly(), RetvalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index 037c422258555289711b8754f2277d077d0cd6a7..d962ef4a5f53470838643541f8a1e693d2f4011c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 16491002b46c4ee6d653a0ab55b8485ae7a62ab7..03a50ef8a059e5a005c4cc2e5e98acedfea8619a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -17,7 +17,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -165,9 +166,8 @@ class ReverseSequenceOp : public XlaOpKernel { auto output = xla::GetTupleElement(loop_output, 2); // Mask out elements after the sequence length. - xla::XlaOp iota; - OP_REQUIRES_OK( - context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota)); + xla::XlaOp iota = + xla::Iota(builder, seq_lens_xla_shape.element_type(), max_seq_len); std::vector dims(input_shape.dims(), 1); dims[batch_dim_] = batch_size; auto mask = xla::Lt(iota, xla::Reshape(seq_lens, dims), {seq_dim_}); diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 76924c6a01a44e7a723b8c8895e8decbdd466c79..ab094d7dd1ce9856a3c2854fd2776827d6c4b76f 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -20,8 +20,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index 14709bb6cbce4b3ae0f7ff859b0fa622c6eda293..f1f32699fee5f03f603f830722fe65622dee5d3e 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index db7e55942012142297f6a4d6afa1065eb0bb24f6..b22ecb7c6dbb42a33a4f4d90b18b20816df16a50 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -14,10 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/lib/scatter.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { @@ -25,15 +27,16 @@ namespace { class UnsortedSegmentReduce : public XlaOpKernel { public: explicit UnsortedSegmentReduce(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + DataType dtype; + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &type_)); } // The initial value to initialize elements of the output to. virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0; // A function to combine two scalars with the same index (e.g., sum). - virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, - xla::XlaBuilder* builder) = 0; + virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) = 0; void Compile(XlaOpKernelContext* ctx) override { // output = unsorted_segment_sum(data, indices, num_segments) @@ -78,9 +81,7 @@ class UnsortedSegmentReduce : public XlaOpKernel { xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes()); auto combiner = [this](xla::XlaOp a, xla::XlaOp b, - xla::XlaBuilder* builder) { - return Combine(a, b, builder); - }; + xla::XlaBuilder* builder) { return Combine(a, b); }; auto result = XlaScatter(buffer, /*updates=*/data, indices, /*indices_are_vectors=*/false, combiner, builder); @@ -89,7 +90,7 @@ class UnsortedSegmentReduce : public XlaOpKernel { } protected: - DataType dtype_; + xla::PrimitiveType type_; }; class UnsortedSegmentSum : public UnsortedSegmentReduce { @@ -98,12 +99,9 @@ class UnsortedSegmentSum : public UnsortedSegmentReduce { : UnsortedSegmentReduce(ctx) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::Zero(builder, dtype_); - }; - xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, - xla::XlaBuilder* builder) override { - return xla::Add(a, b); + return xla::Zero(builder, type_); }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a + b; }; }; REGISTER_XLA_OP( @@ -116,12 +114,9 @@ class UnsortedSegmentProd : public UnsortedSegmentReduce { : UnsortedSegmentReduce(ctx) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::One(builder, dtype_); - }; - xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, - xla::XlaBuilder* builder) override { - return xla::Mul(a, b); + return xla::One(builder, type_); }; + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a * b; }; }; REGISTER_XLA_OP( @@ -134,10 +129,9 @@ class UnsortedSegmentMin : public UnsortedSegmentReduce { : UnsortedSegmentReduce(ctx) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::MaxFiniteValue(builder, dtype_); + return xla::MaxFiniteValue(builder, type_); }; - xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, - xla::XlaBuilder* builder) override { + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return xla::Min(a, b); }; }; @@ -152,10 +146,9 @@ class UnsortedSegmentMax : public UnsortedSegmentReduce { : UnsortedSegmentReduce(ctx) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { - return XlaHelpers::MinFiniteValue(builder, dtype_); + return xla::MinFiniteValue(builder, type_); }; - xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b, - xla::XlaBuilder* builder) override { + xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return xla::Max(a, b); }; }; diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 5c010c9df23ba6c7732d87fa014879d93ff586ce..6ce50efb4aa6e3434a7c6009cf9f52f6cff9cc9f 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/kernels/bounds_check.h" diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index 6281d6c6533f7f49a269f5c7e52226ba0f1d29f6..a7f5a8f1698b9d02560de427d356e9e6be5caa7c 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index bc3d0bf5dfe9e5af8e50a25e27db7148e05e0cfd..25a5bcbe1dd27d741ce3b74125ba9ce425ee78f3 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 5798823cd54c66dd179e3611c0041f7c5a1ff2b5..4e0cf99d8e7ff45ed9145981b5e2e637ce4d4e4b 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/kernels/bounds_check.h" diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 1864584adee357ce35a3e8a38a4e3c58c356bfca..6adc3c58de63ee70c26bed47eebef955893df4a5 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index d1c69f08b0bc85fc47c03015054dd18a65eeedec..025ba827410f1a9f993a8a1855558a2daa86609b 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -15,10 +15,13 @@ limitations under the License. // XLA-specific Ops for softmax. +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -35,38 +38,46 @@ class SoftmaxOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape logits_shape = ctx->InputShape(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape), - errors::InvalidArgument("logits must be 2-dimensional")); + OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(logits_shape), + errors::InvalidArgument("logits must have >= 1 dimension, got ", + logits_shape.DebugString())); - const int kBatchDim = 0; - const int kClassDim = 1; + // Major dimensions are batch dimensions, minor dimension is the class + // dimension. + std::vector batch_dims(logits_shape.dims() - 1); + std::iota(batch_dims.begin(), batch_dims.end(), 0); + const int kClassDim = logits_shape.dims() - 1; const DataType type = input_type(0); + const xla::PrimitiveType xla_type = ctx->input_xla_type(0); auto logits = ctx->Input(0); xla::XlaBuilder* const b = ctx->builder(); const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type); // Find the max in each batch, resulting in a tensor of shape [batch] - auto logits_max = xla::Reduce(logits, XlaHelpers::MinValue(b, type), - max_func, {kClassDim}); + auto logits_max = + xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim}); // Subtract the max in batch b from every element in batch b. Broadcasts // along the batch dimension. - auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim}); + auto shifted_logits = xla::Sub(logits, logits_max, batch_dims); auto exp_shifted = xla::Exp(shifted_logits); const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); + xla::PrimitiveType xla_accumulation_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(accumulation_type, + &xla_accumulation_type)); auto converted = - XlaHelpers::ConvertElementType(b, exp_shifted, accumulation_type); + xla::ConvertElementType(exp_shifted, xla_accumulation_type); auto reduce = - xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + xla::Reduce(converted, xla::Zero(b, xla_accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); auto sum = XlaHelpers::ConvertElementType(b, reduce, type); auto softmax = log_ // softmax = shifted_logits - log(sum(exp(shifted_logits))) - ? xla::Sub(shifted_logits, xla::Log(sum), {kBatchDim}) + ? xla::Sub(shifted_logits, xla::Log(sum), batch_dims) // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) - : xla::Div(exp_shifted, sum, {kBatchDim}); + : xla::Div(exp_shifted, sum, batch_dims); ctx->SetOutput(0, softmax); } @@ -78,8 +89,8 @@ REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp); REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp); std::pair CrossEntropyWithLogits( - XlaOpKernelContext* ctx, DataType type, const xla::XlaOp& logits, - const xla::XlaOp& labels) { + XlaOpKernelContext* ctx, DataType type, xla::PrimitiveType xla_type, + xla::XlaOp logits, xla::XlaOp labels) { const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type); const int kBatchDim = 0; @@ -88,7 +99,7 @@ std::pair CrossEntropyWithLogits( xla::XlaBuilder* b = ctx->builder(); // Find the max in each batch, resulting in a tensor of shape [batch] auto logits_max = - xla::Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); + xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim}); // Subtract the max in batch b from every element in batch b. // Broadcasts along the batch dimension. @@ -148,12 +159,13 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel { // check that "labels" is a matrix too. const DataType type = input_type(0); + const xla::PrimitiveType xla_type = ctx->input_xla_type(0); auto logits = ctx->Input(0); auto labels = ctx->Input(1); xla::XlaOp loss, backprop; std::tie(loss, backprop) = - CrossEntropyWithLogits(ctx, type, logits, labels); + CrossEntropyWithLogits(ctx, type, xla_type, logits, labels); ctx->SetOutput(0, loss); ctx->SetOutput(1, backprop); } @@ -189,8 +201,9 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { int64 batch_size = logits_shape.dim_size(0); int64 depth = logits_shape.dim_size(1); - DataType logits_type = input_type(0); - DataType indices_type = input_type(1); + const DataType logits_type = input_type(0); + const xla::PrimitiveType xla_logits_type = ctx->input_xla_type(0); + const DataType indices_type = input_type(1); xla::XlaOp indices = ctx->Input(1); @@ -218,8 +231,8 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { labels = xla::Add(labels, nan_or_zero, {0}); xla::XlaOp loss, backprop; - std::tie(loss, backprop) = - CrossEntropyWithLogits(ctx, logits_type, ctx->Input(0), labels); + std::tie(loss, backprop) = CrossEntropyWithLogits( + ctx, logits_type, xla_logits_type, ctx->Input(0), labels); ctx->SetOutput(0, loss); ctx->SetOutput(1, backprop); } diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc index faaf8964ff7c40d75a493b03e6b400632117cb45..aaeeae01ccb303091a6d37d1aeb4b2a3377dc638 100644 --- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 8a8525efa186ed4aa02c494f7505f6245677e96e..7327258c31f21f45ff7ffffbc9db7a2a70b4a14c 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 47d282fe9ec664bbc424793e93f778ebb13c6877..4493539fe34f0ce635fdc58660d4ff90af9c9379 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index ca74cf24507e1666070751a17fb940a3ad594695..93fc14e9efca868e84444dd0e07d7f0dfa84c042 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 591e61b4c82836bc1995cd11c4c0314c9d854e50..df91900570107609c0f1c2281faaab8a5e65b98b 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 3b19f8d872c22b41ff12dd66133202126e2e39d2..5412e135478361d08965e4621ec52cfb4a792f1d 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -20,8 +20,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/prng.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -31,134 +34,6 @@ limitations under the License. namespace tensorflow { namespace { -// Rotates a 32-bit integer 'v' left by 'distance' bits. -xla::XlaOp RotateLeftS32(xla::XlaBuilder* builder, const xla::XlaOp& v, - int distance) { - return xla::Or( - xla::ShiftLeft(v, xla::ConstantR0(builder, distance)), - xla::ShiftRightLogical(v, xla::ConstantR0(builder, 32 - distance))); -} - -using ThreeFry2x32State = std::array; - -// Implements the ThreeFry counter-based PRNG algorithm. -// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. -// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder, - ThreeFry2x32State input, ThreeFry2x32State key) { - // Rotation distances specified by the Threefry2x32 algorithm. - constexpr std::array rotations = {13, 15, 26, 6, 17, 29, 16, 24}; - ThreeFry2x32State x; - - std::array ks; - // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. - ks[2] = xla::ConstantR0(builder, 0x1BD11BDA); - for (int i = 0; i < 2; ++i) { - ks[i] = key[i]; - x[i] = input[i]; - ks[2] = xla::Xor(ks[2], key[i]); - } - - x[0] = xla::Add(x[0], ks[0]); - x[1] = xla::Add(x[1], ks[1]); - - // Performs a single round of the Threefry2x32 algorithm, with a rotation - // amount 'rotation'. - auto round = [builder](ThreeFry2x32State v, int rotation) { - v[0] = xla::Add(v[0], v[1]); - v[1] = RotateLeftS32(builder, v[1], rotation); - v[1] = xla::Xor(v[0], v[1]); - return v; - }; - - // There are no known statistical flaws with 13 rounds of Threefry2x32. - // We are conservative and use 20 rounds. - x = round(x, rotations[0]); - x = round(x, rotations[1]); - x = round(x, rotations[2]); - x = round(x, rotations[3]); - x[0] = xla::Add(x[0], ks[1]); - x[1] = xla::Add(xla::Add(x[1], ks[2]), xla::ConstantR0(builder, 1)); - - x = round(x, rotations[4]); - x = round(x, rotations[5]); - x = round(x, rotations[6]); - x = round(x, rotations[7]); - x[0] = xla::Add(x[0], ks[2]); - x[1] = xla::Add(xla::Add(x[1], ks[0]), xla::ConstantR0(builder, 2)); - - x = round(x, rotations[0]); - x = round(x, rotations[1]); - x = round(x, rotations[2]); - x = round(x, rotations[3]); - x[0] = xla::Add(x[0], ks[0]); - x[1] = xla::Add(xla::Add(x[1], ks[1]), xla::ConstantR0(builder, 3)); - - x = round(x, rotations[4]); - x = round(x, rotations[5]); - x = round(x, rotations[6]); - x = round(x, rotations[7]); - x[0] = xla::Add(x[0], ks[1]); - x[1] = xla::Add(xla::Add(x[1], ks[2]), xla::ConstantR0(builder, 4)); - - x = round(x, rotations[0]); - x = round(x, rotations[1]); - x = round(x, rotations[2]); - x = round(x, rotations[3]); - x[0] = xla::Add(x[0], ks[2]); - x[1] = xla::Add(xla::Add(x[1], ks[0]), xla::ConstantR0(builder, 5)); - - return x; -} - -// Returns a tensor of 'shape' random values uniformly distributed in the range -// [minval, maxval) -xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed, - const TensorShape& shape, double minval, - double maxval) { - // Split the seed into two 32-bit scalars to form a key. - auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); - auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); - ThreeFry2x32State key = {seed0, seed1}; - const int64 size = shape.num_elements(); - - const int64 half_size = MathUtil::CeilOfRatio(size, 2); - const bool size_is_odd = (half_size * 2 != size); - - // Fill the generator inputs with unique counter values. - ThreeFry2x32State inputs; - TF_CHECK_OK(XlaHelpers::Iota(builder, DT_INT32, half_size, &inputs[0])); - inputs[1] = xla::Add(inputs[0], xla::ConstantR0(builder, half_size)); - ThreeFry2x32State outputs = ThreeFry2x32(builder, inputs, key); - - if (size_is_odd) { - outputs[1] = xla::Slice(outputs[1], {0}, {half_size - 1}, {1}); - } - - auto bits = - xla::Reshape(xla::ConcatInDim(builder, outputs, 0), shape.dim_sizes()); - - // Form 22 random mantissa bits, with a leading 1 bit. The leading 1 bit - // forces the random bits into the mantissa. - constexpr int kFloatBits = 32; - constexpr int kMantissaBits = 23; - bits = xla::Or( - xla::ShiftRightLogical( - bits, xla::ConstantR0(builder, kFloatBits - kMantissaBits)), - xla::ConstantR0(builder, bit_cast(1.0f))); - auto floats = xla::BitcastConvertType(bits, xla::F32); - - // We have a floating point number in the range [1.0, 2.0). - // Subtract 1.0f to shift to the range [0.0, 1.0) - floats = xla::Sub(floats, xla::ConstantR0(builder, 1.0f)); - // Multiply and add to shift to the range [minval, maxval). - floats = xla::Mul(floats, xla::ConstantR0(builder, maxval - minval)); - floats = xla::Add(floats, xla::ConstantR0(builder, minval)); - return floats; -} - -} // namespace - class StatelessRandomUniformOp : public XlaOpKernel { public: explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) @@ -175,7 +50,17 @@ class StatelessRandomUniformOp : public XlaOpKernel { errors::InvalidArgument("seed must have shape [2], not ", seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(1); - ctx->SetOutput(0, RandomUniform(builder, seed, shape, 0.0, 1.0)); + + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); + + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + + auto uniform = xla::StatelessRngUniform( + {seed0, seed1}, xla_shape, xla::ConstantR0(builder, 0.0), + xla::ConstantR0(builder, 1.0)); + ctx->SetOutput(0, uniform); } private: @@ -204,12 +89,20 @@ class StatelessRandomNormalOp : public XlaOpKernel { seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(1); xla::XlaBuilder* builder = ctx->builder(); - auto uniform = - RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0); + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); + + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + + auto uniform = xla::StatelessRngUniform( + {seed0, seed1}, xla_shape, + xla::ConstantR0(builder, std::nextafter(-1.0f, 0.0f)), + xla::ConstantR0(builder, 1.0)); // Convert uniform distribution to normal distribution by computing // sqrt(2) * erfinv(x) - auto normal = xla::Mul(xla::ConstantR0(builder, std::sqrt(2.0)), - ErfInv(uniform)); + auto normal = + xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform); ctx->SetOutput(0, normal); } @@ -230,8 +123,6 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const DataType dtype = output_type(0); - TensorShape shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); @@ -240,11 +131,19 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { errors::InvalidArgument("seed must have shape [2], not ", seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(1); - xla::XlaBuilder* b = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); + + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); - auto uniform = - RandomUniform(b, seed, shape, std::numeric_limits::min(), 1.0); - ctx->SetOutput(0, TruncatedNormal(dtype, uniform)); + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); + auto uniform = xla::StatelessRngUniform( + {seed0, seed1}, xla_shape, + xla::ConstantR0(builder, std::numeric_limits::min()), + xla::ConstantR0(builder, 1.0)); + + ctx->SetOutput(0, TruncatedNormal(uniform)); } private: @@ -257,4 +156,5 @@ REGISTER_XLA_OP(Name("StatelessTruncatedNormal") .TypeConstraint("Tseed", DT_INT32), StatelessTruncatedNormalOp); +} // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index c2165ccd86dfa1c119790beb20af0844fb1bbda8..1062399d91bd9a9bf8c3820c5ecac534c110746d 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 2f650ce3052ee4502912891cd3f60cfaec8b1d7c..be1814d8e3ae2c0ddad0134b9288e0ea084aa81b 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -25,8 +25,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index c9e56942625a009fb3660f413a845547192460d5..1233a37565d3a40c6dd2882b3139dedbf690a7b6 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index beb7cf263df83ca7b3fb2bd0b0603a34dd08dcbb..183879c7602ccbbd74fca6cb9fa3fc94c066c37d 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" @@ -40,119 +41,27 @@ class TopKOp : public XlaOpKernel { OP_REQUIRES(context, input_shape.dims() >= 1, errors::InvalidArgument("input must be >= 1-D, got shape ", input_shape.DebugString())); + int last_dim = input_shape.dims() - 1; + int last_dim_size = input_shape.dim_size(last_dim); OP_REQUIRES( - context, input_shape.dim_size(input_shape.dims() - 1) >= k, + context, last_dim_size >= k, errors::InvalidArgument("input must have at least k columns. Had ", - input_shape.dim_size(input_shape.dims() - 1), - ", needed ", k)); - - OP_REQUIRES( - context, input_shape.dims() == 1, - errors::Unimplemented("TopK is implemented for 1-D inputs, got shape ", - input_shape.DebugString())); - - const int64 n = input_shape.dim_size(0); - OP_REQUIRES(context, n < (1 << 16), - errors::Unimplemented( - "TopK is implemented for sizes up to 2**16, got shape ", - input_shape.DebugString())); - - xla::XlaBuilder* const b = context->builder(); - if (input_shape.dim_size(0) < k) { - k = input_shape.dim_size(0); + last_dim_size, ", needed ", k)); + if (last_dim_size < k) { + k = last_dim_size; } - const xla::XlaOp input_bf16 = context->Input(0); - xla::XlaOp iota_s32; - OP_REQUIRES_OK(context, XlaHelpers::Iota(b, DT_INT32, n, &iota_s32)); - - // TODO(b/73891930): add a key-value sort to HLO, rather than using - // bit-packing tricks here. - - xla::XlaOp zero = xla::ConstantR0(b, 0); - - // max can either be 0x7FFFFFFF or 0x8000000. Neither choice is totally - // ideal. The implications of the choice are: - // - // 0x7FFFFFFF - // 1. +0.0 > -0.0 - // 2. The elements of the inputs and outputs are bitwise identical. - // 3. The sort is unstable since a later +0.0 will appear before an earlier - // -0.0. - // - // 0x8000000 - // 1. +0.0 == -0.0 - // 2. All -0.0 in the input are replaced with +0.0 in the output. - // 3. The sort is stable. - xla::XlaOp max = xla::ConstantR0(b, 0x80000000); - xla::XlaOp index_mask = xla::ConstantR0(b, 0x0000FFFF); - xla::XlaOp value_mask = xla::ConstantR0(b, 0xFFFF0000); - - // Convert to from bf16 to f32. The lower 16-bits are zero due to the - // definition of bf16. - xla::XlaOp input_f32 = xla::ConvertElementType(input_bf16, xla::F32); - - // Negate the input to reverse sort it. The lower 16-bits are zero, because - // negating a float is just inverting the high-bit. - xla::XlaOp negative_input_f32 = xla::Neg(input_f32); - - // Convert to a sign magnitude integer. The lower 16-bits are zero, since - // bitcast convert doesn't change any bits. - xla::XlaOp negative_input_sm32 = - xla::BitcastConvertType(negative_input_f32, xla::S32); - - // Convert from sign magnitude integer to two's complement integer. The - // lower 16-bits are zero on both sides of the select. On the false side, - // the value is unchanged, and on the true side, the lower 16-bits of max - // are all zero, so the lower 16-bits of the result of the subtraction will - // also be zero. - xla::XlaOp negative_input_s32 = - xla::Select(xla::Lt(negative_input_sm32, zero), - xla::Sub(max, negative_input_sm32), negative_input_sm32); - - // In order for the Or with iota_s32 to to work properly, the lower 16-bits - // of negative_input_32 must be zero. - - // Pack elements as: - // * upper 16 bits are the value - // * lower 16 bits are the index. - xla::XlaOp packed_s32 = xla::Or(negative_input_s32, iota_s32); - - // TODO(phawkins): use a more efficient algorithm that does not require a - // full sort. - xla::XlaOp sorted_s32 = xla::Slice(xla::Sort(packed_s32), - /*start_indices=*/{0}, - /*limit_indices=*/{k}, - /*strides=*/{1}); - - // Unpack the value/index. - xla::XlaOp indices_s32 = xla::And(sorted_s32, index_mask); - xla::XlaOp negative_values_s32 = xla::And(sorted_s32, value_mask); - - // Convert from two's complement integer to sign magnitude integer. - xla::XlaOp negative_values_sm32 = - xla::Select(xla::Lt(negative_values_s32, zero), - xla::Sub(max, negative_values_s32), negative_values_s32); - - xla::XlaOp negative_values_f32 = - xla::BitcastConvertType(negative_values_sm32, xla::F32); - - // Negate the values to get back the original inputs. - xla::XlaOp values_f32 = xla::Neg(negative_values_f32); - - // Convert from f32 to bf16. - xla::XlaOp values_bf16 = xla::ConvertElementType(values_f32, xla::BF16); - - context->SetOutput(0, values_bf16); - context->SetOutput(1, indices_s32); + xla::XlaOp output_tuple = TopK(context->Input(0), k); + context->SetOutput(0, xla::GetTupleElement(output_tuple, 0)); + context->SetOutput(1, xla::GetTupleElement(output_tuple, 1)); } private: bool sorted_; }; -REGISTER_XLA_OP( - Name("TopKV2").CompileTimeConstInput("k").TypeConstraint("T", DT_BFLOAT16), - TopKOp); +REGISTER_XLA_OP(Name("TopKV2").CompileTimeConstInput("k").TypeConstraint( + "T", {DT_UINT32, DT_INT32, DT_FLOAT, DT_BFLOAT16}), + TopKOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 2e5d61e111c068a0e26dba62f29e7e268291dd1d..be5e91138656716daddcc3c7a68dbb78ecb69103 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" @@ -47,7 +49,7 @@ class ResourceApplyGradientDescent : public XlaOpKernel { var_shape.DebugString(), " vs ", delta_shape.DebugString())); - handle = xla::Sub(handle, xla::Mul(ctx->Input(1), ctx->Input(2))); + handle = handle - ctx->Input(1) * ctx->Input(2); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; @@ -55,6 +57,64 @@ REGISTER_XLA_OP( Name("ResourceApplyGradientDescent").TypeConstraint("T", kFloatTypes), ResourceApplyGradientDescent); +xla::XlaOp ProximalGradientDescentUpdate(xla::XlaOp var, xla::XlaOp lr, + xla::XlaOp l1, xla::XlaOp l2, + xla::XlaOp grad) { + xla::XlaOp one = xla::ScalarLike(lr, 1.0); + xla::XlaOp zero = xla::ScalarLike(lr, 0.0); + xla::XlaOp prox_var = var - grad * lr; + xla::XlaOp l1_gt_zero = xla::Sign(prox_var) * + xla::Max(xla::Abs(prox_var) - lr * l1, zero) / + (one + lr * l2); + xla::XlaOp l1_le_zero = prox_var / (one + lr * l2); + return xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero); +} + +class ResourceApplyProximalGradientDescent : public XlaOpKernel { + public: + explicit ResourceApplyProximalGradientDescent(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp var; + TensorShape var_shape; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + + TensorShape alpha_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("alpha is not a scalar: ", + alpha_shape.DebugString())); + TensorShape l1_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("l1 is not a scalar: ", + l1_shape.DebugString())); + TensorShape l2_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("l2 is not a scalar: ", + l2_shape.DebugString())); + TensorShape delta_shape = ctx->InputShape(4); + OP_REQUIRES( + ctx, var_shape.IsSameSize(delta_shape), + errors::InvalidArgument("var and delta do not have the same shape: ", + var_shape.DebugString(), " vs ", + delta_shape.DebugString())); + xla::XlaOp alpha = ctx->Input(1); + xla::XlaOp l1 = ctx->Input(2); + xla::XlaOp l2 = ctx->Input(3); + xla::XlaOp delta = ctx->Input(4); + var = ProximalGradientDescentUpdate(var, alpha, l1, l2, delta); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyProximalGradientDescent") + .TypeConstraint("T", kFloatTypes), + ResourceApplyProximalGradientDescent); + class ResourceApplyMomentum : public XlaOpKernel { public: explicit ResourceApplyMomentum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -94,14 +154,13 @@ class ResourceApplyMomentum : public XlaOpKernel { xla::XlaOp grad = ctx->Input(3); xla::XlaOp momentum = ctx->Input(4); - accum = xla::Add(xla::Mul(accum, momentum), grad); + accum = accum * momentum + grad; if (use_nesterov_) { // See https://github.com/tensorflow/tensorflow/pull/2798 for an // explanation of the reparameterization used here. - var = xla::Sub(var, xla::Add(xla::Mul(grad, lr), - xla::Mul(xla::Mul(accum, momentum), lr))); + var = var - (grad * lr + accum * momentum * lr); } else { - var = xla::Sub(var, xla::Mul(accum, lr)); + var = var - accum * lr; } OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); @@ -118,8 +177,6 @@ class ResourceApplyAdagrad : public XlaOpKernel { explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - DataType type = ctx->input_type(2); TensorShape var_shape, accum_shape; @@ -146,12 +203,8 @@ class ResourceApplyAdagrad : public XlaOpKernel { xla::XlaOp lr = ctx->Input(2); xla::XlaOp grad = ctx->Input(3); - accum = - xla::Add(accum, xla::Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0))); - var = xla::Sub( - var, - xla::Mul(xla::Mul(grad, lr), - xla::Pow(accum, XlaHelpers::FloatLiteral(b, type, -0.5)))); + accum = accum + xla::Square(grad); + var = var - grad * lr * xla::Rsqrt(accum); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); } @@ -159,6 +212,139 @@ class ResourceApplyAdagrad : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes), ResourceApplyAdagrad); +class ResourceApplyProximalAdagrad : public XlaOpKernel { + public: + explicit ResourceApplyProximalAdagrad(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape var_shape, accum_shape; + xla::XlaOp var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum)); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + TensorShape lr_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + TensorShape l1_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape), + errors::InvalidArgument("l1 is not a scalar: ", + l1_shape.DebugString())); + TensorShape l2_shape = ctx->InputShape(4); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape), + errors::InvalidArgument("l2 is not a scalar: ", + l2_shape.DebugString())); + TensorShape grad_shape = ctx->InputShape(5); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape: ", + var_shape.DebugString(), " vs ", grad_shape.DebugString())); + + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp l1 = ctx->Input(3); + xla::XlaOp l2 = ctx->Input(4); + xla::XlaOp grad = ctx->Input(5); + accum = accum + xla::Square(grad); + // Adagrad learning rate. + xla::XlaOp adagrad_lr = lr * xla::Rsqrt(accum); + var = ProximalGradientDescentUpdate(var, adagrad_lr, l1, l2, grad); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP( + Name("ResourceApplyProximalAdagrad").TypeConstraint("T", kFloatTypes), + ResourceApplyProximalAdagrad); + +class ResourceApplyAdagradDA : public XlaOpKernel { + public: + explicit ResourceApplyAdagradDA(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape var_shape, accum_shape, squared_accum_shape; + xla::XlaOp var, accum, squared_accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &squared_accum_shape, + &squared_accum)); + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + OP_REQUIRES( + ctx, var_shape.IsSameSize(squared_accum_shape), + errors::InvalidArgument( + "var and squared accum do not have the same shape", + var_shape.DebugString(), " ", squared_accum_shape.DebugString())); + + TensorShape grad_shape = ctx->InputShape(3); + TensorShape lr_shape = ctx->InputShape(4); + TensorShape l1_shape = ctx->InputShape(5); + TensorShape l2_shape = ctx->InputShape(6); + TensorShape global_step_shape = ctx->InputShape(7); + + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape), + errors::InvalidArgument("l1 is not a scalar: ", + l1_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape), + errors::InvalidArgument("l2 is not a scalar: ", + l2_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step_shape), + errors::InvalidArgument("global step is not a scalar: ", + global_step_shape.DebugString())); + + xla::XlaOp grad = ctx->Input(3); + xla::XlaOp lr = ctx->Input(4); + xla::XlaOp l1 = ctx->Input(5); + xla::XlaOp l2 = ctx->Input(6); + xla::XlaBuilder* const b = ctx->builder(); + xla::XlaOp global_step = + XlaHelpers::ConvertElementType(b, ctx->Input(7), dtype_); + + accum = accum + grad; + squared_accum = squared_accum + xla::Square(grad); + xla::XlaOp zero = xla::ScalarLike(lr, 0.0); + xla::XlaOp denominator = global_step * lr * l2 + xla::Sqrt(squared_accum); + xla::XlaOp l1_le_zero = -lr * accum / denominator; + xla::XlaOp l1_gt_zero = -lr * xla::Sign(accum) * + xla::Max(xla::Abs(accum) - global_step * l1, zero) / + denominator; + + var = xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, squared_accum)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyAdagradDA").TypeConstraint("T", kFloatTypes), + ResourceApplyAdagradDA); + class ResourceApplyAdam : public XlaOpKernel { public: explicit ResourceApplyAdam(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -226,18 +412,12 @@ class ResourceApplyAdam : public XlaOpKernel { // variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon) xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); - xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); - xla::XlaOp alpha = - xla::Div(xla::Mul(lr, xla::Pow(xla::Sub(one, beta2_power), half)), - xla::Sub(one, beta1_power)); - m = xla::Add(m, xla::Mul(xla::Sub(grad, m), xla::Sub(one, beta1))); - v = xla::Add( - v, xla::Mul(xla::Sub(xla::Pow(grad, two), v), xla::Sub(one, beta2))); - var = xla::Sub(var, xla::Div(xla::Mul(m, alpha), - xla::Add(xla::Pow(v, half), epsilon))); + xla::XlaOp alpha = lr * xla::Sqrt(one - beta2_power) / (one - beta1_power); + m = m + (grad - m) * (one - beta1); + v = v + (xla::Square(grad) - v) * (one - beta2); + var = var - m * alpha / (xla::Sqrt(v) + epsilon); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); @@ -250,38 +430,112 @@ class ResourceApplyAdam : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyAdam").TypeConstraint("T", kFloatTypes), ResourceApplyAdam); -class ResourceApplyRMSProp : public XlaOpKernel { +class ResourceApplyAdaMax : public XlaOpKernel { public: - explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ResourceApplyAdaMax(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); + TensorShape var_shape, m_shape, v_shape; + xla::XlaOp var, m, v; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v)); - DataType type = ctx->input_type(3); + TensorShape beta1_power_shape = ctx->InputShape(3); + TensorShape lr_shape = ctx->InputShape(4); + TensorShape beta1_shape = ctx->InputShape(5); + TensorShape beta2_shape = ctx->InputShape(6); + TensorShape epsilon_shape = ctx->InputShape(7); + TensorShape grad_shape = ctx->InputShape(8); - TensorShape var_shape, ms_shape, mom_shape; - xla::XlaOp var, ms, mom; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom)); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_shape), + errors::InvalidArgument("beta1_power is not a scalar: ", + beta1_power_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar : ", + lr_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_shape), + errors::InvalidArgument("beta1 is not a scalar: ", + beta1_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_shape), + errors::InvalidArgument("beta2 is not a scalar: ", + beta2_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon_shape.DebugString())); + OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape), + errors::InvalidArgument("var and m do not have the same shape", + var_shape.DebugString(), " ", + m_shape.DebugString())); + OP_REQUIRES(ctx, var_shape.IsSameSize(v_shape), + errors::InvalidArgument("var and v do not have the same shape", + var_shape.DebugString(), " ", + v_shape.DebugString())); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); - TensorShape lr_shape = ctx->InputShape(3); + xla::XlaOp beta1_power = ctx->Input(3); + xla::XlaOp lr = ctx->Input(4); + xla::XlaOp beta1 = ctx->Input(5); + xla::XlaOp beta2 = ctx->Input(6); + xla::XlaOp epsilon = ctx->Input(7); + xla::XlaOp grad = ctx->Input(8); + + xla::XlaOp one = xla::ScalarLike(lr, 1.0); + m = beta1 * m + (one - beta1) * grad; + v = xla::Max(beta2 * v, xla::Abs(grad)); + var = var - lr / (one - beta1_power) * (m / (v + epsilon)); + + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, v)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyAdaMax").TypeConstraint("T", kFloatTypes), + ResourceApplyAdaMax); + +class ResourceApplyRMSProp : public XlaOpKernel { + public: + explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape var_shape, ms_shape, mom_shape, mg_shape; + xla::XlaOp var, ms, mom, mg; + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput("var", dtype_, &var_shape, &var)); + if (centered_) { + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("mg", dtype_, &mg_shape, &mg)); + } + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("ms", dtype_, &ms_shape, &ms)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput("mom", dtype_, &mom_shape, &mom)); + + TensorShape lr_shape = ctx->InputShape("lr"); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), errors::InvalidArgument("lr is not a scalar: ", lr_shape.DebugString())); - TensorShape rho_shape = ctx->InputShape(4); + TensorShape rho_shape = ctx->InputShape("rho"); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape), errors::InvalidArgument("rho is not a scalar: ", rho_shape.DebugString())); - TensorShape momentum_shape = ctx->InputShape(5); + TensorShape momentum_shape = ctx->InputShape("momentum"); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape), errors::InvalidArgument("momentum is not a scalar: ", momentum_shape.DebugString())); - TensorShape epsilon_shape = ctx->InputShape(6); + TensorShape epsilon_shape = ctx->InputShape("epsilon"); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), errors::InvalidArgument("epsilon is not a scalar: ", epsilon_shape.DebugString())); - TensorShape grad_shape = ctx->InputShape(7); + TensorShape grad_shape = ctx->InputShape("grad"); // var should be the same shape as mom and ms. OP_REQUIRES(ctx, var_shape.IsSameSize(ms_shape), @@ -297,11 +551,11 @@ class ResourceApplyRMSProp : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::XlaOp lr = ctx->Input(3); - xla::XlaOp rho = ctx->Input(4); - xla::XlaOp momentum = ctx->Input(5); - xla::XlaOp epsilon = ctx->Input(6); - xla::XlaOp grad = ctx->Input(7); + xla::XlaOp lr = ctx->Input("lr"); + xla::XlaOp rho = ctx->Input("rho"); + xla::XlaOp momentum = ctx->Input("momentum"); + xla::XlaOp epsilon = ctx->Input("epsilon"); + xla::XlaOp grad = ctx->Input("grad"); // ms <- rho * ms_{t-1} + (1-rho) * grad * grad // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) @@ -320,26 +574,46 @@ class ResourceApplyRMSProp : public XlaOpKernel { // ms <- grad**2 (1 - rho) + ms * rho // // Which is the equation listed above. - xla::XlaOp new_ms = xla::Add( - ms, xla::Mul( - xla::Sub(xla::Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)), - ms), - xla::Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho))); - xla::XlaOp new_mom = - xla::Add(xla::Mul(mom, momentum), - xla::Mul(xla::Mul(grad, lr), - xla::Pow(xla::Add(new_ms, epsilon), - XlaHelpers::FloatLiteral(b, type, -0.5)))); - xla::XlaOp new_var = xla::Sub(var, new_mom); - - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var)); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms)); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, type, new_mom)); + xla::XlaOp one = xla::ScalarLike(ms, 1.0); + xla::XlaOp new_ms = xla::Square(grad) * (one - rho) + ms * rho; + xla::XlaOp denominator; + if (centered_) { + mg = grad * (one - rho) + mg * rho; + denominator = new_ms - xla::Square(mg) + epsilon; + } else { + denominator = new_ms + epsilon; + } + xla::XlaOp new_mom = mom * momentum + grad * lr * xla::Rsqrt(denominator); + xla::XlaOp new_var = var - new_mom; + + OP_REQUIRES_OK(ctx, ctx->AssignVariable("var", dtype_, new_var)); + if (centered_) { + OP_REQUIRES_OK(ctx, ctx->AssignVariable("mg", dtype_, mg)); + } + OP_REQUIRES_OK(ctx, ctx->AssignVariable("ms", dtype_, new_ms)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable("mom", dtype_, new_mom)); } + + protected: + bool centered_ = false; + + private: + DataType dtype_; }; REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes), ResourceApplyRMSProp); +class ResourceApplyCenteredRMSProp : public ResourceApplyRMSProp { + public: + explicit ResourceApplyCenteredRMSProp(OpKernelConstruction* ctx) + : ResourceApplyRMSProp(ctx) { + centered_ = true; + } +}; +REGISTER_XLA_OP( + Name("ResourceApplyCenteredRMSProp").TypeConstraint("T", kFloatTypes), + ResourceApplyCenteredRMSProp); + void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, bool has_l2_shrinkage) { xla::XlaBuilder* b = ctx->builder(); @@ -425,23 +699,18 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0); xla::XlaOp grad_to_use; if (has_l2_shrinkage) { - grad_to_use = xla::Add(grad, xla::Mul(two, xla::Mul(l2_shrinkage, var))); + grad_to_use = grad + two * l2_shrinkage * var; } else { grad_to_use = grad; } - xla::XlaOp new_accum = xla::Add(accum, xla::Pow(grad_to_use, two)); - xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, xla::Neg(lr_power)); - xla::XlaOp accum_lr_pow = xla::Pow(accum, xla::Neg(lr_power)); - linear = xla::Add( - linear, - xla::Sub(grad_to_use, - xla::Mul(xla::Div(xla::Sub(new_accum_lr_pow, accum_lr_pow), lr), - var))); - xla::XlaOp linear_clipped = xla::Clamp(xla::Neg(l1), linear, l1); - xla::XlaOp quadratic = - xla::Add(xla::Div(new_accum_lr_pow, lr), xla::Mul(two, l2)); - var = xla::Div(xla::Sub(linear_clipped, linear), quadratic); + xla::XlaOp new_accum = accum + xla::Square(grad_to_use); + xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power); + xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power); + linear = linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var; + xla::XlaOp linear_clipped = xla::Clamp(-l1, linear, l1); + xla::XlaOp quadratic = new_accum_lr_pow / lr + two * l2; + var = (linear_clipped - linear) / quadratic; accum = new_accum; OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype, var)); @@ -481,5 +750,176 @@ class ResourceApplyFtrlV2 : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes), ResourceApplyFtrlV2); +class ResourceApplyAdadelta : public XlaOpKernel { + public: + explicit ResourceApplyAdadelta(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape var_shape, accum_shape, accum_update_shape; + xla::XlaOp var, accum, accum_update; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &accum_update_shape, + &accum_update)); + + TensorShape lr_shape = ctx->InputShape(3); + TensorShape rho_shape = ctx->InputShape(4); + TensorShape epsilon_shape = ctx->InputShape(5); + TensorShape grad_shape = ctx->InputShape(6); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape), + errors::InvalidArgument("rho is not a scalar: ", + rho_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon_shape.DebugString())); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + xla::XlaOp lr = ctx->Input(3); + xla::XlaOp rho = ctx->Input(4); + xla::XlaOp epsilon = ctx->Input(5); + xla::XlaOp grad = ctx->Input(6); + + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp neg_half = XlaHelpers::FloatLiteral(b, dtype_, -0.5); + xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); + xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); + xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); + + accum = rho * accum + (one - rho) * xla::Pow(grad, two); + xla::XlaOp update = xla::Pow(accum_update + epsilon, half) * + xla::Pow(accum + epsilon, neg_half) * grad; + accum_update = rho * accum_update + (one - rho) * xla::Pow(update, two); + var = var - update * lr; + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, accum_update)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyAdadelta").TypeConstraint("T", kFloatTypes), + ResourceApplyAdadelta); + +class ResourceApplySignBase : public XlaOpKernel { + public: + explicit ResourceApplySignBase(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape var_shape, m_shape; + xla::XlaOp var, m; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m)); + OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape), + errors::InvalidArgument("var and m do not have the same shape", + var_shape.DebugString(), " ", + m_shape.DebugString())); + TensorShape grad_shape = ctx->InputShape(6); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + CheckScalarParams(ctx); + + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp alpha = ctx->Input(3); + xla::XlaOp sign_decay = ctx->Input(4); + xla::XlaOp beta = ctx->Input(5); + xla::XlaOp grad = ctx->Input(6); + + m = m * beta + grad * (xla::ScalarLike(beta, 1.0) - beta); + xla::XlaOp decay = xla::Sign(grad) * xla::Sign(m) * sign_decay; + + xla::XlaOp grad_scale = ComputeGradientScale(alpha, decay); + var = var - lr * grad_scale * grad; + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); + } + + virtual void CheckScalarParams(XlaOpKernelContext* ctx) { + TensorShape lr_shape = ctx->InputShape(2); + TensorShape sign_decay_shape = ctx->InputShape(4); + TensorShape beta_shape = ctx->InputShape(5); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sign_decay_shape), + errors::InvalidArgument("sign_decay is not a scalar: ", + sign_decay_shape.DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta_shape), + errors::InvalidArgument("beta is not a scalar: ", + beta_shape.DebugString())); + } + + virtual xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, + xla::XlaOp decay) = 0; + + private: + DataType dtype_; +}; + +class ResourceApplyAddSign : public ResourceApplySignBase { + public: + explicit ResourceApplyAddSign(OpKernelConstruction* ctx) + : ResourceApplySignBase(ctx) {} + + void CheckScalarParams(XlaOpKernelContext* ctx) override { + ResourceApplySignBase::CheckScalarParams(ctx); + TensorShape alpha_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("alpha is not a scalar: ", + alpha_shape.DebugString())); + } + + xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override { + return alpha + decay; + } +}; +REGISTER_XLA_OP(Name("ResourceApplyAddSign").TypeConstraint("T", kFloatTypes), + ResourceApplyAddSign); + +class ResourceApplyPowerSign : public ResourceApplySignBase { + public: + explicit ResourceApplyPowerSign(OpKernelConstruction* ctx) + : ResourceApplySignBase(ctx) {} + + void CheckScalarParams(XlaOpKernelContext* ctx) override { + ResourceApplySignBase::CheckScalarParams(ctx); + TensorShape logbase_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase_shape), + errors::InvalidArgument("logbase is not a scalar: ", + logbase_shape.DebugString())); + } + + xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override { + return xla::Exp(alpha * decay); + } +}; +REGISTER_XLA_OP(Name("ResourceApplyPowerSign").TypeConstraint("T", kFloatTypes), + ResourceApplyPowerSign); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 6c721c48fe3af45aff5cd0bd5e74e2693faf9f97..f9148b394212777271f9eba51313ee17b19819af 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/bounds_check.h" diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 3823f5c08740cb567db023516d72f0ebc610f7a6..0bdfc05726105e2d18362a691cbe2aab00bf77f3 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -21,7 +21,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -49,118 +51,66 @@ XLAJIT_MAKE_UNARY(Conj, xla::Conj(x)); // Return x if x>0, otherwise -x. XLAJIT_MAKE_UNARY(Abs, xla::Abs(x)); - -// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) -XLAJIT_MAKE_UNARY( - Acos, - xla::Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), - xla::Atan2(xla::Pow(xla::Sub(XlaHelpers::One(b, input_type(0)), - xla::Mul(x, x)), - XlaHelpers::FloatLiteral(b, input_type(0), - 0.5)), - xla::Add(XlaHelpers::One(b, input_type(0)), x)))); - -// acosh(x) = log(x + sqrt(x^2 - 1)) -// = log(x + sqrt((x+1)*(x-1))) -XLAJIT_MAKE_UNARY( - Acosh, - xla::Log(xla::Add( - x, xla::Pow(xla::Mul(xla::Add(x, XlaHelpers::One(b, input_type(0))), - xla::Sub(x, XlaHelpers::One(b, input_type(0)))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); - -// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) -XLAJIT_MAKE_UNARY( - Asin, - xla::Mul( - XlaHelpers::FloatLiteral(b, input_type(0), 2.0), - xla::Atan2(x, - xla::Add(XlaHelpers::One(b, input_type(0)), - xla::Pow(xla::Sub(XlaHelpers::One(b, input_type(0)), - xla::Mul(x, x)), - XlaHelpers::FloatLiteral(b, input_type(0), - 0.5)))))); - -// asinh(x) = log(x + sqrt(x^2 + 1)) -XLAJIT_MAKE_UNARY( - Asinh, - xla::Log(xla::Add( - x, xla::Pow(xla::Add(xla::Mul(x, x), XlaHelpers::One(b, input_type(0))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); - -XLAJIT_MAKE_UNARY(Atan, xla::Atan2(x, XlaHelpers::One(b, input_type(0)))); - -// atanh(x) = 0.5 * log((1 + x) / (1 - x)) -XLAJIT_MAKE_UNARY( - Atanh, - xla::Mul(xla::Log(xla::Div(xla::Add(XlaHelpers::One(b, input_type(0)), x), - xla::Sub(XlaHelpers::One(b, input_type(0)), x))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); +XLAJIT_MAKE_UNARY(Acos, xla::Acos(x)); +XLAJIT_MAKE_UNARY(Acosh, xla::Acosh(x)); +XLAJIT_MAKE_UNARY(Asin, xla::Asin(x)) +XLAJIT_MAKE_UNARY(Asinh, xla::Asinh(x)); +XLAJIT_MAKE_UNARY(Atan, xla::Atan(x)); +XLAJIT_MAKE_UNARY(Atanh, xla::Atanh(x)); XLAJIT_MAKE_UNARY(Ceil, xla::Ceil(x)); XLAJIT_MAKE_UNARY(Cos, xla::Cos(x)); -XLAJIT_MAKE_UNARY(Cosh, - xla::Mul(xla::Add(xla::Exp(x), xla::Exp(xla::Neg(x))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); +XLAJIT_MAKE_UNARY(Cosh, xla::Cosh(x)); XLAJIT_MAKE_UNARY(Sin, xla::Sin(x)); XLAJIT_MAKE_UNARY(Exp, xla::Exp(x)); - XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x)); - XLAJIT_MAKE_UNARY(Floor, xla::Floor(x)); XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x)); -XLAJIT_MAKE_UNARY(IsInf, xla::Eq(xla::Abs(x), - XlaHelpers::FloatLiteral( - b, input_type(0), - std::numeric_limits::infinity()))); +XLAJIT_MAKE_UNARY( + IsInf, + xla::Eq(xla::Abs(x), + xla::ScalarLike(x, std::numeric_limits::infinity()))); XLAJIT_MAKE_UNARY(IsNan, xla::Ne(x, x)); // Return 1/x -XLAJIT_MAKE_UNARY(Inv, xla::Div(XlaHelpers::One(b, input_type(0)), x)); -XLAJIT_MAKE_UNARY(Reciprocal, xla::Div(XlaHelpers::One(b, input_type(0)), x)); +XLAJIT_MAKE_UNARY(Inv, xla::ScalarLike(x, 1.0) / x); +XLAJIT_MAKE_UNARY(Reciprocal, xla::ScalarLike(x, 1.0) / x); XLAJIT_MAKE_UNARY(Log, xla::Log(x)); - -XLAJIT_MAKE_UNARY(Log1p, b->Log1p(x)); +XLAJIT_MAKE_UNARY(Log1p, xla::Log1p(x)); XLAJIT_MAKE_UNARY(Invert, xla::Not(x)); XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x)); -XLAJIT_MAKE_UNARY(Neg, xla::Neg(x)); +XLAJIT_MAKE_UNARY(Neg, -x); // Implements Banker's rounding: numbers that are equidistant between two // integers are rounded towards even. -static xla::XlaOp Round(xla::XlaBuilder* b, DataType dtype, - const xla::XlaOp& x) { - auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); - auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0); - auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0); +xla::XlaOp RoundToEven(xla::XlaOp x) { + auto half = xla::ScalarLike(x, 0.5); + auto one = xla::ScalarLike(x, 1.0); + auto two = xla::ScalarLike(x, 2.0); auto round_val = xla::Floor(x); - auto fraction = xla::Sub(x, round_val); - auto nearest_even_int = - xla::Sub(round_val, xla::Mul(two, xla::Floor(xla::Mul(half, x)))); + auto fraction = x - round_val; + auto nearest_even_int = round_val - two * xla::Floor(half * x); auto is_odd = xla::Eq(nearest_even_int, one); return xla::Select(xla::Or(xla::Gt(fraction, half), xla::And(xla::Eq(fraction, half), is_odd)), - xla::Add(round_val, one), round_val); + round_val + one, round_val); } -XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x)); -XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x)); +XLAJIT_MAKE_UNARY(Rint, RoundToEven(x)); +XLAJIT_MAKE_UNARY(Round, RoundToEven(x)); -XLAJIT_MAKE_UNARY(Rsqrt, xla::Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), - -0.5))); +XLAJIT_MAKE_UNARY(Rsqrt, xla::Rsqrt(x)); // Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2. -static xla::XlaOp Sigmoid(xla::XlaBuilder* b, DataType dtype, - const xla::XlaOp& x) { - auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); - return xla::Add(half, xla::Mul(half, xla::Tanh(xla::Mul(half, x)))); +xla::XlaOp Sigmoid(xla::XlaOp x) { + auto half = xla::ScalarLike(x, 0.5); + return half + half * xla::Tanh(half * x); } -XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x)); +XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(x)); // Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. XLAJIT_MAKE_UNARY(Sign, xla::Sign(x)); -XLAJIT_MAKE_UNARY(Sinh, - xla::Mul(xla::Sub(xla::Exp(x), xla::Exp(xla::Neg(x))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); +XLAJIT_MAKE_UNARY(Sinh, xla::Sinh(x)); // softplus(x) = log(1 + exp(x)) // @@ -170,18 +120,14 @@ XLAJIT_MAKE_UNARY(Sinh, // // This is equivalent to: // max(x, 0) + log1p(exp(-abs(x))) -XLAJIT_MAKE_UNARY(Softplus, - xla::Add(xla::Max(x, XlaHelpers::Zero(b, input_type(0))), - b->Log1p(xla::Exp(xla::Neg(xla::Abs(x)))))); +XLAJIT_MAKE_UNARY(Softplus, xla::Max(x, xla::ScalarLike(x, 0.0)) + + xla::Log1p(xla::Exp(-xla::Abs(x)))); // softsign(x) = x / (abs(x) + 1) -XLAJIT_MAKE_UNARY(Softsign, - xla::Div(x, xla::Add(xla::Abs(x), - XlaHelpers::One(b, input_type(0))))); -XLAJIT_MAKE_UNARY(Sqrt, - xla::Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -XLAJIT_MAKE_UNARY(Square, xla::Mul(x, x)); -XLAJIT_MAKE_UNARY(Tan, xla::Div(xla::Sin(x), xla::Cos(x))); +XLAJIT_MAKE_UNARY(Softsign, x / (xla::Abs(x) + xla::ScalarLike(x, 1.0))); +XLAJIT_MAKE_UNARY(Sqrt, xla::Sqrt(x)); +XLAJIT_MAKE_UNARY(Square, x* x); +XLAJIT_MAKE_UNARY(Tan, xla::Tan(x)); XLAJIT_MAKE_UNARY(Tanh, xla::Tanh(x)); XLAJIT_MAKE_UNARY(Real, xla::Real(x)); @@ -195,18 +141,10 @@ class ErfOp : public XlaOpKernel { public: explicit ErfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - xla::PrimitiveType primitive_type; - xla::XlaOp one = XlaHelpers::One(b, input_type(0)); xla::XlaOp x = ctx->Input(0); - xla::XlaOp abs_x = xla::Abs(x); - - OP_REQUIRES_OK(ctx, - DataTypeToPrimitiveType(input_type(0), &primitive_type)); - + xla::XlaOp one = xla::ScalarLike(x, 1.0); auto y = - xla::Select(xla::Gt(abs_x, one), xla::Sub(one, Erfc(x, primitive_type)), - Erf(x, primitive_type)); + xla::Select(xla::Gt(xla::Abs(x), one), one - xla::Erfc(x), xla::Erf(x)); ctx->SetOutput(0, y); } }; @@ -216,22 +154,60 @@ class ErfcOp : public XlaOpKernel { public: explicit ErfcOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp one = XlaHelpers::One(b, input_type(0)); xla::XlaOp x = ctx->Input(0); - xla::XlaOp abs_x = xla::Abs(x); - - xla::PrimitiveType primitive_type; - OP_REQUIRES_OK(ctx, - DataTypeToPrimitiveType(input_type(0), &primitive_type)); - + xla::XlaOp one = xla::ScalarLike(x, 1.0); auto y = - xla::Select(xla::Lt(abs_x, one), xla::Sub(one, Erf(x, primitive_type)), - Erfc(x, primitive_type)); + xla::Select(xla::Lt(xla::Abs(x), one), one - xla::Erf(x), xla::Erfc(x)); ctx->SetOutput(0, y); } }; REGISTER_XLA_OP(Name("Erfc"), ErfcOp); +class LgammaOp : public XlaOpKernel { + public: + explicit LgammaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Calculate lgamma using the Lanczos approximation + // (https://en.wikipedia.org/wiki/Lanczos_approximation). + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp input = ctx->Input(0); + xla::PrimitiveType input_type = ctx->input_xla_type(0); + + if (input_type == xla::F16 || input_type == xla::BF16) { + // The approximation works better with at least 32-bits of accuracy. + xla::XlaOp input_f32 = xla::ConvertElementType(input, xla::F32); + xla::XlaOp result_f32 = xla::Lgamma(input_f32); + xla::XlaOp result_x16 = xla::ConvertElementType(result_f32, input_type); + ctx->SetOutput(0, result_x16); + } else { + xla::XlaOp result = xla::Lgamma(input); + ctx->SetOutput(0, result); + } + } +}; // namespace +REGISTER_XLA_OP(Name("Lgamma"), LgammaOp); + +class DigammaOp : public XlaOpKernel { + public: + explicit DigammaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Calculate lgamma using the Lanczos approximation + // (https://en.wikipedia.org/wiki/Lanczos_approximation). + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp input = ctx->Input(0); + xla::PrimitiveType input_type = ctx->input_xla_type(0); + + if (input_type == xla::F16 || input_type == xla::BF16) { + // The approximation works better with at least 32-bits of accuracy. + xla::XlaOp input_f32 = xla::ConvertElementType(input, xla::F32); + xla::XlaOp result_f32 = xla::Digamma(input_f32); + xla::XlaOp result_x16 = xla::ConvertElementType(result_f32, input_type); + ctx->SetOutput(0, result_x16); + } else { + xla::XlaOp result = xla::Digamma(input); + ctx->SetOutput(0, result); + } + } +}; // namespace +REGISTER_XLA_OP(Name("Digamma"), DigammaOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index 0e5d58ecbaeb13571f82a1311e29dc0ba91c11ac..8671632976023fded04c26a9780c1a67638b0916 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index febac8287350e32fccfd4cb5613f21b9a5fbcb95..2c92a585f5679242d672d0402e617ff199b94f17 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 340165bac6a2a214d8f84d5a116a4197b1df2c7b..296518229ebf0ba46717afc4f26d5ae1551c2862 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -21,8 +21,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" @@ -299,6 +300,8 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Done building while loop"; } +REGISTER_XLA_OP(Name("While").AllowResourceTypes(), XlaWhileOp); +REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes(), XlaWhileOp); REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 04c600698c7d86808238f29cbeed6aa66acaee70..cb7a40e23d539f758d963791f1c2b4d37374ade5 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -25,8 +25,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) @@ -40,12 +40,13 @@ cc_library( ":triangular_solve", ":util", ":while_loop", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/core:lib", ], ) @@ -58,12 +59,35 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/core:protos_all_cc", ], ) +cc_library( + name = "qr", + srcs = ["qr.cc"], + hdrs = ["qr.h"], + deps = [ + ":batch_dot", + ":util", + ":while_loop", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/core:lib", + ], +) + cc_library( name = "scatter", srcs = ["scatter.cc"], @@ -71,14 +95,14 @@ cc_library( deps = [ ":util", ":while_loop", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -90,13 +114,15 @@ cc_library( deps = [ ":batch_dot", ":util", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/core:lib", ], ) @@ -108,7 +134,7 @@ xla_test( deps = [ ":triangular_solve", "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -116,7 +142,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -130,13 +156,14 @@ cc_library( srcs = ["util.cc"], hdrs = ["util.h"], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) @@ -148,7 +175,7 @@ xla_test( ":batch_dot", ":util", "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -173,8 +200,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index dd29bafcd96a4371ca33e896391008b410701ea6..f666d22ea44216beef74608bb4d9f33fb2fe82c6 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -26,92 +26,94 @@ limitations under the License. namespace tensorflow { -xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, - xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x, - bool conjugate_y) { - TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); - TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); - - // Check that both tensors have the same number of dimensions. There must be - // at least two (the batch dimensions can be empty). - if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { - return errors::InvalidArgument( - "Arguments to BatchedDot have different ranks: ", - xla::ShapeUtil::HumanString(x_shape), " vs. ", - xla::ShapeUtil::HumanString(y_shape)); - } - const int ndims = xla::ShapeUtil::Rank(x_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to BatchedDot must have rank >= 2: ", ndims); - } - - // The batch dimensions must be equal and the matrix dimensions must be - // valid. - std::vector batch_dimension_numbers; - for (int i = 0; i < ndims - 2; ++i) { - if (x_shape.dimensions(i) != y_shape.dimensions(i)) { +xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, + bool transpose_y, bool conjugate_x, bool conjugate_y) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); + + // Check that both tensors have the same number of dimensions. There must be + // at least two (the batch dimensions can be empty). + if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { return errors::InvalidArgument( - "Dimension ", i, " of inputs to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " vs ", + "Arguments to BatchedDot have different ranks: ", + xla::ShapeUtil::HumanString(x_shape), " vs. ", xla::ShapeUtil::HumanString(y_shape)); } - batch_dimension_numbers.push_back(i); - } - - int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); - int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); - if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { - return errors::InvalidArgument( - "Dimensions ", x_inner_dim, " and ", y_inner_dim, - " of arguments to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, - " vs. ", xla::ShapeUtil::HumanString(y_shape), - " transpose: ", transpose_y); - } - - // Check for zero lhs/rhs dim size. - if (xla::ShapeUtil::IsZeroElementArray(x_shape) || - xla::ShapeUtil::IsZeroElementArray(y_shape)) { - std::vector dimensions(batch_dimension_numbers.size()); - for (int i = 0; i < batch_dimension_numbers.size(); ++i) { - dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); + const int ndims = xla::ShapeUtil::Rank(x_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to BatchedDot must have rank >= 2: ", ndims); + } + + // The batch dimensions must be equal and the matrix dimensions must be + // valid. + std::vector batch_dimension_numbers; + for (int i = 0; i < ndims - 2; ++i) { + if (x_shape.dimensions(i) != y_shape.dimensions(i)) { + return errors::InvalidArgument( + "Dimension ", i, " of inputs to BatchedDot must be equal: ", + xla::ShapeUtil::HumanString(x_shape), " vs ", + xla::ShapeUtil::HumanString(y_shape)); + } + batch_dimension_numbers.push_back(i); + } + + int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); + int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); + if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { + return errors::InvalidArgument( + "Dimensions ", x_inner_dim, " and ", y_inner_dim, + " of arguments to BatchedDot must be equal: ", + xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, + " vs. ", xla::ShapeUtil::HumanString(y_shape), + " transpose: ", transpose_y); + } + + // Check for zero lhs/rhs dim size. + if (xla::ShapeUtil::IsZeroElementArray(x_shape) || + xla::ShapeUtil::IsZeroElementArray(y_shape)) { + std::vector dimensions(batch_dimension_numbers.size()); + for (int i = 0; i < batch_dimension_numbers.size(); ++i) { + dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); + } + int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); + int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); + dimensions.push_back(x_shape.dimensions(x_outer_dim)); + dimensions.push_back(y_shape.dimensions(y_outer_dim)); + return xla::Broadcast( + xla::ConstantLiteral(builder, + xla::LiteralUtil::Zero(x_shape.element_type())), + dimensions); + } + + if (x_shape.element_type() == xla::C64 && conjugate_x) { + x = xla::Conj(x); + } + if (y_shape.element_type() == xla::C64 && conjugate_y) { + y = xla::Conj(y); + } + + // If there are no batch dimensions, use a regular Dot. + // TODO(b/69062148) Remove this code when Dot emitters can be passed + // dimensions to transpose directly (i.e. without requiring a Transpose + // HLO). + if (batch_dimension_numbers.empty()) { + auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x; + auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y; + return xla::Dot(lhs, rhs); + } + + xla::DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); + dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); + for (auto batch_dimension_number : batch_dimension_numbers) { + dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); + dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); } - int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); - int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape.dimensions(x_outer_dim)); - dimensions.push_back(y_shape.dimensions(y_outer_dim)); - return xla::Broadcast( - xla::ConstantLiteral(builder, - xla::Literal::Zero(x_shape.element_type())), - dimensions); - } - - if (x_shape.element_type() == xla::C64 && conjugate_x) { - x = xla::Conj(x); - } - if (y_shape.element_type() == xla::C64 && conjugate_y) { - y = xla::Conj(y); - } - - // If there are no batch dimensions, use a regular Dot. - // TODO(b/69062148) Remove this code when Dot emitters can be passed - // dimensions to transpose directly (i.e. without requiring a Transpose HLO). - if (batch_dimension_numbers.empty()) { - auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x; - auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y; - return xla::Dot(lhs, rhs); - } - - xla::DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); - dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); - for (auto batch_dimension_number : batch_dimension_numbers) { - dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); - dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); - } - return xla::DotGeneral(x, y, dot_dnums); + return xla::DotGeneral(x, y, dot_dnums); + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index 1acc72033b05e73b0f5f88907df20cde5cfffbf0..8757b16a1ca6a8cec5e3c801c885e7bbbb2f2c76 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" namespace tensorflow { @@ -43,10 +43,9 @@ namespace tensorflow { // It is computed as: // // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, - xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x = false, - bool conjugate_y = false); +xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, + bool transpose_y = false, bool conjugate_x = false, + bool conjugate_y = false); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 397f0e3a7286ac46030ae602a4c059cd8aaa1ae1..87d73eb3f07ebd7dfa4fef50ebe76cad0c4ed117 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -22,8 +22,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -48,173 +49,163 @@ namespace { // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / // l[..., j, j] // return l -xla::StatusOr CholeskyUnblocked(xla::XlaBuilder* builder, - const xla::XlaOp& a) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int n_dims = xla::ShapeUtil::Rank(a_shape); - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - 2); - - xla::XlaOp l = Zeros(builder, a_shape); - - // Construct the for loop body to iterate over rows. - auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, - xla::XlaBuilder* body_builder) - -> xla::StatusOr> { - xla::Shape col_shape; - xla::Shape row_shape; - for (int64 d : major_dims) { - row_shape.add_dimensions(d); - col_shape.add_dimensions(d); - } - row_shape.add_dimensions(1); - row_shape.add_dimensions(n); - row_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_row = Zeros(body_builder, row_shape); - - col_shape.add_dimensions(n); - col_shape.add_dimensions(1); - col_shape.set_element_type(a_shape.element_type()); - auto mask_zeros_col = Zeros(body_builder, col_shape); - - std::vector mask_vector(n); - std::iota(mask_vector.begin(), mask_vector.end(), 0); - auto mask_range = xla::ConstantR1(body_builder, mask_vector); - auto mask_range_row = - xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims); - auto mask_range_col = - xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims); - auto body_a = loop_vars[0]; - auto body_l = loop_vars[1]; - - // row = l[..., i, :i] - // select the whole i-th row, then mask out all columns past i-1 - auto zero = xla::ConstantR0(body_builder, 0); - TF_ASSIGN_OR_RETURN(auto l_i, DynamicSliceInMinorDims(body_builder, body_l, - {i, zero}, {1, n})); - auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i); - // a[..., i, i] - TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a, - {i, i}, {1, 1})); - // np.dot(row, np.swapaxes(row, -1, -2)) - xla::XlaOp diag_dot; - TF_ASSIGN_OR_RETURN(diag_dot, BatchDot(body_builder, row, row, - /*transpose_x=*/false, - /*transpose_y=*/true)); - // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, - // np.swapaxes(row, -1, -2))) - auto l_ii = - xla::Pow(xla::Sub(a_ii, diag_dot), - FloatLiteral(body_builder, a_shape.element_type(), 0.5)); - - // a[..., i+1:, i] - // select the whole i-th column, then mask out all rows above i+1 +xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int n_dims = xla::ShapeUtil::Rank(a_shape); + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - 2); + + xla::XlaOp l = xla::ZerosLike(a); + + // Construct the for loop body to iterate over rows. + auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + xla::XlaBuilder* body_builder) + -> xla::StatusOr> { + xla::Shape col_shape; + xla::Shape row_shape; + for (int64 d : major_dims) { + row_shape.add_dimensions(d); + col_shape.add_dimensions(d); + } + row_shape.add_dimensions(1); + row_shape.add_dimensions(n); + row_shape.set_element_type(a_shape.element_type()); + auto mask_zeros_row = xla::Zeros(body_builder, row_shape); + + col_shape.add_dimensions(n); + col_shape.add_dimensions(1); + col_shape.set_element_type(a_shape.element_type()); + auto mask_zeros_col = xla::Zeros(body_builder, col_shape); + + std::vector mask_vector(n); + std::iota(mask_vector.begin(), mask_vector.end(), 0); + auto mask_range = xla::ConstantR1(body_builder, mask_vector); + auto mask_range_row = + xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims); + auto mask_range_col = + xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims); + auto body_a = loop_vars[0]; + auto body_l = loop_vars[1]; + + // row = l[..., i, :i] + // select the whole i-th row, then mask out all columns past i-1 + auto zero = xla::ConstantR0(body_builder, 0); + auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n}); + auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i); + // a[..., i, i] + auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); + // np.dot(row, np.swapaxes(row, -1, -2)) + auto diag_dot = BatchDot(row, row, + /*transpose_x=*/false, + /*transpose_y=*/true); + // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, + // np.swapaxes(row, -1, -2))) + auto l_ii = + xla::Pow(a_ii - diag_dot, + FloatLiteral(body_builder, a_shape.element_type(), 0.5)); + + // a[..., i+1:, i] + // select the whole i-th column, then mask out all rows above i+1 + auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1}); + auto a_ip1i = + xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i); + + // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / + // l[..., i, i] + // The columns in [i, n] are zeroed out in `row`, so we just have to + // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], + // r.T) + auto dot = BatchDot(body_l, row, + /*transpose_x=*/false, + /*transpose_y=*/true); + // np.dot(l[..., i+1:, :i], r.T) + auto dot_ip1 = + xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); + + body_l = + DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i}); + // Assign the diagonal after the rest of the column because otherwise the + // column assign will wrap around and overwrite the diagonal assign. + body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i}); + + return std::vector{body_a, body_l}; + }; + TF_ASSIGN_OR_RETURN( - auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1})); - auto a_ip1i = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i); - - // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / - // l[..., i, i] - // The columns in [i, n] are zeroed out in `row`, so we just have to - // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], - // r.T) - TF_ASSIGN_OR_RETURN(auto dot, BatchDot(body_builder, body_l, row, - /*transpose_x=*/false, - /*transpose_y=*/true)); - // np.dot(l[..., i+1:, :i], r.T) - auto dot_ip1 = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); - - auto col_update = xla::Div(xla::Sub(a_ip1i, dot_ip1), l_ii); - TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( - body_builder, body_l, col_update, {i})); - // Assign the diagonal after the rest of the column because otherwise the - // column assign will wrap around and overwrite the diagonal assign. - TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( - body_builder, body_l, l_ii, {i, i})); - - return std::vector{body_a, body_l}; - }; - - TF_ASSIGN_OR_RETURN( - auto cholesky_while, - XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); - - return cholesky_while[1]; + auto cholesky_while, + XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); + + return cholesky_while[1]; + }); } } // namespace -xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, - int64 block_size) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - const int ndims = xla::ShapeUtil::Rank(a_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to Cholesky must have rank >= 2: ", ndims); - } - - const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "Arguments to Cholesky must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); - } - - if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to Cholesky must be >= 1; got ", block_size); - } - - // Blocked left-looking Cholesky factorization. - // Algorithm 1 from - // Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only - // execution." Proceedings of General Purpose GPUs. ACM, 2017. - xla::XlaOp l = Zeros(builder, a_shape); - for (int64 i = 0; i < n; i += block_size) { - int64 k = std::min(block_size, n - i); - if (i > 0) { - // TODO(phawkins): consider implementing SYRK for the diagonal part of - // the panel. - // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) - TF_ASSIGN_OR_RETURN(auto lhs, - SliceInMinorDims(builder, l, {i, 0}, {n, i})); - TF_ASSIGN_OR_RETURN(auto rhs, - SliceInMinorDims(builder, l, {i, 0}, {i + k, i})); - TF_ASSIGN_OR_RETURN(auto delta, - BatchDot(builder, lhs, rhs, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN(auto before, - SliceInMinorDims(builder, a, {i, i}, {n, i + k})); - TF_ASSIGN_OR_RETURN(a, UpdateSliceInMinorDims( - builder, a, xla::Sub(before, delta), {i, i})); +xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int ndims = xla::ShapeUtil::Rank(a_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to Cholesky must have rank >= 2: ", ndims); + } + + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { + return errors::InvalidArgument( + "Arguments to Cholesky must be square matrices: ", + xla::ShapeUtil::HumanString(a_shape)); + } + + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to Cholesky must be >= 1; got ", block_size); } - // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) - TF_ASSIGN_OR_RETURN(auto x, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto factorized, CholeskyUnblocked(builder, x)); - TF_ASSIGN_OR_RETURN(l, - UpdateSliceInMinorDims(builder, l, factorized, {i, i})); - - if (i + k < n) { - // l[i+k:, i:i+k] = trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) - TF_ASSIGN_OR_RETURN(auto panel, - SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); - TF_ASSIGN_OR_RETURN(auto update, - TriangularSolve(builder, factorized, panel, - /*left_side=*/false, - /*lower=*/true, - /*transpose_a=*/true, - /*conjugate_a=*/false, - /*block_size=*/block_size)); - TF_ASSIGN_OR_RETURN( - l, UpdateSliceInMinorDims(builder, l, update, {i + k, i})); + // Blocked left-looking Cholesky factorization. + // Algorithm 1 from + // Haidar, Azzam, et al. "High-performance Cholesky factorization for + // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017. + xla::XlaOp l = xla::ZerosLike(a); + for (int64 i = 0; i < n; i += block_size) { + int64 k = std::min(block_size, n - i); + if (i > 0) { + // TODO(phawkins): consider implementing SYRK for the diagonal part of + // the panel. + // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) + auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); + auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); + auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false, + /*transpose_y=*/true); + auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); + a = UpdateSliceInMinorDims(a, before - delta, {i, i}); + } + + // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) + auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto factorized = CholeskyUnblocked(x); + l = UpdateSliceInMinorDims(l, factorized, {i, i}); + + if (i + k < n) { + // l[i+k:, i:i+k] = + // trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) + auto panel = SliceInMinorDims(a, {i + k, i}, {n, i + k}); + auto update = TriangularSolve(factorized, panel, + /*left_side=*/false, + /*lower=*/true, + /*transpose_a=*/true, + /*conjugate_a=*/false, + /*block_size=*/block_size); + l = UpdateSliceInMinorDims(l, update, {i + k, i}); + } } - } - return l; + return l; + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 20fca7969ece2729a44933fd3ef3f87230ab6cad..1bef9bb166c576ec665bb48265b4da200ddca2a0 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" namespace tensorflow { @@ -30,8 +30,7 @@ namespace tensorflow { // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. // TODO(znado): handle the complex Hermitian case -xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, - int64 block_size = 256); +xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc new file mode 100644 index 0000000000000000000000000000000000000000..fc0c1ee838190b1f1a7ca5b901c97e0a35232a97 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -0,0 +1,387 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/qr.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +namespace { + +// Computes a Householder reflection of the form: +// H = I - tau v v.T. +// such that +// H . ( x1 ) = ( x1 ) +// ( x2 ) = ( x2 ) +// ( ... ) = ( ... ) +// ( xk ) = ( beta ) +// ( ... ) ( 0 ) +// ( ... ) ( 0 ) +// Unlike the usual formulation, we allow the caller to supply 'k' rather than +// only providing the relevant part of 'x' to maintain XLA's static shape +// invariant. In addition, the implementation supports batching. +// Pseudo-code, without batching: +// alpha = x[k] +// x_copy = np.copy(x) +// x_copy[:k+1] = 0 +// xnorm = norm2(x_copy) +// if xnorm == 0: +// beta = alpha +// tau = 0 +// v = np.zeros_like(x) +// else: +// beta = - np.sign(alpha) * dlapy2(alpha, xnorm) +// tau = (beta - alpha) / beta +// v = x / (alpha - beta) +// v[k] = 1 +// return (v, tau, beta) +// TODO(phawkins): LAPACK's xLARFG implementation has code for handling +// overflows in the norm/beta calculations. Perhaps do the same here. +xla::Status House(xla::XlaOp x, xla::XlaOp k, gtl::ArraySlice batch_dims, + const int64 m, xla::XlaOp* v, xla::XlaOp* tau, + xla::XlaOp* beta) { + xla::XlaBuilder* const builder = x.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); + const xla::PrimitiveType type = x_shape.element_type(); + + std::vector batch_dim_ids(batch_dims.size()); + std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0); + const int64 minor_dim = batch_dims.size(); + + xla::XlaOp zero = xla::ScalarLike(x, 0.0); + xla::XlaOp one = xla::ScalarLike(x, 1.0); + + // alpha = x[k] + xla::XlaOp alpha = + xla::Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims); + + // Compute x[k+1:] (padded with zeros in elements 0..k) + xla::XlaOp iota = xla::Iota(builder, xla::S32, m); + xla::XlaOp x_after_k = + xla::Mul(x, xla::ConvertElementType(xla::Gt(iota, k), type), + /*broadcast_dimensions=*/{minor_dim}); + + // sigma = np.dot(x[k+1:], x[k+1:]) + auto sigma = + xla::Reduce(x_after_k * x_after_k, zero, + xla::CreateScalarAddComputation(type, builder), {minor_dim}); + // mu = np.sqrt(x[k]*x[k] + sigma) + auto mu = xla::Sqrt(xla::Square(alpha) + sigma); + + auto sigma_is_zero = xla::Eq(sigma, zero); + + *beta = xla::Select(sigma_is_zero, alpha, -xla::Sign(alpha) * mu); + *tau = xla::Select(sigma_is_zero, xla::Broadcast(zero, batch_dims), + (*beta - alpha) / *beta); + auto divisor = xla::Select(sigma_is_zero, xla::Broadcast(one, batch_dims), + alpha - *beta); + + auto e_k = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, k), type), + std::vector(batch_dims.size(), 1)); + + // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor + // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. + *v = e_k + + xla::Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids); + return Status::OK(); +} + +// Householder QR decomposition. Algorithm 5.2.1 from Golub and Van +// Loan "Matrix Computations", 4th Edition. This is an unblocked implementation +// used as an inner routine of the blocked implementation. +// Algorithm is adapted slightly so the shapes inside the loop are static, at +// the cost of some redundant computation. Since this is used as an inner block +// kernel, accumulates the Householder transformations (vs, taus) rather than +// the matrix q. +// Equivalent Python code, without batching: +// def qr(a): +// m = a.shape[0] +// n = a.shape[1] +// vs = np.zeros([m, n]) +// taus = np.zeros([n]) +// for j in xrange(min(m, n)): +// v, tau, beta = house(a[:, j], j) +// # Unusually, we apply the Householder transformation to the entirety of +// # a, wasting FLOPs to maintain the static shape invariant that XLA +// # requires. For columns that precede j this has no effect. +// a[:, :] -= tau * np.dot(v[:, np.newaxis], +// np.dot(v[np.newaxis, :], a[:, :])) +// # Form column j explicitly rather than relying on the precision of the +// # Householder update. +// a[j, j] = beta +// a[j+1:, j] = np.zeros([m - j - 1], dtype=a.dtype) +// vs[:, j] = v +// taus[j] = tau +// return (q, vs, taus) +struct QRBlockResult { + // The factored R value + xla::XlaOp r; + + // Representation of the Householder matrices I - beta v v.T + xla::XlaOp taus; // Shape: [..., n] + xla::XlaOp vs; // Shape: [..., m, n] +}; +xla::StatusOr QRBlock(xla::XlaOp a) { + xla::XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int num_dims = xla::ShapeUtil::Rank(a_shape); + if (num_dims < 2) { + return errors::InvalidArgument("Arguments to QR must have rank >= 2: ", + num_dims); + } + xla::PrimitiveType type = a_shape.element_type(); + + const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + + const int64 num_batch_dims = num_dims - 2; + std::vector batch_dims(num_batch_dims); + for (int i = 0; i < num_batch_dims; ++i) { + batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i); + } + + std::vector batch_dim_indices(num_batch_dims); + std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); + + auto qr_body_fn = + [&](xla::XlaOp j, gtl::ArraySlice values, + xla::XlaBuilder* builder) -> xla::StatusOr> { + auto a = values[0]; + auto vs = values[1]; + auto taus = values[2]; + + // v, beta = house(a[:, j], j) + auto x = DynamicSliceInMinorDims(a, {j}, {1}); + xla::XlaOp v, tau, beta; + TF_RETURN_IF_ERROR(House(xla::Collapse(x, {num_dims - 2, num_dims - 1}), j, + batch_dims, m, &v, &tau, &beta)); + + std::vector shape = batch_dims; + shape.push_back(1); + shape.push_back(m); + auto v_broadcast = xla::Reshape(v, shape); + // a[:, :] -= tau * np.dot(v[:, np.newaxis], + // np.dot(v[np.newaxis, :], a[:, :])) + auto vva = BatchDot(v_broadcast, a); + vva = BatchDot(v_broadcast, vva, /*transpose_x=*/true); + a = a - xla::Mul(tau, vva, + /*broadcast_dimensions=*/batch_dim_indices); + + // It is more precise to populate column 'k' explicitly, rather than + // computing it implicitly by applying the Householder transformation. + // a[k,k] = beta + // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype) + auto iota = xla::Reshape(xla::Iota(a.builder(), xla::S32, m), {m, 1}); + auto predecessor_mask = xla::ConvertElementType(xla::Lt(iota, j), type); + auto mask = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, j), type), + std::vector(batch_dims.size(), 1)); + auto new_x = + xla::Mul(x, predecessor_mask, + /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + + xla::Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices); + a = DynamicUpdateSliceInMinorDims(a, new_x, {j}); + + // vs[:, j] = v + vs = DynamicUpdateSliceInMinorDims( + vs, xla::Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j}); + // taus[j] = tau + taus = DynamicUpdateSliceInMinorDims( + taus, xla::Reshape(tau, ConcatVectors(batch_dims, {1})), {j}); + return std::vector{a, vs, taus}; + }; + + auto vs = xla::Zeros(builder, xla::ShapeUtil::MakeShape( + type, ConcatVectors(batch_dims, {m, n}))); + auto taus = xla::Zeros( + builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n}))); + + TF_ASSIGN_OR_RETURN(auto values, + XlaForEachIndex(std::min(m, n), xla::S32, qr_body_fn, + {a, vs, taus}, "qr", builder)); + + QRBlockResult result; + result.r = values[0]; + result.vs = values[1]; + result.taus = values[2]; + return result; +} + +// Computes W and Y such that I-WY is equivalent to the sequence of Householder +// transformations given by vs and taus. +// Golub and van Loan, "Matrix Computations", algorithm 5.1.2. +// Y = np.zeros([m, n]) +// W = np.zeros([m, n]) +// Y[:, 0] = vs[:, 0] +// W[:, 0] = -taus[0] * vs[:, 0] +// for j in xrange(1, n): +// v = vs[:, j] +// z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v)) +// W[:, j] = z +// Y[:, j] = v +// return W +// There is no need to return Y since at termination of the loop it is equal to +// vs. +xla::StatusOr ComputeWYRepresentation( + xla::PrimitiveType type, gtl::ArraySlice batch_dims, xla::XlaOp vs, + xla::XlaOp taus, int64 m, int64 n) { + std::vector batch_dim_indices(batch_dims.size()); + std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); + int64 n_index = batch_dims.size() + 1; + + auto body_fn = + [&](xla::XlaOp j, gtl::ArraySlice values, + xla::XlaBuilder* builder) -> xla::StatusOr> { + auto w = values[0]; + auto y = values[1]; + const auto vs = values[2]; + const auto taus = values[3]; + + // Want j values in range [1, ... n). + j = j + xla::ConstantR0(builder, 1); + // vs has shape [..., m, 1] + auto v = DynamicSliceInMinorDims(vs, {j}, {1}); + // beta has shape [..., 1] + auto beta = DynamicSliceInMinorDims(taus, {j}, {1}); + + // yv has shape [..., n, 1] + auto yv = BatchDot(y, v, /*transpose_x=*/true); + // wyv has shape [..., m, 1] + auto wyv = BatchDot(w, yv); + + auto z = xla::Mul( + -beta, v + wyv, + /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); + + w = DynamicUpdateSliceInMinorDims(w, z, {j}); + y = DynamicUpdateSliceInMinorDims(y, v, {j}); + + return std::vector{w, y, vs, taus}; + }; + + xla::XlaBuilder* builder = vs.builder(); + auto w = xla::Zeros(builder, xla::ShapeUtil::MakeShape( + type, ConcatVectors(batch_dims, {m, n}))); + auto y = w; + auto v = SliceInMinorDims(vs, {0}, {1}); + auto beta = SliceInMinorDims(taus, {0}, {1}); + y = UpdateSliceInMinorDims(y, v, {0}); + auto bv = xla::Mul( + -beta, v, + /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); + w = UpdateSliceInMinorDims(w, bv, {0}); + + TF_ASSIGN_OR_RETURN( + auto values, XlaForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus}, + "wy", builder)); + return values[0]; +} + +} // namespace + +// Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van Loan. +// def qr_blocked(a, block_size): +// m = a.shape[0] +// n = a.shape[1] +// q = np.eye(m) +// for i in xrange(0, min(m, n), block_size): +// k = min(block_size, min(m, n) - s) +// (a, vs, taus) = qr(a[i:, i:i+k]) +// y = vs +// w = ComputeWYRepresentation(vs, taus, m-i, k) +// a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:])) +// q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T)) +// return (q, a) +// TODO(phawkins): consider using UT transformations (in the form I - V U V') +// rather than WY transformations. +xla::StatusOr QRDecomposition(xla::XlaOp a, + int64 block_size) { + xla::XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int num_dims = xla::ShapeUtil::Rank(a_shape); + if (num_dims < 2) { + return errors::InvalidArgument("Arguments to QR must have rank >= 2: ", + num_dims); + } + xla::PrimitiveType type = a_shape.element_type(); + + const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + const int64 p = std::min(m, n); + + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to QR must be >= 1; got ", block_size); + } + + const int64 num_batch_dims = num_dims - 2; + std::vector batch_dims(num_batch_dims); + for (int i = 0; i < num_batch_dims; ++i) { + batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i); + } + + auto q = xla::Broadcast(xla::IdentityMatrix(builder, type, m, m), batch_dims); + for (int64 i = 0; i < p; i += block_size) { + int64 k = std::min(block_size, p - i); + + auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k}); + TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block)); + + a = UpdateSliceInMinorDims(a, qr_block.r, {i, i}); + + // Compute the I-WY block representation of a product of Householder + // matrices. + TF_ASSIGN_OR_RETURN(auto w, + ComputeWYRepresentation(type, batch_dims, qr_block.vs, + qr_block.taus, m - i, k)); + auto y = qr_block.vs; + + // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:])) + auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); + auto a_update = BatchDot(w, a_panel, /*transpose_x=*/true); + a_update = BatchDot(y, a_update); + a_panel = a_panel + a_update; + a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); + + // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T)) + auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); + auto q_update = BatchDot(q_panel, w); + q_update = + BatchDot(q_update, y, /*transpose_x=*/false, /*transpose_y=*/true); + q_panel = q_panel + q_update; + q = UpdateSliceInMinorDims(q, q_panel, {0, i}); + } + QRDecompositionResult result; + result.q = q; + result.r = a; + return result; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h new file mode 100644 index 0000000000000000000000000000000000000000..abd2316ac961f583dd29f90f43cf6209de30bd6a --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/qr.h @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" + +namespace tensorflow { + +// Computes the QR decompositions of a batch of matrices. That is, +// given a (batched) matrix a, computes an orthonormal matrix Q and an +// upper-triangular matrix R such that a = QR. +// `a` must be a (batched) matrix of size [..., m, n]. +// The algorithm implements a blocked QR decomposition; `block_size` is +// the block size to use. +// TODO(phawkins): handle the complex case. +struct QRDecompositionResult { + xla::XlaOp q; + xla::XlaOp r; +}; + +xla::StatusOr QRDecomposition(xla::XlaOp a, + int64 block_size = 128); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ diff --git a/tensorflow/compiler/tf2xla/lib/random.cc b/tensorflow/compiler/tf2xla/lib/random.cc index 3dfa66029ca84fad9c511e7b32a906ee41d37812..5e7cf00ee5e063aef36a9531ff87d8fe6928ca1f 100644 --- a/tensorflow/compiler/tf2xla/lib/random.cc +++ b/tensorflow/compiler/tf2xla/lib/random.cc @@ -19,14 +19,14 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/status_macros.h" namespace tensorflow { -xla::XlaOp TruncatedNormal(const DataType dtype, xla::XlaOp uniform) { - xla::XlaBuilder* builder = uniform.builder(); +xla::XlaOp TruncatedNormal(xla::XlaOp uniform) { auto normal_cdf = [](double x) { return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0; }; @@ -41,18 +41,15 @@ xla::XlaOp TruncatedNormal(const DataType dtype, xla::XlaOp uniform) { const double kBetaNormalCdf = normal_cdf(kBeta); const double kZ = kBetaNormalCdf - kAlphaNormalCdf; - xla::XlaOp one = XlaHelpers::FloatLiteral(builder, dtype, 1.0); - xla::XlaOp two = XlaHelpers::FloatLiteral(builder, dtype, 2.0); - xla::XlaOp sqrt_2 = XlaHelpers::FloatLiteral(builder, dtype, std::sqrt(2.0)); - - xla::XlaOp z = XlaHelpers::FloatLiteral(builder, dtype, kZ); - xla::XlaOp alpha_normal_cdf = - XlaHelpers::FloatLiteral(builder, dtype, kAlphaNormalCdf); + xla::XlaOp one = xla::ScalarLike(uniform, 1.0); + xla::XlaOp two = xla::ScalarLike(uniform, 2.0); + xla::XlaOp sqrt_2 = xla::ScalarLike(uniform, std::sqrt(2.0)); + xla::XlaOp z = xla::ScalarLike(uniform, kZ); + xla::XlaOp alpha_normal_cdf = xla::ScalarLike(uniform, kAlphaNormalCdf); + auto p = alpha_normal_cdf + z * uniform; // probit(p) = sqrt(2) * erfinv(2*p-1) - auto p = xla::Add(alpha_normal_cdf, xla::Mul(z, uniform)); - auto erfinv_input = xla::Sub(xla::Mul(p, two), one); - return xla::Mul(sqrt_2, ErfInv(erfinv_input)); + return sqrt_2 * xla::ErfInv(two * p - one); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/random.h b/tensorflow/compiler/tf2xla/lib/random.h index 39cbcf9c5eccffa0035ff4c5e3d9afdb129f05cc..59fc5d0433a51328bc78006ab1c3495d908b44ac 100644 --- a/tensorflow/compiler/tf2xla/lib/random.h +++ b/tensorflow/compiler/tf2xla/lib/random.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/types.pb.h" @@ -28,7 +28,7 @@ namespace tensorflow { // // The "uniform" parameter must be an array of random numbers distributed in // (0,1). -xla::XlaOp TruncatedNormal(DataType dtype, xla::XlaOp uniform); +xla::XlaOp TruncatedNormal(xla::XlaOp uniform); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 85e3d3ab85a89615cc5a01bdb4ec8f7fec30d58e..ba22eff73abab11abeb57283c63318b2e50a9ca1 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -114,7 +114,7 @@ xla::StatusOr XlaScatter( auto buffer = loop_vars[2]; auto zero_index = xla::ConstantLiteral( - body_builder, xla::Literal::Zero(indices_shape.element_type())); + body_builder, xla::LiteralUtil::Zero(indices_shape.element_type())); // Slice the i-th index from the indices array. xla::XlaOp index; @@ -132,7 +132,7 @@ xla::StatusOr XlaScatter( // Discard updates with negative indices, since some users expect this. auto index_in_range = xla::ReduceAll( xla::Le(zero_index, index), xla::ConstantR0(body_builder, true), - xla::CreateScalarAndComputation(body_builder)); + xla::CreateScalarAndComputation(xla::PRED, body_builder)); // Make the index in bounds to prevent implementation defined behavior. index = xla::Max(index, zero_index); diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h index 87309e10ede320a81d173cd0a64492f88a2c7376..13a5f1b850a612bddeeac39bef431c19925351ca 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.h +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index b9f695ac4b0d57f0fdaa5076a4a4bf5a5b989cb1..04fa10108cef66f429392951eea70e59643a2d29 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -20,631 +20,383 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/math/math_util.h" namespace tensorflow { -xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, - const xla::XlaOp& a, xla::XlaOp b, - bool left_side, bool lower, - bool transpose_a, bool conjugate_a, - int64 block_size) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have different ranks: ", - xla::ShapeUtil::HumanString(a_shape), " vs. ", - xla::ShapeUtil::HumanString(b_shape)); - } - const int ndims = xla::ShapeUtil::Rank(a_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to TriangularSolve must have rank >= 2: ", ndims); - } - // The batch dimensions must be equal. - std::vector batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - int64 b_size = b_shape.dimensions(i); - if (a_size != b_size) { - return errors::InvalidArgument( - "Batch dimensions of arguments to TriangularSolve must be equal: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); +// Get the diagonal blocks of the coefficient matrix +xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(a)); + int ndims = xla::ShapeUtil::Rank(shape); + int64 n = xla::ShapeUtil::GetDimension(shape, -1); + int64 num_blocks = n / block_size; + + xla::XlaOp diag_blocks; + + // If the coefficient matrix is exactly the block size, we just add a + // singleton dimension i.e. [..., n, n] -> [..., 1, n, n] + if (n == block_size) { + std::vector permutation(ndims); + std::iota(permutation.begin(), permutation.end(), 1); + permutation.insert(permutation.end() - 2, 0); + return Transpose(Broadcast(a, /*broadcast_sizes=*/{1}), permutation); } - batch_dimensions.push_back(a_size); - } - - if (xla::ShapeUtil::GetDimension(a_shape, -1) != - xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "The 'a' arguments to TriangularSolve must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); - } - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have incompatible matrix shapes: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); - } - - if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to TriangularSolve must be >= 1; got ", - block_size); - } - - std::map base_computations; - auto get_base_triangular_solve = - [&](int k) -> xla::StatusOr { - xla::XlaComputation& computation = base_computations[k]; - if (computation.IsNull()) { - std::unique_ptr sub = builder->CreateSubBuilder( - tensorflow::strings::StrCat("trsm_base_", k)); - - auto a_param = xla::Parameter( - sub.get(), 0, - xla::ShapeUtil::MakeShape( - b_shape.element_type(), - PrependMajorDims(sub.get(), batch_dimensions, {k, k})), - "a"); - - std::array b_lastd; - if (left_side) { - b_lastd = {k, n}; - } else { - b_lastd = {m, k}; - } - auto b_param = xla::Parameter( - sub.get(), 1, - xla::ShapeUtil::MakeShape( - b_shape.element_type(), - PrependMajorDims(sub.get(), batch_dimensions, b_lastd)), - "b"); - - // We use a left-looking or right-looking subroutine on the block diagonal - // in the lower=true cases, while falling back to a recursive call in - // others. The left-looking and right-looking subroutines are written with - // a While loop and so yields much faster compile times. Moreover, they - // can give higher performance on smaller (sub)problems. - if (left_side && lower) { - TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param, - b_param, transpose_a, - conjugate_a) - .status()); - } else if (!left_side && lower) { - TF_RETURN_IF_ERROR(TriangularSolveRightLooking(sub.get(), a_param, - b_param, transpose_a, - conjugate_a) - .status()); - } else { - TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, - left_side, lower, transpose_a, - conjugate_a, - /*block_size=*/1) - .status()); - } - TF_ASSIGN_OR_RETURN(computation, sub->Build()); + // We can grab entire blocks using gather + if (n > block_size) { + // Construct the starting indices of the diagonal blocks + auto gather_indices = + Transpose(Broadcast(Mul(Iota(builder, xla::S32, num_blocks), + xla::ConstantR0(builder, block_size)), + /*broadcast_sizes=*/{2}), + /*permutation=*/{1, 0}); + + // Gather the diagonal blocks + xla::GatherDimensionNumbers dim_numbers; + dim_numbers.add_output_window_dims(ndims - 1); + dim_numbers.add_output_window_dims(ndims); + dim_numbers.add_gather_dims_to_operand_dims(ndims - 2); + dim_numbers.add_gather_dims_to_operand_dims(ndims - 1); + dim_numbers.set_index_vector_dim(1); + diag_blocks = Gather(a, gather_indices, dim_numbers, + /*window_bounds=*/{block_size, block_size}); } - return &computation; - }; - - xla::XlaOp output = Zeros(builder, b_shape); - - // Right-looking blocked triangular solve. - // For an explanation of the algorithm, see the TRSM discussion in: - // Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation - // of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 - // (2008): 4. - - // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if - // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if - // conjugate_a is True. - - if (!left_side && lower == transpose_a) { - // for i in range(0, a.shape[-1], block_size): - for (int64 i = 0; i < n; i += block_size) { - int64 k = std::min(block_size, n - i); - - // output[..., :, i:i+k] = triangular_solve( - // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = xla::Call(builder, *solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = xla::Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {0, i})); - - // if i + k < a.shape[-1]: - // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2) - if (i + k < n) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); - } else { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, n})); - } - TF_ASSIGN_OR_RETURN(auto b_update, - BatchDot(builder, update, a_slice_2, - /*transpose_x=*/false, - /*transpose_y=*/transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/conjugate_a)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, i + k}, {m, n})); - b_update = xla::Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k})); + // The last block might be smaller than the block size, + // so we will need to pad it + if (n % block_size != 0) { + // Pad with zeros + auto last_blocks = + SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n}); + xla::PaddingConfig config = xla::MakeNoPaddingConfig(ndims); + int64 padding = block_size - n % block_size; + config.mutable_dimensions(ndims - 1)->set_edge_padding_high(padding); + config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding); + last_blocks = + Pad(last_blocks, Zero(builder, shape.element_type()), config); + + // Add a singleton dimension + // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size] + TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, + builder->GetShape(last_blocks)); + auto shape_dims = xla::AsInt64Slice(blocks_shape.dimensions()); + auto last_blocks_dims = std::vector(ndims); + std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin()); + last_blocks_dims.insert(last_blocks_dims.end() - 2, 1); + last_blocks = Reshape(last_blocks, last_blocks_dims); + + // Concatenate with the other blocks if necessary + if (n > block_size) { + diag_blocks = + xla::ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2); + } else { + diag_blocks = last_blocks; } } - } else if (left_side && lower != transpose_a) { - // for i in range(0, a.shape[-1], block_size): - for (int64 i = 0; i < m; i += block_size) { - int64 k = std::min(block_size, m - i); - - // output[..., i:i+k, :] = triangular_solve( - // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = xla::Call(builder, *solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = xla::Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); - - // if i + k < a.shape[-1]: - // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :]) - if (i + k < m) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k})); - } else { - TF_ASSIGN_OR_RETURN( - a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, m})); - } + return diag_blocks; + }); +} - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, - /*transpose_x=*/transpose_a, - /*transpose_y=*/false, - /*conjugate_x=*/conjugate_a, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {i + k, 0}, {m, n})); - b_update = xla::Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0})); - } +xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, + bool transpose_a, bool conjugate_a) { + xla::XlaBuilder* builder = diag_blocks.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + // Input is a batch of square lower triangular square matrices. Its shape is + // (..., size, size). We resize this to (num_blocks, size, size). + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(diag_blocks)); + int64 block_size = xla::ShapeUtil::GetDimension(shape, -1); + int64 num_blocks = xla::ShapeUtil::ElementsIn(shape) / + tensorflow::MathUtil::IPow(block_size, 2); + diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size}); + + // The input must be triangular because we rely on that when doing + // multiplications later on + diag_blocks = Triangle(diag_blocks, /*lower=*/lower); + + // Rescale blocks to be unit triangular, but avoid dividing by + // zero (which can happen if the last block was padded) otherwise it will + // introduce nans which will propagate + auto diags = GetMatrixDiagonal(diag_blocks); + TF_ASSIGN_OR_RETURN(xla::Shape diags_shape, builder->GetShape(diags)); + auto one = ScalarLike(diags, 1); + auto ones = Broadcast(one, xla::AsInt64Slice(diags_shape.dimensions())); + diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); + auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); + + // We can now use the fact that for an upper triangular matrix + // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have + // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks + // have been rescaled to be unit triangular, so L22 = L22' = 1. + + // Initialize the output matrix with -1s on the diagonal. We use -1 instead + // of 1 because we cannot do matrix-vector multiplies with variable shapes + // inside of a loop, or do irregularly shaped in-place updates. Hence, + // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the + // entire row i.e. we calculate + // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I]) + // which means [L21 L22 0] <- [-L21 * L11', L22, 0]. + auto identity = + IdentityMatrix(builder, shape.element_type(), block_size, block_size); + auto neg_identity = -identity; + + // The first or last diagonal element should be set to 1 instead of -1 + // though, since we never update it + auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1}); + auto start_index = (lower) ? 0 : block_size - 1; + auto output_block = DynamicUpdateSlice( + neg_identity, pos_one, + /*start_indices=*/xla::ConstantR1(builder, 2, start_index)); + + // Broadcast diag([1, -1, -1, ...]) to every block + xla::XlaOp output = Broadcast(output_block, + /*broadcast_sizes=*/{num_blocks}); + + // Now we construct a loop that performs matrix-vector multiplications + // inverting the blocks one row at a time + std::vector tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + xla::ShapeUtil::MakeShape(xla::S32, {}), + // The output has the shape of A, with one row updated each iteration. + xla::ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size}), + // The input is a loop invariant. + xla::ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size})}; + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + + auto init_i = One(builder, xla::S32); + auto init = xla::Tuple(builder, {init_i, output, scaled_diag_blocks}); + + // Construct the loop condition function. + std::unique_ptr condb = + builder->CreateSubBuilder("InvertDiagCond"); + { + auto i = GetTupleElement( + Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0); + Lt(i, xla::ConstantR0(condb.get(), block_size)); } - } else if (!left_side && lower != transpose_a) { - // for i in reversed(range(0, a.shape[-1], block_size)): - const int64 last_blk_ix = xla::RoundUpToNearest(n, block_size) - block_size; - for (int64 i = last_blk_ix; i >= 0; i -= block_size) { - int64 k = std::min(block_size, n - i); - - // output[..., :, i:i+k] triangular_solve( - // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = xla::Call(builder, *solve, {a_slice, b_slice}); - } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = xla::Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {0, i})); - - // if i - k >= 0: - // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2) - if (i - k >= 0) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); - } else { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {0, i}, {i, i + k})); - } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); + + // Construct the loop body function. + std::unique_ptr bodyb = + builder->CreateSubBuilder("InvertDiagBody"); + { + auto input_tuple = + Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple"); + + auto i = GetTupleElement(input_tuple, 0); + auto body_out = GetTupleElement(input_tuple, 1); + auto body_input = GetTupleElement(input_tuple, 2); + + auto zero = xla::ConstantR1(bodyb.get(), 1, 0); + auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i; + auto start_indices = + xla::ConcatInDim(bodyb.get(), {zero, Reshape(j, {1}), zero}, 0); + auto input_row = + DynamicSlice(body_input, start_indices, + /*slice_sizes=*/{num_blocks, 1, block_size}); + + // We want -L21 L11^{-1} + xla::DotDimensionNumbers dnums; + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(1); + auto update = -DotGeneral(input_row, body_out, dnums); + + body_out = DynamicUpdateSlice(body_out, update, start_indices); + + auto next_i = i + ScalarLike(i, 1); + xla::Tuple(bodyb.get(), {next_i, body_out, body_input}); + } + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto invert_while = While(cond, body, init); + auto inv_diag_blocks = GetTupleElement(invert_while, 1); + + // Undo the scaling + inv_diag_blocks = Div(inv_diag_blocks, diags, + /*broadcast_dimensions=*/{0, 1}); + + // Reshape back to original batch major dimensions + return Reshape(inv_diag_blocks, xla::AsInt64Slice(shape.dimensions())); + }); +} - TF_ASSIGN_OR_RETURN(auto b_update, - BatchDot(builder, update, a_slice_2, - /*transpose_x=*/false, - /*transpose_y=*/transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/conjugate_a)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, 0}, {m, i})); - b_update = xla::Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); +xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, + xla::XlaOp inv_diag_blocks, + bool left_side, bool lower, + bool transpose_a, bool conjugate_a) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, + builder->GetShape(inv_diag_blocks)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + int64 block_size = xla::ShapeUtil::GetDimension(blocks_shape, -1); + + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + int64 ndims = xla::ShapeUtil::Rank(a_shape); + int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + int64 num_blocks = n / block_size + (n % block_size != 0); + int64 m_dim = (left_side) ? -1 : -2; + int64 m = xla::ShapeUtil::GetDimension(b_shape, m_dim); + + // Initialize the solution + auto x = ZerosLike(b); + + // This loop is unrolled for performance reasons, but it could be expressed + // rolled as well since the matrices are of the same size each iteration + for (int i = 0; i < num_blocks; i++) { + // High-level intuition: We have B[i] = L[i] @ X. Since L is upper + // triangular this means B[i] = L[i, :i + 1] @ X[:i + 1]. We can split + // this into two parts: B[i] = L[i, :i] @ X[:i] + L[i, i] @ X[i] which + // can be solved for X[i] as X[i] = inv(L[i, i]) @ B[i] - L[i, :i] @ X[:i] + + // Decide whether we go from first block to last or vice versa + auto j = (left_side ^ lower ^ transpose_a) ? num_blocks - 1 - i : i; + + // Get the size of the inverse blocks (the last one might be smaller) + int64 block = (n % block_size != 0 && j + 1 == num_blocks) + ? n % block_size + : block_size; + auto inv_block = + MaybeConjugate(Collapse(SliceInMinorDims(inv_diag_blocks, {j, 0, 0}, + {j + 1, block, block}), + /*dimensions=*/{ndims - 2, ndims - 1}), + conjugate_a); + + // Get the corresponding row of B + int64 k = std::min((j + 1) * block_size, n); + std::vector start = {j * block_size, 0}; + std::vector end = {k, m}; + if (!left_side) { + std::swap(start[0], start[1]); + std::swap(end[0], end[1]); } - } - } else { // left_side && lower == transpose_a - // for i in reversed(range(0, a.shape[-1], block_size)): - const int64 last_blk_ix = xla::RoundUpToNearest(m, block_size) - block_size; - for (int64 i = last_blk_ix; i >= 0; i -= block_size) { - int64 k = std::min(block_size, m - i); - - // output[..., i:i+k, :] triangular_solve( - // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::XlaOp update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, - get_base_triangular_solve(k)); - update = xla::Call(builder, *solve, {a_slice, b_slice}); + auto b_row = SliceInMinorDims(b, start, end); + + xla::XlaOp remainder; + if (i == 0) { + remainder = b_row; } else { - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - update = xla::Div(b_slice, a_slice_conj); - } - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); - - // if i - k >= 0: - // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] - // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 - // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :]) - if (i - k >= 0) { - xla::XlaOp a_slice_2; - if (lower) { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); + // This matrix multiply involves a lot of multiplying with zero (namely, + // X[i * block_size:] = 0), but this is faster than slicing... + end = {k, n}; + if (!left_side) { + std::swap(end[0], end[1]); + } + if (transpose_a) { + std::swap(start[0], start[1]); + std::swap(end[0], end[1]); + } + auto a_row = + MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a); + if (left_side) { + remainder = b_row - BatchDot(a_row, x, transpose_a, false); } else { - TF_ASSIGN_OR_RETURN(a_slice_2, - SliceInMinorDims(builder, a, {0, i}, {i, i + k})); + remainder = b_row - BatchDot(x, a_row, false, transpose_a); } + } - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, - /*transpose_x=*/transpose_a, - /*transpose_y=*/false, - /*conjugate_x=*/conjugate_a, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, 0}, {i, n})); - b_update = xla::Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); + xla::XlaOp x_update; + auto zero = Zero(builder, xla::S32); + auto start_index = + xla::ConstantR0WithType(builder, xla::S32, j * block_size); + std::vector update_starts = {start_index, zero}; + if (left_side) { + x_update = BatchDot(inv_block, remainder, transpose_a, false); + } else { + x_update = BatchDot(remainder, inv_block, false, transpose_a); + std::swap(update_starts[0], update_starts[1]); } + x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts); } - } - return output; + return x; + }); } -xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - const int64 ndims = xla::ShapeUtil::Rank(a_shape); - - std::vector batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - batch_dimensions.push_back(a_size); - } - - // The main computation is performed in a While loop. - - // Allocate the output and set its first or last row, - // output = np.zeros_like(b) - // if transpose_a: - // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:] - // else: - // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1] - xla::XlaOp output = Zeros(builder, b_shape); - { - auto i = transpose_a ? m - 1 : 0; - TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1})); - TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {i, 0}, {i + 1, n})); - TF_ASSIGN_OR_RETURN(auto a_slice_conj, - MaybeConjugate(builder, a_slice, conjugate_a)); - auto update = xla::Div(b_slice, a_slice_conj); - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); - } - - // Construct the initial loop carry tuple, - // if transpose_a: - // init = (m-2, output, a, b) - // else: - // init = (1, output, a, b) - std::vector tuple_shapes = { - // The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), - // The output has the shape of b, with one row updated each iteration. - b_shape, - // The coefficient matrix a is a loop invariant. - a_shape, - // The right-hand-side matrix b is a loop invariant. - b_shape}; - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); - auto init_i = xla::ConstantR0(builder, transpose_a ? m - 2 : 1); - auto init = xla::Tuple(builder, {init_i, output, a, b}); - - // Construct the loop condition function, - // def cond_fun(loop_carry): - // i, output, a, b = loop_carry - // return i >= 0 if transpose_a else i < m - std::unique_ptr condb = - builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond"); - { - auto i = xla::GetTupleElement( - xla::Parameter(condb.get(), 0, tuple_shape, - "TriangularSolveLeftLookingWhileTuple"), - 0); - if (transpose_a) { - xla::Ge(i, xla::ConstantR0(condb.get(), 0)); - } else { - xla::Lt(i, xla::ConstantR0(condb.get(), m)); +xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, + bool lower, bool transpose_a, bool conjugate_a, + int64 block_size) { + xla::XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { + return errors::InvalidArgument( + "Arguments to TriangularSolve have different ranks: ", + xla::ShapeUtil::HumanString(a_shape), " vs. ", + xla::ShapeUtil::HumanString(b_shape)); } - } - TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - - // Construct the loop body function, - // def body_fun(loop_carry): - // i, output, a, b = loop_carry - // if transpose_a: - // a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2) - // else: - // a_row = a[..., i:i+1, :i] - // result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :]) - // output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] - // if transpose_a: - // return (i - 1, output, a, b) - // else: - // return (i + 1, output, a, b) - // We have to do some extra FLOPs propagating zeros in the matrix multiply - // because we can't have the size of its arguments depend on the loop counter. - std::unique_ptr bodyb = - builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody"); - { - auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape, - "TriangularSolveLeftLookingWhileTuple"); - - // i, output, a, b = loop_carry - auto i = xla::GetTupleElement(input_tuple, 0); - auto body_out = xla::GetTupleElement(input_tuple, 1); - auto body_a = xla::GetTupleElement(input_tuple, 2); - auto body_b = xla::GetTupleElement(input_tuple, 3); - auto zero = xla::ConstantR0(bodyb.get(), 0); - - // We'd like to implement this: - // if transpose_a: - // a_row = T(a[..., i+1:, i:i+1]) - // result_row = (b[..., i:i+1, :] - // - np.matmul(a_row, body_out[..., i+1:, :])) - // else: - // result_row = (b[..., i:i+1, :] - // - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :])) - // But since we can't have intermediate array sizes depend on the loop - // counter, we instead exploit the fact that we initialized the output to - // all zeros and use that as zero-padding (doing unnecessary FLOPs). - xla::XlaOp a_row; - if (transpose_a) { - TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, - {zero, i}, {m, 1})); - } else { - TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, - {i, zero}, {1, m})); + const int64 ndims = xla::ShapeUtil::Rank(a_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to TriangularSolve must have rank >= 2: ", ndims); + } + // The batch dimensions must be equal. + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape.dimensions(i); + int64 b_size = b_shape.dimensions(i); + if (a_size != b_size) { + return errors::InvalidArgument( + "Batch dimensions of arguments to TriangularSolve must be equal: ", + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); + } + batch_dimensions.push_back(a_size); } - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out, - /*transpose_x=*/transpose_a, - /*transpose_y=*/false, - /*conjugate_x=*/conjugate_a, - /*conjugate_y=*/false)); - TF_ASSIGN_OR_RETURN( - auto result_row_slice, - DynamicSliceInMinorDims(bodyb.get(), body_b, {i, zero}, {1, n})); - auto result_row = xla::Sub(result_row_slice, b_update); - - // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] - TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a, - {i, i}, {1, 1})); - TF_ASSIGN_OR_RETURN(auto a_elt_conj, - MaybeConjugate(bodyb.get(), a_elt, conjugate_a)); - auto div_result = xla::Div(result_row, a_elt_conj); - TF_ASSIGN_OR_RETURN(body_out, - DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, - div_result, {i, zero})); - - // if transpose_a: - // return (i - 1, body_out, a, b) - // else: - // return (i + 1, body_out, a, b) - auto next_i = - xla::Add(i, xla::ConstantR0(bodyb.get(), transpose_a ? -1 : 1)); - xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b}); - } - TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); - - // Construct the While loop and return the result, - // return while_loop(cond_fun, body_fun, init)[1] - auto triangular_solve_left_looking_while = xla::While(cond, body, init); - return xla::GetTupleElement(triangular_solve_left_looking_while, 1); -} -xla::StatusOr TriangularSolveRightLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a) { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - const int64 ndims = xla::ShapeUtil::Rank(a_shape); - - std::vector batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - batch_dimensions.push_back(a_size); - } - - // The main computation is performed in a While loop. - xla::XlaOp output = Zeros(builder, b_shape); - - // Construct the initial loop carry tuple, - // if transpose_a: - // init = (0, output, a, b) - // else: - // init = (n-1, output, a, b) - std::vector tuple_shapes = { - // The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), - // The output has the shape of b, with one row updated each iteration. - b_shape, - // The coefficient matrix a is a loop invariant. - a_shape, - // The right-hand-side matrix b is a loop invariant. - b_shape}; - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); - auto init_i = xla::ConstantR0(builder, transpose_a ? 0 : n - 1); - auto init = xla::Tuple(builder, {init_i, output, a, b}); - - // Construct the loop condition function, - // def cond_fun(loop_carry): - // i, output, a, b = loop_carry - // return i < n if transpose_a else i >= 0 - std::unique_ptr condb = - builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond"); - { - auto i = xla::GetTupleElement( - xla::Parameter(condb.get(), 0, tuple_shape, - "TriangularSolveRightLookingWhileTuple"), - 0); - if (transpose_a) { - xla::Lt(i, xla::ConstantR0(condb.get(), n)); - } else { - xla::Ge(i, xla::ConstantR0(condb.get(), 0)); + if (xla::ShapeUtil::GetDimension(a_shape, -1) != + xla::ShapeUtil::GetDimension(a_shape, -2)) { + return errors::InvalidArgument( + "The 'a' arguments to TriangularSolve must be square matrices: ", + xla::ShapeUtil::HumanString(a_shape)); } - } - TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - - // Construct the loop body function, - // def body_fun(loop_carry): - // i, output, a, b = loop_carry - // if transpose_a: - // a_row = np.swapaxes(a[..., :, i:i+1], -1 -2) - // else: - // a_row = a[..., :, i:i+1] - // result_row = b[..., :, i:i+1] - np.matmul(output, a_row) - // output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] - // if transpose_a: - // return (i - 1, output, a, b) - // else: - // return (i + 1, output, a, b) - // We have to do some extra FLOPs propagating zeros in the matrix multiply - // because we can't have the size of its arguments depend on the loop counter. - std::unique_ptr bodyb = - builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody"); - { - auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape, - "TriangularSolveRightLookingWhileTuple"); - - // i, output, a, b = loop_carry - auto i = xla::GetTupleElement(input_tuple, 0); - auto body_out = xla::GetTupleElement(input_tuple, 1); - auto body_a = xla::GetTupleElement(input_tuple, 2); - auto body_b = xla::GetTupleElement(input_tuple, 3); - auto zero = xla::ConstantR0(bodyb.get(), 0); - - // We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :, - // i:i+1]) But since we can't have intermediate array sizes depend on the - // loop counter, we instead exploit the fact that we initialized the output - // to all zeros and use that as zero-padding (doing unnecessary FLOPs). - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), body_out, body_a, - /*transpose_x=*/false, - /*transpose_y=*/transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/conjugate_a)); - // result = b - np.matmul(output, a) - auto result = xla::Sub(body_b, b_update); - // result_row = result[..., :, i:i+1] - TF_ASSIGN_OR_RETURN( - auto result_row, - DynamicSliceInMinorDims(bodyb.get(), result, {zero, i}, {m, 1})); - - // body_out[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] - TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(bodyb.get(), body_a, - {i, i}, {1, 1})); - TF_ASSIGN_OR_RETURN(auto a_ii_conj, - MaybeConjugate(bodyb.get(), a_ii, conjugate_a)); - auto div_result = xla::Div(result_row, a_ii_conj); - TF_ASSIGN_OR_RETURN(body_out, - DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, - div_result, {zero, i})); - - // if transpose_a: - // return (i + 1, body_out, a, b) - // else: - // return (i - 1, body_out, a, b) - auto next_i = - xla::Add(i, xla::ConstantR0(bodyb.get(), transpose_a ? 1 : -1)); - xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b}); - } - TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); - - // Construct the While loop and return the result, - // return while_loop(cond_fun, body_fun, init)[1] - auto triangular_solve_left_looking_while = xla::While(cond, body, init); - return xla::GetTupleElement(triangular_solve_left_looking_while, 1); + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { + return errors::InvalidArgument( + "Arguments to TriangularSolve have incompatible matrix shapes: ", + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); + } + + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to TriangularSolve must be >= 1; got ", + block_size); + } + + // We find the diagonal blocks of the coefficient matrix + auto diag_blocks = DiagonalBlocks(a, block_size); + + // We invert these blocks in parallel using batched matrix-vector products + auto inv_diag_blocks = + InvertDiagonalBlocks(diag_blocks, lower, transpose_a, conjugate_a); + + // We now find the solution using GEMMs + auto x = SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, + lower, transpose_a, conjugate_a); + + return x; + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index 540c26b2473df9e7885f4e549b3e516a3d8a0d43..555760b7efabddfb25c9135b109a1c48b487415e 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" namespace tensorflow { @@ -57,23 +57,9 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, - const xla::XlaOp& a, xla::XlaOp b, - bool left_side, bool lower, - bool transpose_a, bool conjugate_a, - int64 block_size = 256); - -xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a); - -xla::StatusOr TriangularSolveRightLooking(xla::XlaBuilder* builder, - const xla::XlaOp& a, - const xla::XlaOp& b, - bool transpose_a, - bool conjugate_a); +xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, + bool lower, bool transpose_a, bool conjugate_a, + int64 block_size = 128); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc index 87ea4763f7c2357ae179b68ade3715b24c46432f..aeebf16028d40189203cdfd815f06a339ee72902 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -85,11 +85,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, @@ -107,11 +106,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, @@ -129,11 +127,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, @@ -151,11 +148,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/false, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/false, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {0.5, 0.08333334, 0.04629629, 0.03367003}, @@ -173,11 +169,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, @@ -196,11 +191,32 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/true, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); + + xla::Array2D expected({ + {0.5, 1.0, 1.5}, + {0.41666667, 0.33333333, 0.25}, + {0.23148148, 0.18518519, 0.13888889}, + {0.16835017, 0.13468013, 0.1010101}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/3); xla::Array2D expected({ {0.5, 1.0, 1.5}, @@ -219,11 +235,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {0.5, 1.0, 1.5}, @@ -242,11 +257,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/false, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {-0.89646465, -0.69444444, -0.49242424}, @@ -267,11 +281,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { CreateR2Parameter(AValsLowerComplex(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRightComplex(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/false, /*lower=*/true, - /*transpose_a=*/true, /*conjugate_a=*/true, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/true, + /*block_size=*/2); xla::Array2D expected({ {0.5, complex64(0.08333333, 0.08333333), @@ -295,11 +308,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { CreateR2Parameter(AValsUpperComplex(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeftComplex(), 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, - /*left_side=*/true, /*lower=*/false, - /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/2); - TF_ASSERT_OK(result.status()); + TriangularSolve(a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); xla::Array2D expected({ {0.5, 1., 1.5}, @@ -317,49 +329,5 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { xla::ErrorSpec(1e-2, 1e-2)); } -XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, b; - auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); - auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolveLeftLooking(&builder, a, b, - /*transpose_a=*/false, - /*conjugate_a=*/false); - TF_ASSERT_OK(result.status()); - - xla::Array2D expected({ - {0.5, 1.0, 1.5}, - {0.41666667, 0.33333333, 0.25}, - {0.23148148, 0.18518519, 0.13888889}, - {0.16835017, 0.13468013, 0.1010101}, - }); - - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); -} - -XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, b; - auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); - auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); - auto result = TriangularSolveLeftLooking(&builder, a, b, - /*transpose_a=*/false, - /*conjugate_a=*/false); - TF_ASSERT_OK(result.status()); - - xla::Array2D expected({ - {0.5, 1.0, 1.5}, - {0.41666667, 0.33333333, 0.25}, - {0.23148148, 0.18518519, 0.13888889}, - {0.16835017, 0.13468013, 0.1010101}, - }); - - ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); -} - } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 11774dde08876158583cfddb87224a5ebf94c01e..8b5beba383cda45d36e2ee27ca5e3b3c5988b6b7 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -18,7 +18,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -30,7 +31,8 @@ namespace tensorflow { xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) { return xla::Broadcast( - xla::ConstantLiteral(builder, xla::Literal::Zero(shape.element_type())), + xla::ConstantLiteral(builder, + xla::LiteralUtil::Zero(shape.element_type())), xla::AsInt64Slice(shape.dimensions())); } @@ -62,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, xla::Literal literal; switch (type) { case xla::U8: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::U32: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::U64: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::S8: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::S32: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::S64: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::F32: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::F64: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::C64: - literal = std::move(*xla::Literal::CreateR0(value)); + literal = std::move(*xla::LiteralUtil::CreateR0(value)); break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; @@ -95,11 +97,11 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::BF16: literal = std::move( - *xla::Literal::CreateR0(static_cast(value))); + *xla::LiteralUtil::CreateR0(static_cast(value))); break; case xla::F16: - literal = std::move( - *xla::Literal::CreateR0(static_cast(value))); + literal = std::move(*xla::LiteralUtil::CreateR0( + static_cast(value))); break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; @@ -111,130 +113,137 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, return xla::ConstantLiteral(builder, literal); } -xla::StatusOr SliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - gtl::ArraySlice start, - gtl::ArraySlice end) { - TF_RET_CHECK(start.size() == end.size()); - int64 n_minor_dims = start.size(); - - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - n_minor_dims); - - // Prepends 0s in the major dim - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + major_dims.size()); - - // Prepends the shape of the major dims. - std::vector padded_end(n_dims); - std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); - std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); - - std::vector strides(n_dims, 1); - return xla::Slice(x, padded_start, padded_end, strides); +xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, + gtl::ArraySlice end) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_RET_CHECK(start.size() == end.size()); + int64 n_minor_dims = start.size(); + + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_RET_CHECK(n_minor_dims <= n_dims); + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - n_minor_dims); + + // Prepends 0s in the major dim + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + major_dims.size()); + + // Prepends the shape of the major dims. + std::vector padded_end(n_dims); + std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); + std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); + + std::vector strides(n_dims, 1); + return xla::Slice(x, padded_start, padded_end, strides); + }); } -std::vector PrependMajorDims(xla::XlaBuilder* builder, - const gtl::ArraySlice& major_dims, - const gtl::ArraySlice& indices) { - std::vector output(indices.size() + major_dims.size()); - std::copy(major_dims.begin(), major_dims.end(), output.begin()); - std::copy(indices.begin(), indices.end(), output.begin() + major_dims.size()); +std::vector ConcatVectors(gtl::ArraySlice xs, + gtl::ArraySlice ys) { + std::vector output(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), output.begin()); + std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); return output; } -xla::StatusOr DynamicSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector& starts, - const gtl::ArraySlice& sizes) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - int64 n_minor_dims = starts.size(); - TF_RET_CHECK(n_minor_dims == sizes.size()); - TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - sizes.size()); - TF_ASSIGN_OR_RETURN(auto padded_starts, - PrependZerosInMajorDims(builder, x, starts)); - auto padded_sizes = PrependMajorDims(builder, major_dims, sizes); - return xla::DynamicSlice(x, padded_starts, padded_sizes); +xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, + gtl::ArraySlice starts, + gtl::ArraySlice sizes) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + int64 n_minor_dims = starts.size(); + TF_RET_CHECK(n_minor_dims == sizes.size()); + TF_RET_CHECK(n_minor_dims <= n_dims); + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - sizes.size()); + auto padded_starts = PrependZerosInMajorDims(x, starts); + auto padded_sizes = ConcatVectors(major_dims, sizes); + return xla::DynamicSlice(x, padded_starts, padded_sizes); + }); } -xla::StatusOr UpdateSlice(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice start) { - // TODO(phawkins): make int64 work on all backends, remove the int32 cast. - std::vector start_as_int32(start.begin(), start.end()); - auto start_constant = xla::ConstantR1(builder, start_as_int32); - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, - builder->GetShape(start_constant)); - const int64 start_length = - xla::ShapeUtil::GetDimension(start_constant_shape, -1); - TF_RET_CHECK(start_length == n_dims); - return xla::DynamicUpdateSlice(x, update, start_constant); +xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice start) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + // TODO(phawkins): make int64 work on all backends, remove the int32 cast. + std::vector start_as_int32(start.begin(), start.end()); + auto start_constant = xla::ConstantR1(builder, start_as_int32); + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, + builder->GetShape(start_constant)); + const int64 start_length = + xla::ShapeUtil::GetDimension(start_constant_shape, -1); + TF_RET_CHECK(start_length == n_dims); + return xla::DynamicUpdateSlice(x, update, start_constant); + }); } -xla::StatusOr UpdateSliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice start) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - const int64 n_minor_dims = start.size(); - TF_RET_CHECK(n_minor_dims <= n_dims); - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + (n_dims - n_minor_dims)); - return UpdateSlice(builder, x, update, padded_start); +xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice start) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + const int64 n_minor_dims = start.size(); + TF_RET_CHECK(n_minor_dims <= n_dims); + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + (n_dims - n_minor_dims)); + return UpdateSlice(x, update, padded_start); + }); } -xla::StatusOr DynamicUpdateSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update, - const std::vector& starts) { - TF_ASSIGN_OR_RETURN(auto padded_starts, - PrependZerosInMajorDims(builder, x, starts)); +xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice starts) { + auto padded_starts = PrependZerosInMajorDims(x, starts); return xla::DynamicUpdateSlice(x, update, padded_starts); } -xla::StatusOr PrependZerosInMajorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector& starts) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - auto zero = xla::Reshape(xla::ConstantR0(builder, 0), {1}); - std::vector padded_starts(n_dims, zero); - for (int i = 0; i < starts.size(); ++i) { - padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1}); - } - return xla::ConcatInDim(builder, padded_starts, 0); +xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, + gtl::ArraySlice starts) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + auto zero = xla::Reshape(xla::ConstantR0(builder, 0), {1}); + std::vector padded_starts(n_dims, zero); + for (int i = 0; i < starts.size(); ++i) { + padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1}); + } + return xla::ConcatInDim(builder, padded_starts, 0); + }); } -xla::StatusOr TransposeInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - std::vector permutation(n_dims); - std::iota(permutation.begin(), permutation.end(), 0); - std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); - return xla::Transpose(x, permutation); +xla::XlaOp TransposeInMinorDims(xla::XlaOp x) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + std::vector permutation(n_dims); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); + return xla::Transpose(x, permutation); + }); } -xla::StatusOr MaybeConjugate(xla::XlaBuilder* builder, - const xla::XlaOp& x, bool conjugate) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - auto perform_conj = shape.element_type() == xla::C64 && conjugate; - return perform_conj ? xla::Conj(x) : x; +xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate) { + xla::XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + auto perform_conj = shape.element_type() == xla::C64 && conjugate; + return perform_conj ? xla::Conj(x) : x; + }); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 3c120a2548576d6ad46870583ca65beea63507a3..b4905c952820a45371e090aa98466654e2db9661 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -16,16 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { -// Returns a zero-filled tensor with shape `shape`. -xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape); - // Returns a floating point scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, @@ -33,7 +30,7 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, // Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros // prepended until the array is length n_dims. -xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder, +xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, gtl::ArraySlice starts); // Returns a integer scalar constant of 'type' with 'value'. @@ -41,54 +38,43 @@ xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder, xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, int64 value); -// Builds a vector of zeros of length rank(x) with the last two values being +// Builds a vector of zeros of length rank(x) with the last values being // those in `starts`. -xla::StatusOr PrependZerosInMajorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector& starts); +xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, + gtl::ArraySlice starts); // Performs a slice in the minor dimensions of a Tensor. -xla::StatusOr SliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - gtl::ArraySlice start, - gtl::ArraySlice end); +xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, + gtl::ArraySlice end); -// Builds a 1-d vector out of a concatenation of `major_dims` and `starts`. -std::vector PrependMajorDims(xla::XlaBuilder* builder, - const gtl::ArraySlice& major_dims, - const gtl::ArraySlice& indices); +// Returns the concatenation of `xs` and `ys`. +std::vector ConcatVectors(gtl::ArraySlice xs, + gtl::ArraySlice ys); // Performs a dynamic slice in the minor dimensions of a Tensor. -xla::StatusOr DynamicSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, - const std::vector& starts, const gtl::ArraySlice& sizes); +xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, + gtl::ArraySlice starts, + gtl::ArraySlice sizes); // Updates a slice of 'x', i.e., // x[start[0], ..., start[n]] = update -xla::StatusOr UpdateSlice(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice start); +xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice start); // Updates a slice of 'x', where 'start' contains a list of minor dimensions: // x[..., start[0], ..., start[n]] = update -xla::StatusOr UpdateSliceInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x, - const xla::XlaOp& update, - gtl::ArraySlice start); +xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice start); -xla::StatusOr DynamicUpdateSliceInMinorDims( - xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update, - const std::vector& starts); +xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, + gtl::ArraySlice starts); // Transposes a stack of matrices `x` by swapping the last two dimensions. -xla::StatusOr TransposeInMinorDims(xla::XlaBuilder* builder, - const xla::XlaOp& x); +xla::XlaOp TransposeInMinorDims(xla::XlaOp x); // Applies a complex conjugation operation if `a` is complex and `conjugate_a` // is true, otherwise returns its argument. -xla::StatusOr MaybeConjugate(xla::XlaBuilder* builder, - const xla::XlaOp& x, bool conjugate); +xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc index 2a332c933f49db92564650828d9505095cf8436d..442fe92c34ca26cb1a854cc90da8dc034bca79bb 100644 --- a/tensorflow/compiler/tf2xla/lib/util_test.cc +++ b/tensorflow/compiler/tf2xla/lib/util_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -70,8 +70,7 @@ XLA_TEST_F(UtilTest, Simple2dLookup) { auto a_data = CreateR2Parameter(BValsRight(), 0, "a", &builder, &a); auto x_data = CreateR0Parameter(2, 1, "x", &builder, &x); auto y_data = CreateR0Parameter(1, 2, "y", &builder, &y); - auto result = DynamicSliceInMinorDims(&builder, a, {x, y}, {1, 1}); - TF_ASSERT_OK(result.status()); + DynamicSliceInMinorDims(a, {x, y}, {1, 1}); ComputeAndCompareR2(&builder, {{10}}, {a_data.get(), x_data.get(), y_data.get()}, @@ -86,10 +85,8 @@ XLA_TEST_F(UtilTest, Simple3dLookup) { CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); auto index_data = CreateR0Parameter(1, 1, "index", &builder, &index); - TF_ASSERT_OK( - DynamicSliceInMinorDims( - &builder, a, {index, xla::ConstantR0(&builder, 0)}, {1, 4}) - .status()); + DynamicSliceInMinorDims(a, {index, xla::ConstantR0(&builder, 0)}, + {1, 4}); ComputeAndCompareR3(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}}, {a_data.get(), index_data.get()}); @@ -104,8 +101,7 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) { auto x_data = CreateR0Parameter(2, 2, "x", &builder, &x); auto y_data = CreateR0Parameter(1, 3, "y", &builder, &y); - auto result = DynamicUpdateSliceInMinorDims(&builder, a, b, {x, y}); - TF_ASSERT_OK(result.status()); + DynamicUpdateSliceInMinorDims(a, b, {x, y}); xla::Array2D expected( {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}}); @@ -128,13 +124,9 @@ XLA_TEST_F(UtilTest, RowBatchDot) { // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); - TF_ASSERT_OK_AND_ASSIGN( - auto l_index, - DynamicSliceInMinorDims( - &builder, a, {index, xla::ConstantR0(&builder, 0)}, {1, n})); - TF_ASSERT_OK(BatchDot(&builder, l_index, row, - /*transpose_x=*/false, /*transpose_y=*/true) - .status()); + auto l_index = DynamicSliceInMinorDims( + a, {index, xla::ConstantR0(&builder, 0)}, {1, n}); + BatchDot(l_index, row, /*transpose_x=*/false, /*transpose_y=*/true); ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, {a_data.get(), row_data.get(), index_data.get()}); diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index 7cc88f34d291f25814fba9f802c93117973120e7..d64394f1401d7ceea004a59c991ef6f4a1c58b41 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -100,8 +100,9 @@ xla::StatusOr> XlaForEachIndex( std::vector updated_values; updated_values.reserve(values.size()); updated_values.push_back(xla::Add( - iteration, xla::ConstantLiteral( - body_builder, xla::Literal::One(num_iterations_type)))); + iteration, + xla::ConstantLiteral(body_builder, + xla::LiteralUtil::One(num_iterations_type)))); values.remove_prefix(1); TF_ASSIGN_OR_RETURN(std::vector body_outputs, @@ -113,8 +114,8 @@ xla::StatusOr> XlaForEachIndex( std::vector values; values.reserve(initial_values.size() + 1); - values.push_back( - xla::ConstantLiteral(builder, xla::Literal::Zero(num_iterations_type))); + values.push_back(xla::ConstantLiteral( + builder, xla::LiteralUtil::Zero(num_iterations_type))); values.insert(values.end(), initial_values.begin(), initial_values.end()); TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values, diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h index 5b6684c995889efbb1378c7ac4903548891d090a..9493b1f109be0725f7f733b9f9da664264275a69 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/tf2xla/lib/while_loop.h @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index b43405a1a407b5fa98dd740c62af91e048cc9490..77da1bf29ced60e490f07abad41cf8ce96232982 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/common_runtime/dma_helper.h" namespace tensorflow { @@ -32,6 +32,23 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, return Status::OK(); } +Status HostTensorToMutableBorrowingLiteral( + Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor->dtype(), + host_tensor->shape(), &xla_shape)); + return HostTensorToMutableBorrowingLiteral(xla_shape, host_tensor, literal); +} + +Status HostTensorToMutableBorrowingLiteral( + const xla::Shape& xla_shape, Tensor* host_tensor, + xla::MutableBorrowingLiteral* literal) { + *literal = xla::MutableBorrowingLiteral( + static_cast(DMAHelper::base(host_tensor)), xla_shape); + + return Status::OK(); +} + Status HostTensorsToBorrowingLiteralTuple( tensorflow::gtl::ArraySlice host_tensors, xla::BorrowingLiteral* literal) { diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index ab7e861f3336097d2ea52487092f16edb5c14531..09d6fa811669b422532673540e4da47f47e6be4e 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -18,7 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -30,6 +30,16 @@ namespace tensorflow { // 'host_tensor'. Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, xla::BorrowingLiteral* literal); +// Returns a MutableBorrowingLiteral that utilizes the same underlying buffer +// owned by 'host_tensor', but is mutable via the xla::Literal methods. +Status HostTensorToMutableBorrowingLiteral( + Tensor* host_tensor, xla::MutableBorrowingLiteral* literal); +// Similar as above, except the literal shape is explicitly provided and used +// instead of obtaining it from the 'host_tensor'. The provided literal shape +// 'xla_shape' must be compatible with the shape of 'host_tensor'. +Status HostTensorToMutableBorrowingLiteral( + const xla::Shape& xla_shape, Tensor* host_tensor, + xla::MutableBorrowingLiteral* literal); // Returns a BorrowingLiteral tuple that utilizes the same underlying buffers // owned by 'host_tensors'. diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index f3d6787daaa1165b28ce63dfd501533fa0963edd..a3404c2b3df7bf25011359d1f5f5b88c29a3f83b 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -27,7 +28,7 @@ TEST(LiteralUtil, LiteralToHostTensor) { { std::vector int64_values = {1, 2, 3}; std::unique_ptr int64_values_literal = - xla::Literal::CreateR1(gtl::ArraySlice(int64_values)); + xla::LiteralUtil::CreateR1(gtl::ArraySlice(int64_values)); Tensor host_tensor; EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor) @@ -48,7 +49,7 @@ TEST(LiteralUtil, LiteralToHostTensor) { Tensor host_tensor; std::vector int32_values = {10, 11}; std::unique_ptr int32_values_literal = - xla::Literal::CreateR1(gtl::ArraySlice(int32_values)); + xla::LiteralUtil::CreateR1(gtl::ArraySlice(int32_values)); EXPECT_TRUE( LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor) .ok()); diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index ac768b206e2a8d163a4253432a1911152f89ce86..48568c825b7a0f13011d3d6e8e62ec5db026760f 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h index d02fc56c5b8f58f0e4cfe1779ad34fe3b79324c7..432a12a51622b56ae74a677420da321c58960ee6 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.h +++ b/tensorflow/compiler/tf2xla/tf2xla.h @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/framework/graph.pb.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 84c133ffabe20dbdaa4d5a64e035efb5e4c4c44b..56f7045a98201ed398244f9e3f5ff23788135b75 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -73,8 +75,8 @@ TEST(ConvertGraphDefToXla, Sum) { TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); // Set up arguments. - auto x_literal = xla::Literal::CreateR0(10); - auto y_literal = xla::Literal::CreateR0(32); + auto x_literal = xla::LiteralUtil::CreateR0(10); + auto y_literal = xla::LiteralUtil::CreateR0(32); auto x_global_or = client->TransferToServer(*x_literal); auto y_global_or = client->TransferToServer(*y_literal); TF_EXPECT_OK(x_global_or.status()); diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index fe7ec633eca2504faf6cbb2f5fd7f59780ab7976..e89f4733281194f0263ae8cc4907caa0ad781165 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/platform/mem.h" diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index d0b9e34e162f3412cd6662a2e2bbfe3df213c4c2..a6e78825334fec748be5fee80669649df699d2fb 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index 672e19bd93449ccc31f4af5ded23257b197a3c39..334459138b55a201c15cb87ad9feb6a03a13c5ab 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include -#include "tensorflow/compiler/aot/runtime.h" namespace tensorflow { @@ -26,20 +26,29 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, result_index_(static_data.result_index), args_(new void*[static_data.num_args]), temps_(new void*[static_data.num_temps]), + arg_index_to_temp_index_(new int32[static_data.num_args]), + num_args_(static_data.num_args), arg_names_(static_data.arg_names), result_names_(static_data.result_names), program_shape_(static_data.program_shape), hlo_profile_printer_data_(static_data.hlo_profile_printer_data) { // Allocate arg and temp buffers. if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) { - alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( + alloc_args_ = cpu_function_runtime::MallocContiguousBuffers( static_data.arg_sizes, static_data.num_args, args_, /*annotate_initialized=*/false); } - alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( + alloc_temps_ = cpu_function_runtime::MallocContiguousBuffers( static_data.temp_sizes, static_data.num_temps, temps_, /*annotate_initialized=*/true); + for (int i = 0; i < static_data.num_temps; i++) { + if (static_data.temp_sizes[i] < -1) { + int32 param_number = -(static_data.temp_sizes[i] + 2); + arg_index_to_temp_index_[param_number] = i; + } + } + // If Hlo profiling is enabled the generated code expects an appropriately // sized buffer to be passed in as the last argument. If Hlo profiling is // disabled the last function argument is still present in the function @@ -50,11 +59,24 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, } } +bool XlaCompiledCpuFunction::Run() { + // Propagate pointers to the argument buffers into the temps array. Code + // generated by XLA discovers the incoming argument pointers from the temps + // array. + for (int32 i = 0; i < num_args_; i++) { + temps_[arg_index_to_temp_index_[i]] = args_[i]; + } + raw_function_(temps_[result_index_], &run_options_, nullptr, temps_, + profile_counters_); + return true; +} + XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { - tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_); - tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); + cpu_function_runtime::FreeContiguous(alloc_args_); + cpu_function_runtime::FreeContiguous(alloc_temps_); delete[] args_; delete[] temps_; + delete[] arg_index_to_temp_index_; delete[] profile_counters_; } diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 48a8c083cacf2f6ecf9dc1817b6174c01385d035..27cfb354bf5f8ede2dcca85065411006c352a575 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -60,9 +60,19 @@ class XlaCompiledCpuFunction { // The raw function to call. RawFunction raw_function; - // Cardinality and sizes of arg and temp buffers. + // Cardinality and size of arg buffers. const intptr_t* arg_sizes = nullptr; size_t num_args = 0; + + // Cardinality and size of temp buffers. + // + // If temp_sizes[i] >= 0 then the i'th temp is a regular temporary buffer. + // + // If temp_sizes[i] == -1 then the i'th temp is a constant buffer. The + // corresponding entry in the temp buffer array needs to be set to null. + // + // If temp_sizes[i] < -1 then the i'th temp is the entry parameter + // -(temp_sizes[i] + 2). const intptr_t* temp_sizes = nullptr; size_t num_temps = 0; @@ -113,11 +123,7 @@ class XlaCompiledCpuFunction { // Runs the computation, with inputs read from arg buffers, and outputs // written to result buffers. Returns true on success and false on failure. - bool Run() { - raw_function_(temps_[result_index_], &run_options_, - const_cast(args_), temps_, profile_counters_); - return true; - } + bool Run(); // Returns the error message from the previous failed Run call. // @@ -224,6 +230,17 @@ class XlaCompiledCpuFunction { void** args_ = nullptr; void** temps_ = nullptr; + // Argument i needs to be placed in temps_[arg_index_to_temp_index_[i]] for + // XLA generated code to be able to find it. + // + // For now we need to keep around the args_ array because there is code that + // depends on args() returning a void**. However, in the future we may remove + // args_ in favor of using temps_ as the sole storage for the arguments. + int32* arg_index_to_temp_index_; + + // The number of incoming arguments. + int32 num_args_; + // Backing memory for individual arg and temp buffers. void* alloc_args_ = nullptr; void* alloc_temps_ = nullptr; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 0c98c208053b47f4f92cc46f2280271847e88b61..226c89bcf1e66b5afb43cddb03db39b931ca55a8 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -28,12 +28,14 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" @@ -231,10 +233,13 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, case XlaCompiler::Argument::kConstant: LOG(FATAL) << "Unreachable case"; case XlaCompiler::Argument::kParameter: { - TensorShape shape = - is_entry_computation - ? options_.shape_representation_fn(arg.shape, arg.type) - : arg.shape; + TensorShape shape; + if (is_entry_computation) { + TF_ASSIGN_OR_RETURN( + shape, options_.shape_representation_fn(arg.shape, arg.type)); + } else { + shape = arg.shape; + } return TensorShapeToXLAShape(arg.type, shape, xla_shape); } case XlaCompiler::Argument::kResource: { @@ -242,8 +247,9 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, switch (arg.resource_kind) { case XlaResource::kVariable: { - TensorShape representation_shape = - options_.shape_representation_fn(arg.shape, arg.type); + TF_ASSIGN_OR_RETURN( + TensorShape representation_shape, + options_.shape_representation_fn(arg.shape, arg.type)); return TensorShapeToXLAShape(arg.type, representation_shape, xla_shape); } @@ -418,16 +424,18 @@ Status BuildComputation( // assignment will be placed on this value, which will cause the resource // update to be returned from the same device that provided the resource. handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0); - elems.push_back(handle); } } *num_computation_outputs = elems.size(); - // Builds the XLA computation. - if (always_return_tuple || elems.size() != 1) { - xla::Tuple(builder, elems); + // Builds the XLA computation. We *always* form a tuple here to ensure that + // the output value is the last thing added into the XLA computation, even + // if there is only one output value. + auto tuple = xla::Tuple(builder, elems); + if (!always_return_tuple && elems.size() == 1) { + xla::GetTupleElement(tuple, 0); } builder->ClearOpMetadata(); @@ -664,20 +672,17 @@ Status XlaCompiler::CompileSingleOp( namespace { // Check that the ops of all non-functional nodes have been registered. -string ValidateFunctionDef(const FunctionDef* fdef, +Status ValidateFunctionDef(const FunctionDef* fdef, const FunctionLibraryDefinition& flib_def) { - std::vector invalid_ops; for (const NodeDef& node : fdef->node_def()) { const string& op = node.op(); if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) { continue; } const OpDef* op_def; - if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) { - invalid_ops.push_back(op); - } + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def)); } - return tensorflow::str_util::Join(invalid_ops, ", "); + return Status::OK(); } // Check that the graph doesn't have any invalid nodes (e.g. incompatible with @@ -685,35 +690,33 @@ string ValidateFunctionDef(const FunctionDef* fdef, Status ValidateGraph(const Graph* graph, const FunctionLibraryDefinition& flib_def, const DeviceType& device_type, const string& name) { - std::set invalid_ops; + auto maybe_error = [&](const Node* node, const Status& s) -> Status { + if (!s.ok()) { + return errors::InvalidArgument(strings::StrCat( + "Detected unsupported operations when trying to compile graph ", name, + " on ", device_type.type_string(), ": ", node->def().op(), " (", + s.error_message(), ")", FormatNodeForError(*node))); + } + return Status::OK(); + }; + for (const Node* node : graph->nodes()) { if (node->type_string() == FunctionLibraryDefinition::kGradientOp) { continue; } const FunctionDef* fdef = flib_def.Find(node->def().op()); + Status s; if (fdef) { - string error_msg = ValidateFunctionDef(fdef, flib_def); - if (!error_msg.empty()) { - invalid_ops.insert( - strings::StrCat(node->def().op(), ":{", error_msg, "}")); - } + s = ValidateFunctionDef(fdef, flib_def); + TF_RETURN_IF_ERROR(maybe_error(node, s)); continue; } const OpDef* op_def; - if (!OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def).ok()) { - invalid_ops.insert(node->def().op()); - continue; - } + s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def); + TF_RETURN_IF_ERROR(maybe_error(node, s)); TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def)); - if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) { - invalid_ops.insert(node->def().op()); - } - } - if (!invalid_ops.empty()) { - return errors::InvalidArgument(strings::StrCat( - "Detected unsupported operations when trying to compile graph ", name, - " on ", device_type.type_string(), ":", - tensorflow::str_util::Join(invalid_ops, ", "))); + s = FindKernelDef(device_type, node->def(), nullptr, nullptr); + TF_RETURN_IF_ERROR(maybe_error(node, s)); } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 80593eaca5e695cd93f14d52d4af88e7624bf105..25332c8d8e3210a0217a1ba3f5767115fe6b1d93 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" @@ -242,13 +244,20 @@ class XlaCompiler { std::shared_ptr computation; }; - typedef std::function + typedef std::function(const TensorShape&, + DataType)> ShapeRepresentationFn; struct Options { // Name of the compilation device to use. It must be set by the caller. // The default empty value is invalid. DeviceType device_type = DeviceType(""); + // The device to use during compilation to execute instructions on, for + // example for auto-tuning. + // Valid values are defined by `xla::Backend::devices_ordinal_supported()`. + // -1 indicates the default device should be used. + int device_ordinal = -1; + xla::Client* client = nullptr; // Function library in which to find function definitions. Must be non-null. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 613230452b74755ce7543ec2ab82861aa0dfeb7a..be00ed8813fdf2778d6af81556001ef51538dd34 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -206,9 +206,9 @@ TEST_F(XlaCompilerTest, Simple) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::Literal::CreateR1({7, 42}); + xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param1_literal = - xla::Literal::CreateR1({-3, 101}); + xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -222,12 +222,64 @@ TEST_F(XlaCompilerTest, Simple) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected0 = - xla::Literal::CreateR1({4, 143}); + xla::LiteralUtil::CreateR1({4, 143}); std::unique_ptr expected_literal = - xla::Literal::MakeTuple({expected0.get()}); + xla::LiteralUtil::MakeTuple({expected0.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } +// Tests compilation of a graph where the _Retval node is not necessarily last +// amongst the graph nodes in construction order, and always_return_tuple is +// false. Regression test for bug where the wrong value was returned. +TEST_F(XlaCompilerTest, OutOfOrderGraph) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1); + // The _Retval node is not last in construction order. + auto d = ops::_Retval(scope.WithOpName("D"), a, 0); + auto c = ops::Add(scope.WithOpName("C"), a, b); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[1].kind = XlaCompiler::Argument::kParameter; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompileOptions compile_options; + compile_options.always_return_tuple = false; + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + // Tests that the generated computation works. + std::unique_ptr param0_literal = + xla::LiteralUtil::CreateR1({7, 42}); + std::unique_ptr param1_literal = + xla::LiteralUtil::CreateR1({-3, 101}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr actual = + client_ + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal)); +} + TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { // Builds a graph that adds reshapes a tensor, but with the shape not // statically known. @@ -260,7 +312,7 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { str_util::StrContains(status.error_message(), "depends on a parameter")) << status.error_message(); EXPECT_TRUE( - str_util::StrContains(status.error_message(), "[[Node: C = Reshape")) + str_util::StrContains(status.error_message(), "[[{{node C}} = Reshape")) << status.error_message(); } @@ -306,7 +358,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::Literal::CreateR1({7, 42}); + xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -317,9 +369,9 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected0 = - xla::Literal::CreateR1({-7, -42}); + xla::LiteralUtil::CreateR1({-7, -42}); std::unique_ptr expected_literal = - xla::Literal::MakeTuple({expected0.get()}); + xla::LiteralUtil::MakeTuple({expected0.get()}); EXPECT_TRUE( xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -341,7 +393,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::Literal::CreateR1({7, 42}); + xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -351,11 +403,12 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = xla::Literal::CreateR0(7); + std::unique_ptr expected0 = + xla::LiteralUtil::CreateR0(7); std::unique_ptr expected1 = - xla::Literal::CreateR1({-7, -42}); + xla::LiteralUtil::CreateR1({-7, -42}); std::unique_ptr expected = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal)); } } @@ -569,11 +622,11 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { // Tests that the generated computation works. std::unique_ptr input_base = - xla::Literal::CreateR1({7, 42}); + xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr input_grad2 = - xla::Literal::CreateR1({-3, 101}); + xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr input = - xla::Literal::MakeTuple({input_base.get(), input_grad2.get()}); + xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()}); std::unique_ptr param0_data = client_->TransferToServer(*input).ConsumeValueOrDie(); @@ -583,17 +636,18 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr output_read = xla::Literal::CreateR0(42); + std::unique_ptr output_read = + xla::LiteralUtil::CreateR0(42); std::unique_ptr output_base = - xla::Literal::CreateR1({7, 42}); + xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr output_grad1 = - xla::Literal::CreateR1({0, 1}); + xla::LiteralUtil::CreateR1({0, 1}); std::unique_ptr output_grad2 = - xla::Literal::CreateR1({-3, 101}); - std::unique_ptr output_resource = xla::Literal::MakeTuple( + xla::LiteralUtil::CreateR1({-3, 101}); + std::unique_ptr output_resource = xla::LiteralUtil::MakeTuple( {output_base.get(), output_grad1.get(), output_grad2.get()}); std::unique_ptr expected_literal = - xla::Literal::MakeTuple({output_read.get(), output_resource.get()}); + xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -796,9 +850,9 @@ TEST_F(XlaCompilerTest, Variables) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::Literal::CreateR1({7, 42}); + xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param1_literal = - xla::Literal::CreateR1({-3, 101}); + xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -812,11 +866,11 @@ TEST_F(XlaCompilerTest, Variables) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected0 = - xla::Literal::CreateR1({5, 144}); + xla::LiteralUtil::CreateR1({5, 144}); std::unique_ptr expected1 = - xla::Literal::CreateR1({4, 143}); + xla::LiteralUtil::CreateR1({4, 143}); std::unique_ptr expected_literal = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -884,9 +938,9 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::Literal::CreateR2({{4, 55}, {1, -3}}); + xla::LiteralUtil::CreateR2({{4, 55}, {1, -3}}); std::unique_ptr param1_literal = - xla::Literal::CreateR1({22, 11, 33, 404}); + xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -900,11 +954,11 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected0 = - xla::Literal::CreateR2({{27, 67}, {35, 402}}); + xla::LiteralUtil::CreateR2({{27, 67}, {35, 402}}); std::unique_ptr expected1 = - xla::Literal::CreateR1({26, 66, 34, 401}); + xla::LiteralUtil::CreateR1({26, 66, 34, 401}); std::unique_ptr expected_literal = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -953,9 +1007,9 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::Literal::CreateR1({4, 55, 1, -3}); + xla::LiteralUtil::CreateR1({4, 55, 1, -3}); std::unique_ptr param1_literal = - xla::Literal::CreateR1({22, 11, 33, 404}); + xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -969,11 +1023,11 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected0 = - xla::Literal::CreateR1({27, 67, 35, 402}); + xla::LiteralUtil::CreateR1({27, 67, 35, 402}); std::unique_ptr expected1 = - xla::Literal::CreateR1({26, 66, 34, 401}); + xla::LiteralUtil::CreateR1({26, 66, 34, 401}); std::unique_ptr expected_literal = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -1021,8 +1075,9 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", std::move(graph), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE( - str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}")) + EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp")) + << status.error_message(); + EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node fill_fn}}")) << status.error_message(); } @@ -1048,6 +1103,8 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { EXPECT_TRUE(str_util::StrContains(status.error_message(), "is not in the list of allowed values")) << status.error_message(); + EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node Shape}}")) + << status.error_message(); } TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { @@ -1069,9 +1126,10 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", std::move(graph_copy), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), - "The following nodes are unreachable " - "from the source in the graph: NoOp")) + EXPECT_TRUE( + str_util::StrContains(status.error_message(), + "The following nodes are unreachable " + "from the source in the graph: {{node NoOp}}")) << status.error_message(); } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index d0b560690758a4d73c4836ad97470d52e45fc59e..b24e3aabbe6ba858a8bfb4dd435726984cc7b0f5 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -25,9 +25,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -66,8 +67,8 @@ XlaContext::XlaContext( XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, bool is_entry_computation, - const std::function* - shape_representation_fn) + const std::function( + const TensorShape&, DataType)>* shape_representation_fn) : compiler_(compiler), builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), @@ -119,8 +120,8 @@ Status XlaContext::CreateResource( return Status::OK(); } -TensorShape XlaContext::RepresentationShape(const TensorShape& shape, - DataType type) const { +xla::StatusOr XlaContext::RepresentationShape( + const TensorShape& shape, DataType type) const { return (*shape_representation_fn_)(shape, type); } diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 5960daaefd625a0b4daf00d7b8c929f3c856575f..3db37afdba71342cfb20af8841a40cb54709ca73 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -22,8 +22,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -47,8 +48,8 @@ class XlaContext : public ResourceBase { XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, bool is_entry_computation, - const std::function* - shape_representation_fn); + const std::function( + const TensorShape&, DataType)>* shape_representation_fn); // Virtual method defined by ResourceBase. string DebugString() override; @@ -101,8 +102,8 @@ class XlaContext : public ResourceBase { // Returns the XLA shape to be used to represent a variable of TF `shape` // and `type`, or of an argument or return value of a top-level computation. - TensorShape RepresentationShape(const TensorShape& shape, - DataType type) const; + xla::StatusOr RepresentationShape(const TensorShape& shape, + DataType type) const; // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a @@ -160,7 +161,7 @@ class XlaContext : public ResourceBase { // should be represented in XLA. Parameters/return values will be shaped // according to this function, and reshaped back to/from their declared shapes // for computations. Must be non-null. - const std::function* + const std::function(const TensorShape&, DataType)>* shape_representation_fn_; // Cache of prebuilt computations indexed by their type. diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc index ead229aaccc292d4944db0c1eaf98c82583533cd..23d04d43b358e858ad1ab2463322ce0ab93b23c2 100644 --- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc @@ -31,6 +31,10 @@ bool CpuOpFilter(KernelDef* kdef) { DT_FLOAT); return true; } + // TODO(b/26783907): The CPU backend currently does not implement sort. + if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") { + return false; + } if (kdef->op() == "Const") { AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef); } diff --git a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc index 62168b648331844bfe2db1a4d5dcad895c8726f3..1398e9ee536a9675e5b703ec3fabf4a8b9d89cbf 100644 --- a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc @@ -20,12 +20,6 @@ limitations under the License. namespace tensorflow { bool GpuOpFilter(KernelDef* kdef) { - // TODO(b/31361304): The GPU backend does not parallelize PRNG ops, leading to - // slow code. - if (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" || - kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal") { - return false; - } if (kdef->op() == "Const") { AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef); } diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 917ef4037d5ae60bd3085ac6cd465ffcbd43e59b..8efb3d55c88757b9366bdf9622287bdd0a72e295 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -23,7 +23,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -33,140 +37,71 @@ namespace tensorflow { namespace { -Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, - const xla::XlaOp& input, const TensorShape& input_shape, - DataType input_type, DataType output_type, int axis, - bool is_min, xla::XlaOp* argminmax) { - xla::XlaOp init_value; - const xla::XlaComputation* reducer; - if (is_min) { - init_value = XlaHelpers::MaxValue(builder, input_type); - reducer = ctx->GetOrCreateMin(input_type); - } else { - init_value = XlaHelpers::MinValue(builder, input_type); - reducer = ctx->GetOrCreateMax(input_type); - } - - xla::PrimitiveType xla_output_type; - TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(output_type, &xla_output_type)); - - xla::XlaOp input_max = xla::Reduce(input, init_value, *reducer, - /*dimensions_to_reduce=*/{axis}); - std::vector broadcast_dims(input_shape.dims() - 1); - std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); - std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - // Compute a mask that has 1s for elements equal to the maximum. - xla::XlaOp partial_mask = xla::ConvertElementType( - xla::Eq(input, input_max, broadcast_dims), xla_output_type); - - // In order to make identity elements for a bitwise And, we: - // Left shift the 1 to the leftmost bit, yielding 0x10...0 - // Arithmetic right shift the 1 back to the rightmost bit, yielding - // 0xFF...F - int32 bits_in_type = - xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_output_type) * 8 - 1; - xla::XlaOp shift_amount = - XlaHelpers::IntegerLiteral(builder, output_type, bits_in_type); - xla::XlaOp full_mask = xla::ShiftRightArithmetic( - xla::ShiftLeft(partial_mask, shift_amount), shift_amount); - - // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its - // index. - xla::XlaOp iota; - - const int64 axis_size = input_shape.dim_size(axis); - TF_RETURN_IF_ERROR(XlaHelpers::Iota(builder, output_type, axis_size, &iota)); - xla::XlaOp product = - xla::And(full_mask, iota, /*broadcast_dimensions=*/{axis}); - - // If there are multiple maximum elements, choose the one with the highest - // index. - xla::XlaOp output = - xla::Reduce(product, XlaHelpers::MinValue(builder, output_type), - *ctx->GetOrCreateMax(output_type), - /*dimensions_to_reduce=*/{axis}); - *argminmax = output; - return Status::OK(); +xla::XlaOp ArgMinMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis, + bool is_min) { + xla::XlaBuilder* builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); + xla::XlaOp init_value; + xla::XlaComputation reducer; + if (is_min) { + init_value = xla::MaxValue(builder, input_shape.element_type()); + reducer = + xla::CreateScalarMinComputation(input_shape.element_type(), builder); + } else { + init_value = xla::MinValue(builder, input_shape.element_type()); + reducer = + xla::CreateScalarMaxComputation(input_shape.element_type(), builder); + } + + xla::XlaOp input_max = xla::Reduce(input, init_value, reducer, + /*dimensions_to_reduce=*/{axis}); + std::vector broadcast_dims(xla::ShapeUtil::Rank(input_shape) - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); + std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); + // Compute a mask that has 1s for elements equal to the maximum. + xla::XlaOp partial_mask = xla::ConvertElementType( + xla::Eq(input, input_max, broadcast_dims), output_type); + + // In order to make identity elements for a bitwise And, we: + // Left shift the 1 to the leftmost bit, yielding 0x10...0 + // Arithmetic right shift the 1 back to the rightmost bit, yielding + // 0xFF...F + int32 bits_in_type = + xla::ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1; + xla::XlaOp shift_amount = + xla::ConstantR0WithType(builder, output_type, bits_in_type); + xla::XlaOp full_mask = xla::ShiftRightArithmetic( + xla::ShiftLeft(partial_mask, shift_amount), shift_amount); + + // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its + // index. + + const int64 axis_size = xla::ShapeUtil::GetDimension(input_shape, axis); + xla::XlaOp iota = xla::Iota(builder, output_type, axis_size); + xla::XlaOp product = + xla::And(full_mask, iota, /*broadcast_dimensions=*/{axis}); + + // If there are multiple maximum elements, choose the one with the highest + // index. + return xla::Reduce(product, xla::MinValue(builder, output_type), + xla::CreateScalarMaxComputation(output_type, builder), + /*dimensions_to_reduce=*/{axis}); + }); } } // namespace -xla::XlaOp XlaHelpers::MinValue(xla::XlaBuilder* b, DataType data_type) { - xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return xla::ConstantLiteral(b, xla::Literal::MinValue(type)); -} - -xla::XlaOp XlaHelpers::MinFiniteValue(xla::XlaBuilder* b, DataType data_type) { - xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - switch (type) { - case xla::F16: - return xla::ConstantR0( - b, Eigen::NumTraits::lowest()); - case xla::BF16: - return xla::ConstantR0(b, bfloat16::lowest()); - case xla::F32: - return xla::ConstantR0(b, -std::numeric_limits::max()); - case xla::F64: - return xla::ConstantR0(b, -std::numeric_limits::max()); - default: - return xla::ConstantLiteral(b, xla::Literal::MinValue(type)); - } -} - -xla::XlaOp XlaHelpers::MaxValue(xla::XlaBuilder* b, DataType data_type) { - xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return xla::ConstantLiteral(b, xla::Literal::MaxValue(type)); -} - -xla::XlaOp XlaHelpers::MaxFiniteValue(xla::XlaBuilder* b, DataType data_type) { - xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - switch (type) { - case xla::F16: - return xla::ConstantR0( - b, Eigen::NumTraits::highest()); - case xla::BF16: - return xla::ConstantR0(b, bfloat16::highest()); - case xla::F32: - return xla::ConstantR0(b, std::numeric_limits::max()); - case xla::F64: - return xla::ConstantR0(b, std::numeric_limits::max()); - default: - return xla::ConstantLiteral(b, xla::Literal::MaxValue(type)); - } -} - xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return xla::ConstantLiteral(b, xla::Literal::Zero(type)); + return xla::ConstantLiteral(b, xla::LiteralUtil::Zero(type)); } xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return xla::ConstantLiteral(b, xla::Literal::One(type)); -} - -xla::XlaOp XlaHelpers::Epsilon(xla::XlaBuilder* b, DataType data_type) { - switch (data_type) { - case DT_HALF: - return xla::ConstantR0( - b, - static_cast(Eigen::NumTraits::epsilon())); - case DT_BFLOAT16: - return xla::ConstantR0(b, bfloat16::epsilon()); - case DT_FLOAT: - return xla::ConstantR0(b, std::numeric_limits::epsilon()); - case DT_DOUBLE: - return xla::ConstantR0(b, std::numeric_limits::epsilon()); - default: - LOG(FATAL) << "Unsupported type in XlaHelpers::Epsilon: " - << DataTypeString(data_type); - } + return xla::ConstantLiteral(b, xla::LiteralUtil::One(type)); } xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type, @@ -214,45 +149,14 @@ static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { return linspace; } -Status XlaHelpers::ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, - const xla::XlaOp& input, - const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, xla::XlaOp* argmax) { - return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type, - axis, /*is_min=*/false, argmax); -} - -Status XlaHelpers::ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, - const xla::XlaOp& input, - const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, xla::XlaOp* argmin) { - return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type, - axis, /*is_min=*/true, argmin); +xla::XlaOp XlaHelpers::ArgMax(xla::XlaOp input, xla::PrimitiveType output_type, + int axis) { + return ArgMinMax(input, output_type, axis, /*is_min=*/false); } -Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, - xla::XlaOp* iota) { - TensorShape linspace_shape({size}); - Tensor linspace; - switch (dtype) { - case DT_UINT8: - linspace = MakeLinspaceTensor(linspace_shape, size); - break; - case DT_INT32: - linspace = MakeLinspaceTensor(linspace_shape, size); - break; - case DT_INT64: - linspace = MakeLinspaceTensor(linspace_shape, size); - break; - default: - return errors::InvalidArgument("Invalid argument type ", - DataTypeString(dtype)); - } - xla::BorrowingLiteral linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); - - *iota = xla::ConstantLiteral(builder, linspace_literal); - return Status::OK(); +xla::XlaOp XlaHelpers::ArgMin(xla::XlaOp input, xla::PrimitiveType output_type, + int axis) { + return ArgMinMax(input, output_type, axis, /*is_min=*/true); } Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index c320016998c77f4df7b00745ca45afd9cc01ea77..e6522157a535fc3e4ec96cb0496b6be2e525c336 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -19,7 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ #include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -28,22 +28,6 @@ namespace tensorflow { // Helper methods for building XLA computations. class XlaHelpers { public: - // Returns a handle representing the minimum value of a scalar - // element of data_type. -inf for floating-point types. - static xla::XlaOp MinValue(xla::XlaBuilder* b, DataType data_type); - - // Returns a handle representing the minimum finite value of a scalar - // element of data_type. - static xla::XlaOp MinFiniteValue(xla::XlaBuilder* b, DataType data_type); - - // Returns a handle representing the maximum value of a scalar - // element of data_type. inf for floating point types. - static xla::XlaOp MaxValue(xla::XlaBuilder* b, DataType data_type); - - // Returns a handle representing the maximum finite value of a scalar - // element of data_type. - static xla::XlaOp MaxFiniteValue(xla::XlaBuilder* b, DataType data_type); - // Returns a handle representing the zero value of a scalar // element of data_type. static xla::XlaOp Zero(xla::XlaBuilder* b, DataType data_type); @@ -52,10 +36,6 @@ class XlaHelpers { // element of data_type. static xla::XlaOp One(xla::XlaBuilder* b, DataType data_type); - // Returns the machine epsilon for floating-point type `data_type`, i.e., - // the difference between 1.0 and the next representable value. - static xla::XlaOp Epsilon(xla::XlaBuilder* b, DataType data_type); - // Returns a handle representing the given value of an integer scalar // element of data_type. // Note that unlike One and Zero, does not work on boolean types. @@ -73,25 +53,15 @@ class XlaHelpers { gtl::ArraySlice shape, xla::Literal* output); - // Sets `argmax` to the argmax of `input` along `axis`. `input_shape` and - // `input_dtype` are the shape and dtype of `input` respectively, and - // `output_type` is the dtype to use for `argmax`. - static Status ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, - const xla::XlaOp& input, const TensorShape& input_shape, - DataType input_type, DataType output_type, int axis, - xla::XlaOp* argmax); - - // Sets `argmin` to the argmin of `input` along `axis`. `input_shape` and - // `input_dtype` are the shape and dtype of `input` respectively, and - // `output_type` is the dtype to use for `argmin`. - static Status ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, - const xla::XlaOp& input, const TensorShape& input_shape, - DataType input_type, DataType output_type, int axis, - xla::XlaOp* argmin); - - // Sets *iota to a rank 1 tensor with values [0, 1, 2, ...] of `dtype`. - static Status Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, - xla::XlaOp* iota); + // Returns the argmax of `input` along `axis`. `output_type` is the type to + // use for the output. + static xla::XlaOp ArgMax(xla::XlaOp input, xla::PrimitiveType output_type, + int axis); + + // Returns the argmin of `input` along `axis`. `output_type` is the type to + // use for the output. + static xla::XlaOp ArgMin(xla::XlaOp input, xla::PrimitiveType output_type, + int axis); // Converts `indices` into a one-hot representation. `depth` is the size // of the new axis to add. `axis` is the position at which to add the new diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 9e17756b27733e2453ea1688d13e1d718c25cfc8..114a9241bdb00526df76478b030a9efa506dd29c 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -57,11 +58,15 @@ xla::StatusOr> ComputeTempSizes( std::vector temp_sizes; temp_sizes.reserve(allocations.size()); for (const xla::BufferAllocation& allocation : allocations) { - // Callers don't allocate temporary buffers for parameters. Nor for - // thread-local buffers, which are lowered to alloca. - if (allocation.is_entry_computation_parameter() || - allocation.is_thread_local()) { + if (allocation.is_constant() || allocation.is_thread_local()) { + // Constants are lowered to globals. Thread locals are lowered to + // allocas. temp_sizes.push_back(-1); + } else if (allocation.is_entry_computation_parameter()) { + // Entry computation parameters need some preprocessing in + // XlaCompiledCpuFunction::Run. See the comment on + // XlaCompiledCpuFunction::StaticData::temp_sizes. + temp_sizes.push_back(-allocation.parameter_number() - 2); } else { temp_sizes.push_back(allocation.size()); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index c2298b97e1d73c217df36466fdadfd36ac78741a..82028c8b9ca9f65a73f8b50edc0a47c7068aba9a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -19,8 +19,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/dma_helper.h" namespace tensorflow { @@ -64,10 +67,32 @@ const xla::XlaOp& XlaOpKernelContext::Input(int index) { return GetComputationFromTensor(context_->input(index)); } +const xla::XlaOp& XlaOpKernelContext::Input(StringPiece name) { + return GetComputationFromTensor(GetInputTensorByName(name)); +} + TensorShape XlaOpKernelContext::InputShape(int index) { return context_->input(index).shape(); } +TensorShape XlaOpKernelContext::InputShape(StringPiece name) { + return GetInputTensorByName(name).shape(); +} + +DataType XlaOpKernelContext::input_type(int index) const { + return context_->input(index).dtype(); +} + +xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) { + xla::PrimitiveType type; + Status status = DataTypeToPrimitiveType(input_type(index), &type); + if (!status.ok()) { + SetStatus(status); + return xla::PRIMITIVE_TYPE_INVALID; + } + return type; +} + Status XlaOpKernelContext::ConstantInput(int index, xla::Literal* constant_literal) { return ConstantInputReshaped( @@ -316,10 +341,11 @@ Status XlaOpKernelContext::ConstantInputList( return Status::OK(); } -Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, - TensorShape* shape, - xla::XlaOp* value) { - const Tensor& tensor = context_->input(index); +namespace { + +Status ReadVariableInputTensor(const Tensor& tensor, DataType type, + const OpKernelContext* ctx, TensorShape* shape, + xla::XlaOp* value) { const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); @@ -337,9 +363,10 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, *shape = variable->shape(); } - XlaContext& xla_context = XlaContext::Get(context_); - TensorShape representation_shape = - xla_context.RepresentationShape(variable->shape(), variable->type()); + XlaContext& xla_context = XlaContext::Get(ctx); + TF_ASSIGN_OR_RETURN( + TensorShape representation_shape, + xla_context.RepresentationShape(variable->shape(), variable->type())); if (representation_shape == variable->shape()) { *value = variable->value(); } else { @@ -348,6 +375,22 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, return Status::OK(); } +} // namespace + +Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, + TensorShape* shape, + xla::XlaOp* value) { + return ReadVariableInputTensor(context_->input(index), type, context_, shape, + value); +} + +Status XlaOpKernelContext::ReadVariableInput(StringPiece name, DataType type, + TensorShape* shape, + xla::XlaOp* value) { + return ReadVariableInputTensor(GetInputTensorByName(name), type, context_, + shape, value); +} + Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, TensorShape* shape) const { const Tensor& tensor = context_->input(index); @@ -438,17 +481,17 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { return Status::OK(); } -Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, - xla::XlaOp handle) { - TF_RET_CHECK(handle.valid()); +namespace { - const XlaExpression* expression = - CastExpressionFromTensor(context_->input(input_index)); +Status AssignVariableTensor(const Tensor& tensor, DataType type, + const OpKernelContext* ctx, xla::XlaOp handle, + xla::XlaBuilder* builder) { + const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); - auto shape_or_status = builder()->GetShape(handle); + auto shape_or_status = builder->GetShape(handle); if (!shape_or_status.ok()) { return shape_or_status.status(); } @@ -458,15 +501,31 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); - XlaContext& xla_context = XlaContext::Get(context_); - TensorShape representation_shape = - xla_context.RepresentationShape(shape, type); + XlaContext& xla_context = XlaContext::Get(ctx); + TF_ASSIGN_OR_RETURN(TensorShape representation_shape, + xla_context.RepresentationShape(shape, type)); if (shape != representation_shape) { handle = xla::Reshape(handle, representation_shape.dim_sizes()); } return variable->SetValue(handle); } +} // namespace + +Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, + xla::XlaOp handle) { + TF_RET_CHECK(handle.valid()); + return AssignVariableTensor(context_->input(input_index), type, context_, + handle, builder()); +} + +Status XlaOpKernelContext::AssignVariable(StringPiece name, DataType type, + xla::XlaOp handle) { + TF_RET_CHECK(handle.valid()); + return AssignVariableTensor(GetInputTensorByName(name), type, context_, + handle, builder()); +} + XlaCompiler* XlaOpKernelContext::compiler() const { return XlaContext::Get(context_).compiler(); } @@ -506,6 +565,12 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( return XlaContext::Get(context_).GetOrCreateMul(type); } +const Tensor& XlaOpKernelContext::GetInputTensorByName(StringPiece name) { + const Tensor* tensor; + CHECK(context_->input(name, &tensor).ok()); + return *tensor; +} + XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {} void XlaOpKernel::Compute(OpKernelContext* context) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 667dc262ca03ca716ffbf015a78fc14c7a8b7c1a..ac9dfe3369078df7392a4ef04679f7d7beacf8bb 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -17,7 +17,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/macros.h" @@ -66,16 +68,26 @@ class XlaOpKernelContext { // Returns the number of inputs to the operator. int num_inputs() const { return context_->num_inputs(); } - // Returns the type of input 'index'. - DataType input_type(int index) { return context_->input(index).dtype(); } + // Returns the type of input `index`. + DataType input_type(int index) const; - // Returns the shape of input 'index'. + // Returns the type of input `index` as an xla::PrimitiveType. If the type + // is not representable as an XLA type, sets an error status and returns + // xla::PRIMITIVE_TYPE_INVALID. + xla::PrimitiveType input_xla_type(int index); + + // Returns the shape of input `index`. TensorShape InputShape(int index); - // Returns input 'index' as a XlaOp. Unlike + // Returns the shape of input `name`. + TensorShape InputShape(StringPiece name); + + // Returns input `index` as a XlaOp. Unlike // OpKernelContext::Input returns a symbolic value rather than a concrete // Tensor. const xla::XlaOp& Input(int index); + // Returns input `name` as a XlaOp. + const xla::XlaOp& Input(StringPiece name); // Returns true if all inputs are the same shape, otherwise sets the // status to a non-OK value and returns false. @@ -90,13 +102,13 @@ class XlaOpKernelContext { // Helper methods for constant inputs. - // Evaluates input 'index' and stores it in '*constant_literal'. If the + // Evaluates input `index` and stores it in `*constant_literal`. If the // expression cannot be evaluated, e.g., because it depends on unbound // parameters, returns a non-OK status. Status ConstantInput(int index, xla::Literal* constant_literal); - // Evaluates input 'index', reshapes it to 'new_shape' if new_shape != - // InputShape(index), and stores it in '*constant_literal'. If the input + // Evaluates input `index`, reshapes it to `new_shape` if new_shape != + // InputShape(index), and stores it in `*constant_literal`. If the input // cannot be evaluated, e.g., because it depends on unbound parameters, // returns a non-Ok status. If InputShape(index).num_elements() != // new_shape.num_elements(), returns an error status. @@ -131,17 +143,17 @@ class XlaOpKernelContext { return context_->expected_output_dtype(index); } - // Sets output 'index' to the XlaOp 'handle'. + // Sets output `index` to the XlaOp `handle`. // All outputs should be set using SetOutput and SetConstantOutput, not // via the underlying OpKernelContext. void SetOutput(int index, const xla::XlaOp& handle); - // Sets output 'index' to compile-time constant 'host_tensor', where - // 'host_tensor' is a tensor in host memory. It is preferable to use + // Sets output `index` to compile-time constant `host_tensor`, where + // `host_tensor` is a tensor in host memory. It is preferable to use // SetConstantOutput where possible. void SetConstantOutput(int index, const Tensor& host_tensor); - // Sets output 'index' to an invalid value. + // Sets output `index` to an invalid value. // Any subsequent attempt to consume this output will cause an error. void SetInvalidOutput(int index); @@ -151,10 +163,10 @@ class XlaOpKernelContext { // Variables - // Sets '*resource' to the resource associated with input `index`. + // Sets `*resource` to the resource associated with input `index`. Status GetResourceInput(int index, XlaResource** resource); - // Sets output 'index' to be a reference to resource 'resource'. + // Sets output `index` to be a reference to resource `resource`. void SetResourceOutput(int index, XlaResource* resource); // Sets `*type` and `*shape` to the current type and shape of a variable's @@ -163,17 +175,23 @@ class XlaOpKernelContext { TensorShape* shape) const; // Reads the current value of the resouce variable referred to by input - // 'index'. If `shape` is not nullptr, sets `*shape` to the shape of the + // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the // variable. Returns an error if the variable has not been initialized, or if // its type does not match `type`. Status ReadVariableInput(int index, DataType type, TensorShape* shape, xla::XlaOp* value); + // Reads the current value of the resouce variable referred to by input + // `name`. + Status ReadVariableInput(StringPiece name, DataType type, TensorShape* shape, + xla::XlaOp* value); // Assigns the value `handle` to the variable referenced by input // `input_index`. The variable must be of `type`. Returns an error if the // variable has been initialized with a different type or with a // different shape. Status AssignVariable(int input_index, DataType type, xla::XlaOp handle); + // Assigns the value `handle` to the variable referenced by input `name`. + Status AssignVariable(StringPiece name, DataType type, xla::XlaOp handle); // Helper routines for the OP_REQUIRES macros void CtxFailure(const Status& s); @@ -221,6 +239,9 @@ class XlaOpKernelContext { const xla::XlaComputation* GetOrCreateMul(const DataType type); private: + // Returns the tensor of input `name`. + const Tensor& GetInputTensorByName(StringPiece name); + OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 2d4593ea4999ad6d8cd0f0e2eec9c6d69c3020b8..fc14834ca6441ea785eacc57e1f502086f36657e 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -279,7 +279,7 @@ class XlaOpRegistrar { #define REGISTER_XLA_OP_UNIQ(CTR, BUILDER, OP) \ static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \ - XlaOpRegistrationBuilder::BUILDER.Build( \ + ::tensorflow::XlaOpRegistrationBuilder::BUILDER.Build( \ [](::tensorflow::OpKernelConstruction* context) \ -> ::tensorflow::OpKernel* { return new OP(context); })); diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index baea8149658ec0849ebb570931ca68518ec5284e..7928fa034725206a752cbfe086d01f15cd235df9 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 4de18a77887496d30e3b1407ecd9042e619653af..2438490be13809b9f3571a362900b44cb838e76b 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 03e542855ba0e3ae81e0b754eb319cadbd5079ba..fdf13bb18c2567d2994612d15119ae87cbfa9137 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -254,6 +254,7 @@ tf_cc_test( ":types", ":util", ":xla_data_proto", + "//tensorflow/core:lib", "//tensorflow/core:test_main", ], ) @@ -281,9 +282,9 @@ tf_cc_test( ) cc_library( - name = "literal_util", - srcs = ["literal_util.cc"], - hdrs = ["literal_util.h"], + name = "literal", + srcs = ["literal.cc"], + hdrs = ["literal.h"], visibility = ["//visibility:public"], deps = [ ":array2d", @@ -300,11 +301,12 @@ cc_library( ) tf_cc_test( - name = "literal_util_test", - srcs = ["literal_util_test.cc"], + name = "literal_test", + srcs = ["literal_test.cc"], deps = [ ":array3d", ":array4d", + ":literal", ":literal_util", ":shape_util", ":test", @@ -316,6 +318,26 @@ tf_cc_test( ], ) +cc_library( + name = "literal_util", + srcs = ["literal_util.cc"], + hdrs = ["literal_util.h"], + visibility = ["//visibility:public"], + deps = [ + ":array2d", + ":array3d", + ":array4d", + ":literal", + ":shape_util", + ":sparse_index_array", + ":status_macros", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:lib", + ], +) + cc_library( name = "error_spec", hdrs = ["error_spec.h"], @@ -327,6 +349,7 @@ cc_library( hdrs = ["literal_comparison.h"], deps = [ ":error_spec", + ":literal", ":literal_util", ":util", "//tensorflow/core:lib", @@ -458,7 +481,7 @@ cc_library( hdrs = ["packed_literal_reader.h"], visibility = [":internal"], deps = [ - ":literal_util", + ":literal", ":shape_util", ":status_macros", ":statusor", @@ -489,7 +512,7 @@ cc_library( hdrs = ["text_literal_reader.h"], visibility = [":internal"], deps = [ - ":literal_util", + ":literal", ":shape_util", ":status_macros", ":statusor", @@ -505,7 +528,7 @@ tf_cc_test( name = "text_literal_reader_test", srcs = ["text_literal_reader_test.cc"], deps = [ - ":literal_util", + ":literal", ":shape_util", ":test", ":text_literal_reader", @@ -522,7 +545,7 @@ cc_library( hdrs = ["text_literal_writer.h"], visibility = [":internal"], deps = [ - ":literal_util", + ":literal", ":shape_util", ":status_macros", ":types", @@ -535,6 +558,7 @@ tf_cc_test( name = "text_literal_writer_test", srcs = ["text_literal_writer_test.cc"], deps = [ + ":literal", ":literal_util", ":test", ":test_helpers", @@ -607,11 +631,12 @@ cc_library( ":array2d", ":array3d", ":array4d", + ":literal_util", ":util", ":window_util", ":xla_data_proto", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_evaluator", "//tensorflow/compiler/xla/service:shape_inference", @@ -627,7 +652,7 @@ tf_cc_test( ":array2d", ":array3d", ":array4d", - ":literal_util", + ":literal", ":reference_util", ":test", ":util", diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index ea75ad32d5df7bbadd37e89de6144b264ab6d5d1..2d5d078aa77423cc18bab053b80a7576acbd849e 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -409,7 +409,7 @@ class Array { // Returns the total number of elements in the array. int64 num_elements() const { - return std::accumulate(sizes_.begin(), sizes_.end(), 1, + return std::accumulate(sizes_.begin(), sizes_.end(), 1LL, std::multiplies()); } diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 8f08d3b2e04670ad6590aca1db0fd9d25faed83f..ad3fcee05b80181369bfdf3cdcdb5452ec9e7e89 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -64,8 +64,9 @@ cc_library( hdrs = ["client.h"], deps = [ ":global_data", + ":xla_computation", "//tensorflow/compiler/xla:execution_options_util", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:service_interface", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -73,7 +74,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", @@ -100,12 +100,12 @@ cc_library( deps = [ ":client", ":executable_build_options", + ":xla_computation", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", @@ -114,6 +114,7 @@ cc_library( "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:source_map_util", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:support", @@ -126,11 +127,11 @@ cc_library( hdrs = ["compile_only_client.h"], deps = [ ":client", + ":xla_computation", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:compile_only_service", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:stream_executor_no_cuda", @@ -174,3 +175,60 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", ], ) + +cc_library( + name = "xla_computation", + srcs = ["xla_computation.cc"], + hdrs = ["xla_computation.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + ], +) + +cc_library( + name = "xla_builder", + srcs = ["xla_builder.cc"], + hdrs = ["xla_builder.h"], + visibility = ["//visibility:public"], + deps = [ + ":padding", + ":sharding_builder", + ":xla_computation", + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:shape_inference", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "xla_builder_test", + srcs = ["xla_builder_test.cc"], + deps = [ + ":xla_builder", + ":xla_computation", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 3d596a6e65430b6e9692aabd65fc8aa84b7b873d..d0ce5e8a6afa262d4cffdfe8431aab570ffd28df 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -18,9 +18,10 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -409,8 +410,10 @@ StatusOr Client::ExecutionStatsAsString( return string("[Execution Statistics] not available."); } -StatusOr Client::CreateChannelHandle() { +StatusOr Client::CreateChannelHandleByType( + ChannelHandle::ChannelType type) { CreateChannelHandleRequest request; + request.set_channel_type(type); CreateChannelHandleResponse response; VLOG(1) << "making create channel handle request"; @@ -424,4 +427,16 @@ StatusOr Client::CreateChannelHandle() { return response.channel(); } +StatusOr Client::CreateChannelHandle() { + return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_DEVICE); +} + +StatusOr Client::CreateHostToDeviceChannelHandle() { + return CreateChannelHandleByType(ChannelHandle::HOST_TO_DEVICE); +} + +StatusOr Client::CreateDeviceToHostChannelHandle() { + return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_HOST); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index 68f0d0ac78c859fde7a6a007cd250b047a7bfcda..be50cebfcc0e3c19002635dbd280b14048aa0c93 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -20,8 +20,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" @@ -178,10 +178,15 @@ class Client { StatusOr> GetComputationShape( const XlaComputation& computation); - // Creates a channel handle that can be used to transfer data between - // two computations via a pair of Send and Recv instructions. + // Creates a channel handle that can be used to transfer data between two + // computations on different devices via a pair of Send and Recv instructions. StatusOr CreateChannelHandle(); + // Create a channel for communicating with the host via a SendtoHost or + // RecvFromHost operation. + StatusOr CreateHostToDeviceChannelHandle(); + StatusOr CreateDeviceToHostChannelHandle(); + StatusOr LoadSnapshot(const HloSnapshot& module); ServiceInterface* stub() { return stub_; } @@ -192,6 +197,9 @@ class Client { StatusOr ExecutionStatsAsString(const XlaComputation& computation, const ExecutionProfile& profile); + StatusOr CreateChannelHandleByType( + ChannelHandle::ChannelType type); + ServiceInterface* stub_; // Stub that this client is connected on. TF_DISALLOW_COPY_AND_ASSIGN(Client); diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index 332c96503637344d56e363e19db4880c37ca9684..a551edeab0943ec5213c5cb035644c02c3cf54d7 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/compile_only_service.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index d49d959a6c8112d3701857a70cecb24701c7b6d9..a2f32ab97eab10294a607f35fc79ded1cc2c5792 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -13,28 +13,191 @@ filegroup( ]), ) +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") + +# Generate test_suites for all backends, named "${backend}_tests". +generate_backend_suites() + cc_library( name = "arithmetic", srcs = ["arithmetic.cc"], hdrs = ["arithmetic.h"], deps = [ + ":constants", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) +cc_library( + name = "constants", + srcs = ["constants.cc"], + hdrs = ["constants.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + ], +) + +xla_test( + name = "constants_test", + srcs = ["constants_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "math", + srcs = ["math.cc"], + hdrs = ["math.h"], + deps = [ + ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/client:xla_builder", + ], +) + +xla_test( + name = "math_test", + srcs = ["math_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":math", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "numeric", + srcs = ["numeric.cc"], + hdrs = ["numeric.h"], + deps = [ + ":arithmetic", + ":constants", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "numeric_test", + srcs = ["numeric_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":numeric", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "pooling", + srcs = ["pooling.cc"], + hdrs = ["pooling.h"], + deps = [ + ":arithmetic", + ":constants", + "//tensorflow/compiler/tf2xla/lib:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "pooling_test", + srcs = ["pooling_test.cc"], + deps = [ + ":pooling", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "prng", + srcs = ["prng.cc"], + hdrs = ["prng.h"], + deps = [ + ":constants", + ":math", + ":numeric", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "sorting", + srcs = ["sorting.cc"], + hdrs = ["sorting.h"], + deps = [ + ":numeric", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + ], +) + +xla_test( + name = "sorting_test", + srcs = ["sorting_test.cc"], + blacklisted_backends = [ + "cpu", + "gpu", + ], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":sorting", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "testing", srcs = ["testing.cc"], hdrs = ["testing.h"], deps = [ "//tensorflow/compiler/xla:execution_options_util", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -42,8 +205,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 8c314fa61bbd67774c91e7e34e93730dbe77eb8d..9225b1acd69c214d6f08a45372a8082ed789c18c 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -17,8 +17,9 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -93,16 +94,18 @@ XlaComputation CreateScalarMinComputation(PrimitiveType type, }); } -XlaComputation CreateScalarAndComputation(XlaBuilder* builder) { +XlaComputation CreateScalarAndComputation(PrimitiveType type, + XlaBuilder* builder) { return CreateScalarComputation( - "and", PRED, builder, + "and", type, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { return And(lhs, rhs); }); } -XlaComputation CreateScalarOrComputation(XlaBuilder* builder) { - return CreateScalarComputation("or", PRED, builder, +XlaComputation CreateScalarOrComputation(PrimitiveType type, + XlaBuilder* builder) { + return CreateScalarComputation("or", type, builder, [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { return Or(lhs, rhs); }); } @@ -111,7 +114,7 @@ XlaOp Any(XlaOp predicates) { XlaBuilder* builder = predicates.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { auto f = ConstantR0(builder, false); - XlaComputation logical_or = CreateScalarOrComputation(builder); + XlaComputation logical_or = CreateScalarOrComputation(PRED, builder); TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, builder->GetShape(predicates)); std::vector all_dimensions(ShapeUtil::Rank(predicates_shape)); @@ -120,134 +123,4 @@ XlaOp Any(XlaOp predicates) { }); } -namespace { -XlaOp FloatLiteral(XlaBuilder* b, PrimitiveType data_type, float value) { - return ConvertElementType(ConstantR0(b, value), data_type); -} - -// Polynomials for computing erf/erfc. Originally from cephes. -// Note we use float for compatibility across devices, at the cost of some -// precision for 64 bit computations. -// -// Coefficients are in descending order. -std::array kErfcPCoefficient = { - 2.46196981473530512524E-10, 5.64189564831068821977E-1, - 7.46321056442269912687E0, 4.86371970985681366614E1, - 1.96520832956077098242E2, 5.26445194995477358631E2, - 9.34528527171957607540E2, 1.02755188689515710272E3, - 5.57535335369399327526E2}; -std::array kErfcQCoefficient = { - 1.00000000000000000000E0, 1.32281951154744992508E1, - 8.67072140885989742329E1, 3.54937778887819891062E2, - 9.75708501743205489753E2, 1.82390916687909736289E3, - 2.24633760818710981792E3, 1.65666309194161350182E3, - 5.57535340817727675546E2}; -std::array kErfcRCoefficient = { - 5.64189583547755073984E-1, 1.27536670759978104416E0, - 5.01905042251180477414E0, 6.16021097993053585195E0, - 7.40974269950448939160E0, 2.97886665372100240670E0}; -std::array kErfcSCoefficient = { - 1.00000000000000000000E0, 2.26052863220117276590E0, - 9.39603524938001434673E0, 1.20489539808096656605E1, - 1.70814450747565897222E1, 9.60896809063285878198E0, - 3.36907645100081516050E0}; -std::array kErfTCoefficient = { - 9.60497373987051638749E0, 9.00260197203842689217E1, - 2.23200534594684319226E3, 7.00332514112805075473E3, - 5.55923013010394962768E4}; -std::array kErfUCoefficient = { - 1.00000000000000000000E0, 3.35617141647503099647E1, - 5.21357949780152679795E2, 4.59432382970980127987E3, - 2.26290000613890934246E4, 4.92673942608635921086E4}; -} // namespace - -// Evaluate the polynomial given coefficients and `x`. -// N.B. Coefficients should be supplied in decreasing order. -XlaOp EvaluatePolynomial(XlaOp x, - tensorflow::gtl::ArraySlice coefficients, - PrimitiveType data_type) { - XlaBuilder* b = x.builder(); - XlaOp poly = FloatLiteral(b, data_type, 0.0); - for (float c : coefficients) { - poly = Add(Mul(poly, x), FloatLiteral(b, data_type, c)); - } - return poly; -} - -// Compute an approximation of the error function complement (1 - erf(x)). -XlaOp Erfc(XlaOp x, PrimitiveType data_type) { - XlaBuilder* b = x.builder(); - XlaOp zero = FloatLiteral(b, data_type, 0.0); - XlaOp two = FloatLiteral(b, data_type, 2.0); - XlaOp eight = FloatLiteral(b, data_type, 8.0); - - XlaOp abs_x = Abs(x); - XlaOp z = Exp(Mul(Neg(x), x)); - - XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient, data_type); - XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient, data_type); - XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient, data_type); - XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient, data_type); - - XlaOp y = Select(Lt(abs_x, eight), Div(Mul(z, pp), pq), Div(Mul(z, pr), ps)); - - return Select(Lt(x, zero), Sub(two, y), y); -} - -// Compute a polynomial approximation of the error function. -XlaOp Erf(XlaOp x, PrimitiveType data_type) { - XlaOp z = Mul(x, x); - XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient, data_type); - XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient, data_type); - return Div(Mul(x, pt), pu); -} - -// Approximation for the inverse error function from -// Giles, M., "Approximating the erfinv function". -// The approximation has the form: -// w = -log((1 - x) * (1 + x)) -// if ( w < 5 ) { -// w = w - 2.5 -// p = sum_{i=1}^n lq[i]*w^i -// } else { -// w = sqrt(w) - 3 -// p = sum_{i=1}^n gq[i]*w^i -// } -// return p*x -XlaOp ErfInv(XlaOp x) { - XlaBuilder* b = x.builder(); - return b->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x)); - constexpr int kDegree = 9; - constexpr std::array w_less_than_5_constants = { - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - constexpr std::array w_greater_than_5_constants = { - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; - - auto one = ConstantR0(b, 1.0); - auto w = Neg(Log(Mul(Sub(one, x), Add(one, x)))); - - auto lt = Lt(w, ConstantR0(b, 5.0)); - auto coefficient = [&](int i) { - return Select( - lt, - Broadcast(ConstantR0(b, w_less_than_5_constants[i]), - AsInt64Slice(shape.dimensions())), - Broadcast(ConstantR0(b, w_greater_than_5_constants[i]), - AsInt64Slice(shape.dimensions()))); - }; - w = Select(lt, Sub(w, ConstantR0(b, 2.5f)), - Sub(SqrtF32(w), ConstantR0(b, 3.0f))); - auto p = coefficient(0); - for (int i = 1; i < kDegree; ++i) { - p = Add(coefficient(i), Mul(p, w)); - } - return Mul(p, x); - }); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index d0e04bbb5eb5365ab3f45dcaf4d8c389d2e77fa1..632e8cc8bc64fad236a0226c6e93079aadde7050 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -45,31 +45,18 @@ XlaComputation CreateScalarMinComputation(PrimitiveType type, XlaBuilder* builder); // Creates a scalar logical AND computation and returns it. -XlaComputation CreateScalarAndComputation(XlaBuilder* builder); +XlaComputation CreateScalarAndComputation(PrimitiveType type, + XlaBuilder* builder); // Creates a scalar logical OR computation and returns it. -XlaComputation CreateScalarOrComputation(XlaBuilder* builder); +XlaComputation CreateScalarOrComputation(PrimitiveType type, + XlaBuilder* builder); // Returns whether any predicate in "predicates" is set. // // Note: if predicates is zero-sized, Any() vacuously returns false. XlaOp Any(XlaOp predicates); -// Evaluate the polynomial given coefficients and `x`. -// N.B. Coefficients should be supplied in decreasing order. -XlaOp EvaluatePolynomial(XlaOp x, - tensorflow::gtl::ArraySlice coefficients, - PrimitiveType data_type); - -// Compute an approximation of the error function complement (1 - erf(x)). -XlaOp Erfc(XlaOp x, PrimitiveType data_type); - -// Compute an approximation of the error function. -XlaOp Erf(XlaOp x, PrimitiveType data_type); - -// Compute an approximation of the inverse of the error function. -XlaOp ErfInv(XlaOp x); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_ diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc new file mode 100644 index 0000000000000000000000000000000000000000..031d62e4ffef188082303a28866bbc72a154e9b1 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/constants.cc @@ -0,0 +1,103 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/constants.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +XlaOp Zero(XlaBuilder* builder, PrimitiveType type) { + return ConstantLiteral(builder, LiteralUtil::Zero(type)); +} + +XlaOp Zeros(XlaBuilder* builder, const Shape& shape) { + return Broadcast(Zero(builder, shape.element_type()), + AsInt64Slice(shape.dimensions())); +} + +XlaOp ZerosLike(XlaOp prototype) { + XlaBuilder* builder = prototype.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); + return Zeros(builder, shape); + }); +} + +XlaOp One(XlaBuilder* builder, PrimitiveType type) { + return ConstantLiteral(builder, LiteralUtil::One(type)); +} + +XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) { + switch (type) { + case F16: + return ConstantR0( + builder, + static_cast(Eigen::NumTraits::epsilon())); + case BF16: + return ConstantR0(builder, bfloat16::epsilon()); + case F32: + return ConstantR0(builder, std::numeric_limits::epsilon()); + case F64: + return ConstantR0(builder, + std::numeric_limits::epsilon()); + default: + return builder->ReportError(InvalidArgument( + "Invalid type for Epsilon (%s).", PrimitiveType_Name(type).c_str())); + } +} + +XlaOp MinValue(XlaBuilder* builder, PrimitiveType type) { + return ConstantLiteral(builder, LiteralUtil::MinValue(type)); +} + +XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) { + switch (type) { + case F16: + return ConstantR0(builder, + Eigen::NumTraits::lowest()); + case BF16: + return ConstantR0(builder, bfloat16::lowest()); + case F32: + return ConstantR0(builder, -std::numeric_limits::max()); + case F64: + return ConstantR0(builder, -std::numeric_limits::max()); + default: + return MinValue(builder, type); + } +} + +XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) { + return ConstantLiteral(builder, LiteralUtil::MaxValue(type)); +} + +XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) { + switch (type) { + case F16: + return ConstantR0(builder, + Eigen::NumTraits::highest()); + case BF16: + return ConstantR0(builder, bfloat16::highest()); + case F32: + return ConstantR0(builder, std::numeric_limits::max()); + case F64: + return ConstantR0(builder, std::numeric_limits::max()); + default: + return MaxValue(builder, type); + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h new file mode 100644 index 0000000000000000000000000000000000000000..0c8a9b8cc02ba0c1ebdf6a060d4b99262dceb178 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/constants.h @@ -0,0 +1,124 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_ + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns scalar 'value' as a scalar of 'type'. Unlike ConstantR0, 'type' is +// determined at C++ run-time, rather than C++ compile-time. +// If 'value' is floating point but 'type' is not, or if 'value' is complex but +// 'type' is not, an error will be returned. This is to catch accidental +// truncation; in such cases, use an explicit cast. +template +XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { + if (std::is_floating_point::value && + !(primitive_util::IsFloatingPointType(type) || + primitive_util::IsComplexType(type))) { + return builder->ReportError(InvalidArgument( + "Invalid cast from floating point type to %s in ConstantR0WithType.", + PrimitiveType_Name(type).c_str())); + } + if (std::is_same::value && + !primitive_util::IsComplexType(type)) { + return builder->ReportError(InvalidArgument( + "Invalid cast from complex type to %s in ConstantR0WithType.", + PrimitiveType_Name(type).c_str())); + } + switch (type) { + case F16: + return ConstantR0(builder, static_cast(value)); + case BF16: + return ConstantR0(builder, static_cast(value)); + case F32: + return ConstantR0(builder, static_cast(value)); + case F64: + return ConstantR0(builder, static_cast(value)); + case C64: + return ConstantR0(builder, static_cast(value)); + case U8: + return ConstantR0(builder, static_cast(value)); + case U32: + return ConstantR0(builder, static_cast(value)); + case U64: + return ConstantR0(builder, static_cast(value)); + case S8: + return ConstantR0(builder, static_cast(value)); + case S32: + return ConstantR0(builder, static_cast(value)); + case S64: + return ConstantR0(builder, static_cast(value)); + default: + return builder->ReportError( + InvalidArgument("Invalid type for ConstantR0WithType (%s).", + PrimitiveType_Name(type).c_str())); + } +} + +// Returns a scalar containing 'value' cast to the same run-time type as +// 'prototype'. +// If 'value' is floating point but 'prototype' is not, or if 'value' is complex +// 'prototype' is not, an error will be returned. +template +XlaOp ScalarLike(XlaOp prototype, T value) { + XlaBuilder* builder = prototype.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); + return ConstantR0WithType(builder, shape.element_type(), value); + }); +} + +// Returns a scalar with value '0' of 'type'. +XlaOp Zero(XlaBuilder* builder, PrimitiveType type); + +// Returns a zero-filled tensor with shape `shape`. +XlaOp Zeros(XlaBuilder* builder, const Shape& shape); + +// Returns a zero-filled tensor with the same shape as `prototype`. +XlaOp ZerosLike(XlaOp prototype); + +// Returns a scalar with value '1' of 'type'. +XlaOp One(XlaBuilder* builder, PrimitiveType type); + +// Returns the machine epsilon for floating-point type `type`, i.e., +// the difference between 1.0 and the next representable value. +XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type); + +// Returns the minimum representable finite or infinite value for 'type'. +// Returns '-inf' for floating-point types. +XlaOp MinValue(XlaBuilder* builder, PrimitiveType type); + +// Returns the minimum representable finite value for 'type'. For a floating +// point type, this is equal to -MaxFiniteValue(). +XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type); + +// Returns the maximum representable finite or infinite value for 'type'. +// Returns 'inf' for floating-point types. +XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type); + +// Returns the maximum representable finite value for 'type'. +XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_ diff --git a/tensorflow/compiler/xla/client/lib/constants_test.cc b/tensorflow/compiler/xla/client/lib/constants_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f4320f65c1f76d4d4c384110b39d6606773aaf01 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/constants_test.cc @@ -0,0 +1,159 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +using ConstantsTest = ClientLibraryTestBase; + +using ::testing::HasSubstr; + +XLA_TEST_F(ConstantsTest, ConstantR0WithTypeS32) { + XlaBuilder builder(TestName()); + ConstantR0WithType(&builder, xla::S32, 4); + ComputeAndCompareR0(&builder, 4, {}); +} + +XLA_TEST_F(ConstantsTest, ConstantR0WithTypeS32DoesNotAcceptFloats) { + XlaBuilder builder(TestName()); + ConstantR0WithType(&builder, xla::S32, 4.5); + auto statusor = builder.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("Invalid cast")); +} + +XLA_TEST_F(ConstantsTest, ConstantR0WithTypeF32) { + XlaBuilder builder(TestName()); + ConstantR0WithType(&builder, xla::F32, -7); + ComputeAndCompareR0(&builder, -7, {}); + ConstantR0WithType(&builder, xla::F32, 0.5); + ComputeAndCompareR0(&builder, 0.5, {}); +} + +XLA_TEST_F(ConstantsTest, ScalarLikeS32) { + XlaBuilder builder(TestName()); + ScalarLike(ConstantR0(&builder, 42), -3); + ComputeAndCompareR0(&builder, -3, {}); +} + +XLA_TEST_F(ConstantsTest, ScalarLikeF32) { + XlaBuilder builder(TestName()); + ScalarLike(ConstantR0(&builder, 42.75), -3.2); + ComputeAndCompareR0(&builder, -3.2, {}); +} + +XLA_TEST_F(ConstantsTest, ZeroS32) { + XlaBuilder builder(TestName()); + Zero(&builder, S32); + ComputeAndCompareR0(&builder, 0, {}); +} + +XLA_TEST_F(ConstantsTest, ZeroF32) { + XlaBuilder builder(TestName()); + Zero(&builder, F32); + ComputeAndCompareR0(&builder, 0.0, {}); +} + +XLA_TEST_F(ConstantsTest, ZerosS32) { + XlaBuilder builder(TestName()); + Zeros(&builder, ShapeUtil::MakeShape(S32, {2, 2})); + ComputeAndCompareR2(&builder, {{0, 0}, {0, 0}}, {}); +} + +XLA_TEST_F(ConstantsTest, ZerosLikeF32) { + XlaBuilder builder(TestName()); + ZerosLike(ConstantR1(&builder, {1., 2., 3.})); + ComputeAndCompareR1(&builder, {0., 0., 0.}, {}); +} + +XLA_TEST_F(ConstantsTest, OneS32) { + XlaBuilder builder(TestName()); + One(&builder, S32); + ComputeAndCompareR0(&builder, 1, {}); +} + +XLA_TEST_F(ConstantsTest, OneF32) { + XlaBuilder builder(TestName()); + One(&builder, F32); + ComputeAndCompareR0(&builder, 1., {}); +} + +XLA_TEST_F(ConstantsTest, EpsilonF32) { + XlaBuilder builder(TestName()); + Epsilon(&builder, F32); + ComputeAndCompareR0(&builder, std::numeric_limits::epsilon(), + {}); +} + +XLA_TEST_F(ConstantsTest, MinFiniteValueS32) { + XlaBuilder builder(TestName()); + MinFiniteValue(&builder, S32); + ComputeAndCompareR0(&builder, std::numeric_limits::min(), {}); +} + +XLA_TEST_F(ConstantsTest, MaxFiniteValueS32) { + XlaBuilder builder(TestName()); + MaxFiniteValue(&builder, S32); + ComputeAndCompareR0(&builder, std::numeric_limits::max(), {}); +} + +XLA_TEST_F(ConstantsTest, MinFiniteValueF32) { + XlaBuilder builder(TestName()); + MinFiniteValue(&builder, F32); + ComputeAndCompareR0(&builder, -std::numeric_limits::max(), {}); +} + +XLA_TEST_F(ConstantsTest, MaxFiniteValueF32) { + XlaBuilder builder(TestName()); + MaxFiniteValue(&builder, F32); + ComputeAndCompareR0(&builder, std::numeric_limits::max(), {}); +} + +XLA_TEST_F(ConstantsTest, MinValueS32) { + XlaBuilder builder(TestName()); + MinValue(&builder, S32); + ComputeAndCompareR0(&builder, std::numeric_limits::min(), {}); +} + +XLA_TEST_F(ConstantsTest, MaxValueS32) { + XlaBuilder builder(TestName()); + MaxValue(&builder, S32); + ComputeAndCompareR0(&builder, std::numeric_limits::max(), {}); +} + +XLA_TEST_F(ConstantsTest, MinValueF32) { + XlaBuilder builder(TestName()); + MinValue(&builder, F32); + ComputeAndCompareR0(&builder, -std::numeric_limits::infinity(), + {}); +} + +XLA_TEST_F(ConstantsTest, MaxValueF32) { + XlaBuilder builder(TestName()); + MaxValue(&builder, F32); + ComputeAndCompareR0(&builder, std::numeric_limits::infinity(), + {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc new file mode 100644 index 0000000000000000000000000000000000000000..0221de7672c7b7c02b1f8b9c7ff4f92151e567c6 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -0,0 +1,304 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/math.h" + +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { + +XlaOp Sqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, 0.5)); } + +XlaOp Rsqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, -0.5)); } + +XlaOp Square(XlaOp operand) { return operand * operand; } + +XlaOp Reciprocal(XlaOp operand) { return ScalarLike(operand, 1.0) / operand; } + +namespace { + +// Polynomials for computing erf/erfc. Originally from cephes. +// Note we use float for compatibility across devices, at the cost of some +// precision for 64 bit computations. +// +// Coefficients are in descending order. +std::array kErfcPCoefficient = { + 2.46196981473530512524E-10, 5.64189564831068821977E-1, + 7.46321056442269912687E0, 4.86371970985681366614E1, + 1.96520832956077098242E2, 5.26445194995477358631E2, + 9.34528527171957607540E2, 1.02755188689515710272E3, + 5.57535335369399327526E2}; +std::array kErfcQCoefficient = { + 1.00000000000000000000E0, 1.32281951154744992508E1, + 8.67072140885989742329E1, 3.54937778887819891062E2, + 9.75708501743205489753E2, 1.82390916687909736289E3, + 2.24633760818710981792E3, 1.65666309194161350182E3, + 5.57535340817727675546E2}; +std::array kErfcRCoefficient = { + 5.64189583547755073984E-1, 1.27536670759978104416E0, + 5.01905042251180477414E0, 6.16021097993053585195E0, + 7.40974269950448939160E0, 2.97886665372100240670E0}; +std::array kErfcSCoefficient = { + 1.00000000000000000000E0, 2.26052863220117276590E0, + 9.39603524938001434673E0, 1.20489539808096656605E1, + 1.70814450747565897222E1, 9.60896809063285878198E0, + 3.36907645100081516050E0}; +std::array kErfTCoefficient = { + 9.60497373987051638749E0, 9.00260197203842689217E1, + 2.23200534594684319226E3, 7.00332514112805075473E3, + 5.55923013010394962768E4}; +std::array kErfUCoefficient = { + 1.00000000000000000000E0, 3.35617141647503099647E1, + 5.21357949780152679795E2, 4.59432382970980127987E3, + 2.26290000613890934246E4, 4.92673942608635921086E4}; +} // namespace + +// Evaluate the polynomial given coefficients and `x`. +// N.B. Coefficients should be supplied in decreasing order. +XlaOp EvaluatePolynomial(XlaOp x, + tensorflow::gtl::ArraySlice coefficients) { + XlaOp poly = ScalarLike(x, 0.0); + for (float c : coefficients) { + poly = poly * x + ScalarLike(x, c); + } + return poly; +} + +// Compute an approximation of the error function complement (1 - erf(x)). +XlaOp Erfc(XlaOp x) { + XlaOp abs_x = Abs(x); + XlaOp z = Exp(-x * x); + + XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient); + XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient); + XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient); + XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient); + + XlaOp y = Select(Lt(abs_x, ScalarLike(x, 8.0)), z * pp / pq, z * pr / ps); + + return Select(Lt(x, ScalarLike(x, 0.0)), ScalarLike(x, 2.0) - y, y); +} + +// Compute a polynomial approximation of the error function. +XlaOp Erf(XlaOp x) { + XlaOp z = x * x; + XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient); + XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient); + return x * pt / pu; +} + +// Approximation for the inverse error function from +// Giles, M., "Approximating the erfinv function". +// The approximation has the form: +// w = -log((1 - x) * (1 + x)) +// if ( w < 5 ) { +// w = w - 2.5 +// p = sum_{i=1}^n lq[i]*w^i +// } else { +// w = sqrt(w) - 3 +// p = sum_{i=1}^n gq[i]*w^i +// } +// return p*x +XlaOp ErfInv(XlaOp x) { + XlaBuilder* b = x.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x)); + constexpr int kDegree = 9; + constexpr std::array w_less_than_5_constants = { + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + constexpr std::array w_greater_than_5_constants = { + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; + + auto one = ScalarLike(x, 1.0); + auto w = -Log((one - x) * (one + x)); + + auto lt = Lt(w, ScalarLike(x, 5.0)); + auto coefficient = [&](int i) { + return Select(lt, + Broadcast(ScalarLike(x, w_less_than_5_constants[i]), + AsInt64Slice(shape.dimensions())), + Broadcast(ScalarLike(x, w_greater_than_5_constants[i]), + AsInt64Slice(shape.dimensions()))); + }; + w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0)); + auto p = coefficient(0); + for (int i = 1; i < kDegree; ++i) { + p = coefficient(i) + p * w; + } + return p * x; + }); +} + +namespace { +// Coefficients for the Lanczos approximation of the gamma function. The +// coefficients are uniquely determined by the choice of g and n (kLanczosGamma +// and kLanczosCoefficients.size() + 1). The coefficients below correspond to +// [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and [7, +// 9] seemed to be the least sensitive to the quality of the log function. In +// particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5 +// for a particularly inaccurate log function. +static constexpr double kLanczosGamma = 7; // aka g +static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478; +static constexpr std::array kLanczosCoefficients = { + 676.520368121885098567009190444019, -1259.13921672240287047156078755283, + 771.3234287776530788486528258894, -176.61502916214059906584551354, + 12.507343278686904814458936853, -0.13857109526572011689554707, + 9.984369578019570859563e-6, 1.50563273514931155834e-7}; +} // namespace + +// Compute the Lgamma function using Lanczos' approximation from "A Precision +// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis +// series B. Vol. 1: +// lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z) +// t(z) = z + kLanczosGamma + 1/2 +// A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) +XlaOp Lgamma(XlaOp input) { + XlaOp one_half = ScalarLike(input, 0.5); + XlaOp one = ScalarLike(input, 1); + + XlaOp pi = ScalarLike(input, M_PI); + XlaOp log_pi = ScalarLike(input, std::log(M_PI)); + XlaOp log_sqrt_two_pi = ScalarLike(input, (std::log(2) + std::log(M_PI)) / 2); + + XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); + XlaOp log_lanczos_gamma_plus_one_half = + ScalarLike(input, std::log(kLanczosGamma + 0.5)); + + XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); + + // If the input is less than 0.5 use Gauss's reflection formula: + // gamma(x) = pi / sin(pi * x) * gamma(1 - x) + XlaOp need_to_reflect = Lt(Real(input), one_half); + XlaOp z = Select(need_to_reflect, -input, input - one); + + XlaOp x = base_lanczos_coeff; + for (int i = 0; i < kLanczosCoefficients.size(); ++i) { + XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); + XlaOp index = ScalarLike(input, i); + x = x + lanczos_coefficient / (z + index + one); + } + + // To improve accuracy on platforms with less-precise log implementations, + // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on + // the device. + // log(t) = log(kLanczosGamma + 0.5 + z) + // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) + XlaOp t = lanczos_gamma_plus_one_half + z; + XlaOp log_t = + log_lanczos_gamma_plus_one_half + Log1p(z / lanczos_gamma_plus_one_half); + + XlaOp log_y = log_sqrt_two_pi + (z + one_half) * log_t - t + Log(x); + + XlaOp reflection = log_pi - Log(Sin(pi * input)) - log_y; + XlaOp result = Select(need_to_reflect, reflection, log_y); + return result; +} + +// Compute the Digamma function using Lanczos' approximation from "A Precision +// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis +// series B. Vol. 1: +// digamma(z + 1) = log(t(z)) + A'(z) / A(z) - kLanczosGamma / t(z) +// t(z) = z + kLanczosGamma + 1/2 +// A(z) = kBaseLanczosCoeff + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k)) +// A'(z) = sigma(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) +XlaOp Digamma(XlaOp input) { + XlaOp zero = ScalarLike(input, 0); + XlaOp one_half = ScalarLike(input, 0.5); + XlaOp one = ScalarLike(input, 1); + + XlaOp pi = ScalarLike(input, M_PI); + + XlaOp lanczos_gamma = ScalarLike(input, kLanczosGamma); + XlaOp lanczos_gamma_plus_one_half = ScalarLike(input, kLanczosGamma + 0.5); + XlaOp log_lanczos_gamma_plus_one_half = + ScalarLike(input, std::log(kLanczosGamma + 0.5)); + + XlaOp base_lanczos_coeff = ScalarLike(input, kBaseLanczosCoeff); + + // If the input is less than 0.5 use Gauss's reflection formula: + // digamma(x) = digamma(1 - x) - pi * cot(pi * x) + XlaOp need_to_reflect = Lt(Real(input), one_half); + XlaOp z = Select(need_to_reflect, -input, input - one); + + XlaOp num = zero; + XlaOp denom = base_lanczos_coeff; + for (int i = 0; i < kLanczosCoefficients.size(); ++i) { + XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); + XlaOp index = ScalarLike(input, i); + num = num - lanczos_coefficient / ((z + index + one) * (z + index + one)); + denom = denom + lanczos_coefficient / (z + index + one); + } + + // To improve accuracy on platforms with less-precise log implementations, + // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on + // the device. + // log(t) = log(kLanczosGamma + 0.5 + z) + // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5)) + XlaOp t = lanczos_gamma_plus_one_half + z; + XlaOp log_t = + log_lanczos_gamma_plus_one_half + Log1p(z / lanczos_gamma_plus_one_half); + + XlaOp y = log_t + num / denom - lanczos_gamma / t; + XlaOp reflection = y - pi * Cos(pi * input) / Sin(pi * input); + XlaOp result = Select(need_to_reflect, reflection, y); + return result; +} + +// Trigonometric functions. + +// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) +XlaOp Acos(XlaOp x) { + return ScalarLike(x, 2.0) * + Atan2(Sqrt(ScalarLike(x, 1.0) - x * x), ScalarLike(x, 1.0) + x); +} + +// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) +XlaOp Asin(XlaOp x) { + return ScalarLike(x, 2.0) * + Atan2(x, ScalarLike(x, 1.0) + Sqrt(ScalarLike(x, 1.0) - x * x)); +} + +XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); } + +XlaOp Tan(XlaOp x) { return Sin(x) / Cos(x); } + +// Hyperbolic trigonometric functions. + +// acosh(x) = log(x + sqrt(x^2 - 1)) +// = log(x + sqrt((x+1)*(x-1))) +XlaOp Acosh(XlaOp x) { + return Log(x + Sqrt((x + ScalarLike(x, 1.0)) * (x - ScalarLike(x, 1.0)))); +} + +// asinh(x) = log(x + sqrt(x^2 + 1)) +XlaOp Asinh(XlaOp x) { return Log(x + Sqrt(x * x + ScalarLike(x, 1.0))); } + +// atanh(x) = 0.5 * log((1 + x) / (1 - x)) +XlaOp Atanh(XlaOp x) { + return Log((ScalarLike(x, 1.0) + x) / (ScalarLike(x, 1.0) - x)) * + ScalarLike(x, 0.5); +} + +XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); } + +XlaOp Sinh(XlaOp x) { return (Exp(x) - Exp(-x)) * ScalarLike(x, 0.5); } + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h new file mode 100644 index 0000000000000000000000000000000000000000..13db2325569cf2e25e3ff1200adf4b2544dc2f73 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" + +namespace xla { + +// Computes the square root of 'operand'. +XlaOp Sqrt(XlaOp operand); + +// Computes the reciprocal of the square root of 'operand'. +XlaOp Rsqrt(XlaOp operand); + +// Computes the square of 'operand'. +XlaOp Square(XlaOp operand); + +// Computes the reciprocal of 'operand'. +XlaOp Reciprocal(XlaOp operand); + +// Evaluates a polynomial given coefficients and `x`. +// N.B. Coefficients should be supplied in decreasing order. +XlaOp EvaluatePolynomial(XlaOp x, + tensorflow::gtl::ArraySlice coefficients); + +// Computes an approximation of the error function complement (1 - erf(x)). +XlaOp Erfc(XlaOp x); + +// Computes an approximation of the error function. +XlaOp Erf(XlaOp x); + +// Computes an approximation of the inverse of the error function. +XlaOp ErfInv(XlaOp x); + +// Computes an approximation of the lgamma function. +XlaOp Lgamma(XlaOp input); + +// Computes an approximation of the digamma function. +XlaOp Digamma(XlaOp input); + +// Trigonometric functions + +// Computes the arc cosine of 'x'. +XlaOp Acos(XlaOp x); + +// Computes the arc sine of 'x'. +XlaOp Asin(XlaOp x); + +// Computes the arc tangent of 'x'. +XlaOp Atan(XlaOp x); + +// Computes the tangent of 'x'. +XlaOp Tan(XlaOp x); + +// Hyperbolic trigonometric functions + +// Computes the inverse hyperbolic cosine of 'x'. +XlaOp Acosh(XlaOp x); + +// Computes the inverse hyperbolic sine of 'x'. +XlaOp Asinh(XlaOp x); + +// Computes the inverse hyperbolic tangent of 'x'. +XlaOp Atanh(XlaOp x); + +// Computes the hyperbolic cosine of 'x'. +XlaOp Cosh(XlaOp x); + +// Computes the hyperbolic sine of 'x'. +XlaOp Sinh(XlaOp x); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..14c259a7fa2a47642663b65d2785e5bbdc040cfd --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -0,0 +1,140 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class MathTest : public ClientLibraryTestBase { + public: + ErrorSpec error_spec_{0.0001}; +}; + +XLA_TEST_F(MathTest, SqrtF32) { + XlaBuilder builder(TestName()); + Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F32); + + std::unique_ptr zero_data = + client_->TransferToServer(zero_literal).ConsumeValueOrDie(); + + XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero"); + Sqrt(zero); + + ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); +} + +XLA_TEST_F(MathTest, SquareTenValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Square(x); + + std::vector expected = {4.41, 6.76, 6.76, 16., 4.41, + 5.29, 25., 0.81, 5.76, 2.56}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(MathTest, ReciprocalTenValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1( + &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); + Reciprocal(x); + + std::vector expected = { + 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048, + 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(MathTest, SqrtZeroes) { + XlaBuilder builder(TestName()); + auto x = ConstantR1(&builder, {0.0, -0.0}); + Sqrt(x); + + ComputeAndCompareR1(&builder, {0, 0}, {}, error_spec_); +} + +XLA_TEST_F(MathTest, SqrtSixValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1(&builder, {16.0, 1.0, 1024.0, 0.16, 0.2, 12345}); + Sqrt(x); + + std::vector expected = {4, 1, 32, 0.4, 0.4472, 111.1080}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(MathTest, Lgamma) { + XlaBuilder builder(TestName()); + auto x = ConstantR1(&builder, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.5, 1.5, + 2.5, -1.5, -3.5, -5.5}); + Lgamma(x); + + std::vector expected = { + 0, + 0, + static_cast(std::log(2)), + static_cast(std::log(6)), + static_cast(std::log(24)), + static_cast(std::log(120)), + static_cast(std::log(M_PI) / 2), + static_cast(std::log(M_PI) / 2 - std::log(2)), + static_cast(std::log(M_PI) / 2 - std::log(4) + std::log(3)), + static_cast(std::log(M_PI) / 2 - std::log(3) + std::log(4)), + static_cast(std::log(M_PI) / 2 - std::log(105) + std::log(16)), + static_cast(std::log(M_PI) / 2 - std::log(10395) + std::log(64))}; + error_spec_ = ErrorSpec{0.001}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(MathTest, Digamma) { + XlaBuilder builder(TestName()); + auto x = ConstantR1(&builder, {1.0, 0.5, 1 / 3.0, 0.25, 1 / 6.0, 0.125, + 2.0, 3.0, 4.0, 6.0, 8.0, 9.0}); + Digamma(x); + + constexpr double euler_mascheroni = + 0.57721566490153286060651209008240243104215933593992; + std::vector expected = { + static_cast(-euler_mascheroni), + static_cast(-2 * std::log(2) - euler_mascheroni), + static_cast(-M_PI / 2 / std::sqrt(3) - 3 * std::log(3) / 2 - + euler_mascheroni), + static_cast(-M_PI / 2 - 3 * std::log(2) - euler_mascheroni), + static_cast(-M_PI * std::sqrt(3) / 2 - 2 * std::log(2) - + 3 * std::log(3) / 2 - euler_mascheroni), + static_cast( + -M_PI / 2 - 4 * std::log(2) - + (M_PI + std::log(2 + std::sqrt(2)) - std::log(2 - std::sqrt(2))) / + std::sqrt(2) - + euler_mascheroni), + static_cast(1 - euler_mascheroni), + static_cast(1.5 - euler_mascheroni), + static_cast(11 / 6.0 - euler_mascheroni), + static_cast(137 / 60.0 - euler_mascheroni), + static_cast(363 / 140.0 - euler_mascheroni), + static_cast(761 / 280.0 - euler_mascheroni)}; + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc new file mode 100644 index 0000000000000000000000000000000000000000..1c91237ae1574f92cda78c9bddc6f4ac1d68f47c --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/numeric.cc @@ -0,0 +1,137 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +namespace { + +template +XlaOp MakeIota(XlaBuilder* builder, int64 size) { + std::vector values(size); + for (int64 i = 0; i < size; ++i) { + values[i] = static_cast(i); + } + return ConstantR1(builder, values); +} + +} // namespace + +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) { + switch (type) { + case S8: + return MakeIota(builder, size); + case S16: + return MakeIota(builder, size); + case S32: + return MakeIota(builder, size); + case S64: + return MakeIota(builder, size); + case U8: + return MakeIota(builder, size); + case U16: + return MakeIota(builder, size); + case U32: + return MakeIota(builder, size); + case U64: + return MakeIota(builder, size); + case BF16: + return MakeIota(builder, size); + case F16: + return MakeIota(builder, size); + case F32: + return MakeIota(builder, size); + case F64: + return MakeIota(builder, size); + case C64: + return MakeIota(builder, size); + default: + return builder->ReportError( + InvalidArgument("Unimplemented type for Iota: %s.", + PrimitiveType_Name(type).c_str())); + } +} + +XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, + int64 n) { + auto a = Iota(builder, type, m); + auto b = Iota(builder, type, n); + auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0}); + return ConvertElementType(indicator, type); +} + +XlaOp GetMatrixDiagonal(XlaOp x) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + const int64 m = shape.dimensions(n_dims - 2); + const int64 n = shape.dimensions(n_dims - 1); + tensorflow::gtl::ArraySlice major_dims( + AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2); + auto a = Iota(builder, U32, n); + auto b = Iota(builder, U32, m); + auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + auto mask = Broadcast(indicator, major_dims); + + // TPUs don't support S64 add reduction at the moment. But fortunately + // OR-reductions work just as well for integers. + XlaComputation reducer = + primitive_util::IsIntegralType(shape.element_type()) + ? CreateScalarOrComputation(shape.element_type(), builder) + : CreateScalarAddComputation(shape.element_type(), builder); + + return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), + reducer, {m >= n ? n_dims - 2 : n_dims - 1}); + }); +} + +XlaOp Triangle(XlaOp x, bool lower) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + const int64 m = shape.dimensions(n_dims - 2); + const int64 n = shape.dimensions(n_dims - 1); + tensorflow::gtl::ArraySlice major_dims( + AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2); + auto a = Iota(builder, U32, n); + auto b = Iota(builder, U32, m); + xla::XlaOp indicator; + if (lower) { + indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + } else { + indicator = Le(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + } + auto mask = Broadcast(indicator, major_dims); + + return Select(mask, x, Zeros(builder, shape)); + }); +} + +XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } + +XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/numeric.h new file mode 100644 index 0000000000000000000000000000000000000000..efd8cdc25724198633e0bf1c48c4e7d9e4b4c9e1 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/numeric.h @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns a rank 1 tensor of `type` containing values [0, 1, 2, ...]. +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); + +// Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere +// else. +XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); + +// Get the diagonals of the last two dimensions. If 'x' has shape +// [..., M, N], then the output has shape [..., min(M, N)], containing the +// diagonal elements (i.e., with indices [..., i, i]). +XlaOp GetMatrixDiagonal(XlaOp x); + +// Get the upper or lower triangle part of the last two dimensions +XlaOp Triangle(XlaOp x, bool lower); + +// Get the upper triangle part of the last two dimensions +XlaOp UpperTriangle(XlaOp x); + +// Get the lower triangle part of the last two dimensions +XlaOp LowerTriangle(XlaOp x); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8a96ec68d2dca8485215258b1f6731b934e6f2a8 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/numeric_test.cc @@ -0,0 +1,78 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class NumericTest : public ClientLibraryTestBase { + protected: + template + void TestMatrixDiagonal(); +}; + +// TODO(b/64798317): Delete this test case once xla::IotaGen is converted to +// xla::Iota. This test is already implemented for xla::IotaGen in +// xla/tests/iota_test.cc. +XLA_TEST_F(NumericTest, Iota) { + XlaBuilder builder(TestName()); + Iota(&builder, S32, 10); + + ComputeAndCompareR1(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {}); +} + +XLA_TEST_F(NumericTest, Triangle) { + XlaBuilder builder(TestName()); + Array3D input(2, 3, 4); + input.FillIota(0); + + XlaOp a; + auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); + LowerTriangle(a); + Array3D expected({{{0, 0, 0, 0}, {4, 5, 0, 0}, {8, 9, 10, 0}}, + {{12, 0, 0, 0}, {16, 17, 0, 0}, {20, 21, 22, 0}}}); + + ComputeAndCompareR3(&builder, expected, {a_data.get()}); +} + +template +void NumericTest::TestMatrixDiagonal() { + XlaBuilder builder("GetMatrixDiagonal"); + Array3D input(2, 3, 4); + input.FillIota(0); + + XlaOp a; + auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); + GetMatrixDiagonal(a); + Array2D expected({{0, 5, 10}, {12, 17, 22}}); + + ComputeAndCompareR2(&builder, expected, {a_data.get()}); +} + +XLA_TEST_F(NumericTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } + +XLA_TEST_F(NumericTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal(); } + +XLA_TEST_F(NumericTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal(); } + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/pooling.cc b/tensorflow/compiler/xla/client/lib/pooling.cc new file mode 100644 index 0000000000000000000000000000000000000000..7199269a6c889f3589c1148687faf0bb2aaae90a --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/pooling.cc @@ -0,0 +1,183 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" + +namespace xla { + +namespace { + +// Common computation shared between AvgPool and AvgPoolGrad. Divide each +// element of an image by the count of elements that contributed to that +// element during pooling. +XlaOp AvgPoolDivideByCountWithGeneralPadding( + XlaOp sums, PrimitiveType dtype, + tensorflow::gtl::ArraySlice input_shape, + tensorflow::gtl::ArraySlice> spatial_padding, + tensorflow::gtl::ArraySlice ksize, + tensorflow::gtl::ArraySlice stride, + const TensorFormat& data_format) { + // The padding shouldn't be included in the counts. We use another + // ReduceWindow to find the right counts. + const int num_spatial_dims = spatial_padding.size(); + + std::vector input_dim_sizes(num_spatial_dims); + std::vector window_dims(num_spatial_dims); + std::vector window_ksize(num_spatial_dims); + std::vector window_stride(num_spatial_dims); + CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims) + << "Invalid number of spatial dimentions in data format specification"; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + input_dim_sizes[i] = input_shape[dim]; + window_dims[i] = dim; + window_ksize[i] = ksize[dim]; + window_stride[i] = stride[dim]; + } + + XlaBuilder* b = sums.builder(); + // Build a matrix of all 1s, with the same width/height as the input. + auto ones = Broadcast(One(b, dtype), input_dim_sizes); + PaddingConfig padding_config; + for (int i = 0; i < num_spatial_dims; ++i) { + auto dims = padding_config.add_dimensions(); + dims->set_edge_padding_low(spatial_padding[i].first); + dims->set_edge_padding_high(spatial_padding[i].second); + } + auto zero = Zero(b, dtype); + auto padded_ones = Pad(ones, zero, padding_config); + + // Perform a ReduceWindow with the same window size, strides, and padding + // to count the number of contributions to each result element. + auto counts = + ReduceWindow(padded_ones, zero, CreateScalarAddComputation(dtype, b), + window_ksize, window_stride, Padding::kValid); + + return Div(sums, counts, window_dims); +} + +// Sums all elements in the window specified by 'kernel_size' and 'stride'. +XlaOp ComputeSums(XlaOp operand, XlaOp init_value, + tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, + const TensorFormat& data_format) { + XlaBuilder* b = operand.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); + TF_ASSIGN_OR_RETURN(Shape init_shape, b->GetShape(init_value)); + PrimitiveType accumulation_type = init_shape.element_type(); + auto add_computation = CreateScalarAddComputation(accumulation_type, b); + return ReduceWindow(operand, init_value, add_computation, kernel_size, + stride, Padding::kValid); + }); +} + +// Creates a padding configuration out of spatial padding values. +PaddingConfig MakeSpatialPaddingConfig( + tensorflow::gtl::ArraySlice> spatial_padding, + tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, + const TensorFormat& data_format) { + const int num_spatial_dims = kernel_size.size() - 2; + PaddingConfig padding_config; + for (int i = 0; i < 2 + num_spatial_dims; ++i) { + padding_config.add_dimensions(); + } + CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims) + << "Invalid number of spatial dimentions in data format specification"; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + auto padding_dimension = padding_config.mutable_dimensions(dim); + padding_dimension->set_edge_padding_low(spatial_padding[i].first); + padding_dimension->set_edge_padding_high(spatial_padding[i].second); + } + return padding_config; +} + +} // namespace + +XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const TensorFormat& data_format) { + XlaBuilder* b = operand.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); + PrimitiveType dtype = operand_shape.element_type(); + auto max_computation = CreateScalarMaxComputation(dtype, b); + auto init_value = MinValue(b, dtype); + return ReduceWindow(operand, init_value, max_computation, kernel_size, + stride, padding); + }); +} + +XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, + tensorflow::gtl::ArraySlice> padding, + const TensorFormat& data_format, + const bool counts_include_padding) { + XlaBuilder* b = operand.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape operand_shape, b->GetShape(operand)); + PrimitiveType dtype = operand_shape.element_type(); + auto init_value = Zero(b, dtype); + std::vector input_size(operand_shape.dimensions().begin(), + operand_shape.dimensions().end()); + auto padding_config = + MakeSpatialPaddingConfig(padding, kernel_size, stride, data_format); + auto padded_operand = Pad(operand, Zero(b, dtype), padding_config); + auto pooled = ComputeSums(padded_operand, init_value, kernel_size, stride, + data_format); + if (counts_include_padding) { + // If counts include padding, all windows have the same number of elements + // contributing to each average. Divide by the window size everywhere to + // get the average. + int64 window_size = + std::accumulate(kernel_size.begin(), kernel_size.end(), 1, + [](int64 x, int64 y) { return x * y; }); + + auto divisor = ConstantR0WithType(b, dtype, window_size); + return pooled / divisor; + } else { + return AvgPoolDivideByCountWithGeneralPadding( + pooled, dtype, input_size, padding, kernel_size, stride, data_format); + } + }); +} + +std::vector> MakeSpatialPadding( + tensorflow::gtl::ArraySlice input_size, + tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const TensorFormat& data_format) { + const int num_spatial_dims = kernel_size.size() - 2; + std::vector input_spatial_dimensions; + std::vector kernel_size_spatial_dimensions; + std::vector stride_spatial_dimensions; + CHECK_EQ(data_format.num_spatial_dims(), num_spatial_dims) + << "Invalid number of spatial dimentions in data format specification"; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + input_spatial_dimensions.push_back(input_size[dim]); + kernel_size_spatial_dimensions.push_back(kernel_size[dim]); + stride_spatial_dimensions.push_back(stride[dim]); + } + return MakePadding(input_spatial_dimensions, kernel_size_spatial_dimensions, + stride_spatial_dimensions, padding); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/pooling.h b/tensorflow/compiler/xla/client/lib/pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..1699c585d3b09a306c21cfa797a9023a8463bd1f --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/pooling.h @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace xla { + +// Tensor format for reduce window operations. +class TensorFormat { + public: + TensorFormat(int batch_dimension, int feature_dimension, + tensorflow::gtl::ArraySlice spatial_dimensions) + : batch_dimension_(batch_dimension), + feature_dimension_(feature_dimension), + spatial_dimensions_(spatial_dimensions.begin(), + spatial_dimensions.end()) {} + + int batch_dimension() const { return batch_dimension_; } + + int feature_dimension() const { return feature_dimension_; } + + int spatial_dimension(int dim) const { return spatial_dimensions_[dim]; } + + int num_spatial_dims() const { return spatial_dimensions_.size(); } + + private: + // The number of the dimension that represents the batch. + int batch_dimension_; + // The number of the dimension that represents the features. + int feature_dimension_; + // The dimension numbers for the spatial dimensions. + tensorflow::gtl::InlinedVector spatial_dimensions_; +}; + +// Computes the max pool of 'operand'. +XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const TensorFormat& data_format); + +// Computes the average pool of 'operand'. +XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, + tensorflow::gtl::ArraySlice> padding, + const TensorFormat& data_format, + const bool counts_include_padding); + +// Returns the list of low and high padding elements in each spatial dimension +// for the given 'padding' specification. +std::vector> MakeSpatialPadding( + tensorflow::gtl::ArraySlice input_size, + tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const TensorFormat& data_format); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ diff --git a/tensorflow/compiler/xla/client/lib/pooling_test.cc b/tensorflow/compiler/xla/client/lib/pooling_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b4553b60db555ad7c2ab6b695236df745e30683 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/pooling_test.cc @@ -0,0 +1,185 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +TensorFormat MakeNCHWFormat(int num_spatial_dims) { + tensorflow::gtl::InlinedVector spatial_dimensions; + for (int i = 0; i < num_spatial_dims; ++i) { + spatial_dimensions.push_back(i + 2); + } + return TensorFormat(/*batch_dimension=*/0, /*feature_dimension=*/1, + /*spatial_dimensions=*/spatial_dimensions); +} + +std::vector> MakeGeneralPadding( + XlaOp input, tensorflow::gtl::ArraySlice kernel_size, + tensorflow::gtl::ArraySlice stride, Padding padding, + const xla::TensorFormat& data_format) { + XlaBuilder* b = input.builder(); + Shape operand_shape = b->GetShape(input).ValueOrDie(); + std::vector input_size(operand_shape.dimensions().begin(), + operand_shape.dimensions().end()); + return MakeSpatialPadding(input_size, kernel_size, stride, padding, + data_format); +} + +// Add singleton batch and feature dimensions to spatial dimensions, according +// to 'data_format' specification. +std::vector ExpandWithBatchAndFeatureDimensions( + tensorflow::gtl::ArraySlice spatial_dim_sizes, + const xla::TensorFormat& data_format) { + const int num_spatial_dims = spatial_dim_sizes.size(); + std::vector tensor_sizes(num_spatial_dims + 2, 1); + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + tensor_sizes[dim] = spatial_dim_sizes[i]; + } + return tensor_sizes; +} + +class PoolingTest : public ClientLibraryTestBase { + public: + ErrorSpec error_spec_{0.0001}; +}; + +XLA_TEST_F(PoolingTest, MaxPool2D) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + MaxPool(input, kernel_size, stride, Padding::kValid, data_format); + + ComputeAndCompareR4(&builder, {{{{5, 4}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, MaxPool2DWithPadding) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + MaxPool(input, kernel_size, stride, Padding::kSame, data_format); + + ComputeAndCompareR4(&builder, {{{{5, 4, 5}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, MaxPool2DWithPaddingAndStride) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format); + MaxPool(input, kernel_size, stride, Padding::kSame, data_format); + + ComputeAndCompareR4(&builder, {{{{5, 4, 4, 5, 5}, {5, 4, 3, 2, 1}}}}, + {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2D) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kValid, + data_format); + AvgPool(input, kernel_size, stride, padding, data_format, + /*counts_include_padding=*/true); + + ComputeAndCompareR4(&builder, {{{{3, 3}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DWithPadding) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = kernel_size; + auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kSame, + data_format); + AvgPool(input, kernel_size, stride, padding, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, {{{{3, 3, 3}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DWithPaddingAndStride) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format); + auto padding = MakeGeneralPadding(input, kernel_size, stride, Padding::kSame, + data_format); + AvgPool(input, kernel_size, stride, padding, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, + {{{{3, 3, 3, 3, 3}, {4.5, 3.5, 2.5, 1.5, 1}}}}, {}, + error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DWithGeneralPaddingCountNotIncludePadding) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({3, 3}, data_format); + auto stride = kernel_size; + AvgPool(input, kernel_size, stride, {{1, 1}, {2, 1}}, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, {{{{3, 3}}}}, {}, error_spec_); +} + +XLA_TEST_F(PoolingTest, + AvgPool2DWithGeneralPaddingCountNotIncludePaddingAndStride) { + XlaBuilder builder(TestName()); + + XlaOp input = ConstantR4FromArray4D( + &builder, {{{{1, 2, 3, 4, 5}, {5, 4, 3, 2, 1}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({3, 3}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + AvgPool(input, kernel_size, stride, {{2, 1}, {1, 1}}, data_format, + /*counts_include_padding=*/false); + + ComputeAndCompareR4(&builder, {{{{1.5, 3, 4.5}, {3, 3, 3}}}}, {}, + error_spec_); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc new file mode 100644 index 0000000000000000000000000000000000000000..6ef81689489d8117d5951bcb75693c2e3413e4d6 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -0,0 +1,150 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/casts.h" + +namespace xla { +namespace { + +// Rotates a 32-bit integer 'v' left by 'distance' bits. +XlaOp RotateLeftS32(XlaOp v, int distance) { + return (v << ConstantR0(v.builder(), distance)) | + ShiftRightLogical(v, ConstantR0(v.builder(), 32 - distance)); +} + +using ThreeFry2x32State = std::array; + +// Implements the ThreeFry counter-based PRNG algorithm. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) { + XlaBuilder* builder = input[0].builder(); + // Rotation distances specified by the Threefry2x32 algorithm. + constexpr std::array rotations = {13, 15, 26, 6, 17, 29, 16, 24}; + ThreeFry2x32State x; + + std::array ks; + // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. + ks[2] = ConstantR0(builder, 0x1BD11BDA); + for (int i = 0; i < 2; ++i) { + ks[i] = key[i]; + x[i] = input[i]; + ks[2] = ks[2] ^ key[i]; + } + + x[0] = x[0] + ks[0]; + x[1] = x[1] + ks[1]; + + // Performs a single round of the Threefry2x32 algorithm, with a rotation + // amount 'rotation'. + auto round = [](ThreeFry2x32State v, int rotation) { + v[0] = v[0] + v[1]; + v[1] = RotateLeftS32(v[1], rotation); + v[1] = v[0] ^ v[1]; + return v; + }; + + // There are no known statistical flaws with 13 rounds of Threefry2x32. + // We are conservative and use 20 rounds. + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = x[0] + ks[1]; + x[1] = x[1] + ks[2] + ConstantR0(builder, 1); + + x = round(x, rotations[4]); + x = round(x, rotations[5]); + x = round(x, rotations[6]); + x = round(x, rotations[7]); + x[0] = x[0] + ks[2]; + x[1] = x[1] + ks[0] + ConstantR0(builder, 2); + + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = x[0] + ks[0]; + x[1] = x[1] + ks[1] + ConstantR0(builder, 3); + + x = round(x, rotations[4]); + x = round(x, rotations[5]); + x = round(x, rotations[6]); + x = round(x, rotations[7]); + x[0] = x[0] + ks[1]; + x[1] = x[1] + ks[2] + ConstantR0(builder, 4); + + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = x[0] + ks[2]; + x[1] = x[1] + ks[0] + ConstantR0(builder, 5); + + return x; +} + +} // namespace + +XlaOp StatelessRngUniform(std::array seeds, const Shape& shape, + XlaOp minval, XlaOp maxval) { + XlaBuilder* builder = seeds[0].builder(); + if (shape.element_type() != F32) { + return builder->ReportError(Unimplemented( + "Types other than F32 are not implemented by StatelessRngUniform.")); + } + ThreeFry2x32State key = seeds; + const int64 size = ShapeUtil::ElementsIn(shape); + + const int64 half_size = CeilOfRatio(size, 2); + const bool size_is_odd = (half_size * 2 != size); + + // Fill the generator inputs with unique counter values. + ThreeFry2x32State inputs; + inputs[0] = Iota(builder, S32, half_size); + inputs[1] = inputs[0] + ConstantR0(builder, half_size); + ThreeFry2x32State outputs = ThreeFry2x32(inputs, key); + + if (size_is_odd) { + outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1}); + } + + auto bits = Reshape(ConcatInDim(builder, outputs, 0), + AsInt64Slice(shape.dimensions())); + + // Form 23 random mantissa bits, with a leading 1 bit. The leading 1 bit + // forces the random bits into the mantissa. + constexpr int kFloatBits = 32; + constexpr int kMantissaBits = 23; + bits = ShiftRightLogical( + bits, ConstantR0(builder, kFloatBits - kMantissaBits)) | + ConstantR0(builder, tensorflow::bit_cast(1.0f)); + auto floats = BitcastConvertType(bits, F32); + + // We have a floating point number in the range [1.0, 2.0). + // Subtract 1.0f to shift to the range [0.0, 1.0) + floats = floats - ConstantR0(builder, 1.0f); + // Multiply and add to shift to the range [minval, maxval). + return floats * (maxval - minval) + minval; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/prng.h b/tensorflow/compiler/xla/client/lib/prng.h new file mode 100644 index 0000000000000000000000000000000000000000..ad000b1fa1d0655c8fccc0bb33379f2499b77f26 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/prng.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ + +#include + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns a tensor containing 'shape' random values uniformly distributed in +// the range [minval, maxval). Requires 2 32-bit integer seeds. +// Currently only 'shape's of type F32 are implemented. +XlaOp StatelessRngUniform(std::array seeds, const Shape& shape, + XlaOp minval, XlaOp maxval); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_PRNG_H_ diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc new file mode 100644 index 0000000000000000000000000000000000000000..a904be259a3870a679b2c4699ec01e2a11b1ce46 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" + +namespace xla { + +XlaOp TopK(XlaOp input, int64 k) { + XlaBuilder* const builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); + int last_dim = input_shape.dimensions_size() - 1; + int last_dim_size = input_shape.dimensions(last_dim); + + XlaOp iota_s32 = Iota(builder, S32, last_dim_size); + auto input_dims = input_shape.dimensions(); + std::vector broadcast_dims(input_dims.begin(), input_dims.end() - 1); + XlaOp broadcast_s32 = Broadcast(iota_s32, broadcast_dims); + XlaOp sort_result = Sort(Neg(input), broadcast_s32); + std::vector start_indices(input_shape.dimensions_size(), 0); + std::vector limit_indices(input_dims.begin(), input_dims.end()); + limit_indices[last_dim] = k; + std::vector strides(input_shape.dimensions_size(), 1); + + XlaOp values = Neg(Slice(GetTupleElement(sort_result, 0), start_indices, + limit_indices, strides)); + XlaOp indices = Slice(GetTupleElement(sort_result, 1), start_indices, + limit_indices, strides); + return Tuple(builder, {values, indices}); + }); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting.h b/tensorflow/compiler/xla/client/lib/sorting.h new file mode 100644 index 0000000000000000000000000000000000000000..b9dfafdd6f957ae050e0f5dbd076d5288235b490 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/sorting.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns a tuple composed of the top `k` values and corresponding indices in +// `input`. Output values are in descending order, from largest to smallest. +XlaOp TopK(XlaOp input, int64 k); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SORTING_H_ diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fef98c9923096e21a755c6d730de2c7c10852b2d --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace { + +using SortingTest = ClientLibraryTestBase; + +XLA_TEST_F(SortingTest, TopK3From8Values) { + XlaBuilder builder(TestName()); + auto x = + ConstantR1(&builder, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}); + xla::GetTupleElement(xla::TopK(x, 3), 0); + ComputeAndCompareR1(&builder, {7.0, 6.0, 5.0}, {}); +} + +XLA_TEST_F(SortingTest, TopK3From8Indices) { + XlaBuilder builder(TestName()); + auto x_rev = + ConstantR1(&builder, {7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0}); + xla::GetTupleElement(xla::TopK(x_rev, 3), 1); + ComputeAndCompareR1(&builder, {0, 1, 2}, {}); +} + +XLA_TEST_F(SortingTest, TopKFullSort) { + XlaBuilder builder(TestName()); + const int kSize = 16; + std::mt19937 eng; + std::uniform_real_distribution u_dist(0.0, 100.0); + auto gen = std::bind(u_dist, eng); + std::vector inputs(kSize); + std::generate(inputs.begin(), inputs.end(), gen); + auto x = ConstantR1(&builder, inputs); + xla::GetTupleElement(xla::TopK(x, kSize), 0); + + std::sort(inputs.begin(), inputs.end(), std::greater()); + ComputeAndCompareR1(&builder, inputs, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 731ad13b8d0e5d65acc316e72be9fe7d35e826a4..081fec7ad92958aa285e4be41394d7b1876e0815 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -49,7 +49,7 @@ int64 DataSizeOfShape(const Shape& shape) { XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { if (ShapeUtil::IsArray(shape)) { return Broadcast( - ConstantLiteral(builder, Literal::One(shape.element_type())), + ConstantLiteral(builder, LiteralUtil::One(shape.element_type())), AsInt64Slice(shape.dimensions())); } std::vector parts; @@ -98,14 +98,13 @@ std::vector> MakeFakeArgumentsOrDie( << "Computation should have progran shape."; auto program_shape = computation.proto().program_shape(); - // For every (unbound) parameter that the computation wants, we manufacture - // some arbitrary data so that we can invoke the computation. - std::vector> fake_arguments; - for (const Shape& parameter : program_shape.parameters()) { - fake_arguments.push_back(MakeFakeDataOrDie(parameter, client)); - } - - return fake_arguments; + // Create and run a program which produces a tuple with one element per + // parameter, then return the tuple's constituent buffers. + std::vector param_shapes(program_shape.parameters().begin(), + program_shape.parameters().end()); + auto fake_input_tuple = + MakeFakeDataOrDie(ShapeUtil::MakeTupleShape(param_shapes), client); + return client->DeconstructTuple(*fake_input_tuple).ValueOrDie(); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index dc613099e2b42a60d0c11a654ab5cd41f8bd4f6f..03695ce2a339735e3e49522f4fe1bbf2d83a3834 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 5f9710914bd0ceff55f5b0a2db05e553ce8bd637..4d96316d3b247c4f428cf44c9ee4361fbcb2844a 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -18,10 +18,12 @@ limitations under the License. #include #include "llvm/ADT/Triple.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/source_map_util.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/status_macros.h" using xla::source_map_util::InvalidParameterArgument; @@ -29,8 +31,8 @@ using xla::source_map_util::InvalidParameterArgument; namespace xla { namespace { -StatusOr BorrowStreamForDevice(int device_ordinal, - Backend* backend) { +StatusOr BorrowStreamForDevice(int device_ordinal, + Backend* backend) { if (device_ordinal < 0) { device_ordinal = backend->default_device_ordinal(); } @@ -99,11 +101,14 @@ Status LocalExecutable::ValidateExecutionOptions( } } - // Verify that the device the executable was built for is equivalent to the - // device it will run on. - int run_device_ordinal = run_options.device_ordinal() == -1 - ? backend_->default_device_ordinal() - : run_options.device_ordinal(); + // Verify that the device the executable was built for is equivalent + // to the device it will run on. + int run_device_ordinal = run_options.device_ordinal(); + if (run_device_ordinal == -1) { + run_device_ordinal = run_options.stream() != nullptr + ? run_options.stream()->parent()->device_ordinal() + : backend_->default_device_ordinal(); + } TF_ASSIGN_OR_RETURN(bool devices_equivalent, backend_->devices_equivalent( run_device_ordinal, build_options_.device_ordinal())); @@ -141,7 +146,7 @@ StatusOr LocalExecutable::Run( TF_RETURN_IF_ERROR( ValidateExecutionOptions(arguments, run_options, *backend_)); - Backend::StreamPtr stream; + StreamPool::Ptr stream; if (run_options.stream() == nullptr) { // NB! The lifetime of `stream` needs to match the lifetime of // `actual_options` (otherwise we will end up using a returned stream in @@ -298,7 +303,7 @@ StatusOr> LocalClient::TransferFromOutfeedLocal( const Shape& shape, int device_ordinal) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); - auto literal = MakeUnique(); + auto literal = MakeUnique(shape); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed( executor, shape, literal.get())); return std::move(literal); diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 4d9e0d7cd9d6ddebead1e12b23e94b529038039b..ae23809261757c637ab4aec036750c371ac60cdc 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc similarity index 82% rename from tensorflow/compiler/xla/client/xla_client/xla_builder.cc rename to tensorflow/compiler/xla/client/xla_builder.cc index 0145f60483bd3c02d0ecf9c70b784a917dc95232..b3b00e2fffe1196b36190ec72d1425bae4e4e276 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include #include @@ -22,6 +22,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/sharding_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -44,52 +45,20 @@ int64 GetUniqueId() { return id; } -// Returns true if an instruction with the given opcode can be the root of the -// computation. -bool CanBeRoot(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kSend: - case HloOpcode::kSendDone: - case HloOpcode::kOutfeed: - case HloOpcode::kTrace: - return false; - default: - return true; - } -} - } // namespace -XlaOp operator-(const XlaOp& x) { return x.builder()->Neg(x); } -XlaOp operator+(const XlaOp& x, const XlaOp& y) { - return x.builder()->Add(x, y); -} -XlaOp operator-(const XlaOp& x, const XlaOp& y) { - return x.builder()->Sub(x, y); -} -XlaOp operator*(const XlaOp& x, const XlaOp& y) { - return x.builder()->Mul(x, y); -} -XlaOp operator/(const XlaOp& x, const XlaOp& y) { - return x.builder()->Div(x, y); -} -XlaOp operator%(const XlaOp& x, const XlaOp& y) { - return x.builder()->Rem(x, y); -} +XlaOp operator-(const XlaOp& x) { return Neg(x); } +XlaOp operator+(const XlaOp& x, const XlaOp& y) { return Add(x, y); } +XlaOp operator-(const XlaOp& x, const XlaOp& y) { return Sub(x, y); } +XlaOp operator*(const XlaOp& x, const XlaOp& y) { return Mul(x, y); } +XlaOp operator/(const XlaOp& x, const XlaOp& y) { return Div(x, y); } +XlaOp operator%(const XlaOp& x, const XlaOp& y) { return Rem(x, y); } -XlaOp operator~(const XlaOp& x) { return x.builder()->Not(x); } -XlaOp operator&(const XlaOp& x, const XlaOp& y) { - return x.builder()->And(x, y); -} -XlaOp operator|(const XlaOp& x, const XlaOp& y) { - return x.builder()->Or(x, y); -} -XlaOp operator^(const XlaOp& x, const XlaOp& y) { - return x.builder()->Xor(x, y); -} -XlaOp operator<<(const XlaOp& x, const XlaOp& y) { - return x.builder()->ShiftLeft(x, y); -} +XlaOp operator~(const XlaOp& x) { return Not(x); } +XlaOp operator&(const XlaOp& x, const XlaOp& y) { return And(x, y); } +XlaOp operator|(const XlaOp& x, const XlaOp& y) { return Or(x, y); } +XlaOp operator^(const XlaOp& x, const XlaOp& y) { return Xor(x, y); } +XlaOp operator<<(const XlaOp& x, const XlaOp& y) { return ShiftLeft(x, y); } XlaOp operator>>(const XlaOp& x, const XlaOp& y) { XlaBuilder* builder = x.builder(); @@ -101,9 +70,9 @@ XlaOp operator>>(const XlaOp& x, const XlaOp& y) { ShapeUtil::HumanString(shape).c_str()); } if (ShapeUtil::ElementIsSigned(shape)) { - return builder->ShiftRightArithmetic(x, y); + return ShiftRightArithmetic(x, y); } else { - return builder->ShiftRightLogical(x, y); + return ShiftRightLogical(x, y); } }); } @@ -158,28 +127,13 @@ XlaOp XlaBuilder::ReportErrorOrReturn( return ReportErrorOrReturn(op_creator()); } -StatusOr XlaBuilder::GetProgramShape(int64* root_id) const { +StatusOr XlaBuilder::GetProgramShape(int64 root_id) const { TF_RETURN_IF_ERROR(first_error_); - - TF_RET_CHECK(root_id != nullptr); + TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size())); ProgramShape program_shape; - // Not all instructions can be roots. Walk backwards from the last added - // instruction until a valid root is found. - int64 index = instructions_.size() - 1; - for (; index >= 0; index--) { - TF_ASSIGN_OR_RETURN(HloOpcode opcode, - StringToHloOpcode(instructions_[index].opcode())); - if (CanBeRoot(opcode)) { - break; - } - } - if (index < 0) { - return FailedPrecondition("no root instruction was found"); - } - *root_id = instructions_[index].id(); - *program_shape.mutable_result() = instructions_[index].shape(); + *program_shape.mutable_result() = instructions_[root_id].shape(); // Check that the parameter numbers are continuous from 0, and add parameter // shapes and names to the program shape. @@ -204,8 +158,15 @@ StatusOr XlaBuilder::GetProgramShape(int64* root_id) const { } StatusOr XlaBuilder::GetProgramShape() const { - int64 root; - return GetProgramShape(&root); + TF_RET_CHECK(!instructions_.empty()); + return GetProgramShape(instructions_.back().id()); +} + +StatusOr XlaBuilder::GetProgramShape(XlaOp root) const { + if (root.builder_ != this) { + return InvalidArgument("Given root operation is not in this computation."); + } + return GetProgramShape(root.handle()); } void XlaBuilder::IsConstantVisitor(const int64 op_handle, @@ -273,17 +234,29 @@ StatusOr XlaBuilder::Build() { first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); return AppendStatus(first_error_, backtrace); } + return Build(instructions_.back().id()); +} + +StatusOr XlaBuilder::Build(XlaOp root) { + if (root.builder_ != this) { + return InvalidArgument("Given root operation is not in this computation."); + } + return Build(root.handle()); +} + +StatusOr XlaBuilder::Build(int64 root_id) { + if (!first_error_.ok()) { + string backtrace; + first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); + return AppendStatus(first_error_, backtrace); + } HloComputationProto entry; entry.set_id(GetUniqueId()); // Give the computation a global unique id. entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique. - { - int64 root_id; - TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), - GetProgramShape(&root_id)); - entry.set_root_id(root_id); - } + TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), GetProgramShape(root_id)); + entry.set_root_id(root_id); for (auto& instruction : instructions_) { // Ensures that the instruction names are unique among the whole graph. @@ -550,6 +523,14 @@ XlaOp XlaBuilder::Broadcast( }); } +XlaOp XlaBuilder::BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice broadcast_dimensions) { + return ReportErrorOrReturn([&]() -> StatusOr { + return InDimBroadcast(shape, operand, broadcast_dimensions); + }); +} + StatusOr XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) { TF_RETURN_IF_ERROR(first_error_); @@ -745,14 +726,22 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = ShapeUtil::MakeNil(); - *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto(); + *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto(); return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand}); }); } XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) { - return TernaryOp(HloOpcode::kSelect, pred, on_true, on_false); + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& true_shape, GetShape(on_true)); + TF_ASSIGN_OR_RETURN(const Shape& false_shape, GetShape(on_false)); + TF_RET_CHECK(ShapeUtil::IsTuple(true_shape) == + ShapeUtil::IsTuple(false_shape)); + HloOpcode opcode = ShapeUtil::IsTuple(true_shape) ? HloOpcode::kTupleSelect + : HloOpcode::kSelect; + return TernaryOp(opcode, pred, on_true, on_false); + }); } XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice elements) { @@ -1099,11 +1088,11 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { sharding_builder::AssignDevice(0); XlaScopedShardingAssignment scoped_sharding(this, infeed_instruction_sharding); - TF_ASSIGN_OR_RETURN(infeed, - AddInstruction(std::move(instr), HloOpcode::kInfeed)); + TF_ASSIGN_OR_RETURN( + infeed, AddInstruction(std::move(instr), HloOpcode::kInfeed, {})); } else { - TF_ASSIGN_OR_RETURN(infeed, - AddInstruction(std::move(instr), HloOpcode::kInfeed)); + TF_ASSIGN_OR_RETURN( + infeed, AddInstruction(std::move(instr), HloOpcode::kInfeed, {})); } // The infeed instruction produces a tuple of the infed data and a token @@ -1118,6 +1107,35 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { }); } +XlaOp XlaBuilder::InfeedWithToken(const XlaOp& token, const Shape& shape, + const string& config) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument("Given shape to Infeed must have a layout"); + } + const Shape infeed_instruction_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); + *instr.mutable_shape() = infeed_instruction_shape; + instr.set_infeed_config(config); + + if (ShapeUtil::IsArray(shape) && sharding() && + sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) { + // TODO(b/110793772): Support tiled array-shaped infeeds. + return InvalidArgument( + "Tiled sharding is not yet supported for array-shaped infeeds"); + } + + if (sharding() && + sharding()->type() == OpSharding::Type::OpSharding_Type_REPLICATED) { + return InvalidArgument( + "Replicated sharding is not yet supported for infeeds"); + } + + return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token}); + }); +} + void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config) { ReportErrorOrReturn([&]() -> StatusOr { @@ -1163,6 +1181,53 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, }); } +XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const string& outfeed_config) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + + // Check and set outfeed shape. + if (!LayoutUtil::HasLayout(shape_with_layout)) { + return InvalidArgument("Given shape to Outfeed must have a layout"); + } + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { + return InvalidArgument( + "Outfeed shape %s must be compatible with operand shape %s", + ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), + ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + } + *instr.mutable_outfeed_shape() = shape_with_layout; + + instr.set_outfeed_config(outfeed_config); + + return AddInstruction(std::move(instr), HloOpcode::kOutfeed, + {operand, token}); + }); +} + +XlaOp XlaBuilder::CreateToken() { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + return AddInstruction(std::move(instr), HloOpcode::kAfterAll); + }); +} + +XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice tokens) { + return ReportErrorOrReturn([&]() -> StatusOr { + if (tokens.empty()) { + return InvalidArgument("AfterAll requires at least one operand"); + } + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens); + }); +} + XlaOp XlaBuilder::CustomCall(const string& call_target_name, tensorflow::gtl::ArraySlice operands, const Shape& shape) { @@ -1366,13 +1431,31 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand, }); } -XlaOp XlaBuilder::Sort(const XlaOp& operand) { - return UnaryOp(HloOpcode::kSort, operand); -} - -XlaOp XlaBuilder::SqrtF32(const XlaOp& operand) { - return BinaryOp(HloOpcode::kPower, operand, ConstantR0(0.5), - /*broadcast_dimensions=*/{}); +XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional values, + int64 dimension) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); + operand_shape_ptrs.push_back(&keys_shape); + Shape values_shape; + if (values.has_value()) { + TF_ASSIGN_OR_RETURN(values_shape, GetShape(*values)); + operand_shape_ptrs.push_back(&values_shape); + } + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, operand_shape_ptrs)); + if (dimension == -1) { + TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); + dimension = ShapeUtil::Rank(keys_shape) - 1; + } + instr.add_dimensions(dimension); + return values.has_value() + ? AddInstruction(std::move(instr), HloOpcode::kSort, + {keys, *values}) + : AddInstruction(std::move(instr), HloOpcode::kSort, {keys}); + }); } XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, @@ -1405,16 +1488,6 @@ XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, }); } -XlaOp XlaBuilder::SquareF32(const XlaOp& operand) { - return BinaryOp(HloOpcode::kPower, operand, ConstantR0(2.0), - /*broadcast_dimensions=*/{}); -} - -XlaOp XlaBuilder::ReciprocalF32(const XlaOp& operand) { - return BinaryOp(HloOpcode::kPower, operand, ConstantR0(-1.0), - /*broadcast_dimensions=*/{}); -} - XlaOp XlaBuilder::Neg(const XlaOp& operand) { return UnaryOp(HloOpcode::kNegate, operand); } @@ -1551,6 +1624,32 @@ XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, }); } +XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); + TF_ASSIGN_OR_RETURN(const Shape& scatter_indices_shape, + GetShape(scatter_indices)); + TF_ASSIGN_OR_RETURN(const Shape& updates_shape, GetShape(updates)); + TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape, + update_computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferScatterShape( + input_shape, scatter_indices_shape, updates_shape, + to_apply_shape, dimension_numbers)); + + *instr.mutable_scatter_dimension_numbers() = dimension_numbers; + + AddCalledComputation(update_computation, &instr); + return AddInstruction(std::move(instr), HloOpcode::kScatter, + {input, scatter_indices, updates}); + }); +} + XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, const XlaComputation& true_computation, const XlaOp& false_operand, @@ -1594,9 +1693,10 @@ XlaOp XlaBuilder::Reduce( TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value)); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferReduceShape( - operand_shape, init_shape, dimensions_to_reduce, + {&operand_shape, &init_shape}, dimensions_to_reduce, called_program_shape)); for (int64 dim : dimensions_to_reduce) { @@ -1761,10 +1861,6 @@ XlaOp XlaBuilder::CrossReplicaSum( tensorflow::gtl::ArraySlice replica_group_ids, const tensorflow::gtl::optional& channel_id) { return ReportErrorOrReturn([&]() -> StatusOr { - if (channel_id.has_value()) { - return Unimplemented("channel_id is not supported in AllReduce"); - } - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -1774,6 +1870,10 @@ XlaOp XlaBuilder::CrossReplicaSum( instr.add_replica_group_ids(replica_group_id); } + if (channel_id.has_value()) { + instr.set_all_reduce_id(channel_id->handle()); + } + AddCalledComputation(computation, &instr); return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum, @@ -1781,6 +1881,61 @@ XlaOp XlaBuilder::CrossReplicaSum( }); } +XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups) { + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + + // The HloInstruction for Alltoall currently only handles the data + // communication: it accepts N already split parts and scatters them to N + // cores, and each core gathers the N received parts into a tuple as the + // output. So here we explicitly split the operand before the hlo alltoall, + // and concat the tuple elements. + // + // First, run shape inference to make sure the shapes are valid. + TF_RETURN_IF_ERROR( + ShapeInference::InferAllToAllShape(operand_shape, split_dimension, + concat_dimension, split_count) + .status()); + + // Split into N parts. + std::vector slices; + slices.reserve(split_count); + const int64 block_size = + operand_shape.dimensions(split_dimension) / split_count; + for (int i = 0; i < split_count; i++) { + slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size, + /*limit_index=*/(i + 1) * block_size, + /*stride=*/1, /*dimno=*/split_dimension)); + } + + // Handle data communication. + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices)); + std::vector slice_shape_ptrs; + c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs)); + for (const ReplicaGroup& group : replica_groups) { + *instr.add_replica_groups() = group; + } + TF_ASSIGN_OR_RETURN( + XlaOp alltoall, + AddInstruction(std::move(instr), HloOpcode::kAllToAll, slices)); + + // Concat the N received parts. + std::vector received; + received.reserve(split_count); + for (int i = 0; i < split_count; i++) { + received.push_back(this->GetTupleElement(alltoall, i)); + } + return this->ConcatInDim(received, concat_dimension); + }); +} + XlaOp XlaBuilder::SelectAndScatter( const XlaOp& operand, const XlaComputation& select, tensorflow::gtl::ArraySlice window_dimensions, @@ -1847,19 +2002,39 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; + // Send HLO takes two operands: a data operand and a token. Generate the + // token to pass into the send. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto token_instr; + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), + HloOpcode::kAfterAll, {})); - // Send instruction produces a tuple of {aliased operand, U32 context}. + return SendWithToken(operand, token, handle); + }); +} + +XlaOp XlaBuilder::SendWithToken(const XlaOp& operand, const XlaOp& token, + const ChannelHandle& handle) { + return ReportErrorOrReturn([&]() -> StatusOr { + if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) { + return InvalidArgument("Send must use a device-to-device channel"); + } + + // Send instruction produces a tuple of {aliased operand, U32 context, + // token}. + HloInstructionProto send_instr; TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); - *instr.mutable_shape() = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); - instr.set_channel_id(handle.handle()); - TF_ASSIGN_OR_RETURN( - XlaOp send, - AddInstruction(std::move(instr), HloOpcode::kSend, {operand})); + *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + send_instr.set_channel_id(handle.handle()); + TF_ASSIGN_OR_RETURN(XlaOp send, + AddInstruction(std::move(send_instr), HloOpcode::kSend, + {operand, token})); HloInstructionProto send_done_instr; - *send_done_instr.mutable_shape() = ShapeUtil::MakeNil(); + *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); send_done_instr.set_channel_id(handle.handle()); return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, {send}); @@ -1868,18 +2043,132 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; + // Recv HLO takes a single token operand. Generate the token to pass into + // the Recv and RecvDone instructions. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto token_instr; + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), + HloOpcode::kAfterAll, {})); - // Recv instruction produces a tuple of {receive buffer, U32 context}. - *instr.mutable_shape() = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); - instr.set_channel_id(handle.handle()); - TF_ASSIGN_OR_RETURN(XlaOp recv, - AddInstruction(std::move(instr), HloOpcode::kRecv, {})); + XlaOp recv = RecvWithToken(token, shape, handle); + + // The RecvDone instruction produces a tuple of the data and a token + // type. Return XLA op containing the data. + // TODO(b/80000000): Remove this when clients have been updated to handle + // tokens. + HloInstructionProto recv_data; + *recv_data.mutable_shape() = shape; + recv_data.set_tuple_index(0); + return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement, + {recv}); + }); +} + +XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle) { + return ReportErrorOrReturn([&]() -> StatusOr { + if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) { + return InvalidArgument("Recv must use a device-to-device channel"); + } + + // Recv instruction produces a tuple of {receive buffer, U32 context, + // token}. + HloInstructionProto recv_instr; + *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + recv_instr.set_channel_id(handle.handle()); + TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr), + HloOpcode::kRecv, {token})); + + HloInstructionProto recv_done_instr; + *recv_done_instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); + recv_done_instr.set_channel_id(handle.handle()); + return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, + {recv}); + }); +} + +XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const ChannelHandle& handle) { + return ReportErrorOrReturn([&]() -> StatusOr { + if (!LayoutUtil::HasLayout(shape_with_layout)) { + return InvalidArgument("Shape passed to SendToHost must have a layout"); + } + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { + return InvalidArgument( + "SendToHost shape %s must be compatible with operand shape %s", + ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), + ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + } + // TODO(b/111544877): Support tuple shapes. + if (!ShapeUtil::IsArray(operand_shape)) { + return InvalidArgument("SendToHost only supports array shapes, shape: %s", + ShapeUtil::HumanString(operand_shape).c_str()); + } + + if (handle.type() != ChannelHandle::DEVICE_TO_HOST) { + return InvalidArgument("SendToHost must use a device-to-host channel"); + } + + // Send instruction produces a tuple of {aliased operand, U32 context, + // token}. + HloInstructionProto send_instr; + *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape( + {shape_with_layout, ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()}); + send_instr.set_channel_id(handle.handle()); + send_instr.set_is_host_transfer(true); + TF_ASSIGN_OR_RETURN(XlaOp send, + AddInstruction(std::move(send_instr), HloOpcode::kSend, + {operand, token})); + + HloInstructionProto send_done_instr; + *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + send_done_instr.set_channel_id(handle.handle()); + send_done_instr.set_is_host_transfer(true); + return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, + {send}); + }); +} + +XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle) { + return ReportErrorOrReturn([&]() -> StatusOr { + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument("Shape passed to RecvFromHost must have a layout"); + } + + // TODO(b/111544877): Support tuple shapes. + if (!ShapeUtil::IsArray(shape)) { + return InvalidArgument( + "RecvFromHost only supports array shapes, shape: %s", + ShapeUtil::HumanString(shape).c_str()); + } + + if (handle.type() != ChannelHandle::HOST_TO_DEVICE) { + return InvalidArgument("RecvFromHost must use a host-to-device channel"); + } + + // Recv instruction produces a tuple of {receive buffer, U32 context, + // token}. + HloInstructionProto recv_instr; + *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + recv_instr.set_channel_id(handle.handle()); + recv_instr.set_is_host_transfer(true); + TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr), + HloOpcode::kRecv, {token})); HloInstructionProto recv_done_instr; - *recv_done_instr.mutable_shape() = shape; + *recv_done_instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); recv_done_instr.set_channel_id(handle.handle()); + recv_done_instr.set_is_host_transfer(true); return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, {recv}); }); @@ -1918,11 +2207,6 @@ StatusOr XlaBuilder::BuildConstantSubGraph( TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, LookUpInstruction(root_op)); - TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode())); - if (!CanBeRoot(opcode)) { - return InvalidArgument("the operand with opcode %s cannot be root", - root->opcode().c_str()); - } HloComputationProto entry; entry.set_id(GetUniqueId()); // Give the computation a global unique id. @@ -2140,6 +2424,13 @@ XlaOp Broadcast(const XlaOp& operand, return operand.builder()->Broadcast(operand, broadcast_sizes); } +XlaOp BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice broadcast_dimensions) { + return operand.builder()->BroadcastInDim(operand, shape, + broadcast_dimensions); +} + XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, const PaddingConfig& padding_config) { return operand.builder()->Pad(operand, padding_value, padding_config); @@ -2441,6 +2732,13 @@ XlaOp CrossReplicaSum( replica_group_ids, channel_id); } +XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups) { + return operand.builder()->AllToAll(operand, split_dimension, concat_dimension, + split_count, replica_groups); +} + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, @@ -2498,14 +2796,6 @@ XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); } XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); } -XlaOp SqrtF32(const XlaOp& operand) { - return operand.builder()->SqrtF32(operand); -} - -XlaOp SquareF32(const XlaOp& operand) { - return operand.builder()->SquareF32(operand); -} - XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions); @@ -2523,10 +2813,6 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) { return operand.builder()->BitcastConvertType(operand, new_element_type); } -XlaOp ReciprocalF32(const XlaOp& operand) { - return operand.builder()->ReciprocalF32(operand); -} - XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); } XlaOp Transpose(const XlaOp& operand, @@ -2538,7 +2824,10 @@ XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions) { return operand.builder()->Rev(operand, dimensions); } -XlaOp Sort(const XlaOp& operand) { return operand.builder()->Sort(operand); } +XlaOp Sort(XlaOp keys, tensorflow::gtl::optional values, + int64 dimension) { + return keys.builder()->Sort(keys, std::move(values), dimension); +} XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { return min.builder()->Clamp(min, operand, max); @@ -2586,6 +2875,13 @@ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, window_bounds); } +XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers) { + return input.builder()->Scatter(input, scatter_indices, updates, + update_computation, dimension_numbers); +} + void Send(const XlaOp& operand, const ChannelHandle& handle) { return operand.builder()->Send(operand, handle); } @@ -2595,6 +2891,45 @@ XlaOp Recv(XlaBuilder* builder, const Shape& shape, return builder->Recv(shape, handle); } +XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, + const ChannelHandle& handle) { + return operand.builder()->SendWithToken(operand, token, handle); +} + +XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle) { + return token.builder()->RecvWithToken(token, shape, handle); +} + +XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, const ChannelHandle& handle) { + return operand.builder()->SendToHost(operand, token, shape_with_layout, + handle); +} + +XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle) { + return token.builder()->RecvFromHost(token, shape, handle); +} + +XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, + const string& config) { + return token.builder()->InfeedWithToken(token, shape, config); +} + +XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const string& outfeed_config) { + return operand.builder()->OutfeedWithToken(operand, token, shape_with_layout, + outfeed_config); +} + +XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); } + +XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice tokens) { + return builder->AfterAll(tokens); +} + XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, const XlaOp& offset, float epsilon, int64 feature_index) { @@ -2618,4 +2953,11 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, grad_output, epsilon, feature_index); } +XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size) { + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeShape(type, {size}); + return builder->ReportErrorOrReturn( + builder->AddInstruction(std::move(instr), HloOpcode::kIota)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h similarity index 71% rename from tensorflow/compiler/xla/client/xla_client/xla_builder.h rename to tensorflow/compiler/xla/client/xla_builder.h index fe31774b86b53dbc09fccd40b2d98ab988b934dc..9403d7ca8dabc80a3964b50d29f158a98091f843 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_ #include #include @@ -22,7 +22,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -53,7 +54,16 @@ class XlaOp { } ~XlaOp() = default; - XlaBuilder* builder() const { return builder_; } + // Precondition: !IsUninitialized(). + // + // It's very common to do foo.builder()->bar(). Without this precondition, if + // foo.builder() is null, the call to bar will segfault at some point possibly + // deep in the callstack when we finally dereference `this`. The precondition + // lets us avoid this tricky-to-debug problem. + XlaBuilder* builder() const { + CHECK(builder_ != nullptr); + return builder_; + } // Returns true if the XlaOp represents valid, non-erroneous value. bool valid() const { return handle_ >= 0; } @@ -158,6 +168,106 @@ class XlaBuilder { die_immediately_on_error_ = enabled; } + // Default dimension numbers used for a 2D convolution. + static constexpr int64 kConvBatchDimension = 0; + static constexpr int64 kConvFeatureDimension = 1; + static constexpr int64 kConvFirstSpatialDimension = 2; + static constexpr int64 kConvSecondSpatialDimension = 3; + static constexpr int64 kConvKernelOutputDimension = 0; + static constexpr int64 kConvKernelInputDimension = 1; + static constexpr int64 kConvKernelFirstSpatialDimension = 2; + static constexpr int64 kConvKernelSecondSpatialDimension = 3; + + // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for + // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for + // the kernel operand + // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. + static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( + int num_spatial_dims = 2); + + // Returns an error if the convolution dimension numbers have conflicts. + static Status Validate(const ConvolutionDimensionNumbers& dnum); + + // Returns a new XlaBuilder whose resultant Computation is used only by this + // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error + // behavior as the parent. + std::unique_ptr CreateSubBuilder(const string& computation_name); + + // Builds the computation with the requested operations, or returns a non-ok + // status. Note that all ops that have been enqueued will be moved to the + // computation being returned. The root of the computation will be the last + // added operation. + StatusOr Build(); + + // Overload of Build which specifies a particular root instruction for the + // computation. + StatusOr Build(XlaOp root); + + // Builds the computation with the requested operations, or notes an error in + // the parent XlaBuilder and returns an empty computation if building failed. + // This function is intended to be used where the returned XlaComputation is + // only used by the parent XlaBuilder and hence further operation on the + // returned XlaComputation will simply be error'ed out if an error occurred + // while building this computation. If the built computation is to be used by + // a XlaBuilder other than the parent XlaBuilder then Build() should be used + // instead. + XlaComputation BuildAndNoteError(); + + // Returns a subgraph that roots on the given root. If the root is not a + // compile-time constant (see `IsConstant`), returns an error. + // + // This will copy the needed ops/computations to the subgraph. + StatusOr BuildConstantSubGraph(const XlaOp& root_op) const; + + // Returns the first error that was encountered while building the + // computation. When an error is encountered, by default we return a vacuous + // XlaOp and inform the user of the error that occurred while + // building the computation when they make a final call to Build(). + // + // See also set_die_immediately_on_error(). + Status first_error() const { return first_error_; } + + // Returns the shape of the given op. + StatusOr GetShape(const XlaOp& op) const; + + // Returns the (inferred) result for the current computation's shape. This + // assumes the root instruction is the last added instruction. + StatusOr GetProgramShape() const; + + // Returns the (inferred) result for the current computation's shape using the + // given operation as the root. + StatusOr GetProgramShape(XlaOp root) const; + + // Reports an error to the builder, by + // * storing it internally and capturing a backtrace if it's the first error + // (this deferred value will be produced on the call to + // Build()/GetShape()/...) + // * dying if die_immediately_on_error_ is true. + // Returns an XlaOp with an invalid handle but a valid builder. This value can + // be returned in place of a value in APIs that return an XlaOp. + XlaOp ReportError(const Status& error); + + // A helper function that converts a StatusOr into an XlaOp. + // If the Status was an error, reports the error to builder and returns an + // invalid XlaOp handle. + XlaOp ReportErrorOrReturn(const StatusOr& op); + + // A helper function that runs a function that returns a StatusOr and + // returns an XlaOp. + XlaOp ReportErrorOrReturn(const std::function()>& op_creator); + + // Returns true if 'operand' is a compile-time constant. A compile-time + // constant does not depend on any parameters, or on stateful operators such + // as `RngNormal` or `Infeed`. + // + // This tests whether a computation is a compile-time constant without + // evaluating the computation. + StatusOr IsConstant(const XlaOp& operand) const; + + private: + // Build helper which takes the id of the root operation.. + StatusOr Build(int64 root_id); + // Enqueues a "retrieve parameter value" instruction for a parameter that was // passed to the computation. XlaOp Parameter(int64 parameter_number, const Shape& shape, @@ -230,6 +340,27 @@ class XlaBuilder { XlaOp Broadcast(const XlaOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes); + // Performs in-dimension-style broadcast. + // + // Operand specifies the input to be broadcast. "shape" is expected output + // shape. "broadcast_dimensions" are the dimensions to be broadcasting into. + // Dimension numbers in broadcast_dimensions map to individual dimensions + // of the operand, and specify what dimension of the output shape they + // should be broadcast. + // e.g. + // Say operand = [1, 2], i.e., a 1D tensor with 2 elements. + // and dimension of shape is [2,2]. + // Specifying {1} as brodcast_dimension will generate output + // [1 , 2] + // [1 , 2] + // On the other hand, specifying {0} as broadcast_dimension + // will generate output + // [1 , 1] + // [2 , 2] + XlaOp BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice broadcast_dimensions); + // Enqueues a pad operation onto the computation that pads the given value on // the edges as well as between the elements of the input. padding_config // specifies the padding amount for each dimension. @@ -378,26 +509,6 @@ class XlaBuilder { XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers); - // Default dimension numbers used for a 2D convolution. - static constexpr int64 kConvBatchDimension = 0; - static constexpr int64 kConvFeatureDimension = 1; - static constexpr int64 kConvFirstSpatialDimension = 2; - static constexpr int64 kConvSecondSpatialDimension = 3; - static constexpr int64 kConvKernelOutputDimension = 0; - static constexpr int64 kConvKernelInputDimension = 1; - static constexpr int64 kConvKernelFirstSpatialDimension = 2; - static constexpr int64 kConvKernelSecondSpatialDimension = 3; - - // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for - // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for - // the kernel operand - // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. - static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( - int num_spatial_dims = 2); - - // Returns an error if the convolution dimension numbers have conflicts. - static Status Validate(const ConvolutionDimensionNumbers& dnum); - // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, @@ -444,6 +555,8 @@ class XlaBuilder { // Enqueues an infeed instruction onto the computation, which writes data of // the given shape to the infeed buffer of the device. XlaOp Infeed(const Shape& shape, const string& config = ""); + XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, + const string& config = ""); // Enqueues an outfeed instruction onto the computation. This instruction // generates outgoing data transfers for the given data. @@ -453,6 +566,9 @@ class XlaBuilder { // will occur. void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config); + XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const string& outfeed_config); // Enqueues a call instruction onto the computation. XlaOp Call(const XlaComputation& computation, @@ -583,9 +699,9 @@ class XlaBuilder { // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // - // - `channel_id`: for Allreduce nodes from different models, if they have the - // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be - // applied cross models. + // - `channel_id`: for Allreduce nodes from different modules, if they have + // the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will + // not be applied cross modules. // // TODO(b/79737069): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( @@ -594,6 +710,13 @@ class XlaBuilder { const tensorflow::gtl::optional& channel_id = tensorflow::gtl::nullopt); + // Enqueues an operation that do an Alltoall of the operand cross cores. + // + // TODO(b/110096724): This is NOT YET ready to use. + XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, @@ -663,16 +786,6 @@ class XlaBuilder { // Enqueues an imaginary-part instruction onto the computation. XlaOp Imag(const XlaOp& operand); - // Enqueues a float32 sqrt instruction onto the computation. - // (float32 is specified as there is an implicit float32 0.5f constant - // exponent). - XlaOp SqrtF32(const XlaOp& operand); - - // Enqueues a float32 square instruction onto the computation. - // (float32 is specified as there is an implicit float32 2.0f constant - // exponent). - XlaOp SquareF32(const XlaOp& operand); - // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions = {}); @@ -695,14 +808,6 @@ class XlaBuilder { XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); - // Enqueues a float32 reciprocal instruction onto the computation. - // (float32 is specified as there is an implicit float32 -1.0f constant - // exponent). - // - // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the - // shape of the operand. - XlaOp ReciprocalF32(const XlaOp& operand); - // Enqueues a negate instruction onto the computation. XlaOp Neg(const XlaOp& operand); @@ -717,7 +822,24 @@ class XlaBuilder { tensorflow::gtl::ArraySlice dimensions); // Enqueues a sort (as increasing order) instruction onto the computation. - XlaOp Sort(const XlaOp& operand); + // If only keys are provided: + // * If the keys are an rank-1 tensor (an array), the result is a sorted array + // of keys, in ascending order. + // * If the keys have higher rank, the keys are sorted along the provided + // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension + // value of 0 will indepenently sort every column, and a dimension value of 1 + // will independently sort each row. If no dimension number is provided, then + // the last dimension is chosen by default. + // + // If both keys and values are provided: + // * The keys and the values must tensors with the same dimensions. The + // element types of the tensors may be different. + // * The result is a tuple that consists of a sorted tensor of keys (along the + // provided dimension, as above) as the first element, and a tensor with their + // corresponding values as the second element. + XlaOp Sort(XlaOp keys, + tensorflow::gtl::optional values = tensorflow::gtl::nullopt, + int64 dimension = -1); // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); @@ -755,22 +877,40 @@ class XlaBuilder { const GatherDimensionNumbers& dimension_numbers, tensorflow::gtl::ArraySlice window_bounds); - // Enqueues a Send node onto the computation, to send the given operand to - // a Recv instruction that shares the same channel handle. + // Enqueues a Scatter node onto the computation. + XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); + + // Enqueues a Send node onto the computation for device-to-device + // communication, to send the given operand to a Recv instruction that shares + // the same channel handle. void Send(const XlaOp& operand, const ChannelHandle& handle); + XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, + const ChannelHandle& handle); + + // Enqueues a Send node which sends data to the host. + XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, const ChannelHandle& handle); + + // Enqueues a Recv node which receives data from the host. + XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); + + // Enqueues an AfterAll operation with no operands producing a token-shaped + // value. + XlaOp CreateToken(); + + // Enqueues an AfterAll operation with no operands producing a token-shaped + // value. + XlaOp AfterAll(tensorflow::gtl::ArraySlice tokens); // Enqueues a Recv node onto the computation. The data comes from a Send // instruction that shares the same channel handle and its shape must // be the same as the given shape. XlaOp Recv(const Shape& shape, const ChannelHandle& handle); - - // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on any parameters, or on stateful operators such - // as `RngNormal` or `Infeed`. - // - // This tests whether a computation is a compile-time constant without - // evaluating the computation. - StatusOr IsConstant(const XlaOp& operand) const; + XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); // Normalizes operand across spatial and batch dimensions for each feature. // @@ -810,65 +950,6 @@ class XlaBuilder { const XlaOp& grad_output, float epsilon, int64 feature_index); - // Returns a new XlaBuilder whose resultant Computation is used only by this - // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error - // behavior as the parent. - std::unique_ptr CreateSubBuilder(const string& computation_name); - - // Builds the computation with the requested operations, or returns a non-ok - // status. Note that all ops that have been enqueued will be moved to the - // computation being returned. - StatusOr Build(); - - // Builds the computation with the requested operations, or notes an error in - // the parent XlaBuilder and returns an empty computation if building failed. - // This function is intended to be used where the returned XlaComputation is - // only used by the parent XlaBuilder and hence further operation on the - // returned XlaComputation will simply be error'ed out if an error occurred - // while building this computation. If the built computation is to be used by - // a XlaBuilder other than the parent XlaBuilder then Build() should be used - // instead. - XlaComputation BuildAndNoteError(); - - // Returns a subgraph that roots on the given root. If the root is not a - // compile-time constant (see `IsConstant`), returns an error. - // - // This will copy the needed ops/computations to the subgraph. - StatusOr BuildConstantSubGraph(const XlaOp& root_op) const; - - // Returns the first error that was encountered while building the - // computation. When an error is encountered, by default we return a vacuous - // XlaOp and inform the user of the error that occurred while - // building the computation when they make a final call to Build(). - // - // See also set_die_immediately_on_error(). - Status first_error() const { return first_error_; } - - // Returns the shape of the given op. - StatusOr GetShape(const XlaOp& op) const; - - // Returns the (inferred) result for the current computation's shape. - StatusOr GetProgramShape() const; - - // Reports an error to the builder, by - // * storing it internally and capturing a backtrace if it's the first error - // (this deferred value will be produced on the call to - // Build()/GetShape()/...) - // * dying if die_immediately_on_error_ is true. - // Returns an XlaOp with an invalid handle but a valid builder. This value can - // be returned in place of a value in APIs that return an XlaOp. - XlaOp ReportError(const Status& error); - - // A helper function that converts a StatusOr into an XlaOp. - // If the Status was an error, reports the error to builder and returns an - // invalid XlaOp handle. - XlaOp ReportErrorOrReturn(const StatusOr& op); - - // A helper function that runs a function that returns a StatusOr and - // returns an XlaOp. - XlaOp ReportErrorOrReturn(const std::function()>& op_creator); - - private: StatusOr AddInstruction( HloInstructionProto&& instr, HloOpcode opcode, tensorflow::gtl::ArraySlice operands = {}); @@ -908,9 +989,8 @@ class XlaBuilder { // shape. StatusOr Reshape(const Shape& shape, const XlaOp& operand); - // Returns the (inferred) result for the program shape for the current - // computation and fills the root_id in the pointer. - StatusOr GetProgramShape(int64* root_id) const; + // Returns the (inferred) result for the program shape using the given root. + StatusOr GetProgramShape(int64 root_id) const; // Returns shapes for the operands. StatusOr> GetOperandShapes( @@ -971,6 +1051,313 @@ class XlaBuilder { bool die_immediately_on_error_ = false; XlaBuilder* parent_builder_{nullptr}; + + friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, + const Shape& shape, const string& name); + friend XlaOp ConstantLiteral(XlaBuilder* builder, + const LiteralSlice& literal); + template + friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value); + template + friend XlaOp ConstantR1(XlaBuilder* builder, + tensorflow::gtl::ArraySlice values); + friend XlaOp ConstantR1(XlaBuilder* builder, + const tensorflow::core::Bitmap& values); + template + friend XlaOp ConstantR2( + XlaBuilder* builder, + std::initializer_list> values); + template + friend XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, + const Array& values, + const Layout& layout); + template + friend XlaOp ConstantFromArray(XlaBuilder* builder, + const Array& values); + template + friend XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, + const Array2D& values, + const Layout& layout); + template + friend XlaOp ConstantR2FromArray2D(XlaBuilder* builder, + const Array2D& values); + template + friend XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, + const Array3D& values, + const Layout& layout); + template + friend XlaOp ConstantR3FromArray3D(XlaBuilder* builder, + const Array3D& values); + template + friend XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, + const Array4D& values, + const Layout& layout); + template + friend XlaOp ConstantR4FromArray4D(XlaBuilder* builder, + const Array4D& values); + + template + friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); + + friend XlaOp Broadcast(const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); + + friend XlaOp BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice broadcast_dimensions); + + friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config); + + friend XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); + + friend XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes); + + friend XlaOp Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + + friend XlaOp Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + + friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno); + + friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + + friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices); + + friend XlaOp ConcatInDim(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + int64 dimension); + + friend void Trace(const string& tag, const XlaOp& operand); + + friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true, + const XlaOp& on_false); + friend XlaOp Tuple(XlaBuilder* builder, + tensorflow::gtl::ArraySlice elements); + friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); + friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers); + friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + Padding padding); + friend XlaOp ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + friend XlaOp ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers); + friend XlaOp ConvGeneral( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers); + friend XlaOp ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + friend XlaOp Fft(const XlaOp& operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, + const string& config); + friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config); + friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands); + friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape); + friend XlaOp HostCompute(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, + const Shape& shape); + friend XlaOp Complex(const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Conj(const XlaOp& operand); + friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Not(const XlaOp& operand); + friend XlaOp ShiftLeft( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); + friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation); + friend XlaOp ReduceWindow( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding); + friend XlaOp ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + friend XlaOp CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids); + friend XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids, + const tensorflow::gtl::optional& channel_id); + friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups); + friend XlaOp SelectAndScatter( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + friend XlaOp SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + friend XlaOp Abs(const XlaOp& operand); + friend XlaOp Atan2(const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp Exp(const XlaOp& operand); + friend XlaOp Expm1(const XlaOp& operand); + friend XlaOp Floor(const XlaOp& operand); + friend XlaOp Ceil(const XlaOp& operand); + friend XlaOp Round(const XlaOp& operand); + friend XlaOp Log(const XlaOp& operand); + friend XlaOp Log1p(const XlaOp& operand); + friend XlaOp Sign(const XlaOp& operand); + friend XlaOp Clz(const XlaOp& operand); + friend XlaOp Cos(const XlaOp& operand); + friend XlaOp Sin(const XlaOp& operand); + friend XlaOp Tanh(const XlaOp& operand); + friend XlaOp Real(const XlaOp& operand); + friend XlaOp Imag(const XlaOp& operand); + friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp IsFinite(const XlaOp& operand); + // TODO(b/64798317): Finish CPU & GPU implementation, then replace xla::Iota + // in xla/client/lib/numeric.h with this (renamed to xla::Iota). + friend XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size); + friend XlaOp ConvertElementType(const XlaOp& operand, + PrimitiveType new_element_type); + friend XlaOp BitcastConvertType(const XlaOp& operand, + PrimitiveType new_element_type); + friend XlaOp Neg(const XlaOp& operand); + friend XlaOp Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation); + friend XlaOp Rev(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + friend XlaOp Sort(XlaOp keys, tensorflow::gtl::optional values, + int64 dimension); + friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); + friend XlaOp Map(XlaBuilder* builder, + tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands); + friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, + const Shape& shape); + friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); + friend XlaOp While(const XlaComputation& condition, + const XlaComputation& body, const XlaOp& init); + friend XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation); + friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits); + friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds); + friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); + friend void Send(const XlaOp& operand, const ChannelHandle& handle); + friend XlaOp Recv(XlaBuilder* builder, const Shape& shape, + const ChannelHandle& handle); + friend XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index); + friend XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index); + friend XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index); + friend XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, + const ChannelHandle& handle); + friend XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); + friend XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const ChannelHandle& handle); + friend XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); + friend XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, + const string& config); + friend XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const string& outfeed_config); + friend XlaOp CreateToken(XlaBuilder* builder); + friend XlaOp AfterAll(XlaBuilder* builder, + tensorflow::gtl::ArraySlice tokens); }; // RAII-style object: sets the current sharding assignment in builder on @@ -1087,6 +1474,27 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); XlaOp Broadcast(const XlaOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes); +// Performs in-dimension-style broadcast. +// +// Operand specifies the input to be broadcast. "shape" is expected output +// shape. "broadcast_dimensions" are the dimensions to be broadcasting into. +// Dimension numbers in broadcast_dimensions map to individual dimensions +// of the operand, and specify what dimension of the output shape they +// should be broadcast. +// e.g. +// Say operand = [1, 2], i.e., a 1D tensor with 2 elements. +// and dimension of shape is [2,2]. +// Specifying {1} as brodcast_dimension will generate output +// [1 , 2] +// [1 , 2] +// On the other hand, specifying {0} as broadcast_dimension +// will generate output +// [1 , 1] +// [2 , 2] +XlaOp BroadcastInDim( + const XlaOp& operand, const Shape& shape, + const tensorflow::gtl::ArraySlice broadcast_dimensions); + // Enqueues a pad operation onto the computation that pads the given value on // the edges as well as between the elements of the input. padding_config // specifies the padding amount for each dimension. @@ -1281,6 +1689,13 @@ XlaOp Fft(const XlaOp& operand, FftType fft_type, XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config = ""); +// Variant of Infeed which takes a token-shaped operand and produces a +// two-element tuple containing the data value and a token-shaped value. +// Tokens are used for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, + const string& config = ""); + // Enqueues an outfeed instruction onto the computation. This instruction // generates outgoing data transfers for the given data. // @@ -1290,6 +1705,13 @@ XlaOp Infeed(XlaBuilder* builder, const Shape& shape, void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config); +// Variant of Outfeed which takes a token-shaped operand and produces a +// token-shaped value. Tokens are used for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, + const string& outfeed_config); + // Enqueues a call instruction onto the computation. XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, tensorflow::gtl::ArraySlice operands); @@ -1420,9 +1842,9 @@ XlaOp CrossReplicaSum( // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // -// - `channel_id`: for Allreduce nodes from different models, if they have the +// - `channel_id`: for Allreduce nodes from different modules, if they have the // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be -// applied cross models. +// applied cross modules. // // TODO(b/79737069): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, @@ -1430,6 +1852,13 @@ XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, const tensorflow::gtl::optional& channel_id = tensorflow::gtl::nullopt); +// Enqueues an operation that do an Alltoall of the operand cross cores. +// +// TODO(b/110096724): This is NOT YET ready to use. +XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, + int64 concat_dimension, int64 split_count, + const std::vector& replica_groups = {}); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, @@ -1498,16 +1927,6 @@ XlaOp Real(const XlaOp& operand); // Enqueues an imaginary-part instruction onto the computation. XlaOp Imag(const XlaOp& operand); -// Enqueues a float32 sqrt instruction onto the computation. -// (float32 is specified as there is an implicit float32 0.5f constant -// exponent). -XlaOp SqrtF32(const XlaOp& operand); - -// Enqueues a float32 square instruction onto the computation. -// (float32 is specified as there is an implicit float32 2.0f constant -// exponent). -XlaOp SquareF32(const XlaOp& operand); - // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions = {}); @@ -1528,14 +1947,6 @@ XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); // identical. XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); -// Enqueues a float32 reciprocal instruction onto the computation. -// (float32 is specified as there is an implicit float32 -1.0f constant -// exponent). -// -// TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the -// shape of the operand. -XlaOp ReciprocalF32(const XlaOp& operand); - // Enqueues a negate instruction onto the computation. XlaOp Neg(const XlaOp& operand); @@ -1549,7 +1960,24 @@ XlaOp Transpose(const XlaOp& operand, XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions); // Enqueues a sort (as increasing order) instruction onto the computation. -XlaOp Sort(const XlaOp& operand); +// If only keys are provided: +// * If the keys are an rank-1 tensor (an array), the result is a sorted array +// of keys, in ascending order. +// * If the keys have higher rank, the keys are sorted along the provided +// dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension +// value of 0 will indepenently sort every column, and a dimension value of 1 +// will independently sort each row. If no dimension number is provided, then +// the last dimension is chosen by default. +// +// If both keys and values are provided: +// * The keys and the values must tensors with the same dimensions. The +// element types of the tensors may be different. +// * The result is a tuple that consists of a sorted tensor of keys (along the +// provided dimension, as above) as the first element, and a tensor with their +// corresponding values as the second element. +XlaOp Sort(XlaOp keys, + tensorflow::gtl::optional values = tensorflow::gtl::nullopt, + int64 dimension = -1); // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); @@ -1587,16 +2015,64 @@ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, const GatherDimensionNumbers& dimension_numbers, tensorflow::gtl::ArraySlice window_bounds); -// Enqueues a Send node onto the computation, to send the given operand to -// a Recv instruction that shares the same channel handle. +// Enqueues a Scatter node onto the computation. +XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); + +// Enqueues a Send node onto the computation for device-to-device +// communication. This operation sends the given operand to +// a Recv instruction in a different computation that shares the same channel +// handle. void Send(const XlaOp& operand, const ChannelHandle& handle); -// Enqueues a Recv node onto the computation. The data comes from a Send -// instruction that shares the same channel handle and its shape must -// be the same as the given shape. +// Variant of Send which takes a token-shaped operand and produces a +// token-shaped value. Tokens are used for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, + const ChannelHandle& handle); + +// Enqueues a Recv node onto the computation for device-to-device +// communication. The data comes from a Send instruction in a different +// computation that shares the same channel handle and its shape must be the +// same as the given shape. XlaOp Recv(XlaBuilder* builder, const Shape& shape, const ChannelHandle& handle); +// Variant of Recv which takes a token-shaped operand and produces a two-element +// tuple containing the data value and a token-shaped value. Tokens are used +// for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); + +// Enqueues a Send node which transfers data from the device to the host. The +// 'shape_with_layout' argument defines the layout of the data transferred; its +// shape must be compatible with the shape of the operand. The operand must be +// array-shaped. +// TODO(b/111544877): Support tuple shapes. +XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, + const Shape& shape_with_layout, const ChannelHandle& handle); + +// Enqueues a Recv node which transfers data from the host to the device. The +// given shape must contain a layout and must be an array. +// TODO(b/111544877): Support tuple shapes. +XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, + const ChannelHandle& handle); + +// Enqueues an operation (AfterAll) with no operands that produces a +// token-shaped value. Tokens are used for ordering side-effecting operations. +// This is a separate method from AfterAll to facility the removal of +// operand-less AfterAll instructions. +// TODO(b/110532604): Remove this function when all tokens are derived from a +// single token generated or passed into the entry computation. +XlaOp CreateToken(XlaBuilder* builder); + +// Enqueues an AfterAll instruction which produces a token-shaped value and +// takes a variadic number of token-shaped operands. The number of operands must +// be greater than zero. Used for joining tokens. +XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice tokens); + // Normalizes operand across spatial and batch dimensions for each feature. // // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` @@ -1639,12 +2115,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, template XlaOp XlaBuilder::ConstantR0(NativeT value) { - return ConstantLiteral(*Literal::CreateR0(value)); + return ConstantLiteral(*LiteralUtil::CreateR0(value)); } template XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice values) { - return ConstantLiteral(*Literal::CreateR1(values)); + return ConstantLiteral(*LiteralUtil::CreateR1(values)); } template @@ -1656,44 +2132,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { } inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { - return ConstantLiteral(*Literal::CreateR1(values)); + return ConstantLiteral(*LiteralUtil::CreateR1(values)); } template XlaOp XlaBuilder::ConstantR2( std::initializer_list> values) { - return ConstantLiteral(*Literal::CreateR2(values)); + return ConstantLiteral(*LiteralUtil::CreateR2(values)); } template XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array& values, const Layout& layout) { return ConstantLiteral( - *Literal::CreateFromArrayWithLayout(values, layout)); + *LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp XlaBuilder::ConstantFromArray(const Array& values) { - return ConstantLiteral(*Literal::CreateFromArray(values)); + return ConstantLiteral(*LiteralUtil::CreateFromArray(values)); } template XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( const Array2D& values, const Layout& layout) { return ConstantLiteral( - *Literal::CreateFromArrayWithLayout(values, layout)); + *LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D& values) { - return ConstantLiteral(*Literal::CreateR2FromArray2D(values)); + return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D(values)); } template XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { return ConstantLiteral( - *Literal::CreateR3FromArray3DWithLayout(values, layout)); + *LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); } template @@ -1716,13 +2192,13 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { template XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { - return ConstantLiteral(builder, *Literal::CreateR0(value)); + return ConstantLiteral(builder, *LiteralUtil::CreateR0(value)); } template XlaOp ConstantR1(XlaBuilder* builder, tensorflow::gtl::ArraySlice values) { - return ConstantLiteral(builder, *Literal::CreateR1(values)); + return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); } template @@ -1735,13 +2211,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) { inline XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values) { - return ConstantLiteral(builder, *Literal::CreateR1(values)); + return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); } template XlaOp ConstantR2(XlaBuilder* builder, std::initializer_list> values) { - return ConstantLiteral(builder, *Literal::CreateR2(values)); + return ConstantLiteral(builder, *LiteralUtil::CreateR2(values)); } template @@ -1749,12 +2225,14 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, const Array& values, const Layout& layout) { return ConstantLiteral( - builder, *Literal::CreateFromArrayWithLayout(values, layout)); + builder, + *LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values) { - return ConstantLiteral(builder, *Literal::CreateFromArray(values)); + return ConstantLiteral(builder, + *LiteralUtil::CreateFromArray(values)); } template @@ -1762,14 +2240,15 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, const Array2D& values, const Layout& layout) { return ConstantLiteral( - builder, *Literal::CreateFromArrayWithLayout(values, layout)); + builder, + *LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp ConstantR2FromArray2D(XlaBuilder* builder, const Array2D& values) { return ConstantLiteral(builder, - *Literal::CreateR2FromArray2D(values)); + *LiteralUtil::CreateR2FromArray2D(values)); } template @@ -1778,7 +2257,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, const Layout& layout) { return ConstantLiteral( builder, - *Literal::CreateR3FromArray3DWithLayout(values, layout)); + *LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); } template @@ -1802,4 +2281,4 @@ XlaOp ConstantR4FromArray4D(XlaBuilder* builder, } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc similarity index 81% rename from tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc rename to tensorflow/compiler/xla/client/xla_builder_test.cc index 3b8beb2c7840e23752b5f47bbc5f55d89751884d..49a15ec3b449bdec07aa6ecfbc40b7b9f62c3f4e 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -13,16 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -45,6 +47,17 @@ class XlaBuilderTest : public ::testing::Test { return HloModule::CreateFromProto(proto, config); } + // Overload which explicitly specifies the root instruction. + StatusOr> BuildHloModule(XlaBuilder* b, + XlaOp root) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build(root)); + const HloModuleProto& proto = computation.proto(); + TF_ASSIGN_OR_RETURN(const auto& config, + HloModule::CreateModuleConfigFromProto( + proto, legacy_flags::GetDebugOptionsFromFlags())); + return HloModule::CreateFromProto(proto, config); + } + // Returns the name of the test currently being run. string TestName() const { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); @@ -292,6 +305,21 @@ TEST_F(XlaBuilderTest, Transpose) { EXPECT_THAT(root, op::Transpose(op::Parameter())); } +TEST_F(XlaBuilderTest, AllToAll) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); + AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, + /*split_count=*/2); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + + // AllToAll is decomposed into slices -> all-to-all -> gte -> concat. + EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate); + EXPECT_EQ(root->operand(0)->operand(0)->opcode(), HloOpcode::kAllToAll); + EXPECT_TRUE( + ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8}))); +} + TEST_F(XlaBuilderTest, ReportError) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); @@ -319,5 +347,45 @@ TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) { EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error")); } +TEST_F(XlaBuilderTest, BuildWithSpecificRoot) { + XlaBuilder b(TestName()); + XlaOp constant = ConstantR0(&b, 1.0); + Add(constant, ConstantR0(&b, 2.0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/constant)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Constant()); +} + +TEST_F(XlaBuilderTest, BuildWithSpecificRootAndMultipleParameters) { + // Specifying a particular root in Build should still include all entry + // parameters. + XlaBuilder b(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {42, 123}); + XlaOp x = Parameter(&b, 0, shape, "x"); + XlaOp y = Parameter(&b, 1, shape, "y"); + XlaOp z = Parameter(&b, 2, shape, "z"); + Add(x, Sub(y, z)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/x)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Parameter()); + EXPECT_EQ(module->entry_computation()->num_parameters(), 3); + EXPECT_EQ(module->entry_computation()->instruction_count(), 5); +} + +TEST_F(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) { + XlaBuilder b(TestName()); + XlaBuilder other_b(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {42, 123}); + + Parameter(&b, 0, shape, "param"); + XlaOp other_param = Parameter(&other_b, 0, shape, "other_param"); + + Status status = b.Build(other_param).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("root operation is not in this computation")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD deleted file mode 100644 index b0f41ac1d3790e60eed03bb7b5fd07a90b0be9c9..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ /dev/null @@ -1,79 +0,0 @@ -# Description: -# The new XLA client libraries. -# -# This is NOT YET ready to use. - -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = [":friends"]) - -package_group( - name = "friends", - includes = [ - "//tensorflow/compiler/xla:friends", - ], -) - -# Filegroup used to collect source files for dependency checking. -filegroup( - name = "c_srcs", - data = glob([ - "**/*.cc", - "**/*.h", - ]), -) - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "xla_computation", - srcs = ["xla_computation.cc"], - hdrs = ["xla_computation.h"], - deps = [ - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_proto", - ], -) - -cc_library( - name = "xla_builder", - srcs = ["xla_builder.cc"], - hdrs = ["xla_builder.h"], - deps = [ - ":xla_computation", - "//tensorflow/compiler/xla:execution_options_util", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client:sharding_builder", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "xla_builder_test", - srcs = ["xla_builder_test.cc"], - deps = [ - ":xla_builder", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_matchers", - "//tensorflow/core:test", - ], -) diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_computation.cc similarity index 94% rename from tensorflow/compiler/xla/client/xla_client/xla_computation.cc rename to tensorflow/compiler/xla/client/xla_computation.cc index 72e3935696e0c44ae3893fc8f1ceb261fa5e2646..3543d41fc2656ec028646edebc0bf5b6af7f67a5 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_computation.cc +++ b/tensorflow/compiler/xla/client/xla_computation.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.h b/tensorflow/compiler/xla/client/xla_computation.h similarity index 90% rename from tensorflow/compiler/xla/client/xla_client/xla_computation.h rename to tensorflow/compiler/xla/client/xla_computation.h index 0ffba208b1f8683fe1d26107cbfd096b856267f1..71598ef8b296a760b0ee818fce0a59aed5cfc6b4 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_computation.h +++ b/tensorflow/compiler/xla/client/xla_computation.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_COMPUTATION_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_COMPUTATION_H_ #include @@ -64,4 +64,4 @@ class XlaComputation { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index abd10b164eaef8e75ed304483861baf250c5b954..fb135f5ceda67ce6c001de15b8f3f084ca164826 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -20,7 +20,7 @@ from __future__ import print_function import math -import numpy as np +import numpy as _np # Avoids becoming a part of public Tensorflow API. from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python_api import xla_shape @@ -85,7 +85,7 @@ class Sharding(object): something we really want to expose to users (especially as the contract for tile_assignment is very strict). """ - if not isinstance(tile_assignment, np.ndarray): + if not isinstance(tile_assignment, _np.ndarray): raise TypeError('Tile assignment must be of type np.ndarray') if not isinstance(tile_shape, xla_shape.Shape): raise TypeError('Tile shape must be of type xla_shape.Shape') diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 15eeb2ea13607d43c995197f8f0e3c58abd4d94a..b72d190d54591384392e79e73e90cf52df04a902 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -297,7 +297,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { shape.layout().padded_dimensions_size() == 0) { return false; } - CHECK(IsDenseArray(shape)); + CHECK(IsDenseArray(shape)) << shape.ShortDebugString(); CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size()); for (int64 i = 0; i < shape.dimensions_size(); ++i) { if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) { diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc new file mode 100644 index 0000000000000000000000000000000000000000..36e472568ecfdb97c828817ed339260ee7878723 --- /dev/null +++ b/tensorflow/compiler/xla/literal.cc @@ -0,0 +1,2090 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/literal.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::strings::Printf; +using tensorflow::strings::StrCat; + +namespace xla { + +namespace { + +constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; + +// Converts between little and big endian. +// +// Precondition: size % 2 == 0 (elements in the array are 16 bits long) +void ConvertEndianShort(string* bytes) { + CHECK_EQ(bytes->size() / 2, 0); + for (int64 i = 0; i < bytes->size(); i += 2) { + std::swap((*bytes)[i], (*bytes)[i + 1]); + } +} + +void ConvertEndianShort(char* bytes, int64 size) { + CHECK_EQ(size / 2, 0); + for (int64 i = 0; i < size; i += 2) { + std::swap(bytes[i], bytes[i + 1]); + } +} + +} // namespace + +LiteralBase::~LiteralBase() {} + +std::ostream& operator<<(std::ostream& out, const Literal& literal) { + out << literal.ToString(); + return out; +} + +MutableLiteralBase::StrideConfig::StrideConfig( + const Shape& source_shape, const Shape& dest_shape, + tensorflow::gtl::ArraySlice dimensions) + : dimensions(dimensions), + base(dimensions.size(), 0), + step(dimensions.size(), 1) { + if (!dimensions.empty()) { + // Selects the shape with the largest minor dimension as the one upon + // which to run the tight stride loop. + if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >= + dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) { + minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0); + dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension); + } else { + minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0); + source_stride = + IndexUtil::GetDimensionStride(source_shape, minor_dimension); + } + minor_loop_size = dimensions[minor_dimension]; + step[minor_dimension] = minor_loop_size; + } +} + +Literal::Literal(const Shape& shape) + : Literal(shape, /*allocate_arrays=*/true) {} + +void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + SetPiece(subshape, &child_piece, allocate_arrays); + + piece->emplace_back(std::move(child_piece)); + } + } else if (ShapeUtil::IsArray(shape)) { + if (allocate_arrays) { + if (LayoutUtil::IsSparseArray(shape)) { + // For sparse arrays, the buffer must be of the size of the maximum + // number of sparse elements possible. + const int64 max_sparse_elements = + LayoutUtil::MaxSparseElements(shape.layout()); + piece->set_buffer( + new char[max_sparse_elements * + ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); + piece->set_sparse_indices( + new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape))); + } else { + piece->set_buffer(new char[piece->size_bytes()]); + } + } + } else { + // If the shape is neither an array nor tuple, then it must be + // zero-sized. Otherwise, some memory needs to be allocated for it. + CHECK_EQ(piece->size_bytes(), 0); + } +} + +Literal::Literal(const Shape& shape, bool allocate_arrays) + : MutableLiteralBase() { + shape_ = MakeUnique(shape); + CHECK(LayoutUtil::HasLayout(*shape_)); + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + CHECK(&root_piece_->subshape() == shape_.get()); + + SetPiece(*shape_, root_piece_, allocate_arrays); +} + +Literal::~Literal() { + if (root_piece_ != nullptr) { + DeallocateBuffers(); + delete root_piece_; + } +} + +void Literal::DeallocateBuffers() { + root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (piece->buffer() != nullptr) { + delete[] piece->buffer(); + delete piece->sparse_indices(); + } + }); +} + +Literal::Literal(Literal&& other) : MutableLiteralBase() { + *this = std::move(other); +} + +Literal& Literal::operator=(Literal&& other) { + DCHECK(&other.root_piece_->subshape() == other.shape_.get()); + using std::swap; + swap(shape_, other.shape_); + swap(root_piece_, other.root_piece_); + DCHECK(&root_piece_->subshape() == shape_.get()); + + return *this; +} + +std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { + auto literal = MakeUnique(shape); + literal->root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (ShapeUtil::IsArray(piece->subshape())) { + memset(piece->untyped_data(), 0, piece->size_bytes()); + } + }); + return literal; +} + +const SparseIndexArray* LiteralBase::sparse_indices( + const ShapeIndex& shape_index) const { + return piece(shape_index).sparse_indices(); +} + +SparseIndexArray* MutableLiteralBase::sparse_indices( + const ShapeIndex& shape_index) { + return piece(shape_index).sparse_indices(); +} + +template +Status MutableLiteralBase::CopySliceFromInternal( + const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); + TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size()); + + auto linear_index = [](const Shape& shape, + tensorflow::gtl::ArraySlice multi_index) { + return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index); + }; + + if (ShapeUtil::Rank(src_literal.shape()) == 0 || + ShapeUtil::Rank(shape()) == 0) { + // If any of the two shapes are scalars, we can just call the StridedCopy() + // directly, and we know we will be copying only one value. + TF_RET_CHECK(copy_size.empty()); + StridedCopy(data(), linear_index(shape(), dest_base), 0, + src_literal.data(), + linear_index(src_literal.shape(), src_base), 0, 1); + } else if (!ShapeUtil::IsZeroElementArray(shape()) && + !ShapeUtil::IsZeroElementArray(src_literal.shape())) { + // Perform copy if neither src nor dest has dimensions with zero element, + // otherwise it's a no-op. + TF_RET_CHECK(src_base.size() == dest_base.size()); + TF_RET_CHECK(src_base.size() == copy_size.size()); + + // Scan the source from minor, stepping in copy size blocks, then within + // the index enumaration functor, do a strided copy advancing source index + // by one (walking through the minor dimension), and destination index by + // proper stride size at the matching dimension. + DimensionVector src_indexes(src_base.size(), 0); + DimensionVector dest_indexes(dest_base.size(), 0); + MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(), + copy_size); + + auto copy_proc = [&](tensorflow::gtl::ArraySlice indexes) { + // Map from multi-dimensional index, to source index. + std::transform(indexes.begin(), indexes.end(), src_base.begin(), + src_indexes.begin(), std::plus()); + // Map from multi-dimensional index, to destination index. + std::transform(indexes.begin(), indexes.end(), dest_base.begin(), + dest_indexes.begin(), std::plus()); + + int64 src_index = linear_index(src_literal.shape(), src_indexes); + int64 dest_index = linear_index(shape(), dest_indexes); + + // `this->` is needed to workaround MSVC bug: #16882 + StridedCopy(this->data(), dest_index, stride_config.dest_stride, + src_literal.data(), src_index, + stride_config.source_stride, stride_config.minor_loop_size); + return true; + }; + + ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base, + stride_config.dimensions, stride_config.step, + copy_proc); + } + return Status::OK(); +} + +Status MutableLiteralBase::CopyElementFrom( + const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index) { + DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); + const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex( + src_literal.shape(), src_index); + const int64 dest_linear_index = + IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index); + const int64 primitive_size = + ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); + + char* dest_address = + static_cast(untyped_data()) + dest_linear_index * primitive_size; + const char* source_address = + static_cast(src_literal.untyped_data()) + + src_linear_index * primitive_size; + if (dest_address != source_address) { + memcpy(dest_address, source_address, primitive_size); + } + return Status::OK(); +} + +/* static */ StatusOr> +MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { + if (!proto.has_shape()) { + return InvalidArgument("LiteralProto has no shape"); + } + if (!LayoutUtil::HasLayout(proto.shape())) { + return InvalidArgument("LiteralProto has no layout"); + } + + auto literal = MakeUnique(proto.shape()); + + TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + const LiteralProto* proto_element = &proto; + for (int64 i : index) { + CHECK(i < proto_element->tuple_literals_size()); + proto_element = &proto_element->tuple_literals(i); + } + + if (ShapeUtil::IsTuple(piece->subshape())) { + if (proto_element->tuple_literals_size() != + ShapeUtil::TupleElementCount(piece->subshape())) { + return InvalidArgument( + "Expected %lld tuple elements in LiteralProto, has %d", + ShapeUtil::TupleElementCount(piece->subshape()), + proto_element->tuple_literals_size()); + } + return Status::OK(); + } + if (piece->subshape().element_type() == TOKEN) { + return Status::OK(); + } + + CHECK(ShapeUtil::IsArray(piece->subshape())); + TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); + + return Status::OK(); + })); + + return std::move(literal); +} + +std::vector Literal::DecomposeTuple() { + CHECK(ShapeUtil::IsTuple(shape())); + std::vector elements; + for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { + elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), + /*allocate_arrays=*/false)); + Literal& element = elements.back(); + element.root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* dest_piece) { + ShapeIndex src_index = {i}; + for (int64 j : index) { + src_index.push_back(j); + } + Piece& src_piece = piece(src_index); + + // Move the respective buffer and sparse indices over to the element + // Literal. + dest_piece->set_buffer(src_piece.buffer()); + src_piece.set_buffer(nullptr); + dest_piece->set_sparse_indices(src_piece.sparse_indices()); + src_piece.set_sparse_indices(nullptr); + }); + } + // Set this literal to be nil-shaped. + *this = Literal(); + return elements; +} + +namespace { + +// Copies the elements in 'src' to 'dest'. The shape and layout of the data in +// the array slices are indicated by dest_shape and src_shape respectively. +template +void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, + tensorflow::gtl::ArraySlice src, + const Shape& dest_shape, const Shape& src_shape) { + CHECK(ShapeUtil::Compatible(dest_shape, src_shape)); + if (ShapeUtil::IsZeroElementArray(dest_shape)) { + return; + } + std::vector index(ShapeUtil::Rank(dest_shape)); + do { + dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] = + src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)]; + } while (IndexUtil::BumpIndices(dest_shape, &index)); +} + +} // namespace + +Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { + CHECK(subshape_ != nullptr); + CHECK(src.subshape_ != nullptr); + if (ShapeUtil::Equal(subshape(), src.subshape())) { + // If the layouts are equal it's faster just to memcpy. + memcpy(buffer(), src.buffer(), src.size_bytes()); + } else { + TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape())); + std::vector origin(ShapeUtil::Rank(subshape()), 0); + switch (subshape().element_type()) { +#define COPY_ELEMENTS(XLA_T, NATIVE_T) \ + case (XLA_T): \ + CopyElementsBetween(data(), src.data(), \ + subshape(), src.subshape()); \ + break; + COPY_ELEMENTS(U8, uint8); + COPY_ELEMENTS(U16, uint16); + COPY_ELEMENTS(U32, uint32); + COPY_ELEMENTS(U64, uint64); + COPY_ELEMENTS(S8, int8); + COPY_ELEMENTS(S16, int16); + COPY_ELEMENTS(S32, int32); + COPY_ELEMENTS(S64, int64); + COPY_ELEMENTS(F16, half); + COPY_ELEMENTS(BF16, bfloat16); + COPY_ELEMENTS(F32, float); + COPY_ELEMENTS(F64, double); + COPY_ELEMENTS(C64, complex64); + COPY_ELEMENTS(PRED, bool); +#undef COPY_ELEMENTS + default: + return Unimplemented( + "Copying a Literal object with element type %s is not implemented.", + PrimitiveType_Name(subshape().element_type()).c_str()); + } + } + return Status::OK(); +} + +Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal, + const ShapeIndex& dest_shape_index, + const ShapeIndex& src_shape_index) { + const Shape& dest_subshape = + ShapeUtil::GetSubshape(shape(), dest_shape_index); + const Shape& src_subshape = + ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index); + if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) { + return InvalidArgument( + "Destination subshape incompatible with source subshape: %s vs %s", + ShapeUtil::HumanString(dest_subshape).c_str(), + ShapeUtil::HumanString(src_subshape).c_str()); + } + return root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + if (!ShapeUtil::IsArray(piece->subshape())) { + return Status::OK(); + } + + // Determine if this index is in the part of this literal that we want + // to copy over from src_literal. + bool in_subtree_to_copy = true; + for (int i = 0; i < dest_shape_index.size(); ++i) { + if (index[i] != dest_shape_index[i]) { + in_subtree_to_copy = false; + break; + } + } + if (!in_subtree_to_copy) { + return Status::OK(); + } + // Construct the index of the corresponding piece in the source literal. + ShapeIndex src_piece_index = src_shape_index; + for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { + src_piece_index.push_back(index[i]); + } + TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index))); + return Status::OK(); + }); +} + +Status Literal::MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index) { + const Shape& dest_subshape = + ShapeUtil::GetSubshape(shape(), dest_shape_index); + if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) { + return InvalidArgument( + "Destination subshape not equal to source shape: %s vs %s", + ShapeUtil::HumanString(dest_subshape).c_str(), + ShapeUtil::HumanString(src_literal.shape()).c_str()); + } + + src_literal.root_piece_->ForEachSubpiece( + [&](const ShapeIndex& src_index, const Piece& src_piece) { + if (!ShapeUtil::IsArray(src_piece.subshape())) { + return; + } + + ShapeIndex dest_index = dest_shape_index; + for (int64 i : src_index) { + dest_index.push_back(i); + } + Piece& dest_piece = piece(dest_index); + delete[] dest_piece.buffer(); + dest_piece.set_buffer(src_piece.buffer()); + delete dest_piece.sparse_indices(); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); + }); + + src_literal.shape_ = MakeUnique(ShapeUtil::MakeNil()); + delete src_literal.root_piece_; + src_literal.root_piece_ = new LiteralBase::Piece(); + src_literal.root_piece_->set_subshape(src_literal.shape_.get()); + + return Status::OK(); +} + +Status MutableLiteralBase::CopySliceFrom( + const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); + TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) + << ShapeUtil::HumanString(src_literal.shape()); + TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape())); + + switch (shape().element_type()) { + case U8: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case U16: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case U32: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case U64: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case S8: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case S16: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case S32: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case S64: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case F16: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case BF16: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case F32: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case F64: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case C64: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + case PRED: + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); + default: + break; + } + return Unimplemented( + "Copying a slice from a Literal object with element type %d is not " + "implemented.", + shape().element_type()); +} + +void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(element_count(), values.bits()); + CHECK_EQ(shape().element_type(), PRED); + for (int64 i = 0; i < static_cast(values.bits()); ++i) { + Set({i}, values.get(i)); + } +} + +std::unique_ptr LiteralBase::Relayout( + const Layout& new_layout, const ShapeIndex& shape_index) const { + // Create new shape with 'new_layout' set at the given shape index. + Shape new_shape = shape(); + Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); + TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); + *subshape->mutable_layout() = new_layout; + auto result = MakeUnique(new_shape); + TF_CHECK_OK(result->CopyFrom(*this)); + return result; +} + +std::unique_ptr LiteralBase::Relayout( + const Shape& shape_with_layout) const { + CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) + << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) + << " not compatible with literal shape " + << ShapeUtil::HumanString(shape()); + std::unique_ptr result = CreateFromShape(shape_with_layout); + ShapeUtil::ForEachSubshape( + result->shape(), + [this, &result](const Shape& subshape, const ShapeIndex& index) { + if (ShapeUtil::IsArray(subshape)) { + TF_CHECK_OK(result->CopyFrom(*this, + /*dest_shape_index=*/index, + /*src_shape_index=*/index)); + } + }); + return result; +} + +StatusOr> LiteralBase::Broadcast( + const Shape& result_shape, + tensorflow::gtl::ArraySlice dimensions) const { + if (!ShapeUtil::IsArray(shape())) { + return InvalidArgument("Broadcast only supports arrays."); + } + + for (int64 i = 0; i < dimensions.size(); i++) { + TF_RET_CHECK(shape().dimensions(i) == + result_shape.dimensions(dimensions[i])); + } + + std::unique_ptr result = MakeUnique(result_shape); + + // scratch_source_index is temporary storage space for the computed index into + // the input literal. We put it here to avoid allocating an std::vector in + // every iteration of ShapeUtil::ForEachIndex. + std::vector scratch_source_index(shape().dimensions_size()); + + char* dest_data = static_cast(result->untyped_data()); + const char* source_data = static_cast(untyped_data()); + const int64 primitive_size = + ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); + + ShapeUtil::ForEachIndex( + result_shape, [&](tensorflow::gtl::ArraySlice output_index) { + for (int64 i = 0; i < dimensions.size(); ++i) { + scratch_source_index[i] = output_index[dimensions[i]]; + } + int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex( + result_shape, output_index); + int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex( + shape(), scratch_source_index); + memcpy(dest_data + primitive_size * dest_index, + source_data + primitive_size * source_index, primitive_size); + return true; + }); + + return std::move(result); +} + +StatusOr> LiteralBase::Reshape( + tensorflow::gtl::ArraySlice dimensions) const { + if (!ShapeUtil::IsArray(shape())) { + return InvalidArgument("Reshape does not support tuples."); + } + std::unique_ptr output; + if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { + output = + Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); + } else { + output = CloneToUnique(); + } + // Because the layout is monotonic, we can simply reuse the same sequence of + // values without changing their order. + *output->mutable_shape_do_not_use() = + ShapeUtil::MakeShape(shape().element_type(), dimensions); + + int64 elements_before = ShapeUtil::ElementsIn(shape()); + int64 elements_after = ShapeUtil::ElementsIn(output->shape()); + if (elements_before != elements_after) { + return InvalidArgument( + "Shapes before and after Literal::Reshape have different numbers " + "of elements: %s vs %s.", + ShapeUtil::HumanString(shape()).c_str(), + ShapeUtil::HumanString(output->shape()).c_str()); + } + return std::move(output); +} + +std::unique_ptr LiteralBase::Transpose( + tensorflow::gtl::ArraySlice permutation) const { + CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; + CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) + << "Given permutation is not a permutation of dimension numbers"; + // To transpose the array, we just permute the dimensions and layout, and + // do a straight memory copy of the raw data set. + // This is considerably faster than iterating over every array element using + // the EachCell<>() and Set<>() APIs. + std::vector inverse_permutation = InversePermutation(permutation); + Shape permuted_shape = + ShapeUtil::PermuteDimensions(inverse_permutation, shape()); + // Replace the layout with one affine to this shape, such that a + // transpose operation can be performed by leaving the flat values + // representation intact. + // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation. + // The shape with affine layout resulting from that operation will be + // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the + // most minor. + // + // Essentially, given MinMaj(Di) the position of the Di dimension within the + // minor to major vector, and given T(Di) the index that the original Di + // dimension has within the transposed array, a layout is affine if + // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major + // vector of the affine layout. + CHECK(LayoutUtil::IsDenseArray(permuted_shape)); + Layout* layout = permuted_shape.mutable_layout(); + layout->clear_minor_to_major(); + for (auto index : LayoutUtil::MinorToMajor(shape())) { + layout->add_minor_to_major(inverse_permutation[index]); + } + auto new_literal = MakeUnique(permuted_shape); + DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), + ShapeUtil::ByteSizeOf(shape())); + std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); + return new_literal; +} + +template +std::unique_ptr LiteralBase::SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice start_indices) const { + auto result_literal = MakeUnique(result_shape); + DimensionVector new_indices(ShapeUtil::Rank(result_shape)); + result_literal->EachCell( + [&](tensorflow::gtl::ArraySlice indices, NativeT /*value*/) { + for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + new_indices[i] = indices[i] + start_indices[i]; + } + NativeT value = Get(new_indices); + result_literal->Set(indices, value); + }); + return result_literal; +} + +std::unique_ptr LiteralBase::Slice( + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) const { + CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; + + DimensionVector result_dimensions; + for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { + CHECK_GE(start_indices[dnum], 0); + CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)) + << "dnum = " << dnum; + int64 dimension = limit_indices[dnum] - start_indices[dnum]; + CHECK_GE(dimension, 0) << "dnum = " << dnum; + result_dimensions.push_back(dimension); + } + const auto result_shape = + ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, + LayoutUtil::MinorToMajor(shape())); + switch (result_shape.element_type()) { + case F32: + return SliceInternal(result_shape, start_indices); + case BF16: + return SliceInternal(result_shape, start_indices); + case C64: + return SliceInternal(result_shape, start_indices); + case S32: + return SliceInternal(result_shape, start_indices); + case U32: + return SliceInternal(result_shape, start_indices); + default: + LOG(FATAL) << "not yet implemented: " + << PrimitiveType_Name(result_shape.element_type()); + } +} + +Literal LiteralBase::Clone() const { + Literal result(shape()); + TF_CHECK_OK(result.CopyFrom(*this)); + return result; +} + +std::unique_ptr LiteralBase::CloneToUnique() const { + auto result = MakeUnique(shape()); + TF_CHECK_OK(result->CopyFrom(*this)); + return result; +} + +string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { + const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); + CHECK(LayoutUtil::IsDenseArray(subshape)); + switch (subshape.element_type()) { + case PRED: + return Get(multi_index, shape_index) ? "true" : "false"; + case S8: + return StrCat(Get(multi_index, shape_index)); + case S16: + return StrCat(Get(multi_index, shape_index)); + case S32: + return StrCat(Get(multi_index, shape_index)); + case S64: + return StrCat(Get(multi_index, shape_index)); + case U8: + return StrCat(Get(multi_index, shape_index)); + case U16: + return StrCat(Get(multi_index, shape_index)); + case U32: + return StrCat(Get(multi_index, shape_index)); + case U64: + return StrCat(Get(multi_index, shape_index)); + case F16: + return StrCat(static_cast(Get(multi_index, shape_index))); + case F32: + return StrCat(Get(multi_index, shape_index)); + case BF16: + return StrCat( + static_cast(Get(multi_index, shape_index))); + case F64: + return StrCat(Get(multi_index, shape_index)); + case C64: { + complex64 c = Get(multi_index, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } + default: + LOG(FATAL) << PrimitiveType_Name(subshape.element_type()); + } +} + +string LiteralBase::GetSparseElementAsString( + int64 sparse_element_number, const ShapeIndex& shape_index) const { + const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); + CHECK(LayoutUtil::IsSparseArray(subshape)); + switch (subshape.element_type()) { + case PRED: + return GetSparseElement(sparse_element_number, shape_index) + ? "true" + : "false"; + case S8: + return StrCat(GetSparseElement(sparse_element_number, shape_index)); + case S16: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case S32: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case S64: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U8: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U16: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U32: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U64: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case F16: + return StrCat(static_cast( + GetSparseElement(sparse_element_number, shape_index))); + case F32: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case BF16: + return StrCat(static_cast( + GetSparseElement(sparse_element_number, shape_index))); + case F64: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case C64: { + complex64 c = + GetSparseElement(sparse_element_number, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } + default: + LOG(FATAL) << "Invalid element type for sparse arrays: " + << PrimitiveType_Name(subshape.element_type()); + } +} + +StatusOr LiteralBase::GetIntegralAsS64( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(LayoutUtil::IsDenseArray(shape())); + switch (shape().element_type()) { + case PRED: + return Get(multi_index); + case U8: + return Get(multi_index); + case S32: + return Get(multi_index); + case S64: + return Get(multi_index); + case U32: + return Get(multi_index); + case U64: + return Get(multi_index); + default: + return FailedPrecondition( + "Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type()).c_str()); + } +} + +size_t LiteralBase::Hash() const { + using tensorflow::Hash64; + using tensorflow::Hash64Combine; + + size_t hash_value = ShapeUtil::Hash(shape()); + + ShapeUtil::ForEachSubshape( + shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (!ShapeUtil::IsArray(subshape)) { + return; + } + + CHECK(LayoutUtil::IsDense(subshape.layout())); + hash_value = Hash64Combine( + hash_value, Hash64(static_cast(untyped_data(index)), + size_bytes(index))); + }); + + return hash_value; +} + +Status MutableLiteralBase::SetIntegralAsS64( + tensorflow::gtl::ArraySlice multi_index, int64 value) { + CHECK(LayoutUtil::IsDenseArray(shape())); + switch (shape().element_type()) { + case PRED: + Set(multi_index, value); + break; + case U8: + Set(multi_index, value); + break; + case S32: + Set(multi_index, value); + break; + case S64: + Set(multi_index, value); + break; + case U32: + Set(multi_index, value); + break; + case U64: + Set(multi_index, value); + break; + default: + return FailedPrecondition( + "Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type()).c_str()); + } + return Status::OK(); +} + +tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index) const { + const Piece& p = piece(shape_index); + CHECK_GE(sparse_element_number, 0); + CHECK_LT(sparse_element_number, p.sparse_indices()->index_count()); + return p.sparse_indices()->At(sparse_element_number); +} + +void MutableLiteralBase::SortSparseElements(const ShapeIndex& shape_index) { + piece(shape_index).SortSparseElements(); +} + +void LiteralBase::Piece::SortSparseElements() { + switch (subshape().element_type()) { + case PRED: + SortSparseElementsInternal(); + break; + case S8: + SortSparseElementsInternal(); + break; + case U8: + SortSparseElementsInternal(); + break; + case S16: + SortSparseElementsInternal(); + break; + case U16: + SortSparseElementsInternal(); + break; + case S32: + SortSparseElementsInternal(); + break; + case U32: + SortSparseElementsInternal(); + break; + case S64: + SortSparseElementsInternal(); + break; + case U64: + SortSparseElementsInternal(); + break; + case F32: + SortSparseElementsInternal(); + break; + case F64: + SortSparseElementsInternal(); + break; + case C64: + SortSparseElementsInternal(); + break; + case F16: + SortSparseElementsInternal(); + break; + case BF16: + SortSparseElementsInternal(); + break; + default: + LOG(FATAL) << "Element type not valid for sparse array: " + << PrimitiveType_Name(subshape().element_type()); + } +} + +template +void LiteralBase::Piece::SortSparseElementsInternal() { + CHECK(LayoutUtil::IsSparseArray(subshape())); + int64 num_elements = sparse_indices()->index_count(); + auto values = data(); + CHECK_LE(num_elements, values.size()); + sparse_indices()->SortWithValues( + tensorflow::gtl::MutableArraySlice(values.data(), num_elements)); +} + +namespace { + +void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, + bool print_layout, std::vector* pieces) { + const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + CHECK(LayoutUtil::HasLayout(literal.shape())); + CHECK(LayoutUtil::HasLayout(subshape)); + + auto shape_to_string = [print_layout](const Shape& shape) { + if (print_layout) { + return ShapeUtil::HumanStringWithLayout(shape); + } else { + return ShapeUtil::HumanString(shape); + } + }; + + // TODO(b/32894291): refactor this code to reduce code duplication. + if (ShapeUtil::IsTuple(subshape)) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" (\n"); + std::vector tuple_pieces; + for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { + ShapeIndex element_index = shape_index; + element_index.push_back(i); + std::vector element_pieces; + ToStringHelper(literal, element_index, print_layout, &element_pieces); + tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, "")); + } + pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n")); + pieces->push_back("\n)"); + return; + } + + if (ShapeUtil::IsToken(subshape)) { + pieces->push_back("token"); + return; + } + + if (LayoutUtil::IsSparseArray(subshape)) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back("{"); + int64 rank = ShapeUtil::Rank(subshape); + int64 num_elements = literal.sparse_element_count(); + for (int64 i = 0; i < num_elements; ++i) { + if (i > 0) { + pieces->push_back(", "); + } + if (rank == 1) { + pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); + pieces->push_back(": "); + } else { + pieces->push_back("["); + pieces->push_back( + tensorflow::str_util::Join(literal.GetSparseIndex(i), ", ")); + pieces->push_back("]: "); + } + pieces->push_back(literal.GetSparseElementAsString(i)); + } + pieces->push_back("}"); + return; + } + + CHECK(LayoutUtil::IsDenseArray(subshape)); + + auto element_to_string = + [&](tensorflow::gtl::ArraySlice indices) -> string { + PrimitiveType element_type = subshape.element_type(); + if (element_type == PRED) { + // We display predicates in a densely packed form. + return literal.Get(indices, shape_index) ? "1" : "0"; + } + return ((!indices.empty() && indices.back() > 0) ? ", " : "") + + literal.GetAsString(indices, shape_index); + }; + + if (ShapeUtil::Rank(subshape) == 0) { + pieces->push_back(literal.GetAsString({}, shape_index)); + } else if (ShapeUtil::Rank(subshape) == 1) { + pieces->push_back("{"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(element_to_string({i0})); + } + pieces->push_back("}"); + } else if (ShapeUtil::Rank(subshape) == 2) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(" { "); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(element_to_string({i0, i1})); + } + pieces->push_back(" "); + pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n"); + } + pieces->push_back("}"); + } else if (ShapeUtil::Rank(subshape) == 3) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(i0 > 0 ? ",\n{" : "{"); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(i1 > 0 ? ",\n { " : " { "); + for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { + pieces->push_back(element_to_string({i0, i1, i2})); + } + pieces->push_back(" }"); + } + pieces->push_back(" }"); + } + pieces->push_back("\n}"); + } else if (ShapeUtil::Rank(subshape) == 4) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { + pieces->push_back(" {"); + for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { + pieces->push_back(element_to_string({i0, i1, i2, i3})); + } + pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n"); + } + pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" + : " },\n"); + } + pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); + } + pieces->push_back("}"); + } else if (ShapeUtil::Rank(subshape) == 5) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { + pieces->push_back(Printf(" { /*i2=%lld*/\n", i2)); + for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { + pieces->push_back(" {"); + for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) { + pieces->push_back(element_to_string({i0, i1, i2, i3, i4})); + } + pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n" + : "},\n"); + } + pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n" + : " },\n"); + } + pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" + : " },\n"); + } + pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); + } + pieces->push_back("}"); + } else { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {"); + literal.EachCellAsString( + [&](tensorflow::gtl::ArraySlice indices, const string& value) { + pieces->push_back(" "); + pieces->push_back(value); + }); + pieces->push_back("}"); + } +} + +} // namespace + +int64 LiteralBase::sparse_element_count() const { + CHECK(LayoutUtil::IsSparseArray(shape())); + return sparse_indices()->index_count(); +} + +string LiteralBase::ToString(bool print_layout) const { + std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); + ToStringHelper(*this, {}, print_layout, &pieces); + return tensorflow::str_util::Join(pieces, ""); +} + +void LiteralBase::EachCellAsString( + const std::function indices, + const string& value)>& per_cell) const { + if (ShapeUtil::IsZeroElementArray(shape())) { + return; + } + std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( + shape(), /*linear_index=*/0); + do { + per_cell(indices, GetAsString(indices)); + } while (IndexUtil::BumpIndices(shape(), &indices)); +} + +namespace { +template +std::unique_ptr ConvertBetweenNativeTypesWithConverter( + const LiteralBase& src_literal, const ConverterType& converter) { + CHECK(ShapeUtil::IsArray(src_literal.shape())); + auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( + src_literal.shape(), + primitive_util::NativeToPrimitiveType())); + auto src_data = src_literal.data(); + auto dest_data = result_literal->template data(); + int64 num_elements = src_literal.element_count(); + + for (int64 i = 0; i < num_elements; ++i) { + dest_data[i] = converter(src_data[i]); + } + return result_literal; +} + +template +std::unique_ptr ConvertBetweenNativeTypes( + const LiteralBase& src_literal) { + auto converter = [](NativeSrcT src) { return static_cast(src); }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + +template +typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), + std::unique_ptr>::type +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { + auto converter = [](NativeSrcT src) { + return tensorflow::bit_cast(src); + }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + +// This template specialization is here to make the compiler happy. bit_cast has +// a static check that the types are the same size. This specialization should +// never be used because the source and destination types are checked for +// identical sizes higher up. +template +typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), + std::unique_ptr>::type +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { + LOG(FATAL) << "Invalid bitcast between types of different sizes."; +} + +template +std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { + CHECK(ShapeUtil::IsArray(src_literal.shape())); + auto result_literal = MakeUnique( + ShapeUtil::ChangeElementType(src_literal.shape(), C64)); + using NativeSrcT = + typename primitive_util::PrimitiveTypeToNative::type; + tensorflow::gtl::ArraySlice src_data = + src_literal.data(); + tensorflow::gtl::MutableArraySlice dest_data = + result_literal->data(); + int64 num_elements = src_literal.element_count(); + for (int64 i = 0; i < num_elements; ++i) { + dest_data[i] = complex64(static_cast(src_data[i]), 0); + } + return result_literal; +} + +template +std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, + bool bitcast) { + CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); + if (bitcast) { + return BitcastBetweenNativeTypes< + typename primitive_util::PrimitiveTypeToNative< + primitive_src_type>::type, + typename primitive_util::PrimitiveTypeToNative< + primitive_dest_type>::type>(src_literal); + } else { + return ConvertBetweenNativeTypes< + typename primitive_util::PrimitiveTypeToNative< + primitive_src_type>::type, + typename primitive_util::PrimitiveTypeToNative< + primitive_dest_type>::type>(src_literal); + } +} + +template +StatusOr> ConvertIfDestTypeMatches( + const LiteralBase& src_literal, PrimitiveType primitive_dest_type, + bool bitcast) { + switch (primitive_dest_type) { +#define CONVERT_IF_TYPES_MATCH(type) \ + case (type): \ + return ConvertIfTypesMatch(src_literal, \ + bitcast); + CONVERT_IF_TYPES_MATCH(PRED) + CONVERT_IF_TYPES_MATCH(S8) + CONVERT_IF_TYPES_MATCH(S32) + CONVERT_IF_TYPES_MATCH(S64) + CONVERT_IF_TYPES_MATCH(U8) + CONVERT_IF_TYPES_MATCH(U32) + CONVERT_IF_TYPES_MATCH(U64) + CONVERT_IF_TYPES_MATCH(F16) + CONVERT_IF_TYPES_MATCH(F32) + CONVERT_IF_TYPES_MATCH(F64) + CONVERT_IF_TYPES_MATCH(BF16) +#undef CONVERT_IF_TYPES_MATCH + case C64: + if (!bitcast) { + return ConvertToC64(src_literal); + } + break; + // Other types are not yet supported. + default: + break; + } + return Unimplemented( + "Converting from type %s to type %s is not implemented.", + PrimitiveType_Name(src_literal.shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); +} + +StatusOr> ConvertSwitch( + const LiteralBase& literal, PrimitiveType primitive_dest_type, + bool bitcast) { + TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); + if (literal.shape().element_type() == primitive_dest_type) { + return literal.CloneToUnique(); + } + switch (literal.shape().element_type()) { +#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ + case (type): \ + return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \ + bitcast); + CONVERT_IF_DEST_TYPE_MATCHES(PRED) + CONVERT_IF_DEST_TYPE_MATCHES(S8) + CONVERT_IF_DEST_TYPE_MATCHES(S32) + CONVERT_IF_DEST_TYPE_MATCHES(S64) + CONVERT_IF_DEST_TYPE_MATCHES(U8) + CONVERT_IF_DEST_TYPE_MATCHES(U32) + CONVERT_IF_DEST_TYPE_MATCHES(U64) + CONVERT_IF_DEST_TYPE_MATCHES(F16) + CONVERT_IF_DEST_TYPE_MATCHES(F32) + CONVERT_IF_DEST_TYPE_MATCHES(F64) + CONVERT_IF_DEST_TYPE_MATCHES(BF16) +#undef CONVERT_IF_DEST_TYPE_MATCHES + // Other types are not yet supported. + default: + return Unimplemented( + "%s from type %s to type %s is not implemented.", + (bitcast ? "Bitcast converting" : "Converting"), + PrimitiveType_Name(literal.shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); + } +} + +} // namespace + +StatusOr> LiteralBase::Convert( + PrimitiveType primitive_dest_type) const { + return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); +} + +StatusOr> LiteralBase::BitcastConvert( + PrimitiveType primitive_dest_type) const { + if (primitive_util::BitWidth(shape().element_type()) != + primitive_util::BitWidth(primitive_dest_type)) { + return InvalidArgument( + "Cannot bitcast convert from %s to %s, bit widths are different: %d != " + "%d", + PrimitiveType_Name(shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str(), + primitive_util::BitWidth(shape().element_type()), + primitive_util::BitWidth(primitive_dest_type)); + } + return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); +} + +StatusOr> LiteralBase::ConvertToShape( + const Shape& dest_shape, bool round_f32_to_bf16) const { + if (!ShapeUtil::IsTuple(dest_shape)) { + if (round_f32_to_bf16 && shape().element_type() == F32 && + dest_shape.element_type() == BF16) { + auto converter = [](float src) { + return tensorflow::bfloat16::round_to_bfloat16(src); + }; + return ConvertBetweenNativeTypesWithConverter(*this, + converter); + } + return Convert(dest_shape.element_type()); + } + std::vector elements; + for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { + auto element = LiteralSlice(*this, {i}); + TF_ASSIGN_OR_RETURN( + auto new_element, + element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); + elements.push_back(std::move(*new_element)); + } + auto converted = MakeUnique(); + *converted = MutableLiteralBase::MoveIntoTuple(&elements); + return std::move(converted); +} + +/* static */ Literal MutableLiteralBase::MoveIntoTuple( + tensorflow::gtl::MutableArraySlice elements) { + std::vector element_shapes; + for (const Literal& element : elements) { + element_shapes.push_back(element.shape()); + } + Literal literal(ShapeUtil::MakeTupleShape(element_shapes), + /*allocate_arrays=*/false); + for (int i = 0; i < elements.size(); ++i) { + TF_CHECK_OK( + literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i})); + } + return literal; +} + +template +bool LiteralBase::Piece::EqualElementsInternal( + const LiteralBase::Piece& other, std::vector* multi_index) const { + if (multi_index->size() == ShapeUtil::Rank(subshape())) { + return (Get(*multi_index) == other.Get(*multi_index)); + } + for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) { + multi_index->push_back(i); + if (!EqualElementsInternal(other, multi_index)) { + return false; + } + multi_index->pop_back(); + } + return true; +} + +bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { + DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); + + std::vector multi_index; + switch (subshape().element_type()) { + case PRED: + return EqualElementsInternal(other, &multi_index); + case U8: + return EqualElementsInternal(other, &multi_index); + case S32: + return EqualElementsInternal(other, &multi_index); + case S64: + return EqualElementsInternal(other, &multi_index); + case U32: + return EqualElementsInternal(other, &multi_index); + case U64: + return EqualElementsInternal(other, &multi_index); + case F32: + return EqualElementsInternal(other, &multi_index); + case F64: + return EqualElementsInternal(other, &multi_index); + case F16: + return EqualElementsInternal(other, &multi_index); + case BF16: + return EqualElementsInternal(other, &multi_index); + case C64: + return EqualElementsInternal(other, &multi_index); + default: + LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type " + << PrimitiveType_Name(subshape().element_type()); + } +} + +bool LiteralBase::operator==(const LiteralBase& other) const { + if (!ShapeUtil::Compatible(shape(), other.shape())) { + return false; + } + + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + const Piece& other_piece = other.piece(index); + if (!piece.EqualElements(other_piece)) { + return false; + } + return true; + }); +} + +namespace { + +template +static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, + NativeT value) { + for (int64 i = 0; i < data.size(); ++i) { + if (data[i] != value) { + return false; + } + } + return true; +} + +} // namespace + +bool LiteralBase::IsAll(int8 value) const { + return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index, + const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case U8: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; + case U32: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; + case U64: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; + case S8: + return AllElementsEqualValue(piece.data(), value); + case S32: + return AllElementsEqualValue(piece.data(), value); + case S64: + return AllElementsEqualValue(piece.data(), value); + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case PRED: + if (value == 0) { + return AllElementsEqualValue(piece.data(), false); + } + if (value == 1) { + return AllElementsEqualValue(piece.data(), true); + } + return false; + default: + return false; + } + return false; + }; + + if (!piece_is_all()) { + return false; + } + return true; + }); +} + +bool LiteralBase::IsAllFloat(float value) const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue( + piece.data(), static_cast(value)); + default: + return false; + } + }; + if (!piece_is_all()) { + return false; + } + return true; + }); +} + +bool LiteralBase::IsAllComplex(complex64 value) const { + switch (shape().element_type()) { + case C64: + return AllElementsEqualValue(root_piece().data(), + value); + default: + return false; + } +} + +bool LiteralBase::IsAllFirst() const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + // Empty shapes are not all the first element since there is no first + // element. + if (ShapeUtil::IsZeroElementArray(piece.subshape())) { + return false; + } + auto piece_is_all = [&]() { + switch (piece.subshape().element_type()) { + case PRED: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 8 bit types + case S8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 16 bit types + case BF16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 32 bit types + case F32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 64 bit types + case C64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + default: + return false; + } + }; + + if (!piece_is_all()) { + return false; + } + return true; + }); +} + +bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice indices) const { + CHECK(ShapeUtil::IsArray(shape())); + switch (shape().element_type()) { + case U8: + return Get(indices) == 0; + case U32: + return Get(indices) == 0; + case U64: + return Get(indices) == 0; + case S8: + return Get(indices) == 0; + case S32: + return Get(indices) == 0; + case S64: + return Get(indices) == 0; + case F32: + return Get(indices) == 0.0f; + case F64: + return Get(indices) == 0.0; + case C64: + return Get(indices) == complex64(0.0f, 0.0f); + case F16: + return Get(indices) == static_cast(0.0f); + case BF16: + return Get(indices) == static_cast(0.0f); + case PRED: + return Get(indices) == false; + default: + LOG(FATAL) << "Input literal must be an array."; + } +} + +namespace { + +template +void CopyToRepeatedField(RepeatedFieldT* dest, + const tensorflow::gtl::ArraySlice src) { + *dest = RepeatedFieldT(src.begin(), src.end()); +} + +} // namespace + +void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { + *proto->mutable_shape() = subshape(); + switch (subshape().element_type()) { + case PRED: + CopyToRepeatedField(proto->mutable_preds(), data()); + break; + case U8: + proto->set_u8s(static_cast(data().data()), + element_count()); + break; + case U32: + CopyToRepeatedField(proto->mutable_u32s(), data()); + break; + case U64: + CopyToRepeatedField(proto->mutable_u64s(), data()); + break; + case S32: + CopyToRepeatedField(proto->mutable_s32s(), data()); + break; + case S64: + CopyToRepeatedField(proto->mutable_s64s(), data()); + break; + case F16: + *proto->mutable_f16s() = string( + reinterpret_cast(data().data()), size_bytes()); + if (!kLittleEndian) { + ConvertEndianShort(proto->mutable_f16s()); + } + break; + case BF16: + *proto->mutable_bf16s() = string( + reinterpret_cast(data().data()), size_bytes()); + if (!kLittleEndian) { + ConvertEndianShort(proto->mutable_bf16s()); + } + break; + case F32: + CopyToRepeatedField(proto->mutable_f32s(), data()); + break; + case F64: + CopyToRepeatedField(proto->mutable_f64s(), data()); + break; + case C64: + for (complex64 value : data()) { + proto->add_c64s(value.real()); + proto->add_c64s(value.imag()); + } + break; + case TUPLE: + case TOKEN: + // Nothing to do but assign the shape which is done above. + return; + default: + // TODO(b/111551621): Support serializing more PrimitiveTypes. + LOG(FATAL) << "Unhandled primitive type " + << PrimitiveType_Name(subshape().element_type()); + } +} + +const void* LiteralBase::Piece::untyped_data() const { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + return buffer(); +} + +void* LiteralBase::Piece::untyped_data() { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + return buffer(); +} + +namespace { + +template +Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, + const RepeatedFieldT& src) { + if (dest.size() != src.size()) { + return InvalidArgument( + "Expected %lu elements in LiteralProto repeated field, has %d", + dest.size(), src.size()); + } + std::copy(src.begin(), src.end(), dest.begin()); + return Status::OK(); +} + +} // namespace + +Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { + // These conditions should have been checked in + // MutableLiteralBase::CreateFromProto. + TF_RET_CHECK(proto.has_shape()); + TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); + TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); + + switch (subshape().element_type()) { + case PRED: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.preds())); + break; + case U8: { + auto u8_data = data(); + TF_RET_CHECK(proto.u8s().size() == u8_data.size()); + std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin()); + } break; + case S32: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.s32s())); + break; + case S64: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.s64s())); + break; + case U32: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u32s())); + break; + case U64: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u64s())); + break; + case F16: { + const string& s(proto.f16s()); + TF_RET_CHECK(data().size() * sizeof(half) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); + } + } break; + + case BF16: { + const string& s(proto.bf16s()); + TF_RET_CHECK(data().size() * sizeof(bfloat16) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); + } + } break; + case F32: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.f32s())); + break; + case F64: + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.f64s())); + break; + case C64: { + auto complex_data = data(); + TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2); + for (int64 i = 0; i < complex_data.size(); ++i) { + complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)}; + } + } break; + case TUPLE: + LOG(FATAL) << "Should not be called on tuple shapes: " + << ShapeUtil::HumanString(subshape()); + break; + default: + LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); + } + return Status::OK(); +} + +LiteralProto LiteralBase::ToProto() const { + LiteralProto proto; + root_piece().ForEachSubpiece( + [&](const ShapeIndex& index, const Piece& piece) { + LiteralProto* proto_piece = &proto; + for (int64 i : index) { + while (proto_piece->tuple_literals_size() <= i) { + proto_piece->add_tuple_literals(); + } + proto_piece = proto_piece->mutable_tuple_literals(i); + } + piece.WriteToProto(proto_piece); + }); + + if (LayoutUtil::IsSparseArray(shape())) { + CopyToRepeatedField(proto.mutable_sparse_indices(), + sparse_indices()->data()); + } + + return proto; +} + +const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { + return piece(shape_index).untyped_data(); +} + +void* MutableLiteralBase::untyped_data(const ShapeIndex& shape_index) { + return piece(shape_index).untyped_data(); +} + +int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const { + return piece(shape_index).size_bytes(); +} + +string LiteralBase::GetR1U8AsString() const { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(shape().element_type(), U8); + return string(tensorflow::bit_cast(data().data()), + ShapeUtil::ElementsIn(shape())); +} + +void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape, + Piece* src_piece, + Piece* dest_piece) { + DCHECK(ShapeUtil::Equal(src_piece->subshape(), dest_piece->subshape())) + << "src_piece has shape: " + << ShapeUtil::HumanString(src_piece->subshape()) + << "dest_piece has shape: " + << ShapeUtil::HumanString(dest_piece->subshape()); + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + CopyPieceSubtree(subshape, &src_piece->child(i), &child_piece); + + dest_piece->emplace_back(std::move(child_piece)); + } + } else if (ShapeUtil::IsArray(shape)) { + dest_piece->set_buffer(src_piece->buffer()); + } else { + // If the shape is neither an array nor tuple, then it must be + // zero-sized. Otherwise, some memory needs to be allocated for it. + CHECK_EQ(dest_piece->size_bytes(), 0); + } +} + +MutableLiteralBase::~MutableLiteralBase() {} + +MutableBorrowingLiteral::MutableBorrowingLiteral( + const MutableBorrowingLiteral& literal) + : MutableLiteralBase() { + shape_ = MakeUnique(literal.shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_); +} + +MutableBorrowingLiteral& MutableBorrowingLiteral::operator=( + const MutableBorrowingLiteral& literal) { + shape_ = MakeUnique(literal.shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_); + + return *this; +} + +MutableBorrowingLiteral::MutableBorrowingLiteral( + const MutableLiteralBase& literal) + : MutableLiteralBase() { + shape_ = MakeUnique(literal.shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.root_piece(), root_piece_); +} + +MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal) + : MutableLiteralBase() { + shape_ = MakeUnique(literal->shape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal->root_piece(), root_piece_); +} + +MutableBorrowingLiteral::MutableBorrowingLiteral( + MutableBorrowingLiteral literal, const ShapeIndex& view_root) + : MutableLiteralBase() { + shape_ = MakeUnique(literal.piece(view_root).subshape()); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + CopyPieceSubtree(*shape_, &literal.piece(view_root), root_piece_); +} + +MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr, + const Shape& shape) + : MutableLiteralBase() { + shape_ = MakeUnique(shape); + CHECK(LayoutUtil::HasLayout(*shape_)); + CHECK(!ShapeUtil::IsTuple(*shape_)); + + root_piece_ = new Piece(); + root_piece_->set_buffer(const_cast(src_buf_ptr)); + root_piece_->set_subshape(shape_.get()); +} + +MutableBorrowingLiteral::~MutableBorrowingLiteral() { + if (root_piece_ != nullptr) { + root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (piece->buffer() != nullptr) { + delete piece->sparse_indices(); + } + }); + delete root_piece_; + } +} + +LiteralSlice::LiteralSlice(const LiteralBase& literal) + : LiteralBase(), root_piece_(&literal.root_piece()) {} + +LiteralSlice::LiteralSlice(const LiteralBase& literal, + const ShapeIndex& view_root) + : LiteralBase(), root_piece_(&literal.piece(view_root)) {} + +void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { + CHECK(ShapeUtil::IsTuple(shape)); + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + if (ShapeUtil::IsTuple(subshape)) { + BuildPieceSubtree(subshape, &child_piece); + } + + piece->emplace_back(std::move(child_piece)); + } +} + +BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsArray(*shape_)); + CHECK(LayoutUtil::HasLayout(*shape_)); + + root_piece_ = Piece(); + root_piece_.set_buffer(const_cast(src_buf_ptr)); + root_piece_.set_subshape(shape_.get()); +} + +BorrowingLiteral::BorrowingLiteral( + tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsTuple(*shape_)); + CHECK(!ShapeUtil::IsNestedTuple(*shape_)); + CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); + root_piece_ = Piece(); + root_piece_.set_subshape(shape_.get()); + BuildPieceSubtree(*shape_, &root_piece_); + + for (int i = 0; i < src_buf_ptrs.size(); ++i) { + const auto& src_shape = shape_->tuple_shapes(i); + CHECK(ShapeUtil::IsArray(src_shape)); + root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h new file mode 100644 index 0000000000000000000000000000000000000000..92c0f903cbe252a153103aa8514bb5531696bbfe --- /dev/null +++ b/tensorflow/compiler/xla/literal.h @@ -0,0 +1,1188 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LITERAL_H_ +#define TENSORFLOW_COMPILER_XLA_LITERAL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/sparse_index_array.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/bitmap.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Forward declare Literal and LiteralSlice class to be used by the creation +// methods in the base class. +class Literal; +class LiteralSlice; + +// Abstract base class for literals. +class LiteralBase { + public: + virtual ~LiteralBase() = 0; + + // Literals are equal if they have compatible shapes and the same data + // values. Layout is not compared. + bool operator==(const LiteralBase& other) const; + bool operator!=(const LiteralBase& other) const { return !(*this == other); } + + // Returns the shape of the literal. + const Shape& shape() const { return root_piece().subshape(); } + + // Serialize to proto. + LiteralProto ToProto() const; + + // Returns an ArraySlice of the array for this literal for the given NativeT + // (e.g., float). CHECKs if the subshape of the literal at the given + // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type + // to native type. + template + tensorflow::gtl::ArraySlice data( + const ShapeIndex& shape_index = {}) const; + + // Returns a const pointer to the sparse index array. Returns nullptr if the + // literal is not a sparse array. + const SparseIndexArray* sparse_indices( + const ShapeIndex& shape_index = {}) const; + + // Returns a const pointer to (or size of) the underlying buffer holding the + // array at the given shape index. CHECKs if the subshape of the literal at + // the given ShapeIndex is not array. + const void* untyped_data(const ShapeIndex& shape_index = {}) const; + int64 size_bytes(const ShapeIndex& shape_index = {}) const; + + // Returns this literal's data as a string. This literal must be a rank-1 U8 + // array. + string GetR1U8AsString() const; + + // Returns a string representation of the literal value. + // Warning: this function can take minutes for multi-million element Literals. + string ToString(bool print_layout = false) const; + + // Gets an element in the literal at the given index. The multi_index is + // CHECKed against the dimension sizes. + template + NativeT Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const; + // Overloads of Get for array literals. CHECKs if the literal is not + // array-shaped and dense. + template + NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; + + // Returns the element value at index (0, ..., 0), however many zeroes are + // required for that index. + template + NativeT GetFirstElement() const; + + // As Get(), but determines the correct type and converts the value + // into text. + string GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index = {}) const; + // As GetSparseElement(), but determines the correct type and converts the + // value into text. + string GetSparseElementAsString(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + // As Get(), but determines the correct type and converts the value into + // int64. This literal must be an array. + StatusOr GetIntegralAsS64( + tensorflow::gtl::ArraySlice multi_index) const; + + // Returns the multi-index of the element in a sparse literal at the given + // sparse element number. The sparse element number is the position with in + // the sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + tensorflow::gtl::ArraySlice GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; + + // Returns the value of the element in a sparse literal at the given sparse + // element number. The sparse element number is the position with in the + // sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + template + NativeT GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + + // Invokes the "per cell" callback for each element in the provided + // literal with the element's indices and a string representation of + // the element's value. + // + // This function is useful if you want a polymorphic representation + // of the tensor's elements (turning it to a string for something + // like representation in a protobuf). + // + // This literal must have a dense layout. + void EachCellAsString( + const std::function indices, + const string& value)>& per_cell) const; + template + void EachCell(std::function indices, + NativeT value)> + per_cell) const; + + // Returns whether every element in this literal is equal to value. + // + // value is an int8 because we expect this to be called with small + // compile-time constants (0, -1, etc.) and so that whatever value you pass + // can be represented exactly by floating-point types as small as 16 bits. + // + // If value doesn't fit in this literal's type, returns false. Values of 1/0 + // are considered equal to true/false; other values are not considered equal + // to true. Also if this literal is not array-shaped false is returned. + bool IsAll(int8 value) const; + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular floating-point number. + // + // If the literal is not a floating-point value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for values that can be expressed precisely as a float, + // e.g. -0.5. Also if this literal is not array-shaped false is returned. + bool IsAllFloat(float value) const; + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular complex number. + // + // If the literal is not a complex value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for complex values that can be expressed precisely as + // float pairs e.g. (-0.5, 1.0). + // + // This literal must have a dense layout. + bool IsAllComplex(complex64 value) const; + + // Literal consists entirely of the first element of the literal. + bool IsAllFirst() const; + + // Returns whether this literal is zero at the specified index. This literal + // must be an array with a dense layout. + bool IsZero(tensorflow::gtl::ArraySlice indices) const; + + // Returns the count of the elements in the array at the given shape index in + // this literal. + int64 element_count(const ShapeIndex& index = {}) const { + return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); + } + + // Returns the count of the elements in the sparse array at the given shape + // index in this literal, which will be no larger than + // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). + int64 sparse_element_count() const; + + // Compute a hash for this literal. This literal must not be a sparse tensor + // or a tuple containing a sparse tensor. + size_t Hash() const; + + // Converts this literal to the given shape. Returns an error is the + // conversion is not possible. + // + // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding + // instead of truncation; otherwise, truncation is used. + // + // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes + // the default behavior. + StatusOr> ConvertToShape( + const Shape& dest_shape, bool round_f32_to_bf16 = false) const; + + // Converts this literal to another primitive type using a bitcast + // conversion. The to and from primitive types must have the same bit + // width. Returns an error if the conversion is not possible. This literal + // must be array-shaped. + StatusOr> BitcastConvert( + PrimitiveType primitive_dest_type) const; + + // Converts this literal to another primitive type. Returns an error if the + // conversion is not possible. This literal must be array-shaped. + StatusOr> Convert( + PrimitiveType primitive_dest_type) const; + + // Clones the underlying buffers into a new Literal, or new + // std::unique_ptr. + Literal Clone() const; + std::unique_ptr CloneToUnique() const; + + // TODO(b/67651157): The methods below which perform computation on Literals + // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with + // evaluator code which operates on Literals. + // + // Creates a new value that has the equivalent value as this + // literal, but conforms to new_layout; e.g. a literal matrix that was in {0, + // 1} minor-to-major dimension layout can be re-layed-out as {1, 0} + // minor-to-major dimension layout and the value in the cell at any given + // logical index (i0, i1) will be the same. + // + // For tuple shaped literals, shape_index should be used to select the inner + // array that the new layout applies to. + // + // Note: this is useful when the client wants to ensure that a value placed in + // the XLA allocation tracker has a particular layout; for efficiency + // purposes or avoiding unimplemented operation/layout combinations. + std::unique_ptr Relayout(const Layout& new_layout, + const ShapeIndex& shape_index = {}) const; + + // An overload of Relayout which changes the layout of the entire shape rather + // than being limited to a single array within the shape. + std::unique_ptr Relayout(const Shape& shape_with_layout) const; + + // Creates a new literal by reshaping this literal to have the given + // dimensions. The total number of elements must not change; The + // implementation currently only supports monotonic dim0-major layouts. + // This literal must be an array. + StatusOr> Reshape( + tensorflow::gtl::ArraySlice dimensions) const; + + // Creates a new literal by broadcasting this literal with `dimensions` to + // yield a literal of shape `result_shape`. + StatusOr> Broadcast( + const Shape& result_shape, + tensorflow::gtl::ArraySlice dimensions) const; + + // Creates a new literal by reordering the dimensions of this literal. + // The given `permutation` must be a permutation of the dimension numbers + // in the original literal, and it specifies the order of the new dimensions + // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). + // For example, a transpose call on a literal of shape [3 x 8 x 4] and + // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. + // This literal must be an array. + std::unique_ptr Transpose( + tensorflow::gtl::ArraySlice permutation) const; + + // Creates a sub-array from this literal by extracting the indices + // [start_index, limit_index) of each dimension. The result literal has the + // same rank and layout as for the given literal. The number of indices in + // start_indices and limit_indices must be the rank of the literal, and the + // indices follow the order of the dimensions. + // This literal must be an array. + std::unique_ptr Slice( + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) const; + + // Creates a literal with a prepended dimension with bound "times"; e.g. a + // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this + // literal replicated four times. + // This literal must be an array. + template + std::unique_ptr Replicate(int64 times) const; + + // Creates a new Literal object with the shape specified as parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + // + // Note: It's an antipattern to use this method then immediately call + // MutableLiteralBase::Populate on the result (since that results in zero + // initialization, then reinitialization. Conside if a call to + // MakeUnique(shape), followed by the call to + // MutableLiteralBase::Populate can be used instead. + static std::unique_ptr CreateFromShape(const Shape& shape); + + protected: + // A data structure representing a subshape at a particular ShapeIndex within + // the literal. For array-shaped ShapeIndexes, this data structure holds the + // pointer to the memory allocated for the array data. + class Piece { + public: + // Returns the buffer holding the array data for this piece as an array + // slice. This piece must be array-shaped. + template + tensorflow::gtl::ArraySlice data() const; + template + tensorflow::gtl::MutableArraySlice data(); + + // Returns the buffer holding the array data for this piece as a void*. This + // piece must be array-shaped. + void* untyped_data(); + const void* untyped_data() const; + + // Gets or sets an element in the array at the given index. The multi_index + // is CHECKed against the dimension sizes of the array. This piece must be + // array-shaped. + template + NativeT Get(tensorflow::gtl::ArraySlice index) const; + template + void Set(tensorflow::gtl::ArraySlice index, NativeT value); + + // Gets/sets the buffer holding the array data. + char* buffer() const { return buffer_; } + void set_buffer(char* buffer) { buffer_ = buffer; } + + // The array of multi-indices that provide the locations of non-zero + // elements in a sparse array. Only used if + // LayoutUtil::IsSparseArray(shape()) is true. + SparseIndexArray* sparse_indices() const { return sparse_indices_; } + void set_sparse_indices(SparseIndexArray* sparse_indices) { + sparse_indices_ = sparse_indices; + } + + // Gets or sets the subshape of this piece. This reference points to a + // subshape within the shape in the containing Literal (Literal::shape_). + const Shape& subshape() const { return *subshape_; } + void set_subshape(const Shape* subshape) { subshape_ = subshape; } + + // Returns the size in bytes of the buffer holding the array data. + int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } + + // Returns the number of elements in this piece's array. + int64 element_count() const { + // If this is a sparse array, use the number of elements represented by + // the indices in the associated SparseIndexArray. + return LayoutUtil::IsSparseArray(subshape()) + ? sparse_indices()->index_count() + : ShapeUtil::ElementsIn(subshape()); + } + + // Returns the child piece at 'index' of this piece. + Piece& child(int64 index) { return children_[index]; } + + // Adds a child piece to this piece's children. + void emplace_back(Piece child_piece) { + children_.emplace_back(std::move(child_piece)); + } + + // Returns the size of children pieces of this piece. + int64 children_size() { return children_.size(); } + + // Visitor functions that recursively traverses the piece and calls the + // given function at each child piece. The function has the type: + // void (const ShapeIndex& index, const Piece& piece) + template + void ForEachSubpiece(const Fn& func) const { + ShapeIndex index; + return ForEachHelper( + [&func](const ShapeIndex& index, const Piece& piece) { + func(index, piece); + return Status::OK(); + }, + *this, &index) + .IgnoreError(); + } + // Same as above, but the function has the type: + // Status (const ShapeIndex& index, const Piece& piece) + // The first non-OK return value is returned by the function. + template + Status ForEachSubpieceWithStatus(const Fn& func) const { + ShapeIndex index; + return ForEachHelper(func, *this, &index); + } + // Same as above, but the function has the type: + // Bool (const ShapeIndex& index, const Piece& piece) + // The first non-true return value is returned by the function. + template + bool ForEachSubpieceWithBool(const Fn& func) const { + ShapeIndex index; + return ForEachHelperBool(func, *this, &index); + } + // Same as above, but the function has the type: + // Void (const ShapeIndex& index, Piece& piece) + template + void ForEachMutableSubpiece(const Fn& func) { + ShapeIndex index; + return ForEachMutableHelper( + [&func](const ShapeIndex& index, Piece* piece) { + func(index, piece); + return Status::OK(); + }, + const_cast(this), &index) + .IgnoreError(); + } + // Same as above, but the function has the type: + // Status (const ShapeIndex& index, Piece& piece) + // The first non-OK return value is returned by the function. + template + Status ForEachMutableSubpieceWithStatus(const Fn& func) { + ShapeIndex index; + return ForEachMutableHelper( + func, const_cast(this), &index); + } + + // Returns true if this piece and 'other' contain the same data. This piece + // and 'other' must be array-shaped and compatible. + bool EqualElements(const Piece& other) const; + + // Writes the shape and data (if array-shaped) into the given proto. + void WriteToProto(LiteralProto* proto) const; + + // Copy the data from 'src' into this piece's buffer. Shapes of this piece + // and src must be compatible. + Status CopyFrom(const Piece& src); + + // Copies the data from the given proto into this piece. The shape of this + // piece must be equal (not just compatible) to the shape of the proto. + Status CopyFromProto(const LiteralProto& proto); + + // Sorts the elements in a sparse array. + void SortSparseElements(); + + private: + // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'. + // The first non-OK (or non-true) value is returned by the function. + // The callable 'func' has the same signature as described above in + // ForEachSubpiece*. + template + Status ForEachHelper(const Fn& func, const Piece& piece, + ShapeIndex* index) const { + TF_RETURN_IF_ERROR(func(*index, piece)); + for (int64 i = 0; i < piece.children_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index)); + index->pop_back(); + } + return Status::OK(); + } + template + bool ForEachHelperBool(const Fn& func, const Piece& piece, + ShapeIndex* index) const { + if (!func(*index, piece)) { + return false; + } + for (int64 i = 0; i < piece.children_.size(); ++i) { + index->push_back(i); + if (!ForEachHelperBool(func, piece.children_[i], index)) { + return false; + } + index->pop_back(); + } + return true; + } + template + Status ForEachMutableHelper(const Fn& func, Piece* piece, + ShapeIndex* index) { + TF_RETURN_IF_ERROR(func(*index, piece)); + for (int64 i = 0; i < piece->children_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR( + ForEachMutableHelper(func, &piece->children_[i], index)); + index->pop_back(); + } + return Status::OK(); + } + + // Recursive helper for EqualElements. + template + bool EqualElementsInternal(const Piece& other, + std::vector* multi_index) const; + + // Helper for SortSparseElements that has the element type as a template + // parameter. + template + void SortSparseElementsInternal(); + + // For array-shaped pieces, this is the buffer holding the literal data. + char* buffer_ = nullptr; + + // For sparse arrays, this is the array of indices. + SparseIndexArray* sparse_indices_ = nullptr; + + // The shape of piece. This points into the shape of the containing Literal + // (Literal::shape_). + const Shape* subshape_ = nullptr; + + // Children pieces for tuple shaped pieces. + std::vector children_ = {}; + }; // class Piece + + const Piece& piece(const ShapeIndex& shape_index) const { + Piece* piece = &const_cast(root_piece()); + for (const auto i : shape_index) { + DCHECK_GE(i, 0); + DCHECK_LT(i, piece->children_size()); + piece = &piece->child(i); + } + return *piece; + } + + // Returns the piece at the root of the shape. + virtual const Piece& root_piece() const = 0; + + // LiteralSlice and Literal must access Pieces of other Literals. + friend class MutableLiteralBase; + friend class LiteralSlice; + friend class BorrowingLiteral; + + private: + template + std::unique_ptr SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice start_indices) const; +}; + +// Abstract base class representing a mutable literal in XLA. +class MutableLiteralBase : public LiteralBase { + public: + virtual ~MutableLiteralBase() = 0; + + // Returns a MutableArraySlice view of the array for this literal for the + // given NativeT (e.g., float). CHECKs if the subshape of the literal at the + // given ShapeIndex is not array. See primitive_util.h for the mapping from + // XLA type to native type. + template + tensorflow::gtl::MutableArraySlice data( + const ShapeIndex& shape_index = {}); + // Unhide const method from parent class. + using LiteralBase::data; + + // Returns a pointer to the sparse index array. Returns nullptr if the literal + // is not a sparse array. + SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + + // TODO(b/67651157): Remove this accessor. Literal users should not be able to + // mutate the shape as this can produce malformed Literals. + Shape* mutable_shape_do_not_use() { return shape_.get(); } + + // Returns a pointer to the underlying buffer holding the array at the given + // shape index. CHECKs if the subshape of the literal at the given ShapeIndex + // is not array. + void* untyped_data(const ShapeIndex& shape_index = {}); + // Unhide const method from parent class. + using LiteralBase::untyped_data; + + // Populates a literal with a sparse layout with the given indices and values. + // Each index in the indices array is CHECKed against the dimensions in the + // literal's shape. If sort is true, then the indices and values will be + // sorted. If sort is false, then the indices and values are assumed to + // already be in sorted order. See CreateSparse for an example of how data + // are populated. + template + void PopulateSparse(SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, + bool sort = true); + + // Copy values from 'src_literal' rooted at 'src_shape_index' into this + // literal rooted at 'dest_shape_index'. The subshape of this literal rooted + // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' + // rooted at 'src_shape_index', but need not be arrays. + Status CopyFrom(const LiteralSlice& src_literal, + const ShapeIndex& dest_shape_index = {}, + const ShapeIndex& src_shape_index = {}); + + // Copies the values from src_literal, starting at src_base shape indexes, + // to this literal, starting at dest_base, where the copy size in each + // dimension is specified by copy_size. + // The src_literal and this literal must have the same primitive type, + // src_base+copy_size must fit the source literal dimensions, as well as + // dest_base+copy_size must fit the destination literal dimensions. + // Note: if either src_literal or this literal contains dimensions with zero + // element, then copy_size must be 0 in these dimensions while the + // corresponding base indices being 0. + // This literal and 'src_literal' must be arrays. + Status CopySliceFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); + + // Copies one element from src_literal[src_index] to (*this)[dest_index]. + Status CopyElementFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index); + + // Sets an element in the literal at the given index. The multi_index is + // CHECKed against the dimension sizes. + template + void Set(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index, NativeT value); + // Overloads of Set for array literals. CHECKs if the literal is not + // array-shaped and dense. + template + void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); + + // Appends the given element to the literal. If the elements are not appended + // in sorted order, then SortSparseElements should be called before calling + // other methods. This literal must have a sparse layout. + template + void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, + NativeT value, const ShapeIndex& shape_index = {}); + + // Sorts the elements in a sparse array. + void SortSparseElements(const ShapeIndex& shape_index = {}); + + // As Set(), but truncates `value` to the literal element type before storing. + // This literal must be an array. + Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, + int64 value); + + // Populate this literal with the given values. Examples: + // + // // Populate with floats. + // Array2D float_values = ... + // literal.PopulateR2FromArray2D(values); + // + // // Populate with int32s. + // literal.PopulateR2({{1, 2}, {3, 4}}); + // + // The shape and element type of this literal must match given values. For + // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 + // array of S32. + template + void PopulateR1(tensorflow::gtl::ArraySlice values); + void PopulateR1(const tensorflow::core::Bitmap& values); + template + void PopulateR2(std::initializer_list> values); + template + void PopulateFromArray(const Array& values); + template + void PopulateR2FromArray2D(const Array2D& values); + template + void PopulateR3FromArray3D(const Array3D& values); + template + void PopulateR4FromArray4D(const Array4D& values); + + // Populates literal values by calling the generator function for every cell + // in this literal object. + // + // generator must be a callable of the type + // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. + // + // This literal must have a dense layout. + template + Status Populate(const FnType& generator); + + // A parallel version of Populate(). This can be used if the generator is + // thread-safe and the values for the shape's different elements are + // independent. + template + Status PopulateParallel(const FnType& generator); + + // Fills this literal with the given value. + template + void PopulateWithValue(NativeT value); + + // This operation is the inverse of DecomposeTuple. The given elements are + // moved into the tuple elements of a new tuple-shaped Literal which is + // returned. Upon return, each of the Literals in 'elements' is set to a nil + // shape (empty tuple). + static Literal MoveIntoTuple( + tensorflow::gtl::MutableArraySlice elements); + + // Serialize from a proto. + static StatusOr> CreateFromProto( + const LiteralProto& proto); + + protected: + // Returns the piece at the given ShapeIndex. + Piece& piece(const ShapeIndex& shape_index) { + return const_cast(LiteralBase::piece(shape_index)); + } + + Piece& root_piece() const override { return *root_piece_; }; + + // Internal template helper for the Literal::CopySliceFrom(), matching its + // arguments one by one. + template + Status CopySliceFromInternal(const LiteralBase& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); + + // Utility structure which is used to create the optimal configuration for + // a ShapeUtil::ForEachIndex() scan across two literals. + struct StrideConfig { + StrideConfig(const Shape& source_shape, const Shape& dest_shape, + tensorflow::gtl::ArraySlice dimensions); + + // The dimensions of the stride operation. Essentially every dimension + // will be iterated from base[i] to base[i]+dimensions[i], in step[i] + // steps. + tensorflow::gtl::ArraySlice dimensions; + DimensionVector base; + DimensionVector step; + int64 minor_dimension = 0; + // The size of the strides for source and destination. One of the two + // (the one looping through its most minor dimension) will be 1, while + // the other will be the stride size at the dimension matching the other + // shape most minor dimension being scanned. + int64 dest_stride = 1; + int64 source_stride = 1; + // The size of the inner loop on the most minor dimension. + int64 minor_loop_size = 1; + }; + + // Literal class always owns the shape. The parent class borrows this shape. + std::unique_ptr shape_; + + Piece* root_piece_ = nullptr; + + // Implementation details shared between Populate() and PopulateParallel() + template + Status PopulateInternal(const FnType& generator, bool parallel); + + friend class LiteralBase; + friend class MutableBorrowingLiteral; +}; +std::ostream& operator<<(std::ostream& out, const Literal& literal); + +// The underlying buffer and shape is always owned by this class. +class Literal : public MutableLiteralBase { + public: + Literal() : Literal(ShapeUtil::MakeNil()) {} + + // Create a literal of the given shape. The literal is allocated sufficient + // memory to hold the shape. Memory is uninitialized. + explicit Literal(const Shape& shape); + virtual ~Literal(); + + // Literals are moveable, but not copyable. To copy a literal use + // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies + // of literals which can be expensive. + Literal(const Literal& other) = delete; + Literal& operator=(const Literal& other) = delete; + Literal(Literal&& other); + // 'allocate_arrays' indicates whether to allocate memory for the arrays in + // the shape. If false, buffer pointers inside of the Literal::Pieces are set + // to nullptr. + Literal(const Shape& shape, bool allocate_arrays); + Literal& operator=(Literal&& other); + + // Similar to CopyFrom, but with move semantincs. The subshape of this literal + // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' + // (layouts and shapes must match), but need not be arrays. The memory + // allocated in this literal for the subshape at dest_shape_index is + // deallocated, and the respective buffers are replaced with those in + // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). + virtual Status MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index = {}); + + // Returns a vector containing the tuple elements of this Literal as separate + // Literals. This Literal must be tuple-shaped and can be a nested tuple. The + // elements are moved into the new Literals; no data is copied. Upon return + // this Literal is set to a nil shape (empty tuple) + std::vector DecomposeTuple(); + + private: + // Deallocate the buffers held by this literal. + void DeallocateBuffers(); + + // Recursively sets the subshapes and buffers of all subpieces rooted at + // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in + // the shape. + void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); +}; + +// The underlying buffer is not owned by this class and is always owned by +// others. The shape is not owned by this class and not mutable. +class MutableBorrowingLiteral : public MutableLiteralBase { + public: + virtual ~MutableBorrowingLiteral(); + + MutableBorrowingLiteral() : MutableLiteralBase() {} + + MutableBorrowingLiteral(const MutableBorrowingLiteral& literal); + MutableBorrowingLiteral& operator=(const MutableBorrowingLiteral& literal); + + // Implicit conversion constructors. + MutableBorrowingLiteral(const MutableLiteralBase& literal); + MutableBorrowingLiteral(MutableLiteralBase* literal); + MutableBorrowingLiteral(MutableBorrowingLiteral literal, + const ShapeIndex& view_root); + MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape); + + private: + // Recursively copies the subtree from the `src_piece` at the given child + // index to the `dest_piece`. For buffers only the pointers are copied, but + // not the content. + void CopyPieceSubtree(const Shape& shape, Piece* src_piece, + Piece* dest_piece); +}; + +// A read-only view of a Literal. A LiteralSlice contains pointers to shape and +// literal buffers always owned by others. +class LiteralSlice : public LiteralBase { + public: + LiteralSlice() : LiteralBase() {} + + // Implicit conversion constructors. + LiteralSlice(const LiteralBase& literal); + LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root); + + private: + const Piece& root_piece() const override { return *root_piece_; }; + + const Piece* root_piece_; // Not owned. +}; + +// A read-only Literal where the underlying buffers are never owned by this +// class. +class BorrowingLiteral : public LiteralBase { + public: + BorrowingLiteral() : LiteralBase() {} + + // 'src_buf_ptr' is not owned by this class and must outlive the + // lifetime of this class. It points to an appropirately sized buffer with + // data interpretered as indicated by 'shape'. + // This constructor is only used for array shapes. + BorrowingLiteral(const char* src_buf_ptr, const Shape& shape); + // Similar as above, except to be used for constructing non-nested tuples. + BorrowingLiteral(tensorflow::gtl::ArraySlice src_buf_ptrs, + const Shape& shape); + // TODO(b/79707221): adding constructors for nested tuples as well. + + private: + // Recursively builds the subtree for the given piece and sets the subshapes + // of the given piece with the given shape. + void BuildPieceSubtree(const Shape& shape, Piece* piece); + + // Accessor for the root piece of this literal. + const Piece& root_piece() const override { return root_piece_; }; + Piece root_piece_; + + // Shape of this literal. Stored as unique_ptr such that the (default) move + // construction of this class would be trivially correct: the pointer to Shape + // root_piece_ stores will still point to the correct address. + std::unique_ptr shape_; +}; + +template +tensorflow::gtl::ArraySlice LiteralBase::Piece::data() const { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + CHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType()) + << "Attempting to access " + << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) + << " type, but literal element type is " + << PrimitiveType_Name(subshape().element_type()); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(buffer()), element_count()); +} + +template +tensorflow::gtl::MutableArraySlice LiteralBase::Piece::data() { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + CHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType()) + << "Attempting to access " + << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) + << " type, but literal element type is " + << PrimitiveType_Name(subshape().element_type()); + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(buffer()), element_count()); +} + +template +NativeT LiteralBase::Piece::Get( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(LayoutUtil::IsDenseArray(subshape())); + return data()[IndexUtil::MultidimensionalIndexToLinearIndex( + subshape(), multi_index)]; +} + +template +void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, + NativeT value) { + CHECK(LayoutUtil::IsDenseArray(subshape())); + data()[IndexUtil::MultidimensionalIndexToLinearIndex( + subshape(), multi_index)] = value; +} + +template +tensorflow::gtl::ArraySlice LiteralBase::data( + const ShapeIndex& shape_index) const { + return piece(shape_index).data(); +} + +template +tensorflow::gtl::MutableArraySlice MutableLiteralBase::data( + const ShapeIndex& shape_index) { + return piece(shape_index).data(); +} + +template +inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { + return piece(shape_index).Get(multi_index); +} + +template +inline NativeT LiteralBase::Get( + tensorflow::gtl::ArraySlice multi_index) const { + return root_piece().Get(multi_index); +} + +template +inline void MutableLiteralBase::Set( + tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index, NativeT value) { + return piece(shape_index).Set(multi_index, value); +} + +template +inline void MutableLiteralBase::Set( + tensorflow::gtl::ArraySlice multi_index, NativeT value) { + return root_piece().Set(multi_index, value); +} + +template +NativeT LiteralBase::GetFirstElement() const { + return data().at(0); +} + +template +NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index) const { + CHECK( + LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index))); + return data(shape_index)[sparse_element_number]; +} + +template +void MutableLiteralBase::AppendSparseElement( + tensorflow::gtl::ArraySlice multi_index, NativeT value, + const ShapeIndex& shape_index) { + Piece& p = piece(shape_index); + const Shape& subshape = p.subshape(); + CHECK(LayoutUtil::IsSparseArray(subshape)); + int64 rank = ShapeUtil::Rank(subshape); + CHECK_EQ(multi_index.size(), rank); + int64 last_element = p.sparse_indices()->index_count(); + CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout())); + p.sparse_indices()->Append(multi_index); + CHECK_LT(last_element, p.data().size()); + p.data()[last_element] = value; +} + +template +void LiteralBase::EachCell( + std::function indices, + NativeT value)> + per_cell) const { + if (ShapeUtil::IsZeroElementArray(shape())) { + return; + } + std::vector indices(ShapeUtil::Rank(shape()), 0); + do { + per_cell(indices, Get(indices)); + } while (IndexUtil::BumpIndices(shape(), &indices)); +} + +template +inline void MutableLiteralBase::PopulateR1( + tensorflow::gtl::ArraySlice values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); + for (int64 i = 0; i < values.size(); ++i) { + Set({i}, values[i]); + } +} + +template +void MutableLiteralBase::PopulateR2( + std::initializer_list> values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 2); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); + + const int64 dim0_size = values.size(); + const int64 dim1_size = values.begin()->size(); + CHECK_EQ(dim0_size, shape().dimensions(0)); + CHECK_EQ(dim1_size, shape().dimensions(1)); + + int64 dim0 = 0; + for (auto inner_list : values) { + int64 dim1 = 0; + for (auto value : inner_list) { + Set({dim0, dim1}, value); + ++dim1; + } + CHECK_EQ(dim1_size, dim1); + ++dim0; + } +} + +template +void MutableLiteralBase::PopulateFromArray(const Array& values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); + CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions()); + for (int dim = 0; dim < values.num_dimensions(); ++dim) { + CHECK_EQ(values.dim(dim), shape().dimensions(dim)); + } + values.Each([this](tensorflow::gtl::ArraySlice indices, + NativeT value) { this->Set(indices, value); }); +} + +template +void MutableLiteralBase::PopulateR2FromArray2D(const Array2D& values) { + PopulateFromArray(values); +} + +template +void MutableLiteralBase::PopulateR3FromArray3D(const Array3D& values) { + PopulateFromArray(values); +} + +template +void MutableLiteralBase::PopulateR4FromArray4D(const Array4D& values) { + PopulateFromArray(values); +} + +template +void MutableLiteralBase::PopulateSparse( + SparseIndexArray indices, tensorflow::gtl::ArraySlice values, + bool sort) { + CHECK(LayoutUtil::IsSparseArray(shape())); + int rank = ShapeUtil::Rank(shape()); + CHECK_EQ(indices.rank(), rank); + int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout()); + CHECK_LE(indices.max_indices(), max_elements); + int64 num_elements = values.size(); + CHECK_LE(num_elements, max_elements); + CHECK_EQ(num_elements, indices.index_count()); + auto root_data = root_piece().data(); + // Piece::data() returns an ArraySlice of size equal to the number of indices + // in the SparseIndexArray. So there is no need to adjust the size of the data + // here. It is enough to just copy the incoming values into the data buffer. + std::copy(values.begin(), values.end(), root_data.begin()); + *this->root_piece().sparse_indices() = std::move(indices); + if (sort) { + auto root_data = this->root_piece().data(); + this->root_piece().sparse_indices()->SortWithValues(root_data); + } + DCHECK(this->root_piece().sparse_indices()->Validate(shape())); +} + +template +Status MutableLiteralBase::PopulateInternal(const FnType& generator, + bool parallel) { + const Shape& this_shape = shape(); + const int64 rank = ShapeUtil::Rank(this_shape); + TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); + TF_RET_CHECK(this_shape.element_type() == + primitive_util::NativeToPrimitiveType()); + tensorflow::gtl::MutableArraySlice literal_data = data(); + if (rank > 0) { + StrideConfig stride_config(this_shape, this_shape, + AsInt64Slice(this_shape.dimensions())); + int64 minor_dimension_size = + ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); + + auto init_function = [&](tensorflow::gtl::ArraySlice indexes) { + DimensionVector minor_scan_indexes(rank, 0); + const int64 index = + IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); + std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); + for (int64 i = 0; i < minor_dimension_size; ++i) { + minor_scan_indexes[stride_config.minor_dimension] = i; + literal_data.at(index + i) = generator(minor_scan_indexes); + } + }; + if (parallel) { + ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base, + stride_config.dimensions, + stride_config.step, init_function); + } else { + ShapeUtil::ForEachIndex( + this_shape, stride_config.base, stride_config.dimensions, + stride_config.step, + [&init_function](tensorflow::gtl::ArraySlice indexes) { + init_function(indexes); + return true; + }); + } + } else { + // For scalars. + literal_data.at(0) = generator({}); + } + return Status::OK(); +} +template +Status MutableLiteralBase::Populate(const FnType& generator) { + return PopulateInternal(generator, /*parallel=*/false); +} + +template +Status MutableLiteralBase::PopulateParallel(const FnType& generator) { + return PopulateInternal(generator, /*parallel=*/true); +} + +template +void MutableLiteralBase::PopulateWithValue(NativeT value) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); + for (NativeT& element : data()) { + element = value; + } +} + +template +std::unique_ptr LiteralBase::Replicate(int64 times) const { + DimensionVector bounds = {times}; + bounds.reserve(shape().dimensions_size() + 1); + for (int64 bound : shape().dimensions()) { + bounds.push_back(bound); + } + auto literal = + MakeUnique(ShapeUtil::MakeShape(shape().element_type(), bounds)); + int64 elements = ShapeUtil::ElementsIn(literal->shape()); + if (elements == 0) { + return literal; + } + + DimensionVector output_indices(bounds.size(), 0); + tensorflow::gtl::ArraySlice input_indices = output_indices; + input_indices.remove_prefix(1); + + bool done = false; + while (!done) { + const auto element = Get(input_indices); + literal->Set(output_indices, element); + + done = true; + for (int n = 0; n < output_indices.size(); ++n) { + ++output_indices[n]; + if (output_indices[n] < bounds[n]) { + done = false; + break; + } + output_indices[n] = 0; + } + } + return literal; +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LITERAL_H_ diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 2125ab7c61ab5e30fe51e16994e0da4883d509c4..94993cc87443ba8c22fd7c2eacfc8756d3f48edc 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -217,7 +218,7 @@ class NearComparator { return Printf( "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g", FpValueToString(actual).c_str(), FpValueToString(expected).c_str(), - Literal::MultiIndexAsString( + LiteralUtil::MultiIndexAsString( IndexUtil::LinearIndexToMultidimensionalIndex(shape, linear_index)) .c_str(), @@ -722,7 +723,7 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { return AppendStatus(result, tensorflow::strings::Printf( "\nat index: %s\nexpected: %s\nactual: %s", - Literal::MultiIndexAsString(multi_index).c_str(), + LiteralUtil::MultiIndexAsString(multi_index).c_str(), ToStringTruncated(expected).c_str(), ToStringTruncated(actual).c_str())); } diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h index 00a13e361932e74a9a1e614d5c851d3851208852..9e5bf7c1d062ef0f25d07a80d6ded8106df5dacc 100644 --- a/tensorflow/compiler/xla/literal_comparison.h +++ b/tensorflow/compiler/xla/literal_comparison.h @@ -20,7 +20,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ #include "tensorflow/compiler/xla/error_spec.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/lib/core/status.h" namespace xla { diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_test.cc similarity index 76% rename from tensorflow/compiler/xla/literal_util_test.cc rename to tensorflow/compiler/xla/literal_test.cc index 493d807591dd3c425293e4ee796bca3036a3088c..e8f919950f0efc8b508f7ad4aee5233176bc0abd 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" @@ -76,11 +77,11 @@ class LiteralUtilTest : public ::testing::Test { layout_r4_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2, 3}); literal_r4_2x2x3x3_dim0major_ = - Literal::CreateR4FromArray4DWithLayout(arr4d, - layout_r4_dim0major_); + LiteralUtil::CreateR4FromArray4DWithLayout(arr4d, + layout_r4_dim0major_); literal_r4_2x2x3x3_dim0minor_ = - Literal::CreateR4FromArray4DWithLayout(arr4d, - layout_r4_dim0minor_); + LiteralUtil::CreateR4FromArray4DWithLayout(arr4d, + layout_r4_dim0minor_); } Layout layout_r2_dim0major_; @@ -94,47 +95,47 @@ class LiteralUtilTest : public ::testing::Test { }; TEST_F(LiteralUtilTest, LiteralScalarToString) { - auto true_lit = Literal::CreateR0(true); + auto true_lit = LiteralUtil::CreateR0(true); ASSERT_EQ("true", true_lit->ToString()); - auto false_lit = Literal::CreateR0(false); + auto false_lit = LiteralUtil::CreateR0(false); ASSERT_EQ("false", false_lit->ToString()); - auto u32_lit = Literal::CreateR0(42); + auto u32_lit = LiteralUtil::CreateR0(42); ASSERT_EQ("42", u32_lit->ToString()); - auto s32_lit = Literal::CreateR0(-999); + auto s32_lit = LiteralUtil::CreateR0(-999); ASSERT_EQ("-999", s32_lit->ToString()); - auto f32_lit = Literal::CreateR0(3.14f); + auto f32_lit = LiteralUtil::CreateR0(3.14f); ASSERT_EQ("3.14", f32_lit->ToString()); - auto f16_lit = Literal::CreateR0(static_cast(0.5f)); + auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); ASSERT_EQ("0.5", f16_lit->ToString()); - auto c64_lit = Literal::CreateR0({3.14f, 2.78f}); + auto c64_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString()); - auto bf16_lit = Literal::CreateR0(static_cast(0.5f)); + auto bf16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); ASSERT_EQ("0.5", bf16_lit->ToString()); // 3.14 will be truncated to 3.125 in bfloat16 format. auto bf16_lit_truncated = - Literal::CreateR0(static_cast(3.14f)); + LiteralUtil::CreateR0(static_cast(3.14f)); ASSERT_EQ("3.125", bf16_lit_truncated->ToString()); auto bf16_lit_truncated2 = - Literal::CreateR0(static_cast(9.001f)); + LiteralUtil::CreateR0(static_cast(9.001f)); ASSERT_EQ("9", bf16_lit_truncated2->ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { - auto pred_vec = Literal::CreateR1({true, false, true}); + auto pred_vec = LiteralUtil::CreateR1({true, false, true}); ASSERT_EQ("{101}", pred_vec->ToString()); } TEST_F(LiteralUtilTest, R2ToString) { - const auto literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + const auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); const string expected = R"(s32[3,2] { { 1, 2 }, { 3, 4 }, @@ -144,7 +145,8 @@ TEST_F(LiteralUtilTest, R2ToString) { } TEST_F(LiteralUtilTest, R3ToString) { - const auto literal = Literal::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); + const auto literal = + LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); const string expected = R"(s32[3,2,1] { { { 1 }, { 2 } }, @@ -157,9 +159,9 @@ TEST_F(LiteralUtilTest, R3ToString) { } TEST_F(LiteralUtilTest, TupleToString) { - auto scalar = Literal::CreateR0(1.0); - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); + auto scalar = LiteralUtil::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); const string expected = R"((f32[], f32[2,2]) ( 1, f32[2,2] { @@ -182,7 +184,7 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { }); // clang-format on - auto literal = Literal::CreateR3FromArray3D(array_3d); + auto literal = LiteralUtil::CreateR3FromArray3D(array_3d); EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2)); string result = literal->ToString(); const string expected = R"(f32[2,3,2] { @@ -205,7 +207,7 @@ TEST_F(LiteralUtilTest, CreateSparse) { {3, 5, 6}, }; std::vector values = {7, 8, 9, 10}; - auto literal = Literal::CreateSparse( + auto literal = LiteralUtil::CreateSparse( dimensions, SparseIndexArray(indices.n1() + 3, indices), values); Array2D expected_indices = { @@ -224,7 +226,7 @@ TEST_F(LiteralUtilTest, CreateSparse) { TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { // clang-format off - auto literal = Literal::CreateR4Projected({ + auto literal = LiteralUtil::CreateR4Projected({ {1, 2}, {1001, 1002}, {2001, 2002}, @@ -284,7 +286,7 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { TEST_F(LiteralUtilTest, EachCellR2F32) { // clang-format off - auto literal = Literal::CreateR2({ + auto literal = LiteralUtil::CreateR2({ {3.1f, 4.2f}, {9.3f, 12.4f}, }); @@ -303,26 +305,27 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { TEST_F(LiteralUtilTest, ScalarEquality) { // Test equality with scalars. - auto f32_42 = Literal::CreateR0(42.0); - auto f32_42_clone = Literal::CreateR0(42.0); + auto f32_42 = LiteralUtil::CreateR0(42.0); + auto f32_42_clone = LiteralUtil::CreateR0(42.0); EXPECT_EQ(*f32_42, *f32_42); EXPECT_EQ(*f32_42, *f32_42_clone); - auto f32_123 = Literal::CreateR0(123.0); + auto f32_123 = LiteralUtil::CreateR0(123.0); EXPECT_NE(*f32_42, *f32_123); - auto f64_42 = Literal::CreateR0(42.0); + auto f64_42 = LiteralUtil::CreateR0(42.0); EXPECT_NE(*f32_42, *f64_42); } TEST_F(LiteralUtilTest, NonScalarEquality) { // Test equality with nonscalars. - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto matrix_clone = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto matrix_different = Literal::CreateR2({{4.0, 3.0}, {1.0, 2.0}}); - auto vector_literal = Literal::CreateR1({1.0, 2.0, 3.0, 4.0}); - auto scalar = Literal::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_clone = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_different = + LiteralUtil::CreateR2({{4.0, 3.0}, {1.0, 2.0}}); + auto vector_literal = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0}); + auto scalar = LiteralUtil::CreateR0(1.0); Literal nil(ShapeUtil::MakeNil()); EXPECT_EQ(*matrix, *matrix); @@ -335,19 +338,19 @@ TEST_F(LiteralUtilTest, NonScalarEquality) { } TEST_F(LiteralUtilTest, TokenEquality) { - auto token0 = Literal::CreateToken(); - auto token1 = Literal::CreateToken(); - auto scalar = Literal::CreateR0(1.0); + auto token0 = LiteralUtil::CreateToken(); + auto token1 = LiteralUtil::CreateToken(); + auto scalar = LiteralUtil::CreateR0(1.0); EXPECT_EQ(*token0, *token1); EXPECT_NE(*token0, *scalar); - EXPECT_EQ(*Literal::MakeTuple({token0.get()}), - *Literal::MakeTuple({token0.get()})); - EXPECT_EQ(*Literal::MakeTuple({token0.get(), scalar.get()}), - *Literal::MakeTuple({token1.get(), scalar.get()})); - EXPECT_NE(*Literal::MakeTuple({token0.get(), scalar.get()}), - *Literal::MakeTuple({scalar.get(), token1.get()})); + EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get()}), + *LiteralUtil::MakeTuple({token0.get()})); + EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}), + *LiteralUtil::MakeTuple({token1.get(), scalar.get()})); + EXPECT_NE(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}), + *LiteralUtil::MakeTuple({scalar.get(), token1.get()})); } TEST_F(LiteralUtilTest, DifferentLayoutEquality) { @@ -371,43 +374,46 @@ TEST_F(LiteralUtilTest, DifferentLayoutEquality) { TEST_F(LiteralUtilTest, TupleEquality) { // Test equality with tuples. - auto scalar = Literal::CreateR0(1.0); - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple1 = Literal::MakeTuple({scalar.get(), matrix.get()}); + auto scalar = LiteralUtil::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); // Tuple with the same elements. One element is shared with the original // tuple, the other is a clone of the element in the original tuple. - auto scalar_clone = Literal::CreateR0(1.0); - auto tuple2 = Literal::MakeTuple({scalar_clone.get(), matrix.get()}); + auto scalar_clone = LiteralUtil::CreateR0(1.0); + auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()}); EXPECT_EQ(*tuple1, *tuple2); // Tuple with elements reversed. - auto reversed_tuple = Literal::MakeTuple({matrix.get(), scalar.get()}); + auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()}); EXPECT_NE(*tuple1, *reversed_tuple); // Tuple with different value. - auto scalar_42 = Literal::CreateR0(42.0); - auto different_tuple = Literal::MakeTuple({scalar_42.get(), matrix.get()}); + auto scalar_42 = LiteralUtil::CreateR0(42.0); + auto different_tuple = + LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()}); EXPECT_NE(*tuple1, *different_tuple); } TEST_F(LiteralUtilTest, C64Equality) { // Test equality with tuples. - auto vector = Literal::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + auto vector = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); // Tuple with the same elements. One element is shared with the original // tuple, the other is a clone of the element in the original tuple. - auto vector_clone = Literal::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + auto vector_clone = + LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); EXPECT_EQ(*vector, *vector_clone); - auto vector_reversed = Literal::CreateR1({{3.0, 4.0}, {1.0, 2.0}}); + auto vector_reversed = + LiteralUtil::CreateR1({{3.0, 4.0}, {1.0, 2.0}}); EXPECT_NE(*vector, *vector_reversed); } TEST_F(LiteralUtilTest, IsAllTuple) { - auto element1 = Literal::CreateR0(0.0); - auto element2 = Literal::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); - auto tuple = Literal::MakeTuple({element1.get(), element1.get()}); + auto element1 = LiteralUtil::CreateR0(0.0); + auto element2 = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); + auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()}); // Tuples should always return false for IsAll. EXPECT_FALSE(tuple->IsAll(0)); @@ -416,140 +422,141 @@ TEST_F(LiteralUtilTest, IsAllTuple) { // Verifies that CreateFromShape works for tuples. TEST_F(LiteralUtilTest, CreateFromShapeTuple) { - auto scalar = Literal::CreateR0(0.0); - auto matrix = Literal::CreateR2({{0, 0}, {0, 0}}); - auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); + auto scalar = LiteralUtil::CreateR0(0.0); + auto matrix = LiteralUtil::CreateR2({{0, 0}, {0, 0}}); + auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); auto x = Literal::CreateFromShape(tuple->shape()); EXPECT_EQ(*tuple, *x); } TEST_F(LiteralUtilTest, IsAll) { - EXPECT_TRUE(Literal::CreateR0(false)->IsAll(0)); - EXPECT_TRUE(Literal::CreateR0(true)->IsAll(1)); - EXPECT_FALSE(Literal::CreateR0(false)->IsAll(1)); - EXPECT_FALSE(Literal::CreateR0(false)->IsAll(2)); - EXPECT_FALSE(Literal::CreateR0(true)->IsAll(0)); - EXPECT_FALSE(Literal::CreateR0(true)->IsAll(2)); - EXPECT_FALSE(Literal::CreateR0(true)->IsAll(-1)); + EXPECT_TRUE(LiteralUtil::CreateR0(false)->IsAll(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(true)->IsAll(1)); + EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAll(1)); + EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAll(2)); + EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(2)); + EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(-1)); // We shouldn't reinterpret int8_min as an unsigned type and then decide that // it is equal to 255. auto int8_min = std::numeric_limits::min(); - EXPECT_FALSE(Literal::CreateR0(255)->IsAll(int8_min)); + EXPECT_FALSE(LiteralUtil::CreateR0(255)->IsAll(int8_min)); - EXPECT_TRUE(Literal::CreateR0(42.0)->IsAll(42)); - EXPECT_FALSE(Literal::CreateR0(42.0001)->IsAll(42)); + EXPECT_TRUE(LiteralUtil::CreateR0(42.0)->IsAll(42)); + EXPECT_FALSE(LiteralUtil::CreateR0(42.0001)->IsAll(42)); - EXPECT_TRUE(Literal::CreateR1({100, 100, 100})->IsAll(100)); - EXPECT_FALSE(Literal::CreateR1({100, 100, 100.001})->IsAll(100)); + EXPECT_TRUE(LiteralUtil::CreateR1({100, 100, 100})->IsAll(100)); + EXPECT_FALSE(LiteralUtil::CreateR1({100, 100, 100.001})->IsAll(100)); - EXPECT_TRUE(Literal::CreateR2({{8, 8}, {8, 8}})->IsAll(8)); - EXPECT_FALSE(Literal::CreateR2({{8, 8}, {8, 9}})->IsAll(8)); - EXPECT_FALSE(Literal::CreateR2({{9, 8}, {8, 8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{8, 8}, {8, 8}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{8, 8}, {8, 9}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{9, 8}, {8, 8}})->IsAll(8)); half h8(8.0f); half h9(9.0f); - EXPECT_TRUE(Literal::CreateR2({{h8}, {h8}})->IsAll(8)); - EXPECT_FALSE(Literal::CreateR2({{h8}, {h9}})->IsAll(8)); - EXPECT_FALSE(Literal::CreateR2({{h9}, {h8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{h8}, {h8}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{h8}, {h9}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{h9}, {h8}})->IsAll(8)); bfloat16 b8(8.0f); bfloat16 b9(9.0f); - EXPECT_TRUE(Literal::CreateR2({{b8}, {b8}})->IsAll(8)); - EXPECT_FALSE(Literal::CreateR2({{b8}, {b9}})->IsAll(8)); - EXPECT_FALSE(Literal::CreateR2({{b9}, {b8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{b8}, {b8}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{b8}, {b9}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{b9}, {b8}})->IsAll(8)); // 9.001 will be truncated to 9.0 bfloat16 b91(9.001f); bfloat16 b90(9.00f); - EXPECT_TRUE(Literal::CreateR2({{b91}, {b90}})->IsAll(9.0)); + EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}})->IsAll(9.0)); complex64 c8_9 = {8, 9}; - EXPECT_FALSE(Literal::CreateR2({{c8_9}, {c8_9}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}})->IsAll(8)); auto uint64_max = std::numeric_limits::max(); - EXPECT_FALSE(Literal::CreateR2( + EXPECT_FALSE(LiteralUtil::CreateR2( {{uint64_max, uint64_max}, {uint64_max, uint64_max}}) ->IsAll(-1)); } TEST_F(LiteralUtilTest, IsAllFloat) { // IsAllFloat always returns false when the literal is not floating-point. - EXPECT_FALSE(Literal::CreateR0(false)->IsAllFloat(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllFloat(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllFloat(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllFloat(0)); - - EXPECT_TRUE(Literal::CreateR0(0)->IsAllFloat(0)); - EXPECT_TRUE(Literal::CreateR0(.5)->IsAllFloat(.5)); - EXPECT_TRUE(Literal::CreateR0(-.5)->IsAllFloat(-.5)); - EXPECT_FALSE(Literal::CreateR0(-.5)->IsAllFloat(-.49)); + EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); + + EXPECT_TRUE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(.5)->IsAllFloat(.5)); + EXPECT_TRUE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.5)); + EXPECT_FALSE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.49)); EXPECT_FALSE( - Literal::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); - EXPECT_TRUE( - Literal::CreateR2({{.5, .5, .5}, {.5, .5, .5}})->IsAllFloat(.5)); - - EXPECT_TRUE(Literal::CreateR0(0)->IsAllFloat(0)); - EXPECT_TRUE(Literal::CreateR0(.5)->IsAllFloat(.5)); - EXPECT_TRUE(Literal::CreateR0(-.5)->IsAllFloat(-.5)); - EXPECT_FALSE(Literal::CreateR0(-.5)->IsAllFloat(-.49)); + LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR2({{.5, .5, .5}, {.5, .5, .5}}) + ->IsAllFloat(.5)); + + EXPECT_TRUE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(.5)->IsAllFloat(.5)); + EXPECT_TRUE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.5)); + EXPECT_FALSE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.49)); EXPECT_FALSE( - Literal::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); + LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); } TEST_F(LiteralUtilTest, IsAllComplex) { // IsAllComplex always returns false when the literal is not complex. - EXPECT_FALSE(Literal::CreateR0(false)->IsAllComplex(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); complex64 c8_9 = {8, 9}; complex64 c7_9 = {7, 9}; - EXPECT_TRUE(Literal::CreateR2({{c8_9}, {c8_9}}) + EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}) ->IsAllComplex({8.0f, 9.0f})); - EXPECT_FALSE(Literal::CreateR2({{c7_9}, {c8_9}}) + EXPECT_FALSE(LiteralUtil::CreateR2({{c7_9}, {c8_9}}) ->IsAllComplex({8.0f, 9.0f})); - EXPECT_FALSE(Literal::CreateR2({{c8_9}, {c7_9}}) + EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c7_9}}) ->IsAllComplex({8.0f, 9.0f})); } TEST_F(LiteralUtilTest, IsAllFirst) { // IsAllComplex always returns false when the literal is not complex. - EXPECT_FALSE(Literal::CreateR1({false, true})->IsAllFirst()); - EXPECT_TRUE(Literal::CreateR1({false, false})->IsAllFirst()); - EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(Literal::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(Literal::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(Literal::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({false, true})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({false, false})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); complex64 c8_9 = {8, 9}; complex64 c7_9 = {7, 9}; - EXPECT_TRUE(Literal::CreateR2({{c8_9}, {c8_9}})->IsAllFirst()); - EXPECT_FALSE(Literal::CreateR2({{c7_9}, {c8_9}})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}})->IsAllFirst()); + EXPECT_FALSE( + LiteralUtil::CreateR2({{c7_9}, {c8_9}})->IsAllFirst()); } TEST_F(LiteralUtilTest, IsZero) { - auto scalar_zero = Literal::CreateR0(0.0f); - auto scalar_one = Literal::CreateR0(1.0f); + auto scalar_zero = LiteralUtil::CreateR0(0.0f); + auto scalar_one = LiteralUtil::CreateR0(1.0f); EXPECT_TRUE(scalar_zero->IsZero({})); EXPECT_FALSE(scalar_one->IsZero({})); - auto array = Literal::CreateR2({{1, 2, 0, 3}, {1, 0, 1, 2}}); + auto array = LiteralUtil::CreateR2({{1, 2, 0, 3}, {1, 0, 1, 2}}); EXPECT_FALSE(array->IsZero({0, 1})); EXPECT_TRUE(array->IsZero({0, 2})); EXPECT_TRUE(array->IsZero({1, 1})); EXPECT_FALSE(array->IsZero({1, 2})); - auto complex_zero = Literal::CreateR0(0.0f); - auto complex_nonzero = Literal::CreateR0(0.5f); + auto complex_zero = LiteralUtil::CreateR0(0.0f); + auto complex_nonzero = LiteralUtil::CreateR0(0.5f); EXPECT_TRUE(complex_zero->IsZero({})); EXPECT_FALSE(complex_nonzero->IsZero({})); } @@ -563,7 +570,7 @@ TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes); TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { // Make a non-integer for floating point types. TypeParam half = TypeParam(1) / TypeParam(2); - auto data = Literal::CreateR2({{half, 2}, {3, 4}}); + auto data = LiteralUtil::CreateR2({{half, 2}, {3, 4}}); const Layout layout01 = LayoutUtil::MakeLayout({0, 1}); const Layout layout10 = LayoutUtil::MakeLayout({1, 0}); @@ -577,7 +584,7 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { } TEST_F(LiteralUtilTest, ReshapeR0) { - auto original = Literal::CreateR0(1.7f); + auto original = LiteralUtil::CreateR0(1.7f); auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); EXPECT_EQ(*original, *reshape); } @@ -585,13 +592,13 @@ TEST_F(LiteralUtilTest, ReshapeR0) { TEST_F(LiteralUtilTest, ReshapeR4) { // clang-format off // F32[1x3x2x4] - auto original = Literal::CreateR4WithLayout({{ + auto original = LiteralUtil::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); // F32[1x3x4x2] - auto expected = Literal::CreateR3WithLayout({ + auto expected = LiteralUtil::CreateR3WithLayout({ {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, @@ -605,13 +612,13 @@ TEST_F(LiteralUtilTest, ReshapeR4) { TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { // clang-format off // F32[1x3x2x4] - auto original = Literal::CreateR4WithLayout({{ + auto original = LiteralUtil::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0minor_); // F32[1x3x4x2] - auto expected = Literal::CreateR3WithLayout({ + auto expected = LiteralUtil::CreateR3WithLayout({ {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, @@ -623,7 +630,7 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { } TEST_F(LiteralUtilTest, TransposeR0) { - auto original = Literal::CreateR0(1.7f); + auto original = LiteralUtil::CreateR0(1.7f); auto reshape = original->Transpose(/*permutation=*/{}); EXPECT_EQ(*original, *reshape); } @@ -631,7 +638,7 @@ TEST_F(LiteralUtilTest, TransposeR0) { TEST_F(LiteralUtilTest, TransposeR4) { // clang-format off // F32[1x3x2x4] - auto original = Literal::CreateR4({{ + auto original = LiteralUtil::CreateR4({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, @@ -659,7 +666,7 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { TEST_F(LiteralUtilTest, TestR2LinearLayout) { // Test expected memory layout of R2 dim0-minor (column-major) literal. - auto mat_dim0minor = Literal::CreateR2WithLayout( + auto mat_dim0minor = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_); EXPECT_EQ(mat_dim0minor->element_count(), 6); EXPECT_THAT(mat_dim0minor->data(), ElementsAre(1, 4, 2, 5, 3, 6)); @@ -670,7 +677,7 @@ TEST_F(LiteralUtilTest, TestR2LinearLayout) { ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout of R2 created with dim0-major (row-major). - auto mat_dim0major = Literal::CreateR2WithLayout( + auto mat_dim0major = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_); EXPECT_EQ(mat_dim0major->element_count(), 6); EXPECT_THAT(mat_dim0major->data(), ElementsAre(1, 2, 3, 4, 5, 6)); @@ -695,8 +702,8 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { {10, 11, 12}, }, }); // clang-format on - auto lit_dim0minor = - Literal::CreateR3FromArray3DWithLayout(arr3d, layout_r3_dim0minor_); + auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout( + arr3d, layout_r3_dim0minor_); EXPECT_EQ(lit_dim0minor->element_count(), 12); std::vector expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}; @@ -710,8 +717,8 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout of R3 created with dim0-major (row-major). - auto lit_dim0major = - Literal::CreateR3FromArray3DWithLayout(arr3d, layout_r3_dim0major_); + auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout( + arr3d, layout_r3_dim0major_); EXPECT_EQ(lit_dim0major->element_count(), 12); EXPECT_THAT(lit_dim0major->data(), testing::ElementsAreArray(expected_dim0major)); @@ -723,28 +730,28 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { } TEST_F(LiteralUtilTest, SliceR0S32) { - auto input = Literal::CreateR0(1); + auto input = LiteralUtil::CreateR0(1); auto result = input->Slice({}, {}); EXPECT_EQ(*input, *result); } TEST_F(LiteralUtilTest, SliceR1F32) { - auto input = Literal::CreateR1({1.0, 2.0, 3.0, 4.0, 5.0}); + auto input = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0, 5.0}); auto result = input->Slice({3}, {4}); - auto expected = Literal::CreateR1({4.0}); + auto expected = LiteralUtil::CreateR1({4.0}); EXPECT_EQ(*expected, *result); } TEST_F(LiteralUtilTest, SliceR2U32) { - auto input_3x4 = - Literal::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); + auto input_3x4 = LiteralUtil::CreateR2( + {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); auto result = input_3x4->Slice({0, 2}, {2, 4}); - auto expected = Literal::CreateR2({{3, 4}, {7, 8}}); + auto expected = LiteralUtil::CreateR2({{3, 4}, {7, 8}}); EXPECT_EQ(*expected, *result); } TEST_F(LiteralUtilTest, SliceR3U32Full) { - auto input_2x3x2 = Literal::CreateR3( + auto input_2x3x2 = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2}); EXPECT_EQ(*input_2x3x2, *result); @@ -753,21 +760,21 @@ TEST_F(LiteralUtilTest, SliceR3U32Full) { TEST_F(LiteralUtilTest, PopulateR1S64) { Literal output(ShapeUtil::MakeShape(S64, {1})); output.PopulateR1({77}); - auto expected = Literal::CreateR1({77}); + auto expected = LiteralUtil::CreateR1({77}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateR1U64) { Literal output(ShapeUtil::MakeShape(U64, {2})); output.PopulateR1({{77, 88}}); - auto expected = Literal::CreateR1({{77, 88}}); + auto expected = LiteralUtil::CreateR1({{77, 88}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateR1C64) { Literal output(ShapeUtil::MakeShape(C64, {1})); output.PopulateR1({{77, 88}}); - auto expected = Literal::CreateR1({{77, 88}}); + auto expected = LiteralUtil::CreateR1({{77, 88}}); EXPECT_EQ(output, *expected); } @@ -775,7 +782,7 @@ TEST_F(LiteralUtilTest, PopulateR2C64) { Literal output(ShapeUtil::MakeShape(C64, {2, 2})); output.PopulateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); auto expected = - Literal::CreateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); + LiteralUtil::CreateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); EXPECT_EQ(output, *expected); } @@ -783,7 +790,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { Literal output(ShapeUtil::MakeShape(BF16, {})); bfloat16 h(0.25f); output.PopulateWithValue(h); - auto expected = Literal::CreateR0(h); + auto expected = LiteralUtil::CreateR0(h); EXPECT_EQ(output, *expected); } @@ -791,7 +798,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { Literal output(ShapeUtil::MakeShape(BF16, {3})); bfloat16 h(0.5f); output.PopulateWithValue(h); - auto expected = Literal::CreateR1({h, h, h}); + auto expected = LiteralUtil::CreateR1({h, h, h}); EXPECT_EQ(output, *expected); } @@ -799,28 +806,28 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { Literal output(ShapeUtil::MakeShape(BF16, {2, 2})); bfloat16 h(2.0f); output.PopulateWithValue(h); - auto expected = Literal::CreateR2({{h, h}, {h, h}}); + auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output(ShapeUtil::MakeShape(F32, {})); output.PopulateWithValue(2.5f); - auto expected = Literal::CreateR0(2.5f); + auto expected = LiteralUtil::CreateR0(2.5f); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1S64) { Literal output(ShapeUtil::MakeShape(S64, {3})); output.PopulateWithValue(-7); - auto expected = Literal::CreateR1({-7, -7, -7}); + auto expected = LiteralUtil::CreateR1({-7, -7, -7}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { Literal output(ShapeUtil::MakeShape(U64, {2, 2})); output.PopulateWithValue(42); - auto expected = Literal::CreateR2({{42, 42}, {42, 42}}); + auto expected = LiteralUtil::CreateR2({{42, 42}, {42, 42}}); EXPECT_EQ(output, *expected); } @@ -828,7 +835,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { Literal output(ShapeUtil::MakeShape(C64, {2, 2})); output.PopulateWithValue({4, 2}); auto expected = - Literal::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); + LiteralUtil::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); EXPECT_EQ(output, *expected); } @@ -836,7 +843,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { Literal output(ShapeUtil::MakeShape(F16, {})); half h(0.25f); output.PopulateWithValue(h); - auto expected = Literal::CreateR0(h); + auto expected = LiteralUtil::CreateR0(h); EXPECT_EQ(output, *expected); } @@ -844,7 +851,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { Literal output(ShapeUtil::MakeShape(F16, {3})); half h(0.5f); output.PopulateWithValue(h); - auto expected = Literal::CreateR1({h, h, h}); + auto expected = LiteralUtil::CreateR1({h, h, h}); EXPECT_EQ(output, *expected); } @@ -852,15 +859,15 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { Literal output(ShapeUtil::MakeShape(F16, {2, 2})); half h(2.0f); output.PopulateWithValue(h); - auto expected = Literal::CreateR2({{h, h}, {h, h}}); + auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, ReplicateR2U32) { - auto input = - Literal::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); + auto input = LiteralUtil::CreateR2( + {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); auto output = input->Replicate(3); - auto expected = Literal::CreateR3( + auto expected = LiteralUtil::CreateR3( {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}}); @@ -914,12 +921,12 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { } TEST_F(LiteralUtilTest, CopyFromScalars) { - auto zero = Literal::CreateR0(0); - auto nine = Literal::CreateR0(9); + auto zero = LiteralUtil::CreateR0(0); + auto nine = LiteralUtil::CreateR0(9); TF_EXPECT_OK(zero->CopyFrom(*nine)); EXPECT_EQ(*zero, *nine); - auto vect = Literal::CreateR1({3, 4, 9, 12, 5, 17, 21}); + auto vect = LiteralUtil::CreateR1({3, 4, 9, 12, 5, 17, 21}); TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {})); EXPECT_EQ(zero->Get({}), 17); TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {})); @@ -928,13 +935,13 @@ TEST_F(LiteralUtilTest, CopyFromScalars) { TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { const Shape empty_r1_shape = ShapeUtil::MakeShape(F32, {0}); - const auto const_nine = Literal::CreateR1({9}); + const auto const_nine = LiteralUtil::CreateR1({9}); const auto const_empty = Literal::CreateFromShape(empty_r1_shape); { // Source contains dimension with zero elements. const auto empty = Literal::CreateFromShape(empty_r1_shape); - auto nine = Literal::CreateR1({9}); + auto nine = LiteralUtil::CreateR1({9}); TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0})); EXPECT_EQ(*nine, *const_nine); @@ -943,7 +950,7 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { { // Copy 0 element to destination with zero elements. const auto empty = Literal::CreateFromShape(empty_r1_shape); - auto nine = Literal::CreateR1({9}); + auto nine = LiteralUtil::CreateR1({9}); TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0})); EXPECT_EQ(*empty, *const_empty); @@ -958,16 +965,16 @@ TEST_F(LiteralUtilTest, CopyFromNilShape) { } TEST_F(LiteralUtilTest, CopyFromArrays) { - auto scalar_42 = Literal::CreateR0(42.0); - auto scalar_123 = Literal::CreateR0(123.0); + auto scalar_42 = LiteralUtil::CreateR0(42.0); + auto scalar_123 = LiteralUtil::CreateR0(123.0); EXPECT_NE(*scalar_42, *scalar_123); TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{}, /*src_shape_index=*/{})); EXPECT_EQ(*scalar_42, *scalar_123); EXPECT_EQ(scalar_42->Get({}), 123.0f); - auto matrix_1234 = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto matrix_5678 = Literal::CreateR2({{5.0, 6.0}, {7.0, 8.0}}); + auto matrix_1234 = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_5678 = LiteralUtil::CreateR2({{5.0, 6.0}, {7.0, 8.0}}); EXPECT_NE(*matrix_1234, *matrix_5678); EXPECT_EQ(matrix_1234->Get({0, 0}), 1.0f); TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{}, @@ -977,19 +984,19 @@ TEST_F(LiteralUtilTest, CopyFromArrays) { } TEST_F(LiteralUtilTest, CopyFromTuples) { - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = Literal::MakeTuple( + auto nested_tuple = LiteralUtil::MakeTuple( {matrix.get(), - Literal::MakeTuple({Literal::CreateR0(42).get(), - Literal::CreateR1({23.0, 44.0}).get(), - &nil_literal}) + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(42).get(), + LiteralUtil::CreateR1({23.0, 44.0}).get(), &nil_literal}) .get()}); // Create a tuple the same shape as the inner tuple of nested_tuple but with // different values.. - auto tuple = Literal::MakeTuple({Literal::CreateR0(-5).get(), - Literal::CreateR1({2.0, 4.0}).get(), - &nil_literal}); + auto tuple = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(-5).get(), + LiteralUtil::CreateR1({2.0, 4.0}).get(), &nil_literal}); EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); EXPECT_EQ(nested_tuple->Get({}, {1, 0}), 42); @@ -1010,8 +1017,8 @@ TEST_F(LiteralUtilTest, CopyFromTuples) { EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 4.0); } TEST_F(LiteralUtilTest, CopyBetweenSameTuple) { - auto tuple = Literal::MakeTuple( - {Literal::CreateR0(-2).get(), Literal::CreateR0(4).get()}); + auto tuple = LiteralUtil::MakeTuple({LiteralUtil::CreateR0(-2).get(), + LiteralUtil::CreateR0(4).get()}); EXPECT_EQ(tuple->Get({}, {0}), -2); EXPECT_EQ(tuple->Get({}, {1}), 4); @@ -1025,8 +1032,8 @@ TEST_F(LiteralUtilTest, CopyBetweenSameTuple) { } TEST_F(LiteralUtilTest, CopyFromDifferentShapes) { - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto vector = Literal::CreateR1({5.0, 7.0}); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto vector = LiteralUtil::CreateR1({5.0, 7.0}); Status status = matrix->CopyFrom(*vector); ASSERT_FALSE(status.ok()); ASSERT_THAT(status.error_message(), @@ -1051,7 +1058,7 @@ TEST_F(LiteralUtilTest, F16) { half h1(1.0f); half h2(2.0f); - auto m2 = Literal::CreateR2({{h1, h2}, {h2, h1}}); + auto m2 = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); Literal* l2 = m2.get(); const char* d2 = reinterpret_cast(l2->data().data()); EXPECT_EQ(d2[0], 0); @@ -1150,12 +1157,12 @@ TEST_F(LiteralUtilTest, PopulateParallel) { TEST_F(LiteralUtilTest, ConvertR4) { // clang-format off - auto original = Literal::CreateR4WithLayout({{ + auto original = LiteralUtil::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); - auto expected = Literal::CreateR4WithLayout({{ + auto expected = LiteralUtil::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, @@ -1169,42 +1176,42 @@ TEST_F(LiteralUtilTest, ConvertR4) { TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { // clang-format off - auto s8 = Literal::CreateR4WithLayout({{ + auto s8 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto s32 = Literal::CreateR4WithLayout({{ + auto s32 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto u32 = Literal::CreateR4WithLayout({{ + auto u32 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto s64 = Literal::CreateR4WithLayout({{ + auto s64 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto u64 = Literal::CreateR4WithLayout({{ + auto u64 = LiteralUtil::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto pred = Literal::CreateR4WithLayout({{ + auto pred = LiteralUtil::CreateR4WithLayout({{ {{true, false, true, false}, {false, true, false, true}}, {{false, true, false, true}, {true, false, true, false}}, {{true, false, true, false}, {false, true, false, true}}, }}, layout_r4_dim0major_); - auto int32_pred = Literal::CreateR4WithLayout({{ + auto int32_pred = LiteralUtil::CreateR4WithLayout({{ {{1, 0, 1, 0}, {0, 1, 0, 1}}, {{0, 1, 0, 1}, {1, 0, 1, 0}}, {{1, 0, 1, 0}, {0, 1, 0, 1}}, }}, layout_r4_dim0major_); - auto f16 = Literal::CreateR4WithLayout({{ + auto f16 = LiteralUtil::CreateR4WithLayout({{ {{half(10.0), half(0.0), half(12.0), half(0.0)}, {half(0.0), half(15.0), half(0.0), half(17.0)}}, {{half(0.0), half(19.0), half(0.0), half(21.0)}, @@ -1212,7 +1219,7 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{half(26.0), half(0.0), half(28.0), half(0.0)}, {half(0.0), half(31.0), half(0.0), half(33.0)}}, }}, layout_r4_dim0major_); - auto bf16 = Literal::CreateR4WithLayout({{ + auto bf16 = LiteralUtil::CreateR4WithLayout({{ {{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)}, {bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}}, {{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)}, @@ -1220,17 +1227,17 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)}, {bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}}, }}, layout_r4_dim0major_); - auto f32 = Literal::CreateR4WithLayout({{ + auto f32 = LiteralUtil::CreateR4WithLayout({{ {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, }}, layout_r4_dim0major_); - auto f64 = Literal::CreateR4WithLayout({{ + auto f64 = LiteralUtil::CreateR4WithLayout({{ {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}}, {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}}, {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}}, }}, layout_r4_dim0major_); - auto c64 = Literal::CreateR4WithLayout({{ + auto c64 = LiteralUtil::CreateR4WithLayout({{ {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, @@ -1302,18 +1309,18 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { } TEST_F(LiteralUtilTest, BitcastConvert) { - auto original = - Literal::CreateR1({tensorflow::bit_cast(2.5f), - tensorflow::bit_cast(-42.25f), - tensorflow::bit_cast(100.f), 0xbeef}); - auto expected = Literal::CreateR1( + auto original = LiteralUtil::CreateR1( + {tensorflow::bit_cast(2.5f), + tensorflow::bit_cast(-42.25f), + tensorflow::bit_cast(100.f), 0xbeef}); + auto expected = LiteralUtil::CreateR1( {2.5f, -42.25f, 100.0f, tensorflow::bit_cast(0xbeef)}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr converted, original->BitcastConvert(F32)); } TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) { - auto literal = Literal::CreateR0(1234); + auto literal = LiteralUtil::CreateR0(1234); Status status = literal->BitcastConvert(F64).status(); EXPECT_NE(Status::OK(), status); EXPECT_TRUE(tensorflow::str_util::StrContains(status.error_message(), @@ -1348,7 +1355,7 @@ TEST_F(LiteralUtilTest, ToProto_f16) { half h1(1.0f); half h2(2.0f); - auto m = Literal::CreateR2({{h1, h2}, {h2, h1}}); + auto m = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); Literal* l = m.get(); EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape())); EXPECT_EQ(4, l->data().size()); @@ -1391,10 +1398,10 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { } TEST_F(LiteralUtilTest, LiteralSliceTest) { - auto scalar = Literal::CreateR0(1.0); - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); + auto scalar = LiteralUtil::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); Literal nil(ShapeUtil::MakeNil()); EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar); @@ -1413,10 +1420,10 @@ TEST_F(LiteralUtilTest, LiteralSliceTest) { } TEST_F(LiteralUtilTest, MutatingLiteralSlice) { - auto scalar = Literal::CreateR0(1.0); - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); + auto scalar = LiteralUtil::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); // Verify that changing the underlying data beneath the view changes the // data of the view itself. const auto nested_tuple_view = LiteralSlice(*nested_tuple); @@ -1436,15 +1443,16 @@ TEST_F(LiteralUtilTest, MutatingLiteralSlice) { } TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { - auto scalar = Literal::CreateR0(1.0); - auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); + auto scalar = LiteralUtil::CreateR0(1.0); + auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); const auto nested_tuple_view = LiteralSlice(*nested_tuple); const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0}); const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1}); - EXPECT_EQ(matrix_view, *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); + EXPECT_EQ(matrix_view, + *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) { @@ -1488,7 +1496,7 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) { TEST_F(LiteralUtilTest, LiteralMove) { std::unique_ptr matrix = - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); Literal literal(std::move(*matrix)); EXPECT_TRUE( @@ -1501,11 +1509,11 @@ TEST_F(LiteralUtilTest, LiteralMove) { TEST_F(LiteralUtilTest, DecomposeTuple) { Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = Literal::MakeTuple( - {Literal::CreateR2({{1, 2}, {3, 4}}).get(), - Literal::MakeTuple({Literal::CreateR0(42).get(), - Literal::CreateR1({23.0, 44.0}).get(), - &nil_literal}) + auto nested_tuple = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1, 2}, {3, 4}}).get(), + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(42).get(), + LiteralUtil::CreateR1({23.0, 44.0}).get(), &nil_literal}) .get(), &nil_literal}); @@ -1542,13 +1550,13 @@ TEST_F(LiteralUtilTest, DecomposeEmptyTuple) { TEST_F(LiteralUtilTest, MoveIntoTuple) { std::vector elements; - elements.push_back(std::move(*Literal::CreateR0(1.0))); - elements.push_back(std::move(*Literal::CreateR1({4, 8}))); - elements.push_back(std::move( - *Literal::MakeTuple({Literal::CreateR0(42).get(), - Literal::CreateR1({23.0, 44.0}).get()}) + elements.push_back(std::move(*LiteralUtil::CreateR0(1.0))); + elements.push_back(std::move(*LiteralUtil::CreateR1({4, 8}))); + elements.push_back(std::move(*LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(42).get(), + LiteralUtil::CreateR1({23.0, 44.0}).get()}) - )); + )); Literal literal = Literal::MoveIntoTuple(&elements); ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); @@ -1577,7 +1585,7 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape())); std::unique_ptr matrix = - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); literal = std::move(*matrix); EXPECT_TRUE( @@ -1590,7 +1598,7 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { TEST_F(LiteralUtilTest, LiteralSliceCopy) { std::unique_ptr matrix = - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); const auto matrix_view = LiteralSlice(*matrix); LiteralSlice matrix_view_copy(matrix_view); @@ -1601,9 +1609,9 @@ TEST_F(LiteralUtilTest, LiteralSliceCopy) { } TEST_F(LiteralUtilTest, GetSetTuple) { - auto tuple = Literal::MakeTuple( - {Literal::CreateR0(42.0).get(), - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get()}); + auto tuple = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(42.0).get(), + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get()}); EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0); tuple->Set(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0); EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0); @@ -1644,20 +1652,20 @@ TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { TEST_F(LiteralUtilTest, ProtoRoundTrip) { // Test serializing then deserializing a Literal through a proto. - auto one_f32 = Literal::CreateR0(1.0); - auto two_f32 = Literal::CreateR0(2.0); - auto vector_int8 = Literal::CreateR1({-128, 0, 2, 4, 7, 56, 127}); - auto vector_c64 = Literal::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); - auto vector_bfloat16 = Literal::CreateR1( + auto one_f32 = LiteralUtil::CreateR0(1.0); + auto two_f32 = LiteralUtil::CreateR0(2.0); + auto vector_int8 = LiteralUtil::CreateR1({-128, 0, 2, 4, 7, 56, 127}); + auto vector_c64 = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + auto vector_bfloat16 = LiteralUtil::CreateR1( {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); auto vector_half = - Literal::CreateR1({half{10.0}, half{20.0}, half{-30.0}}); + LiteralUtil::CreateR1({half{10.0}, half{20.0}, half{-30.0}}); auto matrix_pred = - Literal::CreateR2({{true, false, true}, {false, false, true}}); - auto tuple = Literal::MakeTuple( + LiteralUtil::CreateR2({{true, false, true}, {false, false, true}}); + auto tuple = LiteralUtil::MakeTuple( {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()}); Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = Literal::MakeTuple( + auto nested_tuple = LiteralUtil::MakeTuple( {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal}); auto to_from_proto = [](const Literal& literal) -> Literal { @@ -1790,8 +1798,8 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { } TEST_F(LiteralUtilTest, SortSparseElements) { - auto literal = - Literal::CreateSparse({10, 10, 10}, SparseIndexArray(10, 3), {}); + auto literal = LiteralUtil::CreateSparse({10, 10, 10}, + SparseIndexArray(10, 3), {}); literal->AppendSparseElement({2, 3, 4}, 2.0); literal->AppendSparseElement({3, 4, 5}, 3.0); literal->AppendSparseElement({1, 2, 3}, 1.0); @@ -1805,21 +1813,22 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}}); ASSERT_EQ( - Literal::CreateSparse(dimensions, indices, {true, false, true}) + LiteralUtil::CreateSparse(dimensions, indices, {true, false, true}) ->GetSparseElementAsString(1), "false"); - ASSERT_EQ(Literal::CreateSparse(dimensions, indices, {1, 2, 3}) + ASSERT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {1, 2, 3}) ->GetSparseElementAsString(1), tensorflow::strings::StrCat(int64{2})); - ASSERT_EQ(Literal::CreateSparse(dimensions, indices, {1.0, 2.0, 3.0}) - ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(double{2.0})); - ASSERT_EQ(Literal::CreateSparse(dimensions, indices, - {half{1.0}, half{2.0}, half{3.0}}) + ASSERT_EQ( + LiteralUtil::CreateSparse(dimensions, indices, {1.0, 2.0, 3.0}) + ->GetSparseElementAsString(1), + tensorflow::strings::StrCat(double{2.0})); + ASSERT_EQ(LiteralUtil::CreateSparse(dimensions, indices, + {half{1.0}, half{2.0}, half{3.0}}) ->GetSparseElementAsString(1), tensorflow::strings::StrCat(static_cast(half{2.0}))); ASSERT_EQ( - Literal::CreateSparse( + LiteralUtil::CreateSparse( dimensions, indices, std::vector{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) ->GetSparseElementAsString(1), @@ -1827,33 +1836,36 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) { - std::unique_ptr literal = Literal::CreateR1({1, 2}); + std::unique_ptr literal = LiteralUtil::CreateR1({1, 2}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr broadcasted_literal, literal->Broadcast( /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), /*dimensions=*/{0})); - EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{1, 1}, {2, 2}})); + EXPECT_EQ(*broadcasted_literal, + *LiteralUtil::CreateR2({{1, 1}, {2, 2}})); } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) { - std::unique_ptr literal = Literal::CreateR1({1, 2}); + std::unique_ptr literal = LiteralUtil::CreateR1({1, 2}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr broadcasted_literal, literal->Broadcast( /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), /*dimensions=*/{1})); - EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{1, 2}, {1, 2}})); + EXPECT_EQ(*broadcasted_literal, + *LiteralUtil::CreateR2({{1, 2}, {1, 2}})); } TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) { - std::unique_ptr literal = Literal::CreateR0(9); + std::unique_ptr literal = LiteralUtil::CreateR0(9); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr broadcasted_literal, literal->Broadcast( /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}), /*dimensions=*/{})); - EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{9, 9}, {9, 9}})); + EXPECT_EQ(*broadcasted_literal, + *LiteralUtil::CreateR2({{9, 9}, {9, 9}})); } } // namespace diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index eeabf835ac348a5ba55699631188b0e329c98c43..5d33df7d40bf3bfcc8012ce1129d532b34555344 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -34,34 +34,15 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/types.h" -using tensorflow::strings::Printf; using tensorflow::strings::StrCat; namespace xla { namespace { -constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; - -// Converts between little and big endian. -// -// Precondition: size % 2 == 0 (elements in the array are 16 bits long) -void ConvertEndianShort(string* bytes) { - CHECK_EQ(bytes->size() / 2, 0); - for (int64 i = 0; i < bytes->size(); i += 2) { - std::swap((*bytes)[i], (*bytes)[i + 1]); - } -} - -void ConvertEndianShort(char* bytes, int64 size) { - CHECK_EQ(size / 2, 0); - for (int64 i = 0; i < size; i += 2) { - std::swap(bytes[i], bytes[i + 1]); - } -} - // Return a literal with all arrays of type FromNativeT converted to type // ToNativeT in the given literal. template @@ -103,505 +84,54 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } // namespace -LiteralBase::~LiteralBase() {} - -std::ostream& operator<<(std::ostream& out, const Literal& literal) { - out << literal.ToString(); - return out; -} - -Literal::StrideConfig::StrideConfig( - const Shape& source_shape, const Shape& dest_shape, - tensorflow::gtl::ArraySlice dimensions) - : dimensions(dimensions), - base(dimensions.size(), 0), - step(dimensions.size(), 1) { - if (!dimensions.empty()) { - // Selects the shape with the largest minor dimension as the one upon - // which to run the tight stride loop. - if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >= - dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) { - minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0); - dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension); - } else { - minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0); - source_stride = - IndexUtil::GetDimensionStride(source_shape, minor_dimension); - } - minor_loop_size = dimensions[minor_dimension]; - step[minor_dimension] = minor_loop_size; - } -} - -Literal::Literal(const Shape& shape) - : Literal(shape, /*allocate_arrays=*/true) {} - -void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { - if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - const Shape& subshape = shape.tuple_shapes(i); - - auto child_piece = Piece(); - child_piece.set_subshape(&subshape); - - SetPiece(subshape, &child_piece, allocate_arrays); - - piece->emplace_back(std::move(child_piece)); - } - } else if (ShapeUtil::IsArray(shape)) { - if (allocate_arrays) { - if (LayoutUtil::IsSparseArray(shape)) { - // For sparse arrays, the buffer must be of the size of the maximum - // number of sparse elements possible. - const int64 max_sparse_elements = - LayoutUtil::MaxSparseElements(shape.layout()); - piece->set_buffer( - new char[max_sparse_elements * - ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); - piece->set_sparse_indices( - new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape))); - } else { - piece->set_buffer(new char[piece->size_bytes()]); - } - } - } else { - // If the shape is neither an array nor tuple, then it must be - // zero-sized. Otherwise, some memory needs to be allocated for it. - CHECK_EQ(piece->size_bytes(), 0); - } -} - -Literal::Literal(const Shape& shape, bool allocate_arrays) - : LiteralBase(), shape_(MakeUnique(shape)) { - CHECK(LayoutUtil::HasLayout(*shape_)); - root_piece_ = new Piece(); - root_piece_->set_subshape(shape_.get()); - CHECK(&root_piece_->subshape() == shape_.get()); - - SetPiece(*shape_, root_piece_, allocate_arrays); -} - -Literal::~Literal() { - if (root_piece_ != nullptr) { - DeallocateBuffers(); - delete root_piece_; - } -} - -void Literal::DeallocateBuffers() { - root_piece_->ForEachMutableSubpiece( - [&](const ShapeIndex& index, Piece* piece) { - if (piece->buffer() != nullptr) { - delete[] piece->buffer(); - delete piece->sparse_indices(); - } - }); -} - -Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); } - -Literal& Literal::operator=(Literal&& other) { - DCHECK(&other.root_piece_->subshape() == other.shape_.get()); - using std::swap; - swap(shape_, other.shape_); - swap(root_piece_, other.root_piece_); - DCHECK(&root_piece_->subshape() == shape_.get()); - - return *this; -} - -std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { - auto literal = MakeUnique(shape); - literal->root_piece_->ForEachMutableSubpiece( - [&](const ShapeIndex& index, Piece* piece) { - if (ShapeUtil::IsArray(piece->subshape())) { - memset(piece->untyped_data(), 0, piece->size_bytes()); - } - }); - return literal; -} - -const SparseIndexArray* LiteralBase::sparse_indices( - const ShapeIndex& shape_index) const { - return piece(shape_index).sparse_indices(); -} - -SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { - return piece(shape_index).sparse_indices(); -} - -/* static */ std::unique_ptr Literal::CreateFromDimensions( +/* static */ std::unique_ptr LiteralUtil::CreateFromDimensions( PrimitiveType primitive_type, tensorflow::gtl::ArraySlice dimensions) { - return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions)); + return Literal::CreateFromShape( + ShapeUtil::MakeShape(primitive_type, dimensions)); } -/* static */ std::unique_ptr Literal::ConvertBF16ToF32( +/* static */ std::unique_ptr LiteralUtil::ConvertBF16ToF32( const LiteralSlice& bf16_literal) { return ConvertType(bf16_literal); } -/* static */ std::unique_ptr Literal::ConvertF32ToBF16( +/* static */ std::unique_ptr LiteralUtil::ConvertF32ToBF16( const LiteralSlice& f32_literal) { return ConvertType(f32_literal); } -template -Status Literal::CopySliceFromInternal( - const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { - TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); - TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size()); - - auto linear_index = [](const Shape& shape, - tensorflow::gtl::ArraySlice multi_index) { - return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index); - }; - - if (ShapeUtil::Rank(src_literal.shape()) == 0 || - ShapeUtil::Rank(shape()) == 0) { - // If any of the two shapes are scalars, we can just call the StridedCopy() - // directly, and we know we will be copying only one value. - TF_RET_CHECK(copy_size.empty()); - StridedCopy(data(), linear_index(shape(), dest_base), 0, - src_literal.data(), - linear_index(src_literal.shape(), src_base), 0, 1); - } else if (!ShapeUtil::IsZeroElementArray(shape()) && - !ShapeUtil::IsZeroElementArray(src_literal.shape())) { - // Perform copy if neither src nor dest has dimensions with zero element, - // otherwise it's a no-op. - TF_RET_CHECK(src_base.size() == dest_base.size()); - TF_RET_CHECK(src_base.size() == copy_size.size()); - - // Scan the source from minor, stepping in copy size blocks, then within - // the index enumaration functor, do a strided copy advancing source index - // by one (walking through the minor dimension), and destination index by - // proper stride size at the matching dimension. - DimensionVector src_indexes(src_base.size(), 0); - DimensionVector dest_indexes(dest_base.size(), 0); - Literal::StrideConfig stride_config(src_literal.shape(), shape(), - copy_size); - - auto copy_proc = [&](tensorflow::gtl::ArraySlice indexes) { - // Map from multi-dimensional index, to source index. - std::transform(indexes.begin(), indexes.end(), src_base.begin(), - src_indexes.begin(), std::plus()); - // Map from multi-dimensional index, to destination index. - std::transform(indexes.begin(), indexes.end(), dest_base.begin(), - dest_indexes.begin(), std::plus()); - - int64 src_index = linear_index(src_literal.shape(), src_indexes); - int64 dest_index = linear_index(shape(), dest_indexes); - - // `this->` is needed to workaround MSVC bug: #16882 - StridedCopy(this->data(), dest_index, stride_config.dest_stride, - src_literal.data(), src_index, - stride_config.source_stride, stride_config.minor_loop_size); - return true; - }; - - ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base, - stride_config.dimensions, stride_config.step, - copy_proc); - } - return Status::OK(); -} - -Status Literal::CopyElementFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index) { - DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); - const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex( - src_literal.shape(), src_index); - const int64 dest_linear_index = - IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index); - const int64 primitive_size = - ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); - - char* dest_address = - static_cast(untyped_data()) + dest_linear_index * primitive_size; - const char* source_address = - static_cast(src_literal.untyped_data()) + - src_linear_index * primitive_size; - if (dest_address != source_address) { - memcpy(dest_address, source_address, primitive_size); - } - return Status::OK(); -} - -/* static */ std::unique_ptr Literal::CreateToken() { +/* static */ std::unique_ptr LiteralUtil::CreateToken() { return MakeUnique(ShapeUtil::MakeTokenShape()); } -std::vector Literal::DecomposeTuple() { - CHECK(ShapeUtil::IsTuple(shape())); - std::vector elements; - for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { - elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), - /*allocate_arrays=*/false)); - Literal& element = elements.back(); - element.root_piece_->ForEachMutableSubpiece( - [&](const ShapeIndex& index, Piece* dest_piece) { - ShapeIndex src_index = {i}; - for (int64 j : index) { - src_index.push_back(j); - } - Piece& src_piece = piece(src_index); - - // Move the respective buffer and sparse indices over to the element - // Literal. - dest_piece->set_buffer(src_piece.buffer()); - src_piece.set_buffer(nullptr); - dest_piece->set_sparse_indices(src_piece.sparse_indices()); - src_piece.set_sparse_indices(nullptr); - }); - } - // Set this literal to be nil-shaped. - *this = Literal(); - return elements; -} - -/* static */ Literal Literal::MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements) { - std::vector element_shapes; - for (const Literal& element : elements) { - element_shapes.push_back(element.shape()); - } - Literal literal(ShapeUtil::MakeTupleShape(element_shapes), - /*allocate_arrays=*/false); - for (int i = 0; i < elements.size(); ++i) { - TF_CHECK_OK( - literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i})); - } - return literal; -} - -namespace { - -// Copies the elements in 'src' to 'dest'. The shape and layout of the data in -// the array slices are indicated by dest_shape and src_shape respectively. -template -void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, - tensorflow::gtl::ArraySlice src, - const Shape& dest_shape, const Shape& src_shape) { - CHECK(ShapeUtil::Compatible(dest_shape, src_shape)); - if (ShapeUtil::IsZeroElementArray(dest_shape)) { - return; - } - std::vector index(ShapeUtil::Rank(dest_shape)); - do { - dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] = - src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)]; - } while (IndexUtil::BumpIndices(dest_shape, &index)); -} - -} // namespace - -Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { - CHECK(subshape_ != nullptr); - CHECK(src.subshape_ != nullptr); - if (ShapeUtil::Equal(subshape(), src.subshape())) { - // If the layouts are equal it's faster just to memcpy. - memcpy(buffer(), src.buffer(), src.size_bytes()); - } else { - TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape())); - std::vector origin(ShapeUtil::Rank(subshape()), 0); - switch (subshape().element_type()) { -#define COPY_ELEMENTS(XLA_T, NATIVE_T) \ - case (XLA_T): \ - CopyElementsBetween(data(), src.data(), \ - subshape(), src.subshape()); \ - break; - COPY_ELEMENTS(U8, uint8); - COPY_ELEMENTS(U16, uint16); - COPY_ELEMENTS(U32, uint32); - COPY_ELEMENTS(U64, uint64); - COPY_ELEMENTS(S8, int8); - COPY_ELEMENTS(S16, int16); - COPY_ELEMENTS(S32, int32); - COPY_ELEMENTS(S64, int64); - COPY_ELEMENTS(F16, half); - COPY_ELEMENTS(BF16, bfloat16); - COPY_ELEMENTS(F32, float); - COPY_ELEMENTS(F64, double); - COPY_ELEMENTS(C64, complex64); - COPY_ELEMENTS(PRED, bool); -#undef COPY_ELEMENTS - default: - return Unimplemented( - "Copying a Literal object with element type %s is not implemented.", - PrimitiveType_Name(subshape().element_type()).c_str()); - } - } - return Status::OK(); -} - -Status Literal::CopyFrom(const LiteralSlice& src_literal, - const ShapeIndex& dest_shape_index, - const ShapeIndex& src_shape_index) { - const Shape& dest_subshape = - ShapeUtil::GetSubshape(shape(), dest_shape_index); - const Shape& src_subshape = - ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index); - if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) { - return InvalidArgument( - "Destination subshape incompatible with source subshape: %s vs %s", - ShapeUtil::HumanString(dest_subshape).c_str(), - ShapeUtil::HumanString(src_subshape).c_str()); - } - return root_piece_->ForEachMutableSubpieceWithStatus( - [&](const ShapeIndex& index, Piece* piece) { - if (!ShapeUtil::IsArray(piece->subshape())) { - return Status::OK(); - } - - // Determine if this index is in the part of this literal that we want - // to copy over from src_literal. - bool in_subtree_to_copy = true; - for (int i = 0; i < dest_shape_index.size(); ++i) { - if (index[i] != dest_shape_index[i]) { - in_subtree_to_copy = false; - break; - } - } - if (!in_subtree_to_copy) { - return Status::OK(); - } - // Construct the index of the corresponding piece in the source literal. - ShapeIndex src_piece_index = src_shape_index; - for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { - src_piece_index.push_back(index[i]); - } - TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index))); - return Status::OK(); - }); -} - -Status Literal::MoveFrom(Literal&& src_literal, - const ShapeIndex& dest_shape_index) { - const Shape& dest_subshape = - ShapeUtil::GetSubshape(shape(), dest_shape_index); - if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) { - return InvalidArgument( - "Destination subshape not equal to source shape: %s vs %s", - ShapeUtil::HumanString(dest_subshape).c_str(), - ShapeUtil::HumanString(src_literal.shape()).c_str()); - } - - src_literal.root_piece_->ForEachSubpiece( - [&](const ShapeIndex& src_index, const Piece& src_piece) { - if (!ShapeUtil::IsArray(src_piece.subshape())) { - return; - } - - ShapeIndex dest_index = dest_shape_index; - for (int64 i : src_index) { - dest_index.push_back(i); - } - Piece& dest_piece = piece(dest_index); - delete[] dest_piece.buffer(); - dest_piece.set_buffer(src_piece.buffer()); - delete dest_piece.sparse_indices(); - dest_piece.set_sparse_indices(src_piece.sparse_indices()); - }); - - src_literal.shape_ = MakeUnique(ShapeUtil::MakeNil()); - delete src_literal.root_piece_; - src_literal.root_piece_ = new LiteralBase::Piece(); - src_literal.root_piece_->set_subshape(src_literal.shape_.get()); - - return Status::OK(); -} - -Status Literal::CopySliceFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { - TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); - TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) - << ShapeUtil::HumanString(src_literal.shape()); - TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape())); - - switch (shape().element_type()) { - case U8: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case U16: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case U32: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case U64: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case S8: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case S16: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case S32: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case S64: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case F16: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case BF16: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case F32: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case F64: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case C64: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - case PRED: - return CopySliceFromInternal(src_literal, src_base, dest_base, - copy_size); - default: - break; - } - return Unimplemented( - "Copying a slice from a Literal object with element type %d is not " - "implemented.", - shape().element_type()); -} - -/* static */ Literal Literal::Zero(PrimitiveType primitive_type) { +/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case U32: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case U64: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case S8: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case S32: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case S64: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case F16: - return std::move(*Literal::CreateR0(static_cast(0.0f))); + return std::move(*LiteralUtil::CreateR0(static_cast(0.0f))); case BF16: return std::move( - *Literal::CreateR0(static_cast(0.0f))); + *LiteralUtil::CreateR0(static_cast(0.0f))); case F32: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case F64: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case C64: - return std::move(*Literal::CreateR0(0)); + return std::move(*LiteralUtil::CreateR0(0)); case PRED: - return std::move(*Literal::CreateR0(false)); + return std::move(*LiteralUtil::CreateR0(false)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -614,33 +144,33 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, } } -/* static */ Literal Literal::One(PrimitiveType primitive_type) { +/* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case U32: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case U64: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case S8: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case S32: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case S64: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case F16: - return std::move(*Literal::CreateR0(static_cast(1.0f))); + return std::move(*LiteralUtil::CreateR0(static_cast(1.0f))); case BF16: return std::move( - *Literal::CreateR0(static_cast(1.0f))); + *LiteralUtil::CreateR0(static_cast(1.0f))); case F32: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case F64: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case C64: - return std::move(*Literal::CreateR0(1)); + return std::move(*LiteralUtil::CreateR0(1)); case PRED: - return std::move(*Literal::CreateR0(true)); + return std::move(*LiteralUtil::CreateR0(true)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -653,44 +183,44 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, } } -/* static */ Literal Literal::MinValue(PrimitiveType primitive_type) { +/* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: return std::move( - *Literal::CreateR0(std::numeric_limits::min())); + *LiteralUtil::CreateR0(std::numeric_limits::min())); case U32: return std::move( - *Literal::CreateR0(std::numeric_limits::min())); + *LiteralUtil::CreateR0(std::numeric_limits::min())); case U64: return std::move( - *Literal::CreateR0(std::numeric_limits::min())); + *LiteralUtil::CreateR0(std::numeric_limits::min())); case S8: return std::move( - *Literal::CreateR0(std::numeric_limits::min())); + *LiteralUtil::CreateR0(std::numeric_limits::min())); case S32: return std::move( - *Literal::CreateR0(std::numeric_limits::min())); + *LiteralUtil::CreateR0(std::numeric_limits::min())); case S64: return std::move( - *Literal::CreateR0(std::numeric_limits::min())); + *LiteralUtil::CreateR0(std::numeric_limits::min())); case F32: - return std::move( - *Literal::CreateR0(-std::numeric_limits::infinity())); + return std::move(*LiteralUtil::CreateR0( + -std::numeric_limits::infinity())); case F64: - return std::move( - *Literal::CreateR0(-std::numeric_limits::infinity())); + return std::move(*LiteralUtil::CreateR0( + -std::numeric_limits::infinity())); case C64: LOG(FATAL) << "C64 element type has no minimum value"; case PRED: - return std::move(*Literal::CreateR0(false)); + return std::move(*LiteralUtil::CreateR0(false)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return std::move(*Literal::CreateR0( + return std::move(*LiteralUtil::CreateR0( static_cast(-std::numeric_limits::infinity()))); case BF16: - return std::move(*Literal::CreateR0( + return std::move(*LiteralUtil::CreateR0( static_cast(-std::numeric_limits::infinity()))); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; @@ -701,42 +231,42 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, } } -/* static */ Literal Literal::MaxValue(PrimitiveType primitive_type) { +/* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: return std::move( - *Literal::CreateR0(std::numeric_limits::max())); + *LiteralUtil::CreateR0(std::numeric_limits::max())); case U32: return std::move( - *Literal::CreateR0(std::numeric_limits::max())); + *LiteralUtil::CreateR0(std::numeric_limits::max())); case U64: return std::move( - *Literal::CreateR0(std::numeric_limits::max())); + *LiteralUtil::CreateR0(std::numeric_limits::max())); case S8: return std::move( - *Literal::CreateR0(std::numeric_limits::max())); + *LiteralUtil::CreateR0(std::numeric_limits::max())); case S32: return std::move( - *Literal::CreateR0(std::numeric_limits::max())); + *LiteralUtil::CreateR0(std::numeric_limits::max())); case S64: return std::move( - *Literal::CreateR0(std::numeric_limits::max())); + *LiteralUtil::CreateR0(std::numeric_limits::max())); case F32: - return std::move( - *Literal::CreateR0(std::numeric_limits::infinity())); + return std::move(*LiteralUtil::CreateR0( + std::numeric_limits::infinity())); case F64: - return std::move( - *Literal::CreateR0(std::numeric_limits::infinity())); + return std::move(*LiteralUtil::CreateR0( + std::numeric_limits::infinity())); case PRED: - return std::move(*Literal::CreateR0(true)); + return std::move(*LiteralUtil::CreateR0(true)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return std::move(*Literal::CreateR0( + return std::move(*LiteralUtil::CreateR0( static_cast(std::numeric_limits::infinity()))); case BF16: - return std::move(*Literal::CreateR0( + return std::move(*LiteralUtil::CreateR0( static_cast(std::numeric_limits::infinity()))); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; @@ -747,7 +277,7 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, } } -/* static */ std::unique_ptr Literal::CreateR1( +/* static */ std::unique_ptr LiteralUtil::CreateR1( const tensorflow::core::Bitmap& values) { auto literal = MakeUnique( ShapeUtil::MakeShape(PRED, {static_cast(values.bits())})); @@ -755,17 +285,7 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal, return literal; } -void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); - CHECK_EQ(element_count(), values.bits()); - CHECK_EQ(shape().element_type(), PRED); - for (int64 i = 0; i < static_cast(values.bits()); ++i) { - Set({i}, values.get(i)); - } -} - -/* static */ std::unique_ptr Literal::CreateR1U8( +/* static */ std::unique_ptr LiteralUtil::CreateR1U8( tensorflow::StringPiece value) { auto literal = MakeUnique( ShapeUtil::MakeShape(U8, {static_cast(value.size())})); @@ -775,116 +295,13 @@ void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { return literal; } -/* static */ std::unique_ptr Literal::CreateR2F32Linspace(float from, - float to, - int64 rows, - int64 cols) { +/* static */ std::unique_ptr LiteralUtil::CreateR2F32Linspace( + float from, float to, int64 rows, int64 cols) { auto value = MakeLinspaceArray2D(from, to, rows, cols); return CreateR2FromArray2D(*value); } -std::unique_ptr LiteralBase::Relayout( - const Layout& new_layout, const ShapeIndex& shape_index) const { - // Create new shape with 'new_layout' set at the given shape index. - Shape new_shape = shape(); - Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); - TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); - *subshape->mutable_layout() = new_layout; - auto result = MakeUnique(new_shape); - TF_CHECK_OK(result->CopyFrom(*this)); - return result; -} - -std::unique_ptr LiteralBase::Relayout( - const Shape& shape_with_layout) const { - CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) - << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) - << " not compatible with literal shape " - << ShapeUtil::HumanString(shape()); - std::unique_ptr result = CreateFromShape(shape_with_layout); - ShapeUtil::ForEachSubshape( - result->shape(), - [this, &result](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsArray(subshape)) { - TF_CHECK_OK(result->CopyFrom(*this, - /*dest_shape_index=*/index, - /*src_shape_index=*/index)); - } - }); - return result; -} - -StatusOr> LiteralBase::Broadcast( - const Shape& result_shape, - tensorflow::gtl::ArraySlice dimensions) const { - if (!ShapeUtil::IsArray(shape())) { - return InvalidArgument("Broadcast only supports arrays."); - } - - for (int64 i = 0; i < dimensions.size(); i++) { - TF_RET_CHECK(shape().dimensions(i) == - result_shape.dimensions(dimensions[i])); - } - - std::unique_ptr result = MakeUnique(result_shape); - - // scratch_source_index is temporary storage space for the computed index into - // the input literal. We put it here to avoid allocating an std::vector in - // every iteration of ShapeUtil::ForEachIndex. - std::vector scratch_source_index(shape().dimensions_size()); - - char* dest_data = static_cast(result->untyped_data()); - const char* source_data = static_cast(untyped_data()); - const int64 primitive_size = - ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); - - ShapeUtil::ForEachIndex( - result_shape, [&](tensorflow::gtl::ArraySlice output_index) { - for (int64 i = 0; i < dimensions.size(); ++i) { - scratch_source_index[i] = output_index[dimensions[i]]; - } - int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex( - result_shape, output_index); - int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex( - shape(), scratch_source_index); - memcpy(dest_data + primitive_size * dest_index, - source_data + primitive_size * source_index, primitive_size); - return true; - }); - - return std::move(result); -} - -StatusOr> LiteralBase::Reshape( - tensorflow::gtl::ArraySlice dimensions) const { - if (!ShapeUtil::IsArray(shape())) { - return InvalidArgument("Reshape does not support tuples."); - } - std::unique_ptr output; - if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { - output = - Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); - } else { - output = CloneToUnique(); - } - // Because the layout is monotonic, we can simply reuse the same sequence of - // values without changing their order. - *output->mutable_shape_do_not_use() = - ShapeUtil::MakeShape(shape().element_type(), dimensions); - - int64 elements_before = ShapeUtil::ElementsIn(shape()); - int64 elements_after = ShapeUtil::ElementsIn(output->shape()); - if (elements_before != elements_after) { - return InvalidArgument( - "Shapes before and after Literal::Reshape have different numbers " - "of elements: %s vs %s.", - ShapeUtil::HumanString(shape()).c_str(), - ShapeUtil::HumanString(output->shape()).c_str()); - } - return std::move(output); -} - -/* static */ std::unique_ptr Literal::ReshapeSlice( +/* static */ std::unique_ptr LiteralUtil::ReshapeSlice( tensorflow::gtl::ArraySlice new_dimensions, tensorflow::gtl::ArraySlice minor_to_major, const LiteralSlice& literal) { @@ -956,588 +373,77 @@ StatusOr> LiteralBase::Reshape( return new_literal; } -std::unique_ptr LiteralBase::Transpose( - tensorflow::gtl::ArraySlice permutation) const { - CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; - CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) - << "Given permutation is not a permutation of dimension numbers"; - // To transpose the array, we just permute the dimensions and layout, and - // do a straight memory copy of the raw data set. - // This is considerably faster than iterating over every array element using - // the EachCell<>() and Set<>() APIs. - std::vector inverse_permutation = InversePermutation(permutation); - Shape permuted_shape = - ShapeUtil::PermuteDimensions(inverse_permutation, shape()); - // Replace the layout with one affine to this shape, such that a - // transpose operation can be performed by leaving the flat values - // representation intact. - // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation. - // The shape with affine layout resulting from that operation will be - // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the - // most minor. - // - // Essentially, given MinMaj(Di) the position of the Di dimension within the - // minor to major vector, and given T(Di) the index that the original Di - // dimension has within the transposed array, a layout is affine if - // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major - // vector of the affine layout. - CHECK(LayoutUtil::IsDenseArray(permuted_shape)); - Layout* layout = permuted_shape.mutable_layout(); - layout->clear_minor_to_major(); - for (auto index : LayoutUtil::MinorToMajor(shape())) { - layout->add_minor_to_major(inverse_permutation[index]); - } - auto new_literal = MakeUnique(permuted_shape); - DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), - ShapeUtil::ByteSizeOf(shape())); - std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); - return new_literal; -} - -template -std::unique_ptr LiteralBase::SliceInternal( - const Shape& result_shape, - tensorflow::gtl::ArraySlice start_indices) const { - auto result_literal = MakeUnique(result_shape); - DimensionVector new_indices(ShapeUtil::Rank(result_shape)); - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, NativeT /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - NativeT value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; -} - -std::unique_ptr LiteralBase::Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const { - CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; - - DimensionVector result_dimensions; - for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { - CHECK_GE(start_indices[dnum], 0); - CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)) - << "dnum = " << dnum; - int64 dimension = limit_indices[dnum] - start_indices[dnum]; - CHECK_GE(dimension, 0) << "dnum = " << dnum; - result_dimensions.push_back(dimension); - } - const auto result_shape = - ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, - LayoutUtil::MinorToMajor(shape())); - switch (result_shape.element_type()) { - case F32: - return SliceInternal(result_shape, start_indices); - case BF16: - return SliceInternal(result_shape, start_indices); - case C64: - return SliceInternal(result_shape, start_indices); - case S32: - return SliceInternal(result_shape, start_indices); - case U32: - return SliceInternal(result_shape, start_indices); - default: - LOG(FATAL) << "not yet implemented: " - << PrimitiveType_Name(result_shape.element_type()); - } -} - -Literal LiteralBase::Clone() const { - Literal result(shape()); - TF_CHECK_OK(result.CopyFrom(*this)); - return result; -} - -std::unique_ptr LiteralBase::CloneToUnique() const { - auto result = MakeUnique(shape()); - TF_CHECK_OK(result->CopyFrom(*this)); - return result; -} - -string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const { - const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); - CHECK(LayoutUtil::IsDenseArray(subshape)); - switch (subshape.element_type()) { +/* static */ Literal LiteralUtil::GetFirstScalarLiteral( + const LiteralSlice& literal) { + CHECK(ShapeUtil::IsArray(literal.shape())); + CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0); + switch (literal.shape().element_type()) { case PRED: - return Get(multi_index, shape_index) ? "true" : "false"; + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); + // 8 bit types. case S8: - return StrCat(Get(multi_index, shape_index)); - case S16: - return StrCat(Get(multi_index, shape_index)); - case S32: - return StrCat(Get(multi_index, shape_index)); - case S64: - return StrCat(Get(multi_index, shape_index)); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); case U8: - return StrCat(Get(multi_index, shape_index)); - case U16: - return StrCat(Get(multi_index, shape_index)); - case U32: - return StrCat(Get(multi_index, shape_index)); - case U64: - return StrCat(Get(multi_index, shape_index)); - case F16: - return StrCat(static_cast(Get(multi_index, shape_index))); - case F32: - return StrCat(Get(multi_index, shape_index)); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); + // 16 bit types. case BF16: - return StrCat( - static_cast(Get(multi_index, shape_index))); - case F64: - return StrCat(Get(multi_index, shape_index)); - case C64: { - complex64 c = Get(multi_index, shape_index); - return StrCat("(", c.real(), ", ", c.imag(), ")"); - } - default: - LOG(FATAL) << PrimitiveType_Name(subshape.element_type()); - } -} - -string LiteralBase::GetSparseElementAsString( - int64 sparse_element_number, const ShapeIndex& shape_index) const { - const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); - CHECK(LayoutUtil::IsSparseArray(subshape)); - switch (subshape.element_type()) { - case PRED: - return GetSparseElement(sparse_element_number, shape_index) - ? "true" - : "false"; - case S8: - return StrCat(GetSparseElement(sparse_element_number, shape_index)); + return std::move(*LiteralUtil::CreateR0( + literal.GetFirstElement())); + case F16: + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); case S16: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case S32: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case S64: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case U8: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); case U16: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case U32: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case U64: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case F16: - return StrCat(static_cast( - GetSparseElement(sparse_element_number, shape_index))); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); + // 32 bit types. case F32: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case BF16: - return StrCat(static_cast( - GetSparseElement(sparse_element_number, shape_index))); - case F64: - return StrCat( - GetSparseElement(sparse_element_number, shape_index)); - case C64: { - complex64 c = - GetSparseElement(sparse_element_number, shape_index); - return StrCat("(", c.real(), ", ", c.imag(), ")"); - } - default: - LOG(FATAL) << "Invalid element type for sparse arrays: " - << PrimitiveType_Name(subshape.element_type()); - } -} - -StatusOr LiteralBase::GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const { - CHECK(LayoutUtil::IsDenseArray(shape())); - switch (shape().element_type()) { - case PRED: - return Get(multi_index); - case U8: - return Get(multi_index); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); case S32: - return Get(multi_index); - case S64: - return Get(multi_index); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); case U32: - return Get(multi_index); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); + // 64 bit types. + case C64: + return std::move(*LiteralUtil::CreateR0( + literal.GetFirstElement())); + case F64: + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); + case S64: + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); case U64: - return Get(multi_index); + return std::move( + *LiteralUtil::CreateR0(literal.GetFirstElement())); default: - return FailedPrecondition( - "Array element type is not integral: %s", - PrimitiveType_Name(shape().element_type()).c_str()); + LOG(FATAL) << "Unhandled primitive type " + << literal.shape().element_type(); } } -size_t LiteralBase::Hash() const { - using tensorflow::Hash64; - using tensorflow::Hash64Combine; - - size_t hash_value = ShapeUtil::Hash(shape()); - - ShapeUtil::ForEachSubshape( - shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (!ShapeUtil::IsArray(subshape)) { - return; - } - - CHECK(LayoutUtil::IsDense(subshape.layout())); - hash_value = Hash64Combine( - hash_value, Hash64(static_cast(untyped_data(index)), - size_bytes(index))); - }); - - return hash_value; +/* static */ std::unique_ptr LiteralUtil::MakeTuple( + tensorflow::gtl::ArraySlice elements) { + std::vector element_shapes; + for (const auto* element : elements) { + element_shapes.push_back(element->shape()); + } + auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + for (int i = 0; i < elements.size(); ++i) { + TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i})); + } + return literal; } -Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, - int64 value) { - CHECK(LayoutUtil::IsDenseArray(shape())); - switch (shape().element_type()) { - case PRED: - Set(multi_index, value); - break; - case U8: - Set(multi_index, value); - break; - case S32: - Set(multi_index, value); - break; - case S64: - Set(multi_index, value); - break; - case U32: - Set(multi_index, value); - break; - case U64: - Set(multi_index, value); - break; - default: - return FailedPrecondition( - "Array element type is not integral: %s", - PrimitiveType_Name(shape().element_type()).c_str()); - } - return Status::OK(); -} - -tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( - int64 sparse_element_number, const ShapeIndex& shape_index) const { - const Piece& p = piece(shape_index); - CHECK_GE(sparse_element_number, 0); - CHECK_LT(sparse_element_number, p.sparse_indices()->index_count()); - return p.sparse_indices()->At(sparse_element_number); -} - -void Literal::SortSparseElements(const ShapeIndex& shape_index) { - piece(shape_index).SortSparseElements(); -} - -Literal LiteralBase::GetFirstScalarLiteral() const { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_GT(ShapeUtil::ElementsIn(shape()), 0); - switch (shape().element_type()) { - case PRED: - return std::move(*Literal::CreateR0(GetFirstElement())); - // 8 bit types. - case S8: - return std::move(*Literal::CreateR0(GetFirstElement())); - case U8: - return std::move(*Literal::CreateR0(GetFirstElement())); - // 16 bit types. - case BF16: - return std::move( - *Literal::CreateR0(GetFirstElement())); - case F16: - return std::move(*Literal::CreateR0(GetFirstElement())); - case S16: - return std::move(*Literal::CreateR0(GetFirstElement())); - case U16: - return std::move(*Literal::CreateR0(GetFirstElement())); - // 32 bit types. - case F32: - return std::move(*Literal::CreateR0(GetFirstElement())); - case S32: - return std::move(*Literal::CreateR0(GetFirstElement())); - case U32: - return std::move(*Literal::CreateR0(GetFirstElement())); - // 64 bit types. - case C64: - return std::move( - *Literal::CreateR0(GetFirstElement())); - case F64: - return std::move(*Literal::CreateR0(GetFirstElement())); - case S64: - return std::move(*Literal::CreateR0(GetFirstElement())); - case U64: - return std::move(*Literal::CreateR0(GetFirstElement())); - default: - LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); - } -} - -void LiteralBase::Piece::SortSparseElements() { - switch (subshape().element_type()) { - case PRED: - SortSparseElementsInternal(); - break; - case S8: - SortSparseElementsInternal(); - break; - case U8: - SortSparseElementsInternal(); - break; - case S16: - SortSparseElementsInternal(); - break; - case U16: - SortSparseElementsInternal(); - break; - case S32: - SortSparseElementsInternal(); - break; - case U32: - SortSparseElementsInternal(); - break; - case S64: - SortSparseElementsInternal(); - break; - case U64: - SortSparseElementsInternal(); - break; - case F32: - SortSparseElementsInternal(); - break; - case F64: - SortSparseElementsInternal(); - break; - case C64: - SortSparseElementsInternal(); - break; - case F16: - SortSparseElementsInternal(); - break; - case BF16: - SortSparseElementsInternal(); - break; - default: - LOG(FATAL) << "Element type not valid for sparse array: " - << PrimitiveType_Name(subshape().element_type()); - } -} - -template -void LiteralBase::Piece::SortSparseElementsInternal() { - CHECK(LayoutUtil::IsSparseArray(subshape())); - int64 num_elements = sparse_indices()->index_count(); - auto values = data(); - CHECK_LE(num_elements, values.size()); - sparse_indices()->SortWithValues( - tensorflow::gtl::MutableArraySlice(values.data(), num_elements)); -} - -namespace { - -void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, - bool print_layout, std::vector* pieces) { - const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); - CHECK(LayoutUtil::HasLayout(literal.shape())); - CHECK(LayoutUtil::HasLayout(subshape)); - - auto shape_to_string = [print_layout](const Shape& shape) { - if (print_layout) { - return ShapeUtil::HumanStringWithLayout(shape); - } else { - return ShapeUtil::HumanString(shape); - } - }; - - // TODO(b/32894291): refactor this code to reduce code duplication. - if (ShapeUtil::IsTuple(subshape)) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" (\n"); - std::vector tuple_pieces; - for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { - ShapeIndex element_index = shape_index; - element_index.push_back(i); - std::vector element_pieces; - ToStringHelper(literal, element_index, print_layout, &element_pieces); - tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, "")); - } - pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n")); - pieces->push_back("\n)"); - return; - } - - if (ShapeUtil::IsToken(subshape)) { - pieces->push_back("token"); - return; - } - - if (LayoutUtil::IsSparseArray(subshape)) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back("{"); - int64 rank = ShapeUtil::Rank(subshape); - int64 num_elements = literal.sparse_element_count(); - for (int64 i = 0; i < num_elements; ++i) { - if (i > 0) { - pieces->push_back(", "); - } - if (rank == 1) { - pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); - pieces->push_back(": "); - } else { - pieces->push_back("["); - pieces->push_back( - tensorflow::str_util::Join(literal.GetSparseIndex(i), ", ")); - pieces->push_back("]: "); - } - pieces->push_back(literal.GetSparseElementAsString(i)); - } - pieces->push_back("}"); - return; - } - - CHECK(LayoutUtil::IsDenseArray(subshape)); - - auto element_to_string = - [&](tensorflow::gtl::ArraySlice indices) -> string { - PrimitiveType element_type = subshape.element_type(); - if (element_type == PRED) { - // We display predicates in a densely packed form. - return literal.Get(indices, shape_index) ? "1" : "0"; - } - return ((!indices.empty() && indices.back() > 0) ? ", " : "") + - literal.GetAsString(indices, shape_index); - }; - - if (ShapeUtil::Rank(subshape) == 0) { - pieces->push_back(literal.GetAsString({}, shape_index)); - } else if (ShapeUtil::Rank(subshape) == 1) { - pieces->push_back("{"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(element_to_string({i0})); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 2) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(" { "); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(element_to_string({i0, i1})); - } - pieces->push_back(" "); - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n"); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 3) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(i0 > 0 ? ",\n{" : "{"); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(i1 > 0 ? ",\n { " : " { "); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(element_to_string({i0, i1, i2})); - } - pieces->push_back(" }"); - } - pieces->push_back(" }"); - } - pieces->push_back("\n}"); - } else if (ShapeUtil::Rank(subshape) == 4) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(" {"); - for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { - pieces->push_back(element_to_string({i0, i1, i2, i3})); - } - pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n"); - } - pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" - : " },\n"); - } - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 5) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(Printf(" { /*i2=%lld*/\n", i2)); - for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { - pieces->push_back(" {"); - for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) { - pieces->push_back(element_to_string({i0, i1, i2, i3, i4})); - } - pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n" - : "},\n"); - } - pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n" - : " },\n"); - } - pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" - : " },\n"); - } - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); - } - pieces->push_back("}"); - } else { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {"); - literal.EachCellAsString( - [&](tensorflow::gtl::ArraySlice indices, const string& value) { - pieces->push_back(" "); - pieces->push_back(value); - }); - pieces->push_back("}"); - } -} - -} // namespace - -int64 LiteralBase::sparse_element_count() const { - CHECK(LayoutUtil::IsSparseArray(shape())); - return sparse_indices()->index_count(); -} - -string LiteralBase::ToString(bool print_layout) const { - std::vector pieces; - CHECK(LayoutUtil::HasLayout(this->shape())); - ToStringHelper(*this, {}, print_layout, &pieces); - return tensorflow::str_util::Join(pieces, ""); -} - -/* static */ std::unique_ptr Literal::MakeTuple( - tensorflow::gtl::ArraySlice elements) { - std::vector element_shapes; - for (const auto* element : elements) { - element_shapes.push_back(element->shape()); - } - auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); - for (int i = 0; i < elements.size(); ++i) { - TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i})); - } - return literal; -} - -/* static */ std::unique_ptr Literal::MakeTupleFromSlices( +/* static */ std::unique_ptr LiteralUtil::MakeTupleFromSlices( tensorflow::gtl::ArraySlice elements) { std::vector element_shapes; for (const auto& element : elements) { @@ -1550,7 +456,7 @@ string LiteralBase::ToString(bool print_layout) const { return literal; } -/* static */ std::unique_ptr Literal::MakeTupleOwned( +/* static */ std::unique_ptr LiteralUtil::MakeTupleOwned( std::vector> elements) { std::vector element_shapes; element_shapes.reserve(elements.size()); @@ -1565,822 +471,9 @@ string LiteralBase::ToString(bool print_layout) const { return literal; } -void LiteralBase::EachCellAsString( - const std::function indices, - const string& value)>& per_cell) const { - if (ShapeUtil::IsZeroElementArray(shape())) { - return; - } - std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( - shape(), /*linear_index=*/0); - do { - per_cell(indices, GetAsString(indices)); - } while (IndexUtil::BumpIndices(shape(), &indices)); -} - -namespace { -template -std::unique_ptr ConvertBetweenNativeTypesWithConverter( - const LiteralBase& src_literal, const ConverterType& converter) { - CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( - src_literal.shape(), - primitive_util::NativeToPrimitiveType())); - auto src_data = src_literal.data(); - auto dest_data = result_literal->template data(); - int64 num_elements = src_literal.element_count(); - - for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = converter(src_data[i]); - } - return result_literal; -} - -template -std::unique_ptr ConvertBetweenNativeTypes( - const LiteralBase& src_literal) { - auto converter = [](NativeSrcT src) { return static_cast(src); }; - return ConvertBetweenNativeTypesWithConverter( - src_literal, converter); -} - -template -typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), - std::unique_ptr>::type -BitcastBetweenNativeTypes(const LiteralBase& src_literal) { - auto converter = [](NativeSrcT src) { - return tensorflow::bit_cast(src); - }; - return ConvertBetweenNativeTypesWithConverter( - src_literal, converter); -} - -// This template specialization is here to make the compiler happy. bit_cast has -// a static check that the types are the same size. This specialization should -// never be used because the source and destination types are checked for -// identical sizes higher up. -template -typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), - std::unique_ptr>::type -BitcastBetweenNativeTypes(const LiteralBase& src_literal) { - LOG(FATAL) << "Invalid bitcast between types of different sizes."; -} - -template -std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { - CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = MakeUnique( - ShapeUtil::ChangeElementType(src_literal.shape(), C64)); - using NativeSrcT = - typename primitive_util::PrimitiveTypeToNative::type; - tensorflow::gtl::ArraySlice src_data = - src_literal.data(); - tensorflow::gtl::MutableArraySlice dest_data = - result_literal->data(); - int64 num_elements = src_literal.element_count(); - for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = complex64(static_cast(src_data[i]), 0); - } - return result_literal; -} - -template -std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, - bool bitcast) { - CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); - if (bitcast) { - return BitcastBetweenNativeTypes< - typename primitive_util::PrimitiveTypeToNative< - primitive_src_type>::type, - typename primitive_util::PrimitiveTypeToNative< - primitive_dest_type>::type>(src_literal); - } else { - return ConvertBetweenNativeTypes< - typename primitive_util::PrimitiveTypeToNative< - primitive_src_type>::type, - typename primitive_util::PrimitiveTypeToNative< - primitive_dest_type>::type>(src_literal); - } -} - -template -StatusOr> ConvertIfDestTypeMatches( - const LiteralBase& src_literal, PrimitiveType primitive_dest_type, - bool bitcast) { - switch (primitive_dest_type) { -#define CONVERT_IF_TYPES_MATCH(type) \ - case (type): \ - return ConvertIfTypesMatch(src_literal, \ - bitcast); - CONVERT_IF_TYPES_MATCH(PRED) - CONVERT_IF_TYPES_MATCH(S8) - CONVERT_IF_TYPES_MATCH(S32) - CONVERT_IF_TYPES_MATCH(S64) - CONVERT_IF_TYPES_MATCH(U8) - CONVERT_IF_TYPES_MATCH(U32) - CONVERT_IF_TYPES_MATCH(U64) - CONVERT_IF_TYPES_MATCH(F16) - CONVERT_IF_TYPES_MATCH(F32) - CONVERT_IF_TYPES_MATCH(F64) - CONVERT_IF_TYPES_MATCH(BF16) -#undef CONVERT_IF_TYPES_MATCH - case C64: - if (!bitcast) { - return ConvertToC64(src_literal); - } - break; - // Other types are not yet supported. - default: - break; - } - return Unimplemented( - "Converting from type %s to type %s is not implemented.", - PrimitiveType_Name(src_literal.shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); -} - -StatusOr> ConvertSwitch( - const LiteralBase& literal, PrimitiveType primitive_dest_type, - bool bitcast) { - TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); - if (literal.shape().element_type() == primitive_dest_type) { - return literal.CloneToUnique(); - } - switch (literal.shape().element_type()) { -#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ - case (type): \ - return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \ - bitcast); - CONVERT_IF_DEST_TYPE_MATCHES(PRED) - CONVERT_IF_DEST_TYPE_MATCHES(S8) - CONVERT_IF_DEST_TYPE_MATCHES(S32) - CONVERT_IF_DEST_TYPE_MATCHES(S64) - CONVERT_IF_DEST_TYPE_MATCHES(U8) - CONVERT_IF_DEST_TYPE_MATCHES(U32) - CONVERT_IF_DEST_TYPE_MATCHES(U64) - CONVERT_IF_DEST_TYPE_MATCHES(F16) - CONVERT_IF_DEST_TYPE_MATCHES(F32) - CONVERT_IF_DEST_TYPE_MATCHES(F64) - CONVERT_IF_DEST_TYPE_MATCHES(BF16) -#undef CONVERT_IF_DEST_TYPE_MATCHES - // Other types are not yet supported. - default: - return Unimplemented( - "%s from type %s to type %s is not implemented.", - (bitcast ? "Bitcast converting" : "Converting"), - PrimitiveType_Name(literal.shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); - } -} - -} // namespace - -StatusOr> LiteralBase::Convert( - PrimitiveType primitive_dest_type) const { - return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); -} - -StatusOr> LiteralBase::BitcastConvert( - PrimitiveType primitive_dest_type) const { - if (primitive_util::BitWidth(shape().element_type()) != - primitive_util::BitWidth(primitive_dest_type)) { - return InvalidArgument( - "Cannot bitcast convert from %s to %s, bit widths are different: %d != " - "%d", - PrimitiveType_Name(shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str(), - primitive_util::BitWidth(shape().element_type()), - primitive_util::BitWidth(primitive_dest_type)); - } - return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); -} - -StatusOr> LiteralBase::ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16) const { - if (!ShapeUtil::IsTuple(dest_shape)) { - if (round_f32_to_bf16 && shape().element_type() == F32 && - dest_shape.element_type() == BF16) { - auto converter = [](float src) { - return tensorflow::bfloat16::round_to_bfloat16(src); - }; - return ConvertBetweenNativeTypesWithConverter(*this, - converter); - } - return Convert(dest_shape.element_type()); - } - std::vector elements; - for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { - auto element = LiteralSlice(*this, {i}); - TF_ASSIGN_OR_RETURN( - auto new_element, - element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); - elements.push_back(std::move(*new_element)); - } - auto converted = MakeUnique(); - *converted = Literal::MoveIntoTuple(&elements); - return std::move(converted); -} - -template -bool LiteralBase::Piece::EqualElementsInternal( - const LiteralBase::Piece& other, std::vector* multi_index) const { - if (multi_index->size() == ShapeUtil::Rank(subshape())) { - return (Get(*multi_index) == other.Get(*multi_index)); - } - for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) { - multi_index->push_back(i); - if (!EqualElementsInternal(other, multi_index)) { - return false; - } - multi_index->pop_back(); - } - return true; -} - -bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { - DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); - - std::vector multi_index; - switch (subshape().element_type()) { - case PRED: - return EqualElementsInternal(other, &multi_index); - case U8: - return EqualElementsInternal(other, &multi_index); - case S32: - return EqualElementsInternal(other, &multi_index); - case S64: - return EqualElementsInternal(other, &multi_index); - case U32: - return EqualElementsInternal(other, &multi_index); - case U64: - return EqualElementsInternal(other, &multi_index); - case F32: - return EqualElementsInternal(other, &multi_index); - case F64: - return EqualElementsInternal(other, &multi_index); - case F16: - return EqualElementsInternal(other, &multi_index); - case BF16: - return EqualElementsInternal(other, &multi_index); - case C64: - return EqualElementsInternal(other, &multi_index); - default: - LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type " - << PrimitiveType_Name(subshape().element_type()); - } -} - -bool LiteralBase::operator==(const LiteralBase& other) const { - if (!ShapeUtil::Compatible(shape(), other.shape())) { - return false; - } - - return root_piece().ForEachSubpieceWithBool( - [&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { - return true; - } - - const Piece& other_piece = other.piece(index); - if (!piece.EqualElements(other_piece)) { - return false; - } - return true; - }); -} - -namespace { - -template -static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, - NativeT value) { - for (int64 i = 0; i < data.size(); ++i) { - if (data[i] != value) { - return false; - } - } - return true; -} - -} // namespace - -bool LiteralBase::IsAll(int8 value) const { - return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index, - const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { - return true; - } - - auto piece_is_all = [&]() { - switch (shape().element_type()) { - case U8: - if (value >= 0) { - return AllElementsEqualValue(piece.data(), value); - } - return false; - case U32: - if (value >= 0) { - return AllElementsEqualValue(piece.data(), value); - } - return false; - case U64: - if (value >= 0) { - return AllElementsEqualValue(piece.data(), value); - } - return false; - case S8: - return AllElementsEqualValue(piece.data(), value); - case S32: - return AllElementsEqualValue(piece.data(), value); - case S64: - return AllElementsEqualValue(piece.data(), value); - case F32: - return AllElementsEqualValue(piece.data(), value); - case F64: - return AllElementsEqualValue(piece.data(), value); - case F16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - case BF16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - case PRED: - if (value == 0) { - return AllElementsEqualValue(piece.data(), false); - } - if (value == 1) { - return AllElementsEqualValue(piece.data(), true); - } - return false; - default: - return false; - } - return false; - }; - - if (!piece_is_all()) { - return false; - } - return true; - }); -} - -bool LiteralBase::IsAllFloat(float value) const { - return root_piece().ForEachSubpieceWithBool( - [&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { - return true; - } - - auto piece_is_all = [&]() { - switch (shape().element_type()) { - case F32: - return AllElementsEqualValue(piece.data(), value); - case F64: - return AllElementsEqualValue(piece.data(), value); - case F16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - case BF16: - return AllElementsEqualValue( - piece.data(), static_cast(value)); - default: - return false; - } - }; - if (!piece_is_all()) { - return false; - } - return true; - }); -} - -bool LiteralBase::IsAllComplex(complex64 value) const { - switch (shape().element_type()) { - case C64: - return AllElementsEqualValue(root_piece().data(), - value); - default: - return false; - } -} - -bool LiteralBase::IsAllFirst() const { - return root_piece().ForEachSubpieceWithBool( - [&](const ShapeIndex& index, const Piece& piece) { - if (!ShapeUtil::IsArray(piece.subshape())) { - return true; - } - - // Empty shapes are not all the first element since there is no first - // element. - if (ShapeUtil::IsZeroElementArray(piece.subshape())) { - return false; - } - auto piece_is_all = [&]() { - switch (piece.subshape().element_type()) { - case PRED: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 8 bit types - case S8: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U8: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 16 bit types - case BF16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case F16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 32 bit types - case F32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 64 bit types - case C64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case F64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - default: - return false; - } - }; - - if (!piece_is_all()) { - return false; - } - return true; - }); -} - -bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice indices) const { - CHECK(ShapeUtil::IsArray(shape())); - switch (shape().element_type()) { - case U8: - return Get(indices) == 0; - case U32: - return Get(indices) == 0; - case U64: - return Get(indices) == 0; - case S8: - return Get(indices) == 0; - case S32: - return Get(indices) == 0; - case S64: - return Get(indices) == 0; - case F32: - return Get(indices) == 0.0f; - case F64: - return Get(indices) == 0.0; - case C64: - return Get(indices) == complex64(0.0f, 0.0f); - case F16: - return Get(indices) == static_cast(0.0f); - case BF16: - return Get(indices) == static_cast(0.0f); - case PRED: - return Get(indices) == false; - default: - LOG(FATAL) << "Input literal must be an array."; - } -} - -namespace { - -template -void CopyToRepeatedField(RepeatedFieldT* dest, - const tensorflow::gtl::ArraySlice src) { - *dest = RepeatedFieldT(src.begin(), src.end()); -} - -} // namespace - -void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { - *proto->mutable_shape() = subshape(); - switch (subshape().element_type()) { - case PRED: - CopyToRepeatedField(proto->mutable_preds(), data()); - break; - case U8: - proto->set_u8s(static_cast(data().data()), - element_count()); - break; - case U32: - CopyToRepeatedField(proto->mutable_u32s(), data()); - break; - case U64: - CopyToRepeatedField(proto->mutable_u64s(), data()); - break; - case S32: - CopyToRepeatedField(proto->mutable_s32s(), data()); - break; - case S64: - CopyToRepeatedField(proto->mutable_s64s(), data()); - break; - case F16: - *proto->mutable_f16s() = string( - reinterpret_cast(data().data()), size_bytes()); - if (!kLittleEndian) { - ConvertEndianShort(proto->mutable_f16s()); - } - break; - case BF16: - *proto->mutable_bf16s() = string( - reinterpret_cast(data().data()), size_bytes()); - if (!kLittleEndian) { - ConvertEndianShort(proto->mutable_bf16s()); - } - break; - case F32: - CopyToRepeatedField(proto->mutable_f32s(), data()); - break; - case F64: - CopyToRepeatedField(proto->mutable_f64s(), data()); - break; - case C64: - for (complex64 value : data()) { - proto->add_c64s(value.real()); - proto->add_c64s(value.imag()); - } - break; - case TUPLE: - case TOKEN: - // Nothing to do but assign the shape which is done above. - return; - default: - LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); - } -} - -const void* LiteralBase::Piece::untyped_data() const { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - return buffer(); -} - -void* LiteralBase::Piece::untyped_data() { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - return buffer(); -} - -namespace { - -template -Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, - const RepeatedFieldT& src) { - if (dest.size() != src.size()) { - return InvalidArgument( - "Expected %lu elements in LiteralProto repeated field, has %d", - dest.size(), src.size()); - } - std::copy(src.begin(), src.end(), dest.begin()); - return Status::OK(); -} - -} // namespace - -Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { - // These conditions should have been checked in Literal::CreateFromProto. - TF_RET_CHECK(proto.has_shape()); - TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); - TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); - - switch (subshape().element_type()) { - case PRED: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.preds())); - break; - case U8: { - auto u8_data = data(); - TF_RET_CHECK(proto.u8s().size() == u8_data.size()); - std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin()); - } break; - case S32: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.s32s())); - break; - case S64: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.s64s())); - break; - case U32: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u32s())); - break; - case U64: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u64s())); - break; - case F16: { - const string& s(proto.f16s()); - TF_RET_CHECK(data().size() * sizeof(half) == s.size()); - memcpy(untyped_data(), s.data(), s.size()); - if (!kLittleEndian) { - ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); - } - } break; - - case BF16: { - const string& s(proto.bf16s()); - TF_RET_CHECK(data().size() * sizeof(bfloat16) == s.size()); - memcpy(untyped_data(), s.data(), s.size()); - if (!kLittleEndian) { - ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); - } - } break; - case F32: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.f32s())); - break; - case F64: - TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.f64s())); - break; - case C64: { - auto complex_data = data(); - TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2); - for (int64 i = 0; i < complex_data.size(); ++i) { - complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)}; - } - } break; - case TUPLE: - LOG(FATAL) << "Should not be called on tuple shapes: " - << ShapeUtil::HumanString(subshape()); - break; - default: - LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); - } - return Status::OK(); -} - -LiteralProto LiteralBase::ToProto() const { - LiteralProto proto; - root_piece().ForEachSubpiece( - [&](const ShapeIndex& index, const Piece& piece) { - LiteralProto* proto_piece = &proto; - for (int64 i : index) { - while (proto_piece->tuple_literals_size() <= i) { - proto_piece->add_tuple_literals(); - } - proto_piece = proto_piece->mutable_tuple_literals(i); - } - piece.WriteToProto(proto_piece); - }); - - if (LayoutUtil::IsSparseArray(shape())) { - CopyToRepeatedField(proto.mutable_sparse_indices(), - sparse_indices()->data()); - } - - return proto; -} - -/* static */ -StatusOr> Literal::CreateFromProto( - const LiteralProto& proto) { - if (!proto.has_shape()) { - return InvalidArgument("LiteralProto has no shape"); - } - if (!LayoutUtil::HasLayout(proto.shape())) { - return InvalidArgument("LiteralProto has no layout"); - } - - auto literal = MakeUnique(proto.shape()); - - TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( - [&](const ShapeIndex& index, Piece* piece) { - const LiteralProto* proto_element = &proto; - for (int64 i : index) { - CHECK(i < proto_element->tuple_literals_size()); - proto_element = &proto_element->tuple_literals(i); - } - - if (ShapeUtil::IsTuple(piece->subshape())) { - if (proto_element->tuple_literals_size() != - ShapeUtil::TupleElementCount(piece->subshape())) { - return InvalidArgument( - "Expected %lld tuple elements in LiteralProto, has %d", - ShapeUtil::TupleElementCount(piece->subshape()), - proto_element->tuple_literals_size()); - } - return Status::OK(); - } - if (piece->subshape().element_type() == TOKEN) { - return Status::OK(); - } - - CHECK(ShapeUtil::IsArray(piece->subshape())); - TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); - - return Status::OK(); - })); - - return std::move(literal); -} - -/* static */ string Literal::MultiIndexAsString( +/* static */ string LiteralUtil::MultiIndexAsString( tensorflow::gtl::ArraySlice multi_index) { return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}"); } -const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { - return piece(shape_index).untyped_data(); -} - -void* Literal::untyped_data(const ShapeIndex& shape_index) { - return piece(shape_index).untyped_data(); -} - -int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const { - return piece(shape_index).size_bytes(); -} - -string LiteralBase::GetR1U8AsString() const { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); - CHECK_EQ(shape().element_type(), U8); - return string(tensorflow::bit_cast(data().data()), - ShapeUtil::ElementsIn(shape())); -} - -void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { - CHECK(ShapeUtil::IsTuple(shape)); - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - const Shape& subshape = shape.tuple_shapes(i); - - auto child_piece = Piece(); - child_piece.set_subshape(&subshape); - - if (ShapeUtil::IsTuple(subshape)) { - BuildPieceSubtree(subshape, &child_piece); - } - - piece->emplace_back(std::move(child_piece)); - } -} - -LiteralSlice::LiteralSlice(const LiteralBase& literal) - : LiteralBase(), root_piece_(&literal.root_piece()) {} - -LiteralSlice::LiteralSlice(const LiteralBase& literal, - const ShapeIndex& view_root) - : LiteralBase(), root_piece_(&literal.piece(view_root)) {} - -BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) - : LiteralBase(), shape_(MakeUnique(shape)) { - CHECK(ShapeUtil::IsArray(*shape_)); - CHECK(LayoutUtil::HasLayout(*shape_)); - - root_piece_ = Piece(); - root_piece_.set_buffer(const_cast(src_buf_ptr)); - root_piece_.set_subshape(shape_.get()); -} - -BorrowingLiteral::BorrowingLiteral( - tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) - : LiteralBase(), shape_(MakeUnique(shape)) { - CHECK(ShapeUtil::IsTuple(*shape_)); - CHECK(!ShapeUtil::IsNestedTuple(*shape_)); - CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); - root_piece_ = Piece(); - root_piece_.set_subshape(shape_.get()); - BuildPieceSubtree(*shape_, &root_piece_); - - for (int i = 0; i < src_buf_ptrs.size(); ++i) { - const auto& src_shape = shape_->tuple_shapes(i); - CHECK(ShapeUtil::IsArray(src_shape)); - root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); - } -} - } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 37ca8ea9f1d158b6bce8d5688288351f55c3b3c8..e3737a9d0051b32dc0becc19e1849c856a50e52e 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -51,679 +52,12 @@ limitations under the License. namespace xla { -// Forward declare Literal and LiteralSlice class to be used by the creation -// methods in the base class. -class Literal; -class LiteralSlice; - -// Abstract base class for literals. -class LiteralBase { +class LiteralUtil { public: - virtual ~LiteralBase() = 0; - - // Literals are equal if they have compatible shapes and the same data - // values. Layout is not compared. - bool operator==(const LiteralBase& other) const; - bool operator!=(const LiteralBase& other) const { return !(*this == other); } - - // Returns the shape of the literal. - const Shape& shape() const { return root_piece().subshape(); } - - // Serialize to proto. - LiteralProto ToProto() const; - - // Returns an ArraySlice of the array for this literal for the given NativeT - // (e.g., float). CHECKs if the subshape of the literal at the given - // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type - // to native type. - template - tensorflow::gtl::ArraySlice data( - const ShapeIndex& shape_index = {}) const; - - // Returns a const pointer to the sparse index array. Returns nullptr if the - // literal is not a sparse array. - const SparseIndexArray* sparse_indices( - const ShapeIndex& shape_index = {}) const; - - // Returns a const pointer to (or size of) the underlying buffer holding the - // array at the given shape index. CHECKs if the subshape of the literal at - // the given ShapeIndex is not array. - const void* untyped_data(const ShapeIndex& shape_index = {}) const; - int64 size_bytes(const ShapeIndex& shape_index = {}) const; - - // Returns this literal's data as a string. This literal must be a rank-1 U8 - // array. - string GetR1U8AsString() const; - - // Returns a string representation of the literal value. - // Warning: this function can take minutes for multi-million element Literals. - string ToString(bool print_layout = false) const; - - // Gets an element in the literal at the given index. The multi_index is - // CHECKed against the dimension sizes. - template - NativeT Get(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const; - // Overloads of Get for array literals. CHECKs if the literal is not - // array-shaped and dense. - template - NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; - - // Returns the element value at index (0, ..., 0), however many zeroes are - // required for that index. - template - NativeT GetFirstElement() const; - - // As Get(), but determines the correct type and converts the value - // into text. - string GetAsString(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index = {}) const; - // As GetSparseElement(), but determines the correct type and converts the - // value into text. - string GetSparseElementAsString(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; - // As Get(), but determines the correct type and converts the value into - // int64. This literal must be an array. - StatusOr GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const; - - // Returns the multi-index of the element in a sparse literal at the given - // sparse element number. The sparse element number is the position with in - // the sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - tensorflow::gtl::ArraySlice GetSparseIndex( - int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; - - // Returns the value of the element in a sparse literal at the given sparse - // element number. The sparse element number is the position with in the - // sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - template - NativeT GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; - - // Invokes the "per cell" callback for each element in the provided - // literal with the element's indices and a string representation of - // the element's value. - // - // This function is useful if you want a polymorphic representation - // of the tensor's elements (turning it to a string for something - // like representation in a protobuf). - // - // This literal must have a dense layout. - void EachCellAsString( - const std::function indices, - const string& value)>& per_cell) const; - template - void EachCell(std::function indices, - NativeT value)> - per_cell) const; - - // Returns whether every element in this literal is equal to value. - // - // value is an int8 because we expect this to be called with small - // compile-time constants (0, -1, etc.) and so that whatever value you pass - // can be represented exactly by floating-point types as small as 16 bits. - // - // If value doesn't fit in this literal's type, returns false. Values of 1/0 - // are considered equal to true/false; other values are not considered equal - // to true. Also if this literal is not array-shaped false is returned. - bool IsAll(int8 value) const; - - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular floating-point number. - // - // If the literal is not a floating-point value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for values that can be expressed precisely as a float, - // e.g. -0.5. Also if this literal is not array-shaped false is returned. - bool IsAllFloat(float value) const; - - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular complex number. - // - // If the literal is not a complex value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for complex values that can be expressed precisely as - // float pairs e.g. (-0.5, 1.0). - // - // This literal must have a dense layout. - bool IsAllComplex(complex64 value) const; - - // Literal consists entirely of the first element of the literal. - bool IsAllFirst() const; - - // Returns whether this literal is zero at the specified index. This literal - // must be an array with a dense layout. - bool IsZero(tensorflow::gtl::ArraySlice indices) const; - - // Returns the count of the elements in the array at the given shape index in - // this literal. - int64 element_count(const ShapeIndex& index = {}) const { - return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); - } - - // Returns the count of the elements in the sparse array at the given shape - // index in this literal, which will be no larger than - // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). - int64 sparse_element_count() const; - - // Compute a hash for this literal. This literal must not be a sparse tensor - // or a tuple containing a sparse tensor. - size_t Hash() const; - - // Converts this literal to the given shape. Returns an error is the - // conversion is not possible. - // - // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding - // instead of truncation; otherwise, truncation is used. - // - // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes - // the default behavior. - StatusOr> ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16 = false) const; - - // Converts this literal to another primitive type using a bitcast - // conversion. The to and from primitive types must have the same bit - // width. Returns an error if the conversion is not possible. This literal - // must be array-shaped. - StatusOr> BitcastConvert( - PrimitiveType primitive_dest_type) const; - - // Converts this literal to another primitive type. Returns an error if the - // conversion is not possible. This literal must be array-shaped. - StatusOr> Convert( - PrimitiveType primitive_dest_type) const; + LiteralUtil() = delete; // Returns a literal scalar representing the first element. - Literal GetFirstScalarLiteral() const; - - // Clones the underlying buffers into a new Literal, or new - // std::unique_ptr. - Literal Clone() const; - std::unique_ptr CloneToUnique() const; - - // TODO(b/67651157): The methods below which perform computation on Literals - // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with - // evaluator code which operates on Literals. - // - // Creates a new value that has the equivalent value as this - // literal, but conforms to new_layout; e.g. a literal matrix that was in {0, - // 1} minor-to-major dimension layout can be re-layed-out as {1, 0} - // minor-to-major dimension layout and the value in the cell at any given - // logical index (i0, i1) will be the same. - // - // For tuple shaped literals, shape_index should be used to select the inner - // array that the new layout applies to. - // - // Note: this is useful when the client wants to ensure that a value placed in - // the XLA allocation tracker has a particular layout; for efficiency - // purposes or avoiding unimplemented operation/layout combinations. - std::unique_ptr Relayout(const Layout& new_layout, - const ShapeIndex& shape_index = {}) const; - - // An overload of Relayout which changes the layout of the entire shape rather - // than being limited to a single array within the shape. - std::unique_ptr Relayout(const Shape& shape_with_layout) const; - - // Creates a new literal by reshaping this literal to have the given - // dimensions. The total number of elements must not change; The - // implementation currently only supports monotonic dim0-major layouts. - // This literal must be an array. - StatusOr> Reshape( - tensorflow::gtl::ArraySlice dimensions) const; - - // Creates a new literal by broadcasting this literal with `dimensions` to - // yield a literal of shape `result_shape`. - StatusOr> Broadcast( - const Shape& result_shape, - tensorflow::gtl::ArraySlice dimensions) const; - - // Creates a new literal by reordering the dimensions of this literal. - // The given `permutation` must be a permutation of the dimension numbers - // in the original literal, and it specifies the order of the new dimensions - // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). - // For example, a transpose call on a literal of shape [3 x 8 x 4] and - // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. - // This literal must be an array. - std::unique_ptr Transpose( - tensorflow::gtl::ArraySlice permutation) const; - - // Creates a sub-array from this literal by extracting the indices - // [start_index, limit_index) of each dimension. The result literal has the - // same rank and layout as for the given literal. The number of indices in - // start_indices and limit_indices must be the rank of the literal, and the - // indices follow the order of the dimensions. - // This literal must be an array. - std::unique_ptr Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const; - - // Creates a literal with a prepended dimension with bound "times"; e.g. a - // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this - // literal replicated four times. - // This literal must be an array. - template - std::unique_ptr Replicate(int64 times) const; - - // Creates a new Literal object with the shape specified as parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - // - // Note: It's an antipattern to use this method then immediately call - // Literal::Populate on the result (since that results in zero initialization, - // then reinitialization. Conside if a call to MakeUnique(shape), - // followed by the call to Literal::Populate can be used instead. - static std::unique_ptr CreateFromShape(const Shape& shape); - - protected: - // A data structure representing a subshape at a particular ShapeIndex within - // the literal. For array-shaped ShapeIndexes, this data structure holds the - // pointer to the memory allocated for the array data. - class Piece { - public: - // Returns the buffer holding the array data for this piece as an array - // slice. This piece must be array-shaped. - template - tensorflow::gtl::ArraySlice data() const; - template - tensorflow::gtl::MutableArraySlice data(); - - // Returns the buffer holding the array data for this piece as a void*. This - // piece must be array-shaped. - void* untyped_data(); - const void* untyped_data() const; - - // Gets or sets an element in the array at the given index. The multi_index - // is CHECKed against the dimension sizes of the array. This piece must be - // array-shaped. - template - NativeT Get(tensorflow::gtl::ArraySlice index) const; - template - void Set(tensorflow::gtl::ArraySlice index, NativeT value); - - // Gets/sets the buffer holding the array data. - char* buffer() const { return buffer_; } - void set_buffer(char* buffer) { buffer_ = buffer; } - - // The array of multi-indices that provide the locations of non-zero - // elements in a sparse array. Only used if - // LayoutUtil::IsSparseArray(shape()) is true. - SparseIndexArray* sparse_indices() const { return sparse_indices_; } - void set_sparse_indices(SparseIndexArray* sparse_indices) { - sparse_indices_ = sparse_indices; - } - - // Gets or sets the subshape of this piece. This reference points to a - // subshape within the shape in the containing Literal (Literal::shape_). - const Shape& subshape() const { return *subshape_; } - void set_subshape(const Shape* subshape) { subshape_ = subshape; } - - // Returns the size in bytes of the buffer holding the array data. - int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } - - // Returns the number of elements in this piece's array. - int64 element_count() const { - // If this is a sparse array, use the number of elements represented by - // the indices in the associated SparseIndexArray. - return LayoutUtil::IsSparseArray(subshape()) - ? sparse_indices()->index_count() - : ShapeUtil::ElementsIn(subshape()); - } - - // Returns the child piece at 'index' of this piece. - Piece& child(int64 index) { return children_[index]; } - - // Adds a child piece to this piece's children. - void emplace_back(Piece child_piece) { - children_.emplace_back(std::move(child_piece)); - } - - // Returns the size of children pieces of this piece. - int64 children_size() { return children_.size(); } - - // Visitor functions that recursively traverses the piece and calls the - // given function at each child piece. The function has the type: - // void (const ShapeIndex& index, const Piece& piece) - template - void ForEachSubpiece(const Fn& func) const { - ShapeIndex index; - return ForEachHelper( - [&func](const ShapeIndex& index, const Piece& piece) { - func(index, piece); - return Status::OK(); - }, - *this, &index) - .IgnoreError(); - } - // Same as above, but the function has the type: - // Status (const ShapeIndex& index, const Piece& piece) - // The first non-OK return value is returned by the function. - template - Status ForEachSubpieceWithStatus(const Fn& func) const { - ShapeIndex index; - return ForEachHelper(func, *this, &index); - } - // Same as above, but the function has the type: - // Bool (const ShapeIndex& index, const Piece& piece) - // The first non-true return value is returned by the function. - template - bool ForEachSubpieceWithBool(const Fn& func) const { - ShapeIndex index; - return ForEachHelperBool(func, *this, &index); - } - // Same as above, but the function has the type: - // Void (const ShapeIndex& index, Piece& piece) - template - void ForEachMutableSubpiece(const Fn& func) { - ShapeIndex index; - return ForEachMutableHelper( - [&func](const ShapeIndex& index, Piece* piece) { - func(index, piece); - return Status::OK(); - }, - const_cast(this), &index) - .IgnoreError(); - } - // Same as above, but the function has the type: - // Status (const ShapeIndex& index, Piece& piece) - // The first non-OK return value is returned by the function. - template - Status ForEachMutableSubpieceWithStatus(const Fn& func) { - ShapeIndex index; - return ForEachMutableHelper( - func, const_cast(this), &index); - } - - // Returns true if this piece and 'other' contain the same data. This piece - // and 'other' must be array-shaped and compatible. - bool EqualElements(const Piece& other) const; - - // Writes the shape and data (if array-shaped) into the given proto. - void WriteToProto(LiteralProto* proto) const; - - // Copy the data from 'src' into this piece's buffer. Shapes of this piece - // and src must be compatible. - Status CopyFrom(const Piece& src); - - // Copies the data from the given proto into this piece. The shape of this - // piece must be equal (not just compatible) to the shape of the proto. - Status CopyFromProto(const LiteralProto& proto); - - // Sorts the elements in a sparse array. - void SortSparseElements(); - - private: - // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'. - // The first non-OK (or non-true) value is returned by the function. - // The callable 'func' has the same signature as described above in - // ForEachSubpiece*. - template - Status ForEachHelper(const Fn& func, const Piece& piece, - ShapeIndex* index) const { - TF_RETURN_IF_ERROR(func(*index, piece)); - for (int64 i = 0; i < piece.children_.size(); ++i) { - index->push_back(i); - TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index)); - index->pop_back(); - } - return Status::OK(); - } - template - bool ForEachHelperBool(const Fn& func, const Piece& piece, - ShapeIndex* index) const { - if (!func(*index, piece)) { - return false; - } - for (int64 i = 0; i < piece.children_.size(); ++i) { - index->push_back(i); - if (!ForEachHelperBool(func, piece.children_[i], index)) { - return false; - } - index->pop_back(); - } - return true; - } - template - Status ForEachMutableHelper(const Fn& func, Piece* piece, - ShapeIndex* index) { - TF_RETURN_IF_ERROR(func(*index, piece)); - for (int64 i = 0; i < piece->children_.size(); ++i) { - index->push_back(i); - TF_RETURN_IF_ERROR( - ForEachMutableHelper(func, &piece->children_[i], index)); - index->pop_back(); - } - return Status::OK(); - } - - // Recursive helper for EqualElements. - template - bool EqualElementsInternal(const Piece& other, - std::vector* multi_index) const; - - // Helper for SortSparseElements that has the element type as a template - // parameter. - template - void SortSparseElementsInternal(); - - // For array-shaped pieces, this is the buffer holding the literal data. - char* buffer_ = nullptr; - - // For sparse arrays, this is the array of indices. - SparseIndexArray* sparse_indices_ = nullptr; - - // The shape of piece. This points into the shape of the containing Literal - // (Literal::shape_). - const Shape* subshape_ = nullptr; - - // Children pieces for tuple shaped pieces. - std::vector children_ = {}; - }; // class Piece - - const Piece& piece(const ShapeIndex& shape_index) const { - Piece* piece = &const_cast(root_piece()); - for (const auto i : shape_index) { - DCHECK_GE(i, 0); - DCHECK_LT(i, piece->children_size()); - piece = &piece->child(i); - } - return *piece; - } - - // Returns the piece at the root of the shape. - virtual const Piece& root_piece() const = 0; - - // LiteralSlice and Literal must access Pieces of other Literals. - friend class Literal; - friend class LiteralSlice; - friend class BorrowingLiteral; - - private: - template - std::unique_ptr SliceInternal( - const Shape& result_shape, - tensorflow::gtl::ArraySlice start_indices) const; -}; - -// Class representing literal values in XLA. -// -// The underlying buffer and shape is always owned by this class. -class Literal : public LiteralBase { - public: - Literal() : Literal(ShapeUtil::MakeNil()) {} - - // Create a literal of the given shape. The literal is allocated sufficient - // memory to hold the shape. Memory is uninitialized. - explicit Literal(const Shape& shape); - virtual ~Literal(); - - // Literals are moveable, but not copyable. To copy a literal use - // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies - // of literals which can be expensive. - Literal(const Literal& other) = delete; - Literal& operator=(const Literal& other) = delete; - Literal(Literal&& other); - // 'allocate_arrays' indicates whether to allocate memory for the arrays in - // the shape. If false, buffer pointers inside of the Literal::Pieces are set - // to nullptr. - Literal(const Shape& shape, bool allocate_arrays); - Literal& operator=(Literal&& other); - - // TODO(b/67651157): Remove this accessor. Literal users should not be able to - // mutate the shape as this can produce malformed Literals. - Shape* mutable_shape_do_not_use() { return shape_.get(); } - - // Returns a MutableArraySlice view of the array for this literal for the - // given NativeT (e.g., float). CHECKs if the subshape of the literal at the - // given ShapeIndex is not array. See primitive_util.h for the mapping from - // XLA type to native type. - template - tensorflow::gtl::MutableArraySlice data( - const ShapeIndex& shape_index = {}); - // Unhide const method from parent class. - using LiteralBase::data; - - // Returns a pointer to the sparse index array. Returns nullptr if the literal - // is not a sparse array. - SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); - - // Returns a pointer to the underlying buffer holding the array at the given - // shape index. CHECKs if the subshape of the literal at the given ShapeIndex - // is not array. - void* untyped_data(const ShapeIndex& shape_index = {}); - // Unhide const method from parent class. - using LiteralBase::untyped_data; - - // Populates a literal with a sparse layout with the given indices and values. - // Each index in the indices array is CHECKed against the dimensions in the - // literal's shape. If sort is true, then the indices and values will be - // sorted. If sort is false, then the indices and values are assumed to - // already be in sorted order. See CreateSparse for an example of how data - // are populated. - template - void PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort = true); - - // Copy values from 'src_literal' rooted at 'src_shape_index' into this - // literal rooted at 'dest_shape_index'. The subshape of this literal rooted - // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' - // rooted at 'src_shape_index', but need not be arrays. - Status CopyFrom(const LiteralSlice& src_literal, - const ShapeIndex& dest_shape_index = {}, - const ShapeIndex& src_shape_index = {}); - - // Similar to CopyFrom, but with move semantincs. The subshape of this literal - // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' - // (layouts and shapes must match), but need not be arrays. The memory - // allocated in this literal for the subshape at dest_shape_index is - // deallocated, and the respective buffers are replaced with those in - // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). - Status MoveFrom(Literal&& src_literal, - const ShapeIndex& dest_shape_index = {}); - - // Copies the values from src_literal, starting at src_base shape indexes, - // to this literal, starting at dest_base, where the copy size in each - // dimension is specified by copy_size. - // The src_literal and this literal must have the same primitive type, - // src_base+copy_size must fit the source literal dimensions, as well as - // dest_base+copy_size must fit the destination literal dimensions. - // Note: if either src_literal or this literal contains dimensions with zero - // element, then copy_size must be 0 in these dimensions while the - // corresponding base indices being 0. - // This literal and 'src_literal' must be arrays. - Status CopySliceFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); - - // Copies one element from src_literal[src_index] to (*this)[dest_index]. - Status CopyElementFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index); - - // Sets an element in the literal at the given index. The multi_index is - // CHECKed against the dimension sizes. - template - void Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value); - // Overloads of Set for array literals. CHECKs if the literal is not - // array-shaped and dense. - template - void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); - - // Appends the given element to the literal. If the elements are not appended - // in sorted order, then SortSparseElements should be called before calling - // other methods. This literal must have a sparse layout. - template - void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, - NativeT value, const ShapeIndex& shape_index = {}); - - // Sorts the elements in a sparse array. - void SortSparseElements(const ShapeIndex& shape_index = {}); - - // As Set(), but truncates `value` to the literal element type before storing. - // This literal must be an array. - Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, - int64 value); - - // Populate this literal with the given values. Examples: - // - // // Populate with floats. - // Array2D float_values = ... - // literal.PopulateR2FromArray2D(values); - // - // // Populate with int32s. - // literal.PopulateR2({{1, 2}, {3, 4}}); - // - // The shape and element type of this literal must match given values. For - // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 - // array of S32. - template - void PopulateR1(tensorflow::gtl::ArraySlice values); - void PopulateR1(const tensorflow::core::Bitmap& values); - template - void PopulateR2(std::initializer_list> values); - template - void PopulateFromArray(const Array& values); - template - void PopulateR2FromArray2D(const Array2D& values); - template - void PopulateR3FromArray3D(const Array3D& values); - template - void PopulateR4FromArray4D(const Array4D& values); - - // Populates literal values by calling the generator function for every cell - // in this literal object. - // - // generator must be a callable of the type - // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. - // - // This literal must have a dense layout. - template - Status Populate(const FnType& generator); - - // A parallel version of Populate(). This can be used if the generator is - // thread-safe and the values for the shape's different elements are - // independent. - template - Status PopulateParallel(const FnType& generator); - - // Fills this literal with the given value. - template - void PopulateWithValue(NativeT value); - - // Factory methods below. - // - - // Serialize from a proto. - static StatusOr> CreateFromProto( - const LiteralProto& proto); + static Literal GetFirstScalarLiteral(const LiteralSlice& literal); // Creates a new literal of a given rank. To minimize ambiguity (for users // and the compiler) these CreateR[0-2] methods should explicitly specify the @@ -889,7 +223,7 @@ class Literal : public LiteralBase { // As above, but intended to be invoked with move semantics; i.e. // // std::vector> elements = ...; - // auto result = Literal::MakeTupleOwned(std::move(elements)); + // auto result = LiteralUtil::MakeTupleOwned(std::move(elements)); // // This would have been declared as an overload, but there is ambiguity // in invocation between the above signature and this one. @@ -899,7 +233,7 @@ class Literal : public LiteralBase { // This overload lets you pass a braced list of unique_ptrs to // MakeTupleOwned: // - // Literal::MakeTupleOwned(Literal::CreateR1(...), ...). + // LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...). // // Simply relying on the MakeTupleOwned(std::vector>) // overload doesn't work because std::initializer_list's elements are always @@ -920,19 +254,6 @@ class Literal : public LiteralBase { // Create a constant token literal. Token types have no value. static std::unique_ptr CreateToken(); - // Returns a vector containing the tuple elements of this Literal as separate - // Literals. This Literal must be tuple-shaped and can be a nested tuple. The - // elements are moved into the new Literals; no data is copied. Upon return - // this Literal is set to a nil shape (empty tuple) - std::vector DecomposeTuple(); - - // This operation is the inverse of DecomposeTuple. The given elements are - // moved into the tuple elements of a new tuple-shaped Literal which is - // returned. Upon return, each of the Literals in 'elements' is set to a nil - // shape (empty tuple). - static Literal MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements); - // Creates a new Literal object with its values havings the primitive_type // type, and with dimensions defined by the dimensions parameter. // The content of the literal values is the default value of the primitive @@ -1000,194 +321,12 @@ class Literal : public LiteralBase { // dimension 1 equal to 8. static string MultiIndexAsString( tensorflow::gtl::ArraySlice multi_index); - - private: - // Recursively sets the subshapes and buffers of all subpieces rooted at - // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in - // the shape. - void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); - - // Returns the piece at the given ShapeIndex. - Piece& piece(const ShapeIndex& shape_index) { - return const_cast(LiteralBase::piece(shape_index)); - } - - Piece& root_piece() const override { return *root_piece_; }; - - // Internal template helper for the Literal::CopySliceFrom(), matching its - // arguments one by one. - template - Status CopySliceFromInternal(const LiteralBase& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); - - // Utility structure which is used to create the optimal configuration for - // a ShapeUtil::ForEachIndex() scan across two literals. - struct StrideConfig { - StrideConfig(const Shape& source_shape, const Shape& dest_shape, - tensorflow::gtl::ArraySlice dimensions); - - // The dimensions of the stride operation. Essentially every dimension - // will be iterated from base[i] to base[i]+dimensions[i], in step[i] - // steps. - tensorflow::gtl::ArraySlice dimensions; - DimensionVector base; - DimensionVector step; - int64 minor_dimension = 0; - // The size of the strides for source and destination. One of the two - // (the one looping through its most minor dimension) will be 1, while - // the other will be the stride size at the dimension matching the other - // shape most minor dimension being scanned. - int64 dest_stride = 1; - int64 source_stride = 1; - // The size of the inner loop on the most minor dimension. - int64 minor_loop_size = 1; - }; - - // Literal class always owns the shape. The parent class borrows this shape. - std::unique_ptr shape_; - - Piece* root_piece_ = nullptr; - - // Implementation details shared between Populate() and PopulateParallel() - template - Status PopulateInternal(const FnType& generator, bool parallel); - - // Deallocate the buffers held by this literal. - void DeallocateBuffers(); - - friend class LiteralBase; -}; -std::ostream& operator<<(std::ostream& out, const Literal& literal); - -// A read-only view of a Literal. A LiteralSlice contains pointers to shape and -// literal buffers always owned by others. -class LiteralSlice : public LiteralBase { - public: - LiteralSlice() : LiteralBase() {} - - // Implicit conversion constructors. - LiteralSlice(const LiteralBase& literal); - LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root); - - private: - const Piece& root_piece() const override { return *root_piece_; }; - - const Piece* root_piece_; // Not owned. -}; - -// A read-only Literal where the underlying buffers are never owned by this -// class. -class BorrowingLiteral : public LiteralBase { - public: - BorrowingLiteral() : LiteralBase() {} - - // 'src_buf_ptr' is not owned by this class and must outlive the - // lifetime of this class. It points to an appropirately sized buffer with - // data interpretered as indicated by 'shape'. - // This constructor is only used for array shapes. - BorrowingLiteral(const char* src_buf_ptr, const Shape& shape); - // Similar as above, except to be used for constructing non-nested tuples. - BorrowingLiteral(tensorflow::gtl::ArraySlice src_buf_ptrs, - const Shape& shape); - // TODO(b/79707221): adding constructors for nested tuples as well. - - private: - // Recursively builds the subtree for the given piece and sets the subshapes - // of the given piece with the given shape. - void BuildPieceSubtree(const Shape& shape, Piece* piece); - - // Accessor for the root piece of this literal. - const Piece& root_piece() const override { return root_piece_; }; - Piece root_piece_; - - // Shape of this literal. Stored as unique_ptr so such that the (default) - // move construction of this class would be trivially correct: the pointer to - // Shape root_piece_ stores will still point to the correct address. - std::unique_ptr shape_; }; -template -tensorflow::gtl::ArraySlice LiteralBase::Piece::data() const { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - CHECK_EQ(subshape().element_type(), - primitive_util::NativeToPrimitiveType()) - << "Attempting to access " - << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) - << " type, but literal element type is " - << PrimitiveType_Name(subshape().element_type()); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(buffer()), element_count()); -} - -template -tensorflow::gtl::MutableArraySlice LiteralBase::Piece::data() { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - CHECK_EQ(subshape().element_type(), - primitive_util::NativeToPrimitiveType()) - << "Attempting to access " - << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) - << " type, but literal element type is " - << PrimitiveType_Name(subshape().element_type()); - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(buffer()), element_count()); -} - -template -NativeT LiteralBase::Piece::Get( - tensorflow::gtl::ArraySlice multi_index) const { - CHECK(LayoutUtil::IsDenseArray(subshape())); - return data()[IndexUtil::MultidimensionalIndexToLinearIndex( - subshape(), multi_index)]; -} - -template -void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, - NativeT value) { - CHECK(LayoutUtil::IsDenseArray(subshape())); - data()[IndexUtil::MultidimensionalIndexToLinearIndex( - subshape(), multi_index)] = value; -} - -template -tensorflow::gtl::ArraySlice LiteralBase::data( - const ShapeIndex& shape_index) const { - return piece(shape_index).data(); -} - -template -tensorflow::gtl::MutableArraySlice Literal::data( - const ShapeIndex& shape_index) { - return piece(shape_index).data(); -} - -template -inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const { - return piece(shape_index).Get(multi_index); -} - -template -inline NativeT LiteralBase::Get( - tensorflow::gtl::ArraySlice multi_index) const { - return root_piece().Get(multi_index); -} - -template -inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value) { - return piece(shape_index).Set(multi_index, value); -} - -template -inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - NativeT value) { - return root_piece().Set(multi_index, value); -} +std::ostream& operator<<(std::ostream& out, const Literal& literal); template -/* static */ std::unique_ptr Literal::CreateR0(NativeT value) { +/* static */ std::unique_ptr LiteralUtil::CreateR0(NativeT value) { auto literal = MakeUnique(ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), {})); literal->Set({}, value); @@ -1195,7 +334,7 @@ template } template -/* static */ std::unique_ptr Literal::CreateR1( +/* static */ std::unique_ptr LiteralUtil::CreateR1( tensorflow::gtl::ArraySlice values) { auto literal = MakeUnique( ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), @@ -1205,7 +344,7 @@ template } template -/* static */ std::unique_ptr Literal::CreateR2WithLayout( +/* static */ std::unique_ptr LiteralUtil::CreateR2WithLayout( std::initializer_list> values, const Layout& layout) { auto literal = MakeUnique(ShapeUtil::MakeShapeWithLayout( @@ -1218,13 +357,13 @@ template } template -/* static */ std::unique_ptr Literal::CreateR2( +/* static */ std::unique_ptr LiteralUtil::CreateR2( std::initializer_list> values) { return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } template -/* static */ std::unique_ptr Literal::CreateR3WithLayout( +/* static */ std::unique_ptr LiteralUtil::CreateR3WithLayout( std::initializer_list>> values, const Layout& layout) { @@ -1249,14 +388,14 @@ template } template -/* static */ std::unique_ptr Literal::CreateR3( +/* static */ std::unique_ptr LiteralUtil::CreateR3( std::initializer_list>> values) { return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); } template -/* static */ std::unique_ptr Literal::CreateR4WithLayout( +/* static */ std::unique_ptr LiteralUtil::CreateR4WithLayout( std::initializer_list>>> values, @@ -1287,7 +426,7 @@ template } template -/* static */ std::unique_ptr Literal::CreateSparse( +/* static */ std::unique_ptr LiteralUtil::CreateSparse( tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, tensorflow::gtl::ArraySlice values, bool sort) { int64 num_elements = values.size(); @@ -1302,7 +441,7 @@ template } template -/* static */ std::unique_ptr Literal::CreateR4( +/* static */ std::unique_ptr LiteralUtil::CreateR4( std::initializer_list>>> values) { @@ -1310,7 +449,7 @@ template } template -/* static */ std::unique_ptr Literal::CreateFromArrayWithLayout( +/* static */ std::unique_ptr LiteralUtil::CreateFromArrayWithLayout( const Array& values, const Layout& layout) { auto literal = MakeUnique(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), values.dimensions(), @@ -1320,38 +459,40 @@ template } template -/* static */ std::unique_ptr Literal::CreateFromArray( +/* static */ std::unique_ptr LiteralUtil::CreateFromArray( const Array& values) { return CreateFromArrayWithLayout( values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); } template -/* static */ std::unique_ptr Literal::CreateR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { +/* static */ std::unique_ptr +LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } template -/* static */ std::unique_ptr Literal::CreateR2FromArray2D( +/* static */ std::unique_ptr LiteralUtil::CreateR2FromArray2D( const Array2D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr Literal::CreateR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout) { +/* static */ std::unique_ptr +LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } template -/* static */ std::unique_ptr Literal::CreateR3FromArray3D( +/* static */ std::unique_ptr LiteralUtil::CreateR3FromArray3D( const Array3D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr Literal::CreateR3Projected( +/* static */ std::unique_ptr LiteralUtil::CreateR3Projected( std::initializer_list> values, int64 projection) { int64 dim0_size = projection; @@ -1376,7 +517,7 @@ template } template -/* static */ std::unique_ptr Literal::CreateR4Projected( +/* static */ std::unique_ptr LiteralUtil::CreateR4Projected( std::initializer_list> values, int64 projection_p, int64 projection_z) { int64 dim0_size = projection_p; @@ -1404,49 +545,21 @@ template } template -/* static */ std::unique_ptr Literal::CreateR4FromArray4D( +/* static */ std::unique_ptr LiteralUtil::CreateR4FromArray4D( const Array4D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr Literal::CreateR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout) { +/* static */ std::unique_ptr +LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } -template -NativeT LiteralBase::GetFirstElement() const { - return data().at(0); -} - -template -NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index) const { - CHECK( - LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index))); - return data(shape_index)[sparse_element_number]; -} - -template -void Literal::AppendSparseElement( - tensorflow::gtl::ArraySlice multi_index, NativeT value, - const ShapeIndex& shape_index) { - Piece& p = piece(shape_index); - const Shape& subshape = p.subshape(); - CHECK(LayoutUtil::IsSparseArray(subshape)); - int64 rank = ShapeUtil::Rank(subshape); - CHECK_EQ(multi_index.size(), rank); - int64 last_element = p.sparse_indices()->index_count(); - CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout())); - p.sparse_indices()->Append(multi_index); - CHECK_LT(last_element, p.data().size()); - p.data()[last_element] = value; -} - // Returns an identity matrix (rank 2) with the given row and column count. template -/* static */ std::unique_ptr Literal::MakeIdentityR2(int64 size) { +/* static */ std::unique_ptr LiteralUtil::MakeIdentityR2(int64 size) { Array2D array(size, size, 0); for (int64 i = 0; i < size; ++i) { array(i, i) = 1; @@ -1455,174 +568,8 @@ template } template -void LiteralBase::EachCell( - std::function indices, - NativeT value)> - per_cell) const { - if (ShapeUtil::IsZeroElementArray(shape())) { - return; - } - std::vector indices(ShapeUtil::Rank(shape()), 0); - do { - per_cell(indices, Get(indices)); - } while (IndexUtil::BumpIndices(shape(), &indices)); -} - -template -inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 1); - CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); - CHECK_EQ(shape().element_type(), - primitive_util::NativeToPrimitiveType()); - for (int64 i = 0; i < values.size(); ++i) { - Set({i}, values[i]); - } -} - -template -void Literal::PopulateR2( - std::initializer_list> values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(ShapeUtil::Rank(shape()), 2); - CHECK_EQ(shape().element_type(), - primitive_util::NativeToPrimitiveType()); - - const int64 dim0_size = values.size(); - const int64 dim1_size = values.begin()->size(); - CHECK_EQ(dim0_size, shape().dimensions(0)); - CHECK_EQ(dim1_size, shape().dimensions(1)); - - int64 dim0 = 0; - for (auto inner_list : values) { - int64 dim1 = 0; - for (auto value : inner_list) { - Set({dim0, dim1}, value); - ++dim1; - } - CHECK_EQ(dim1_size, dim1); - ++dim0; - } -} - -template -void Literal::PopulateFromArray(const Array& values) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(shape().element_type(), - primitive_util::NativeToPrimitiveType()); - CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions()); - for (int dim = 0; dim < values.num_dimensions(); ++dim) { - CHECK_EQ(values.dim(dim), shape().dimensions(dim)); - } - values.Each([this](tensorflow::gtl::ArraySlice indices, - NativeT value) { this->Set(indices, value); }); -} - -template -void Literal::PopulateR2FromArray2D(const Array2D& values) { - PopulateFromArray(values); -} - -template -void Literal::PopulateR3FromArray3D(const Array3D& values) { - PopulateFromArray(values); -} - -template -void Literal::PopulateR4FromArray4D(const Array4D& values) { - PopulateFromArray(values); -} - -template -void Literal::PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort) { - CHECK(LayoutUtil::IsSparseArray(shape())); - int rank = ShapeUtil::Rank(shape()); - CHECK_EQ(indices.rank(), rank); - int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout()); - CHECK_LE(indices.max_indices(), max_elements); - int64 num_elements = values.size(); - CHECK_LE(num_elements, max_elements); - CHECK_EQ(num_elements, indices.index_count()); - auto root_data = root_piece().data(); - // Piece::data() returns an ArraySlice of size equal to the number of indices - // in the SparseIndexArray. So there is no need to adjust the size of the data - // here. It is enough to just copy the incoming values into the data buffer. - std::copy(values.begin(), values.end(), root_data.begin()); - *this->root_piece().sparse_indices() = std::move(indices); - if (sort) { - auto root_data = this->root_piece().data(); - this->root_piece().sparse_indices()->SortWithValues(root_data); - } - DCHECK(this->root_piece().sparse_indices()->Validate(shape())); -} - -template -Status Literal::PopulateInternal(const FnType& generator, bool parallel) { - const Shape& this_shape = shape(); - const int64 rank = ShapeUtil::Rank(this_shape); - TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); - TF_RET_CHECK(this_shape.element_type() == - primitive_util::NativeToPrimitiveType()); - tensorflow::gtl::MutableArraySlice literal_data = data(); - if (rank > 0) { - StrideConfig stride_config(this_shape, this_shape, - AsInt64Slice(this_shape.dimensions())); - int64 minor_dimension_size = - ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); - - auto init_function = [&](tensorflow::gtl::ArraySlice indexes) { - DimensionVector minor_scan_indexes(rank, 0); - const int64 index = - IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); - std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); - for (int64 i = 0; i < minor_dimension_size; ++i) { - minor_scan_indexes[stride_config.minor_dimension] = i; - literal_data.at(index + i) = generator(minor_scan_indexes); - } - }; - if (parallel) { - ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base, - stride_config.dimensions, - stride_config.step, init_function); - } else { - ShapeUtil::ForEachIndex( - this_shape, stride_config.base, stride_config.dimensions, - stride_config.step, - [&init_function](tensorflow::gtl::ArraySlice indexes) { - init_function(indexes); - return true; - }); - } - } else { - // For scalars. - literal_data.at(0) = generator({}); - } - return Status::OK(); -} -template -Status Literal::Populate(const FnType& generator) { - return PopulateInternal(generator, /*parallel=*/false); -} - -template -Status Literal::PopulateParallel(const FnType& generator) { - return PopulateInternal(generator, /*parallel=*/true); -} - -template -void Literal::PopulateWithValue(NativeT value) { - CHECK(ShapeUtil::IsArray(shape())); - CHECK_EQ(shape().element_type(), - primitive_util::NativeToPrimitiveType()); - for (NativeT& element : data()) { - element = value; - } -} - -template -/* static */ std::unique_ptr Literal::CreateFullWithDescendingLayout( +/* static */ std::unique_ptr +LiteralUtil::CreateFullWithDescendingLayout( tensorflow::gtl::ArraySlice dimensions, NativeT value) { auto literal = MakeUnique(ShapeUtil::MakeShapeWithDescendingLayout( primitive_util::NativeToPrimitiveType(), dimensions)); @@ -1630,44 +577,9 @@ template return literal; } -template -std::unique_ptr LiteralBase::Replicate(int64 times) const { - DimensionVector bounds = {times}; - bounds.reserve(shape().dimensions_size() + 1); - for (int64 bound : shape().dimensions()) { - bounds.push_back(bound); - } - auto literal = - MakeUnique(ShapeUtil::MakeShape(shape().element_type(), bounds)); - int64 elements = ShapeUtil::ElementsIn(literal->shape()); - if (elements == 0) { - return literal; - } - - DimensionVector output_indices(bounds.size(), 0); - tensorflow::gtl::ArraySlice input_indices = output_indices; - input_indices.remove_prefix(1); - - bool done = false; - while (!done) { - const auto element = Get(input_indices); - literal->Set(output_indices, element); - - done = true; - for (int n = 0; n < output_indices.size(); ++n) { - ++output_indices[n]; - if (output_indices[n] < bounds[n]) { - done = false; - break; - } - output_indices[n] = 0; - } - } - return literal; -} - template -/* static */ StatusOr> Literal::CreateRandomLiteral( +/* static */ StatusOr> +LiteralUtil::CreateRandomLiteral( const Shape& shape, const std::function)>& generator) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; @@ -1681,8 +593,9 @@ template } template -/* static */ StatusOr> Literal::CreateRandomLiteral( - const Shape& shape, E* engine, T mean, T stddev) { +/* static */ StatusOr> +LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, + T stddev) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; std::normal_distribution generator(mean, stddev); return CreateRandomLiteral( @@ -1692,8 +605,8 @@ template } template -/* static */ StatusOr> Literal::CreateRandomLiteral( - const Shape& shape, T mean, T stddev) { +/* static */ StatusOr> +LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { std::minstd_rand0 engine; return CreateRandomLiteral(shape, &engine, mean, stddev); } diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc index fed0e58e66a04df2ff9554cb0dd0053b7c669803..69ef4f7a2f3ea559a334a11cbe8392b610742bab 100644 --- a/tensorflow/compiler/xla/metric_table_report.cc +++ b/tensorflow/compiler/xla/metric_table_report.cc @@ -134,8 +134,7 @@ void MetricTableReport::AppendHeader() { void MetricTableReport::AppendCategoryTable() { const std::vector categories = MakeCategories(&entries_); - AppendLine("********** categories table **********"); - AppendLine("The left hand side numbers are ", metric_name_, "."); + AppendLine("********** categories table for ", metric_name_, " **********"); AppendLine(); double metric_sum = UnaccountedMetric(); @@ -185,8 +184,8 @@ void MetricTableReport::AppendCategoryTable() { } void MetricTableReport::AppendEntryTable() { - AppendLine("********** ", entry_name_, " table **********"); - AppendLine("The left hand side numbers are ", metric_name_, "."); + AppendLine("********** ", entry_name_, " table for ", metric_name_, + " **********"); AppendLine(); double metric_sum = UnaccountedMetric(); diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 857aae0a7982a57bb3057a6f267f5f033a0fdde4..6b7fd10d63f8f97b0e0bf7570488c06323368d75 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h index 45a9fe012784d3e4168e7549240dec962aa1a17a..98dccaa9a246520bf60217b96d67a13a24c34b4a 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.h +++ b/tensorflow/compiler/xla/packed_literal_reader.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 83834c1ff65ea2f9989fe08279c29056d9070adb..c8f2d65c223ccfe20862954c224d016cca421812 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -33,6 +33,7 @@ cc_library( srcs = ["numpy_bridge.cc"], hdrs = ["numpy_bridge.h"], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -52,9 +53,9 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", - "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", @@ -70,7 +71,7 @@ tf_py_wrap_cc( deps = [ ":local_computation_builder", ":numpy_bridge", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:cpu_plugin", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index b5ba4e2d429e465649fc1b7acaf19fcb75f6d1ef..8246f76d3443d58f4174cc4f86100f54d6b46928 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -14,7 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/local_computation_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" @@ -485,6 +486,11 @@ LocalOp LocalComputationBuilder::ConvertElementType( return xla::ConvertElementType(operand.op(), new_element_type); } +LocalOp LocalComputationBuilder::BitcastConvertType( + const LocalOp& operand, PrimitiveType new_element_type) { + return xla::BitcastConvertType(operand.op(), new_element_type); +} + LocalOp LocalComputationBuilder::Call( const LocalComputation& local_computation, tensorflow::gtl::ArraySlice operands) { @@ -613,6 +619,12 @@ _FORWARD_BINOP(Min) _FORWARD_BINOP(And) _FORWARD_BINOP(Or) _FORWARD_BINOP(Xor) +_FORWARD_BINOP(ShiftLeft) +_FORWARD_BINOP(ShiftRightArithmetic) +_FORWARD_BINOP(ShiftRightLogical) +_FORWARD_BINOP(Atan2) +_FORWARD_BINOP(Pow) +_FORWARD_BINOP(Complex) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) @@ -626,13 +638,30 @@ _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) _FORWARD_UNOP(Sin) _FORWARD_UNOP(Tanh) -_FORWARD_UNOP(SqrtF32) -_FORWARD_UNOP(SquareF32) -_FORWARD_BINOP(Pow) _FORWARD_UNOP(IsFinite) -_FORWARD_UNOP(ReciprocalF32) _FORWARD_UNOP(Neg) _FORWARD_UNOP(Sort) +_FORWARD_UNOP(Sqrt) +_FORWARD_UNOP(Rsqrt) +_FORWARD_UNOP(Square) +_FORWARD_UNOP(Reciprocal) +_FORWARD_UNOP(Erfc) +_FORWARD_UNOP(Erf) +_FORWARD_UNOP(ErfInv) +_FORWARD_UNOP(Lgamma) +_FORWARD_UNOP(Digamma) +_FORWARD_UNOP(Acos) +_FORWARD_UNOP(Asin) +_FORWARD_UNOP(Atan) +_FORWARD_UNOP(Tan) +_FORWARD_UNOP(Acosh) +_FORWARD_UNOP(Asinh) +_FORWARD_UNOP(Atanh) +_FORWARD_UNOP(Cosh) +_FORWARD_UNOP(Sinh) +_FORWARD_UNOP(Real) +_FORWARD_UNOP(Imag) +_FORWARD_UNOP(Conj) #undef _FORWARD #undef _FORWARD_UNOP diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index e920f8aecd6cfc6fd4c965b1cc9eceb36b2d7371..a568c24c6376e1fe17f5e5a4f6626bf0970985a3 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -259,6 +259,9 @@ class LocalComputationBuilder { LocalOp ConvertElementType(const LocalOp& operand, PrimitiveType new_element_type); + LocalOp BitcastConvertType(const LocalOp& operand, + PrimitiveType new_element_type); + LocalOp Call(const LocalComputation& local_computation, tensorflow::gtl::ArraySlice operands); @@ -333,6 +336,12 @@ class LocalComputationBuilder { _FORWARD_BINOP(And) _FORWARD_BINOP(Or) _FORWARD_BINOP(Xor) + _FORWARD_BINOP(ShiftLeft) + _FORWARD_BINOP(ShiftRightArithmetic) + _FORWARD_BINOP(ShiftRightLogical) + _FORWARD_BINOP(Atan2) + _FORWARD_BINOP(Pow) + _FORWARD_BINOP(Complex) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) @@ -346,13 +355,30 @@ class LocalComputationBuilder { _FORWARD_UNOP(Cos) _FORWARD_UNOP(Sin) _FORWARD_UNOP(Tanh) - _FORWARD_UNOP(SqrtF32) - _FORWARD_UNOP(SquareF32) - _FORWARD_BINOP(Pow) _FORWARD_UNOP(IsFinite) - _FORWARD_UNOP(ReciprocalF32) _FORWARD_UNOP(Neg) _FORWARD_UNOP(Sort) + _FORWARD_UNOP(Sqrt) + _FORWARD_UNOP(Rsqrt) + _FORWARD_UNOP(Square) + _FORWARD_UNOP(Reciprocal) + _FORWARD_UNOP(Erfc) + _FORWARD_UNOP(Erf) + _FORWARD_UNOP(ErfInv) + _FORWARD_UNOP(Lgamma) + _FORWARD_UNOP(Digamma) + _FORWARD_UNOP(Acos) + _FORWARD_UNOP(Asin) + _FORWARD_UNOP(Atan) + _FORWARD_UNOP(Tan) + _FORWARD_UNOP(Acosh) + _FORWARD_UNOP(Asinh) + _FORWARD_UNOP(Atanh) + _FORWARD_UNOP(Cosh) + _FORWARD_UNOP(Sinh) + _FORWARD_UNOP(Real) + _FORWARD_UNOP(Imag) + _FORWARD_UNOP(Conj) #undef _FORWARD #undef _FORWARD_UNOP diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 76e9e637cd45509ec443be092fd9934db1a9653f..5d5a955bfee35b38a61b9a9f792c1b31259ce044 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -109,7 +109,7 @@ limitations under the License. // Must be included first #include "tensorflow/python/lib/core/numpy.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -957,6 +957,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Tuple; %unignore xla::swig::LocalComputationBuilder::GetTupleElement; %unignore xla::swig::LocalComputationBuilder::ConvertElementType; +%unignore xla::swig::LocalComputationBuilder::BitcastConvertType; %unignore xla::swig::LocalComputationBuilder::Call; %unignore xla::swig::LocalComputationBuilder::Transpose; %unignore xla::swig::LocalComputationBuilder::Rev; @@ -989,6 +990,9 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::And; %unignore xla::swig::LocalComputationBuilder::Or; %unignore xla::swig::LocalComputationBuilder::Xor; +%unignore xla::swig::LocalComputationBuilder::ShiftLeft; +%unignore xla::swig::LocalComputationBuilder::ShiftRightArithmetic; +%unignore xla::swig::LocalComputationBuilder::ShiftRightLogical; %unignore xla::swig::LocalComputationBuilder::Not; %unignore xla::swig::LocalComputationBuilder::Abs; %unignore xla::swig::LocalComputationBuilder::Exp; @@ -1002,13 +1006,33 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Cos; %unignore xla::swig::LocalComputationBuilder::Sin; %unignore xla::swig::LocalComputationBuilder::Tanh; -%unignore xla::swig::LocalComputationBuilder::SqrtF32; -%unignore xla::swig::LocalComputationBuilder::SquareF32; -%unignore xla::swig::LocalComputationBuilder::Pow; +%unignore xla::swig::LocalComputationBuilder::Atan2; %unignore xla::swig::LocalComputationBuilder::IsFinite; -%unignore xla::swig::LocalComputationBuilder::ReciprocalF32; +%unignore xla::swig::LocalComputationBuilder::Pow; %unignore xla::swig::LocalComputationBuilder::Neg; %unignore xla::swig::LocalComputationBuilder::Sort; +%unignore xla::swig::LocalComputationBuilder::Sqrt; +%unignore xla::swig::LocalComputationBuilder::Rsqrt; +%unignore xla::swig::LocalComputationBuilder::Square; +%unignore xla::swig::LocalComputationBuilder::Reciprocal; +%unignore xla::swig::LocalComputationBuilder::Erfc; +%unignore xla::swig::LocalComputationBuilder::Erf; +%unignore xla::swig::LocalComputationBuilder::ErfInv; +%unignore xla::swig::LocalComputationBuilder::Lgamma; +%unignore xla::swig::LocalComputationBuilder::Digamma; +%unignore xla::swig::LocalComputationBuilder::Acos; +%unignore xla::swig::LocalComputationBuilder::Asin; +%unignore xla::swig::LocalComputationBuilder::Atan; +%unignore xla::swig::LocalComputationBuilder::Tan; +%unignore xla::swig::LocalComputationBuilder::Acosh; +%unignore xla::swig::LocalComputationBuilder::Asinh; +%unignore xla::swig::LocalComputationBuilder::Atanh; +%unignore xla::swig::LocalComputationBuilder::Cosh; +%unignore xla::swig::LocalComputationBuilder::Sinh; +%unignore xla::swig::LocalComputationBuilder::Real; +%unignore xla::swig::LocalComputationBuilder::Imag; +%unignore xla::swig::LocalComputationBuilder::Conj; +%unignore xla::swig::LocalComputationBuilder::Complex; %unignore xla::swig::DestructureLocalShapedBufferTuple; %unignore xla::swig::DeleteLocalShapedBuffer; %unignore xla::swig::DeleteLocalComputation; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 68648a3a176363de69a56ecb8070f82862874e94..6f665faf61b25b23a32ce4d0a012543ba18d7e64 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/numpy_bridge.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/logging.h" @@ -49,6 +50,8 @@ int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { return NPY_FLOAT32; case F64: return NPY_FLOAT64; + case C64: + return NPY_COMPLEX64; case TUPLE: return NPY_OBJECT; default: @@ -82,6 +85,8 @@ PrimitiveType NumpyTypeToPrimitiveType(int np_type) { return F32; case NPY_FLOAT64: return F64; + case NPY_COMPLEX64: + return C64; case NPY_OBJECT: return TUPLE; default: @@ -103,6 +108,7 @@ bool NumpyTypeIsValid(int np_type) { case NPY_FLOAT16: case NPY_FLOAT32: case NPY_FLOAT64: + case NPY_COMPLEX64: case NPY_OBJECT: return true; default: @@ -374,7 +380,7 @@ StatusOr> XlaLiteralFromPyObject(PyObject* o) { TF_ASSIGN_OR_RETURN(auto literal, XlaLiteralFromPyObject(element)); elements.push_back(std::move(literal)); } - return Literal::MakeTupleOwned(std::move(elements)); + return LiteralUtil::MakeTupleOwned(std::move(elements)); } else if (PyArray_Check(o)) { PyArrayObject* py_array = reinterpret_cast(o); int rank = PyArray_NDIM(py_array); @@ -383,7 +389,7 @@ StatusOr> XlaLiteralFromPyObject(PyObject* o) { dimensions[i] = PyArray_DIM(py_array, i); } int np_type = PyArray_TYPE(py_array); - auto literal = Literal::CreateFromDimensions( + auto literal = LiteralUtil::CreateFromDimensions( NumpyTypeToPrimitiveType(np_type), dimensions); TF_RETURN_IF_ERROR( CopyNumpyArrayToLiteral(np_type, py_array, literal.get())); @@ -424,6 +430,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, case NPY_FLOAT64: CopyNumpyArrayToLiteral(py_array, literal); break; + case NPY_COMPLEX64: + CopyNumpyArrayToLiteral(py_array, literal); + break; default: return InvalidArgument( "No XLA literal container for Numpy type number: %d", np_type); @@ -461,6 +470,9 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, case NPY_FLOAT64: CopyLiteralToNumpyArray(literal, py_array); break; + case NPY_COMPLEX64: + CopyLiteralToNumpyArray(literal, py_array); + break; default: LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; } diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 64f0aae0f9790f0199ac6cb931a5c9f6dc356f4c..a67c93a4fb7413f9bbcb9afd92c36fd118836e1f 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -25,7 +25,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/python/lib/core/numpy.h" diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index abb97d0c6fae515b8f1c11c7df48299f05fc9fad..a2c6fc344d192265d536ef7e23ad5c6d7c847014 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -99,12 +99,30 @@ _UNARY_OPS = [ 'Cos', 'Sin', 'Tanh', - 'SqrtF32', - 'SquareF32', 'IsFinite', - 'ReciprocalF32', + 'Sqrt', + 'Rsqrt', + 'Square', + 'Reciprocal', 'Neg', 'Sort', + 'Erf', + 'Erfc', + 'ErfInv', + 'Lgamma', + 'Digamma', + 'Acos', + 'Asin', + 'Atan', + 'Tan', + 'Acosh', + 'Asinh', + 'Atanh', + 'Cosh', + 'Sinh', + 'Real', + 'Imag', + 'Conj', ] _BINARY_OPS = [ @@ -125,6 +143,11 @@ _BINARY_OPS = [ 'Or', 'Xor', 'Pow', + 'ShiftLeft', + 'ShiftRightArithmetic', + 'ShiftRightLogical', + 'Atan2', + 'Complex', ] @@ -461,14 +484,16 @@ class LocalComputation(object): if self.is_compiled: raise ValueError('Attempt to compile a compiled local XLA computation.') + result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape()) + if layout_fn: argument_shapes = [ shape.map_leaves(layout_fn) for shape in argument_shapes ] - result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape()) result_shape = result_shape.map_leaves(layout_fn) - compile_options = compile_options or CompileOptions() - compile_options.result_shape = result_shape + + compile_options = compile_options or CompileOptions() + compile_options.result_shape = result_shape return LocalComputation( self.c_local_computation.Compile(argument_shapes, compile_options), is_compiled=True) @@ -700,6 +725,18 @@ class ComputationBuilder(object): """ return self._client.ConvertElementType(operand, new_element_type) + def BitcastConvertType(self, operand, new_element_type): + """Enqueues a bitcast type conversion operation onto the computation. + + Args: + operand: the operand to convert. + new_element_type: the target primitive type. + + Returns: + A LocalOp representing the added conversion op. + """ + return self._client.BitcastConvertType(operand, new_element_type) + def GetShape(self, operand): return _wrap_shape(self._client.GetShape(operand)) diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 0564ddcb85ee3952f82649687e79a864999baf2c..fd98e19457f61aade947aa354d2e415148d127f6 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -171,6 +171,24 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + def testShiftLeft(self): + c = self._NewComputation() + c.ShiftLeft(c.Constant(NumpyArrayS32([3])), + c.Constant(NumpyArrayS32([2]))) + self._ExecuteAndCompareClose(c, expected=[12]) + + def testShiftRightArithmetic(self): + c = self._NewComputation() + c.ShiftRightArithmetic(c.Constant(NumpyArrayS32([-2])), + c.Constant(NumpyArrayS32([1]))) + self._ExecuteAndCompareClose(c, expected=[-1]) + + def testShiftRightLogical(self): + c = self._NewComputation() + c.ShiftRightLogical(c.Constant(NumpyArrayS32([-1])), + c.Constant(NumpyArrayS32([1]))) + self._ExecuteAndCompareClose(c, expected=[2**31 - 1]) + def testGetProto(self): c = self._NewComputation() c.Add( @@ -471,6 +489,34 @@ class SingleOpTest(LocalComputationTest): for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): _ConvertAndTest(x, src_dtype, dst_dtype) + def testBitcastConvertType(self): + xla_x32_types = { + np.int32: xla_client.xla_data_pb2.S32, + np.float32: xla_client.xla_data_pb2.F32, + } + + xla_x64_types = { + np.int64: xla_client.xla_data_pb2.S64, + np.float64: xla_client.xla_data_pb2.F64, + } + + def _ConvertAndTest(template, src_dtype, dst_dtype, dst_etype): + c = self._NewComputation() + x = c.Constant(np.array(template, dtype=src_dtype)) + c.BitcastConvertType(x, dst_etype) + + result = c.Build().Compile().Execute() + expected = np.array(template, src_dtype).view(dst_dtype) + + self.assertEqual(result.shape, expected.shape) + self.assertEqual(result.dtype, expected.dtype) + np.testing.assert_equal(result, expected) + + x = [0, 1, 0, 0, 1] + for xla_types in [xla_x32_types, xla_x64_types]: + for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): + _ConvertAndTest(x, src_dtype, dst_dtype, xla_types[dst_dtype]) + def testCrossReplicaSumOneReplica(self): samples = [ NumpyArrayF32(42.0), diff --git a/tensorflow/compiler/xla/python_api/BUILD b/tensorflow/compiler/xla/python_api/BUILD index 8999cda5ef852d1246bea45a3312575ec1ac0721..d790c4db6c466a2bf4d2cf30365749fb901f74a0 100644 --- a/tensorflow/compiler/xla/python_api/BUILD +++ b/tensorflow/compiler/xla/python_api/BUILD @@ -10,6 +10,8 @@ py_library( srcs = ["types.py"], deps = [ "//tensorflow/compiler/xla:xla_data_proto_py", + "//tensorflow/python:dtypes", + "//tensorflow/python:platform", "//third_party/py/numpy", ], ) diff --git a/tensorflow/compiler/xla/python_api/types.py b/tensorflow/compiler/xla/python_api/types.py index b60f8dce92ace1b2c682374a2605b3a477936bbc..57dfce3971b829d2a3052d347e5d2d322db0c841 100644 --- a/tensorflow/compiler/xla/python_api/types.py +++ b/tensorflow/compiler/xla/python_api/types.py @@ -20,9 +20,10 @@ from __future__ import print_function import collections -import numpy as np +import numpy as _np # Avoids becoming a part of public Tensorflow API. from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.python.framework import dtypes # Records corresponsence between a XLA primitive type and Python/Numpy types. # @@ -40,76 +41,82 @@ TypeConversionRecord = collections.namedtuple('TypeConversionRecord', [ # Maps from XLA primitive types to TypeConversionRecord. MAP_XLA_TYPE_TO_RECORD = { + xla_data_pb2.BF16: + TypeConversionRecord( + primitive_type=xla_data_pb2.BF16, + numpy_dtype=dtypes.bfloat16.as_numpy_dtype, + literal_field_name='bf16s', + literal_field_type=float), xla_data_pb2.F16: TypeConversionRecord( primitive_type=xla_data_pb2.F16, - numpy_dtype=np.float16, + numpy_dtype=_np.float16, literal_field_name='f16s', literal_field_type=float), xla_data_pb2.F32: TypeConversionRecord( primitive_type=xla_data_pb2.F32, - numpy_dtype=np.float32, + numpy_dtype=_np.float32, literal_field_name='f32s', literal_field_type=float), xla_data_pb2.F64: TypeConversionRecord( primitive_type=xla_data_pb2.F64, - numpy_dtype=np.float64, + numpy_dtype=_np.float64, literal_field_name='f64s', literal_field_type=float), xla_data_pb2.S8: TypeConversionRecord( primitive_type=xla_data_pb2.S8, - numpy_dtype=np.int8, + numpy_dtype=_np.int8, literal_field_name='s8s', literal_field_type=int), xla_data_pb2.S16: TypeConversionRecord( primitive_type=xla_data_pb2.S16, - numpy_dtype=np.int16, + numpy_dtype=_np.int16, literal_field_name='s16s', literal_field_type=int), xla_data_pb2.S32: TypeConversionRecord( primitive_type=xla_data_pb2.S32, - numpy_dtype=np.int32, + numpy_dtype=_np.int32, literal_field_name='s32s', literal_field_type=int), xla_data_pb2.S64: TypeConversionRecord( primitive_type=xla_data_pb2.S64, - numpy_dtype=np.int64, + numpy_dtype=_np.int64, literal_field_name='s64s', literal_field_type=int), xla_data_pb2.U8: TypeConversionRecord( primitive_type=xla_data_pb2.U8, - numpy_dtype=np.uint8, + numpy_dtype=_np.uint8, literal_field_name='s8s', literal_field_type=int), xla_data_pb2.U16: TypeConversionRecord( primitive_type=xla_data_pb2.U16, - numpy_dtype=np.uint16, + numpy_dtype=_np.uint16, literal_field_name='s16s', literal_field_type=int), xla_data_pb2.U32: TypeConversionRecord( primitive_type=xla_data_pb2.U32, - numpy_dtype=np.uint32, + numpy_dtype=_np.uint32, literal_field_name='s32s', literal_field_type=int), xla_data_pb2.U64: TypeConversionRecord( primitive_type=xla_data_pb2.U64, - numpy_dtype=np.uint64, + numpy_dtype=_np.uint64, literal_field_name='s64s', literal_field_type=int), xla_data_pb2.PRED: TypeConversionRecord( primitive_type=xla_data_pb2.PRED, - numpy_dtype=np.bool, + numpy_dtype=_np.bool, literal_field_name='preds', literal_field_type=bool) } @@ -119,6 +126,6 @@ MAP_XLA_TYPE_TO_RECORD = { # doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, # when keying by dtype in this dict, we use the string form of dtypes. MAP_DTYPE_TO_RECORD = { - str(np.dtype(record.numpy_dtype)): record + str(_np.dtype(record.numpy_dtype)): record for record in MAP_XLA_TYPE_TO_RECORD.values() } diff --git a/tensorflow/compiler/xla/python_api/xla_literal.py b/tensorflow/compiler/xla/python_api/xla_literal.py index b040098c294ffaae92b72f678947f99289239314..757e41a78ad2b57d2ef6e1f3055160be22c7b3ed 100644 --- a/tensorflow/compiler/xla/python_api/xla_literal.py +++ b/tensorflow/compiler/xla/python_api/xla_literal.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np +import numpy as _np # Avoids becoming a part of public Tensorflow API. from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python_api import types @@ -35,7 +35,7 @@ def ConvertLiteralToNumpyArray(literal): type_record = types.MAP_XLA_TYPE_TO_RECORD[element_type] if not literal.shape.dimensions: - return np.array( + return _np.array( getattr(literal, type_record.literal_field_name)[0], type_record.numpy_dtype) else: @@ -54,7 +54,7 @@ def ConvertLiteralToNumpyArray(literal): numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='C') else: raise NotImplementedError('Unsupported layout: {0}'.format(layout_order)) - ndarray = np.array( + ndarray = _np.array( getattr(literal, type_record.literal_field_name), copy=False, dtype=type_record.numpy_dtype) @@ -69,11 +69,11 @@ def _ConvertNumpyArrayToLiteral(ndarray): if ndarray.ndim == 0: getattr(literal, type_record.literal_field_name).append( - np.asscalar(ndarray.astype(type_record.literal_field_type))) + _np.asscalar(ndarray.astype(type_record.literal_field_type))) else: # Ndarrays with boolean dtypes need special type conversion with protobufs - if ndarray.dtype in {np.bool_, np.dtype('bool')}: - for element in np.nditer(ndarray): + if ndarray.dtype in {_np.bool_, _np.dtype('bool')}: + for element in _np.nditer(ndarray): getattr(literal, type_record.literal_field_name).append( type_record.literal_field_type(element)) else: diff --git a/tensorflow/compiler/xla/python_api/xla_shape.py b/tensorflow/compiler/xla/python_api/xla_shape.py index 6af28958035bbb03e7e1dbb0d0c7bb2c2f25b96d..f158f6b2410352432445f669155aff0af5526abf 100644 --- a/tensorflow/compiler/xla/python_api/xla_shape.py +++ b/tensorflow/compiler/xla/python_api/xla_shape.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np +import numpy as _np # Avoids becoming a part of public Tensorflow API. from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python_api import types @@ -111,7 +111,7 @@ def _CreateShapeFromNumpy(ndarray): # pylint: disable=invalid-name # Set the shape's layout based on the ordering of ndarray. # Numpy arrays come in two orders: Fortran (column-major) and C (row-major). - if np.isfortran(ndarray): + if _np.isfortran(ndarray): # Column-major layout. This corresponds to a "dimension order is # minor-to-major" layout in XLA. layout = range(ndarray.ndim) diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index c289c84cff743871a7126cb932d6cda823ceb696..a803520876952a0ab67ecb827b1f256c915335f9 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -18,7 +18,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -510,8 +511,8 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( std::pair lhs_dilation, std::pair rhs_dilation, ConvolutionDimensionNumbers dnums) { HloComputation::Builder b("ConvArray4DGeneralDimensionDilated"); - auto lhs_literal = Literal::CreateR4FromArray4D(lhs); - auto rhs_literal = Literal::CreateR4FromArray4D(rhs); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(lhs); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(rhs); std::array ordered_kernel_strides; std::array ordered_input_dimensions; diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index 9da9bc60a2025e63b57a3be9ed360d150f88d73c..8091bed4996a753649a5ecedda69a1ae48fb5897 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -53,7 +53,7 @@ class ReferenceUtilTest : public ::testing::Test { TEST_F(ReferenceUtilTest, TransposeArray2D) { auto result = ReferenceUtil::TransposeArray2D(*matrix_); - auto actual_literal = Literal::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}}, *actual_literal, ErrorSpec(0.0001)); } @@ -65,7 +65,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { {11.f, 12.f}, }); auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); - auto actual_literal = Literal::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{58.f, 64.f}, {139.f, 154.f}}, *actual_literal, ErrorSpec(0.0001)); } @@ -73,7 +73,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { TEST_F(ReferenceUtilTest, ReduceToColArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add); - auto actual_literal = Literal::CreateR1(*result); + auto actual_literal = LiteralUtil::CreateR1(*result); LiteralTestUtil::ExpectR1Near({6.f, 15.f}, *actual_literal, ErrorSpec(0.0001)); } @@ -81,13 +81,13 @@ TEST_F(ReferenceUtilTest, ReduceToColArray2D) { TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add); - auto actual_literal = Literal::CreateR1(*result); + auto actual_literal = LiteralUtil::CreateR1(*result); LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) { - auto result = Literal::CreateR1(ReferenceUtil::Reduce4DTo1D( + auto result = LiteralUtil::CreateR1(ReferenceUtil::Reduce4DTo1D( Array4D(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2}, [](float a, float b) { return a + b; })); LiteralTestUtil::ExpectR1Equal({0}, *result); @@ -96,7 +96,7 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) { TEST_F(ReferenceUtilTest, MapArray2D) { auto identity = [](float value) { return log(exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); - auto actual_literal = Literal::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal, ErrorSpec(0.0001)); } @@ -106,7 +106,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { return value + row + col; }; auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index); - auto actual_literal = Literal::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}}, *actual_literal, ErrorSpec(0.0001)); } @@ -117,7 +117,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) { input->FillWithMultiples(1.0f); auto multiply_by_two = [](float value) { return 2 * value; }; auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two); - auto actual_literal = Literal::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.FillWithMultiples(2.0f); @@ -134,7 +134,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { return value - (3 * 4 * 5 * plane + 4 * 5 * depth + 5 * height + width); }; auto result = ReferenceUtil::MapWithIndexArray4D(*input, subtract_index); - auto actual_literal = Literal::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.Fill(0.0f); @@ -144,7 +144,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { TEST_F(ReferenceUtilTest, SliceArray2D) { auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}}); - auto actual_literal = Literal::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 2.f}, {4.f, 5.f}}, *actual_literal, ErrorSpec(0.0001)); @@ -152,7 +152,7 @@ TEST_F(ReferenceUtilTest, SliceArray2D) { TEST_F(ReferenceUtilTest, SliceStridedArray2D) { auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}}); - auto actual_literal = Literal::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 3.f}, {4.f, 6.f}}, *actual_literal, ErrorSpec(0.0001)); @@ -164,7 +164,7 @@ TEST_F(ReferenceUtilTest, SliceArray3D) { auto result = ReferenceUtil::Slice3D(input, {{0, 0, 0}}, {{2, 2, 2}}, {{1, 1, 1}}); - auto actual_literal = Literal::CreateR3FromArray3D(*result); + auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result); LiteralTestUtil::ExpectR3Near( {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal, @@ -177,7 +177,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray3D) { auto result = ReferenceUtil::Slice3D(input, {{0, 0, 0}}, {{2, 3, 4}}, {{1, 2, 2}}); - auto actual_literal = Literal::CreateR3FromArray3D(*result); + auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result); LiteralTestUtil::ExpectR3Near( {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, @@ -190,7 +190,7 @@ TEST_F(ReferenceUtilTest, SliceArray4D) { auto result = ReferenceUtil::Slice4D(input, {{1, 0, 0, 0}}, {{2, 2, 2, 2}}, {{1, 1, 1, 1}}); - auto actual_literal = Literal::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); LiteralTestUtil::ExpectR4Near( {{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}}, @@ -203,7 +203,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray4D) { auto result = ReferenceUtil::Slice4D(input, {{1, 0, 0, 0}}, {{2, 3, 4, 5}}, {{1, 2, 2, 2}}); - auto actual_literal = Literal::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); LiteralTestUtil::ExpectR4Near( {{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}}, @@ -218,7 +218,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) { ReferenceUtil::ConvArray3D(input, weights, 1, Padding::kSame); Array3D expected = {{{17, 28, 39, 20}}}; - auto actual_literal = Literal::CreateR3FromArray3D(*actual); + auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual); LiteralTestUtil::ExpectR3NearArray3D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -231,7 +231,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) { ReferenceUtil::ConvArray3D(input, weights, 1, Padding::kValid); Array3D expected = {{{17, 28, 39}}}; - auto actual_literal = Literal::CreateR3FromArray3D(*actual); + auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual); LiteralTestUtil::ExpectR3NearArray3D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -266,7 +266,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) { })); // clang-format on - auto actual_literal = Literal::CreateR4FromArray4D(*actual); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -300,7 +300,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) { })); // clang-format on - auto actual_literal = Literal::CreateR4FromArray4D(*actual); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -356,7 +356,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) { }}); // clang-format on - auto actual_literal = Literal::CreateR4FromArray4D(*actual); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -409,7 +409,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { Array4D expected({{{{2514, 2685}}}}); // clang-format on - auto actual_literal = Literal::CreateR4FromArray4D(*actual); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -422,7 +422,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) { auto actual = ReferenceUtil::ApplyElementwise2D( [](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c); - auto actual_literal = Literal::CreateR2FromArray2D(*actual); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual); LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}}, *actual_literal, ErrorSpec(0.0001)); } diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 0b1cec1925d4424db086f8a3f62c91ede090189c..44b22a5586dee3f7dd8ea0edbf9deb2090986ac8 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -56,7 +56,7 @@ tf_cc_test( ":grpc_stub", "//tensorflow:grpc++", "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index f8414468bd9e0a9faf0072c47d94d12ab11b908d..67886761813f0bb45a600661b017be91ffeade73 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "grpcpp/security/credentials.h" #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/rpc/grpc_stub.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/lib/io/path.h" @@ -97,7 +97,7 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) { 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796, 6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327}; std::unique_ptr expected_literal = - Literal::CreateR1(expected); + LiteralUtil::CreateR1(expected); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer( computation, {}, nullptr)); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index fe99f700d23dbab799ba011b705c59d6ef7a2e52..1b93d72a3e2ad64f54407667b98fc5ef86a13d02 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -136,7 +136,7 @@ cc_library( ":hlo_dce", ":hlo_pass", ":tuple_simplifier", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", @@ -182,6 +182,7 @@ tf_cc_test( name = "shape_inference_test", srcs = ["shape_inference_test.cc"], deps = [ + ":hlo", ":shape_inference", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -227,6 +228,7 @@ cc_library( ":hlo", ":hlo_query", ":shape_inference", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -244,7 +246,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_evaluator", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -254,7 +256,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -294,6 +296,7 @@ cc_library( ":hlo_reachability", ":name_uniquer", "//tensorflow/compiler/xla:array", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_tree", @@ -396,6 +399,7 @@ tf_cc_test( deps = [ ":hlo_matchers", ":hlo_parser", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -407,7 +411,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_parser", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -424,7 +428,7 @@ tf_cc_test( srcs = ["hlo_sharding_test.cc"], deps = [ ":hlo", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -453,7 +457,7 @@ tf_cc_test( srcs = ["call_graph_test.cc"], deps = [ ":call_graph", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -487,6 +491,7 @@ cc_library( hdrs = ["call_inliner.h"], deps = [ ":call_graph", + ":hlo_dce", ":hlo_pass", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", @@ -502,7 +507,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -521,7 +526,7 @@ tf_cc_test( deps = [ ":call_graph", ":flatten_call_graph", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -559,7 +564,7 @@ cc_library( ":computation_placer", ":device_memory_allocator", ":platform_util", - ":pool", + ":stream_pool", ":transfer_manager", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -593,6 +598,7 @@ cc_library( ":hlo_proto_util", ":platform_util", ":source_map_util", + ":stream_pool", ":transfer_manager", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:execution_options_util", @@ -637,7 +643,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:executable_build_options", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], @@ -746,8 +752,8 @@ cc_library( ":hlo_execution_profile", ":hlo_graph_dumper", ":hlo_proto", - ":pool", ":shaped_buffer", + ":stream_pool", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -797,7 +803,7 @@ cc_library( hdrs = ["transfer_manager.h"], deps = [ ":shaped_buffer", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -833,7 +839,7 @@ cc_library( hdrs = ["execution_tracker.h"], deps = [ ":backend", - ":pool", + ":stream_pool", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -941,7 +947,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ], @@ -960,7 +965,7 @@ tf_cc_test( ":hlo", ":hlo_ordering", ":hlo_scheduling", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -1038,7 +1043,7 @@ tf_cc_test( ":hlo_ordering", ":hlo_value", ":tuple_points_to_analysis", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1052,6 +1057,7 @@ cc_library( hdrs = ["hlo_module_group_metadata.h"], deps = [ ":hlo", + ":hlo_casting_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -1121,7 +1127,7 @@ cc_library( hdrs = ["hlo_query.h"], deps = [ ":hlo", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", ], ) @@ -1170,6 +1176,7 @@ cc_library( deps = [ ":hlo", ":shape_inference", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -1200,6 +1207,7 @@ cc_library( deps = [ ":hlo", ":hlo_pass", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1219,6 +1227,7 @@ cc_library( ":hlo_creation_utils", ":hlo_pass", ":while_util", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", ], @@ -1232,8 +1241,9 @@ tf_cc_test( ":batchnorm_expander", ":hlo", ":hlo_matchers", + ":hlo_parser", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -1255,6 +1265,7 @@ cc_library( ":hlo_pass", ":hlo_query", ":pattern_matcher", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1274,7 +1285,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -1310,7 +1321,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -1345,7 +1356,7 @@ cc_library( ":call_inliner", ":hlo", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -1361,6 +1372,7 @@ tf_cc_test( ":conditional_simplifier", ":hlo", ":hlo_matchers", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1372,6 +1384,18 @@ tf_cc_test( ], ) +cc_library( + name = "while_loop_analysis", + srcs = ["while_loop_analysis.cc"], + hdrs = ["while_loop_analysis.h"], + deps = [ + ":hlo", + ":hlo_evaluator", + "//tensorflow/compiler/xla:literal", + "//tensorflow/core:lib", + ], +) + cc_library( name = "while_loop_simplifier", srcs = ["while_loop_simplifier.cc"], @@ -1379,8 +1403,8 @@ cc_library( deps = [ ":call_inliner", ":hlo", - ":hlo_evaluator", ":hlo_pass", + ":while_loop_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", ], @@ -1420,7 +1444,7 @@ tf_cc_test( deps = [ ":defuser", ":hlo_matchers", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", ], @@ -1448,7 +1472,7 @@ tf_cc_test( deps = [ ":hlo_matchers", ":implicit_broadcast_remover", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", ], @@ -1490,7 +1514,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":tuple_simplifier", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -1505,7 +1529,7 @@ cc_library( hdrs = ["reshape_mover.h"], deps = [ ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", @@ -1520,7 +1544,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":reshape_mover", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -1555,7 +1579,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":inliner", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", @@ -1572,7 +1596,7 @@ cc_library( hdrs = ["computation_placer.h"], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -1604,7 +1628,7 @@ cc_library( hdrs = ["generic_transfer_manager.h"], deps = [ ":transfer_manager", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1651,8 +1675,8 @@ tf_cc_test( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -1695,7 +1719,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_matchers", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -1710,6 +1734,7 @@ tf_cc_binary( deps = [ ":hlo", ":hlo_graph_dumper", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1724,7 +1749,7 @@ tf_cc_test( srcs = ["hlo_module_test.cc"], deps = [ ":hlo", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", @@ -1822,7 +1847,7 @@ tf_cc_test( ":hlo_matchers", ":hlo_ordering", ":instruction_fusion", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -1859,7 +1884,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_liveness_analysis", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -1920,7 +1945,7 @@ tf_cc_test( ":hlo_matchers", ":hlo_ordering", ":instruction_fusion", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -1955,6 +1980,7 @@ cc_library( ":hlo_dataflow_analysis", ":logical_buffer", ":logical_buffer_analysis", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -1973,6 +1999,7 @@ tf_cc_test( ":hlo_matchers", ":instruction_fusion", ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1996,6 +2023,7 @@ cc_library( deps = [ ":computation_layout", ":hlo", + ":hlo_casting_utils", ":hlo_dce", ":hlo_graph_dumper", ":hlo_pass", @@ -2044,7 +2072,7 @@ tf_cc_test( ":hlo_graph_dumper", ":hlo_matchers", ":hlo_runner", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", @@ -2108,6 +2136,7 @@ tf_cc_test( srcs = ["hlo_verifier_test.cc"], deps = [ ":hlo", + ":hlo_parser", ":hlo_verifier", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -2169,6 +2198,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_dce", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -2189,7 +2219,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_module_dce", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -2213,7 +2243,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":layout_assignment", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -2272,7 +2302,7 @@ cc_library( ":hlo", ":hlo_domain_map", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -2288,7 +2318,7 @@ tf_cc_test( ":hlo", ":hlo_cse", ":hlo_matchers", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -2310,7 +2340,7 @@ cc_library( ":hlo_evaluator", ":hlo_pass", ":hlo_query", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", @@ -2325,7 +2355,7 @@ tf_cc_test( ":hlo_constant_folding", ":hlo_matchers", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -2362,6 +2392,20 @@ cc_library( ], ) +cc_library( + name = "hlo_domain_verifier", + srcs = ["hlo_domain_verifier.cc"], + hdrs = ["hlo_domain_verifier.h"], + deps = [ + ":hlo", + ":hlo_domain_map", + ":hlo_graph_dumper", + ":hlo_pass", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + ], +) + cc_library( name = "hlo_domain_isolator", srcs = ["hlo_domain_isolator.cc"], @@ -2381,8 +2425,8 @@ cc_library( hdrs = ["hlo_domain_remover.h"], deps = [ ":hlo", - ":hlo_domain_isolator", ":hlo_domain_map", + ":hlo_domain_verifier", ":hlo_graph_dumper", ":hlo_pass", "//tensorflow/compiler/xla:types", @@ -2417,7 +2461,7 @@ cc_library( ":hlo_evaluator", ":hlo_pass", ":hlo_query", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", @@ -2552,7 +2596,7 @@ cc_library( hdrs = ["hlo_tfgraph_builder.h"], deps = [ ":hlo", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:framework", @@ -2583,7 +2627,7 @@ cc_library( ":hlo_casting_utils", ":hlo_execution_profile", ":hlo_tfgraph_builder", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:window_util", @@ -2601,6 +2645,7 @@ tf_cc_test( deps = [ ":hlo", ":hlo_graph_dumper", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/tests:test_utils", @@ -2632,12 +2677,12 @@ tf_cc_test( ":hlo_matchers", ":shape_inference", ":transpose_folding", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -2653,7 +2698,7 @@ cc_library( deps = [ ":hlo", ":hlo_pass", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", @@ -2668,13 +2713,13 @@ tf_cc_test( ":hlo", ":shape_inference", ":zero_sized_hlo_elimination", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -2682,21 +2727,25 @@ tf_cc_test( ) cc_library( - name = "pool", - hdrs = ["pool.h"], + name = "stream_pool", + srcs = ["stream_pool.cc"], + hdrs = ["stream_pool.h"], deps = [ + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", ], ) tf_cc_test( - name = "pool_test", - srcs = ["pool_test.cc"], + name = "stream_pool_test", + srcs = ["stream_pool_test.cc"], deps = [ - ":pool", + ":stream_pool", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -2828,6 +2877,7 @@ cc_library( ":hlo", ":hlo_creation_utils", ":tuple_util", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/core:lib", ], ) @@ -2963,6 +3013,7 @@ cc_library( ":hlo", ":hlo_lexer", ":hlo_sharding_metadata", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 48fd07371d60a27404be25af163678b582b02d7b..946ef6f0d6b9025b84c4b9341f4ec600465d4b1e 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -149,6 +150,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) override; + Status HandleSort(HloInstruction* sort) override; + Status HandleTranspose(HloInstruction* transpose) override; Status HandleSubtract(HloInstruction* sub) override; @@ -195,7 +198,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { HloInstruction* zero = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::Zero(hlo->shape().element_type()).CloneToUnique())); + LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( @@ -537,8 +540,8 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { // If a literal is all the same element replace it with a scalar broadcast. if (ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsAllFirst()) { - std::unique_ptr unique_scalar = - MakeUnique(constant->literal().GetFirstScalarLiteral()); + std::unique_ptr unique_scalar = MakeUnique( + LiteralUtil::GetFirstScalarLiteral(constant->literal())); HloInstruction* scalar = computation_->AddInstruction( HloInstruction::CreateConstant(std::move(unique_scalar))); return ReplaceWithNewInstruction( @@ -1093,7 +1096,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { ShapeUtil::IsZeroElementArray(lhs->shape()) || ShapeUtil::IsZeroElementArray(rhs->shape())) { auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); return ReplaceWithNewInstruction( dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } @@ -1155,6 +1158,19 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { return Status::OK(); } + // 0*A => 0. Only applies for integral types for correct NaN-handling. + if (IsAll(lhs, 0) && + primitive_util::IsIntegralType(multiply->shape().element_type()) && + ReplaceInstructionIfSameShape(multiply, lhs)) { + return Status::OK(); + } + // A*0 => 0 + if (IsAll(rhs, 0) && + primitive_util::IsIntegralType(multiply->shape().element_type()) && + ReplaceInstructionIfSameShape(multiply, rhs)) { + return Status::OK(); + } + // exp(A) * exp(B) => exp(A+B) if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) { auto add = computation_->AddInstruction(HloInstruction::CreateBinary( @@ -1252,9 +1268,10 @@ bool OutputIsPermutationOfOperandElements(HloInstruction* instruction, switch (instruction->opcode()) { case HloOpcode::kReshape: case HloOpcode::kReverse: - case HloOpcode::kSort: case HloOpcode::kTranspose: return true; + case HloOpcode::kSort: + return (!ShapeUtil::IsTuple(instruction->shape())); default: return false; } @@ -1518,7 +1535,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs)))); if (IsAll(rhs, 0)) { auto one = HloInstruction::CreateConstant( - Literal::One(power->shape().element_type()).CloneToUnique()); + LiteralUtil::One(power->shape().element_type()).CloneToUnique()); std::unique_ptr ones; if (ShapeUtil::IsScalar(power->shape())) { ones = std::move(one); @@ -1553,7 +1570,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); if (IsAll(rhs, -1)) { auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::One(rhs->shape().element_type()).CloneToUnique())); + LiteralUtil::One(rhs->shape().element_type()).CloneToUnique())); // Explicitly broadcast scalar 1 to the output shape, to avoid implicit // broadcast in divide HLO as we are trying to eliminate implicit @@ -1729,19 +1746,37 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) { return Status::OK(); } + + auto is_unstrided_slice = [](const HloInstruction* hlo) { + return c_all_of(hlo->slice_strides(), + [](int64 stride) { return stride == 1; }); + }; + if (slice->operand(0)->opcode() == HloOpcode::kSlice && + is_unstrided_slice(slice) && is_unstrided_slice(slice->operand(0))) { + HloInstruction* operand_slice = slice->mutable_operand(0); + std::vector new_slice_starts = slice->slice_starts(); + std::vector new_slice_limits = slice->slice_limits(); + for (int64 i = 0; i < new_slice_starts.size(); ++i) { + new_slice_starts[i] += operand_slice->slice_starts(i); + new_slice_limits[i] += operand_slice->slice_starts(i); + } + return ReplaceWithNewInstruction( + slice, HloInstruction::CreateSlice( + slice->shape(), operand_slice->mutable_operand(0), + new_slice_starts, new_slice_limits, slice->slice_strides())); + } return Status::OK(); } Status AlgebraicSimplifierVisitor::HandleDynamicSlice( HloInstruction* dynamic_slice) { auto operand = dynamic_slice->mutable_operand(0); - auto start_indices = dynamic_slice->operand(1); if (ShapeUtil::IsScalar(dynamic_slice->shape())) { return ReplaceInstruction(dynamic_slice, operand); } - // DynamicSlice where operand has the same size as the output and - // start_indices are all zero is simply equal to operand. - if (IsAll(start_indices, 0) && SameShape(operand, dynamic_slice)) { + // DynamicSlice where operand has the same size as the output is simply equal + // to operand. + if (SameShape(operand, dynamic_slice)) { return ReplaceInstruction(dynamic_slice, operand); } return Status::OK(); @@ -1750,20 +1785,10 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice( Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) { auto update = dynamic_update_slice->mutable_operand(1); - auto start_indices = dynamic_update_slice->operand(2); - // DynamicUpdateSlice on a scalar just passes through the update argument. - if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) { - return ReplaceInstruction(dynamic_update_slice, update); - } - // DynamicUpdateSlice where operand and update have the same size and - // start_indices are all zero is simply equal to update. - // - // (We require start_indices to be all zero because we want this optimization - // not to affect the visible behavior of this op even when the indices are out - // of range. Currently dynamic-update-slice wraps out-of-range indices, so - // we can only remove the op if its indices never wrap.) - if (IsAll(start_indices, 0) && SameShape(dynamic_update_slice, update)) { + // DynamicUpdateSlice where operand and update have the same size is simply + // equal to update. + if (SameShape(dynamic_update_slice, update)) { return ReplaceInstruction(dynamic_update_slice, update); } @@ -1889,6 +1914,26 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { new_reduce_dimensions, function)); } } + // Convert Reduce(concat({a,b,...})) to + // map(reduce(a),map(reduce(b),...,)) + // + // This should make fusion easier or use less memory bandwidth in the unfused + // case. + if (arg->opcode() == HloOpcode::kConcatenate && + c_linear_search(reduce->dimensions(), arg->concatenate_dimension())) { + HloInstruction* old_reduce = nullptr; + for (HloInstruction* operand : arg->operands()) { + HloInstruction* new_reduce = computation_->AddInstruction( + HloInstruction::CreateReduce(reduce->shape(), operand, init_value, + reduce->dimensions(), function)); + if (old_reduce != nullptr) { + new_reduce = computation_->AddInstruction(HloInstruction::CreateMap( + reduce->shape(), {old_reduce, new_reduce}, function)); + } + old_reduce = new_reduce; + } + return ReplaceInstruction(reduce, old_reduce); + } return Status::OK(); } @@ -2062,6 +2107,21 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( /*reduce_computation=*/function)); } +Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { + auto operand = sort->mutable_operand(0); + int64 dimension_to_sort = sort->dimensions(0); + if (ShapeUtil::IsZeroElementArray(operand->shape()) || + operand->shape().dimensions(dimension_to_sort) <= 1) { + if (sort->operand_count() == 1) { + return ReplaceInstruction(sort, operand); + } + // If it is key/value sort, the output of sort is a tuple. + return ReplaceWithNewInstruction( + sort, HloInstruction::CreateTuple({operand, sort->mutable_operand(1)})); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { auto operand = transpose->mutable_operand(0); if (std::is_sorted(transpose->dimensions().begin(), @@ -2097,7 +2157,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( HloInstruction::CreateBroadcast( convolution->shape(), computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::Zero(convolution->shape().element_type()) + LiteralUtil::Zero(convolution->shape().element_type()) .CloneToUnique())), {})); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index b733f6f59eb028b2dff921722c462441251772fe..862cbeeba6b82e1f24a6616b3237dc47d022e9af 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -60,7 +60,7 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero)); @@ -74,12 +74,32 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { EXPECT_EQ(root, param0); } +// Test that A * 0 is simplified to 0 +TEST_F(AlgebraicSimplifierTest, MulZero) { + Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0s32, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0s32, HloOpcode::kMultiply, param0, zero)); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kMultiply); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_EQ(computation->root_instruction(), zero); +} + // Test that Reduce(Reduce(A)) -> Reduce(A) TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { HloComputation::Builder builder(TestName()); // Create add computation. HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); HloComputation* add_computation = nullptr; { HloComputation::Builder builder(TestName() + ".add"); @@ -119,7 +139,7 @@ TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0)); @@ -140,9 +160,9 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.14159f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.14159f))); HloInstruction* add1 = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1)); @@ -165,7 +185,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); HloInstruction* bcast = builder.AddInstruction( HloInstruction::CreateBroadcast(r2f32, zero, {0, 1})); builder.AddInstruction( @@ -200,7 +220,7 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.AddInstruction(HloInstruction::CreateMap( r2f32, {param0, builder.AddInstruction( @@ -223,7 +243,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 0, 0}))); HloInstruction* bcast = builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1})); builder.AddInstruction( @@ -242,7 +262,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) { HloComputation::Builder builder(TestName()); builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({3.14f, 3.14f, 3.14f}))); + LiteralUtil::CreateR1({3.14f, 3.14f, 3.14f}))); auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); @@ -258,7 +278,7 @@ TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) { TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) { HloComputation::Builder builder(TestName()); builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({3.14, 3.14, 4}))); + LiteralUtil::CreateR1({3.14, 3.14, 4}))); auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); @@ -277,7 +297,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero)); @@ -298,7 +318,7 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kSubtract, param0, constant)); @@ -493,7 +513,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({0.f, 1.f, 2.f}))); + LiteralUtil::CreateR1({0.f, 1.f, 2.f}))); builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, constant)); @@ -559,7 +579,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one)); @@ -580,7 +600,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 1.0}, {1.0, 1.0}}))); + LiteralUtil::CreateR2({{1.0, 1.0}, {1.0, 1.0}}))); HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one)); @@ -860,7 +880,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero)); @@ -884,7 +904,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero)); @@ -912,7 +932,7 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one)); @@ -934,7 +954,7 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* two = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two)); @@ -956,7 +976,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* negative_one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(-1))); builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, negative_one)); @@ -1047,7 +1067,7 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { builder.AddInstruction(HloInstruction::CreateReduceWindow( ShapeUtil::MakeShape(F32, {5, 2}), param, builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), window, add_computation)); module().AddEntryComputation(builder.Build()); HloPassFix simplifier(/*is_layout_sensitive=*/false, @@ -1074,7 +1094,7 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {5, 2}), param, builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), padding)); module().AddEntryComputation(builder.Build()); EXPECT_THAT(module().entry_computation()->root_instruction(), @@ -1116,7 +1136,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); @@ -1208,7 +1228,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r1f32, "param1")); HloInstruction* empty_literal = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); HloInstruction* empty_slice = builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1})); @@ -1230,6 +1250,55 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { op::Concatenate(param0, param0, param1)); } +// Test that reduce of concat is simplified. +TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { + const int kParamLength = 100; + Shape r3f32 = + ShapeUtil::MakeShape(F32, {kParamLength, kParamLength, kParamLength}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r3f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r3f32, "param1")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r3f32, "param2")); + Shape concat_shape = + ShapeUtil::MakeShape(F32, {kParamLength, 3 * kParamLength, kParamLength}); + HloInstruction* Concatenate = + builder.AddInstruction(HloInstruction::CreateConcatenate( + concat_shape, {param0, param1, param2}, 1)); + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module().AddEmbeddedComputation(builder.Build()); + } + Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); + Shape reduce_shape = ShapeUtil::MakeShape(F32, {kParamLength}); + + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + builder.AddInstruction(HloInstruction::CreateReduce( + reduce_shape, Concatenate, zero, {1, 2}, add_computation)); + + auto computation = module().AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT( + computation->root_instruction(), + op::Map(op::Map(op::Reduce(param0, zero), op::Reduce(param1, zero)), + op::Reduce(param2, zero))); +} + // Test a concatenate with only empty operands is removed. TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { const int kParamLength = 100; @@ -1238,7 +1307,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* empty_literal = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); HloInstruction* empty_slice = builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1})); @@ -1420,7 +1489,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param0")), builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{0, 0}, {0, 0}}))))); + LiteralUtil::CreateR2({{0, 0}, {0, 0}}))))); builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add)); @@ -1443,7 +1512,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param0")), builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{0, 0}, {0, 0}}))))); + LiteralUtil::CreateR2({{0, 0}, {0, 0}}))))); builder.AddInstruction( HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add, @@ -1726,7 +1795,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {2, 2}), "param")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); PaddingConfig no_padding; for (int i = 0; i < 2; ++i) { auto dimension = no_padding.add_dimensions(); @@ -1757,7 +1826,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {10, 10}), "param")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); PaddingConfig padding; int64 low_padding[2] = {-1, -2}; int64 high_padding[2] = {2, -3}; @@ -1839,6 +1908,73 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { EXPECT_THAT(computation->root_instruction(), param); } +TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) { + HloComputation::Builder builder(TestName()); + const int64 dim0 = 11; + const int64 dim1 = 12; + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param")); + HloInstruction* original_slice = + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {dim0 - 2, dim1 - 4}), param, + /*start_indices=*/{1, 2}, + /*limit_indices=*/{dim0 - 1, dim1 - 2}, /*strides=*/{1, 1})); + + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {dim0 - 5, dim1 - 9}), original_slice, + /*start_indices=*/{2, 3}, + /*limit_indices=*/{dim0 - 3, dim1 - 6}, /*strides=*/{1, 1})); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Slice(op::Slice(param))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Slice(param)); + EXPECT_EQ(computation->root_instruction()->slice_starts(0), 3); + EXPECT_EQ(computation->root_instruction()->slice_starts(1), 5); + EXPECT_EQ(computation->root_instruction()->slice_limits(0), dim0 - 2); + EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4); +} + +TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {1}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), keys); +} + +TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {5, 0}); + Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto values = builder.AddInstruction( + HloInstruction::CreateParameter(1, values_shape, "values")); + builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values)); +} + TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { struct ConvTestOptions { int in_batch = 10; @@ -1870,7 +2006,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { // Builds a convolution from and runs algebraic simplification on // the computation. Returns a string description of the result of // simplification. - auto build_and_simplify = [&options, this]() -> string { + auto build_and_simplify = [&options]() -> string { HloComputation::Builder b(TestName()); Window window; @@ -2109,7 +2245,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloComputation::Builder builder(TestName()); HloInstruction* forty_two = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6}); HloInstruction* broadcast = builder.AddInstruction( @@ -2156,7 +2292,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { padding.mutable_dimensions(3)->set_edge_padding_high(2); HloInstruction* pad_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding)); @@ -2187,7 +2323,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { const Shape reduce_window_shape = ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); HloInstruction* reduce_init_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); HloInstruction* reduce_window = builder.AddInstruction(HloInstruction::CreateReduceWindow( reduce_window_shape, pad, reduce_init_value, window, @@ -2238,7 +2374,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { padding.mutable_dimensions(3)->set_edge_padding_high(2); HloInstruction* pad_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding)); @@ -2273,7 +2409,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { const Shape reduce_window_shape = ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); HloInstruction* reduce_init_value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); HloInstruction* reduce_window = builder.AddInstruction(HloInstruction::CreateReduceWindow( reduce_window_shape, convert, reduce_init_value, window, @@ -2344,9 +2480,9 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { HloComputation::Builder call_builder(TestName() + ".Call"); HloInstruction* zero = call_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0.0f}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0.0f}))); HloInstruction* one = call_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1.0f}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1.0f}))); call_builder.AddInstruction( HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); @@ -2362,9 +2498,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { HloComputation::Builder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; - std::unique_ptr value = - Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), - Literal::CreateR1(constant_vector).get()}); + std::unique_ptr value = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(constant_scalar).get(), + LiteralUtil::CreateR1(constant_vector).get()}); builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); auto computation = module().AddEntryComputation(builder.Build()); @@ -2387,8 +2523,8 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { shape, builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "slice_from")), - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))), + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), /*slice_sizes=*/{10, 100, 1000})); auto computation = module().AddEntryComputation(builder.Build()); @@ -2421,8 +2557,8 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { builder.AddInstruction( HloInstruction::CreateParameter(2, slice_shape, "to_update")), slice, - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))))); + builder.AddInstruction(HloInstruction::CreateParameter( + 3, ShapeUtil::MakeShape(U32, {3}), "update_indices")))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -2437,7 +2573,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { HloComputation::Builder builder(TestName()); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* input_array = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({3, 4}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({3, 4}))); HloInstruction* inner_bcast = builder.AddInstruction( HloInstruction::CreateBroadcast(r2f32, input_array, {1})); Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2}); @@ -2546,7 +2682,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( pad_shape, input, builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), padding)); HloComputation* add_computation = nullptr; @@ -2565,7 +2701,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { Window window = window_util::MakeWindow( decorate_spatials(param.reduce_window_spatials, 1, 1)); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, ShapeInference::InferReduceWindowShape( pad->shape(), zero->shape(), window, @@ -2704,7 +2840,7 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k}); auto* lhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k))); Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n}); @@ -2783,7 +2919,7 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n}); auto* rhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n))); DotDimensionNumbers dot_dnums; @@ -2830,7 +2966,7 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { HloInstruction* const update = builder.AddInstruction( HloInstruction::CreateParameter(1, update_shape, "update")); HloInstruction* const start_indices = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0}))); builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( dslice_shape, operand, update, start_indices)); const HloComputation* const computation = @@ -2879,7 +3015,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase); Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols}); auto* lhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows, /*cols=*/lhs_cols))); @@ -2887,7 +3023,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { int32 start_col = (spec.lcd == 0) ? spec.s : 0; const auto start_indices = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({start_row, start_col}))); + LiteralUtil::CreateR1({start_row, start_col}))); int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1; int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k; Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); @@ -2898,7 +3034,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k; Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols}); auto* rhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows, /*cols=*/rhs_cols))); @@ -2946,7 +3082,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k; Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols}); auto* lhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows, /*cols=*/lhs_cols))); @@ -2957,7 +3093,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase); Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols}); auto* rhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows, /*cols=*/rhs_cols))); @@ -2965,7 +3101,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { int32 start_col = (spec.rcd == 0) ? spec.s : 0; const auto start_indices = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({start_row, start_col}))); + LiteralUtil::CreateR1({start_row, start_col}))); int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1; int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k; Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 95b4cb6d2e694063b648b264bd2454ae0a5469ff..51ebc4763b612884a4453edec5711f78c4006fc3 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -109,11 +109,11 @@ Status AllocationTracker::Unregister(const GlobalDataHandle& data) { ResolveInternal(data)); for (const auto& shaped_buffer : replicated_buffers) { std::vector shape_indices; - ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), - [this, &shape_indices](const Shape& /*subshape*/, - const ShapeIndex& index) { - shape_indices.push_back(index); - }); + ShapeUtil::ForEachSubshape( + shaped_buffer->on_device_shape(), + [&shape_indices](const Shape& /*subshape*/, const ShapeIndex& index) { + shape_indices.push_back(index); + }); for (const ShapeIndex& index : shape_indices) { TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index), shaped_buffer->device_ordinal())); diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 349b32451a697dbd6804b44cd1a36419c753bb14..d12be3e007fe0b16ac850d64521f0025d481b5d2 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -96,24 +96,19 @@ Backend::CreateDefaultBackend() { return CreateBackend(backend_options); } -StatusOr Backend::BorrowStream(int device_ordinal) { - TF_ASSIGN_OR_RETURN(auto exec, stream_executor(device_ordinal)); - return BorrowStream(exec); +StatusOr Backend::BorrowStream(int device_ordinal) { + TF_ASSIGN_OR_RETURN(auto executor, stream_executor(device_ordinal)); + return BorrowStream(executor); } -StatusOr Backend::BorrowStream( - se::StreamExecutor* executor) { +StatusOr Backend::BorrowStream(se::StreamExecutor* executor) { tensorflow::mutex_lock l(mu_); if (0 == stream_pools_.count(executor)) { stream_pools_.emplace(std::piecewise_construct, std::forward_as_tuple(executor), - std::forward_as_tuple([executor]() { - auto stream = MakeUnique(executor); - stream->Init(); - return stream; - })); + std::forward_as_tuple()); } - return stream_pools_.at(executor).Allocate(); + return stream_pools_.at(executor).BorrowStream(executor); } Backend::Backend( diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index 6546602473e3381cf13879ddebd05d34d1f7a055..1bc3796fa48c1627538474d04ef5358ba64dfce9 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" -#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -63,11 +63,9 @@ class BackendOptions { // // It also offers a pooling API for creation/use of initialized streams: // -// StreamPtr stream = backend->BorrowStream().ConsumeValueOrDie(); +// StreamPool::Ptr stream = backend->BorrowStream().ConsumeValueOrDie(); class Backend { public: - using StreamPtr = Pool::SmartPtr; - // Creates a new backend. static StatusOr> CreateBackend( const BackendOptions& options); @@ -114,13 +112,13 @@ class Backend { // Borrows a stream for use by the caller, either by grabbing it from an // internal pool, or by constructing/initializating it, and returns the result // to the caller. - StatusOr BorrowStream(int device_ordinal); - StatusOr BorrowStream(se::StreamExecutor* executor); + StatusOr BorrowStream(int device_ordinal); + StatusOr BorrowStream(se::StreamExecutor* executor); // Returns a function to borrow a stream, as `BorrowStream` above does. // Purely for convenience, the caller could rather make this anonymous // function itself. - std::function(int)> StreamBorrower() { + std::function(int)> StreamBorrower() { return [this](int device_ordinal) { return BorrowStream(device_ordinal); }; } @@ -169,7 +167,7 @@ class Backend { tensorflow::mutex mu_; // Mapping from stream executor to stream pools, used by `BorrowStream` above. - std::map> stream_pools_ GUARDED_BY(mu_); + std::map stream_pools_ GUARDED_BY(mu_); // The default memory allocator to use. std::unique_ptr memory_allocator_; diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index ec13fadbc75e2315d1d6ef72e24a0faca0c7de40..c4cd60c1201f7ddbf0aba4b6d587952531b74bfa 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -41,6 +43,8 @@ namespace xla { namespace { +using tensorflow::gtl::optional; + // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm // operations into smaller operations. class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { @@ -97,7 +101,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { add_instruction(HloInstruction::CreateConvert( ShapeUtil::MakeShape(operand->shape().element_type(), {}), add_instruction(HloInstruction::CreateConstant( - Literal::CreateR0(-0.5f))))), + LiteralUtil::CreateR0(-0.5f))))), {})); return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kPower, operand, exponent); @@ -113,7 +117,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { add_instruction(HloInstruction::CreateConvert( ShapeUtil::MakeShape(operand->shape().element_type(), {}), add_instruction(HloInstruction::CreateConstant( - Literal::CreateR0(1.0 / element_count))))), + LiteralUtil::CreateR0(1.0 / element_count))))), {})); return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply, operand, elem_count_recip); @@ -200,11 +204,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( HloInstruction* offset = batch_norm->mutable_operand(2); const Shape feature_shape = scale->shape(); - auto zero_literal = Literal::CreateR0(0.0f); + auto zero_literal = LiteralUtil::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); - auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); auto epsilon = add(HloInstruction::CreateBroadcast( operand_shape, @@ -288,16 +292,22 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( int64 instruction_count_after = computation_->instruction_count(); CHECK_EQ(instruction_count_after, instruction_count_before + added_instructions.size()); + const HloSharding& sharding = batch_norm->sharding(); HloSharding operand_sharding = - batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0}); + sharding.GetAsShapeTree(batch_norm->shape()).element({0}); + optional unique_device = batch_norm->sharding_unique_device(); + HloSharding default_sharding = + unique_device.has_value() + ? HloSharding::AssignDevice(unique_device.value()) + : HloSharding::Replicate(); for (HloInstruction* inst : added_instructions) { if (ShapeUtil::Equal(inst->shape(), operand_shape)) { inst->set_sharding(operand_sharding); } else { - inst->set_sharding(HloSharding::Replicate()); + inst->set_sharding(default_sharding); } } - tuple->set_sharding(batch_norm->sharding()); + tuple->set_sharding(sharding); } TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); return Status::OK(); @@ -320,7 +330,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( HloInstruction* var = batch_norm->mutable_operand(4); const Shape feature_shape = scale->shape(); - auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast( operand_shape, @@ -388,14 +398,20 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( CHECK_EQ(instruction_count_after, instruction_count_before + added_instructions.size()); if (batch_norm->has_sharding()) { + const HloSharding& sharding = batch_norm->sharding(); + optional unique_device = batch_norm->sharding_unique_device(); + HloSharding default_sharding = + unique_device.has_value() + ? HloSharding::AssignDevice(unique_device.value()) + : HloSharding::Replicate(); for (HloInstruction* inst : added_instructions) { if (ShapeUtil::Equal(inst->shape(), operand_shape)) { - inst->set_sharding(batch_norm->sharding()); + inst->set_sharding(sharding); } else { - inst->set_sharding(HloSharding::Replicate()); + inst->set_sharding(default_sharding); } } - shifted_normalized->set_sharding(batch_norm->sharding()); + shifted_normalized->set_sharding(sharding); } TF_CHECK_OK( ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized))); @@ -447,11 +463,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( const int64 feature_count = activation_shape.dimensions(feature_index); const int64 elements_per_feature_int64 = size_in_elements / feature_count; - auto zero_literal = Literal::CreateR0(0.0f); + auto zero_literal = LiteralUtil::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); - auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); auto epsilon_scalar = add(HloInstruction::CreateConstant(std::move(epsilon_literal))); @@ -542,7 +558,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add)); auto elements_per_feature_literal = - Literal::CreateR0(elements_per_feature_int64); + LiteralUtil::CreateR0(elements_per_feature_int64); TF_ASSIGN_OR_RETURN(elements_per_feature_literal, elements_per_feature_literal->Convert(ptype)); auto elements_per_feature = add( @@ -562,19 +578,25 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( auto tuple = HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}); if (batch_norm->has_sharding()) { + const HloSharding& sharding = batch_norm->sharding(); int64 instruction_count_after = computation_->instruction_count(); CHECK_EQ(instruction_count_after, instruction_count_before + added_instructions.size()); HloSharding activation_sharding = - batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0}); + sharding.GetAsShapeTree(batch_norm->shape()).element({0}); + auto unique_device = batch_norm->sharding_unique_device(); + HloSharding default_sharding = + unique_device.has_value() + ? HloSharding::AssignDevice(unique_device.value()) + : HloSharding::Replicate(); for (HloInstruction* inst : added_instructions) { if (ShapeUtil::Equal(inst->shape(), activation_shape)) { inst->set_sharding(activation_sharding); } else { - inst->set_sharding(HloSharding::Replicate()); + inst->set_sharding(default_sharding); } } - tuple->set_sharding(batch_norm->sharding()); + tuple->set_sharding(sharding); } TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index aa36e64b07099a372dab67babc7a18a2d39596bc..a725351462809e5b670bbf1d79d2dded87e54f07 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -19,12 +19,13 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -114,5 +115,33 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) { EXPECT_EQ(root->opcode(), HloOpcode::kTuple); } +TEST_F(BatchNormExpanderTest, BatchNormTrainingSharding) { + const char* module_str = R"( +HloModule module +ENTRY entry { + %param.0 = f32[8,4] parameter(0) + %param.1 = f32[4] parameter(1) + %param.2 = f32[4] parameter(2) + ROOT %batch-norm-training = (f32[8,4], f32[4], f32[4]) + batch-norm-training(f32[8,4] %param.0, f32[4] %param.1, f32[4] %param.2), + epsilon=0.001, feature_index=1, sharding={maximal device=1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(module_str)); + BatchNormExpander rewriter(/*rewrite_training_op=*/true, + /*rewrite_inference_op=*/true, + /*rewrite_grad_op=*/true); + ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + + for (auto* instruction : module->entry_computation()->instructions()) { + if (instruction->opcode() == HloOpcode::kParameter) { + continue; + } + auto device = instruction->sharding_unique_device(); + ASSERT_TRUE(device); + EXPECT_EQ(*device, 1); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index f7b4c1405dbc8719d8fba5476e6e41d2921ea877..7cf05ca443c00c3b40eeb7d756cf216b45c45c39 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -235,7 +235,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, - sum, /*replica_group_ids=*/{}, /*barrier=*/"")); + sum, /*replica_group_ids=*/{}, /*barrier=*/"", + /*all_reduce_id=*/tensorflow::gtl::nullopt)); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index 14c54ddd135af024327f63418b410da1ed3c4fd4..16e99b57220cc185fbfaa75d30a0de709cf61ee7 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -34,8 +34,10 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo) override; - // Special handling for cross-replica-sum which can have a tuple output. + // Special handling for cross-replica-sum and sort which can have a tuple + // output. Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleSort(HloInstruction* sort) override; static bool Run(HloComputation* computation, const BFloat16Support* bfloat16_support) { @@ -49,6 +51,10 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { // conversions between F32 and BF16 to make it supported. Status HandleInstruction(HloInstruction* hlo); + // Handle instructions with tuple outputs by examining each output + // independently. + Status HandleMultipleOutputs(HloInstruction* hlo); + // Inserts a conversion HLO that changes the given HLO's output type. Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to, HloComputation* computation); @@ -148,22 +154,35 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( HloInstruction* crs) { if (!ShapeUtil::IsTuple(crs->shape())) { return HandleInstruction(crs); + } else { + return HandleMultipleOutputs(crs); } +} + +Status BFloat16NormalizationVisitor::HandleSort(HloInstruction* sort) { + if (!ShapeUtil::IsTuple(sort->shape())) { + return HandleInstruction(sort); + } else { + return HandleMultipleOutputs(sort); + } +} - std::vector operand_types(crs->operand_count()); - std::vector output_types(crs->operand_count()); +Status BFloat16NormalizationVisitor::HandleMultipleOutputs( + HloInstruction* hlo) { + std::vector operand_types(hlo->operand_count()); + std::vector output_types(hlo->operand_count()); int64 f32_count = 0; int64 bf16_count = 0; bool has_unsupported_bf16_operand = false; bool has_unsupported_bf16_output = false; - for (int64 i = 0; i < crs->operand_count(); ++i) { - operand_types[i] = crs->operand(i)->shape().element_type(); - output_types[i] = ShapeUtil::GetSubshape(crs->shape(), {i}).element_type(); + for (int64 i = 0; i < hlo->operand_count(); ++i) { + operand_types[i] = hlo->operand(i)->shape().element_type(); + output_types[i] = ShapeUtil::GetSubshape(hlo->shape(), {i}).element_type(); if (operand_types[i] == F32) { f32_count += 1; } else if (operand_types[i] == BF16) { bf16_count += 1; - if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) { + if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) { has_unsupported_bf16_operand = true; } } @@ -171,7 +190,7 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( f32_count += 1; } else if (output_types[i] == BF16) { bf16_count += 1; - if (!bfloat16_support_->SupportsBF16Output(*crs)) { + if (!bfloat16_support_->SupportsBF16Output(*hlo)) { has_unsupported_bf16_output = true; } } @@ -185,43 +204,43 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( if (operand_types[i] != BF16) { return false; } - if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) { + if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) { return true; } - if (bfloat16_support_->SupportsMixedPrecisions(*crs)) { + if (bfloat16_support_->SupportsMixedPrecisions(*hlo)) { return false; } return has_unsupported_bf16_operand || has_unsupported_bf16_output || f32_count > 0; }; - for (int64 i = 0; i < crs->operand_count(); ++i) { + for (int64 i = 0; i < hlo->operand_count(); ++i) { if (should_convert_operand(i)) { - TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_)); + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); f32_count += 1; bf16_count -= 1; } } if (!has_unsupported_bf16_output && - (bfloat16_support_->SupportsMixedPrecisions(*crs) || f32_count == 0 || + (bfloat16_support_->SupportsMixedPrecisions(*hlo) || f32_count == 0 || bf16_count == 0)) { return Status::OK(); } - std::vector materialized_users = crs->users(); - std::vector output_elements(crs->operand_count()); - auto original_shape = crs->shape(); - for (int64 i = 0; i < crs->operand_count(); ++i) { - auto subshape = ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i}); + std::vector materialized_users = hlo->users(); + std::vector output_elements(hlo->operand_count()); + auto original_shape = hlo->shape(); + for (int64 i = 0; i < hlo->operand_count(); ++i) { + auto subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), {i}); if (output_types[i] != BF16) { output_elements[i] = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(*subshape, crs, i)); + HloInstruction::CreateGetTupleElement(*subshape, hlo, i)); continue; } subshape->set_element_type(F32); auto gte = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(*subshape, crs, i)); + HloInstruction::CreateGetTupleElement(*subshape, hlo, i)); output_elements[i] = computation_->AddInstruction(HloInstruction::CreateConvert( ShapeUtil::ChangeElementType(*subshape, BF16), gte)); @@ -229,11 +248,11 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( auto tuple = computation_->AddInstruction( HloInstruction::CreateTuple(output_elements)); - // Use the crs' shape temporarily, in order to pass checks in + // Use the hlo' shape temporarily, in order to pass checks in // ReplaceUseWith. - *tuple->mutable_shape() = crs->shape(); + *tuple->mutable_shape() = hlo->shape(); for (auto* user : materialized_users) { - TF_RETURN_IF_ERROR(crs->ReplaceUseWith(user, tuple)); + TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, tuple)); } *tuple->mutable_shape() = original_shape; return Status::OK(); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 830f26422bdc2b3bd789e7d5926bcebac815d34a..f9f1f64998f5b925102dc238941897ff6d441b3f 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -251,7 +251,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, - /*replica_group_ids=*/{}, /*barrier=*/"")); + /*replica_group_ids=*/{}, /*barrier=*/"", + /*all_reduce_id=*/tensorflow::gtl::nullopt)); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); @@ -265,6 +266,33 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32); } +TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {1024}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024}); + Shape s32_shape = ShapeUtil::MakeShape(BF16, {1024}); + + HloInstruction* key = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "key")); + HloInstruction* value = builder.AddInstruction( + HloInstruction::CreateParameter(1, s32_shape, "value")); + + HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, value)); + HloInstruction* gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction(), gte); + EXPECT_EQ(gte->shape().element_type(), BF16); + EXPECT_EQ(sort->operand(0)->shape().element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32); +} + // Tests that the normalization should not cause unsupported mixed precision due // to resolving unsupported BF16 operand. TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index ff6d5027efba813042af65a0e50e172cc0a99ff8..2fb401c4289728f3f59538464c5b8ad49957985b 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_propagation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -215,7 +215,12 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, if (ContainsKey(values_that_must_be_kept_as_f32_, value)) { return false; } - if (ValueTypeAfterChange(value) == BF16) { + // We use the original type for the value because we are going to examine + // the uses of it, instead of the value itself. If ValueTypeAfterChange() + // were used, it would cause problems when there are aliasing buffers, i.e., + // ResolveInconsistencyOfAliasingBuffers() would fail to revert the + // tentative change to BF16 even if the uses require F32. + if (value->shape().element_type() == BF16) { continue; } for (const HloUse& use : value->uses()) { @@ -566,6 +571,9 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( } visited_computations->insert(visited_in_while.begin(), visited_in_while.end()); + } else if (hlo->opcode() == HloOpcode::kFusion) { + ResolveInconsistencyOfAliasingBuffersHelper( + hlo->fused_instructions_computation(), visited_computations); } } // Now adjust parameters of called computations. @@ -615,7 +623,6 @@ Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) { // (1) a is F32 but tuple is BF16 // (2) after adding conversion // (3) after tuple simplifier and DCE. - bool needs_tuple_simplifier = false; for (auto computation : module->MakeComputationPostOrder()) { auto insts = computation->MakeInstructionPostOrder(); for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { @@ -629,67 +636,25 @@ Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) { continue; } ShapeTree converted_outputs(hlo->shape()); - // Iterate through nodes in the shape tree in pre-order and initialize - // each non-root node with a corresponding get-tuple-element. For a leaf - // node, if its shape does not match the fusion output, create a - // conversion node to overwrite the node value. - for (auto it = converted_outputs.begin(); it != converted_outputs.end(); - ++it) { - ShapeIndex output_index = it->first; - HloInstruction*& output = it->second; - const Shape subshape = - ShapeUtil::GetSubshape(hlo->shape(), output_index); - if (output_index.empty()) { - output = fusion_root; - } else { - ShapeIndex parent_index = output_index; - parent_index.pop_back(); - output = fusion_computation->AddInstruction( - HloInstruction::CreateGetTupleElement( - subshape, converted_outputs.element(parent_index), - output_index.back())); - } - if (!ShapeUtil::IsArray(subshape)) { - continue; - } - if (!ShapeUtil::Compatible( - subshape, - ShapeUtil::GetSubshape(fusion_root->shape(), output_index))) { - output = fusion_computation->AddInstruction( - HloInstruction::CreateConvert(subshape, output)); - } - } - // Iterate through nodes in the shape tree in reverse pre-order and create - // a tuple instruction for each non-leaf node where the elements are the - // values of its child nodes. - for (auto it = converted_outputs.rbegin(); it != converted_outputs.rend(); - ++it) { - ShapeIndex output_index = it->first; - HloInstruction*& output = it->second; - const Shape& subshape = - ShapeUtil::GetSubshape(hlo->shape(), output_index); - if (!ShapeUtil::IsTuple(subshape)) { - continue; - } - std::vector elements( - ShapeUtil::TupleElementCount(subshape)); - ShapeIndex child_index = output_index; - for (int64 i = 0; i < elements.size(); ++i) { - child_index.push_back(i); - elements[i] = converted_outputs.element(child_index); - child_index.pop_back(); - } - output = fusion_computation->AddInstruction( - HloInstruction::CreateTuple(elements)); - } - fusion_computation->set_root_instruction(converted_outputs.element({})); - needs_tuple_simplifier |= ShapeUtil::IsTuple(hlo->shape()); + // Deep copy the fusion root, and convert a leaf node only if its shape + // does not match the fusion output. + TF_ASSIGN_OR_RETURN( + HloInstruction * copy, + fusion_computation->DeepCopyInstructionWithCustomCopier( + fusion_root, + [hlo](HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* comp) { + const Shape& hlo_subshape = + ShapeUtil::GetSubshape(hlo->shape(), leaf_index); + if (ShapeUtil::Compatible(leaf->shape(), hlo_subshape)) { + return leaf; + } + return comp->AddInstruction( + HloInstruction::CreateConvert(hlo_subshape, leaf)); + })); + fusion_computation->set_root_instruction(copy); } } - if (needs_tuple_simplifier) { - TupleSimplifier tuple_simplifier; - TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); - } return Status::OK(); } @@ -758,10 +723,38 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { changes_to_bf16_.clear(); changed_ = false; + auto computations_topological_order = module->MakeComputationPostOrder(); + + // Before running the propagation pass, we insert copies (kConvert to the same + // type) of F32 inputs to while loops. This prevents other uses of the same + // input from aliasing the while loop input/output, so that there's greater + // chance to use BF16 inside the loop. If some of these added copies do not + // help, they will remain F32 after BF16 propagation and will be removed since + // they are no-ops. + for (auto computation : computations_topological_order) { + for (auto inst : computation->MakeInstructionPostOrder()) { + if (inst->opcode() != HloOpcode::kWhile) { + continue; + } + + auto operand = inst->mutable_operand(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * copy, + computation->DeepCopyInstructionWithCustomCopier( + operand, [](HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* comp) { + if (leaf->shape().element_type() != F32) { + return leaf; + } + return comp->AddInstruction( + HloInstruction::CreateConvert(leaf->shape(), leaf)); + })); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(inst, copy)); + } + } + TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module)); - const auto& computations_topological_order = - module->MakeComputationPostOrder(); // The first step is a forward pass (parameters to root), where we determine // the potential candidate instructions to use bfloat16 in the outputs that // are not likely to cause overhead from extra explicit conversions. This is @@ -784,8 +777,7 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { // propagation in reverse topological order. for (auto comp_it = computations_topological_order.rbegin(); comp_it != computations_topological_order.rend(); ++comp_it) { - if ((*comp_it)->IsFusionComputation()) { - // Fusion computations are handled when visiting the fusion instruction. + if (ContainsKey(computations_visited_in_backward_pass_, *comp_it)) { continue; } auto insts = (*comp_it)->MakeInstructionPostOrder(); @@ -793,6 +785,7 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/true); } + computations_visited_in_backward_pass_.insert(*comp_it); } // It's possible that an instruction does not define a buffer, but the @@ -810,23 +803,27 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { } } + // Removes redundant HLOs added by this pass, either when inserting + // de-aliasing copies to while loop inputs, or later when converting output + // types. + auto clean_up = [this, module]() { + TF_RETURN_IF_ERROR(SkipNoopConversions(module)); + TupleSimplifier tuple_simplifier; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + HloDCE dce; + TF_RETURN_IF_ERROR(dce.Run(module).status()); + return Status::OK(); + }; + if (!changed_) { + TF_RETURN_IF_ERROR(clean_up()); return false; } TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module)); TF_RETURN_IF_ERROR(ResolveConvertedConstants(module)); - // This pass could have turned an F32 -> BF16 conversion to a no-op (BF16 -> - // BF16), so we skip them now. - TF_RETURN_IF_ERROR(SkipNoopConversions(module)); - - { - // We may have dead HLOs after ResolveInconsistentFusions, - // ResolveConvertedConstants and SkipNoopConversions. - HloDCE dce; - TF_RETURN_IF_ERROR(dce.Run(module).status()); - } + TF_RETURN_IF_ERROR(clean_up()); return true; } diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 560910cc5ffbf74737b6f025f7da2928c9cd621b..69b654d30e42b1ed69304206f09120e86831d468 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -133,9 +133,9 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { array_b.FillUnique(10.0f); HloInstruction* a = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateFromArray(array_a))); + HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a))); HloInstruction* b = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateFromArray(array_b))); + HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b))); HloInstruction* dot = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b)); @@ -150,10 +150,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)), + *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)), dot->operand(0)->literal())); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)), + *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)), dot->operand(1)->literal())); } @@ -240,12 +240,10 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); - EXPECT_TRUE(OutputsBF16(add0)); EXPECT_TRUE(OutputsBF16(add1)); EXPECT_TRUE(OutputsBF16(lhs)); - // rhs is a get-tuple-element, which does not define a buffer, but its shape - // should also be adjusted accordingly. - EXPECT_TRUE(OutputsBF16(rhs)); + + // add0 and rhs have been eliminated by simplification and DCE. } // Tests that a non-fusion computation's root should not be changed. @@ -434,7 +432,7 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) { HloInstruction* tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({param, add1})); HloInstruction* sel = builder.AddInstruction(HloInstruction::CreateTernary( - tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); + tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1)); HloInstruction* gte0 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, sel, 0)); HloInstruction* gte1 = builder.AddInstruction( @@ -510,6 +508,63 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { EXPECT_FALSE(OutputsBF16(dot)); } +// Tests that if the while condition prevents using BF16, no changes should be +// made to the while body and thus the fusion node inside it. +TEST_F(BFloat16PropagationTest, + ConditionPreventsPropagationForFusionInsideWhile) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param1")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + + auto builder_cond = HloComputation::Builder("cond"); + auto cond_param = builder_cond.AddInstruction( + HloInstruction::CreateParameter(0, shape, "cond_param")); + builder_cond.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond_param, {0, 0}, {1, 1}, {1, 1})), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond_param, {1, 1}, {2, 2}, {1, 1})))); + auto cond = module->AddEmbeddedComputation(builder_cond.Build()); + + auto builder_body = HloComputation::Builder("body"); + auto body_param = builder_body.AddInstruction( + HloInstruction::CreateParameter(0, shape, "body_param")); + auto body_transpose = builder_body.AddInstruction( + HloInstruction::CreateTranspose(shape, body_param, {0, 1})); + + auto builder_f = HloComputation::Builder("fusion"); + HloInstruction* a_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + builder_f.AddInstruction(HloInstruction::CreateTranspose(shape, a_f, {0, 1})); + auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); + auto body_fusion = builder_body.AddInstruction(HloInstruction::CreateFusion( + shape, HloInstruction::FusionKind::kCustom, {body_transpose}, comp_f)); + auto body = module->AddEmbeddedComputation(builder_body.Build()); + + auto while_hlo = builder.AddInstruction( + HloInstruction::CreateWhile(shape, cond, body, add)); + + auto dot = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, while_hlo, while_hlo)); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_FALSE(OutputsBF16(add)); + EXPECT_FALSE(OutputsBF16(body_fusion)); + EXPECT_FALSE(OutputsBF16(body_param)); + EXPECT_FALSE(OutputsBF16(body_transpose)); + EXPECT_FALSE(OutputsBF16(a_f)); +} + // Tests that BF16 is propagated properly through while computations with // tuple-shaped input/output. TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { @@ -555,10 +610,14 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { HloInstruction::CreateGetTupleElement(shape, body_param, 0)); auto body_rhs = builder_body.AddInstruction( HloInstruction::CreateGetTupleElement(shape, body_param, 1)); - auto body_dot = builder_body.AddInstruction( + auto body_dot1 = builder_body.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); + auto body_dot2 = builder_body.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_rhs, body_lhs)); + auto body_transpose = builder_body.AddInstruction( + HloInstruction::CreateTranspose(shape, body_dot2, {0, 1})); builder_body.AddInstruction( - HloInstruction::CreateTuple({body_dot, body_rhs})); + HloInstruction::CreateTuple({body_dot1, body_transpose})); auto body = module->AddEmbeddedComputation(builder_body.Build()); auto while_hlo = builder.AddInstruction( @@ -577,9 +636,11 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(lhs)); EXPECT_FALSE(OutputsBF16(rhs)); - EXPECT_TRUE(OutputsBF16(body_dot)); + EXPECT_TRUE(OutputsBF16(body_dot1)); EXPECT_TRUE(OutputsBF16(body_lhs)); EXPECT_FALSE(OutputsBF16(body_rhs)); + EXPECT_FALSE(OutputsBF16(body_dot2)); + EXPECT_FALSE(OutputsBF16(body_transpose)); EXPECT_TRUE(OutputsBF16(cond_lhs)); EXPECT_FALSE(OutputsBF16(cond_rhs)); EXPECT_TRUE(OutputsBF16(add0)); @@ -734,10 +795,8 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), add2); - EXPECT_EQ(add2->operand(0), gte0); - EXPECT_EQ(add2->operand(1), gte1); - EXPECT_EQ(gte0->shape().element_type(), BF16); - EXPECT_EQ(gte1->shape().element_type(), BF16); + EXPECT_EQ(add2->operand(0), add0); + EXPECT_EQ(add2->operand(1), add1); EXPECT_EQ(add0->shape().element_type(), BF16); EXPECT_EQ(add1->shape().element_type(), BF16); } diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index 8595afca7e735528d9ef29a323696c0661fe971c..23645346e6f491beb5171cc839c013ce5f83d789 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -103,6 +103,7 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( case HloOpcode::kDynamicUpdateSlice: return operand_index == 0 || operand_index == 1; case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: return operand_index == 1 || operand_index == 2; default: break; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index afe4b2e1425f9e84320ffd5f08beceaac8168c22..118a11c8de3c06d240079723f0a5db314cfcace5 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -270,7 +270,7 @@ BufferAllocationProto BufferAllocation::ToProto() const { proto.set_index(index_); proto.set_size(size_); proto.set_is_thread_local(is_thread_local_); - proto.set_is_reusable(is_reusable_); + proto.set_is_tuple(is_tuple_); proto.set_color(color_.value()); if (is_entry_computation_parameter_) { proto.set_is_entry_computation_parameter(true); @@ -279,6 +279,7 @@ BufferAllocationProto BufferAllocation::ToProto() const { } proto.set_parameter_number(parameter_number_); } + proto.set_is_constant(is_constant_); proto.set_maybe_live_out(maybe_live_out_); for (const auto& buffer_offset_size : assigned_buffers_) { BufferAllocationProto::Assigned* proto_assigned = proto.add_assigned(); @@ -304,6 +305,9 @@ string BufferAllocation::ToString() const { StrAppend(&output, ", parameter ", parameter_number(), " at ShapeIndex ", param_shape_index().ToString()); } + if (is_constant()) { + StrAppend(&output, ", constant"); + } if (is_thread_local()) { StrAppend(&output, ", thread-local"); } @@ -491,20 +495,16 @@ BufferAssignment::GetUniqueTopLevelOutputSlice() const { } BufferAllocation* BufferAssignment::NewEmptyAllocation( - int64 size, bool is_thread_local, bool is_reusable, - LogicalBuffer::Color color) { + int64 size, LogicalBuffer::Color color) { BufferAllocation::Index index = allocations_.size(); - allocations_.emplace_back(index, size, is_thread_local, is_reusable, color); + allocations_.emplace_back(index, size, color); BufferAllocation* allocation = &allocations_.back(); return allocation; } BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer, - int64 size, - bool is_thread_local, - bool is_reusable) { - BufferAllocation* allocation = - NewEmptyAllocation(size, is_thread_local, is_reusable, buffer.color()); + int64 size) { + BufferAllocation* allocation = NewEmptyAllocation(size, buffer.color()); AddAssignment(allocation, buffer, /*offset=*/0, size); allocation->peak_buffers_.push_back(&buffer); return allocation; @@ -517,7 +517,8 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation, CHECK_EQ(0, allocation_index_for_buffer_.count(&buffer)) << "LogicalBuffer " << buffer << " already has an allocation."; CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty()) - << "Non-reusable allocation already assigned a buffer"; + << "Non-reusable allocation already assigned a buffer: " + << allocation->ToString(); TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer)); @@ -609,6 +610,10 @@ Status BufferAssignment::ComputeSummaryStats() { stats_.parameter_allocation_count++; stats_.parameter_allocation_bytes += allocation.size(); } + if (allocation.is_constant()) { + stats_.constant_allocation_count++; + stats_.constant_allocation_bytes += allocation.size(); + } if (allocation.maybe_live_out()) { stats_.maybe_live_out_allocation_count++; stats_.maybe_live_out_allocation_bytes += allocation.size(); @@ -645,6 +650,8 @@ string BufferAssignment::Stats::ToString() const { Appendf(&s, "BufferAssignment stats:\n"); Appendf(&s, " parameter allocation: %10s\n", HumanReadableNumBytes(parameter_allocation_bytes).c_str()); + Appendf(&s, " constant allocation: %10s\n", + HumanReadableNumBytes(constant_allocation_bytes).c_str()); Appendf(&s, " maybe_live_out allocation: %10s\n", HumanReadableNumBytes(maybe_live_out_allocation_bytes).c_str()); Appendf(&s, " preallocated temp allocation: %10s\n", @@ -722,8 +729,10 @@ StatusOr> BufferAssigner::Run( const HloModule* module, std::unique_ptr hlo_ordering, LogicalBuffer::SizeFunction buffer_size, LogicalBuffer::AlignmentFunction color_alignment, - bool allow_input_output_aliasing, BufferLiveness::Colorer colorer) { - BufferAssigner assigner(allow_input_output_aliasing, std::move(colorer)); + bool allow_input_output_aliasing, bool allocate_buffers_for_constants, + BufferLiveness::Colorer colorer) { + BufferAssigner assigner(allow_input_output_aliasing, + allocate_buffers_for_constants, std::move(colorer)); return assigner.CreateAssignment(module, std::move(hlo_ordering), std::move(buffer_size), std::move(color_alignment)); @@ -751,8 +760,8 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, return false; } - if (allocation->is_entry_computation_parameter()) { - VLOG(4) << "Can't assign: allocation holds parameter"; + if (allocation->is_readonly()) { + VLOG(4) << "Can't assign: allocation is readonly"; return false; } @@ -808,8 +817,7 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, } Status BufferAssigner::AssignBuffersForComputation( - const HloComputation* computation, const DebugOptions& debug_options, - bool is_thread_local, + const HloComputation* computation, bool is_thread_local, const FlatSet& colocated_buffers, const FlatSet& colocated_allocations, FlatMap>* @@ -869,8 +877,8 @@ Status BufferAssigner::AssignBuffersForComputation( // important reuse case where an elementwise instruction reuses one of its // operand's buffer. This improves locality. std::sort(sorted_buffers.begin(), sorted_buffers.end(), - [this, has_sequential_order, &liveness, &post_order_position, - assignment](const LogicalBuffer* a, const LogicalBuffer* b) { + [has_sequential_order, &liveness, &post_order_position, assignment]( + const LogicalBuffer* a, const LogicalBuffer* b) { // Primary sort is by decreasing buffer size. const int64 a_size = assignment->buffer_size_(*a); const int64 b_size = assignment->buffer_size_(*b); @@ -905,15 +913,19 @@ Status BufferAssigner::AssignBuffersForComputation( TF_RET_CHECK(!assignment->HasAllocation(*buffer)); const HloInstruction* instruction = buffer->instruction(); + const int64 buffer_size = assignment->buffer_size_(*buffer); + if (instruction->opcode() == HloOpcode::kConstant) { - // No BufferAllocations for constants. - // TODO(b/32248867): For consistency, constants should get allocations. - VLOG(3) << "Skipping constant: " << *buffer; + if (allocate_buffers_for_constants_) { + BufferAllocation* allocation = + assignment->NewAllocation(*buffer, buffer_size); + allocation->set_constant(true); + VLOG(3) << "New allocation #" << allocation->index() << " for constant " + << *buffer; + } continue; } - const int64 buffer_size = assignment->buffer_size_(*buffer); - const bool is_entry_parameter = instruction->opcode() == HloOpcode::kParameter && computation == computation->parent()->entry_computation(); @@ -923,9 +935,7 @@ Status BufferAssigner::AssignBuffersForComputation( // computations do not need special allocations because they live inside // callers. BufferAllocation* allocation = - assignment->NewAllocation(*buffer, buffer_size, - /*is_thread_local=*/false, - /*is_reusable=*/false); + assignment->NewAllocation(*buffer, buffer_size); allocation->set_entry_computation_parameter( instruction->parameter_number(), buffer->index()); VLOG(3) << "New allocation #" << allocation->index() @@ -934,20 +944,18 @@ Status BufferAssigner::AssignBuffersForComputation( } if (is_thread_local) { - // We do not reuse thread-local buffers for now, because they are - // dynamically allocated and their lifetimes are hard to compute. - BufferAllocation* allocation = assignment->NewAllocation( - *buffer, buffer_size, is_thread_local, /*is_reusable=*/false); + BufferAllocation* allocation = + assignment->NewAllocation(*buffer, buffer_size); + allocation->set_is_thread_local(true); VLOG(3) << "New allocation #" << allocation->index() << " for thread-local: " << *buffer; continue; } if (ShapeUtil::IsTuple(buffer->shape())) { - // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend - // assumes longer buffer liveness than indicated by the analysis. - BufferAllocation* allocation = assignment->NewAllocation( - *buffer, buffer_size, is_thread_local, /*is_reusable=*/false); + BufferAllocation* allocation = + assignment->NewAllocation(*buffer, buffer_size); + allocation->set_is_tuple(true); VLOG(3) << "New allocation #" << allocation->index() << " for tuple-shaped buffer: " << *buffer; continue; @@ -1030,8 +1038,8 @@ Status BufferAssigner::AssignBuffersForComputation( } if (!assignment->HasAllocation(*buffer)) { - BufferAllocation* allocation = assignment->NewAllocation( - *buffer, buffer_size, is_thread_local, /*is_reusable=*/true); + BufferAllocation* allocation = + assignment->NewAllocation(*buffer, buffer_size); allocation_indices.push_back(allocation->index()); VLOG(3) << "New allocation #" << allocation->index() << " for: " << *buffer; @@ -1085,6 +1093,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); HeapSimulator::Options options; + options.alloc_constants = allocate_buffers_for_constants_; BufferValueFlatSet buffer_value_set = ToBufferValueFlatSet(single_colored_set.second); options.buffers_to_assign = &buffer_value_set; @@ -1227,8 +1236,8 @@ void BufferAssigner::AssignBuffersFromHeapSimulator( result.fragmentation_size; } - BufferAllocation* allocation = assignment->NewEmptyAllocation( - result.heap_size, /*is_thread_local=*/false, /*is_reusable=*/true, color); + BufferAllocation* allocation = + assignment->NewEmptyAllocation(result.heap_size, color); for (const auto& buffer_chunk : result.chunk_map) { // TODO(lauj) Remove this down_cast after downstream users of // BufferAllocation::assigned_buffers() are updated to use BufferValue. @@ -1332,11 +1341,25 @@ BufferAssigner::MergeColocatedBufferSets( auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness, &buffer_size, &is_entry_parameter](int64 i, int64 j) { - // Do not merge if one of the sets includes live outs or entry parameters. + // Do not merge if one of the sets includes live outs, entry parameters or + // constants. + // + // Buffer liveness does not report the correct live range for entry + // parameter and live out buffers so we have to special case them here. On + // backends that support constant buffer allocations, constant buffers are + // assigned globals in readonly storage so we can't merge colocated buffer + // sets containing constants with colocated buffer sets containing writing + // instructions or other constants. + // + // Moreover (on the CPU/GPU backends) the entry parameter buffers belong to + // the caller of the executable so we can't write to entry parameters + // either, and the argument for not merging constants also applies to entry + // parameters. for (int64 key : {i, j}) { for (auto& buffer : colocated_buffer_sets[key]) { if (buffer_liveness.MaybeLiveOut(*buffer) || - is_entry_parameter(*buffer)) { + is_entry_parameter(*buffer) || + buffer->instruction()->opcode() == HloOpcode::kConstant) { return true; } } @@ -1418,9 +1441,9 @@ void BufferAssigner::BuildColocatedBufferSets( const HloInstruction* while_hlo = instruction; ShapeUtil::ForEachSubshape( while_hlo->shape(), - [this, while_hlo, &points_to_analysis, &buffer_liveness, - buffer_size, computation, colocated_buffer_sets]( - const Shape& /*subshape*/, const ShapeIndex& index) { + [this, while_hlo, &points_to_analysis, buffer_size, + colocated_buffer_sets](const Shape& /*subshape*/, + const ShapeIndex& index) { std::vector colocated_set; // Add while.init. AddBufferToColocatedSet(while_hlo->operand(0), index, @@ -1444,8 +1467,23 @@ void BufferAssigner::BuildColocatedBufferSets( }); } else if (opcode == HloOpcode::kCall) { const HloInstruction* call_hlo = instruction; - const HloInstruction* root_hlo = - call_hlo->to_apply()->root_instruction(); + const HloComputation* callee = call_hlo->to_apply(); + const HloInstruction* root_hlo = callee->root_instruction(); + for (int64 i = 0; i < call_hlo->operand_count(); i++) { + const HloInstruction* call_param = callee->parameter_instruction(i); + const HloInstruction* call_operand = call_hlo->operand(i); + ShapeUtil::ForEachSubshape( + call_operand->shape(), + [&](const Shape& /*subshape*/, const ShapeIndex& index) { + std::vector colocated_set; + AddBufferToColocatedSet(call_param, index, points_to_analysis, + &colocated_set); + AddBufferToColocatedSet(call_operand, index, points_to_analysis, + &colocated_set); + AddSetToColocatedBufferSets(colocated_set, + colocated_buffer_sets); + }); + } ShapeUtil::ForEachSubshape( call_hlo->shape(), [this, call_hlo, root_hlo, &points_to_analysis, @@ -1551,6 +1589,7 @@ void BufferAssigner::AssignColocatedBufferSets( // param in 'colocated_buffer_set'. int64 entry_parameter_number = -1; const ShapeIndex* entry_parameter_shape_idx = nullptr; + bool is_constant = false; for (const LogicalBuffer* buffer : colocated_buffer_set) { const HloInstruction* instruction = buffer->instruction(); const HloComputation* computation = instruction->parent(); @@ -1558,10 +1597,14 @@ void BufferAssigner::AssignColocatedBufferSets( computation == computation->parent()->entry_computation()) { entry_parameter_number = instruction->parameter_number(); entry_parameter_shape_idx = &buffer->index(); - break; + } else if (instruction->opcode() == HloOpcode::kConstant) { + is_constant = true; } } + CHECK(!is_constant || entry_parameter_number == -1) + << "Copy insertion should have inserted copies to prevent this."; + for (const LogicalBuffer* buffer : colocated_buffer_set) { const int64 buffer_size = assignment->buffer_size_(*buffer); if (allocation == nullptr) { @@ -1569,18 +1612,14 @@ void BufferAssigner::AssignColocatedBufferSets( // allocations for each colocated buffer set. When liveness has // module-level scope, we can allow buffers to be shared across // computations (in some cases). - allocation = assignment->NewAllocation(*buffer, buffer_size, - /*is_thread_local=*/false, - /*is_reusable=*/true); + allocation = assignment->NewAllocation(*buffer, buffer_size); if (entry_parameter_number >= 0) { - // This colocated buffer set contains an entry parameter and other - // logical buffers which use the parameter as read-only in a while - // body computation (which updates in place). - // Set 'entry_computation_parameter' to indicate that it contains - // an entry parameter, and to prevent reuse in MaybeAssignBuffer. allocation->set_entry_computation_parameter( entry_parameter_number, *entry_parameter_shape_idx); } + if (is_constant) { + allocation->set_constant(true); + } colocated_allocations->insert(allocation->index()); } else { CHECK_EQ(buffer_size, allocation->size()) @@ -1638,7 +1677,7 @@ StatusOr> BufferAssigner::CreateAssignment( buffers_to_assign_sequentially; for (auto* computation : global_computations) { TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, module->config().debug_options(), + computation, /*is_thread_local=*/false, colocated_buffers, colocated_allocations, &buffers_to_assign_sequentially, assignment.get())); } @@ -1659,7 +1698,7 @@ StatusOr> BufferAssigner::CreateAssignment( continue; } TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, module->config().debug_options(), + computation, /*is_thread_local=*/true, colocated_buffers, colocated_allocations, /*buffers_to_assign_sequentially=*/nullptr, assignment.get())); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index ad0b0bf7c25d7194a06801e4ef1c9ee961f6b915..94495290c131e22392079dc2d0237d990b646d3e 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -58,13 +57,8 @@ class BufferAllocation { // contiguously and can be used as array indexes. using Index = int64; - BufferAllocation(Index index, int64 size, bool is_thread_local, - bool is_reusable, LogicalBuffer::Color color) - : index_(index), - size_(size), - is_thread_local_(is_thread_local), - is_reusable_(is_reusable), - color_(color) {} + BufferAllocation(Index index, int64 size, LogicalBuffer::Color color) + : index_(index), size_(size), color_(color) {} ~BufferAllocation() {} // Returns the index of this allocation. @@ -74,9 +68,28 @@ class BufferAllocation { // inside of a map or reduce computation. Such allocations need to be thread // local. bool is_thread_local() const { return is_thread_local_; } + void set_is_thread_local(bool is_thread_local) { + is_thread_local_ = is_thread_local; + } // Whether this allocation can be used by more than one logical buffer. - bool is_reusable() const { return is_reusable_; } + bool is_reusable() const { + // We do not reuse thread-local buffers for now, because they are + // dynamically allocated and their lifetimes are hard to compute. + // + // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend + // assumes longer buffer liveness than indicated by the analysis. + return !is_thread_local() && !is_tuple(); + } + + // Whether this allocation is readonly i.e. backed by memory we cannot write + // to. + bool is_readonly() const { + return is_entry_computation_parameter() || is_constant(); + } + + bool is_tuple() const { return is_tuple_; } + void set_is_tuple(bool is_tuple) { is_tuple_ = is_tuple; } // Whether this allocation holds a LogicalBuffer from a parameter of the entry // computation. These buffers have lifetimes which may be longer than the @@ -84,6 +97,13 @@ class BufferAllocation { bool is_entry_computation_parameter() const { return is_entry_computation_parameter_; } + + // Whether this allocation holds a constant. On the CPU and GPU backends + // constant allocations are not allocated dynamically, instead we resolve + // references to these buffer allocations to a global in the readonly section + // of the binary. + bool is_constant() const { return is_constant_; } + // If this allocation holds a Buffer from a parameter of the entry // computation, this methods returns the parameter number. CHECKs otherwise. int64 parameter_number() const { @@ -189,7 +209,9 @@ class BufferAllocation { // of the computation. !maybe_live_out() && // Thread-local buffers are allocated using `alloca`s. - !is_thread_local(); + !is_thread_local() && + // Constant buffers are allocated as global values. + !is_constant(); } // Add a heap trace which was used to assign slices to logical buffers in this @@ -245,6 +267,8 @@ class BufferAllocation { parameter_number_ = parameter_number; param_shape_index_ = std::move(param_shape_index); } + + void set_constant(bool is_constant) { is_constant_ = is_constant; } void set_maybe_live_out(bool value) { maybe_live_out_ = value; } void set_index(Index index) { index_ = index; } void set_size(int64 size) { size_ = size; } @@ -256,10 +280,10 @@ class BufferAllocation { int64 size_; // Whether this buffer needs to be thread-local. - bool is_thread_local_; + bool is_thread_local_ = false; - // Whether this buffer is usable by more than one logical buffer. - bool is_reusable_; + // Whether this buffer holds a tuple. + bool is_tuple_ = false; // Color of the allocation. LogicalBuffer::Color color_; @@ -283,6 +307,9 @@ class BufferAllocation { // might not actually escape. bool maybe_live_out_ = false; + // See comment on the is_constant() accessor. + bool is_constant_ = false; + // Mapping from the set of buffers assigned to this allocation to their // logical offsets and sizes. tensorflow::gtl::FlatMap assigned_buffers_; @@ -398,6 +425,8 @@ class BufferAssignment { struct Stats { int64 parameter_allocation_count = 0; int64 parameter_allocation_bytes = 0; + int64 constant_allocation_count = 0; + int64 constant_allocation_bytes = 0; int64 maybe_live_out_allocation_count = 0; int64 maybe_live_out_allocation_bytes = 0; int64 preallocated_temp_allocation_count = 0; @@ -426,14 +455,11 @@ class BufferAssignment { // Creates and returns a new BufferAllocation, with no assigned // LogicalBuffers. Ownership is maintained internally. - BufferAllocation* NewEmptyAllocation(int64 size, bool is_thread_local, - bool is_reusable, - LogicalBuffer::Color color); + BufferAllocation* NewEmptyAllocation(int64 size, LogicalBuffer::Color color); // Helper that calls NewEmptyAllocation and AddAssignment in one call, // creating an allocation containing a single LogicalBuffer. - BufferAllocation* NewAllocation(const LogicalBuffer& buffer, int64 size, - bool is_thread_local, bool is_reusable); + BufferAllocation* NewAllocation(const LogicalBuffer& buffer, int64 size); // Adds a LogicalBuffer to the set assigned to the given allocation. void AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer, @@ -493,12 +519,15 @@ class BufferAssigner { LogicalBuffer::SizeFunction buffer_size, LogicalBuffer::AlignmentFunction color_alignment, bool allow_input_output_aliasing = false, + bool allocate_buffers_for_constants = false, BufferLiveness::Colorer colorer = BufferLiveness::DefaultColorer()); private: BufferAssigner(bool allow_input_output_aliasing, + bool allocate_buffers_for_constants, BufferLiveness::Colorer colorer) : allow_input_output_aliasing_(allow_input_output_aliasing), + allocate_buffers_for_constants_(allocate_buffers_for_constants), colorer_(colorer) {} virtual ~BufferAssigner() = default; @@ -513,8 +542,7 @@ class BufferAssigner { // true, then all assigned buffers have the is_thread_local flag set to // true. Status AssignBuffersForComputation( - const HloComputation* computation, const DebugOptions& debug_options, - bool is_thread_local, + const HloComputation* computation, bool is_thread_local, const tensorflow::gtl::FlatSet& colocated_buffers, const tensorflow::gtl::FlatSet& colocated_allocations, @@ -595,6 +623,9 @@ class BufferAssigner { // buffers can be shared if their sizes match. bool allow_input_output_aliasing_; + // If true, allocate buffers for constant instructions. + bool allocate_buffers_for_constants_; + // Functor used to assign colors to newly allocated logical buffers. BufferLiveness::Colorer colorer_; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 28b5a5784ff7f5d0b7fd412d1c50f3025f11bb81..eccb146a0d7d628870be179a540d9750df3fe41c 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/call_graph.h" @@ -89,7 +89,20 @@ class BufferAssignmentTest : public HloTestBase { return BufferAssigner::Run( module, xla::MakeUnique(module), backend().compiler()->BufferSizeBytesFunction(), - [alignment](LogicalBuffer::Color) { return alignment; }) + [alignment](LogicalBuffer::Color) { return alignment; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) + .ConsumeValueOrDie(); + } + + std::unique_ptr RunBufferAssignmentNoBuffersForConstants( + HloModule* module, int64 alignment = 1) { + return BufferAssigner::Run( + module, xla::MakeUnique(module), + backend().compiler()->BufferSizeBytesFunction(), + [alignment](LogicalBuffer::Color) { return alignment; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/false) .ConsumeValueOrDie(); } @@ -98,8 +111,9 @@ class BufferAssignmentTest : public HloTestBase { return BufferAssigner::Run( module, xla::MakeUnique(module), backend().compiler()->BufferSizeBytesFunction(), - [alignment](LogicalBuffer::Color) { return alignment; }, false, - std::move(colorer)) + [alignment](LogicalBuffer::Color) { return alignment; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true, std::move(colorer)) .ConsumeValueOrDie(); } @@ -115,7 +129,9 @@ class BufferAssignmentTest : public HloTestBase { module, xla::MakeUnique(module, module_sequence), backend().compiler()->BufferSizeBytesFunction(), - [alignment](LogicalBuffer::Color) { return alignment; }) + [alignment](LogicalBuffer::Color) { return alignment; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); } @@ -125,7 +141,7 @@ class BufferAssignmentTest : public HloTestBase { auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); auto value = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value)); return builder.Build(); @@ -142,7 +158,7 @@ class BufferAssignmentTest : public HloTestBase { const string& name) { auto builder = HloComputation::Builder(name); auto const4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); auto index = builder.AddInstruction( @@ -167,9 +183,9 @@ class BufferAssignmentTest : public HloTestBase { const string& name) { auto builder = HloComputation::Builder(name); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto constv = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); auto indexc = builder.AddInstruction( @@ -290,13 +306,19 @@ static bool BuffersDistinct(const std::vector& a, TEST_F(BufferAssignmentTest, ScalarConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); - // Check that the constant does not have a buffer assigned. - EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); + { + auto buffers = RunBufferAssignment(module.get()); + EXPECT_TRUE(buffers->HasTopLevelAllocation(const0)); + } + + { + auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get()); + EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); + } } TEST_F(BufferAssignmentTest, BufferForConst) { @@ -304,20 +326,26 @@ TEST_F(BufferAssignmentTest, BufferForConst) { // no buffers assigned, and their consumer has a buffer. auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({4.1f, 4.2f, 4.3f, 4.4f}))); + LiteralUtil::CreateR1({4.1f, 4.2f, 4.3f, 4.4f}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); - // The two constant nodes have no buffers assigned. - EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); - EXPECT_FALSE(buffers->HasTopLevelAllocation(const1)); - // The add node has an output buffer. - GetAssignedOutputAllocation(*buffers, add); + { + auto buffers = RunBufferAssignment(module.get()); + EXPECT_TRUE(buffers->HasTopLevelAllocation(const0)); + EXPECT_TRUE(buffers->HasTopLevelAllocation(const1)); + GetAssignedOutputAllocation(*buffers, add); + } + { + auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get()); + EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); + EXPECT_FALSE(buffers->HasTopLevelAllocation(const1)); + GetAssignedOutputAllocation(*buffers, add); + } } TEST_F(BufferAssignmentTest, HasAllocationAt) { @@ -327,7 +355,7 @@ TEST_F(BufferAssignmentTest, HasAllocationAt) { auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, f32vec100_, "param0")); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); auto tuple = builder.AddInstruction( @@ -352,7 +380,7 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) { // This computation copies a constant to output. auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto copy = builder.AddInstruction( HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0)); auto module = CreateNewModule(); @@ -660,7 +688,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { auto exp2 = builder.AddInstruction( HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1)); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( /*shape=*/f32vec10_, /*operand=*/exp2, @@ -708,9 +736,9 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { // Creates the main kernel and verifies instruction counts. auto builder = HloComputation::Builder(TestName()); auto const3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({const3, const4})); auto while_op = builder.AddInstruction(HloInstruction::CreateWhile( @@ -773,11 +801,11 @@ TEST_F(BufferAssignmentTest, ExampleConditional) { auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(56.4f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(56.4f))); auto const2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(12.4f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(12.4f))); auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( r0f32_, pred, const1, true_computation, const2, false_computation)); module->AddEntryComputation(builder.Build()); @@ -1094,7 +1122,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { // Allocations for the call computation should not be thread-local. auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param); - EXPECT_FALSE(call_param_alloc.is_entry_computation_parameter()); + EXPECT_TRUE(call_param_alloc.is_entry_computation_parameter()); EXPECT_FALSE(call_param_alloc.maybe_live_out()); EXPECT_FALSE(call_param_alloc.is_thread_local()); @@ -1196,12 +1224,13 @@ TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) { // TODO(b/32248867): Enable when buffer assignment gives allocations to // constants. -TEST_F(BufferAssignmentTest, DISABLED_TupleConstantAsOutput) { +TEST_F(BufferAssignmentTest, TupleConstantAsOutput) { // Test that a tuple constant which is forwarded to the computation output // is properly handled. auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction(HloInstruction::CreateConstant(Literal::MakeTuple( - {Literal::CreateR0(0).get(), Literal::CreateR0(1).get()}))); + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), + LiteralUtil::CreateR0(1).get()}))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1252,16 +1281,18 @@ TEST_F(BufferAssignmentTest, TupleCallAsOutput) { auto assignment = RunBufferAssignment(module.get()); - EXPECT_EQ(3, assignment->Allocations().size()); + EXPECT_EQ(2, assignment->Allocations().size()); // Buffers for call are colocated with the sub-computation. EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{}), GetAllocation(*assignment, sub_tuple, /*index=*/{})); EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{0}), GetAllocation(*assignment, sub_param, /*index=*/{})); - // The parameter isn't aliased with anything. + + // The parameter isn't aliased with the result tuple, but it is aliased with + // the call operand. EXPECT_NE(GetTopLevelAllocation(*assignment, param), GetTopLevelAllocation(*assignment, sub_tuple)); - EXPECT_NE(GetTopLevelAllocation(*assignment, param), + EXPECT_EQ(GetTopLevelAllocation(*assignment, param), GetTopLevelAllocation(*assignment, sub_param)); } @@ -1325,13 +1356,15 @@ TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) { GetAllocation(*assignment, c_call, /*index=*/{0})); EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{0}), GetAllocation(*assignment, d_param, /*index=*/{0})); - // The parameters aren't aliased with anything. + EXPECT_TRUE(BuffersDistinct({a_param}, {b_param}, *assignment)); EXPECT_TRUE(BuffersDistinct({a_param}, {c_param}, *assignment)); EXPECT_TRUE(BuffersDistinct({a_param}, {d_param}, *assignment)); - EXPECT_TRUE(BuffersDistinct({b_param}, {c_param}, *assignment)); - EXPECT_TRUE(BuffersDistinct({b_param}, {d_param}, *assignment)); - EXPECT_TRUE(BuffersDistinct({c_param}, {d_param}, *assignment)); + + EXPECT_EQ(GetAllocation(*assignment, b_param, /*index=*/{0}), + GetAllocation(*assignment, c_param, /*index=*/{0})); + EXPECT_EQ(GetAllocation(*assignment, c_param, /*index=*/{0}), + GetAllocation(*assignment, d_param, /*index=*/{0})); } TEST_F(BufferAssignmentTest, BitcastAsOutput) { @@ -1365,8 +1398,9 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { HloInstruction::CreateParameter(1, tuple_shape, "param1")); auto pred_param = builder.AddInstruction(HloInstruction::CreateParameter( 2, ShapeUtil::MakeShape(PRED, {}), "param1")); - auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred_param, tuple_param0, tuple_param1)); + auto select = builder.AddInstruction( + HloInstruction::CreateTernary(tuple_shape, HloOpcode::kTupleSelect, + pred_param, tuple_param0, tuple_param1)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1583,7 +1617,7 @@ TEST_F(BufferAssignmentTest, PeakBuffersWhile) { auto b = HloComputation::Builder(TestName() + ".cond"); b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); condition = module->AddEmbeddedComputation(b.Build()); } HloComputation* body; @@ -1638,6 +1672,66 @@ TEST_F(BufferAssignmentTest, PeakBuffersWhile) { nonbcast_buffer->instruction() == condition->parameter_instruction(0)); } +TEST_F(BufferAssignmentTest, ConstantBuffersAreNotReused) { + const char* hlo_text = R"( +HloModule Module + +True { + ROOT x.0.1 = f32[] parameter(0) +} + +False { + x.0.0 = f32[] parameter(0) + ROOT copy.1 = f32[] copy(x.0.0) +} + +ENTRY main { + pred.1.0 = pred[] parameter(0) + constant.1.1 = f32[] constant(56) + copy.2 = f32[] copy(constant.1.1) + constant.1.2 = f32[] constant(12) + ROOT conditional.1.3 = f32[] conditional(pred.1.0, copy.2, constant.1.2), + true_computation=True, false_computation=False +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + HloInstruction* constant_1 = + module->entry_computation()->GetInstructionWithName("constant.1.1"); + HloInstruction* constant_2 = + module->entry_computation()->GetInstructionWithName("constant.1.2"); + + auto buffers = RunBufferAssignment(module.get()); + + { + const BufferAllocation& allocation_for_const_1 = + GetTopLevelAllocation(*buffers, constant_1); + EXPECT_TRUE(allocation_for_const_1.is_constant()); + for (const auto& buffer_offset_pair : + allocation_for_const_1.assigned_buffers()) { + EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(), + HloOpcode::kCopy); + EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(), + HloOpcode::kConditional); + } + } + + { + const BufferAllocation& allocation_for_const_2 = + GetTopLevelAllocation(*buffers, constant_2); + EXPECT_TRUE(allocation_for_const_2.is_constant()); + for (const auto& buffer_offset_pair : + allocation_for_const_2.assigned_buffers()) { + EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(), + HloOpcode::kCopy); + EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(), + HloOpcode::kConditional); + } + } +} + class WhileBufferAssignmentTest : public HloTestBase { protected: std::unique_ptr BuildWhileConditionComputation( @@ -1646,9 +1740,9 @@ class WhileBufferAssignmentTest : public HloTestBase { builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto ten = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(10))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten)); return builder.Build(); @@ -1677,7 +1771,9 @@ class WhileBufferAssignmentTest : public HloTestBase { return BufferAssigner::Run( module, xla::MakeUnique(module, sequence), ByteSizeOf, - [alignment](LogicalBuffer::Color) { return alignment; }) + [alignment](LogicalBuffer::Color) { return alignment; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); } @@ -1707,7 +1803,7 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { HloInstruction::CreateParameter(2, data_shape_, "weights1")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); auto output1 = builder.AddInstruction( @@ -1827,6 +1923,74 @@ ENTRY %test_module { EXPECT_NE(slice_param, slice_while1); } +TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithConstant) { + const Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + + const char* module_str = R"( +HloModule test_module + +%cond.v0 { + %param = s32[] parameter(0) + ROOT %constant = pred[] constant(true) +} + +%cond.v1 { + %param.0 = s32[] parameter(0) + ROOT %constant.0 = pred[] constant(true) +} + +%body.v0 { + ROOT %param.1 = s32[] parameter(0) +} + +%body.v1 { + %param.2 = s32[] parameter(0) + ROOT add = s32[] add(%param.2, %param.2) +} + +ENTRY %test_module { + %constant.42 = s32[] constant(42) + %while.0 = s32[] while(%constant.42), condition=%cond.v0, body=%body.v0 + %mul = s32[] multiply(%while.0, %while.0) + %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1 + ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + // Run CopyInsertion and check if the graph constructed above doesn't need + // any copies inserted for BufferAssignment to run. + int64 instruction_count = module->instruction_count(); + CopyInsertion copy_insertion; + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + ASSERT_EQ(instruction_count, module->instruction_count()); + + // Get the instructions in the module. + const HloInstruction* bcast = module->entry_computation()->root_instruction(); + const HloInstruction* constant = + module->entry_computation()->GetInstructionWithName("constant.42"); + ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); + const HloInstruction* while1 = bcast->operand(0); + ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); + const HloInstruction* while0 = while1->operand(0)->operand(0); + ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); + + // Run buffer assignment. + auto assignment = RunBufferAssignment(module.get()); + TF_ASSERT_OK_AND_ASSIGN(auto slice_constant, + assignment->GetUniqueSlice(constant, {})); + TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, + assignment->GetUniqueSlice(while0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto slice_while1, + assignment->GetUniqueSlice(while1, {})); + + // The constant slice is part of the while0's colocation set (init value), but + // not merged into the while1's colocation set. + EXPECT_EQ(slice_constant, slice_while0); + EXPECT_NE(slice_constant, slice_while1); +} + // Tests that the colocated buffers for while instructions are properly assigned // during buffer assignment such that the result tuple elements are not assigned // to the same buffer. @@ -1850,7 +2014,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { auto build_cond = [&]() { auto builder = HloComputation::Builder("cond"); auto const4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x")); builder.AddInstruction(HloInstruction::CreateBinary( @@ -1862,7 +2026,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { auto build_body = [&]() { auto builder = HloComputation::Builder("body"); auto const9 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(9))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(9))); auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x")); builder.AddInstruction( @@ -1874,7 +2038,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); - auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, token, "")); auto infeed_data = builder.AddInstruction( @@ -1890,7 +2054,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { HloInstruction::CreateWhile(r0s32, cond1, body1, while0)); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero)); auto cond2 = module->AddEmbeddedComputation(build_cond()); @@ -1921,7 +2085,9 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { module.get(), xla::MakeUnique(module.get(), sequence), backend().compiler()->BufferSizeBytesFunction(), - [](LogicalBuffer::Color) { return 1; })); + [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // The result tuple elements must be assigned with different buffers. TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0})); @@ -1952,7 +2118,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { HloInstruction::CreateParameter(1, data_shape_, "weights0")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); @@ -1996,16 +2162,16 @@ TEST_F(BufferAssignmentTest, TwoCalls) { auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param")); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1)); sub_computation = module->AddEmbeddedComputation(builder.Build(add)); } auto builder = HloComputation::Builder(TestName()); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto call1 = builder.AddInstruction( HloInstruction::CreateCall(r0f32, {constant2}, sub_computation)); auto call2 = builder.AddInstruction( @@ -2029,6 +2195,56 @@ TEST_F(BufferAssignmentTest, TwoCalls) { EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment)); } +TEST_F(BufferAssignmentTest, CallParamCoAllocation) { + const char* hlo_text = R"( +HloModule CallParamCoAllocation + +Callee { + param0 = (f32[100],(f32[200],f32[300])) parameter(0) + param1 = s32[20] parameter(1) + ROOT constant = f32[] constant(1) +} + +ENTRY Main { + entry_param0 = f32[100] parameter(0) + entry_param1 = s32[20] parameter(1) + custom_call = (f32[200],f32[300]) custom-call(), custom_call_target="call-target" + call_op0 = (f32[100],(f32[200],f32[300])) tuple(entry_param0, custom_call) + ROOT call_result = f32[] call(call_op0, entry_param1), to_apply=Callee +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloRunner::CreateModuleFromString( + hlo_text, legacy_flags::GetDebugOptionsFromFlags())); + + auto buffers = RunBufferAssignment(module.get()); + + HloComputation* main = module->entry_computation(); + HloComputation* callee = module->GetComputationWithName("Callee"); + EXPECT_NE(callee, nullptr); + + HloInstruction* param0 = callee->parameter_instruction(0); + HloInstruction* param1 = callee->parameter_instruction(1); + + HloInstruction* entry_param0 = main->parameter_instruction(0); + HloInstruction* entry_param1 = main->parameter_instruction(1); + HloInstruction* custom_call = main->GetInstructionWithName("custom_call"); + + EXPECT_EQ(GetAllocation(*buffers, entry_param0, {}), + GetAllocation(*buffers, param0, {0})); + EXPECT_EQ(GetAllocation(*buffers, entry_param1, {}), + GetAllocation(*buffers, param1, {})); + + EXPECT_EQ(GetAllocation(*buffers, custom_call, {}), + GetAllocation(*buffers, param0, {1})); + EXPECT_EQ(GetAllocation(*buffers, custom_call, {0}), + GetAllocation(*buffers, param0, {1, 0})); + EXPECT_EQ(GetAllocation(*buffers, custom_call, {1}), + GetAllocation(*buffers, param0, {1, 1})); +} + static bool IsPostOrderTraversal( const std::vector& sequence) { tensorflow::gtl::FlatSet seen_so_far; @@ -2057,9 +2273,9 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto builder = HloComputation::Builder(TestName()); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto input0 = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape_, "input0")); @@ -2125,7 +2341,9 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { BufferAssigner::Run( module.get(), xla::MakeUnique(module.get(), sequence), - ByteSizeOf, [](LogicalBuffer::Color) { return 1; }) + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); @@ -2141,7 +2359,7 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { HloInstruction::CreateParameter(1, data_shape_, "weights0")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); auto output1 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index f623aef67a4f98b447a9a15634a78deb60cfe6f1..4a927b57674345f8b3493c098778182a299c5902 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -327,11 +327,12 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, param, param)); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); auto recv = builder.AddInstruction( - HloInstruction::CreateRecv(vec_, /*channel_id=*/0)); + HloInstruction::CreateRecv(vec_, token, /*channel_id=*/0)); auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); auto send = builder.AddInstruction( - HloInstruction::CreateSend(recv_done, /*channel_id=*/1)); + HloInstruction::CreateSend(recv_done, token, /*channel_id=*/1)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); auto module = CreateNewModule(); @@ -438,11 +439,13 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { // computation. The buffer containing {0, 1} is copied by GetTupleElement, and // the buffers containing {3} and 3 are dead. auto builder = HloComputation::Builder(TestName()); - auto inner_tuple0 = Literal::MakeTuple( - {Literal::CreateR0(0).get(), Literal::CreateR0(1).get()}); - auto inner_tuple1 = Literal::MakeTuple({Literal::CreateR0(3).get()}); + auto inner_tuple0 = + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), + LiteralUtil::CreateR0(1).get()}); + auto inner_tuple1 = + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(3).get()}); auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); + LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); builder.AddInstruction(HloInstruction::CreateGetTupleElement( inner_tuple0->shape(), tuple_constant, 0)); @@ -490,7 +493,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple_element0_shape, tuple_param0, 0)); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0)); @@ -502,7 +505,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple_element1_shape, tuple_param0, 1)); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1)); @@ -554,7 +557,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple_element0_shape, tuple_param0, 0)); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0)); @@ -626,7 +629,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1)); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); HloInstruction* slice = nullptr; if (update_uses_tuple_element1) { // Create a slice instruction as an additional user of 'gte1'. @@ -637,7 +640,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -756,7 +759,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1)); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); if (tuple_element1_has_two_uses) { // Add 'gte0' and 'gte1' to create another user of 'gte1'. @@ -765,7 +768,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index 1ea7d538cd515c3098b6a1f03c6146d288330406..cc80b7484313329104eec1ce71a150b47d8330c9 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -82,7 +82,7 @@ class CallGraphTest : public HloTestBase { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); return builder.Build(); @@ -247,11 +247,11 @@ TEST_F(CallGraphTest, ComputationWithConditional) { HloComputation::Builder builder(TestName()); HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloInstruction* const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(56.4f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(56.4f))); HloInstruction* const2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(12.6f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(12.6f))); HloInstruction* conditional = builder.AddInstruction(HloInstruction::CreateConditional( kScalarShape, pred, const1, true_computation, const2, diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc index 482ccc5b67109258f544e5657ecfa0e8f62192c0..256d05a73e0bf61d959d21795c106286b52d0b19 100644 --- a/tensorflow/compiler/xla/service/call_inliner.cc +++ b/tensorflow/compiler/xla/service/call_inliner.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -151,6 +152,14 @@ StatusOr CallInliner::Run(HloModule* module) { } return Status::OK(); })); + if (did_mutate) { + // Run DCE to remove called computations which are now becoming unused. + // This can result then in problems if within the called computation, there + // were send/recv instructions, which the module group verifier will flag as + // error findingthe same channel ID used for multiple send/recv + // instructions. + TF_RETURN_IF_ERROR(HloDCE().Run(module).status()); + } return did_mutate; } diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 924348c870b9ca3d86af560a0c8359af7220427e..ff968bca297077c7cf869ff8d2becb8bf739dce3 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -48,9 +48,9 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { // the "one" value. HloComputation::Builder inner(TestName() + ".inner"); HloInstruction* zero = inner.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(24.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(24.0f))); HloInstruction* one = inner.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); TF_ASSERT_OK(zero->AddControlDependencyTo(one)); auto module = CreateNewModule(); HloComputation* inner_computation = @@ -87,7 +87,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { // little trickier. HloComputation::Builder just_false(TestName() + ".false"); just_false.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* false_computation = module->AddEmbeddedComputation(just_false.Build()); @@ -99,7 +99,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { HloComputation::Builder outer(TestName() + ".outer"); HloInstruction* init_value = outer.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); outer.AddInstruction( HloInstruction::CreateWhile(pred, call_false, call_false, init_value)); @@ -123,9 +123,9 @@ TEST_F(CallInlinerTest, InlineWithoutRunningPass) { HloComputation::Builder just_false(TestName() + ".false"); auto* true_constant = just_false.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({true}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({true}))); auto* false_constant = just_false.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); TF_ASSERT_OK(false_constant->AddControlDependencyTo(true_constant)); HloComputation* false_computation = module->AddEmbeddedComputation(just_false.Build()); @@ -147,8 +147,8 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { HloComputation::Builder outfeeder(TestName() + ".outfeeder"); auto value = outfeeder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - auto token = outfeeder.AddInstruction(HloInstruction::CreateAfterAll({})); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + auto token = outfeeder.AddInstruction(HloInstruction::CreateToken()); outfeeder.AddInstruction( HloInstruction::CreateOutfeed(f32, value, token, /*outfeed_config=*/"")); diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc index a5b392cbc33c12c3255f3c06e9842fc116e672e5..13008efed1494402eaff47904c2e4797334381a1 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.cc +++ b/tensorflow/compiler/xla/service/channel_tracker.cc @@ -31,16 +31,23 @@ namespace xla { ChannelTracker::ChannelTracker() : next_channel_(1) {} -ChannelHandle ChannelTracker::NewChannel() { +StatusOr ChannelTracker::NewChannel( + ChannelHandle::ChannelType type) { + if (type != ChannelHandle::DEVICE_TO_DEVICE && + type != ChannelHandle::HOST_TO_DEVICE && + type != ChannelHandle::DEVICE_TO_HOST) { + return InvalidArgument("Invalid channel type: %d", type); + } tensorflow::mutex_lock lock(channel_mutex_); // Create a new channel handle with a unique value. - const ChannelHandle new_handle = AllocateHandle(); + ChannelHandle new_handle = AllocateHandle(type); // Register a channel object associated with the handle. Channel channel; channel.has_sender = false; channel.receiver_count = 0; + channel.type = type; opaque_to_channel_[new_handle.handle()] = channel; return new_handle; @@ -56,10 +63,11 @@ Status ChannelTracker::RegisterRecv(const ChannelHandle& handle) { return RegisterRecvInternal(handle); } -ChannelHandle ChannelTracker::AllocateHandle() { +ChannelHandle ChannelTracker::AllocateHandle(ChannelHandle::ChannelType type) { int64 handle_value = next_channel_++; ChannelHandle result; result.set_handle(handle_value); + result.set_type(type); return result; } @@ -68,6 +76,13 @@ Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { return NotFound("channel handle not found: %lld", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; + if (channel.type == ChannelHandle::HOST_TO_DEVICE) { + return FailedPrecondition( + "host-to-device channels cannot be used with a Send operation; " + "channel handle: %lld", + handle.handle()); + } + if (channel.has_sender) { return FailedPrecondition( "when registering send, passed a channel handle that is already used " @@ -83,6 +98,13 @@ Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { return NotFound("channel handle not found: %lld", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; + if (channel.type == ChannelHandle::DEVICE_TO_HOST) { + return FailedPrecondition( + "device-to-host channels cannot be used with a Recv operation; " + "channel handle: %lld", + handle.handle()); + } + // TODO(b/33942691): Allow more than 1 receivers for broadcast. if (channel.receiver_count >= 1) { return FailedPrecondition( diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h index fac0afd672ff3ed083aacf778dd9c4f90a2ee870..d773558c284a7d645f2766bb88c50f7da3777e5d 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.h +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -48,11 +48,12 @@ class ChannelTracker { struct Channel { bool has_sender; int64 receiver_count; + ChannelHandle::ChannelType type; }; // Creates a new Channel object and returns the corresponding // ChannelHandle for it. - ChannelHandle NewChannel(); + StatusOr NewChannel(ChannelHandle::ChannelType type); // Informs that the given channel handle is used for a Send operation. // Returns an error status if the handle is already used by another Send. @@ -65,7 +66,8 @@ class ChannelTracker { private: // Bumps the next_channel_ number and returns the allocated number // wrapped in a ChannelHandle. - ChannelHandle AllocateHandle() EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_); + ChannelHandle AllocateHandle(ChannelHandle::ChannelType type) + EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_); Status RegisterSendInternal(const ChannelHandle& handle) EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_); diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index 7c1bacff92b231661477b9931a3066fd91110445..187ce568cbb6c6666e978b8c8114262313c70ba5 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" @@ -29,9 +29,13 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; + namespace xla { Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { @@ -71,6 +75,19 @@ DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { return std::move(assignment); } +string DeviceAssignment::ToString() const { + string output = StrCat("Computations: ", computation_count(), + " Replicas: ", replica_count(), "\n"); + for (int computation = 0; computation < computation_count(); ++computation) { + StrAppend(&output, "Computation ", computation, ": "); + for (int replica = 0; replica < replica_count(); ++replica) { + StrAppend(&output, operator()(replica, computation), " "); + } + StrAppend(&output, "\n"); + } + return output; +} + StatusOr ComputationPlacer::DeviceId(int replica, int computation, int replica_count, int computation_count) { diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h index 737d00e93ecb51a9bd544bbcbe99d93374d108fb..c899ffb9dc562426ef14c0d414469c04debeec70 100644 --- a/tensorflow/compiler/xla/service/computation_placer.h +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -55,6 +55,8 @@ class DeviceAssignment : public Array2D { // due to a StatusOr of an incomplete type (DeviceAssignment). static StatusOr> Deserialize( const DeviceAssignmentProto& proto); + + string ToString() const; }; // A generic implementation of the XLA computation placer, which assigns device diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index e9ec796121fff223474c3e81a5e973cc37f8caec..b7be3ba605a89a736b032eaab5a5085ac64fc549 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index c38719d50efaf7e1b95b5ed2cf3030f9bfdfe57f..c43a31b167d47af3c92ed35fa52594fa5da1e4af 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -55,7 +55,7 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) { true_computation_builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {}), "param")); auto one = true_computation_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); true_computation_builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, one)); @@ -73,7 +73,7 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) { HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "param")); auto forty_two = false_computation_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); false_computation_builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, forty_two)); @@ -82,11 +82,11 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) { } auto false_instrn = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto false_param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {}), "false_param")); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); builder.AddInstruction(HloInstruction::CreateConditional( ShapeUtil::MakeShape(S32, {}), false_instrn, one, true_computation, @@ -106,7 +106,7 @@ TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) { HloComputation* computation = MakeConditional(&module()); auto* true_op = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); TF_ASSERT_OK( true_op->AddControlDependencyTo(computation->root_instruction())); @@ -119,10 +119,11 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) { ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); auto* true_computation = conditional->true_computation(); + auto* token = true_computation->AddInstruction(HloInstruction::CreateToken()); auto* send = true_computation->AddInstruction(HloInstruction::CreateSend( true_computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))), - /*channel_id=*/0)); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))), + token, /*channel_id=*/0)); true_computation->AddInstruction(HloInstruction::CreateSendDone(send)); EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); } @@ -133,8 +134,9 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) { ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); auto* true_computation = conditional->true_computation(); + auto* token = true_computation->AddInstruction(HloInstruction::CreateToken()); auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv( - ShapeUtil::MakeShape(F32, {1}), /*channel_id=*/0)); + ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0)); true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv)); EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); } @@ -144,8 +146,7 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { auto* conditional = computation->root_instruction(); ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); auto* false_computation = conditional->false_computation(); - auto token = - false_computation->AddInstruction(HloInstruction::CreateAfterAll({})); + auto token = false_computation->AddInstruction(HloInstruction::CreateToken()); false_computation->AddInstruction(HloInstruction::CreateInfeed( ShapeUtil::MakeShape(F32, {1}), token, "config")); EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index b0ad433d8ddb7b5e0861150634ff91e4068d10dd..36fb9b43aa20bad788a0638b4fed6c88fc9023f0 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -76,15 +76,6 @@ SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node, policy.copy_parameters_and_constants = true; policy.copy_root_replicated_buffers = true; } - for (const CallSite& site : node.caller_callsites()) { - // The AddCopiesForConditional() already adds copies, but the copy remover - // removes them, so we re-add them by returning the policy here. But really - // the copy remover should not be removing them. - if (site.instruction()->opcode() == HloOpcode::kConditional) { - policy.copy_parameters_and_constants = true; - policy.copy_root_replicated_buffers = true; - } - } return policy; } @@ -360,26 +351,6 @@ Status StripControlDependenciesFrom(HloInstruction* instruction) { return Status::OK(); } -// Add kCopy instructions to the given module to guarantee there is no -// live-range interference. Generally interference can only occur around kWhile -// instructions which have update-in-place semantics. -Status AddCopiesToResolveInterference(HloModule* module) { - TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module)); - - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kWhile) { - TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction)); - } else if (instruction->opcode() == HloOpcode::kConditional) { - TF_RETURN_IF_ERROR( - AddCopiesForConditional(*alias_analysis, instruction)); - } - } - } - return Status::OK(); -} - // Class for removing unnecessary copies from the module. // // kCopy instructions are added conservatively to guarantee no live range @@ -954,6 +925,36 @@ class CopyRemover { BufferValueTracker buffer_value_tracker_; }; +void MaybeDumpModule(const string& message, const HloModule& module) { + if (VLOG_IS_ON(3)) { + VLOG(3) << message; + XLA_VLOG_LINES(3, module.ToString()); + hlo_graph_dumper::MaybeDumpHloModule(module, message); + } +} + +} // namespace + +// Add kCopy instructions to the given module to guarantee there is no +// live-range interference. Generally interference can only occur around kWhile +// instructions which have update-in-place semantics. +Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); + + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kWhile) { + TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction)); + } else if (instruction->opcode() == HloOpcode::kConditional) { + TF_RETURN_IF_ERROR( + AddCopiesForConditional(*alias_analysis, instruction)); + } + } + } + return Status::OK(); +} + // Add copies to address special constraints on the roots of computations not // related to live range interference: // @@ -964,9 +965,10 @@ class CopyRemover { // // (3) Constants and parameters cannot be live out of the entry computation // -Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { +Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, + HloModule* module) { TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module)); + HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); // Identify which shape indices of which instructions need to be copied. Store // these results in 'instructions_to_copy'. @@ -1074,33 +1076,20 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { return Status::OK(); } -Status VerifyNoLiveRangeInterference(HloModule* module) { +Status CopyInsertion::VerifyNoLiveRangeInterference(HloModule* module) { TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module)); + HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); DependencyHloOrdering ordering(module); TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering)); return Status::OK(); } -void MaybeDumpModule(const string& message, const HloModule& module) { - if (VLOG_IS_ON(3)) { - VLOG(3) << message; - XLA_VLOG_LINES(3, module.ToString()); - hlo_graph_dumper::MaybeDumpHloModule(module, message); - } -} - -} // namespace - -Status RemoveUnnecessaryCopies( - const HloOrdering& ordering, - const tensorflow::gtl::FlatSet& copies_to_exclude, HloModule* module, - const HloDataflowAnalysis::FusionCanShareBufferFunction& - fusion_can_share_buffer) { +Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering, + HloModule* module) { MaybeDumpModule("after adding copies to resolve interference", *module); TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module, fusion_can_share_buffer)); + HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); CopyRemover copy_remover(*alias_analysis, ordering, module); XLA_VLOG_LINES(3, copy_remover.ToString()); @@ -1108,7 +1097,6 @@ Status RemoveUnnecessaryCopies( for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kCopy && - !ContainsKey(copies_to_exclude, instruction->unique_id()) && instruction->CopyElisionAllowed()) { TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); } @@ -1152,16 +1140,13 @@ StatusOr CopyInsertion::Run(HloModule* module) { "Call graph must be flattened before copy insertion."); } - // Gather Ids of existing kCopy instructions in the module. We avoid removing - // these copies (except via DCE in TupleSimplifier) because they may have been - // added for reasons not considered by copy insertion (eg, layout assignment). - // Instruction id is used instead of HloInstruction* because the pointer - // values may be recycled. - tensorflow::gtl::FlatSet existing_copies; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy) { - existing_copies.insert(instruction->unique_id()); + int64 num_existing_copies = 0; + if (VLOG_IS_ON(1)) { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + ++num_existing_copies; + } } } } @@ -1181,8 +1166,7 @@ StatusOr CopyInsertion::Run(HloModule* module) { TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); DependencyHloOrdering ordering(module); - TF_RETURN_IF_ERROR( - RemoveUnnecessaryCopies(ordering, existing_copies, module)); + TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module)); TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); @@ -1203,7 +1187,7 @@ StatusOr CopyInsertion::Run(HloModule* module) { } } } - VLOG(1) << "Num copies before copy-insertion: " << existing_copies.size(); + VLOG(1) << "Num copies before copy-insertion: " << num_existing_copies; VLOG(1) << "Num copies after copy-insertion: " << num_total_copies; } diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index 6d257060891122e56b763b32166fb4c11dfc444b..5ba64b78a3c9aff5f323691df2ece9b5e6bf3232 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -72,21 +71,26 @@ class CopyInsertion : public HloPassInterface { // TODO(b/62548313): Remove this when buffer assignment is module-scoped. static StatusOr AddCopiesForBufferAssignment(HloModule* module); + // Try to remove as many copies from the module as possible without + // introducing live range interference. Only copy instructions that are + // eligible for copy elision are considered for removal. + Status RemoveUnnecessaryCopies(const HloOrdering& ordering, + HloModule* module); + private: + // Verifies that no HLO values have interfering live ranged assuming the + // ordering used by copy insertion. + Status VerifyNoLiveRangeInterference(HloModule* module); + + Status AddCopiesToResolveInterference(HloModule* module); + + Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module); + // Backend specific function that decides whether a fusion can share buffer // with its operand. HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer_; }; -// Try to remove as many copies from the module as possible without introducing -// live range interference. Copy instructions (identified by their unique id) in -// the set copies_to_exclude are not considered for removal. -Status RemoveUnnecessaryCopies( - const HloOrdering& ordering, - const tensorflow::gtl::FlatSet& copies_to_exclude, HloModule* module, - const HloDataflowAnalysis::FusionCanShareBufferFunction& - fusion_can_share_buffer = nullptr); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_COPY_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index e7539759cea523191c158a6d22e05440b9b01af8..cd735256b83f5f1d69a89e693de6064d460a36e5 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -108,7 +108,7 @@ TEST_F(CopyInsertionTest, SingleConstant) { // be copied before entering the tuple. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant})); @@ -125,21 +125,27 @@ TEST_F(CopyInsertionTest, SingleConstant) { } TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { - // Verify that an kCopy instructions which exist in the pass before + // Verify that kCopy instructions which change layout and exist before // copy-insertion remain in the graph after copy-insertion. auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - HloInstruction* copy_1 = builder.AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kCopy, constant)); - HloInstruction* copy_2 = builder.AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kCopy, constant)); + HloInstruction* constant = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{0.f, 2.f}, {2.f, 4.f}}))); + auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape()); + Layout reversed_layout = + LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major); + Shape copy_shape = constant->shape(); + *copy_shape.mutable_layout() = reversed_layout; + HloInstruction* copy_1 = builder.AddInstruction( + HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant)); + HloInstruction* copy_2 = builder.AddInstruction( + HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant)); HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( constant->shape(), HloOpcode::kAdd, copy_1, copy_2)); - HloInstruction* add_copy = builder.AddInstruction( - HloInstruction::CreateUnary(constant->shape(), HloOpcode::kCopy, add)); + builder.AddInstruction( + HloInstruction::CreateUnary(add->shape(), HloOpcode::kCopy, add)); module->AddEntryComputation(builder.Build()); @@ -147,12 +153,11 @@ TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { InsertCopies(module.get()); - EXPECT_EQ(CountCopies(*module), 3); + EXPECT_EQ(CountCopies(*module), 2); - EXPECT_EQ(module->entry_computation()->root_instruction(), add_copy); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - op::Copy(op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())))); + EXPECT_EQ(module->entry_computation()->root_instruction(), add); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Add(op::Copy(op::Constant()), op::Copy(op::Constant()))); } TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { @@ -162,9 +167,9 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { auto builder = HloComputation::Builder(TestName()); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); HloInstruction* x = builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); @@ -192,11 +197,11 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { // the computation result. Verify that copies are added properly. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); HloInstruction* constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); HloInstruction* tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -204,9 +209,9 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { HloInstruction::CreateTuple({constant3, constant2})); HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); EXPECT_THAT(constant1->users(), UnorderedElementsAre(tuple1)); EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2)); @@ -250,8 +255,9 @@ TEST_F(CopyInsertionTest, BitcastConstant) { // The output of a bitcast is its operand (same buffer), so a bitcast // constant feeding the result must have a copy added. auto builder = HloComputation::Builder(TestName()); - HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1.0, 42.0}))); + HloInstruction* constant = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1.0, 42.0}))); HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant)); @@ -365,9 +371,9 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { // copy is added. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); HloInstruction* tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -375,9 +381,9 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { HloInstruction::CreateTuple({constant2, constant1})); HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); HloInstruction* gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(select->shape(), {0}), select, 0)); @@ -408,7 +414,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { const Shape& loop_state_shape) { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(10))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); auto loop_state = builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); auto induction_variable = @@ -437,7 +443,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); // Update data GTE(1). @@ -475,7 +481,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); // add0 = Add(in0, 1) auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -544,7 +550,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); // add0 = Add(in0, 1) auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); @@ -559,8 +565,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest { data = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); } - auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto update = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); // add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data, update)); @@ -593,7 +600,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( gte0->shape(), HloOpcode::kAdd, gte0, inc)); @@ -603,8 +610,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // GTE(GTE(loop_state, 1), 0) -> Add auto gte10 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0)); - auto update10 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto update10 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); auto add10 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, gte10, update10)); @@ -628,10 +636,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest { bool nested = false) { auto builder = HloComputation::Builder(TestName() + ".While"); auto induction_var_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); - auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto data_init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); if (nested) { auto inner_init = builder.AddInstruction( @@ -654,8 +663,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest { HloInstruction* BuildWhileInstruction_InitPointsToConstant() { auto builder = HloComputation::Builder(TestName() + ".While"); - auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto data_init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, &builder); } @@ -672,11 +682,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto builder = HloComputation::Builder(TestName() + ".While"); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto v1 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, one, {1})); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto v2 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); @@ -684,9 +694,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto data_init = builder.AddInstruction(HloInstruction::CreateTernary( - nested_tuple_shape_, HloOpcode::kSelect, pred, tuple1, tuple2)); + nested_tuple_shape_, HloOpcode::kTupleSelect, pred, tuple1, tuple2)); return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_, data_init, &builder); @@ -696,7 +706,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto builder = HloComputation::Builder(TestName() + ".While"); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto one_vec = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, one, {1})); auto data_init = @@ -709,11 +719,12 @@ class WhileCopyInsertionTest : public CopyInsertionTest { HloInstruction* BuildWhileInstruction_InitPointsToInterfering() { auto builder = HloComputation::Builder(TestName() + ".While"); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto data_init = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, one, {1})); - auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto one_vec = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); // Take a reference to 'data_init' to make it interfere with while result. auto add = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data_init, one_vec)); @@ -745,7 +756,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { const bool nested = ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_); auto induction_var_init = builder->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto condition = module_->AddEmbeddedComputation( BuildConditionComputation(loop_state_shape)); auto body = module_->AddEmbeddedComputation( @@ -1247,7 +1258,6 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { auto loop_init = builder.AddInstruction( HloInstruction::CreateTuple({iter_param, data_param, data_param})); - // Two while loops shares the same loop init tuple. auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape, condition1, body1, loop_init)); @@ -1305,7 +1315,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); cond_builder.AddInstruction(HloInstruction::CreateUnary( cond_constant->shape(), HloOpcode::kNot, cond_constant)); HloComputation* condition = @@ -1313,9 +1323,9 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( @@ -1370,7 +1380,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); cond_builder.AddInstruction(HloInstruction::CreateUnary( cond_constant->shape(), HloOpcode::kNot, cond_constant)); HloComputation* condition = @@ -1378,9 +1388,9 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( @@ -1430,7 +1440,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); cond_builder.AddInstruction(HloInstruction::CreateUnary( cond_constant->shape(), HloOpcode::kNot, cond_constant)); HloComputation* condition = @@ -1438,7 +1448,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant, constant})); builder.AddInstruction( @@ -1515,7 +1525,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); cond_builder.AddInstruction(HloInstruction::CreateUnary( cond_constant->shape(), HloOpcode::kNot, cond_constant)); HloComputation* condition = @@ -1570,14 +1580,14 @@ TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { body_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param")); body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0))); HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); cond_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module->AddEmbeddedComputation(cond_builder.Build()); @@ -1639,7 +1649,7 @@ std::unique_ptr MakeTrivialCondition(const Shape& shape) { builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "loop_state")); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kNot, constant)); return builder.Build(); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 2c3eb1ae367ffe1de93c6fc8f4efdc6d69964e10..504b61d134a0099d055d0266408e1dfb94af5b2a 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -37,6 +37,7 @@ cc_library( srcs = ["cpu_transfer_manager.cc"], hdrs = ["cpu_transfer_manager.h"], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -72,7 +73,7 @@ cc_library( ":ir_emitter", ":parallel_task_assignment", ":simple_orc_jit", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -89,7 +90,6 @@ cc_library( "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", - "//tensorflow/compiler/xla/service:gather_expander", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", @@ -129,7 +129,7 @@ cc_library( "@llvm//:x86_code_gen", # fixdeps: keep "@llvm//:x86_disassembler", # fixdeps: keep ] + select({ - "@org_tensorflow//tensorflow:linux_ppc64le": [ + "//tensorflow:linux_ppc64le": [ "@llvm//:powerpc_disassembler", "@llvm//:powerpc_code_gen", ], @@ -252,12 +252,13 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", + "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", - "//tensorflow/compiler/xla/service/llvm_ir:ops", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "@llvm//:code_gen", @@ -355,7 +356,7 @@ tf_cc_binary( srcs = ["sample_harness.cc"], deps = [ "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -363,8 +364,8 @@ tf_cc_binary( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) @@ -444,6 +445,7 @@ cc_library( deps = [ ":vector_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:math_ops", "//tensorflow/core:lib", "@llvm//:core", "@llvm//:transform_utils", @@ -717,7 +719,7 @@ tf_cc_test( deps = [ ":cpu_layout_assignment", ":target_machine_features_fake", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -809,7 +811,7 @@ tf_cc_test( ":cpu_executable", ":parallel_task_assignment", ":target_machine_features_fake", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -892,7 +894,7 @@ tf_cc_test( srcs = ["cpu_copy_insertion_test.cc"], deps = [ ":cpu_copy_insertion", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 6a7eb85e3baec3517b8f3ddef6a8dcfae9c9e614..128eea4828b5e514b2ba6b398898e4a5d228e746 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -156,9 +156,26 @@ std::unique_ptr CompilerFunctor::operator()( target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream); codegen_passes.run(module); - // Construct ObjectFile from machine code buffer. - return std::unique_ptr( + std::unique_ptr memory_buffer( new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer))); + + if (VLOG_IS_ON(2)) { + llvm::Expected> obj_file = + llvm::object::ObjectFile::createObjectFile(*memory_buffer); + if (obj_file) { + StatusOr disasm_result = + disassembler_->DisassembleObjectFile(*obj_file.get()); + if (disasm_result.ok()) { + XLA_VLOG_LINES(2, disasm_result.ValueOrDie().text); + } else { + LOG(WARNING) << "Could not disassemble object file!"; + } + } else { + LOG(WARNING) << "Could convert memory buffer to object file!"; + } + } + + return memory_buffer; } static std::vector VectorFunctionsForTargetLibraryInfoImpl() { diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 375b017b09263c20c1b1ef8329f7e2f6a573dda4..547d4c696da5cfdde3dece03250ae5fa51c92f25 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -60,11 +60,11 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { auto builder = HloComputation::Builder(TestName()); // The input dimensions are in CNHW order. auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR4FromArray4D(Array4D( + LiteralUtil::CreateR4FromArray4D(Array4D( kInputFeatureCount, kBatchSize, kInputSize, kInputSize)))); // The kernel dimensions are in OIHW order. auto kernel = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR4FromArray4D(Array4D( + LiteralUtil::CreateR4FromArray4D(Array4D( kOutputFeatureCount, kInputFeatureCount, kWindowSize, kWindowSize)))); ConvolutionDimensionNumbers dnums; @@ -122,11 +122,11 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { auto builder = HloComputation::Builder(TestName()); // The input dimensions are in NHWC order. auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR4FromArray4D(Array4D( + LiteralUtil::CreateR4FromArray4D(Array4D( kBatchSize, kInputSize, kInputSize, kInputFeatureCount)))); // The kernel dimensions are in HWIO order. auto kernel = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR4FromArray4D(Array4D( + LiteralUtil::CreateR4FromArray4D(Array4D( kWindowSize, kWindowSize, kInputFeatureCount, kOutputFeatureCount)))); ConvolutionDimensionNumbers dnums; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 55962ba70d213939ccb49cad3bdd75395cc4eaa5..8cbe9a1b0d5b0553b1121d544196412f36f8ce43 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -30,6 +30,7 @@ limitations under the License. #include "llvm/ADT/Triple.h" #include "llvm/IR/Function.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Mangler.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "llvm/Object/ObjectFile.h" @@ -38,7 +39,7 @@ limitations under the License. #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -66,7 +67,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" -#include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" @@ -297,8 +297,6 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pipeline.AddPass(/*is_layout_sensitive=*/false); pipeline.AddPass(); - pipeline.AddPass(); - ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -564,7 +562,9 @@ StatusOr> CpuCompiler::RunBackend( BufferAssigner::Run( module.get(), xla::MakeUnique(module.get(), module_sequence), - BufferSizeBytesFunction(), memory_alignment)); + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -586,6 +586,8 @@ StatusOr> CpuCompiler::RunBackend( std::move(computation_to_profile_idx), &target_machine_features); + TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + for (auto embedded_computation : entry_computation->MakeEmbeddedComputationsList()) { if (embedded_computation->IsFusionComputation()) { @@ -607,7 +609,13 @@ StatusOr> CpuCompiler::RunBackend( /*is_top_level_computation=*/true, &module_sequence.at(entry_computation))); - string function_name = llvm_ir::AsString(entry_function->getName()); + string function_name = [&]() { + llvm::SmallVector function_name_vector; + llvm::Mangler::getNameWithPrefix( + function_name_vector, entry_function->getName(), jit->data_layout()); + return string(function_name_vector.begin(), function_name_vector.end()); + }(); + string ir_module_string; if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); @@ -743,7 +751,9 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, BufferAssigner::Run( module, xla::MakeUnique(module, module_sequence), - BufferSizeBytesFunction(), memory_alignment)); + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -772,6 +782,9 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), &target_machine_features); + + TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { @@ -827,17 +840,29 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, BufferSizes buffer_sizes; for (const BufferAllocation& allocation : assignment->Allocations()) { - // Callers don't need to allocate temporary buffers for parameters. - if (allocation.is_entry_computation_parameter()) { - buffer_sizes.push_back(-1); - continue; - } // Callers don't need to allocate anything for thread-local temporary // buffers. They are lowered to allocas. if (allocation.is_thread_local()) { buffer_sizes.push_back(-1); continue; } + + // Callers don't need to allocate anything for constant buffers. They are + // lowered to globals. + if (allocation.is_constant()) { + buffer_sizes.push_back(-1); + continue; + } + + // Callers don't need to allocate anything for entry computation buffers, + // but they do need to stash the pointer to the entry computation buffer + // in the temp buffer table. See the comment on + // XlaCompiledCpuFunction::StaticData::temp_sizes. + if (allocation.is_entry_computation_parameter()) { + buffer_sizes.push_back(-allocation.parameter_number() - 2); + continue; + } + buffer_sizes.push_back(allocation.size()); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc index a05a26941786cbf404c4685abb098c9ac8caaa09..4db7fa446ea9188940f930bcadf753bd3e6b79e3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -74,14 +74,14 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { body_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param")); body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0))); HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); cond_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module->AddEmbeddedComputation(cond_builder.Build()); @@ -114,7 +114,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) { auto sub_param = sub_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param")); auto constant = sub_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0))); auto add = sub_builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, sub_param, constant)); sub_builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 1093559892ddb9c238fd9c1f7e3d419ec7022776..946f5124b87bc011df4f3553077dbb37a3333ed2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -69,12 +69,19 @@ CpuExecutable::CpuExecutable( // guarded by the mutex. compute_function_ = reinterpret_cast(cantFail(sym.getAddress())); + VLOG(1) << "compute_function_ at address " + << reinterpret_cast(compute_function_); } -Status CpuExecutable::AllocateBuffers( +StatusOr, + std::vector>> +CpuExecutable::CreateTempArray( DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector* buffers) { - CHECK_EQ(buffers->size(), assignment_->Allocations().size()); + tensorflow::gtl::ArraySlice arguments) { + std::vector unowning_buffers( + assignment_->Allocations().size()); + std::vector owning_buffers( + assignment_->Allocations().size()); VLOG(3) << "Allocating " << assignment_->Allocations().size() << " allocations for module " << module().name(); for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); @@ -84,44 +91,51 @@ Status CpuExecutable::AllocateBuffers( VLOG(3) << allocation.ToString(); if (allocation.is_entry_computation_parameter()) { + unowning_buffers[i] = arguments[allocation.parameter_number()]->buffer( + allocation.param_shape_index()); VLOG(3) << "allocation #" << i << " is a parameter"; continue; } + if (allocation.is_constant()) { + VLOG(3) << "allocation #" << i << " is a constant"; + continue; + } + if (allocation.is_thread_local()) { VLOG(3) << "buffer #" << i << " is thread-local"; continue; } int64 buffer_size = allocation.size(); - if (!(*buffers)[i].is_null()) { + if (!owning_buffers[i].is_null()) { VLOG(3) << "buffer #" << i << " is in the preallocated result ShapedBuffer"; } else { - TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate( - device_ordinal, buffer_size)); + TF_ASSIGN_OR_RETURN(owning_buffers[i], memory_allocator->Allocate( + device_ordinal, buffer_size)); + unowning_buffers[i] = owning_buffers[i].AsDeviceMemoryBase(); VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes [" - << (*buffers)[i].opaque() << "]"; + << owning_buffers[i].opaque() << "]"; } // Since the output buffer and all the temporary buffers were written into // by the JITed code, msan has no way of knowing their memory was // initialized. Mark them initialized so that msan doesn't flag loads from // these buffers. - TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size); + TF_ANNOTATE_MEMORY_IS_INITIALIZED(owning_buffers[i].opaque(), buffer_size); } TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, assignment_->GetUniqueTopLevelOutputSlice()); VLOG(3) << "result index: " << result_slice.index(); - return Status::OK(); + return {{std::move(unowning_buffers), std::move(owning_buffers)}}; } Status CpuExecutable::ExecuteComputeFunction( const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { // The calling convention for JITed functions is: @@ -131,17 +145,11 @@ Status CpuExecutable::ExecuteComputeFunction( // // result: Points at the result. // run_options: the ExecutableRunOptions object. - // args_array: An array of pointers, each of which points to a parameter. - // The size of this array is determined by the function's arity - // (ProgramShape). - // temps_array: An array of pointers, each of which points to a temporary - // buffer the computation needs. The size of this array is - // determined by buffer analysis. + // args_array: null + // temps_array: An array of pointers, containing pointers to temporary buffers + // required by the executable adn pointers to entry computation + // parameters. // - std::vector args_array; - for (const ShapedBuffer* argument : arguments) { - args_array.push_back(argument->root_buffer().opaque()); - } uint64 start_micros = tensorflow::Env::Default()->NowMicros(); @@ -164,16 +172,14 @@ Status CpuExecutable::ExecuteComputeFunction( if (VLOG_IS_ON(3)) { VLOG(3) << "Executing compute function:"; VLOG(3) << tensorflow::strings::Printf( - " func(void* result, void* params[%zu], void* temps[%zu], " + " func(void* result, void* params[null], void* temps[%zu], " "uint64 profile_counters[%zu])", - args_array.size(), buffer_pointers.size(), profile_counters_size); + buffer_pointers.size(), profile_counters_size); VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); auto ptr_printer = [](string* out, const void* p) { tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); }; - VLOG(3) << tensorflow::strings::Printf( - " params = [%s]", - tensorflow::str_util::Join(args_array, ", ", ptr_printer).c_str()); + VLOG(3) << " params = nullptr"; VLOG(3) << tensorflow::strings::Printf( " temps = [%s]", tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); @@ -181,8 +187,8 @@ Status CpuExecutable::ExecuteComputeFunction( profile_counters); } - compute_function_(result_buffer, run_options, args_array.data(), - buffer_pointers.data(), profile_counters); + compute_function_(result_buffer, run_options, nullptr, buffer_pointers.data(), + profile_counters); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -249,21 +255,18 @@ StatusOr CpuExecutable::ExecuteOnStream( se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); - - TF_RETURN_IF_ERROR(AllocateBuffers( - memory_allocator, stream->parent()->device_ordinal(), &buffers)); + std::vector owning_buffers; std::vector unowning_buffers; - unowning_buffers.reserve(buffers.size()); - for (auto& buffer : buffers) { - unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); - } - TF_RETURN_IF_ERROR(ExecuteComputeFunction(&run_options->run_options(), - arguments, unowning_buffers, - hlo_execution_profile)); + TF_ASSIGN_OR_RETURN( + std::tie(unowning_buffers, owning_buffers), + CreateTempArray(memory_allocator, stream->parent()->device_ordinal(), + arguments)); + + TF_RETURN_IF_ERROR(ExecuteComputeFunction( + &run_options->run_options(), unowning_buffers, hlo_execution_profile)); - return CreateResultShapedBuffer(run_options, &buffers); + return CreateResultShapedBuffer(run_options, &owning_buffers); } StatusOr CpuExecutable::ExecuteAsyncOnStream( @@ -279,17 +282,15 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( run_options->stream()->implementation()); se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); - TF_RETURN_IF_ERROR(AllocateBuffers( - memory_allocator, stream->parent()->device_ordinal(), &buffers)); - + std::vector owning_buffers; std::vector unowning_buffers; - unowning_buffers.reserve(buffers.size()); - for (auto& buffer : buffers) { - unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); - } + TF_ASSIGN_OR_RETURN( + std::tie(unowning_buffers, owning_buffers), + CreateTempArray(memory_allocator, stream->parent()->device_ordinal(), + arguments)); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, - CreateResultShapedBuffer(run_options, &buffers)); + CreateResultShapedBuffer(run_options, &owning_buffers)); // At this point, `unowning_buffers` contains unowning pointers to all of our // buffers, and `buffers` contains owning pointers to the non-live-out @@ -307,7 +308,6 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( struct AsyncRunTask { CpuExecutable* executable; ServiceExecutableRunOptions run_options; - std::vector arguments; std::vector unowning_buffers; std::shared_ptr> buffers; @@ -315,15 +315,14 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( // Failing a CHECK here is not great, but I don't see an obvious way to // return a failed Status asynchronously. TF_CHECK_OK(executable->ExecuteComputeFunction( - &run_options.run_options(), arguments, unowning_buffers, + &run_options.run_options(), unowning_buffers, /*hlo_execution_profile=*/nullptr)); } }; - host_stream->EnqueueTask(AsyncRunTask{ - this, *run_options, - std::vector(arguments.begin(), arguments.end()), - unowning_buffers, - std::make_shared>(std::move(buffers))}); + host_stream->EnqueueTask( + AsyncRunTask{this, *run_options, std::move(unowning_buffers), + std::make_shared>( + std::move(owning_buffers))}); return std::move(result); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 8dd47bfb865e8a0552542f510d3365cff0d111e0..8af8a5dfec2834678418f069619ba88b01633361 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -85,20 +85,29 @@ class CpuExecutable : public Executable { const BufferAssignment& buffer_assignment() const { return *assignment_; } private: - // Allocate buffers required for execution and assign them to the elements of - // "buffers". "buffers" should be sized to the number of buffers in buffer - // assignment. Each vector element corresponds to a particular Index. If - // a vector element already contains a non-null DeviceMemoryBase, then no - // buffer is assigned for this element. - Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator, - int device_ordinal, - std::vector* buffers); + // Creates an array suitable for passing as the "temps" argument to the JIT + // compiled function pointer. + // + // Returns (unowning_buffers, owning_buffers) where: + // + // - unowning_buffers.data() can be passed as the temps argument as-is and + // includes pointers to the scratch storage required by the computation, + // the live-out buffer into which the result will be written and entry + // computation parameters. + // + // - owning_buffers contains owning pointers to the buffers that were + // allocated by this routine. This routine allocates buffers for temporary + // storage and the live-out buffer into which the computation writes it + // result. + StatusOr, + std::vector>> + CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal, + tensorflow::gtl::ArraySlice arguments); // Calls the generated function performing the computation with the given // arguments using the supplied buffers. Status ExecuteComputeFunction( const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 750310c633286aa8f964c9ae5dcf847f2dc0557c..991b14f17dbc8cd061af98e032824d3f7075e78b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -282,7 +282,7 @@ class OpcodeFusionTest : public InstructionFusionTest { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "arg0")); HloInstruction* one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one)); return module->AddEmbeddedComputation(builder.Build()); @@ -595,7 +595,7 @@ TEST_F(OpcodeFusionTest, MessOfFusileNodes) { auto pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(S32, {5}), idx_choice, builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), padding_config)); auto slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 429fc7b78608da0e9cd794ac294851b326f5be24..3681d12d8da818d06d2f690024008c9ccb896286 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 54c52bc08f9c53b8c6898689b18c4cb7f4bdcfd0..639064040f521a9e84bd87c5d05f674204e4d6e2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -92,9 +92,10 @@ tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) { } // namespace -void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, - const void* shape, - xla::int32 shape_length) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void* +__xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, + const void* shape, + xla::int32 shape_length) { if (VLOG_IS_ON(2)) { LOG(INFO) << "AcquireInfeedBufferForDequeue: " << ShapeString(shape, shape_length); @@ -111,9 +112,11 @@ void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, return buffer->data(); } -void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( - xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, - xla::int32 shape_length) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length, + void* buffer_ptr, + const void* shape_ptr, + xla::int32 shape_length) { if (VLOG_IS_ON(2)) { LOG(INFO) << "ReleaseInfeedBufferAfterDeque: " << ShapeString(shape_ptr, shape_length); @@ -125,8 +128,10 @@ void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( std::move(shape)); } -void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( - xla::int32 buffer_length, const void* shape_ptr, xla::int32 shape_length) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void* +__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length, + const void* shape_ptr, + xla::int32 shape_length) { if (VLOG_IS_ON(2)) { LOG(INFO) << "AcquireOutfeedBufferForPopulation: " << ShapeString(shape_ptr, shape_length); @@ -143,9 +148,11 @@ void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( return buffer->data(); } -void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( - xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, - xla::int32 shape_length) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(xla::int32 buffer_length, + void* buffer_ptr, + const void* shape_ptr, + xla::int32 shape_length) { if (VLOG_IS_ON(2)) { LOG(INFO) << "ReleaseOutfeedBufferAfterPopulation: " << ShapeString(shape_ptr, shape_length); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index b877b295814a7e13569a1837ed3e1787f2fc3f56..59bc7e0e16fcc66a010408259a1ccfb2b6bb35fd 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -172,7 +173,7 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, Status CpuTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) { + MutableBorrowingLiteral literal) { if (!ShapeUtil::IsTuple(literal_shape)) { int64 size = GetByteSizeRequirement(literal_shape); // Note: OSS build didn't like implicit conversion from @@ -180,18 +181,16 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( tensorflow::gtl::ArraySlice dimensions( tensorflow::bit_cast(literal_shape.dimensions().data()), literal_shape.dimensions().size()); - *literal = std::move(*Literal::CreateFromDimensions( - literal_shape.element_type(), dimensions)); - TF_ASSIGN_OR_RETURN(Shape received_shape, - TransferArrayBufferFromOutfeed( - executor, literal->untyped_data(), size)); - TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal->shape())) + TF_ASSIGN_OR_RETURN( + Shape received_shape, + TransferArrayBufferFromOutfeed(executor, literal.untyped_data(), size)); + TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal.shape())) << "Shape received from outfeed " << ShapeUtil::HumanString(received_shape) << " did not match the shape that was requested for outfeed: " << ShapeUtil::HumanString(literal_shape); TF_RET_CHECK(size == GetByteSizeRequirement(received_shape)); - *literal->mutable_shape_do_not_use() = received_shape; + *literal.mutable_shape_do_not_use() = received_shape; return Status::OK(); } @@ -200,22 +199,12 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( "Nested tuple outfeeds are not yet implemented on CPU."); } - std::vector> elements; std::vector> buffer_data; for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { const Shape& tuple_element_shape = ShapeUtil::GetTupleElementShape(literal_shape, i); - // Note: OSS build didn't like implicit conversion from - // literal_shape.dimensions() to the array slice on 2017-07-10. - tensorflow::gtl::ArraySlice dimensions( - tensorflow::bit_cast( - tuple_element_shape.dimensions().data()), - tuple_element_shape.dimensions().size()); - auto empty = Literal::CreateFromDimensions( - tuple_element_shape.element_type(), dimensions); int64 size = GetByteSizeRequirement(tuple_element_shape); - buffer_data.push_back({empty->untyped_data(), size}); - elements.push_back(std::move(empty)); + buffer_data.push_back({literal.untyped_data({i}), size}); } TF_ASSIGN_OR_RETURN(Shape received_shape, @@ -229,11 +218,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( TF_RET_CHECK(GetByteSizeRequirement(literal_shape) == GetByteSizeRequirement(received_shape)); - for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { - *elements[i]->mutable_shape_do_not_use() = received_shape.tuple_shapes(i); - } - *literal = std::move(*Literal::MakeTupleOwned(std::move(elements))); - TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape)); + TF_RET_CHECK(ShapeUtil::Equal(literal.shape(), literal_shape)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index 6dfc666f09dfa6df740cd54bea0957e3144181bc..80ef953d532798281c10b7a212b9c4d84a790c27 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -39,13 +40,14 @@ class CpuTransferManager : public GenericTransferManager { Status TransferLiteralToInfeed(se::StreamExecutor* executor, const LiteralSlice& literal) override; - Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, - const void* source) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + MutableBorrowingLiteral literal) override; private: + Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, + const void* source); + // Transfers infeed data to device. InfeedBuffer->Done() must be // called to clean up the memory allocated for InfeedBuffer. StatusOr TransferBufferToInfeedInternal( diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 58228180ca55ede50c8579bbd73cfdfffc07e208..645888de783e4025cffd6fa4835e60b84bbd7d99 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -49,15 +49,15 @@ class MemoryTile { // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at // `major_dim_offset` in the major dimension. The tile size along the minor // dimension is the vector size, and that is implicitly determined by `vsl`. - MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder, + MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* b, llvm::Value* matrix, int64 matrix_size_along_minor_dim, llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) - : vsl_(vsl), ir_builder_(ir_builder) { + : vsl_(vsl), b_(b) { pointers_.reserve(tile_size_along_major_dim); for (int64 i = 0; i < tile_size_along_major_dim; i++) { - llvm::Value* total_offset = ir_builder->CreateMul( - ir_builder->getInt64(matrix_size_along_minor_dim), - ir_builder->CreateAdd(ir_builder->getInt64(i), major_dim_offset)); + llvm::Value* total_offset = + b->CreateMul(b->getInt64(matrix_size_along_minor_dim), + b->CreateAdd(b->getInt64(i), major_dim_offset)); pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset)); } } @@ -101,8 +101,7 @@ class MemoryTile { for (int64 i = 0; i < pointers_.size(); i++) { for (int64 j = 0; j < tile_size_along_middle_dim; j++) { result[i].push_back(vsl_->LoadBroadcast( - pointers_[i], ir_builder_->CreateAdd(minor_dim_offset, - ir_builder_->getInt64(j)))); + pointers_[i], b_->CreateAdd(minor_dim_offset, b_->getInt64(j)))); } } return result; @@ -110,7 +109,7 @@ class MemoryTile { private: VectorSupportLibrary* vsl_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; std::vector pointers_; }; @@ -249,16 +248,15 @@ class ColumnMajorMatrixVectorProductEmitter ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, llvm::Value* result, - llvm::IRBuilder<>* ir_builder) + llvm::IRBuilder<>* b) : config_(config), lhs_(lhs), rhs_(rhs), addend_(addend), result_(result), - ir_builder_(ir_builder), - ksl_(ir_builder_), - vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), - ir_builder_, "") { + b_(b), + ksl_(b_), + vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), b_, "") { CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast(tile_rows()))); CHECK(!has_addend() || addend != nullptr); } @@ -272,7 +270,7 @@ class ColumnMajorMatrixVectorProductEmitter bool is_first_column); MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) { - return MemoryTile(&vsl_, ir_builder_, /*matrix=*/lhs_, + return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, /*matrix_size_along_minor_dim=*/m(), /*major_dim_offset=*/column_start, /*tile_size_along_major_dim=*/column_count); @@ -302,7 +300,7 @@ class ColumnMajorMatrixVectorProductEmitter llvm::Value* rhs_; llvm::Value* addend_; llvm::Value* result_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; KernelSupportLibrary ksl_; VectorSupportLibrary vsl_; }; @@ -331,7 +329,7 @@ void ColumnMajorMatrixVectorProductEmitter::Emit() { }); if (column_remainder != 0) { - EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder, + EmitOuterLoopBody(b_->getInt64(column_limit), column_remainder, column_limit == 0); } } @@ -364,7 +362,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( return; } - llvm::Value* columns_llvm = ir_builder_->getInt64(columns); + llvm::Value* columns_llvm = b_->getInt64(columns); // for (col = current_tile_col; col < (columns + current_tile_col); col++) // for (row = row_start, row < m_; row++) { @@ -375,12 +373,11 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( ksl_.ForReturnVoid( "dot.inner.epilg.outer", /*start=*/current_tile_col, - /*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col), + /*end=*/b_->CreateAdd(columns_llvm, current_tile_col), /*step=*/1, /*peel_first_iteration=*/false, [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); - llvm::Value* total_offset = - ir_builder_->CreateMul(col, ir_builder_->getInt64(m())); + llvm::Value* total_offset = b_->CreateMul(col, b_->getInt64(m())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); ksl_.ForReturnVoid( @@ -388,9 +385,8 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( /*step=*/1, [&](llvm::Value* scalar_row) { llvm::Value* product = vsl_.Mul( vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); - llvm::Value* setting_result_first_time = ir_builder_->CreateAnd( - is_first_scalar_col, - ir_builder_->getInt1(is_first_tiled_column)); + llvm::Value* setting_result_first_time = b_->CreateAnd( + is_first_scalar_col, b_->getInt1(is_first_tiled_column)); ksl_.IfReturnVoid( setting_result_first_time, /*true_block_generator=*/ @@ -478,16 +474,15 @@ class RowMajorMatrixVectorProductEmitter RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, - llvm::Value* result, - llvm::IRBuilder<>* ir_builder) + llvm::Value* result, llvm::IRBuilder<>* b) : config_(config), lhs_(lhs), rhs_(rhs), addend_(addend), result_(result), - ir_builder_(ir_builder), - ksl_(ir_builder_), - vsl_(scalar_type(), /*vector_size=*/tile_cols(), ir_builder_, "") { + b_(b), + ksl_(b_), + vsl_(scalar_type(), /*vector_size=*/tile_cols(), b_, "") { CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast(tile_cols()))); CHECK(!has_addend() || addend != nullptr); } @@ -498,7 +493,7 @@ class RowMajorMatrixVectorProductEmitter private: MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) { - return MemoryTile(&vsl_, ir_builder_, /*matrix=*/lhs_, + return MemoryTile(&vsl_, b_, /*matrix=*/lhs_, /*matrix_size_along_minor_dim=*/k(), /*major_dim_offset=*/row_start, /*tile_size_along_major_dim=*/row_count); @@ -517,7 +512,7 @@ class RowMajorMatrixVectorProductEmitter llvm::Value* rhs_; llvm::Value* addend_; llvm::Value* result_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; KernelSupportLibrary ksl_; VectorSupportLibrary vsl_; }; @@ -559,7 +554,7 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, for (int i = 0; i < row_count; i++) { llvm::Value* result_value = vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get()); - llvm::Value* offset = ir_builder_->CreateAdd(ir_builder_->getInt64(i), row); + llvm::Value* offset = b_->CreateAdd(b_->getInt64(i), row); if (addend_ && row_count != vsl_.vector_size()) { result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value); } @@ -578,7 +573,7 @@ void RowMajorMatrixVectorProductEmitter::Emit() { [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); if (row_remainder != 0) { - EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder); + EmitOuterLoopBody(b_->getInt64(row_limit), row_remainder); } } @@ -609,9 +604,8 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( } for (int r = 0; r < rows; r++) { - llvm::Value* total_offset = ir_builder_->CreateMul( - ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row), - ir_builder_->getInt64(k())); + llvm::Value* total_offset = b_->CreateMul( + b_->CreateAdd(b_->getInt64(r), current_tile_row), b_->getInt64(k())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); ksl_.ForReturnVoid( @@ -722,13 +716,13 @@ class MatrixMatrixBlockPanelEmitter { // `lhs` with `rhs` and stores the result in `result`. explicit MatrixMatrixBlockPanelEmitter(Config config, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, - llvm::IRBuilder<>* ir_builder) + llvm::IRBuilder<>* b) : lhs_(lhs), rhs_(rhs), result_(result), config_(config), - ir_builder_(ir_builder), - ksl_(ir_builder_) { + b_(b), + ksl_(b_) { CHECK(max_vectorization_width() > 0 && IsPowerOfTwo(static_cast(max_vectorization_width()))); CHECK_GT(max_vector_count(), 0); @@ -761,7 +755,7 @@ class MatrixMatrixBlockPanelEmitter { int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end); - llvm::Value* GetInt64(int64 value) { return ir_builder_->getInt64(value); } + llvm::Value* GetInt64(int64 value) { return b_->getInt64(value); } Config config() const { return config_; } Dimensions dims() const { return config().dims(); } @@ -782,7 +776,7 @@ class MatrixMatrixBlockPanelEmitter { llvm::Value* result_; Config config_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; KernelSupportLibrary ksl_; }; @@ -804,8 +798,8 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { current_vectorization_width >= min_vectorization_width()) { int64 n_end = dims().n() - (dims().n() % current_vectorization_width); if (n_start != n_end) { - VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, - ir_builder_, "gebp"); + VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_, + "gebp"); HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); n_start = n_end; } @@ -819,10 +813,9 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { } if (n_start != dims().n()) { - VectorSupportLibrary vsl(scalar_type(), 1, ir_builder_, "gebp"); + VectorSupportLibrary vsl(scalar_type(), 1, b_, "gebp"); ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { - llvm::Value* n_i_next = - ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1)); + llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1)); HandleResiduesOnK(&vsl, n_i, n_i_next); }); } @@ -935,11 +928,11 @@ void MatrixMatrixBlockPanelEmitter::EmitTiledGemm( ksl_.ForReturnVoid( "dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { MemoryTile result_memory_tile( - vsl, ir_builder_, /*matrix=*/result_, + vsl, b_, /*matrix=*/result_, /*matrix_size_along_minor_dim=*/dims().n(), /*major_dim_offset=*/m_i, /*tile_size_along_major_dim=*/tile_size_m); - MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_, + MemoryTile lhs_memory_tile(vsl, b_, /*matrix=*/lhs_, /*matrix_size_along_minor_dim=*/dims().k(), /*major_dim_offset=*/m_i, /*tile_size_along_major_dim=*/tile_size_m); @@ -949,8 +942,8 @@ void MatrixMatrixBlockPanelEmitter::EmitTiledGemm( result_memory_tile.LoadTile(n_i)); ksl_.ForReturnVoid( "dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { - MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_, - dims().n(), k_i, tile_size_k); + MemoryTile rhs_memory_tile(vsl, b_, rhs_, dims().n(), k_i, + tile_size_k); std::vector> lhs_tile = lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); std::vector rhs_tile = @@ -980,7 +973,7 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* ir_builder, + llvm::IRBuilder<>* b, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) : dot_(dot), @@ -989,7 +982,7 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, rhs_array_(rhs_array), addend_array_(addend_array), executable_run_options_value_(executable_run_options_value), - ir_builder_(ir_builder), + b_(b), hlo_module_config_(hlo_module_config), target_machine_features_(target_machine_features) {} @@ -997,15 +990,14 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, const HloInstruction& dot, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { PrimitiveType type = target_array.GetShape().element_type(); TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type); DotOpEmitter dot_emitter(dot, target_array, lhs_array, rhs_array, - addend_array, executable_run_options_value, - ir_builder, hlo_module_config, - target_machine_features); + addend_array, executable_run_options_value, b, + hlo_module_config, target_machine_features); return dot_emitter.Emit(); } @@ -1050,13 +1042,13 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( } int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - ir_builder_->CreateMemSet( - target, ir_builder_->getInt8(0), size_bytes, + b_->CreateMemSet( + target, b_->getInt8(0), size_bytes, target_machine_features_.minimum_alignment_for_allocation(size_bytes)); int64 max_target_vector_width = target_machine_features_.vector_register_num_elements( - *ir_builder_->GetInsertBlock()->getParent(), primitive_type); + *b_->GetInsertBlock()->getParent(), primitive_type); int64 tile_size_m, tile_size_k, tile_size_n_in_vector_width; std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) = @@ -1080,12 +1072,12 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, ir_builder_, - config.GetCacheKey(), lhs, rhs, target, + /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), lhs, + rhs, target, [this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) { - MatrixMatrixBlockPanelEmitter gebp_emitter( - config, /*lhs=*/lhs, /*rhs=*/rhs, - /*result=*/target, ir_builder_); + MatrixMatrixBlockPanelEmitter gebp_emitter(config, /*lhs=*/lhs, + /*rhs=*/rhs, + /*result=*/target, b_); gebp_emitter.Emit(); }); @@ -1163,7 +1155,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { const int target_vector_register_element_size = target_machine_features_.vector_register_num_elements( - *ir_builder_->GetInsertBlock()->getParent(), primitive_type); + *b_->GetInsertBlock()->getParent(), primitive_type); // We may not always know the vector register size for the target we're // compiling against, in which case target_vector_register_element_size is 0. @@ -1184,13 +1176,13 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, ir_builder_, - config.GetCacheKey(), lhs_op, rhs_op, + /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), + lhs_op, rhs_op, addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, llvm::Value* result_op) { ColumnMajorMatrixVectorProductEmitter emitter( - config, lhs_op, rhs_op, addend_op, result_op, ir_builder_); + config, lhs_op, rhs_op, addend_op, result_op, b_); emitter.Emit(); }); } else { @@ -1203,13 +1195,13 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, ir_builder_, - config.GetCacheKey(), lhs_op, rhs_op, + /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), + lhs_op, rhs_op, addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, llvm::Value* result_op) { - RowMajorMatrixVectorProductEmitter emitter( - config, lhs_op, rhs_op, addend_op, result_op, ir_builder_); + RowMajorMatrixVectorProductEmitter emitter(config, lhs_op, rhs_op, + addend_op, result_op, b_); emitter.Emit(); }); } @@ -1285,11 +1277,11 @@ Status DotOpEmitter::Emit() { // Create loop nests which loop through the LHS operand dimensions and the RHS // operand dimensions. The reduction dimension of the LHS and RHS are handled // in a separate innermost loop which performs the sum of products. - llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(&dot_), ir_builder_); - llvm_ir::IrArray::Index lhs_index = EmitOperandArrayLoopNest( - &loop_nest, lhs_array_, lhs_reduction_dimension, "lhs"); - llvm_ir::IrArray::Index rhs_index = EmitOperandArrayLoopNest( - &loop_nest, rhs_array_, rhs_reduction_dimension, "rhs"); + llvm_ir::ForLoopNest loop_nest(llvm_ir::IrName(&dot_), b_); + llvm_ir::IrArray::Index lhs_index = loop_nest.EmitOperandArrayLoopNest( + lhs_array_, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); + llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest( + rhs_array_, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs"); // Create the loop which does the sum of products reduction. // @@ -1319,62 +1311,55 @@ Status DotOpEmitter::Emit() { // Function entry basic block. // - Emit alloca for accumulator llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent(); - SetToFirstInsertPoint(&func->getEntryBlock(), ir_builder_); + SetToFirstInsertPoint(&func->getEntryBlock(), b_); llvm::Type* accum_type = target_array_.GetElementLlvmType(); - llvm::Value* accum_address = ir_builder_->CreateAlloca( - accum_type, /*ArraySize=*/nullptr, "accum_address"); + llvm::Value* accum_address = + b_->CreateAlloca(accum_type, /*ArraySize=*/nullptr, "accum_address"); // Preheader basic block of reduction loop: // - Initialize accumulator to zero. llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock(); - ir_builder_->SetInsertPoint(preheader_bb->getTerminator()); + b_->SetInsertPoint(preheader_bb->getTerminator()); - ir_builder_->CreateStore(llvm::Constant::getNullValue(accum_type), - accum_address); + b_->CreateStore(llvm::Constant::getNullValue(accum_type), accum_address); // Body basic block of reduction loop: // - Load elements from lhs and rhs array. // - Multiply lhs-element and rhs-element. // - Load accumulator and add to product. // - Store sum back into accumulator. - SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), ir_builder_); + SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), b_); - llvm::Value* lhs_element = - lhs_array_.EmitReadArrayElement(lhs_index, ir_builder_); - llvm::Value* rhs_element = - rhs_array_.EmitReadArrayElement(rhs_index, ir_builder_); + llvm::Value* lhs_element = lhs_array_.EmitReadArrayElement(lhs_index, b_); + llvm::Value* rhs_element = rhs_array_.EmitReadArrayElement(rhs_index, b_); - llvm::Value* accum = ir_builder_->CreateLoad(accum_address); + llvm::Value* accum = b_->CreateLoad(accum_address); llvm::Value* updated_accum; if (ShapeUtil::ElementIsComplex(lhs_shape)) { - auto real = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {1}); - }; - llvm::Value* product_real = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(real(lhs_element), real(rhs_element)), - ir_builder_->CreateFMul(imag(lhs_element), imag(rhs_element))); - llvm::Value* product_imag = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(lhs_element), imag(rhs_element)), - ir_builder_->CreateFMul(imag(lhs_element), real(rhs_element))); - updated_accum = ir_builder_->CreateInsertValue( - accum, ir_builder_->CreateFAdd(real(accum), product_real), {0}); - updated_accum = ir_builder_->CreateInsertValue( - updated_accum, ir_builder_->CreateFAdd(imag(accum), product_imag), {1}); + auto real = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {0}); }; + auto imag = [&](llvm::Value* x) { return b_->CreateExtractValue(x, {1}); }; + llvm::Value* product_real = + b_->CreateFSub(b_->CreateFMul(real(lhs_element), real(rhs_element)), + b_->CreateFMul(imag(lhs_element), imag(rhs_element))); + llvm::Value* product_imag = + b_->CreateFAdd(b_->CreateFMul(real(lhs_element), imag(rhs_element)), + b_->CreateFMul(imag(lhs_element), real(rhs_element))); + updated_accum = b_->CreateInsertValue( + accum, b_->CreateFAdd(real(accum), product_real), {0}); + updated_accum = b_->CreateInsertValue( + updated_accum, b_->CreateFAdd(imag(accum), product_imag), {1}); } else { - llvm::Value* product = ir_builder_->CreateFMul(lhs_element, rhs_element); - updated_accum = ir_builder_->CreateFAdd(accum, product); + llvm::Value* product = b_->CreateFMul(lhs_element, rhs_element); + updated_accum = b_->CreateFAdd(accum, product); } - ir_builder_->CreateStore(updated_accum, accum_address); + b_->CreateStore(updated_accum, accum_address); // Exit basic block of reduction loop. // - Load accumulator value (the result). // - Store into output array. - SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), ir_builder_); + SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), b_); - llvm::Value* result = ir_builder_->CreateLoad(accum_address); + llvm::Value* result = b_->CreateLoad(accum_address); // Create index into target address. The target index is the concatenation of // the rhs and lhs indexes with the reduction dimensions removed. The terms @@ -1392,11 +1377,11 @@ Status DotOpEmitter::Emit() { } } - target_array_.EmitWriteArrayElement(target_index, result, ir_builder_); + target_array_.EmitWriteArrayElement(target_index, result, b_); // Set the IR builder insert point to the exit basic block of the outer most // loop. - ir_builder_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); + b_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); return Status::OK(); } @@ -1405,31 +1390,30 @@ Status DotOpEmitter::EmitScalarDot() { // A scalar dot is just a scalar multiply. llvm::Value* result; // Use the same index_type for all tensor accesses in the same kernel. - llvm::Type* index_type = ir_builder_->getInt64Ty(); + llvm::Type* index_type = b_->getInt64Ty(); llvm_ir::IrArray::Index element_index(index_type); llvm::Value* lhs_value = - lhs_array_.EmitReadArrayElement(/*index=*/element_index, ir_builder_); + lhs_array_.EmitReadArrayElement(/*index=*/element_index, b_); llvm::Value* rhs_value = - rhs_array_.EmitReadArrayElement(/*index=*/element_index, ir_builder_); + rhs_array_.EmitReadArrayElement(/*index=*/element_index, b_); if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) { -#define REAL(x) ir_builder_->CreateExtractValue(x, {0}) -#define IMAG(x) ir_builder_->CreateExtractValue(x, {1}) - llvm::Value* real = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(REAL(lhs_value), REAL(rhs_value)), - ir_builder_->CreateFMul(IMAG(lhs_value), IMAG(rhs_value))); - llvm::Value* imag = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(REAL(lhs_value), IMAG(rhs_value)), - ir_builder_->CreateFMul(IMAG(lhs_value), REAL(rhs_value))); +#define REAL(x) b_->CreateExtractValue(x, {0}) +#define IMAG(x) b_->CreateExtractValue(x, {1}) + llvm::Value* real = + b_->CreateFSub(b_->CreateFMul(REAL(lhs_value), REAL(rhs_value)), + b_->CreateFMul(IMAG(lhs_value), IMAG(rhs_value))); + llvm::Value* imag = + b_->CreateFAdd(b_->CreateFMul(REAL(lhs_value), IMAG(rhs_value)), + b_->CreateFMul(IMAG(lhs_value), REAL(rhs_value))); #undef IMAG #undef REAL result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType()); - result = ir_builder_->CreateInsertValue(result, real, {0}); - result = ir_builder_->CreateInsertValue(result, imag, {1}); + result = b_->CreateInsertValue(result, real, {0}); + result = b_->CreateInsertValue(result, imag, {1}); } else { - result = ir_builder_->CreateFMul(lhs_value, rhs_value); + result = b_->CreateFMul(lhs_value, rhs_value); } - target_array_.EmitWriteArrayElement(/*index=*/element_index, result, - ir_builder_); + target_array_.EmitWriteArrayElement(/*index=*/element_index, result, b_); return Status::OK(); } @@ -1452,7 +1436,7 @@ Status DotOpEmitter::EmitCallToRuntime() { fn_name = multi_threaded ? runtime::kEigenMatMulF16SymbolName : runtime::kEigenSingleThreadedMatMulF16SymbolName; - float_type = ir_builder_->getHalfTy(); + float_type = b_->getHalfTy(); break; case F32: fn_name = multi_threaded @@ -1461,7 +1445,7 @@ Status DotOpEmitter::EmitCallToRuntime() { : (use_mkl_dnn ? runtime::kMKLSingleThreadedMatMulF32SymbolName : runtime::kEigenSingleThreadedMatMulF32SymbolName); - float_type = ir_builder_->getFloatTy(); + float_type = b_->getFloatTy(); break; case F64: fn_name = multi_threaded @@ -1470,7 +1454,7 @@ Status DotOpEmitter::EmitCallToRuntime() { : (use_mkl_dnn ? runtime::kMKLSingleThreadedMatMulF64SymbolName : runtime::kEigenSingleThreadedMatMulF64SymbolName); - float_type = ir_builder_->getDoubleTy(); + float_type = b_->getDoubleTy(); break; default: return Unimplemented("Invalid type %s for dot operation", @@ -1478,16 +1462,16 @@ Status DotOpEmitter::EmitCallToRuntime() { } llvm::Type* float_ptr_type = float_type->getPointerTo(); - llvm::Type* int64_type = ir_builder_->getInt64Ty(); - llvm::Type* int32_type = ir_builder_->getInt32Ty(); - llvm::Type* int8_ptr_type = ir_builder_->getInt8Ty()->getPointerTo(); + llvm::Type* int64_type = b_->getInt64Ty(); + llvm::Type* int32_type = b_->getInt32Ty(); + llvm::Type* int8_ptr_type = b_->getInt8Ty()->getPointerTo(); llvm::FunctionType* matmul_type = llvm::FunctionType::get( - ir_builder_->getVoidTy(), + b_->getVoidTy(), {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type, int64_type, int64_type, int64_type, int32_type, int32_type}, /*isVarArg=*/false); - llvm::Function* function = ir_builder_->GetInsertBlock()->getParent(); + llvm::Function* function = b_->GetInsertBlock()->getParent(); llvm::Module* module = function->getParent(); llvm::Function* matmul_func = llvm::cast( @@ -1522,18 +1506,15 @@ Status DotOpEmitter::EmitCallToRuntime() { std::swap(transpose_lhs, transpose_rhs); } - ir_builder_->CreateCall( + b_->CreateCall( matmul_func, - {ir_builder_->CreateBitCast(executable_run_options_value_, int8_ptr_type), - ir_builder_->CreateBitCast(target_array_.GetBasePointer(), - float_ptr_type), - ir_builder_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type), - ir_builder_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type), - ir_builder_->getInt64(mat_mult_dims.m), - ir_builder_->getInt64(mat_mult_dims.n), - ir_builder_->getInt64(mat_mult_dims.k), - ir_builder_->getInt32(transpose_lhs), - ir_builder_->getInt32(transpose_rhs)}); + {b_->CreateBitCast(executable_run_options_value_, int8_ptr_type), + b_->CreateBitCast(target_array_.GetBasePointer(), float_ptr_type), + b_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type), + b_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type), + b_->getInt64(mat_mult_dims.m), b_->getInt64(mat_mult_dims.n), + b_->getInt64(mat_mult_dims.k), b_->getInt32(transpose_lhs), + b_->getInt32(transpose_rhs)}); return Status::OK(); } @@ -1556,36 +1537,6 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0}; } -llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( - llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array, - int64 reduction_dimension, tensorflow::StringPiece name_suffix) { - // Prepares the dimension list we will use to emit the loop nest. Outermost - // loops are added first. Add loops in major-to-minor order, and skip the - // reduction dimension. - std::vector dimensions; - const Shape& shape = operand_array.GetShape(); - for (int i = LayoutUtil::MinorToMajor(shape).size() - 1; i >= 0; --i) { - int64 dimension = LayoutUtil::Minor(shape.layout(), i); - if (dimension != reduction_dimension) { - dimensions.push_back(dimension); - } - } - - // Create loop nest with one for-loop for each dimension of the - // output. - llvm_ir::IrArray::Index index = - loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix); - // Verify every dimension except the reduction dimension was set in the index. - for (int dimension = 0; dimension < index.size(); ++dimension) { - if (dimension == reduction_dimension) { - DCHECK_EQ(nullptr, index[dimension]); - } else { - DCHECK_NE(nullptr, index[dimension]); - } - } - return index; -} - // Return whether the given shape is a matrix with no padding. static bool IsRank2WithNoPadding(const Shape& shape) { return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index ed2a18976a0f1a88e7bb4632d3a63167d5c146ad..590032fbe907d7ca90bf69b7ccc3170b8efec72e 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -61,7 +61,7 @@ class DotOpEmitter { const HloInstruction& dot, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features); @@ -70,8 +70,7 @@ class DotOpEmitter { const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* ir_builder, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features); @@ -89,17 +88,6 @@ class DotOpEmitter { // Emits a call to the CPU runtime to perform the matrix multiply. Status EmitCallToRuntime(); - // Emits a series of nested loops for iterating over an operand array in the - // dot operation. Loops are constructed in major to minor dimension layout - // order. No loop is emitted for the given reduction_dimension. The function - // returns an IrArray index for the given operand_array containing the indvars - // of the loops. All dimensions of the index are filled except for the - // reduction dimension. name_suffix is the string to append to the names of - // LLVM constructs (eg, basic blocks) constructed by this method. - llvm_ir::IrArray::Index EmitOperandArrayLoopNest( - llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array, - int64 reduction_dimension, tensorflow::StringPiece name_suffix); - // Represents the dimensions of a matrix-matrix multiply operation. struct MatMultDims { // The number of rows in the LHS. @@ -171,7 +159,7 @@ class DotOpEmitter { const llvm_ir::IrArray& rhs_array_; const llvm_ir::IrArray* addend_array_; llvm::Value* executable_run_options_value_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; const HloModuleConfig& hlo_module_config_; const TargetMachineFeatures& target_machine_features_; }; diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index e97113dfa0f59e791d614c0093d0781e49c48ee4..c13d36776f94221598338dca4eadf024c0a892df 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -19,6 +19,8 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/types.h" @@ -38,8 +40,7 @@ StatusOr CpuElementalIrEmitter::EmitFloatUnaryOp( switch (element_type) { case F16: cast_result_to_fp16 = true; - operand_value = ir_builder_->CreateFPCast(operand_value, - ir_builder_->getFloatTy()); + operand_value = b_->CreateFPCast(operand_value, b_->getFloatTy()); TF_FALLTHROUGH_INTENDED; case F32: function_name = "tanhf"; @@ -59,9 +60,9 @@ StatusOr CpuElementalIrEmitter::EmitFloatUnaryOp( function->setDoesNotThrow(); function->setDoesNotAccessMemory(); // Create an instruction to call the function. - llvm::Value* result = ir_builder_->CreateCall(function, operand_value); + llvm::Value* result = b_->CreateCall(function, operand_value); if (cast_result_to_fp16) { - result = ir_builder_->CreateFPCast(result, ir_builder_->getHalfTy()); + result = b_->CreateFPCast(result, b_->getHalfTy()); } return result; } @@ -77,8 +78,8 @@ StatusOr CpuElementalIrEmitter::EmitAtan2( switch (prim_type) { case F16: cast_result_to_fp16 = true; - lhs = ir_builder_->CreateFPCast(lhs, ir_builder_->getFloatTy()); - rhs = ir_builder_->CreateFPCast(rhs, ir_builder_->getFloatTy()); + lhs = b_->CreateFPCast(lhs, b_->getFloatTy()); + rhs = b_->CreateFPCast(rhs, b_->getFloatTy()); TF_FALLTHROUGH_INTENDED; case F32: function_name = "atan2f"; @@ -98,9 +99,9 @@ StatusOr CpuElementalIrEmitter::EmitAtan2( function->setDoesNotThrow(); function->setDoesNotAccessMemory(); // Create an instruction to call the function. - llvm::Value* result = ir_builder_->CreateCall(function, {lhs, rhs}); + llvm::Value* result = b_->CreateCall(function, {lhs, rhs}); if (cast_result_to_fp16) { - result = ir_builder_->CreateFPCast(result, ir_builder_->getHalfTy()); + result = b_->CreateFPCast(result, b_->getHalfTy()); } return result; } @@ -118,9 +119,8 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( ElementwiseSourceIndex(index, *hlo, i))); operands.push_back(operand_value); } - return ir_emitter_->EmitScalarCall(hlo->shape().element_type(), - hlo->to_apply(), operands, - llvm_ir::IrName(hlo)); + return ir_emitter_->EmitElementalMap(*Cast(hlo), + operands, llvm_ir::IrName(hlo)); }; } return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator); diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index 4446dfd2821fb4b6e75f33694367392ecbcdd8bf..9598a886ab49fcecf5df7bd65f425fe485de3574 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -31,7 +31,7 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { public: CpuElementalIrEmitter(const HloModuleConfig& module_config, IrEmitter* ir_emitter, llvm::Module* module) - : ElementalIrEmitter(module_config, module, ir_emitter->ir_builder()), + : ElementalIrEmitter(module_config, module, ir_emitter->b()), ir_emitter_(ir_emitter) {} llvm_ir::ElementGenerator MakeElementGenerator( diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 6b66a4b0b7cef0058a761801815606b9440016cf..ca645d3f1da18fb26378a10526c27a7d254896e2 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -51,10 +51,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -89,14 +90,14 @@ IrEmitter::IrEmitter( : assignment_(assignment), module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), - ir_builder_(llvm_module->getContext()), + b_(llvm_module->getContext()), instruction_to_profile_idx_(std::move(instruction_to_profile_idx)), computation_to_profile_idx_(std::move(computation_to_profile_idx)), alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), hlo_module_config_(hlo_module.config()), is_top_level_computation_(false), target_machine_features_(*target_machine_features) { - ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( + b_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() .xla_enable_fast_math())); } @@ -115,6 +116,19 @@ StatusOr IrEmitter::EmitComputation( computation->root_instruction()->outer_dimension_partitions().size(); } + if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) { + TF_ASSIGN_OR_RETURN( + computation_root_allocation_, + assignment_.GetUniqueTopLevelSlice(computation->root_instruction())); + } + + for (const HloInstruction* param : computation->parameter_instructions()) { + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice, + assignment_.GetUniqueTopLevelSlice(param)); + computation_parameter_allocations_[param_slice.allocation()->index()] = + param->parameter_number(); + } + InitializeIrFunction(function_name); // The rdtscp instruction is x86 specific. We will fallback to LLVM's generic // readcyclecounter if it is unavailable. @@ -131,6 +145,8 @@ StatusOr IrEmitter::EmitComputation( // Delete 'compute_function', finalizing 'ir_function' and restoring caller // IR insert point. compute_function_.reset(); + computation_root_allocation_ = BufferAllocation::Slice(); + computation_parameter_allocations_.clear(); return ir_function; } @@ -146,7 +162,7 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { new IrFunction(function_name, linkage, options::OptimizeForSizeRequested(hlo_module_config_), hlo_module_config_.debug_options().xla_enable_fast_math(), - module_, &ir_builder_, num_dynamic_loop_bounds_)); + module_, &b_, num_dynamic_loop_bounds_)); } IrEmitter::~IrEmitter() {} @@ -154,9 +170,9 @@ IrEmitter::~IrEmitter() {} Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { VLOG(2) << "HandleBitcast: " << bitcast->ToString(); emitted_value_[bitcast] = - ir_builder_.CreateBitCast(GetEmittedValueFor(bitcast->operand(0)), - IrShapeType(bitcast->shape())->getPointerTo(), - AsStringRef(IrName(bitcast))); + b_.CreateBitCast(GetEmittedValueFor(bitcast->operand(0)), + IrShapeType(bitcast->shape())->getPointerTo(), + AsStringRef(IrName(bitcast))); return Status::OK(); } @@ -175,25 +191,36 @@ llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { result_global, IrShapeType(literal.shape())->getPointerTo()); } -Status IrEmitter::HandleConstant(HloInstruction* constant) { - VLOG(2) << "HandleConstant: " << constant->ToString(); - const Literal& literal = constant->literal(); - llvm::Constant* global_for_const; +Status IrEmitter::EmitConstantGlobals() { + for (const BufferAllocation& allocation : assignment_.Allocations()) { + if (!allocation.is_constant()) { + continue; + } - auto it = emitted_literals_.find(&literal); - if (it != emitted_literals_.end()) { - global_for_const = it->second; - } else { - global_for_const = EmitGlobalForLiteral(literal); - emitted_literals_[&literal] = global_for_const; + const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation); + llvm::Constant* global_for_const; + auto it = emitted_literals_.find(&literal); + if (it != emitted_literals_.end()) { + global_for_const = it->second; + } else { + global_for_const = EmitGlobalForLiteral(literal); + InsertOrDie(&emitted_literals_, &literal, global_for_const); + } + + InsertOrDie(&constant_buffer_to_global_, allocation.index(), + global_for_const); } - emitted_value_[constant] = global_for_const; - VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*global_for_const); - VLOG(2) << " its type: " - << llvm_ir::DumpToString(*global_for_const->getType()); + return Status::OK(); } +Status IrEmitter::HandleConstant(HloInstruction* constant) { + VLOG(2) << "HandleConstant: " << constant->ToString(); + // IrEmitter::EmitConstantGlobals has already taken care of emitting the body + // of the constant. + return EmitTargetAddressForOp(constant); +} + Status IrEmitter::HandleCopy(HloInstruction* copy) { if (ShapeUtil::IsTuple(copy->shape())) { // kCopy shallow copies a tuple so just memcpy the top-level buffer. @@ -273,27 +300,30 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { const Shape& shape = get_tuple_element->shape(); emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement( shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape), - GetEmittedValueFor(operand), &ir_builder_, module_); + GetEmittedValueFor(operand), &b_, module_); return Status::OK(); } Status IrEmitter::HandleSelect(HloInstruction* select) { auto pred = select->operand(0); - auto on_true = select->operand(1); - auto on_false = select->operand(2); TF_RET_CHECK(pred->shape().element_type() == PRED); - - if (ShapeUtil::IsTuple(select->shape())) { - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(select)); - llvm_ir::EmitTupleSelect( - GetIrArrayFor(select), GetIrArrayFor(pred), GetEmittedValueFor(on_true), - GetEmittedValueFor(on_false), &ir_builder_, module_); - return Status::OK(); - } - return DefaultAction(select); } +Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) { + auto pred = tuple_select->operand(0); + auto on_true = tuple_select->operand(1); + auto on_false = tuple_select->operand(2); + TF_RET_CHECK(pred->shape().element_type() == PRED); + TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape())); + TF_RET_CHECK(ShapeUtil::IsTuple(tuple_select->shape())); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select)); + llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred), + GetEmittedValueFor(on_true), + GetEmittedValueFor(on_false), &b_, module_); + return Status::OK(); +} + Status IrEmitter::HandleInfeed(HloInstruction* instruction) { HloInfeedInstruction* infeed = Cast(instruction); VLOG(2) << "HandleInfeed: " << infeed->ToString(); @@ -313,8 +343,8 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { assignment_.GetUniqueSlice(infeed, {1})); llvm::Value* token_address = EmitTempBufferPointer( token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1)); - llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, - &ir_builder_, module_); + llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_, + module_); if (ShapeUtil::IsTuple(data_shape)) { TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape)); @@ -345,7 +375,7 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { } llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape), - tuple_element_addresses, &ir_builder_, module_); + tuple_element_addresses, &b_, module_); } else { TF_RETURN_IF_ERROR( EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address)); @@ -366,14 +396,14 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, int32 length_32 = static_cast(length); int32 shape_length; - TF_ASSIGN_OR_RETURN(llvm::Value * shape_ptr, - llvm_ir::EncodeSelfDescribingShapeConstant( - shape, &shape_length, &ir_builder_)); + TF_ASSIGN_OR_RETURN( + llvm::Value * shape_ptr, + llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_)); // The signature of the acquire infeed buffer function is: // // (void*)(int32 length); - llvm::Type* int32_type = ir_builder_.getInt32Ty(); + llvm::Type* int32_type = b_.getInt32Ty(); llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); llvm::FunctionType* acquire_type = llvm::FunctionType::get( i8_ptr_type, {int32_type, i8_ptr_type, int32_type}, @@ -393,8 +423,7 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, // // (void)(int32 length, void* buffer); llvm::FunctionType* release_type = llvm::FunctionType::get( - ir_builder_.getVoidTy(), - {int32_type, i8_ptr_type, i8_ptr_type, int32_type}, + b_.getVoidTy(), {int32_type, i8_ptr_type, i8_ptr_type, int32_type}, /*isVarArg=*/false); llvm::Function* release_func; @@ -411,25 +440,22 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, // of size exactly 'length_32', and the runtime is responsible for // check-failing the process if there is a mismatch, versus passing us back a // buffer that we might overrun. - llvm::Value* acquired_pointer = ir_builder_.CreateCall( - acquire_func, {ir_builder_.getInt32(length_32), shape_ptr, - ir_builder_.getInt32(shape_length)}); + llvm::Value* acquired_pointer = b_.CreateCall( + acquire_func, + {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)}); if (kind == XfeedKind::kInfeed) { // Copy to the program buffer address from the acquired buffer. - ir_builder_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, - acquired_pointer, - /*SrcAlign=*/1, length_32); + b_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer, + /*SrcAlign=*/1, length_32); } else { // Outfeed -- copy from the in-program address to the acquired buffer. - ir_builder_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, - program_buffer_address, - /*SrcAlign=*/1, length_32); + b_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address, + /*SrcAlign=*/1, length_32); } - ir_builder_.CreateCall(release_func, - {ir_builder_.getInt32(length_32), acquired_pointer, - shape_ptr, ir_builder_.getInt32(shape_length)}); + b_.CreateCall(release_func, {b_.getInt32(length_32), acquired_pointer, + shape_ptr, b_.getInt32(shape_length)}); return Status::OK(); } @@ -450,7 +476,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { ShapeUtil::GetTupleElementShape(operand_shape, i); llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement( tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape), - value, &ir_builder_, module_); + value, &b_, module_); TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed, tuple_element_shape, tuple_element)); } @@ -469,46 +495,96 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { for (auto operand : tuple->operands()) { base_ptrs.push_back(GetEmittedValueFor(operand)); } - llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &ir_builder_, module_); + llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &b_, module_); return Status::OK(); } -Status IrEmitter::HandleMap(HloInstruction* map) { - gtl::ArraySlice operands(map->operands()); - HloComputation* function = map->to_apply(); - // The called computation should have been emitted previously. - llvm::Function* mapped_ir_function = FindOrDie(emitted_functions_, function); +llvm::Value* IrEmitter::EmitElementalMap( + const HloMapInstruction& map_instr, + tensorflow::gtl::ArraySlice elemental_operands, + tensorflow::StringPiece name) { + return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name); +} + +StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( + HloReduceWindowInstruction* reduce_window, + const llvm_ir::IrArray::Index& index) { + const HloInstruction* operand = reduce_window->operand(0); + const Window& window = reduce_window->window(); + + // We fold inputs into the accumulator and initialize it to + // the initial value on the reduce_window. + PrimitiveType operand_element_type = operand->shape().element_type(); + llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), + "reduce_window_accumulator_address", &b_, + MinimumAlignmentForPrimitiveType(operand_element_type)); + b_.CreateStore(b_.CreateLoad(GetEmittedValueFor(reduce_window->operand(1))), + accumulator_address); - return EmitTargetElementLoop(map, [this, map, operands, mapped_ir_function]( - const llvm_ir::IrArray::Index& index) { - std::vector parameter_addresses; - for (const HloInstruction* operand : operands) { - const llvm_ir::IrArray& array = GetIrArrayFor(operand); - parameter_addresses.push_back( - array.EmitArrayElementAddress(index, &ir_builder_)); + llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_); + std::vector window_size; + for (const auto& dim : window.dimensions()) { + window_size.push_back(dim.size()); + } + const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape( + ShapeUtil::MakeShape(operand_element_type, window_size), "window"); + CHECK_EQ(window_index.size(), index.size()); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); + + llvm_ir::IrArray::Index input_index(b_.getInt64Ty(), index.size()); + llvm::Value* in_bounds_condition = nullptr; + for (size_t i = 0; i < index.size(); ++i) { + llvm::Value* strided_index = + b_.CreateNSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); + input_index[i] = + b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]), + b_.getInt64(window.dimensions(i).padding_low())); + + // We need to check if 0 <= input_index[i] < bound, as otherwise we are in + // the padding so that we can skip the computation. That is equivalent to + // input_index[i] < bound as an *unsigned* comparison, since a negative + // value will wrap to a large positive value. + llvm::Value* index_condition = b_.CreateICmpULT( + input_index[i], + b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + if (in_bounds_condition == nullptr) { + in_bounds_condition = index_condition; + } else { + in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); } - return EmitElementFunctionCall(mapped_ir_function, map->shape(), - parameter_addresses, "map_function"); - }); + } + CHECK(in_bounds_condition != nullptr); + + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); + SetToFirstInsertPoint(if_data.true_block, &b_); + + // We are not in the padding, so carry out the computation. + llvm_ir::IrArray input_array(GetIrArrayFor(operand)); + llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_); + llvm::Value* result = EmitThreadLocalCall( + *reduce_window->to_apply(), + {b_.CreateLoad(accumulator_address), input_value}, "reducer_function"); + b_.CreateStore(result, accumulator_address); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); + return b_.CreateLoad(accumulator_address); } Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { - auto operand = reduce_window->operand(0); - const Window& window = reduce_window->window(); - HloComputation* function = reduce_window->to_apply(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( - /*instruction=*/*reduce_window, /*operands=*/{operand}, + /*instruction=*/*reduce_window, + /*operands=*/{reduce_window->operand(0)}, /*supported_types=*/{F32, BF16, S32})); // TODO(b/31410564): Implement dilation for reduce-window. - if (window_util::HasDilation(window)) { + if (window_util::HasDilation(reduce_window->window())) { return Unimplemented( "Dilation for ReduceWindow is not implemented on CPU."); } - // The called computation should have been emitted previously. - llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); - // Pseudo code for reduce window: // // for (coordinates O in the output) @@ -523,73 +599,9 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { // This is completely un-optimized and just here to have something // that works. return EmitTargetElementLoop( - reduce_window, [this, reduce_window, operand, window, - reducer_function](const llvm_ir::IrArray::Index& index) { - // We fold inputs into the accumulator and initialize it to - // the initial value on the reduce_window. - PrimitiveType operand_element_type = operand->shape().element_type(); - llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), - "reduce_window_accumulator_address", &ir_builder_, - MinimumAlignmentForPrimitiveType(operand_element_type)); - ir_builder_.CreateStore(ir_builder_.CreateLoad(GetEmittedValueFor( - reduce_window->operand(1))), - accumulator_address); - - llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), - &ir_builder_); - std::vector window_size; - for (const auto& dim : window.dimensions()) { - window_size.push_back(dim.size()); - } - const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape( - ShapeUtil::MakeShape(operand_element_type, window_size), "window"); - CHECK_EQ(window_index.size(), index.size()); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); - - llvm_ir::IrArray::Index input_index(ir_builder_.getInt64Ty(), - index.size()); - llvm::Value* in_bounds_condition = nullptr; - for (size_t i = 0; i < index.size(); ++i) { - llvm::Value* strided_index = ir_builder_.CreateNSWMul( - index[i], ir_builder_.getInt64(window.dimensions(i).stride())); - input_index[i] = ir_builder_.CreateNSWSub( - ir_builder_.CreateNSWAdd(strided_index, window_index[i]), - ir_builder_.getInt64(window.dimensions(i).padding_low())); - - // We need to check if 0 <= input_index[i] < bound, as - // otherwise we are in the padding so that we can skip the - // computation. That is equivalent to input_index[i] < bound - // as an *unsigned* comparison, since a negative value will - // wrap to a large positive value. - llvm::Value* index_condition = ir_builder_.CreateICmpULT( - input_index[i], ir_builder_.getInt64(ShapeUtil::GetDimension( - operand->shape(), i))); - if (in_bounds_condition == nullptr) { - in_bounds_condition = index_condition; - } else { - in_bounds_condition = - ir_builder_.CreateAnd(in_bounds_condition, index_condition); - } - } - CHECK(in_bounds_condition != nullptr); - - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - in_bounds_condition, "in-bounds", &ir_builder_); - SetToFirstInsertPoint(if_data.true_block, &ir_builder_); - - // We are not in the padding, so carry out the computation. - llvm_ir::IrArray input_array(GetIrArrayFor(operand)); - llvm::Value* input_value_address = - input_array.EmitArrayElementAddress(input_index, &ir_builder_); - llvm::Value* result = EmitElementFunctionCall( - reducer_function, reduce_window->shape(), - {accumulator_address, input_value_address}, "reducer_function"); - ir_builder_.CreateStore(result, accumulator_address); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); - return ir_builder_.CreateLoad(accumulator_address); + reduce_window, [&](const llvm_ir::IrArray::Index& index) { + return EmitTargetElementLoopBodyForReduceWindow( + Cast(reduce_window), index); }); } @@ -610,12 +622,6 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { "Dilation for SelectAndScatter is not implemented on CPU. "); } - // The select and scatter computations should have been emitted previously. - llvm::Function* select_function = - FindOrDie(emitted_functions_, select_and_scatter->select()); - llvm::Function* scatter_function = - FindOrDie(emitted_functions_, select_and_scatter->scatter()); - // Pseudo code for select-and-scatter: // // initialized_flag is initially off for every window, and is turned on after @@ -641,141 +647,128 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"), [this, init_value](const llvm_ir::IrArray::Index& target_index) { llvm::Value* init_value_addr = GetEmittedValueFor(init_value); - return ir_builder_.CreateLoad(init_value_addr); + return b_.CreateLoad(init_value_addr); })); // Create a loop to iterate over the source array to scatter to the output. - llvm_ir::ForLoopNest source_loops(IrName(select_and_scatter), &ir_builder_); + llvm_ir::ForLoopNest source_loops(IrName(select_and_scatter), &b_); const llvm_ir::IrArray::Index source_index = source_loops.AddLoopsForShape(source->shape(), "source"); - SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(), - &ir_builder_); + SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(), &b_); // Allocate space to keep the currently selected value, its index, and // the boolean initialized_flag, which is initially set to false. llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), - "selected_value_address", &ir_builder_, + "selected_value_address", &b_, MinimumAlignmentForPrimitiveType(operand_element_type)); llvm::Value* selected_index_address = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - ir_builder_.getInt64Ty(), ir_builder_.getInt32(rank), - "selected_index_address", &ir_builder_); + b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_); llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( - ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_); - ir_builder_.CreateStore(ir_builder_.getInt1(false), initialized_flag_address); + b_.getInt1Ty(), "initialized_flag_address", &b_); + b_.CreateStore(b_.getInt1(false), initialized_flag_address); // Create the inner loop to iterate over the window. - llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), - &ir_builder_); + llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_); std::vector window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); } const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape( ShapeUtil::MakeShape(operand_element_type, window_size), "window"); - SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), - &ir_builder_); + SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), &b_); // Compute the operand index to visit and evaluate the condition whether the // operand index is within the bounds. The unsigned comparison includes // checking whether the operand index >= 0. - llvm_ir::IrArray::Index operand_index(ir_builder_.getInt64Ty(), - source_index.size()); - llvm::Value* in_bounds_condition = ir_builder_.getTrue(); + llvm_ir::IrArray::Index operand_index(b_.getInt64Ty(), source_index.size()); + llvm::Value* in_bounds_condition = b_.getTrue(); for (int64 i = 0; i < rank; ++i) { - llvm::Value* strided_index = ir_builder_.CreateNSWMul( - source_index[i], ir_builder_.getInt64(window.dimensions(i).stride())); - operand_index[i] = ir_builder_.CreateNSWSub( - ir_builder_.CreateNSWAdd(strided_index, window_index[i]), - ir_builder_.getInt64(window.dimensions(i).padding_low())); - llvm::Value* index_condition = ir_builder_.CreateICmpULT( + llvm::Value* strided_index = b_.CreateNSWMul( + source_index[i], b_.getInt64(window.dimensions(i).stride())); + operand_index[i] = + b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]), + b_.getInt64(window.dimensions(i).padding_low())); + llvm::Value* index_condition = b_.CreateICmpULT( operand_index[i], - ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); - in_bounds_condition = - ir_builder_.CreateAnd(in_bounds_condition, index_condition); + b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); } CHECK(in_bounds_condition != nullptr); // Only need to do something if the operand index is within the bounds. First // check if the initialized_flag is set. llvm_ir::LlvmIfData if_in_bounds = - llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_); - SetToFirstInsertPoint(if_in_bounds.true_block, &ir_builder_); - llvm_ir::LlvmIfData if_initialized = - llvm_ir::EmitIfThenElse(ir_builder_.CreateLoad(initialized_flag_address), - "initialized", &ir_builder_); + llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); + SetToFirstInsertPoint(if_in_bounds.true_block, &b_); + llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( + b_.CreateLoad(initialized_flag_address), "initialized", &b_); // If the initialized_flag is false, initialize the selected value and index // with the currently visiting operand. - SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_); + SetToFirstInsertPoint(if_initialized.false_block, &b_); const auto save_operand_index = [&](const llvm_ir::IrArray::Index& operand_index) { for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - ir_builder_.CreateInBoundsGEP(selected_index_address, - {ir_builder_.getInt32(i)}); - ir_builder_.CreateStore(operand_index[i], - selected_index_address_slot); + b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); + b_.CreateStore(operand_index[i], selected_index_address_slot); } }; llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = - operand_array.EmitReadArrayElement(operand_index, &ir_builder_); - ir_builder_.CreateStore(operand_data, selected_value_address); + operand_array.EmitReadArrayElement(operand_index, &b_); + b_.CreateStore(operand_data, selected_value_address); save_operand_index(operand_index); - ir_builder_.CreateStore(ir_builder_.getInt1(true), initialized_flag_address); + b_.CreateStore(b_.getInt1(true), initialized_flag_address); // If the initialized_flag is true, call the `select` function to potentially // update the selected value and index with the currently visiting operand. - SetToFirstInsertPoint(if_initialized.true_block, &ir_builder_); - const Shape output_shape = ShapeUtil::MakeShape(PRED, {}); + SetToFirstInsertPoint(if_initialized.true_block, &b_); llvm::Value* operand_address = - operand_array.EmitArrayElementAddress(operand_index, &ir_builder_); - llvm::Value* result = EmitElementFunctionCall( - select_function, output_shape, {selected_value_address, operand_address}, + operand_array.EmitArrayElementAddress(operand_index, &b_); + llvm::Value* operand_element = b_.CreateLoad(operand_address); + llvm::Value* result = EmitThreadLocalCall( + *select_and_scatter->select(), + {b_.CreateLoad(selected_value_address), operand_element}, "select_function"); // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. - llvm::Value* cond = ir_builder_.CreateICmpNE( + llvm::Value* cond = b_.CreateICmpNE( result, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), "boolean_predicate"); llvm_ir::LlvmIfData if_select_lhs = - llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_); - SetToFirstInsertPoint(if_select_lhs.false_block, &ir_builder_); - ir_builder_.CreateStore(ir_builder_.CreateLoad(operand_address), - selected_value_address); + llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_); + SetToFirstInsertPoint(if_select_lhs.false_block, &b_); + b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address); save_operand_index(operand_index); // After iterating over the window elements, scatter the source element to // the selected index of the output. The value we store at the output // location is computed by calling the `scatter` function with the source // value and the current output value. - SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), - &ir_builder_); + SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &b_); llvm_ir::IrArray::Index selected_index(source_index.GetType()); for (int64 i = 0; i < rank; ++i) { - llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( - selected_index_address, {ir_builder_.getInt32(i)}); - selected_index.push_back( - ir_builder_.CreateLoad(selected_index_address_slot)); + llvm::Value* selected_index_address_slot = + b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); + selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); } llvm_ir::IrArray source_array(GetIrArrayFor(source)); - llvm::Value* source_value_address = - source_array.EmitArrayElementAddress(source_index, &ir_builder_); + llvm::Value* source_value = + source_array.EmitReadArrayElement(source_index, &b_); llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter)); - llvm::Value* output_value_address = - output_array.EmitArrayElementAddress(selected_index, &ir_builder_); - llvm::Value* scatter_value = EmitElementFunctionCall( - scatter_function, source->shape(), - {output_value_address, source_value_address}, "scatter_function"); - output_array.EmitWriteArrayElement(selected_index, scatter_value, - &ir_builder_); - - SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), - &ir_builder_); + llvm::Value* output_value = + output_array.EmitReadArrayElement(selected_index, &b_); + llvm::Value* scatter_value = + EmitThreadLocalCall(*select_and_scatter->scatter(), + {output_value, source_value}, "scatter_function"); + output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_); + + SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_); return Status::OK(); } @@ -814,21 +807,155 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Dot operation is complicated so we delegate to a helper class. return DotOpEmitter::EmitDotOperation( *dot, target_array, lhs_array, rhs_array, /*addend_array=*/nullptr, - GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_, + GetExecutableRunOptionsArgument(), &b_, hlo_module_config_, target_machine_features_); } +StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( + HloConvolutionInstruction* convolution, + const llvm_ir::IrArray::Index& index) { + const HloInstruction* lhs = convolution->operand(0); + const HloInstruction* rhs = convolution->operand(1); + const Window& window = convolution->window(); + + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); + int num_spatial_dims = dnums.output_spatial_dimensions_size(); + std::vector output_spatial(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + output_spatial[i] = index[dnums.output_spatial_dimensions(i)]; + } + llvm::Value* output_feature = index[dnums.output_feature_dimension()]; + llvm::Value* batch = index[dnums.output_batch_dimension()]; + + // We will accumulate the products into this sum to calculate the output entry + // at the given index. + PrimitiveType lhs_element_type = lhs->shape().element_type(); + llvm::Type* lhs_llvm_type = + llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_); + llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( + lhs_llvm_type, "convolution_sum_address", &b_, + MinimumAlignmentForPrimitiveType(lhs_element_type)); + llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type); + b_.CreateStore(constant_zero, sum_address); + + llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_); + std::vector kernel_spatial(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + kernel_spatial[i] = + loops + .AddLoop( + 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)), + tensorflow::strings::StrCat("k", i)) + ->GetIndVarValue(); + } + llvm::Value* input_feature = + loops + .AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()), + "iz") + ->GetIndVarValue(); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); + + // Calculate the spatial index in the input array, taking striding, dilation + // and padding into account. An index in the padding will be out of the bounds + // of the array. + const auto calculate_input_index = [this](llvm::Value* output_index, + llvm::Value* kernel_index, + const WindowDimension& window_dim) { + llvm::Value* strided_index = + b_.CreateNSWMul(output_index, b_.getInt64(window_dim.stride())); + llvm::Value* dilated_kernel_index = b_.CreateNSWMul( + kernel_index, b_.getInt64(window_dim.window_dilation())); + return b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, dilated_kernel_index), + b_.getInt64(window_dim.padding_low())); + }; + std::vector input_spatial(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + input_spatial[i] = calculate_input_index( + output_spatial[i], kernel_spatial[i], window.dimensions(i)); + } + + // We need to check if 0 <= input dim < bound, as otherwise we are in the + // padding so that we can skip the computation. That is equivalent to input + // dim < bound as an *unsigned* comparison, since a negative value will wrap + // to a large positive value. The input dim is dilated, so we need to dilate + // the bound as well to match. + + // Also need to check that the input coordinates are not in one of the + // holes created by base dilation. + const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) { + llvm::Value* remainder = + b_.CreateSRem(input_index, b_.getInt64(base_dilation)); + return b_.CreateICmpEQ(remainder, b_.getInt64(0)); + }; + + llvm::Value* in_bounds_condition = b_.getInt1(true); + for (int i = 0; i < num_spatial_dims; ++i) { + llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound( + lhs->shape().dimensions(dnums.input_spatial_dimensions(i)), + window.dimensions(i).base_dilation())); + llvm::Value* dim_in_bound = b_.CreateICmpULT(input_spatial[i], input_bound); + llvm::Value* dim_not_in_hole = + not_in_hole(input_spatial[i], window.dimensions(i).base_dilation()); + llvm::Value* dim_ok = b_.CreateAnd(dim_in_bound, dim_not_in_hole); + in_bounds_condition = b_.CreateAnd(in_bounds_condition, dim_ok); + } + + // Now we need to map the dilated base coordinates back to the actual + // data indices on the lhs. + const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) { + return b_.CreateSDiv(input_index, b_.getInt64(base_dilation)); + }; + for (int i = 0; i < num_spatial_dims; ++i) { + input_spatial[i] = + undilate(input_spatial[i], window.dimensions(i).base_dilation()); + } + + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); + SetToFirstInsertPoint(if_data.true_block, &b_); + + // We are not in the padding, so carry out the computation. + int num_dims = num_spatial_dims + 2; + llvm_ir::IrArray::Index input_index(b_.getInt64Ty(), num_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i]; + } + input_index[dnums.input_feature_dimension()] = input_feature; + input_index[dnums.input_batch_dimension()] = batch; + + llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs)); + llvm_ir::IrArray::Index kernel_index(b_.getInt64Ty(), num_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + kernel_index[dnums.kernel_spatial_dimensions(i)] = + window.dimensions(i).window_reversal() + ? b_.CreateNSWSub(b_.getInt64(window.dimensions(i).size() - 1), + kernel_spatial[i]) + : kernel_spatial[i]; + } + + kernel_index[dnums.kernel_input_feature_dimension()] = input_feature; + kernel_index[dnums.kernel_output_feature_dimension()] = output_feature; + + llvm_ir::IrArray input_array(GetIrArrayFor(lhs)); + llvm::Value* product = + b_.CreateFMul(input_array.EmitReadArrayElement(input_index, &b_), + kernel_array.EmitReadArrayElement(kernel_index, &b_)); + llvm::Value* sum = b_.CreateFAdd(b_.CreateLoad(sum_address), product); + b_.CreateStore(sum, sum_address); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); + return b_.CreateLoad(sum_address); +} + Status IrEmitter::HandleConvolution(HloInstruction* convolution) { auto lhs = convolution->operand(0); auto rhs = convolution->operand(1); - const auto& window = convolution->window(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*convolution, /*operands=*/{lhs, rhs}, /*supported_types=*/{F16, F32, C64})); - const ConvolutionDimensionNumbers& dnums = - convolution->convolution_dimension_numbers(); - // TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support // different data layouts. if (PotentiallyImplementedAsEigenConvolution(*convolution, @@ -908,12 +1035,12 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { PrimitiveType primitive_type = lhs->shape().element_type(); llvm::Type* ir_ptr_type = primitive_type == F16 - ? ir_builder_.getHalfTy()->getPointerTo() - : ir_builder_.getFloatTy()->getPointerTo(); - llvm::Type* int64_type = ir_builder_.getInt64Ty(); - llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo(); + ? b_.getHalfTy()->getPointerTo() + : b_.getFloatTy()->getPointerTo(); + llvm::Type* int64_type = b_.getInt64Ty(); + llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo(); llvm::FunctionType* conv_type = llvm::FunctionType::get( - ir_builder_.getVoidTy(), + b_.getVoidTy(), {int8_ptr_type, ir_ptr_type, ir_ptr_type, ir_ptr_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, @@ -945,34 +1072,34 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { conv_func->setCallingConv(llvm::CallingConv::C); conv_func->setDoesNotThrow(); conv_func->setOnlyAccessesArgMemory(); - ir_builder_.CreateCall( - conv_func, { - GetExecutableRunOptionsArgument(), - ir_builder_.CreateBitCast( - GetEmittedValueFor(convolution), ir_ptr_type), - ir_builder_.CreateBitCast(lhs_address, ir_ptr_type), - ir_builder_.CreateBitCast(rhs_address, ir_ptr_type), - ir_builder_.getInt64(input_batch), - ir_builder_.getInt64(input_rows), - ir_builder_.getInt64(input_cols), - ir_builder_.getInt64(input_channels), - ir_builder_.getInt64(kernel_rows), - ir_builder_.getInt64(kernel_cols), - ir_builder_.getInt64(kernel_channels), - ir_builder_.getInt64(kernel_filters), - ir_builder_.getInt64(output_rows), - ir_builder_.getInt64(output_cols), - ir_builder_.getInt64(row_stride), - ir_builder_.getInt64(col_stride), - ir_builder_.getInt64(padding_top), - ir_builder_.getInt64(padding_bottom), - ir_builder_.getInt64(padding_left), - ir_builder_.getInt64(padding_right), - ir_builder_.getInt64(lhs_row_dilation), - ir_builder_.getInt64(lhs_col_dilation), - ir_builder_.getInt64(rhs_row_dilation), - ir_builder_.getInt64(rhs_col_dilation), - }); + b_.CreateCall( + conv_func, + { + GetExecutableRunOptionsArgument(), + b_.CreateBitCast(GetEmittedValueFor(convolution), ir_ptr_type), + b_.CreateBitCast(lhs_address, ir_ptr_type), + b_.CreateBitCast(rhs_address, ir_ptr_type), + b_.getInt64(input_batch), + b_.getInt64(input_rows), + b_.getInt64(input_cols), + b_.getInt64(input_channels), + b_.getInt64(kernel_rows), + b_.getInt64(kernel_cols), + b_.getInt64(kernel_channels), + b_.getInt64(kernel_filters), + b_.getInt64(output_rows), + b_.getInt64(output_cols), + b_.getInt64(row_stride), + b_.getInt64(col_stride), + b_.getInt64(padding_top), + b_.getInt64(padding_bottom), + b_.getInt64(padding_left), + b_.getInt64(padding_right), + b_.getInt64(lhs_row_dilation), + b_.getInt64(lhs_col_dilation), + b_.getInt64(rhs_row_dilation), + b_.getInt64(rhs_col_dilation), + }); return Status::OK(); } @@ -985,150 +1112,9 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // See the description of convolution in the XLA documentation for the pseudo // code for convolution. return EmitTargetElementLoop( - convolution, [this, convolution, lhs, rhs, window, - dnums](const llvm_ir::IrArray::Index& index) { - int num_spatial_dims = dnums.output_spatial_dimensions_size(); - std::vector output_spatial(num_spatial_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - output_spatial[i] = index[dnums.output_spatial_dimensions(i)]; - } - llvm::Value* output_feature = index[dnums.output_feature_dimension()]; - llvm::Value* batch = index[dnums.output_batch_dimension()]; - - // We will accumulate the products into this sum to calculate - // the output entry at the given index. - PrimitiveType lhs_element_type = lhs->shape().element_type(); - llvm::Type* lhs_llvm_type = - llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_); - llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( - lhs_llvm_type, "convolution_sum_address", &ir_builder_, - MinimumAlignmentForPrimitiveType(lhs_element_type)); - llvm::Value* constant_zero = - llvm::Constant::getNullValue(lhs_llvm_type); - ir_builder_.CreateStore(constant_zero, sum_address); - - llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &ir_builder_); - std::vector kernel_spatial(num_spatial_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - kernel_spatial[i] = - loops - .AddLoop(0, - rhs->shape().dimensions( - dnums.kernel_spatial_dimensions(i)), - tensorflow::strings::StrCat("k", i)) - ->GetIndVarValue(); - } - llvm::Value* input_feature = - loops - .AddLoop( - 0, lhs->shape().dimensions(dnums.input_feature_dimension()), - "iz") - ->GetIndVarValue(); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); - - // Calculate the spatial index in the input array, taking striding, - // dilation and padding into account. An index in the padding will be - // out of the bounds of the array. - const auto calculate_input_index = - [this](llvm::Value* output_index, llvm::Value* kernel_index, - const WindowDimension& window_dim) { - llvm::Value* strided_index = ir_builder_.CreateNSWMul( - output_index, ir_builder_.getInt64(window_dim.stride())); - llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul( - kernel_index, - ir_builder_.getInt64(window_dim.window_dilation())); - return ir_builder_.CreateNSWSub( - ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index), - ir_builder_.getInt64(window_dim.padding_low())); - }; - std::vector input_spatial(num_spatial_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - input_spatial[i] = calculate_input_index( - output_spatial[i], kernel_spatial[i], window.dimensions(i)); - } - - // We need to check if 0 <= input dim < bound, as otherwise we are in - // the padding so that we can skip the computation. That is equivalent - // to input dim < bound as an *unsigned* comparison, since a negative - // value will wrap to a large positive value. The input dim is dilated, - // so we need to dilate the bound as well to match. - - // Also need to check that the input coordinates are not in one of the - // holes created by base dilation. - const auto not_in_hole = [&](llvm::Value* input_index, - int64 base_dilation) { - llvm::Value* remainder = ir_builder_.CreateSRem( - input_index, ir_builder_.getInt64(base_dilation)); - return ir_builder_.CreateICmpEQ(remainder, ir_builder_.getInt64(0)); - }; - - llvm::Value* in_bounds_condition = ir_builder_.getInt1(true); - for (int i = 0; i < num_spatial_dims; ++i) { - llvm::ConstantInt* input_bound = - ir_builder_.getInt64(window_util::DilatedBound( - lhs->shape().dimensions(dnums.input_spatial_dimensions(i)), - window.dimensions(i).base_dilation())); - llvm::Value* dim_in_bound = - ir_builder_.CreateICmpULT(input_spatial[i], input_bound); - llvm::Value* dim_not_in_hole = not_in_hole( - input_spatial[i], window.dimensions(i).base_dilation()); - llvm::Value* dim_ok = - ir_builder_.CreateAnd(dim_in_bound, dim_not_in_hole); - in_bounds_condition = - ir_builder_.CreateAnd(in_bounds_condition, dim_ok); - } - - // Now we need to map the dilated base coordinates back to the actual - // data indices on the lhs. - const auto undilate = [&](llvm::Value* input_index, - int64 base_dilation) { - return ir_builder_.CreateSDiv(input_index, - ir_builder_.getInt64(base_dilation)); - }; - for (int i = 0; i < num_spatial_dims; ++i) { - input_spatial[i] = - undilate(input_spatial[i], window.dimensions(i).base_dilation()); - } - - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - in_bounds_condition, "in-bounds", &ir_builder_); - SetToFirstInsertPoint(if_data.true_block, &ir_builder_); - - // We are not in the padding, so carry out the computation. - int num_dims = num_spatial_dims + 2; - llvm_ir::IrArray::Index input_index(ir_builder_.getInt64Ty(), num_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i]; - } - input_index[dnums.input_feature_dimension()] = input_feature; - input_index[dnums.input_batch_dimension()] = batch; - - llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs)); - llvm_ir::IrArray::Index kernel_index(ir_builder_.getInt64Ty(), - num_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - kernel_index[dnums.kernel_spatial_dimensions(i)] = - window.dimensions(i).window_reversal() - ? ir_builder_.CreateNSWSub( - ir_builder_.getInt64(window.dimensions(i).size() - 1), - kernel_spatial[i]) - : kernel_spatial[i]; - } - - kernel_index[dnums.kernel_input_feature_dimension()] = input_feature; - kernel_index[dnums.kernel_output_feature_dimension()] = output_feature; - - llvm_ir::IrArray input_array(GetIrArrayFor(lhs)); - llvm::Value* product = ir_builder_.CreateFMul( - input_array.EmitReadArrayElement(input_index, &ir_builder_), - kernel_array.EmitReadArrayElement(kernel_index, &ir_builder_)); - llvm::Value* sum = ir_builder_.CreateFAdd( - ir_builder_.CreateLoad(sum_address), product); - ir_builder_.CreateStore(sum, sum_address); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); - return ir_builder_.CreateLoad(sum_address); + convolution, [&](const llvm_ir::IrArray::Index& index) { + return EmitTargetElementLoopBodyForConvolution( + Cast(convolution), index); }); } @@ -1152,11 +1138,11 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { } // Args have been computed, make the call. - llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo(); - llvm::Type* int32_type = ir_builder_.getInt32Ty(); - llvm::Type* int64_type = ir_builder_.getInt64Ty(); + llvm::Type* int8_ptr_type = b_.getInt8Ty()->getPointerTo(); + llvm::Type* int32_type = b_.getInt32Ty(); + llvm::Type* int64_type = b_.getInt64Ty(); llvm::FunctionType* fft_type = llvm::FunctionType::get( - ir_builder_.getVoidTy(), + b_.getVoidTy(), {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); @@ -1173,16 +1159,15 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { fft_func->setDoesNotThrow(); fft_func->setOnlyAccessesInaccessibleMemOrArgMem(); const int fft_rank = fft_length.size(); - ir_builder_.CreateCall( + b_.CreateCall( fft_func, {GetExecutableRunOptionsArgument(), - ir_builder_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type), - ir_builder_.CreateBitCast(operand_address, int8_ptr_type), - ir_builder_.getInt32(fft->fft_type()), ir_builder_.getInt32(fft_rank), - ir_builder_.getInt64(input_batch), - ir_builder_.getInt64(fft_rank > 0 ? fft_length[0] : 0), - ir_builder_.getInt64(fft_rank > 1 ? fft_length[1] : 0), - ir_builder_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); + b_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type), + b_.CreateBitCast(operand_address, int8_ptr_type), + b_.getInt32(fft->fft_type()), b_.getInt32(fft_rank), + b_.getInt64(input_batch), b_.getInt64(fft_rank > 0 ? fft_length[0] : 0), + b_.getInt64(fft_rank > 1 ? fft_length[1] : 0), + b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); return Status::OK(); } @@ -1221,11 +1206,10 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape)); // TODO(b/63762267): Be more aggressive about specifying alignment. - ir_builder_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, - /*SrcAlign=*/1, - ShapeUtil::ByteSizeOf(operand_shape)); + b_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, + /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape)); } - llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &ir_builder_, module_); + llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_, module_); return Status::OK(); } @@ -1258,47 +1242,7 @@ static llvm_ir::IrArray::Index FillReducedDimensionIndex( Status IrEmitter::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); - auto param_number = parameter->parameter_number(); - auto param_shape = parameter->shape(); - - // We have to access the parameter at offset param_number in the params - // array. The code generated here is equivalent to this C code: - // - // i8* param_address_untyped = params[param_number]; - // Param* param_address_typed = (Param*)param_address_untyped; - // - // Where Param is the actual element type of the underlying buffer (for - // example, float for an XLA F32 element type). - llvm::Value* params = compute_function_->parameters_arg(); - llvm::Value* param_address_offset = - llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_); - llvm::LoadInst* param_address_untyped = - ir_builder_.CreateLoad(param_address_offset); - param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped"))); - if (is_top_level_computation_ && - hlo_module_config_.debug_options() - .xla_llvm_enable_invariant_load_metadata()) { - // In the entry computation the parameter slots in the %params argument are - // invariant through program execution. In computations that are called - // from the entry computation (via kWhile, kCall and kConditional) the - // parameter slots are *not* invariant since they're written to by their - // callers. - param_address_untyped->setMetadata( - llvm::LLVMContext::MD_invariant_load, - llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{})); - } - - llvm::Value* param_address_typed = ir_builder_.CreateBitCast( - param_address_untyped, IrShapeType(param_shape)->getPointerTo()); - emitted_value_[parameter] = param_address_typed; - - if (!ShapeUtil::IsOpaque(param_shape)) { - AttachAlignmentMetadataForLoad(param_address_untyped, param_shape); - AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape); - } - - VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*param_address_typed); - return Status::OK(); + return EmitTargetAddressForOp(parameter); } // Returns true if the relative order of the unreduced dimensions stays the same @@ -1396,62 +1340,61 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( return nullptr; case HloOpcode::kAdd: - return [root_is_integral](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, + return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { - return root_is_integral ? ir_builder->CreateAdd(lhs, rhs) - : ir_builder->CreateFAdd(lhs, rhs); + return root_is_integral ? b->CreateAdd(lhs, rhs) + : b->CreateFAdd(lhs, rhs); }; case HloOpcode::kMultiply: - return [root_is_integral](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, + return [root_is_integral](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { - return root_is_integral ? ir_builder->CreateMul(lhs, rhs) - : ir_builder->CreateFMul(lhs, rhs); + return root_is_integral ? b->CreateMul(lhs, rhs) + : b->CreateFMul(lhs, rhs); }; case HloOpcode::kAnd: - return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, - llvm::Value* rhs) { return ir_builder->CreateAnd(lhs, rhs); }; + return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { + return b->CreateAnd(lhs, rhs); + }; case HloOpcode::kOr: - return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, - llvm::Value* rhs) { return ir_builder->CreateOr(lhs, rhs); }; + return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { + return b->CreateOr(lhs, rhs); + }; case HloOpcode::kXor: - return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, - llvm::Value* rhs) { return ir_builder->CreateXor(lhs, rhs); }; + return [](llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { + return b->CreateXor(lhs, rhs); + }; case HloOpcode::kMaximum: return [root_is_floating_point, root_is_signed]( - llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, - llvm::Value* rhs) { + llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { if (root_is_floating_point) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum, - {lhs, rhs}, {lhs->getType()}, - ir_builder); + {lhs, rhs}, {lhs->getType()}, b); } - return ir_builder->CreateSelect( - ir_builder->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SGE - : llvm::ICmpInst::ICMP_UGE, - lhs, rhs), + return b->CreateSelect( + b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SGE + : llvm::ICmpInst::ICMP_UGE, + lhs, rhs), lhs, rhs); }; case HloOpcode::kMinimum: return [root_is_floating_point, root_is_signed]( - llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, - llvm::Value* rhs) { + llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { if (root_is_floating_point) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum, - {lhs, rhs}, {lhs->getType()}, - ir_builder); + {lhs, rhs}, {lhs->getType()}, b); } - return ir_builder->CreateSelect( - ir_builder->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SLE - : llvm::ICmpInst::ICMP_ULE, - lhs, rhs), + return b->CreateSelect( + b->CreateICmp(root_is_signed ? llvm::ICmpInst::ICMP_SLE + : llvm::ICmpInst::ICMP_ULE, + lhs, rhs), lhs, rhs); }; } @@ -1520,34 +1463,31 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( accumulator.reserve(accumulator_type.size()); for (auto accumulator_shard_type : accumulator_type) { accumulator.push_back(llvm_ir::EmitAllocaAtFunctionEntry( - accumulator_shard_type, "accumulator", &ir_builder_, 0)); + accumulator_shard_type, "accumulator", &b_, 0)); } - llvm::Value* init_value_ssa = - ir_builder_.CreateLoad(GetEmittedValueFor(init_value)); + llvm::Value* init_value_ssa = b_.CreateLoad(GetEmittedValueFor(init_value)); for (llvm::Value* accumulator_shard : accumulator) { llvm::Value* initial_value; auto shard_type = accumulator_shard->getType()->getPointerElementType(); if (auto vector_type = llvm::dyn_cast(shard_type)) { - initial_value = ir_builder_.CreateVectorSplat( - vector_type->getNumElements(), init_value_ssa); + initial_value = + b_.CreateVectorSplat(vector_type->getNumElements(), init_value_ssa); } else { initial_value = init_value_ssa; } - ir_builder_.CreateAlignedStore(initial_value, accumulator_shard, - element_alignment); + b_.CreateAlignedStore(initial_value, accumulator_shard, element_alignment); } llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"), - &ir_builder_); + &b_); llvm_ir::IrArray::Index reduced_dims_index = reduction_loop_nest.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, "reduction_dim"); - SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), - &ir_builder_); + SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), &b_); llvm_ir::IrArray arg_array(GetIrArrayFor(arg)); llvm_ir::IrArray::Index input_index = reduced_dims_index; @@ -1560,38 +1500,34 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( } CHECK(output_index.end() == it); - llvm::Value* input_address = ir_builder_.CreateBitCast( - arg_array.EmitArrayElementAddress(input_index, &ir_builder_), - ir_builder_.getInt8PtrTy()); + llvm::Value* input_address = b_.CreateBitCast( + arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy()); for (int i = 0; i < accumulator.size(); i++) { auto input_address_typed = - ir_builder_.CreateBitCast(input_address, accumulator[i]->getType()); + b_.CreateBitCast(input_address, accumulator[i]->getType()); auto current_accumulator_value = - ir_builder_.CreateAlignedLoad(accumulator[i], element_alignment); - auto addend = - ir_builder_.CreateAlignedLoad(input_address_typed, element_alignment); + b_.CreateAlignedLoad(accumulator[i], element_alignment); + auto addend = b_.CreateAlignedLoad(input_address_typed, element_alignment); arg_array.AnnotateLoadStoreInstructionWithMetadata(addend); auto reduced_result = - reduction_generator(&ir_builder_, current_accumulator_value, addend); - ir_builder_.CreateAlignedStore(reduced_result, accumulator[i], - element_alignment); + reduction_generator(&b_, current_accumulator_value, addend); + b_.CreateAlignedStore(reduced_result, accumulator[i], element_alignment); if (i != (accumulator.size() - 1)) { - input_address = ir_builder_.CreateConstInBoundsGEP1_32( - reduced_result->getType(), input_address_typed, 1); + input_address = b_.CreateConstInBoundsGEP1_32(reduced_result->getType(), + input_address_typed, 1); } } - SetToFirstInsertPoint(reduction_loop_nest.GetOuterLoopExitBasicBlock(), - &ir_builder_); + SetToFirstInsertPoint(reduction_loop_nest.GetOuterLoopExitBasicBlock(), &b_); ShardedVector result_ssa; result_ssa.reserve(accumulator.size()); for (auto accumulator_shard : accumulator) { result_ssa.push_back( - ir_builder_.CreateAlignedLoad(accumulator_shard, element_alignment)); + b_.CreateAlignedLoad(accumulator_shard, element_alignment)); } return result_ssa; } @@ -1600,17 +1536,17 @@ void IrEmitter::EmitShardedVectorStore( llvm::Value* store_address, const std::vector& value_to_store, const int alignment, const llvm_ir::IrArray& containing_array) { for (int i = 0; i < value_to_store.size(); i++) { - auto store_address_typed = ir_builder_.CreateBitCast( + auto store_address_typed = b_.CreateBitCast( store_address, llvm::PointerType::getUnqual(value_to_store[i]->getType())); - auto store_instruction = ir_builder_.CreateAlignedStore( + auto store_instruction = b_.CreateAlignedStore( value_to_store[i], store_address_typed, alignment); containing_array.AnnotateLoadStoreInstructionWithMetadata( store_instruction); if (i != (value_to_store.size() - 1)) { - store_address = ir_builder_.CreateConstInBoundsGEP1_32( + store_address = b_.CreateConstInBoundsGEP1_32( value_to_store[i]->getType(), store_address_typed, 1); } } @@ -1676,8 +1612,8 @@ StatusOr IrEmitter::EmitVectorizedReduce( // } // } - llvm_ir::ForLoopNest loop_nest(IrName(reduce), &ir_builder_); - llvm_ir::IrArray::Index array_index(ir_builder_.getInt64Ty(), + llvm_ir::ForLoopNest loop_nest(IrName(reduce), &b_); + llvm_ir::IrArray::Index array_index(b_.getInt64Ty(), reduce->shape().dimensions_size()); for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0; --i) { @@ -1696,7 +1632,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( if (llvm::BasicBlock* innermost_body_bb = loop_nest.GetInnerLoopBodyBasicBlock()) { - SetToFirstInsertPoint(innermost_body_bb, &ir_builder_); + SetToFirstInsertPoint(innermost_body_bb, &b_); } auto outermost_loop_exit_block = loop_nest.GetOuterLoopExitBasicBlock(); @@ -1710,7 +1646,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( tensorflow::strings::Printf("dim.%lld", innermost_dimension)); array_index[innermost_dimension] = loop->GetIndVarValue(); - SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &ir_builder_); + SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_); ShardedVectorType vector_type = CreateShardedVectorType( reduce->shape().element_type(), vectorization_factor); @@ -1721,16 +1657,16 @@ StatusOr IrEmitter::EmitVectorizedReduce( llvm_ir::IrArray target_array = GetIrArrayFor(reduce); llvm::Value* output_address = - target_array.EmitArrayElementAddress(array_index, &ir_builder_); + target_array.EmitArrayElementAddress(array_index, &b_); EmitShardedVectorStore(output_address, accumulator, element_alignment, target_array); if (auto exit_terminator = loop->GetExitBasicBlock()->getTerminator()) { CHECK_GT(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1); - ir_builder_.SetInsertPoint(exit_terminator); + b_.SetInsertPoint(exit_terminator); } else { CHECK_EQ(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1); - ir_builder_.SetInsertPoint(loop->GetExitBasicBlock()); + b_.SetInsertPoint(loop->GetExitBasicBlock()); } } @@ -1740,8 +1676,8 @@ StatusOr IrEmitter::EmitVectorizedReduce( if (innermost_dimension_size % vectorization_factor) { // TODO(b/63775531): Consider using a scalar loop here to save on code size. array_index[innermost_dimension] = - ir_builder_.getInt64(innermost_dimension_size - - (innermost_dimension_size % vectorization_factor)); + b_.getInt64(innermost_dimension_size - + (innermost_dimension_size % vectorization_factor)); ShardedVectorType vector_type = CreateShardedVectorType( reduce->shape().element_type(), @@ -1753,18 +1689,72 @@ StatusOr IrEmitter::EmitVectorizedReduce( llvm_ir::IrArray target_array = GetIrArrayFor(reduce); llvm::Value* output_address = - target_array.EmitArrayElementAddress(array_index, &ir_builder_); + target_array.EmitArrayElementAddress(array_index, &b_); EmitShardedVectorStore(output_address, accumulator, element_alignment, target_array); } if (outermost_loop_exit_block) { - ir_builder_.SetInsertPoint(outermost_loop_exit_block); + b_.SetInsertPoint(outermost_loop_exit_block); } return true; } +StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( + HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index) { + const HloInstruction* arg = reduce->mutable_operand(0); + const HloInstruction* init_value = reduce->mutable_operand(1); + gtl::ArraySlice dimensions(reduce->dimensions()); + + // Initialize an accumulator with init_value. + PrimitiveType accumulator_type = reduce->shape().element_type(); + llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator", + &b_, MinimumAlignmentForPrimitiveType(accumulator_type)); + llvm::Value* init_value_addr = GetEmittedValueFor(init_value); + llvm::Value* load_init_value = b_.CreateLoad(init_value_addr); + b_.CreateStore(load_init_value, accumulator_addr); + + // The enclosing loops go over all the target elements. Now we have to compute + // the actual target element. For this, we build a new loop nest to iterate + // over all the reduction dimensions in the argument. + // AddLoopsForShapeOnDimensions will return an Index where induction Value*s + // are placed for each dimension in dimensions, and all the rest are nullptrs. + llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_); + const llvm_ir::IrArray::Index reduced_dims_index = + loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, + "reduction_dim"); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); + + // Build a full index for the input argument, using reduced_dims_index as the + // base. In reduced_dims_index only the reduction dimensions are filled in. We + // fill in the rest of the dimensions with induction Value*s taken from + // 'index' which iterates over the target array. See the high-level + // description in the XLA documentation for details. + llvm_ir::IrArray arg_array(GetIrArrayFor(arg)); + llvm_ir::IrArray::Index input_index = reduced_dims_index; + llvm_ir::IrArray::Index::const_iterator it = index.begin(); + + for (size_t i = 0; i < input_index.size(); ++i) { + if (input_index[i] == nullptr) { + input_index[i] = *it++; + } + } + CHECK(index.end() == it); + + // Apply the reduction function to the loaded value. + llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_); + llvm::Value* result = EmitThreadLocalCall( + *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element}, + "reduce_function"); + b_.CreateStore(result, accumulator_addr); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); + return b_.CreateLoad(accumulator_addr); +} + Status IrEmitter::HandleReduce(HloInstruction* reduce) { auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); @@ -1786,61 +1776,11 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { } } - // The called computation should have been emitted previously. - llvm::Function* reducer_function = FindOrDie(emitted_functions_, function); - return EmitTargetElementLoop( - reduce, [this, reduce, arg, init_value, dimensions, - reducer_function](const llvm_ir::IrArray::Index& index) { - // Initialize an accumulator with init_value. - PrimitiveType accumulator_type = reduce->shape().element_type(); - llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), - "accumulator", &ir_builder_, - MinimumAlignmentForPrimitiveType(accumulator_type)); - llvm::Value* init_value_addr = GetEmittedValueFor(init_value); - llvm::Value* load_init_value = ir_builder_.CreateLoad(init_value_addr); - ir_builder_.CreateStore(load_init_value, accumulator_addr); - - // The enclosing loops go over all the target elements. Now we have to - // compute the actual target element. For this, we build a new loop nest - // to iterate over all the reduction dimensions in the argument. - // AddLoopsForShapeOnDimensions will return an Index where induction - // Value*s are placed for each dimension in dimensions, and all the rest - // are nullptrs. - llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &ir_builder_); - const llvm_ir::IrArray::Index reduced_dims_index = - loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, - "reduction_dim"); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); - - // Build a full index for the input argument, using reduced_dims_index - // as the base. In reduced_dims_index only the reduction dimensions are - // filled in. We fill in the rest of the dimensions with induction - // Value*s taken from 'index' which iterates over the target array. - // See the high-level description in the XLA documentation for details. - llvm_ir::IrArray arg_array(GetIrArrayFor(arg)); - llvm_ir::IrArray::Index input_index = reduced_dims_index; - llvm_ir::IrArray::Index::const_iterator it = index.begin(); - - for (size_t i = 0; i < input_index.size(); ++i) { - if (input_index[i] == nullptr) { - input_index[i] = *it++; - } - } - CHECK(index.end() == it); - - // Apply the reduction function to the loaded value. - llvm::Value* input_address = - arg_array.EmitArrayElementAddress(input_index, &ir_builder_); - llvm::Value* result = EmitElementFunctionCall( - reducer_function, reduce->shape(), - {accumulator_addr, input_address}, "reduce_function"); - ir_builder_.CreateStore(result, accumulator_addr); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); - return ir_builder_.CreateLoad(accumulator_addr); - }); + return EmitTargetElementLoop(reduce, + [&](const llvm_ir::IrArray::Index& index) { + return EmitTargetElementLoopBodyForReduce( + Cast(reduce), index); + }); } Status IrEmitter::HandleSend(HloInstruction* send) { @@ -1853,6 +1793,10 @@ Status IrEmitter::HandleSendDone(HloInstruction* send_done) { return Unimplemented("Send-done is not implemented on CPU."); } +Status IrEmitter::HandleScatter(HloInstruction*) { + return Unimplemented("Scatter is not implemented on CPUs."); +} + Status IrEmitter::HandleSlice(HloInstruction* slice) { VLOG(2) << "HandleSlice: " << slice->ToString(); auto operand = slice->operand(0); @@ -1942,7 +1886,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { llvm_ir::IrArray target_array = GetIrArrayFor(slice); const int64 num_outer_loops = outer_dims.size(); - llvm_ir::ForLoopNest loops(IrName(slice), &ir_builder_); + llvm_ir::ForLoopNest loops(IrName(slice), &b_); llvm_ir::IrArray::Index target_index = loops.AddLoopsForShapeOnDimensions(slice->shape(), outer_dims, "slice"); @@ -1951,21 +1895,21 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { // for the rest of the dimensions the copy writes to the full dimension. std::replace(target_index.begin(), target_index.end(), static_cast(nullptr), - static_cast(ir_builder_.getInt64(0))); + static_cast(b_.getInt64(0))); if (num_outer_loops > 0) { - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); } llvm_ir::IrArray source_array = GetIrArrayFor(operand); const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice( /*shape=*/slice->shape(), /*starts=*/slice->slice_starts(), - /*strides=*/slice->slice_strides(), /*builder=*/&ir_builder_); + /*strides=*/slice->slice_strides(), /*builder=*/&b_); - llvm::Value* memcpy_dest = target_array.EmitArrayElementAddress( - target_index, &ir_builder_, "slice.dest"); - llvm::Value* memcpy_source = source_array.EmitArrayElementAddress( - source_index, &ir_builder_, "slice.source"); + llvm::Value* memcpy_dest = + target_array.EmitArrayElementAddress(target_index, &b_, "slice.dest"); + llvm::Value* memcpy_source = + source_array.EmitArrayElementAddress(source_index, &b_, "slice.source"); const int64 memcpy_elements = primitive_elements_per_logical_element * memcpy_logical_elements; @@ -1982,7 +1926,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { } if (num_outer_loops > 0) { - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); } return Status::OK(); @@ -2008,7 +1952,7 @@ Status IrEmitter::HandleDynamicUpdateSlice( auto operands = GetIrArraysForOperandsOf(dynamic_update_slice); return llvm_ir::EmitDynamicUpdateSliceInPlace( operands, GetIrArrayFor(dynamic_update_slice), - IrName(dynamic_update_slice, "in_place"), &ir_builder_); + IrName(dynamic_update_slice, "in_place"), &b_); } return DefaultAction(dynamic_update_slice); } @@ -2042,43 +1986,41 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { [this, pad](const llvm_ir::IrArray::Index& target_index) { const HloInstruction* padding_value = pad->operand(1); llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value); - return ir_builder_.CreateLoad(padding_value_addr); + return b_.CreateLoad(padding_value_addr); })); // Create a loop to iterate over the operand elements and update the output // locations where the operand elements should be stored. - llvm_ir::ForLoopNest loops(IrName(pad, "assign"), &ir_builder_); + llvm_ir::ForLoopNest loops(IrName(pad, "assign"), &b_); const HloInstruction* operand = pad->operand(0); const llvm_ir::IrArray::Index operand_index = loops.AddLoopsForShape(operand->shape(), "operand"); - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); // Load an element from the operand. llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = - operand_array.EmitReadArrayElement(operand_index, &ir_builder_); + operand_array.EmitReadArrayElement(operand_index, &b_); // Compute the output index the operand element should be assigned to. // output_index := edge_padding_low + operand_index * (interior_padding + 1) const PaddingConfig& padding_config = pad->padding_config(); llvm_ir::IrArray::Index output_index(operand_index.GetType()); for (size_t i = 0; i < operand_index.size(); ++i) { - llvm::Value* offset = ir_builder_.CreateMul( + llvm::Value* offset = b_.CreateMul( operand_index[i], - ir_builder_.getInt64(padding_config.dimensions(i).interior_padding() + - 1)); - llvm::Value* index = ir_builder_.CreateAdd( - offset, - ir_builder_.getInt64(padding_config.dimensions(i).edge_padding_low())); + b_.getInt64(padding_config.dimensions(i).interior_padding() + 1)); + llvm::Value* index = b_.CreateAdd( + offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low())); output_index.push_back(index); } // Store the operand element to the computed output location. llvm_ir::IrArray output_array(GetIrArrayFor(pad)); - output_array.EmitWriteArrayElement(output_index, operand_data, &ir_builder_); + output_array.EmitWriteArrayElement(output_index, operand_data, &b_); - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); return Status::OK(); } @@ -2100,8 +2042,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { // Delegate to common implementation of fused in-place dynamic-update-slice. auto operands = GetIrArraysForOperandsOf(fusion); return llvm_ir::EmitFusedDynamicUpdateSliceInPlace( - fusion, operands, GetIrArrayFor(fusion), &elemental_emitter, - &ir_builder_); + fusion, operands, GetIrArrayFor(fusion), &elemental_emitter, &b_); } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) { VLOG(3) << "HandleFusion kLoop"; CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); @@ -2136,7 +2077,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( *dot, target_array, lhs_array, rhs_array, &addend_array, - GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_, + GetExecutableRunOptionsArgument(), &b_, hlo_module_config_, target_machine_features_)); return Status::OK(); } else { @@ -2148,18 +2089,13 @@ Status IrEmitter::HandleCall(HloInstruction* call) { HloComputation* computation = call->to_apply(); llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation); - std::vector parameter_addresses; - for (const HloInstruction* operand : call->operands()) { - parameter_addresses.push_back(GetEmittedValueFor(operand)); - } - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call)); if (!computation->root_instruction()->outer_dimension_partitions().empty()) { // ParallelTaskAssignment assigned partitions, emit call to // ParallelForkJoin. std::vector call_args = GetArrayFunctionCallArguments( - parameter_addresses, &ir_builder_, computation->name(), + {}, &b_, computation->name(), /*return_value_buffer=*/emitted_value_[call], /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), /*temp_buffers_arg=*/GetTempBuffersArgument(), @@ -2167,11 +2103,10 @@ Status IrEmitter::HandleCall(HloInstruction* call) { HloInstruction* root = computation->root_instruction(); TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin( - call_args, root->shape(), root->outer_dimension_partitions(), - &ir_builder_, call_ir_function, computation->name())); + call_args, root->shape(), root->outer_dimension_partitions(), &b_, + call_ir_function, computation->name())); } else { - EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, - emitted_value_[call], computation->name()); + EmitGlobalCall(*computation, computation->name()); } return Status::OK(); @@ -2180,33 +2115,31 @@ Status IrEmitter::HandleCall(HloInstruction* call) { Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { gtl::ArraySlice operands(custom_call->operands()); tensorflow::StringPiece custom_call_target(custom_call->custom_call_target()); - llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy(); + llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - i8_ptr_type, ir_builder_.getInt32(operands.size()), - "cc_operands_alloca", &ir_builder_); + i8_ptr_type, b_.getInt32(operands.size()), "cc_operands_alloca", &b_); for (size_t i = 0; i < operands.size(); ++i) { const HloInstruction* operand = operands[i]; llvm::Value* operand_as_i8ptr = - ir_builder_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type); - llvm::Value* slot_in_operands_alloca = ir_builder_.CreateInBoundsGEP( - operands_alloca, {ir_builder_.getInt64(i)}); - ir_builder_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca); + b_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type); + llvm::Value* slot_in_operands_alloca = + b_.CreateInBoundsGEP(operands_alloca, {b_.getInt64(i)}); + b_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca); } auto* custom_call_ir_function = llvm::cast(module_->getOrInsertFunction( AsStringRef(custom_call_target), llvm::FunctionType::get( - /*Result=*/ir_builder_.getVoidTy(), + /*Result=*/b_.getVoidTy(), /*Params=*/{i8_ptr_type, operands_alloca->getType()}, /*isVarArg=*/false))); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); - auto* output_address_arg = ir_builder_.CreatePointerCast( - GetEmittedValueFor(custom_call), i8_ptr_type); + auto* output_address_arg = + b_.CreatePointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); - ir_builder_.CreateCall(custom_call_ir_function, - {output_address_arg, operands_alloca}); + b_.CreateCall(custom_call_ir_function, {output_address_arg, operands_alloca}); return Status::OK(); } @@ -2254,12 +2187,6 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { const HloInstruction* init = xla_while->operand(0); emitted_value_[xla_while] = GetEmittedValueFor(init); - // The called computation should have been emitted previously. - llvm::Function* condition_ir_function = - FindOrDie(emitted_functions_, condition); - llvm::Function* body_ir_function = - FindOrDie(emitted_functions_, xla_while->while_body()); - // Generating: // while (Condition(while_result)) { // // CopyInsertion pass inserts copies which enable 'while_result' to @@ -2271,17 +2198,15 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "header")), compute_function_->function()); - ir_builder_.CreateBr(header_bb); - ir_builder_.SetInsertPoint(header_bb); + b_.CreateBr(header_bb); + b_.SetInsertPoint(header_bb); // Calls the condition function to determine whether to proceed with the // body. It must return a bool, so use the scalar call form. - llvm::Value* while_result = GetEmittedValueFor(xla_while); - llvm::Value* while_condition = EmitElementFunctionCall( - condition_ir_function, condition->root_instruction()->shape(), - {while_result}, IrName(xla_while, "cond")); - llvm::Value* while_predicate = ir_builder_.CreateICmpNE( - while_condition, + EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond")); + llvm::Value* while_predicate = b_.CreateICmpNE( + b_.CreateLoad( + GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); // Branches to the body or to the while exit depending on the condition. @@ -2290,20 +2215,20 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { compute_function_->function()); llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "exit"))); - ir_builder_.CreateCondBr(while_predicate, body_bb, exit_bb); + b_.CreateCondBr(while_predicate, body_bb, exit_bb); // Calls the body function from the body block. - ir_builder_.SetInsertPoint(body_bb); + b_.SetInsertPoint(body_bb); // Calls the body function. - EmitArrayFunctionCallInto(body_ir_function, {while_result}, while_result, - IrName(xla_while, "body")); + EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body")); + // Finishes with a branch back to the header. - ir_builder_.CreateBr(header_bb); + b_.CreateBr(header_bb); // Adds the exit block to the function and sets the insert point there. compute_function_->function()->getBasicBlockList().push_back(exit_bb); - ir_builder_.SetInsertPoint(exit_bb); + b_.SetInsertPoint(exit_bb); return Status::OK(); } @@ -2345,21 +2270,21 @@ StatusOr IrEmitter::EmitFastConcatenate( std::vector outer_dims(std::next(concat_dim_layout_itr), output_min2maj.end()); - llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy(); - llvm::Type* i8_type = ir_builder_.getInt8Ty(); + llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); + llvm::Type* i8_type = b_.getInt8Ty(); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate)); llvm_ir::IrArray target_array = GetIrArrayFor(concatenate); - llvm_ir::ForLoopNest loops(IrName(concatenate), &ir_builder_); + llvm_ir::ForLoopNest loops(IrName(concatenate), &b_); llvm_ir::IrArray::Index outer_dims_index = loops.AddLoopsForShapeOnDimensions(output_shape, outer_dims, "concat"); std::replace(outer_dims_index.begin(), outer_dims_index.end(), static_cast(nullptr), - static_cast(ir_builder_.getInt64(0))); + static_cast(b_.getInt64(0))); if (!outer_dims.empty()) { - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); } PrimitiveType primitive_type = output_shape.element_type(); @@ -2368,10 +2293,10 @@ StatusOr IrEmitter::EmitFastConcatenate( // Contiguous subregions from each operand to the concatenate contribute to a // contiguous subregion in the target buffer starting at target_region_begin. - llvm::Value* target_region_begin = ir_builder_.CreateBitCast( - target_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_, - "target_region"), - i8_ptr_type); + llvm::Value* target_region_begin = + b_.CreateBitCast(target_array.EmitArrayElementAddress( + outer_dims_index, &b_, "target_region"), + i8_ptr_type); int64 byte_offset_into_target_region = 0; int64 inner_dims_product = @@ -2385,14 +2310,13 @@ StatusOr IrEmitter::EmitFastConcatenate( for (HloInstruction* operand : operands) { const Shape& input_shape = operand->shape(); llvm_ir::IrArray source_array = GetIrArrayFor(operand); - llvm::Value* copy_source_address = ir_builder_.CreateBitCast( - source_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_, - "src_addr"), + llvm::Value* copy_source_address = b_.CreateBitCast( + source_array.EmitArrayElementAddress(outer_dims_index, &b_, "src_addr"), i8_ptr_type); - llvm::Value* copy_target_address = ir_builder_.CreateGEP( - i8_type, target_region_begin, - ir_builder_.getInt64(byte_offset_into_target_region)); + llvm::Value* copy_target_address = + b_.CreateGEP(i8_type, target_region_begin, + b_.getInt64(byte_offset_into_target_region)); EmitTransferElements( copy_target_address, copy_source_address, @@ -2405,7 +2329,7 @@ StatusOr IrEmitter::EmitFastConcatenate( } if (!outer_dims.empty()) { - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); } return true; @@ -2424,16 +2348,15 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, llvm_ir::PrimitiveTypeToIrType(primitive_type, module_)); if (element_count == 1) { - auto* load_instruction = ir_builder_.CreateAlignedLoad( - ir_builder_.CreateBitCast(source, primitive_ptr_type), - element_alignment); + auto* load_instruction = b_.CreateAlignedLoad( + b_.CreateBitCast(source, primitive_ptr_type), element_alignment); source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction); - auto* store_instruction = ir_builder_.CreateAlignedStore( - load_instruction, ir_builder_.CreateBitCast(target, primitive_ptr_type), + auto* store_instruction = b_.CreateAlignedStore( + load_instruction, b_.CreateBitCast(target, primitive_ptr_type), element_alignment); target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); } else { - auto* memcpy_instruction = ir_builder_.CreateMemCpy( + auto* memcpy_instruction = b_.CreateMemCpy( target, /*DstAlign=*/element_alignment, source, /*SrcAlign=*/element_alignment, element_count * primitive_type_size); @@ -2467,8 +2390,6 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { Status IrEmitter::HandleConditional(HloInstruction* conditional) { auto pred = conditional->operand(0); - auto true_arg = conditional->operand(1); - auto false_arg = conditional->operand(2); TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) && pred->shape().element_type() == PRED) << "Predicate on a Conditional must be bool; got: " @@ -2490,37 +2411,31 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { << " and " << ShapeUtil::HumanString(false_computation->root_instruction()->shape()); - llvm::Function* true_function = - FindOrDie(emitted_functions_, true_computation); - llvm::Function* false_function = - FindOrDie(emitted_functions_, false_computation); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional)); - llvm::Value* conditional_result = GetEmittedValueFor(conditional); // Generating: // if (pred) // cond_result = true_computation(true_operand) // else // cond_result = false_computation(false_operand) - llvm::LoadInst* pred_value = ir_builder_.CreateLoad( + llvm::LoadInst* pred_value = b_.CreateLoad( GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); - llvm::Value* pred_cond = ir_builder_.CreateICmpNE( + llvm::Value* pred_cond = b_.CreateICmpNE( pred_value, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), "boolean_predicate"); llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(pred_cond, "conditional", &ir_builder_); + llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_); - SetToFirstInsertPoint(if_data.true_block, &ir_builder_); - EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)}, - conditional_result, IrName(conditional, "_true")); + SetToFirstInsertPoint(if_data.true_block, &b_); + EmitGlobalCall(*conditional->true_computation(), + IrName(conditional, "_true")); - SetToFirstInsertPoint(if_data.false_block, &ir_builder_); - EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)}, - conditional_result, IrName(conditional, "_false")); + SetToFirstInsertPoint(if_data.false_block, &b_); + EmitGlobalCall(*conditional->false_computation(), + IrName(conditional, "_false")); - SetToFirstInsertPoint(if_data.after_block, &ir_builder_); + SetToFirstInsertPoint(if_data.after_block, &b_); return Status::OK(); } @@ -2531,6 +2446,28 @@ Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) { return Status::OK(); } +Status IrEmitter::HandleIota(HloInstruction* iota) { + // TODO(b/64798317): implement iota on CPU. + return Unimplemented("Iota is not implemented on CPU."); +} + +Status IrEmitter::HandleRng(HloInstruction* rng) { + ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; + for (const HloInstruction* operand : rng->operands()) { + operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { + return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_); + }; + } + + CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); + TF_RETURN_IF_ERROR(EmitTargetElementLoop( + rng, elemental_emitter.MakeElementGenerator(rng, operand_to_generator))); + + llvm_ir::IncrementVariableForPhiloxRngState(1, module_, &b_); + + return Status::OK(); +} + Status IrEmitter::FinishVisit(HloInstruction* root) { // When this method is called, we should have already emitted an IR value for // the root (return) op. The IR value holds the address of the buffer holding @@ -2548,7 +2485,7 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { auto record_complete_computation = [&](llvm::Value* prof_counter) { if (prof_counter) { - profiling_state_.RecordCompleteComputation(&ir_builder_, prof_counter); + profiling_state_.RecordCompleteComputation(&b_, prof_counter); } }; @@ -2570,54 +2507,51 @@ llvm::Value* IrEmitter::GetProfileCounterCommon( int64 prof_counter_idx = it->second; string counter_name = IrName("prof_counter", hlo.name()); - return ir_builder_.CreateGEP(GetProfileCountersArgument(), - ir_builder_.getInt64(prof_counter_idx), - AsStringRef(counter_name)); + return b_.CreateGEP(GetProfileCountersArgument(), + b_.getInt64(prof_counter_idx), AsStringRef(counter_name)); } -void IrEmitter::ProfilingState::UpdateProfileCounter( - llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter, - llvm::Value* cycle_end, llvm::Value* cycle_start) { - auto* cycle_diff = ir_builder->CreateSub(cycle_end, cycle_start); +void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b, + llvm::Value* prof_counter, + llvm::Value* cycle_end, + llvm::Value* cycle_start) { + auto* cycle_diff = b->CreateSub(cycle_end, cycle_start); llvm::LoadInst* old_cycle_count = - ir_builder->CreateLoad(prof_counter, "old_cycle_count"); + b->CreateLoad(prof_counter, "old_cycle_count"); auto* new_cycle_count = - ir_builder->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count"); - ir_builder->CreateStore(new_cycle_count, prof_counter); + b->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count"); + b->CreateStore(new_cycle_count, prof_counter); } -llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter( - llvm::IRBuilder<>* ir_builder) { - llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); +llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter(llvm::IRBuilder<>* b) { + llvm::Module* module = b->GetInsertBlock()->getModule(); if (use_rdtscp_) { llvm::Function* func_llvm_readcyclecounter = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::readcyclecounter); - return ir_builder->CreateCall(func_llvm_readcyclecounter); + return b->CreateCall(func_llvm_readcyclecounter); } llvm::Function* func_llvm_x86_rdtscp = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::x86_rdtscp); if (!aux_i8ptr_) { - llvm::AllocaInst* rdtscp_aux = llvm_ir::EmitAllocaAtFunctionEntry( - ir_builder->getInt32Ty(), "rdtscp_aux", ir_builder); - aux_i8ptr_ = - ir_builder->CreateBitCast(rdtscp_aux, ir_builder->getInt8PtrTy()); + llvm::AllocaInst* rdtscp_aux = + llvm_ir::EmitAllocaAtFunctionEntry(b->getInt32Ty(), "rdtscp_aux", b); + aux_i8ptr_ = b->CreateBitCast(rdtscp_aux, b->getInt8PtrTy()); } - llvm::ConstantInt* alloca_size = ir_builder->getInt64(4); + llvm::ConstantInt* alloca_size = b->getInt64(4); llvm::Function* func_llvm_lifetime_start = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_start); - ir_builder->CreateCall(func_llvm_lifetime_start, {alloca_size, aux_i8ptr_}); - llvm::Value* rdtscp_call = - ir_builder->CreateCall(func_llvm_x86_rdtscp, aux_i8ptr_); + b->CreateCall(func_llvm_lifetime_start, {alloca_size, aux_i8ptr_}); + llvm::Value* rdtscp_call = b->CreateCall(func_llvm_x86_rdtscp, aux_i8ptr_); llvm::Function* func_llvm_lifetime_end = llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_end); - ir_builder->CreateCall(func_llvm_lifetime_end, {alloca_size, aux_i8ptr_}); + b->CreateCall(func_llvm_lifetime_end, {alloca_size, aux_i8ptr_}); return rdtscp_call; } -void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* ir_builder, +void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo) { - auto* cycle_start = ReadCycleCounter(ir_builder); + auto* cycle_start = ReadCycleCounter(b); cycle_start->setName(AsStringRef(IrName(hlo, "cycle_start"))); cycle_starts_[hlo] = cycle_start; if (first_read_cycle_start_ == nullptr) { @@ -2625,20 +2559,20 @@ void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* ir_builder, } } -void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* ir_builder, +void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* b, HloInstruction* hlo, llvm::Value* prof_counter) { - auto* cycle_end = ReadCycleCounter(ir_builder); + auto* cycle_end = ReadCycleCounter(b); cycle_end->setName(AsStringRef(IrName(hlo, "cycle_end"))); auto* cycle_start = cycle_starts_[hlo]; - UpdateProfileCounter(ir_builder, prof_counter, cycle_end, cycle_start); + UpdateProfileCounter(b, prof_counter, cycle_end, cycle_start); last_read_cycle_end_ = cycle_end; } void IrEmitter::ProfilingState::RecordCompleteComputation( - llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter) { + llvm::IRBuilder<>* b, llvm::Value* prof_counter) { if (last_read_cycle_end_ && first_read_cycle_start_) { - UpdateProfileCounter(ir_builder, prof_counter, last_read_cycle_end_, + UpdateProfileCounter(b, prof_counter, last_read_cycle_end_, first_read_cycle_start_); } } @@ -2646,14 +2580,14 @@ void IrEmitter::ProfilingState::RecordCompleteComputation( Status IrEmitter::Preprocess(HloInstruction* hlo) { VLOG(3) << "Visiting: " << hlo->ToString(); if (instruction_to_profile_idx_.count(hlo)) { - profiling_state_.RecordCycleStart(&ir_builder_, hlo); + profiling_state_.RecordCycleStart(&b_, hlo); } return Status::OK(); } Status IrEmitter::Postprocess(HloInstruction* hlo) { if (auto* prof_counter = GetProfileCounterFor(*hlo)) { - profiling_state_.RecordCycleDelta(&ir_builder_, hlo, prof_counter); + profiling_state_.RecordCycleDelta(&b_, hlo, prof_counter); } return Status::OK(); } @@ -2700,42 +2634,76 @@ llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { return compute_function_->exec_run_options_arg(); } -llvm::Value* IrEmitter::EmitTempBufferPointer( +llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { - llvm::Type* element_type = IrShapeType(target_shape); - // The alignment and number of bytes within the temporary buffer is determined - // by the maximal shape as determined by buffer assignment. - const BufferAllocation& allocation = assignment_.GetAllocation(slice.index()); - if (allocation.is_thread_local()) { + const BufferAllocation& allocation = *slice.allocation(); + llvm::Value* tempbuf_address = [&]() -> llvm::Value* { + if (slice == computation_root_allocation_) { + llvm::Argument* retval = compute_function_->result_arg(); + llvm::AttrBuilder attr_builder; + attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); + attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); + retval->addAttrs(attr_builder); + return retval; + } + + auto param_it = + computation_parameter_allocations_.find(slice.allocation()->index()); + if (param_it != computation_parameter_allocations_.end()) { + int64 param_number = param_it->second; + // We have to access the parameter at offset param_number in the params + // array. The code generated here is equivalent to this C code: + // + // i8* param_address_untyped = params[param_number]; + // Param* param_address_typed = (Param*)param_address_untyped; + // + // Where Param is the actual element type of the underlying buffer (for + // example, float for an XLA F32 element type). + llvm::Value* params = compute_function_->parameters_arg(); + llvm::Value* param_address_offset = + llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); + llvm::LoadInst* param_address_untyped = + b_.CreateLoad(param_address_offset); + + if (!ShapeUtil::IsOpaque(target_shape)) { + AttachAlignmentMetadataForLoad(param_address_untyped, target_shape); + AttachDereferenceableMetadataForLoad(param_address_untyped, + target_shape); + } + return param_address_untyped; + } + // Thread-local allocations should only be assigned a single buffer. const auto& assigned_buffers = allocation.assigned_buffers(); CHECK_EQ(1, assigned_buffers.size()); const Shape& shape = assigned_buffers.begin()->first->shape(); - llvm::AllocaInst*& tempbuf_address = thread_local_buffers_[{ - ir_builder_.GetInsertBlock()->getParent(), slice}]; - if (tempbuf_address == nullptr) { - tempbuf_address = llvm_ir::EmitAllocaAtFunctionEntry( + std::pair key = { + compute_function_->function(), slice}; + auto buf_it = thread_local_buffers_.find(key); + if (buf_it == thread_local_buffers_.end()) { + llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry( IrShapeType(shape), - tensorflow::strings::StrCat("thread_local", slice.ToString()), - &ir_builder_, MinimumAlignmentForShape(target_shape)); + tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_, + MinimumAlignmentForShape(target_shape)); + auto it_inserted_pair = thread_local_buffers_.insert({key, buffer}); + CHECK(it_inserted_pair.second); + buf_it = it_inserted_pair.first; } - return ir_builder_.CreateBitCast(tempbuf_address, - element_type->getPointerTo()); - } + return buf_it->second; + }(); + return b_.CreateBitCast(tempbuf_address, + IrShapeType(target_shape)->getPointerTo()); +} +llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( + const BufferAllocation::Slice& slice, const Shape& target_shape) { + const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP( - GetTempBuffersArgument(), slice.index(), &ir_builder_); - llvm::LoadInst* tempbuf_address_base = - ir_builder_.CreateLoad(tempbuf_address_ptr); - if (is_top_level_computation_ && - hlo_module_config_.debug_options() + GetTempBuffersArgument(), slice.index(), &b_); + llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr); + if (hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { - // In the entry computation the parameter slots in the %params argument are - // invariant through program execution. In computations that are called - // from the entry computation (via kWhile, kCall and kConditional) the - // parameter slots are *not* invariant since they're written to by their - // callers. tempbuf_address_base->setMetadata( llvm::LLVMContext::MD_invariant_load, llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); @@ -2746,90 +2714,29 @@ llvm::Value* IrEmitter::EmitTempBufferPointer( llvm::Value* tempbuf_address_untyped = tempbuf_address_base; if (slice.offset() > 0) { // Adjust the address to account for the slice offset. - tempbuf_address_untyped = ir_builder_.CreateInBoundsGEP( - tempbuf_address_base, ir_builder_.getInt64(slice.offset())); + tempbuf_address_untyped = + b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset())); } - return ir_builder_.CreateBitCast(tempbuf_address_untyped, - element_type->getPointerTo()); -} - -// Emits a function call returning a single array element. Allocates space -// for a single element_type value, and loads it after call. -llvm::Value* IrEmitter::EmitElementFunctionCall( - llvm::Function* function, const Shape& return_shape, - gtl::ArraySlice parameter_addresses, - tensorflow::StringPiece name) { - llvm::Value* return_value_buffer = EmitArrayFunctionCall( - function, return_shape, 1, parameter_addresses, name); - return ir_builder_.CreateLoad( - return_value_buffer, - AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); -} - -// Emits a core function call based on the following pseudo-code. -// -// char** parameter_addresses_buffer = -// allocate buffer with a pointer for each parameter to the function -// for each parameter index, i.e. for i = 0, ..., #parameters: -// parameter_addresses_buffer[i] = parameter_addresses[i] -// call function(return_value_buffer, -// parameter_addresses_buffer, -// temps) -// return return_value_buffer -- address of the return value. -void IrEmitter::EmitArrayFunctionCallInto( - llvm::Function* function, gtl::ArraySlice parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name) { - ir_builder_.CreateCall( - function, GetArrayFunctionCallArguments( - parameter_addresses, &ir_builder_, name, - /*return_value_buffer=*/return_value_buffer, - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), - /*profile_counters_arg=*/GetProfileCountersArgument())); + return b_.CreateBitCast(tempbuf_address_untyped, + IrShapeType(target_shape)->getPointerTo()); } -llvm::Value* IrEmitter::EmitArrayFunctionCall( - llvm::Function* function, const Shape& return_shape, int64 element_count, - gtl::ArraySlice parameter_addresses, - tensorflow::StringPiece name) { - llvm::Value* elements = - llvm::ConstantInt::get(ir_builder_.getInt64Ty(), element_count); - PrimitiveType return_type = return_shape.element_type(); - llvm::Value* return_value_buffer = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements, - tensorflow::strings::StrCat(name, "_return_value_address"), - &ir_builder_, MinimumAlignmentForPrimitiveType(return_type)); - EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer, - name); - return return_value_buffer; +llvm::Value* IrEmitter::EmitTempBufferPointer( + const BufferAllocation::Slice& slice, const Shape& target_shape) { + if (slice.allocation()->is_thread_local()) { + return EmitThreadLocalTempBufferPointer(slice, target_shape); + } else if (slice.allocation()->is_constant()) { + return FindOrDie(constant_buffer_to_global_, slice.allocation()->index()); + } else { + return EmitGlobalTempBufferPointer(slice, target_shape); + } } Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { - llvm::Value* addr; const Shape& target_shape = op->shape(); - if (op == op->parent()->root_instruction()) { - // For the root node, we write directly to the output buffer of the - // function. - llvm::Argument* retval = compute_function_->result_arg(); - if ((ShapeUtil::IsArray(target_shape) && - !ShapeUtil::IsZeroElementArray(target_shape)) || - (ShapeUtil::IsTuple(target_shape) && - !ShapeUtil::IsEmptyTuple(target_shape))) { - llvm::AttrBuilder attr_builder; - attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); - attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); - retval->addAttrs(attr_builder); - } - addr = ir_builder_.CreateBitCast(retval, - IrShapeType(target_shape)->getPointerTo()); - } else { - // For other nodes, we need the temporary buffer allocated for this node to - // write the result into. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - assignment_.GetUniqueTopLevelSlice(op)); - addr = EmitTempBufferPointer(slice, target_shape); - } + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + assignment_.GetUniqueTopLevelSlice(op)); + llvm::Value* addr = EmitTempBufferPointer(slice, target_shape); addr->setName(AsStringRef(IrName(op))); emitted_value_[op] = addr; return Status::OK(); @@ -2864,14 +2771,14 @@ Status IrEmitter::EmitTargetElementLoop( llvm_ir::IrArray(op_target_address, element_shape)); } TF_RETURN_IF_ERROR( - llvm_ir::LoopEmitter(element_generator, output_arrays, &ir_builder_) + llvm_ir::LoopEmitter(element_generator, output_arrays, &b_) .EmitLoop(IrName(target_op))); std::vector tuple_operand_ptrs; for (int64 i = 0; i < output_arrays.size(); ++i) { tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } - llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &ir_builder_, module_); + llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &b_, module_); } else { if (ShouldEmitParallelLoopFor(*target_op)) { @@ -2880,11 +2787,11 @@ Status IrEmitter::EmitTargetElementLoop( compute_function_->GetDynamicLoopBounds(); // Emit parallel loop with dynamic loop bounds for most-major dimensions. TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array, - &dynamic_loop_bounds, &ir_builder_) + &dynamic_loop_bounds, &b_) .EmitLoop(IrName(target_op))); } else { TF_RETURN_IF_ERROR( - llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_) + llvm_ir::LoopEmitter(element_generator, target_array, &b_) .EmitLoop(IrName(target_op))); } } @@ -2897,8 +2804,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source, llvm::Value* destination_value = GetEmittedValueFor(&destination); int64 source_size = ByteSizeOf(source.shape()); // TODO(b/63762267): Be more aggressive about specifying alignment. - ir_builder_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value, - /*SrcAlign=*/1, source_size); + b_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value, + /*SrcAlign=*/1, source_size); return Status::OK(); } @@ -2926,7 +2833,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : hlo->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArrayFor(operand).EmitReadArrayElement(index, &ir_builder_); + return GetIrArrayFor(operand).EmitReadArrayElement(index, &b_); }; } CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); @@ -2934,20 +2841,69 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator)); } -StatusOr IrEmitter::EmitScalarCall( - PrimitiveType return_type, HloComputation* computation, - const std::vector& arguments, tensorflow::StringPiece name) { - llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation); - std::vector argument_addrs; - for (auto argument : arguments) { - llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry( - argument->getType(), "arg_addr", &ir_builder_); - ir_builder_.CreateStore(argument, argument_addr); - argument_addrs.push_back(argument_addr); +llvm::Value* IrEmitter::EmitThreadLocalCall( + const HloComputation& callee, + tensorflow::gtl::ArraySlice parameters, + tensorflow::StringPiece name) { + const Shape& return_shape = callee.root_instruction()->shape(); + + // Lifting this restriction to allow "small" arrays should be easy. Allowing + // larger arrays is difficult because we allocate the buffer for this return + // value on the stack. + CHECK(ShapeUtil::IsScalar(return_shape)); + + PrimitiveType return_type = return_shape.element_type(); + + std::vector parameter_addrs; + for (llvm::Value* parameter : parameters) { + CHECK(!parameter->getType()->isPointerTy()); + llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry( + parameter->getType(), "arg_addr", &b_); + b_.CreateStore(parameter, parameter_addr); + parameter_addrs.push_back(parameter_addr); + } + + llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(return_type, module_), + tensorflow::strings::StrCat(name, "_retval_addr"), &b_, + MinimumAlignmentForPrimitiveType(return_type)); + + b_.CreateCall( + FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + parameter_addrs, &b_, name, + /*return_value_buffer=*/return_value_buffer, + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), + /*profile_counters_arg=*/GetProfileCountersArgument())); + + return b_.CreateLoad(return_value_buffer); +} + +void IrEmitter::EmitGlobalCall(const HloComputation& callee, + tensorflow::StringPiece name) { + b_.CreateCall(FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + /*parameter_addresses=*/{}, &b_, name, + /*return_value_buffer=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()), + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument())); +} + +llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( + const HloComputation& callee) { + const HloInstruction* root_inst = callee.root_instruction(); + if (root_inst->opcode() == HloOpcode::kOutfeed) { + return llvm::Constant::getNullValue(b_.getInt8PtrTy()); } - return EmitElementFunctionCall(llvm_function, - ShapeUtil::MakeShape(return_type, {}), - argument_addrs, name); + + const BufferAllocation::Slice root_buffer = + assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie(); + return EmitTempBufferPointer(root_buffer, root_inst->shape()); } + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 3c110a320fad931e68e48236d4b4a33d0601ab5a..c9a1dab62dcbcd926baa82737d24efa03fd326e9 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -97,12 +98,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { bool is_top_level_computation, std::vector* instruction_order); - llvm::IRBuilder<>* ir_builder() { return &ir_builder_; } + llvm::IRBuilder<>* b() { return &b_; } - // Emits a call to `computation` with scalar arguments `arguments`. - StatusOr EmitScalarCall( - PrimitiveType return_type, HloComputation* computation, - const std::vector& arguments, tensorflow::StringPiece name); + // Emit an LLVM global variable for every constant buffer allocation. + Status EmitConstantGlobals(); + + // Emit code to map one element according to `map_instr`. + llvm::Value* EmitElementalMap( + const HloMapInstruction& map_instr, + tensorflow::gtl::ArraySlice elemental_operands, + tensorflow::StringPiece name); protected: // @@ -117,6 +122,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleCopy(HloInstruction* copy) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; @@ -138,14 +144,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleRecvDone(HloInstruction* recv_done) override; Status HandlePad(HloInstruction* pad) override; Status HandleTuple(HloInstruction* tuple) override; - Status HandleMap(HloInstruction* map) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleWhile(HloInstruction* xla_while) override; Status HandleConcatenate(HloInstruction* concatenate) override; Status HandleConditional(HloInstruction* conditional) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* gen_token) override; + Status HandleIota(HloInstruction* iota) override; + Status HandleRng(HloInstruction* rng) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; @@ -211,9 +219,18 @@ class IrEmitter : public DfsHloVisitorWithDefault { // computation function being emitted by this emitter. llvm::Value* GetTempBuffersArgument(); - // Emits code that computes the address of the given temporary buffer to the - // function. target_shape is the shape of this temporary buffer. - // The returned Value's type is a pointer to element_type. + // Helper for EmitTempBufferPointer. + llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice, + const Shape& target_shape); + + // Helper for EmitTempBufferPointer. + llvm::Value* EmitThreadLocalTempBufferPointer( + const BufferAllocation::Slice& slice, const Shape& target_shape); + + // Emits code that computes the address of the given buffer allocation slice. + // + // TODO(sanjoy): This should be renamed to reflect that it no longer provides + // access to just temporaries. llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice, const Shape& target_shape); @@ -225,44 +242,27 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::StringPiece function_name_suffix); // Used for LLVM IR register names. - // Methods that emit a function call. - // Parameters: - // function - The LLVM function to call. - // return_shape - The return shape of the HLO computation that was used to - // make the function. Not the same as the return type of the function - // in LLVM, since we use output parameters for the return type. - // element_count - number of elements to return (array form only). - // parameter_addresses - pointers to be passed to the function as - // parameters. - // name - used for LLVM IR register names. - - // Emits a function call, returning a scalar, often an element of a larger - // array. Returns a Value for the scalar element returned by the function. - llvm::Value* EmitElementFunctionCall( - llvm::Function* function, const Shape& return_shape, - tensorflow::gtl::ArraySlice parameter_addresses, + // Emits a call to a thread local function (e.g. to the computation nested + // within a reduce or a map). Thread local callees (by definition) only write + // to and read from thread local allocations. + // + // `parameters` holds the *scalar values* that need to be passed to the + // callee. The return value is the scalar returned by the callee. + llvm::Value* EmitThreadLocalCall( + const HloComputation& callee, + tensorflow::gtl::ArraySlice parameters, tensorflow::StringPiece name); - // Array function call emitter. Stores the function's result into a supplied - // buffer. - // Parameters: - // function - The LLVM function to call. - // parameter_addresses - pointers to be passed to the function as - // parameters. - // return_value - pointer to a buffer where the call result is stored. - - void EmitArrayFunctionCallInto( - llvm::Function* function, - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name); - - // Array function call emitter. Returns a Value for the function's return - // value buffer address. The return value buffer is alloca'ed by this - // function. - llvm::Value* EmitArrayFunctionCall( - llvm::Function* function, const Shape& return_shape, int64 element_count, - tensorflow::gtl::ArraySlice parameter_addresses, - tensorflow::StringPiece name); + // Emits a call to a "global" function (e.g. to the computation nested within + // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to + // the parameters and return values for these computations so there is no need + // to explicitly pass parameters or return results. + void EmitGlobalCall(const HloComputation& callee, + tensorflow::StringPiece name); + + // Returns the buffer to which a global call to `callee` would have written + // its result. + llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee); // Verifies that the element types of all of the given operand instructions // match and are of one of the given supported types. @@ -401,11 +401,10 @@ class IrEmitter : public DfsHloVisitorWithDefault { NameUniquer name_uniquer_; // Map containing all previously emitted computations. - std::map emitted_functions_; + std::map emitted_functions_; // Map containing all previously emitted thread-local temporary buffers. - std::map, - llvm::AllocaInst*> + std::map, llvm::Value*> thread_local_buffers_; // The following fields track the IR emission state. According to LLVM memory @@ -413,7 +412,17 @@ class IrEmitter : public DfsHloVisitorWithDefault { // creates the encapsulated llvm::Function s.t. it is added to the llvm // module's function list). std::unique_ptr compute_function_; - llvm::IRBuilder<> ir_builder_; + llvm::IRBuilder<> b_; + + // The buffer allocation slice for the root of the computation being compiled. + // Only relevant for thread local computations. + BufferAllocation::Slice computation_root_allocation_; + + // Maps the buffer allocation slices for the parameters to the computation + // being compiled to their parameter numbers. Only relevant for thread local + // computations. + tensorflow::gtl::FlatMap + computation_parameter_allocations_; // Maps HLO instructions to their index into the profile counter array. const std::unordered_map @@ -449,23 +458,22 @@ class IrEmitter : public DfsHloVisitorWithDefault { : use_rdtscp_(use_rdtscp), prof_counters_(prof_counters) {} // Record the cycle counter before an HLO executes. - void RecordCycleStart(llvm::IRBuilder<>* ir_builder, HloInstruction* hlo); + void RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo); // Record the number of cycles it took for an HLO to execute. - void RecordCycleDelta(llvm::IRBuilder<>* ir_builder, HloInstruction* hlo, + void RecordCycleDelta(llvm::IRBuilder<>* b, HloInstruction* hlo, llvm::Value* prof_counter); // Record the number of cycles it took for the entire computation to // execute. - void RecordCompleteComputation(llvm::IRBuilder<>* ir_builder, + void RecordCompleteComputation(llvm::IRBuilder<>* b, llvm::Value* prof_counter); // Convenience function to generate a call to an intrinsic which reads the // CPU cycle counter. - llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* ir_builder); + llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* b); // Store the cycle counter delta to the per-HLO profile counter. - void UpdateProfileCounter(llvm::IRBuilder<>* ir_builder, - llvm::Value* prof_counter, llvm::Value* cycle_end, - llvm::Value* cycle_start); + void UpdateProfileCounter(llvm::IRBuilder<>* b, llvm::Value* prof_counter, + llvm::Value* cycle_end, llvm::Value* cycle_start); private: // Should we use the x86-specific rdtscp or the generic readcyclecounter @@ -513,6 +521,17 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Returns the number of bytes within the shape. int64 ByteSizeOf(const Shape& shape) const; + StatusOr EmitTargetElementLoopBodyForMap( + HloMapInstruction* map, const llvm_ir::IrArray::Index& index); + StatusOr EmitTargetElementLoopBodyForReduceWindow( + HloReduceWindowInstruction* reduce_window, + const llvm_ir::IrArray::Index& index); + StatusOr EmitTargetElementLoopBodyForConvolution( + HloConvolutionInstruction* convolution, + const llvm_ir::IrArray::Index& index); + StatusOr EmitTargetElementLoopBodyForReduce( + HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index); + enum class XfeedKind { kInfeed, kOutfeed, @@ -546,6 +565,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { LiteralPtrHashFunctor, LiteralPtrEqualityFunctor> emitted_literals_; + tensorflow::gtl::FlatMap + constant_buffer_to_global_; + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index 2d6f2f3818a7bd4424aaa7d918ca86abef15c0e9..2db4d000f5b149969c88fb4325ca28aa11dc3708 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -49,11 +49,10 @@ IrFunction::IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage, const bool optimize_for_size_requested, const bool enable_fast_math, llvm::Module* llvm_module, - llvm::IRBuilder<>* ir_builder, - int64 num_dynamic_loop_bounds) - : ir_builder_(ir_builder), + llvm::IRBuilder<>* b, int64 num_dynamic_loop_bounds) + : b_(b), llvm_module_(llvm_module), - caller_insert_point_guard_(*ir_builder), + caller_insert_point_guard_(*b), num_dynamic_loop_bounds_(num_dynamic_loop_bounds) { Initialize(function_name, linkage, optimize_for_size_requested, enable_fast_math); @@ -61,7 +60,7 @@ IrFunction::IrFunction(const string& function_name, IrFunction::~IrFunction() { // Emit function return value. - ir_builder_->CreateRetVoid(); + b_->CreateRetVoid(); } DynamicLoopBounds IrFunction::GetDynamicLoopBounds() { @@ -81,9 +80,16 @@ void IrFunction::Initialize(const string& function_name, // void function(i8* retval, i8* run_options, i8** params, i8** temps, // i64* dynamic_loop_bounds, i64* prof_counters) // - // retval: points to the returned value. - // params: address of an array with pointers to parameters. - // temps: address of an array with pointers to temporary buffers. + // For thread local functions: + // retval: points to the returned value. + // params: address of an array with pointers to parameters. + // temps: is null + // + // For global functions: + // retval: is null + // params: is null + // temps: address of an array with pointers to temporary buffers and entry + // computation parameters. // // Therefore, the generated function's signature (FunctionType) is statically // determined - parameter unpacking is done in code generated into the @@ -174,7 +180,7 @@ void IrFunction::Initialize(const string& function_name, function_->addAttribute(argument.getArgNo() + 1, llvm::Attribute::NoAlias); } - ir_builder_->SetInsertPoint(llvm::BasicBlock::Create( + b_->SetInsertPoint(llvm::BasicBlock::Create( /*Context=*/llvm_module_->getContext(), /*Name=*/"entry", /*Parent=*/function_)); @@ -184,9 +190,8 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { CHECK_GT(num_dynamic_loop_bounds_, 0); CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); - return ir_builder_->CreateLoad( - ir_builder_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), - ir_builder_->getInt64(offset), AsStringRef(name))); + return b_->CreateLoad(b_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), + b_->getInt64(offset), AsStringRef(name))); } // Emits code to allocate an array of parameter address pointers, and store @@ -195,27 +200,32 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { // address buffer). std::vector GetArrayFunctionCallArguments( tensorflow::gtl::ArraySlice parameter_addresses, - llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name, + llvm::IRBuilder<>* b, tensorflow::StringPiece name, llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { - llvm::Value* parameter_addresses_buffer = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - ir_builder->getInt8PtrTy(), - ir_builder->getInt32(parameter_addresses.size()), - tensorflow::strings::StrCat(name, "_parameter_addresses"), - ir_builder); - for (size_t i = 0; i < parameter_addresses.size(); ++i) { - llvm::Value* parameter_as_i8ptr = ir_builder->CreateBitCast( - parameter_addresses[i], ir_builder->getInt8PtrTy(), - AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, - "_address_as_i8ptr"))); - llvm::Value* slot_in_param_addresses = ir_builder->CreateInBoundsGEP( - parameter_addresses_buffer, {ir_builder->getInt64(i)}); - ir_builder->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); + llvm::Value* parameter_addresses_buffer; + + if (parameter_addresses.empty()) { + parameter_addresses_buffer = + llvm::Constant::getNullValue(b->getInt8PtrTy()->getPointerTo()); + } else { + parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()), + tensorflow::strings::StrCat(name, "_parameter_addresses"), b); + + for (size_t i = 0; i < parameter_addresses.size(); ++i) { + llvm::Value* parameter_as_i8ptr = + b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(), + AsStringRef(tensorflow::strings::StrCat( + name, "_parameter_", i, "_address_as_i8ptr"))); + llvm::Value* slot_in_param_addresses = + b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); + b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); + } } const auto to_int8_ptr = [=](llvm::Value* ptr) { - return ir_builder->CreatePointerCast(ptr, ir_builder->getInt8PtrTy()); + return b->CreatePointerCast(ptr, b->getInt8PtrTy()); }; std::vector arguments{ to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg), @@ -230,22 +240,21 @@ std::vector GetArrayFunctionCallArguments( // calls to 'parallel_function' (and joins threads before returning). Status EmitCallToParallelForkJoin( const std::vector& arguments, const Shape& shape, - const std::vector& dimension_partition_counts, - llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function, - const string& name) { - llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); + const std::vector& dimension_partition_counts, llvm::IRBuilder<>* b, + llvm::Function* parallel_function, const string& name) { + llvm::Module* module = b->GetInsertBlock()->getModule(); // Build ParallelForkJoin function type. std::vector compute_function_params = GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0); // Number of parallel compute functions. - compute_function_params.push_back(ir_builder->getInt32Ty()); + compute_function_params.push_back(b->getInt32Ty()); // Array of partitions. There is an array element for each // partition x partition_dim x 2 (for dimension start and limit). compute_function_params.push_back( llvm::Type::getInt64PtrTy(module->getContext())); // Number of partitioned most-major dimensions in 'shape'. - compute_function_params.push_back(ir_builder->getInt32Ty()); + compute_function_params.push_back(b->getInt32Ty()); // Function pointer for compute function to be dispatched in parallel. compute_function_params.push_back( llvm::Type::getInt8PtrTy(module->getContext())); @@ -268,7 +277,7 @@ Status EmitCallToParallelForkJoin( ShapePartitionIterator partition_iterator(shape, dimension_partition_counts); const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); // Add argument specifying the number of parallel partitions. - fork_join_arguments.push_back(ir_builder->getInt32(num_partitions)); + fork_join_arguments.push_back(b->getInt32(num_partitions)); // The number of partitioned most-major dimensions in 'shape'. const int32 num_partitioned_dims = dimension_partition_counts.size(); @@ -293,15 +302,15 @@ Status EmitCallToParallelForkJoin( const std::pair& dim_partition = dim_partitions[j]; const int32 index = partition_index + j * dim_partition_size; // Store partition [dim_start, dim_limit) intervals for each dimension. - partitions[index] = ir_builder->getInt64(dim_partition.first); + partitions[index] = b->getInt64(dim_partition.first); partitions[index + 1] = - ir_builder->getInt64(dim_partition.first + dim_partition.second); + b->getInt64(dim_partition.first + dim_partition.second); } } // Create global variable out of dimension partitions in 'partitions'. llvm::ArrayType* partitions_array_type = - llvm::ArrayType::get(ir_builder->getInt64Ty(), partition_array_size); + llvm::ArrayType::get(b->getInt64Ty(), partition_array_size); llvm::Constant* partitions_array = llvm::ConstantArray::get(partitions_array_type, partitions); llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( @@ -315,16 +324,16 @@ Status EmitCallToParallelForkJoin( tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); // Add argument specifying parallel dimension partitions. - fork_join_arguments.push_back(ir_builder->CreateBitCast( - global_partitions_array, - llvm::Type::getInt64PtrTy(module->getContext()))); + fork_join_arguments.push_back( + b->CreateBitCast(global_partitions_array, + llvm::Type::getInt64PtrTy(module->getContext()))); // Add argument specifying the number of partitioned most-major dimensions. - fork_join_arguments.push_back(ir_builder->getInt32(num_partitioned_dims)); + fork_join_arguments.push_back(b->getInt32(num_partitioned_dims)); // Add argument for parallel compute function pointer. fork_join_arguments.push_back( - ir_builder->CreateBitCast(parallel_function, ir_builder->getInt8PtrTy())); + b->CreateBitCast(parallel_function, b->getInt8PtrTy())); // Emit call to parallel fork/join. - ir_builder->CreateCall(fork_join_func, fork_join_arguments); + b->CreateCall(fork_join_func, fork_join_arguments); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h index 2e55181eed867aca762f2b9b8310624ea12c7487..a41cbb64cdd9f5b6de5d1eadfbf7e63e1e984801 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.h +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -54,7 +54,7 @@ class IrFunction { IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage, const bool optimize_for_size_requested, const bool enable_fast_math, llvm::Module* llvm_module, - llvm::IRBuilder<>* ir_builder, int64 num_dynamic_loop_bounds); + llvm::IRBuilder<>* b, int64 num_dynamic_loop_bounds); ~IrFunction(); // Emit ir to read and return the set of ir values representing the dynamic @@ -97,7 +97,7 @@ class IrFunction { // 'offset' from the "dynamic_loop_bounds" argument of this function. llvm::Value* GetDynamicLoopBound(int64 offset); - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; llvm::Module* llvm_module_; llvm::IRBuilder<>::InsertPointGuard caller_insert_point_guard_; @@ -116,7 +116,7 @@ class IrFunction { // Returns an array of compute function call argument ir values. std::vector GetArrayFunctionCallArguments( tensorflow::gtl::ArraySlice parameter_addresses, - llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name, + llvm::IRBuilder<>* b, tensorflow::StringPiece name, llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg); @@ -124,9 +124,8 @@ std::vector GetArrayFunctionCallArguments( // calls to 'parallel_function' (and joins threads before returning). Status EmitCallToParallelForkJoin( const std::vector& arguments, const Shape& shape, - const std::vector& dimension_partition_counts, - llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function, - const string& name); + const std::vector& dimension_partition_counts, llvm::IRBuilder<>* b, + llvm::Function* parallel_function, const string& name); } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index 2e5cc96098241415b82f225afc81981f3e1069e0..cef5e57b0b12b7ae93af0d2508b2b9d6a592d390 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/IR/Verifier.h" #include "llvm/Transforms/Utils/Cloning.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/logging.h" @@ -52,46 +53,14 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, llvm::BasicBlock* vector_tanh_body = llvm::BasicBlock::Create(*context, "body", vector_tanh_function); - llvm::IRBuilder<> ir_builder(vector_tanh_body); + llvm::IRBuilder<> b(vector_tanh_body); llvm::FastMathFlags fast_math_flags; - fast_math_flags.setFast(); - ir_builder.setFastMathFlags(fast_math_flags); - - VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "tanh_f32"); + fast_math_flags.setFast(enable_fast_math); + b.setFastMathFlags(fast_math_flags); llvm::Value* input = &*vector_tanh_function->arg_begin(); - CHECK_EQ(input->getType(), vsl.vector_type()); - - // This implements the same rational interpolant as implemented in Eigen3. - llvm::Value* input_clamped = - vsl.Clamp(input, /*low=*/GetIeeeF32(-9.0), /*high=*/GetIeeeF32(9.0)); - - std::array numerator_coeffs{ - -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, - 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, - 4.89352455891786e-03f}; - - std::array denominator_coeffs{ - 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, - 4.89352518554385e-03f}; - - llvm::Value* input_squared = vsl.Mul(input_clamped, input_clamped); - llvm::Value* numerator = vsl.SplatFloat(GetIeeeF32(numerator_coeffs[0])); - for (int i = 1; i < numerator_coeffs.size(); i++) { - numerator = - vsl.MulAdd(input_squared, numerator, GetIeeeF32(numerator_coeffs[i])); - } - - numerator = vsl.Mul(input_clamped, numerator); - - llvm::Value* denominator = vsl.SplatFloat(GetIeeeF32(denominator_coeffs[0])); - for (int i = 1; i < denominator_coeffs.size(); i++) { - denominator = vsl.MulAdd(input_squared, denominator, - GetIeeeF32(denominator_coeffs[i])); - } - - llvm::Value* result = vsl.Div(numerator, denominator); - ir_builder.CreateRet(result); + CHECK_EQ(vector_width, input->getType()->getVectorNumElements()); + b.CreateRet(llvm_ir::EmitFastTanh(&b, input)); DCHECK(!llvm::verifyFunction(*vector_tanh_function)); return vector_tanh_function; @@ -113,12 +82,12 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, llvm::BasicBlock* vector_exp_body = llvm::BasicBlock::Create(*context, "body", vector_exp_function); - llvm::IRBuilder<> ir_builder(vector_exp_body); + llvm::IRBuilder<> b(vector_exp_body); llvm::FastMathFlags fast_math_flags; fast_math_flags.setFast(); - ir_builder.setFastMathFlags(fast_math_flags); + b.setFastMathFlags(fast_math_flags); - VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "exp_f32"); + VectorSupportLibrary vsl(F32, vector_width, &b, "exp_f32"); // This implements the same polynomial approximation as implemented in Eigen3. @@ -160,21 +129,21 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, // VectorSupportLibrary (intentionally) can't juggle more than one type at a // time so drop down to IRBuilder for this bit. llvm::Value* vector_constant_0x7f = - ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f)); + b.CreateVectorSplat(vector_width, b.getInt32(0x7f)); llvm::Value* vector_constant_23 = - ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23)); + b.CreateVectorSplat(vector_width, b.getInt32(23)); llvm::Type* i32_vector_type = - llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width); + llvm::VectorType::get(b.getInt32Ty(), vector_width); // fx is clamped so we don't have to worry about it being out of range for // i32. - llvm::Value* emm0 = ir_builder.CreateFPToSI(fx, i32_vector_type); - emm0 = ir_builder.CreateAdd(emm0, vector_constant_0x7f); - emm0 = ir_builder.CreateShl(emm0, vector_constant_23); - llvm::Value* emm0_f32 = ir_builder.CreateBitCast(emm0, vsl.vector_type()); + llvm::Value* emm0 = b.CreateFPToSI(fx, i32_vector_type); + emm0 = b.CreateAdd(emm0, vector_constant_0x7f); + emm0 = b.CreateShl(emm0, vector_constant_23); + llvm::Value* emm0_f32 = b.CreateBitCast(emm0, vsl.vector_type()); llvm::Value* result = vsl.Max(vsl.Mul(y, emm0_f32), input); - ir_builder.CreateRet(result); + b.CreateRet(result); DCHECK(!llvm::verifyFunction(*vector_exp_function)); return vector_exp_function; @@ -196,13 +165,13 @@ llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module, llvm::BasicBlock* vector_log_body = llvm::BasicBlock::Create(*context, "body", vector_log_function); - llvm::IRBuilder<> ir_builder(vector_log_body); + llvm::IRBuilder<> b(vector_log_body); llvm::FastMathFlags fast_math_flags; fast_math_flags.setFast(); - ir_builder.setFastMathFlags(fast_math_flags); + b.setFastMathFlags(fast_math_flags); llvm::Value* input = &*vector_log_function->arg_begin(); - VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "log_f32"); + VectorSupportLibrary vsl(F32, vector_width, &b, "log_f32"); const llvm::APFloat half = GetIeeeF32(0.5); const llvm::APFloat one = GetIeeeF32(1.0); @@ -238,22 +207,21 @@ llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module, // VectorSupportLibrary (intentionally) can't juggle more than one type at a // time so drop down to IRBuilder for this bit. llvm::Value* vector_constant_0x7f = - ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f)); + b.CreateVectorSplat(vector_width, b.getInt32(0x7f)); llvm::Value* vector_constant_23 = - ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23)); + b.CreateVectorSplat(vector_width, b.getInt32(23)); llvm::Type* i32_vector_type = - llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width); + llvm::VectorType::get(b.getInt32Ty(), vector_width); - llvm::Value* emm0 = ir_builder.CreateLShr( - ir_builder.CreateBitCast(input, i32_vector_type), vector_constant_23); + llvm::Value* emm0 = + b.CreateLShr(b.CreateBitCast(input, i32_vector_type), vector_constant_23); // Keep only the fractional part. input = vsl.FloatAnd(input, inv_mant_mask); input = vsl.FloatOr(input, half); - emm0 = ir_builder.CreateSub(emm0, vector_constant_0x7f); - llvm::Value* e = - vsl.Add(one, ir_builder.CreateSIToFP(emm0, vsl.vector_type())); + emm0 = b.CreateSub(emm0, vector_constant_0x7f); + llvm::Value* e = vsl.Add(one, b.CreateSIToFP(emm0, vsl.vector_type())); // part2: // if( x < SQRTHF ) { @@ -294,7 +262,7 @@ llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module, llvm::Value* or_rhs = vsl.FloatAnd(iszero_mask, minus_inf); llvm::Value* result = vsl.FloatOr(or_lhs, or_rhs); - ir_builder.CreateRet(result); + b.CreateRet(result); DCHECK(!llvm::verifyFunction(*vector_log_function)); return vector_log_function; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index 59ae5acd8b7cea049f09eaf4cc98b41339973c77..8560e4296aa95fe791446abb1b4363b9145f343e 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -25,8 +25,8 @@ namespace cpu { ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, const llvm_ir::IrArray& target_array, - const DynamicLoopBounds* dynamic_loop_bounds, llvm::IRBuilder<>* ir_builder) - : LoopEmitter(target_element_generator, target_array, ir_builder), + const DynamicLoopBounds* dynamic_loop_bounds, llvm::IRBuilder<>* b) + : LoopEmitter(target_element_generator, target_array, b), dynamic_loop_bounds_(dynamic_loop_bounds) {} std::vector @@ -37,7 +37,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( CHECK(!ShapeUtil::IsTuple(shape_)); CHECK(!ShapeUtil::IsScalar(shape_)); - llvm_ir::ForLoopNest loop_nest(loop_name, ir_builder_); + llvm_ir::ForLoopNest loop_nest(loop_name, b_); const int64 num_dims = shape_.dimensions_size(); llvm_ir::IrArray::Index array_index(index_type, num_dims); @@ -65,8 +65,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( } } // Point IR builder at inner loop BB. - llvm_ir::SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), - ir_builder_); + llvm_ir::SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), b_); // Set exit_bb_ to the exit block of the loop nest. exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock(); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h index 25e182a26d6f21c7eba550020cf17403aa92abf7..076c683ca566f2c53992c358903d2aadead290f9 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -54,7 +54,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, const llvm_ir::IrArray& target_array, const DynamicLoopBounds* dynamic_loop_bounds, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc index d03da46575b331de113cc5f33c2b4267504e8308..a5f34908d70dd18ec017bdf9833c7df40f80db07 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc @@ -20,6 +20,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -58,13 +59,14 @@ using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, // [partition1_dim2_start] // [partition1_dim2_limit] // -void __xla_cpu_runtime_ParallelForkJoin( +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( void* result_ptr, const void* run_options_ptr, const void** params, void** temps, uint64* prof_counters, int32 num_partitions, int64* partitions, int32 num_partitioned_dims, void* function_ptr) { VLOG(2) << "ParallelForkJoin ENTRY" << " num_partitions: " << num_partitions << " num_partitioned_dims: " << num_partitioned_dims; + CHECK_EQ(params, nullptr); CHECK_GT(num_partitions, 1); CHECK_GT(num_partitioned_dims, 0); const xla::ExecutableRunOptions* run_options = @@ -79,9 +81,9 @@ void __xla_cpu_runtime_ParallelForkJoin( for (int32 i = 1; i < num_partitions; ++i) { const int64 offset = i * stride; run_options->intra_op_thread_pool()->enqueueNoNotification( - [i, function, result_ptr, run_options_ptr, params, temps, prof_counters, + [i, function, result_ptr, run_options_ptr, temps, prof_counters, partitions, offset, &bc]() { - function(result_ptr, run_options_ptr, params, temps, + function(result_ptr, run_options_ptr, nullptr, temps, &partitions[offset], prof_counters); bc.DecrementCount(); VLOG(3) << "ParallelForkJoin partition " << i << " done."; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index 39b13183ff093611a42b3931d45f64eadb420622..a71a85913cfef271bc2a226cb0cf2dd4204499a4 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -20,6 +20,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h" +#include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" using tensorflow::int32; @@ -77,27 +78,24 @@ void MatMulImpl(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, } // namespace -void __xla_cpu_runtime_EigenMatMulF16(const void* run_options_ptr, - Eigen::half* out, Eigen::half* lhs, - Eigen::half* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16( + const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, + Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, + int32 transpose_rhs) { MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } -void __xla_cpu_runtime_EigenMatMulF32(const void* run_options_ptr, float* out, - float* lhs, float* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } -void __xla_cpu_runtime_EigenMatMulF64(const void* run_options_ptr, double* out, - double* lhs, double* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64( + const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc index f8c8dd5e93d53db8d87be0208b5cf4daac3464f1..997fdd2ab309f0b68a9dbd0f156a8dc19955b437 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc @@ -23,6 +23,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool" +#include "tensorflow/core/platform/dynamic_annotations.h" using tensorflow::int32; using tensorflow::int64; @@ -74,10 +75,9 @@ void MatMulF64(const void* run_options_ptr, double* out, double* lhs, } // namespace -void __xla_cpu_runtime_MKLMatMulF32(const void* run_options_ptr, float* out, - float* lhs, float* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLMatMulF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread @@ -88,11 +88,11 @@ void __xla_cpu_runtime_MKLMatMulF32(const void* run_options_ptr, float* out, // Set thread number back to the previous number. mkl_set_num_threads_local(prev_num_threads); } + // BLAS GEMM API for 64-bit Matrix Multiplication -void __xla_cpu_runtime_MKLMatMulF64(const void* run_options_ptr, double* out, - double* lhs, double* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLMatMulF64( + const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread @@ -103,22 +103,26 @@ void __xla_cpu_runtime_MKLMatMulF64(const void* run_options_ptr, double* out, // Set thread number back to the previous number. mkl_set_num_threads_local(prev_num_threads); } -void __xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr, - float* out, float* lhs, - float* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr, + float* out, float* lhs, float* rhs, + int64 m, int64 n, int64 k, + int32 transpose_lhs, + int32 transpose_rhs) { // Set the thread number to 1 for single threaded excution. int prev_num_threads = mkl_set_num_threads_local(1); MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); // Set thread number back to the previous number. mkl_set_num_threads_local(prev_num_threads); } -void __xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr, - double* out, double* lhs, - double* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr, + double* out, double* lhs, + double* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { // Set the thread number to 1 for single threaded excution. int prev_num_threads = mkl_set_num_threads_local(1); MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index 17303e2f0d34e531a3a56aa147608b949e0f43ae..16692e7f2e6145b2649b67987eef47916e958be2 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -17,6 +17,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h" +#include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" using tensorflow::int32; @@ -71,7 +72,8 @@ void SingleThreadedMatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, } // namespace -void __xla_cpu_runtime_EigenSingleThreadedMatMulF16( +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF16( const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { @@ -79,16 +81,22 @@ void __xla_cpu_runtime_EigenSingleThreadedMatMulF16( transpose_lhs, transpose_rhs); } -void __xla_cpu_runtime_EigenSingleThreadedMatMulF32( - const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void* run_options_ptr, + float* out, float* lhs, + float* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } -void __xla_cpu_runtime_EigenSingleThreadedMatMulF64( - const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr, + double* out, double* lhs, + double* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index 7e792a82b8bf28121c054332bc619d736858c729..f227e4ae139b92e56786e38ef8eef72c9e2cd424 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -21,9 +21,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -38,12 +38,13 @@ int main(int argc, char** argv) { // Transfer parameters. std::unique_ptr param0_literal = - xla::Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + xla::LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = xla::Literal::CreateR2( - {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); + std::unique_ptr param1_literal = + xla::LiteralUtil::CreateR2( + {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); std::unique_ptr param1_data = client->TransferToServer(*param1_literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 66ae5ef0f66e90982102d73e474f5d0582f5415c..181cec3cdddeb40daf5276d9d1d6a139417a6072 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -40,7 +40,7 @@ tf_cc_test( name = "cpu_fusion_test", srcs = ["cpu_fusion_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -82,7 +82,7 @@ tf_cc_test( name = "cpu_noalias_test", srcs = ["cpu_noalias_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -128,16 +128,16 @@ tf_cc_test( name = "cpu_infeed_test", srcs = ["cpu_infeed_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h b/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h index 7c8d07a10baf55dba8cbd347ebe1459b78e268e0..77b3a0301f2f90b577b7eaad86064dc30e2d9456 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h @@ -22,7 +22,7 @@ namespace xla { namespace cpu { // Tests that verify IR emitted by the CPU backend is as expected. -class CpuCodegenTest : public LLVMIRGenTestBase {}; +class CpuCodegenTest : public LlvmIrGenTestBase {}; } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc index 1d4bf483aedef5a15ef51cf216030b76255d4ec8..00a7aa2ad2f6bac4877302296ccb76222557535c 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -40,7 +40,7 @@ class CpuExternalConstantsTest : public CpuCodegenTest { HloInstruction* constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2FromArray2D(backing_array))); + LiteralUtil::CreateR2FromArray2D(backing_array))); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index 783b2820e922612973632c555fc8ae01418f1754..d98856fdbf4165a5909f193ebe8512e21af83dfc 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -43,8 +43,8 @@ class CpuFusionTest : public HloTestBase { TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { auto builder = HloComputation::Builder(TestName()); - auto input_literal1 = Literal::CreateR1({1.0, 2.0, 3.0}); - auto input_literal2 = Literal::CreateR1({-2.0, -42.0, 2.0}); + auto input_literal1 = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); + auto input_literal2 = LiteralUtil::CreateR1({-2.0, -42.0, 2.0}); Shape vshape = input_literal1->shape(); auto input1 = builder.AddInstruction( @@ -83,7 +83,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { TEST_F(CpuFusionTest, FuseElementwiseOpChain) { auto builder = HloComputation::Builder(TestName()); - auto input_literal = Literal::CreateR1({-1.5, -2.5, -3.0}); + auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); Shape vshape = input_literal->shape(); auto input = builder.AddInstruction( @@ -99,7 +99,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { auto two = builder.AddInstruction(HloInstruction::CreateBroadcast( vshape, builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))), {})); builder.AddInstruction( HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor)); @@ -134,7 +134,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { // middle. auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto input_literal = Literal::CreateR1({-1.5, -2.5, -3.0}); + auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); Shape vshape = input_literal->shape(); auto input = builder.AddInstruction( @@ -166,7 +166,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { ShapeUtil::MakeShape(F32, {6, 1}), concatenate)), /*init_value=*/ builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), /*dimensions_to_reduce=*/{1}, add_f32)); auto exp = builder.AddInstruction( @@ -176,7 +176,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { auto two = builder.AddInstruction(HloInstruction::CreateBroadcast( cshape, builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))), {})); builder.AddInstruction( HloInstruction::CreateBinary(cshape, HloOpcode::kMultiply, two, floor)); @@ -231,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { // operand vectors. Test for this problem by counting the number of nodes in // each fusion instruction to ensure that negate is not duplicated. auto builder = HloComputation::Builder(TestName()); - auto input_literal = Literal::CreateR1({1.0, 2.0, 3.0}); + auto input_literal = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); Shape vshape = input_literal->shape(); auto constant = builder.AddInstruction( @@ -292,10 +292,10 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) { // computation. The duplication is caused by the other use of exp2 in the // tuple. auto builder = HloComputation::Builder(TestName()); - auto input_literal1 = Literal::CreateR1({1.0, 2.0, 3.0}); - auto input_literal2 = Literal::CreateR1({-2.0, -42.0, 2.0}); + auto input_literal1 = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); + auto input_literal2 = LiteralUtil::CreateR1({-2.0, -42.0, 2.0}); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); Shape shape = constant->shape(); auto exp1 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc index ea7e479d66fbda1bfd388fd77b25db2db56f0d65..c35569c6619ba5b534c5d8bb7ad683d84b6ecf4b 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase { }; TEST_F(InfeedTest, SingleInfeedR0Bool) { - TestInfeedRoundTrip(*Literal::CreateR0(true)); + TestInfeedRoundTrip(*LiteralUtil::CreateR0(true)); } TEST_F(InfeedTest, SingleInfeedR1U32) { - TestInfeedRoundTrip(*Literal::CreateR1({1, 2, 3})); + TestInfeedRoundTrip(*LiteralUtil::CreateR1({1, 2, 3})); } TEST_F(InfeedTest, SingleInfeedR2F32) { - TestInfeedRoundTrip(*Literal::CreateR2F32Linspace(0.0, 1.0, 128, 64)); + TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); } TEST_F(InfeedTest, SingleInfeedR3F32) { TestInfeedRoundTrip( - *Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); - TestInfeedRoundTrip( - *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, - r3_dim0minor)); + TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0minor)); - TestInfeedRoundTrip( - *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, - r3_dim0major)); + TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0major)); } TEST_F(InfeedTest, SingleInfeedR4S32) { - TestInfeedRoundTrip(*Literal::CreateR4( + TestInfeedRoundTrip(*LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } TEST_F(InfeedTest, SingleInfeedTuple) { TestInfeedRoundTrip( - *Literal::MakeTuple({Literal::CreateR1({1, 2, 3}).get(), - Literal::CreateR0(false).get()})); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), + LiteralUtil::CreateR0(false).get()})); } TEST_F(InfeedTest, SingleInfeedEmptyTuple) { - TestInfeedRoundTrip(*Literal::MakeTuple({})); + TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); } // Tests Infeed operation used in a while loop, as in the code below. The @@ -156,13 +156,16 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { }); // Send 5 Infeed data of shape F32[3]. - ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1({1, 2, 3}))); - ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1({4, 5, 6}))); - ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1({7, 8, 9}))); ASSERT_IS_OK( - client_->TransferToInfeed(*Literal::CreateR1({10, 11, 12}))); + client_->TransferToInfeed(*LiteralUtil::CreateR1({1, 2, 3}))); + ASSERT_IS_OK( + client_->TransferToInfeed(*LiteralUtil::CreateR1({4, 5, 6}))); + ASSERT_IS_OK( + client_->TransferToInfeed(*LiteralUtil::CreateR1({7, 8, 9}))); + ASSERT_IS_OK( + client_->TransferToInfeed(*LiteralUtil::CreateR1({10, 11, 12}))); ASSERT_IS_OK( - client_->TransferToInfeed(*Literal::CreateR1({13, 14, 15}))); + client_->TransferToInfeed(*LiteralUtil::CreateR1({13, 14, 15}))); delete computation_thread; // Joins the thread. auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); @@ -217,7 +220,7 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // The body adds the reduced value of the Infeed data (first tuple element) // to the previous accumulator, and returns the accumulator and the continue // flag (second tuple element) as a tuple. - const auto build_body = [this, &result_shape](const Shape& infeed_shape) { + const auto build_body = [&result_shape](const Shape& infeed_shape) { XlaComputation body; XlaBuilder builder("body"); auto prev = Parameter(&builder, 0, result_shape, "prev"); @@ -247,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Send the first 4 Infeed data of shape Tuple(F32[2], PRED). ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({1, 2}).get(), - Literal::CreateR0(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), + LiteralUtil::CreateR0(true).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({3, 4}).get(), - Literal::CreateR0(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({3, 4}).get(), + LiteralUtil::CreateR0(true).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({5, 6}).get(), - Literal::CreateR0(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({5, 6}).get(), + LiteralUtil::CreateR0(true).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({7, 8}).get(), - Literal::CreateR0(false).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({7, 8}).get(), + LiteralUtil::CreateR0(false).get()}))); // Asynchronously launch the execution on the device. std::unique_ptr result; @@ -272,14 +275,14 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED). sleep(1); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({1, 2, 3}).get(), - Literal::CreateR0(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), + LiteralUtil::CreateR0(true).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({7, 8, 9}).get(), - Literal::CreateR0(false).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({7, 8, 9}).get(), + LiteralUtil::CreateR0(false).get()}))); ASSERT_IS_OK(client_->TransferToInfeed( - *Literal::MakeTuple({Literal::CreateR1({4, 5, 6}).get(), - Literal::CreateR0(true).get()}))); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({4, 5, 6}).get(), + LiteralUtil::CreateR0(true).get()}))); // Wait for the execution to be done, and transfer the result. delete computation_thread; // Joins the thread. diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index 3b6b0ed74065615fb9e47a0ec3c6c4ab078e45c4..01daed4bcd38323bfe33e798a78c2b00b150a1bc 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "llvm/IR/Module.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" @@ -42,7 +42,7 @@ TEST_F(CpuNoAliasTest, Concat) { HloComputation::Builder builder(TestName()); std::unique_ptr literal = - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* param_x = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "x")); @@ -78,7 +78,7 @@ TEST_F(CpuNoAliasTest, Concat) { llvm::Function* func = llvm::cast( ir_module.getOrInsertFunction("test_fn", llvm::Type::getVoidTy(context))); llvm::BasicBlock* bb = llvm::BasicBlock::Create(context, "body", func); - llvm::IRBuilder<> ir_builder(bb); + llvm::IRBuilder<> b(bb); auto* zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0); llvm_ir::IrArray::Index zero2D({zero, zero}); @@ -90,7 +90,7 @@ TEST_F(CpuNoAliasTest, Concat) { ir_module.getOrInsertGlobal("param_x", array2d_type); llvm_ir::IrArray param_x_array(param_x_val, param_shape); aa.AddAliasingInformationToIrArray(*param_x, ¶m_x_array); - param_x_array.EmitReadArrayElement(zero2D, &ir_builder) + param_x_array.EmitReadArrayElement(zero2D, &b) ->setName("read_param_x_array"); } @@ -100,7 +100,7 @@ TEST_F(CpuNoAliasTest, Concat) { auto shape = ShapeUtil::MakeShape(F32, {2, 4}); llvm_ir::IrArray concat1_array(concat1_val, shape); aa.AddAliasingInformationToIrArray(*concat1, &concat1_array); - concat1_array.EmitReadArrayElement(zero2D, &ir_builder) + concat1_array.EmitReadArrayElement(zero2D, &b) ->setName("read_concat1_array"); } @@ -110,7 +110,7 @@ TEST_F(CpuNoAliasTest, Concat) { auto shape = ShapeUtil::MakeShape(F32, {2, 6}); llvm_ir::IrArray concat2_array(concat2_val, shape); aa.AddAliasingInformationToIrArray(*concat2, &concat2_array); - concat2_array.EmitReadArrayElement(zero2D, &ir_builder) + concat2_array.EmitReadArrayElement(zero2D, &b) ->setName("read_concat2_array"); } diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index c444d151858d3a152a01b99657ffae89ebc6b487..3274be8d9dbfaa55e250748a389ad34fdeb81922 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -23,14 +23,14 @@ namespace xla { namespace cpu { VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size, - llvm::IRBuilder<>* ir_builder, + llvm::IRBuilder<>* b, std::string name) : vector_size_(vector_size), primitive_type_(primitive_type), - ir_builder_(ir_builder), + b_(b), name_(std::move(name)) { scalar_type_ = llvm_ir::PrimitiveTypeToIrType( - primitive_type, ir_builder_->GetInsertBlock()->getModule()); + primitive_type, b_->GetInsertBlock()->getModule()); scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_); vector_type_ = llvm::VectorType::get(scalar_type_, vector_size); vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_); @@ -63,9 +63,9 @@ llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) { llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs, llvm::Value* rhs) { if (scalar_type_->isFloatingPointTy()) { - return ir_builder()->CreateFMul(lhs, rhs, name()); + return b()->CreateFMul(lhs, rhs, name()); } else { - return ir_builder()->CreateMul(lhs, rhs, name()); + return b()->CreateMul(lhs, rhs, name()); } } @@ -76,13 +76,13 @@ llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) { llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) { AssertCorrectTypes({lhs, rhs}); - return ir_builder()->CreateFSub(lhs, rhs); + return b()->CreateFSub(lhs, rhs); } llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) { AssertCorrectTypes({lhs, rhs}); if (scalar_type_->isFloatingPointTy()) { - return llvm_ir::EmitFloatMax(lhs, rhs, ir_builder_); + return llvm_ir::EmitFloatMax(lhs, rhs, b_); } else { LOG(FATAL) << "Max for integers is unimplemented"; } @@ -91,13 +91,13 @@ llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) { llvm::Value* VectorSupportLibrary::Floor(llvm::Value* a) { AssertCorrectTypes({a}); return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {a}, - {a->getType()}, ir_builder()); + {a->getType()}, b()); } llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) { AssertCorrectTypes({lhs, rhs}); if (scalar_type_->isFloatingPointTy()) { - return ir_builder()->CreateFDiv(lhs, rhs, name()); + return b()->CreateFDiv(lhs, rhs, name()); } else { LOG(FATAL) << "Division for integers is unimplemented"; } @@ -111,42 +111,41 @@ llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a, CHECK(low.compare(high) == llvm::APFloat::cmpLessThan); CHECK(scalar_type_->isFloatingPointTy()); return llvm_ir::EmitFloatMin( - llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), ir_builder_), - GetConstantFloat(type, high), ir_builder_); + llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), b_), + GetConstantFloat(type, high), b_); } llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs) { AssertCorrectTypes({lhs, rhs}); - return I1ToFloat(ir_builder()->CreateFCmpOEQ(lhs, rhs, name())); + return I1ToFloat(b()->CreateFCmpOEQ(lhs, rhs, name())); } llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs) { AssertCorrectTypes({lhs, rhs}); - return I1ToFloat(ir_builder()->CreateFCmpOLT(lhs, rhs, name())); + return I1ToFloat(b()->CreateFCmpOLT(lhs, rhs, name())); } llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs) { AssertCorrectTypes({lhs, rhs}); - return I1ToFloat(ir_builder()->CreateFCmpULE(lhs, rhs, name())); + return I1ToFloat(b()->CreateFCmpULE(lhs, rhs, name())); } llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) { bool is_vector = llvm::isa(i1->getType()); llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector); - return ir_builder()->CreateBitCast( - ir_builder()->CreateSExt(i1, integer_type, name()), - is_vector ? vector_type() : scalar_type(), name()); + return b()->CreateBitCast(b()->CreateSExt(i1, integer_type, name()), + is_vector ? vector_type() : scalar_type(), name()); } llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) { CHECK(scalar_type()->isFloatingPointTy()); const llvm::DataLayout& data_layout = - ir_builder()->GetInsertBlock()->getModule()->getDataLayout(); + b()->GetInsertBlock()->getModule()->getDataLayout(); int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type()); - llvm::Type* scalar_int_type = ir_builder()->getIntNTy(float_size_bits); + llvm::Type* scalar_int_type = b()->getIntNTy(float_size_bits); if (vector) { return llvm::VectorType::get(scalar_int_type, vector_size()); } else { @@ -156,7 +155,7 @@ llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) { llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) { CHECK_EQ(x->getType(), scalar_type()); - return ir_builder()->CreateVectorSplat(vector_size(), x, name()); + return b()->CreateVectorSplat(vector_size(), x, name()); } llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs, @@ -164,10 +163,9 @@ llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs, AssertCorrectTypes({lhs, rhs}); llvm::Type* int_type = IntegerTypeForFloatSize(lhs->getType() == vector_type()); - return ir_builder()->CreateBitCast( - ir_builder()->CreateAnd( - ir_builder()->CreateBitCast(lhs, int_type, name()), - ir_builder()->CreateBitCast(rhs, int_type, name()), name()), + return b()->CreateBitCast( + b()->CreateAnd(b()->CreateBitCast(lhs, int_type, name()), + b()->CreateBitCast(rhs, int_type, name()), name()), vector_type()); } @@ -175,9 +173,8 @@ llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) { AssertCorrectTypes({lhs}); llvm::Type* int_type = IntegerTypeForFloatSize(lhs->getType() == vector_type()); - return ir_builder()->CreateBitCast( - ir_builder()->CreateNot( - ir_builder()->CreateBitCast(lhs, int_type, name()), name()), + return b()->CreateBitCast( + b()->CreateNot(b()->CreateBitCast(lhs, int_type, name()), name()), vector_type()); } @@ -185,47 +182,43 @@ llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) { AssertCorrectTypes({lhs, rhs}); llvm::Type* int_type = IntegerTypeForFloatSize(lhs->getType() == vector_type()); - return ir_builder()->CreateBitCast( - ir_builder()->CreateOr(ir_builder()->CreateBitCast(lhs, int_type, name()), - ir_builder()->CreateBitCast(rhs, int_type, name()), - name()), + return b()->CreateBitCast( + b()->CreateOr(b()->CreateBitCast(lhs, int_type, name()), + b()->CreateBitCast(rhs, int_type, name()), name()), vector_type(), name()); } llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs, llvm::Value* rhs) { if (scalar_type_->isFloatingPointTy()) { - return ir_builder()->CreateFAdd(lhs, rhs, name()); + return b()->CreateFAdd(lhs, rhs, name()); } else { - return ir_builder()->CreateAdd(lhs, rhs, name()); + return b()->CreateAdd(lhs, rhs, name()); } } llvm::Value* VectorSupportLibrary::ComputeOffsetPointer( llvm::Value* base_pointer, llvm::Value* offset_elements) { if (base_pointer->getType() != scalar_pointer_type()) { - base_pointer = ir_builder()->CreateBitCast(base_pointer, - scalar_pointer_type(), name()); + base_pointer = + b()->CreateBitCast(base_pointer, scalar_pointer_type(), name()); } - return ir_builder()->CreateInBoundsGEP(base_pointer, {offset_elements}, - name()); + return b()->CreateInBoundsGEP(base_pointer, {offset_elements}, name()); } llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) { if (pointer->getType() != vector_pointer_type()) { - pointer = - ir_builder()->CreateBitCast(pointer, vector_pointer_type(), name()); + pointer = b()->CreateBitCast(pointer, vector_pointer_type(), name()); } - return ir_builder()->CreateAlignedLoad( + return b()->CreateAlignedLoad( pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name()); } llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) { if (pointer->getType() != scalar_pointer_type()) { - pointer = - ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); + pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name()); } - return ir_builder()->CreateAlignedLoad( + return b()->CreateAlignedLoad( pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name()); } @@ -233,30 +226,28 @@ void VectorSupportLibrary::StoreVector(llvm::Value* value, llvm::Value* pointer) { AssertCorrectTypes({value}); if (pointer->getType() != vector_pointer_type()) { - pointer = ir_builder()->CreateBitCast(pointer, vector_pointer_type()); + pointer = b()->CreateBitCast(pointer, vector_pointer_type()); } - ir_builder()->CreateAlignedStore( - value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); + b()->CreateAlignedStore(value, pointer, + ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); } void VectorSupportLibrary::StoreScalar(llvm::Value* value, llvm::Value* pointer) { AssertCorrectTypes({value}); if (pointer->getType() != scalar_pointer_type()) { - pointer = - ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); + pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name()); } - ir_builder()->CreateAlignedStore( - value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); + b()->CreateAlignedStore(value, pointer, + ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); } llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) { if (pointer->getType() != scalar_pointer_type()) { - pointer = - ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); + pointer = b()->CreateBitCast(pointer, scalar_pointer_type(), name()); } - return ir_builder()->CreateVectorSplat( - vector_size(), ir_builder()->CreateLoad(pointer), name()); + return b()->CreateVectorSplat(vector_size(), b()->CreateLoad(pointer), + name()); } llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) { @@ -267,20 +258,19 @@ llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) { for (unsigned j = 0; j < vector_size(); ++j) { if (j < (i / 2)) { - mask[j] = ir_builder()->getInt32(i / 2 + j); + mask[j] = b()->getInt32(i / 2 + j); } else { - mask[j] = llvm::UndefValue::get(ir_builder()->getInt32Ty()); + mask[j] = llvm::UndefValue::get(b()->getInt32Ty()); } } - llvm::Value* half_remaining_lanes = ir_builder()->CreateShuffleVector( - vector, llvm::UndefValue::get(vector_type()), - llvm::ConstantVector::get(mask), ""); + llvm::Value* half_remaining_lanes = + b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask), ""); vector = Add(vector, half_remaining_lanes); } - return ir_builder()->CreateExtractElement(vector, ir_builder()->getInt32(0), - name()); + return b()->CreateExtractElement(vector, b()->getInt32(0), name()); } llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs, @@ -307,19 +297,19 @@ llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs, // vector, which are the lanes 2 and 3 in the rhs vector. for (int i = 0; i < vector_size(); i += 2) { int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2); - mask_a.push_back(ir_builder()->getInt32(increment + i)); - mask_b.push_back(ir_builder()->getInt32(increment + i + 1)); + mask_a.push_back(b()->getInt32(increment + i)); + mask_b.push_back(b()->getInt32(increment + i + 1)); } for (int i = 0; i < vector_size(); i += 2) { int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size(); - mask_a.push_back(ir_builder()->getInt32(increment + i)); - mask_b.push_back(ir_builder()->getInt32(increment + i + 1)); + mask_a.push_back(b()->getInt32(increment + i)); + mask_b.push_back(b()->getInt32(increment + i + 1)); } - llvm::Value* shuffle_0 = ir_builder()->CreateShuffleVector( - lhs, rhs, llvm::ConstantVector::get(mask_a)); - llvm::Value* shuffle_1 = ir_builder()->CreateShuffleVector( - lhs, rhs, llvm::ConstantVector::get(mask_b)); + llvm::Value* shuffle_0 = + b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_a)); + llvm::Value* shuffle_1 = + b()->CreateShuffleVector(lhs, rhs, llvm::ConstantVector::get(mask_b)); return Add(shuffle_0, shuffle_1); } @@ -327,23 +317,21 @@ llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs, llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) { llvm::SmallVector mask; for (int i = 0; i < vector_size() / 2; i++) { - mask.push_back(ir_builder()->getInt32(i)); + mask.push_back(b()->getInt32(i)); } - return ir_builder()->CreateShuffleVector(vector, - llvm::UndefValue::get(vector_type()), - llvm::ConstantVector::get(mask)); + return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask)); } llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) { llvm::SmallVector mask; for (int i = 0; i < vector_size() / 2; i++) { - mask.push_back(ir_builder()->getInt32(i + vector_size() / 2)); + mask.push_back(b()->getInt32(i + vector_size() / 2)); } - return ir_builder()->CreateShuffleVector(vector, - llvm::UndefValue::get(vector_type()), - llvm::ConstantVector::get(mask)); + return b()->CreateShuffleVector(vector, llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask)); } std::vector VectorSupportLibrary::ComputeHorizontalSums( @@ -360,8 +348,8 @@ std::vector VectorSupportLibrary::ComputeHorizontalSums( [this](llvm::Value* vector) { return AddReduce(vector); }); if (init_values) { for (int64 i = 0, e = result.size(); i < e; i++) { - result[i] = Add(result[i], ir_builder()->CreateExtractElement( - init_values, ir_builder()->getInt32(i))); + result[i] = Add(result[i], + b()->CreateExtractElement(init_values, b()->getInt32(i))); } } return result; @@ -398,9 +386,9 @@ VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums( std::vector results; for (int i = 0; i < lane_width; i++) { - llvm::Value* scalar_result = ir_builder()->CreateExtractElement( - i < (lane_width / 2) ? low : high, - ir_builder()->getInt32(i % (lane_width / 2)), name()); + llvm::Value* scalar_result = + b()->CreateExtractElement(i < (lane_width / 2) ? low : high, + b()->getInt32(i % (lane_width / 2)), name()); results.push_back(scalar_result); } @@ -415,17 +403,14 @@ llvm::Value* VectorSupportLibrary::GetZeroScalar() { return llvm::Constant::getNullValue(scalar_type()); } -LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* ir_builder) - : ir_builder_(ir_builder) { - alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", ir_builder_); +LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* b) : b_(b) { + alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", b_); } -llvm::Value* LlvmVariable::Get() const { - return ir_builder_->CreateLoad(alloca_); -} +llvm::Value* LlvmVariable::Get() const { return b_->CreateLoad(alloca_); } void LlvmVariable::Set(llvm::Value* new_value) { - ir_builder_->CreateStore(new_value, alloca_); + b_->CreateStore(new_value, alloca_); } TileVariable::TileVariable(VectorSupportLibrary* vector_support, diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index 49c2a4e2f4bae9e1672b7d2fe891301bce08bd4b..c728f6df0aef83e6ddc6c932a347f14da06d9d0d 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -46,11 +46,11 @@ class VectorSupportLibrary { // instance (i.e. LoadVector will load a vector of type <`vector_size` x // `primitive_type`>). VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size, - llvm::IRBuilder<>* ir_builder, std::string name); + llvm::IRBuilder<>* b, std::string name); llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs); llvm::Value* Mul(int64 lhs, llvm::Value* rhs) { - return Mul(ir_builder()->getInt64(lhs), rhs); + return Mul(b()->getInt64(lhs), rhs); } llvm::Value* Mul(const llvm::APFloat& lhs, llvm::Value* rhs) { return Mul(GetConstantFloat(rhs->getType(), lhs), rhs); @@ -63,7 +63,7 @@ class VectorSupportLibrary { llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs); llvm::Value* Add(int64 lhs, llvm::Value* rhs) { - return Add(ir_builder()->getInt64(lhs), rhs); + return Add(b()->getInt64(lhs), rhs); } llvm::Value* Add(const llvm::APFloat& lhs, llvm::Value* rhs) { return Add(GetConstantFloat(rhs->getType(), lhs), rhs); @@ -147,13 +147,11 @@ class VectorSupportLibrary { llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, llvm::Value* offset_elements, int64 scale) { return ComputeOffsetPointer( - base_pointer, - ir_builder_->CreateMul(ir_builder_->getInt64(scale), offset_elements)); + base_pointer, b_->CreateMul(b_->getInt64(scale), offset_elements)); } llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, int64 offset_elements) { - return ComputeOffsetPointer(base_pointer, - ir_builder()->getInt64(offset_elements)); + return ComputeOffsetPointer(base_pointer, b()->getInt64(offset_elements)); } llvm::Value* LoadVector(llvm::Value* pointer); @@ -164,7 +162,7 @@ class VectorSupportLibrary { } llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) { - return LoadVector(base_pointer, ir_builder()->getInt64(offset_elements)); + return LoadVector(base_pointer, b()->getInt64(offset_elements)); } llvm::Value* LoadScalar(llvm::Value* pointer); @@ -175,7 +173,7 @@ class VectorSupportLibrary { } llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) { - return LoadScalar(base_pointer, ir_builder()->getInt64(offset_elements)); + return LoadScalar(base_pointer, b()->getInt64(offset_elements)); } void StoreVector(llvm::Value* value, llvm::Value* pointer); @@ -187,7 +185,7 @@ class VectorSupportLibrary { void StoreVector(llvm::Value* value, llvm::Value* base_pointer, int64 offset_elements) { - StoreVector(value, base_pointer, ir_builder()->getInt64(offset_elements)); + StoreVector(value, base_pointer, b()->getInt64(offset_elements)); } void StoreScalar(llvm::Value* value, llvm::Value* pointer); @@ -198,7 +196,7 @@ class VectorSupportLibrary { void StoreScalar(llvm::Value* value, llvm::Value* base_pointer, int64 offset_elements) { - StoreScalar(base_pointer, ir_builder()->getInt64(offset_elements)); + StoreScalar(base_pointer, b()->getInt64(offset_elements)); } llvm::Value* LoadBroadcast(llvm::Value* pointer); @@ -207,7 +205,7 @@ class VectorSupportLibrary { return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements)); } llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) { - return LoadBroadcast(base_pointer, ir_builder()->getInt64(offset_elements)); + return LoadBroadcast(base_pointer, b()->getInt64(offset_elements)); } // Compute the horizontal sum of each vector in `vectors`. The i'th element @@ -220,7 +218,7 @@ class VectorSupportLibrary { llvm::Value* GetZeroVector(); llvm::Value* GetZeroScalar(); - llvm::IRBuilder<>* ir_builder() const { return ir_builder_; } + llvm::IRBuilder<>* b() const { return b_; } int64 vector_size() const { return vector_size_; } llvm::Type* vector_type() const { return vector_type_; } llvm::Type* vector_pointer_type() const { return vector_pointer_type_; } @@ -277,7 +275,7 @@ class VectorSupportLibrary { int64 vector_size_; PrimitiveType primitive_type_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; llvm::Type* vector_type_; llvm::Type* vector_pointer_type_; llvm::Type* scalar_type_; @@ -289,22 +287,21 @@ class VectorSupportLibrary { // can later convert to a SSA value. class LlvmVariable { public: - LlvmVariable(llvm::Type*, llvm::IRBuilder<>* ir_builder); + LlvmVariable(llvm::Type*, llvm::IRBuilder<>* b); llvm::Value* Get() const; void Set(llvm::Value* new_value); private: llvm::AllocaInst* alloca_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; }; class VectorVariable : public LlvmVariable { public: VectorVariable(VectorSupportLibrary* vector_support, llvm::Value* initial_value) - : LlvmVariable(vector_support->vector_type(), - vector_support->ir_builder()) { + : LlvmVariable(vector_support->vector_type(), vector_support->b()) { Set(initial_value); } }; @@ -313,8 +310,7 @@ class ScalarVariable : public LlvmVariable { public: ScalarVariable(VectorSupportLibrary* vector_support, llvm::Value* initial_value) - : LlvmVariable(vector_support->scalar_type(), - vector_support->ir_builder()) { + : LlvmVariable(vector_support->scalar_type(), vector_support->b()) { Set(initial_value); } }; diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc index 32b5c5d35fae61ae6cb17fafcada1abd6c3c088c..e727ba49cb6321e499b5d50d5f45e7f7f6bb6fef 100644 --- a/tensorflow/compiler/xla/service/defuser_test.cc +++ b/tensorflow/compiler/xla/service/defuser_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/defuser.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" @@ -124,7 +124,7 @@ TEST_F(DefuserTest, NonTrivialFusionInstruction) { auto div = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3)); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto add2 = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); @@ -162,7 +162,7 @@ TEST_F(DefuserTest, MultipleFusionInstructions) { auto div = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3)); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto add2 = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index cb3676c5ba9b55ef4cb46dbd97f84ea9a6a6c5d0..86d57581f84920e8005e8f3c420e7488fc095434 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" @@ -76,6 +76,7 @@ class DfsHloVisitorBase { virtual Status HandleClamp(HloInstructionPtr hlo) = 0; virtual Status HandleSelect(HloInstructionPtr hlo) = 0; + virtual Status HandleTupleSelect(HloInstructionPtr hlo) = 0; virtual Status HandleMaximum(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } @@ -105,6 +106,7 @@ class DfsHloVisitorBase { virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; + virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } @@ -211,6 +213,7 @@ class DfsHloVisitorBase { virtual Status HandleReverse(HloInstructionPtr hlo) = 0; virtual Status HandleSort(HloInstructionPtr hlo) = 0; virtual Status HandleConstant(HloInstructionPtr hlo) = 0; + virtual Status HandleIota(HloInstructionPtr hlo) = 0; virtual Status HandleGetTupleElement(HloInstructionPtr hlo) = 0; virtual Status HandleReduce(HloInstructionPtr hlo) = 0; virtual Status HandleBitcast(HloInstructionPtr hlo) = 0; @@ -231,6 +234,7 @@ class DfsHloVisitorBase { virtual Status HandleWhile(HloInstructionPtr hlo) = 0; virtual Status HandleConditional(HloInstructionPtr hlo) = 0; virtual Status HandleGather(HloInstructionPtr hlo) = 0; + virtual Status HandleScatter(HloInstructionPtr hlo) = 0; virtual Status HandlePad(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 987c91e5ba3eb01a7535d162cbcf6441d568adae..617a5a2eb4796d8003099e39e3d26389e532e954 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" @@ -79,6 +79,9 @@ class DfsHloVisitorWithDefaultBase Status HandleSelect(HloInstructionPtr select) override { return DefaultAction(select); } + Status HandleTupleSelect(HloInstructionPtr tuple_select) override { + return DefaultAction(tuple_select); + } Status HandleDot(HloInstructionPtr dot) override { return DefaultAction(dot); } @@ -91,6 +94,9 @@ class DfsHloVisitorWithDefaultBase Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } + Status HandleAllToAll(HloInstructionPtr crs) override { + return DefaultAction(crs); + } Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); } @@ -112,6 +118,9 @@ class DfsHloVisitorWithDefaultBase Status HandleConstant(HloInstructionPtr constant) override { return DefaultAction(constant); } + Status HandleIota(HloInstructionPtr iota) override { + return DefaultAction(iota); + } Status HandleGetTupleElement(HloInstructionPtr get_tuple_element) override { return DefaultAction(get_tuple_element); } @@ -188,6 +197,9 @@ class DfsHloVisitorWithDefaultBase Status HandleGather(HloInstructionPtr gather) override { return DefaultAction(gather); } + Status HandleScatter(HloInstructionPtr scatter) override { + return DefaultAction(scatter); + } Status HandleAfterAll(HloInstructionPtr token) override { return DefaultAction(token); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index ce0951bbe1873973c7b97055aba5ba71a14ad24f..6aab317ca5b89fec1a01f92295a41c9fc26ccee1 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -61,13 +61,13 @@ int64 GlobalRandomValue() { llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, int64 mantissa_bits, - llvm::IRBuilder<>* ir_builder) { + llvm::IRBuilder<>* b) { // Integer and float types for casting and constant generation. llvm::Type* float_type = x->getType(); - llvm::IntegerType* int_type = ir_builder->getInt32Ty(); + llvm::IntegerType* int_type = b->getInt32Ty(); // Cast the input value to an integer for bitwise manipulation. - llvm::Value* x_as_int = ir_builder->CreateBitCast(x, int_type); + llvm::Value* x_as_int = b->CreateBitCast(x, int_type); if (mantissa_bits < 23) { // Last remaining mantissa bit. @@ -77,22 +77,22 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, // equal to a base value of 0111... plus one bit if the last remaining // mantissa bit is 1. const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; - llvm::Value* x_last_mantissa_bit = ir_builder->CreateLShr( - ir_builder->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), + llvm::Value* x_last_mantissa_bit = b->CreateLShr( + b->CreateAnd(x_as_int, + llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), (23 - mantissa_bits)); - llvm::Value* x_rounding_bias = ir_builder->CreateAdd( - x_last_mantissa_bit, - llvm::ConstantInt::get(int_type, base_rounding_bias)); + llvm::Value* x_rounding_bias = + b->CreateAdd(x_last_mantissa_bit, + llvm::ConstantInt::get(int_type, base_rounding_bias)); // Add rounding bias, and mask out truncated bits. Note that the case // where adding the rounding bias overflows into the exponent bits is // correct; the non-masked mantissa bits will all be zero, and the // exponent will be incremented by one. const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); - x_as_int = ir_builder->CreateAdd(x_as_int, x_rounding_bias); - x_as_int = ir_builder->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, truncation_mask)); + x_as_int = b->CreateAdd(x_as_int, x_rounding_bias); + x_as_int = b->CreateAnd(x_as_int, + llvm::ConstantInt::get(int_type, truncation_mask)); } if (exponent_bits < 8) { @@ -120,29 +120,29 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, f32_exponent_bias - reduced_exponent_bias; // Do we overflow or underflow? - llvm::Value* x_exponent = ir_builder->CreateAnd( + llvm::Value* x_exponent = b->CreateAnd( x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); - llvm::Value* x_overflows = ir_builder->CreateICmpUGT( + llvm::Value* x_overflows = b->CreateICmpUGT( x_exponent, llvm::ConstantInt::get(int_type, reduced_max_exponent << 23)); - llvm::Value* x_underflows = ir_builder->CreateICmpULE( + llvm::Value* x_underflows = b->CreateICmpULE( x_exponent, llvm::ConstantInt::get(int_type, reduced_min_exponent << 23)); // Compute appropriately-signed values of zero and infinity. - llvm::Value* x_signed_zero = ir_builder->CreateAnd( + llvm::Value* x_signed_zero = b->CreateAnd( x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask)); - llvm::Value* x_signed_inf = ir_builder->CreateOr( + llvm::Value* x_signed_inf = b->CreateOr( x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); // Force to zero or infinity if overflow or underflow. (Note that this // truncates all denormal values to zero, rather than rounding them.) - x_as_int = ir_builder->CreateSelect(x_overflows, x_signed_inf, x_as_int); - x_as_int = ir_builder->CreateSelect(x_underflows, x_signed_zero, x_as_int); + x_as_int = b->CreateSelect(x_overflows, x_signed_inf, x_as_int); + x_as_int = b->CreateSelect(x_underflows, x_signed_zero, x_as_int); } // Cast the result back to a floating-point type. - llvm::Value* result = ir_builder->CreateBitCast(x_as_int, float_type); + llvm::Value* result = b->CreateBitCast(x_as_int, float_type); // Correct result for NaN inputs. // @@ -154,53 +154,49 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, // // If the fast-math flags are set to assume no NaNs, the comparison is likely // to be optimized away, so there's no point in even emitting it. - if (!ir_builder->getFastMathFlags().noNaNs()) { - llvm::Value* x_is_nan = ir_builder->CreateFCmpUNO(x, x); + if (!b->getFastMathFlags().noNaNs()) { + llvm::Value* x_is_nan = b->CreateFCmpUNO(x, x); if (mantissa_bits > 0) { - result = ir_builder->CreateSelect(x_is_nan, x, result); + result = b->CreateSelect(x_is_nan, x, result); } else { - result = ir_builder->CreateSelect( + result = b->CreateSelect( x_is_nan, llvm::ConstantFP::getInfinity(float_type), result); } } return result; } -llvm::Value* EmitF32ToBF16(llvm::Value* f32_value, - llvm::IRBuilder<>* ir_builder) { +llvm::Value* EmitF32ToBF16(llvm::Value* f32_value, llvm::IRBuilder<>* b) { auto reduced_precision = EmitReducePrecisionFloat( f32_value, /*exponent_bits=*/primitive_util::kBFloat16ExponentBits, - /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, ir_builder); - auto as_int32 = - ir_builder->CreateBitCast(reduced_precision, ir_builder->getInt32Ty()); - auto shifted = ir_builder->CreateLShr(as_int32, 16); - auto truncated = ir_builder->CreateTrunc(shifted, ir_builder->getInt16Ty()); - return ir_builder->CreateBitCast(truncated, ir_builder->getInt16Ty()); + /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, b); + auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty()); + auto shifted = b->CreateLShr(as_int32, 16); + auto truncated = b->CreateTrunc(shifted, b->getInt16Ty()); + return b->CreateBitCast(truncated, b->getInt16Ty()); } -llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, - llvm::IRBuilder<>* ir_builder) { - auto as_int16 = - ir_builder->CreateBitCast(bf16_value, ir_builder->getInt16Ty()); - auto as_int32 = ir_builder->CreateZExt(as_int16, ir_builder->getInt32Ty()); - auto shifted = ir_builder->CreateShl(as_int32, 16); - return ir_builder->CreateBitCast(shifted, ir_builder->getFloatTy()); +llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, llvm::IRBuilder<>* b) { + auto as_int16 = b->CreateBitCast(bf16_value, b->getInt16Ty()); + auto as_int32 = b->CreateZExt(as_int16, b->getInt32Ty()); + auto shifted = b->CreateShl(as_int32, 16); + return b->CreateBitCast(shifted, b->getFloatTy()); } llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, PrimitiveType from_type, PrimitiveType to_type, llvm::Module* module, - llvm::IRBuilder<>* ir_builder) { + llvm::IRBuilder<>* b) { if (primitive_util::IsSignedIntegralType(from_type)) { - return ir_builder->CreateSIToFP( - integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module)); + return b->CreateSIToFP(integer_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module)); } else { CHECK(primitive_util::IsUnsignedIntegralType(from_type) || from_type == PRED); - return ir_builder->CreateUIToFP( - integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module)); + return b->CreateUIToFP(integer_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module)); } } @@ -226,39 +222,43 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); PrimitiveType to_type = op->shape().element_type(); - CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED); + CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED) + << from_type; if (from_type == to_type) { return operand_value; } + if (to_type == PRED) { + return b_->CreateZExt( + b_->CreateICmpNE(operand_value, llvm::ConstantInt::get( + operand_value->getType(), 0)), + llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + } if (primitive_util::IsIntegralType(to_type)) { - return ir_builder_->CreateIntCast( + return b_->CreateIntCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_), primitive_util::IsSignedIntegralType(from_type)); } if (primitive_util::IsFloatingPointType(to_type)) { if (to_type == BF16) { - return EmitF32ToBF16( - EmitIntegralToFloating(operand_value, from_type, F32, module_, - ir_builder_), - ir_builder_); + return EmitF32ToBF16(EmitIntegralToFloating(operand_value, from_type, + F32, module_, b_), + b_); } return EmitIntegralToFloating(operand_value, from_type, to_type, - module_, ir_builder_); + module_, b_); } if (primitive_util::IsComplexType(to_type)) { auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType( primitive_util::ComplexComponentType(to_type), module_); if (primitive_util::IsSignedIntegralType(from_type)) { return EmitComposeComplex( - op, - ir_builder_->CreateSIToFP(operand_value, to_ir_component_type), + op, b_->CreateSIToFP(operand_value, to_ir_component_type), nullptr); } if (primitive_util::IsUnsignedIntegralType(from_type) || from_type == PRED) { return EmitComposeComplex( - op, - ir_builder_->CreateUIToFP(operand_value, to_ir_component_type), + op, b_->CreateUIToFP(operand_value, to_ir_component_type), nullptr); } } @@ -275,7 +275,7 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } if (primitive_util::BitWidth(from_type) == primitive_util::BitWidth(to_type)) { - return ir_builder_->CreateBitCast( + return b_->CreateBitCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return InvalidArgument( @@ -293,18 +293,18 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); auto zero = llvm::ConstantInt::get(type, 0); - auto cmp = ir_builder_->CreateICmpSGE(operand_value, zero); - return ir_builder_->CreateSelect(cmp, operand_value, - ir_builder_->CreateNeg(operand_value)); + auto cmp = b_->CreateICmpSGE(operand_value, zero); + return b_->CreateSelect(cmp, operand_value, + b_->CreateNeg(operand_value)); } else { return operand_value; } } case HloOpcode::kClz: { - auto is_zero_undef = ir_builder_->getFalse(); - return llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::ctlz, {operand_value, is_zero_undef}, - {operand_value->getType()}, ir_builder_); + auto is_zero_undef = b_->getFalse(); + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctlz, + {operand_value, is_zero_undef}, + {operand_value->getType()}, b_); } case HloOpcode::kSign: { bool is_signed = @@ -312,31 +312,28 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); auto zero = llvm::ConstantInt::get(type, 0); - auto cmp = ir_builder_->CreateICmpEQ(operand_value, zero); + auto cmp = b_->CreateICmpEQ(operand_value, zero); if (is_signed) { - auto ashr = ir_builder_->CreateAShr(operand_value, - type->getIntegerBitWidth() - 1); - return ir_builder_->CreateSelect(cmp, zero, - ir_builder_->CreateOr(ashr, 1)); + auto ashr = + b_->CreateAShr(operand_value, type->getIntegerBitWidth() - 1); + return b_->CreateSelect(cmp, zero, b_->CreateOr(ashr, 1)); } else { - return ir_builder_->CreateSelect(cmp, zero, - llvm::ConstantInt::get(type, 1)); + return b_->CreateSelect(cmp, zero, llvm::ConstantInt::get(type, 1)); } } case HloOpcode::kNegate: - return ir_builder_->CreateNeg(operand_value); + return b_->CreateNeg(operand_value); case HloOpcode::kNot: { auto type = op->shape().element_type(); if (type == PRED) { // It is not sufficient to just call CreateNot() here because a PRED // is represented as an i8 and the truth value is stored only in the // bottom bit. - return ir_builder_->CreateZExt( - ir_builder_->CreateNot(ir_builder_->CreateTrunc( - operand_value, ir_builder_->getInt1Ty())), + return b_->CreateZExt( + b_->CreateNot(b_->CreateTrunc(operand_value, b_->getInt1Ty())), llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } else if (primitive_util::IsIntegralType(type)) { - return ir_builder_->CreateNot(operand_value); + return b_->CreateNot(operand_value); } return Unimplemented("unary op Not is not defined for type '%d'", type); } @@ -352,7 +349,7 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); PrimitiveType to_type = op->shape().element_type(); - CHECK(primitive_util::IsFloatingPointType(from_type)); + CHECK(primitive_util::IsFloatingPointType(from_type)) << from_type; if (from_type == to_type) { return operand_value; } @@ -364,32 +361,38 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } return EmitComposeComplex( op, - ir_builder_->CreateFPCast( - operand_value, - llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)), + b_->CreateFPCast(operand_value, llvm_ir::PrimitiveTypeToIrType( + to_component_type, module_)), nullptr); } if (from_type == BF16) { TF_RET_CHECK(to_type != BF16); - operand_value = EmitBF16ToF32(operand_value, ir_builder_); + operand_value = EmitBF16ToF32(operand_value, b_); from_type = F32; if (from_type == to_type) { return operand_value; } } if (from_type == F32 && to_type == BF16) { - return EmitF32ToBF16(operand_value, ir_builder_); + return EmitF32ToBF16(operand_value, b_); + } + if (to_type == PRED) { + return b_->CreateZExt( + b_->CreateFCmpUNE( + operand_value, + llvm::ConstantFP::get(operand_value->getType(), 0.0)), + llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } if (primitive_util::IsFloatingPointType(to_type)) { - return ir_builder_->CreateFPCast( + return b_->CreateFPCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsSignedIntegralType(to_type)) { - return ir_builder_->CreateFPToSI( + return b_->CreateFPToSI( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsUnsignedIntegralType(to_type)) { - return ir_builder_->CreateFPToUI( + return b_->CreateFPToUI( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return Unimplemented("unhandled conversion operation: %s => %s", @@ -405,7 +408,7 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } if (primitive_util::BitWidth(from_type) == primitive_util::BitWidth(to_type)) { - return ir_builder_->CreateBitCast( + return b_->CreateBitCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return InvalidArgument( @@ -429,45 +432,49 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( case HloOpcode::kSin: return EmitSin(op->shape().element_type(), operand_value); case HloOpcode::kFloor: - return llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()}, - ir_builder_); + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, + {operand_value}, + {operand_value->getType()}, b_); case HloOpcode::kCeil: - return llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::ceil, {operand_value}, {operand_value->getType()}, - ir_builder_); + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ceil, + {operand_value}, + {operand_value->getType()}, b_); case HloOpcode::kAbs: - return llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::fabs, {operand_value}, {operand_value->getType()}, - ir_builder_); + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, + {operand_value}, + {operand_value->getType()}, b_); case HloOpcode::kRoundNearestAfz: - return llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::round, {operand_value}, {operand_value->getType()}, - ir_builder_); + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::round, + {operand_value}, + {operand_value->getType()}, b_); case HloOpcode::kSign: { // TODO(b/32151903): Ensure consistent sign behavior for -0.0. auto type = operand_value->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = ir_builder_->CreateFCmpOEQ(operand_value, zero); - auto olt = ir_builder_->CreateFCmpOLT(operand_value, zero); - return ir_builder_->CreateSelect( + auto oeq = b_->CreateFCmpOEQ(operand_value, zero); + auto olt = b_->CreateFCmpOLT(operand_value, zero); + return b_->CreateSelect( oeq, zero, - ir_builder_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0), - llvm::ConstantFP::get(type, 1.0))); + b_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0), + llvm::ConstantFP::get(type, 1.0))); } case HloOpcode::kIsFinite: { // abs(x) o!= inf, this works because the comparison returns false if // either operand is NaN. auto type = operand_value->getType(); auto abs_value = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_); + llvm::Intrinsic::fabs, {operand_value}, {type}, b_); auto infinity = llvm::ConstantFP::getInfinity(type); - auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity); - return ir_builder_->CreateZExt( - not_infinite, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + auto not_infinite = b_->CreateFCmpONE(abs_value, infinity); + return b_->CreateZExt(not_infinite, + llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } case HloOpcode::kNegate: - return ir_builder_->CreateFNeg(operand_value); + return b_->CreateFNeg(operand_value); + case HloOpcode::kReal: + return operand_value; + case HloOpcode::kImag: + return llvm::ConstantFP::get(operand_value->getType(), 0.0); default: return Unimplemented("unary floating-point op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -487,13 +494,12 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); - auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), - ir_builder_->CreateFMul(b, b)); + auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b)); TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a)); auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex( - op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); + return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq), + angle); } case HloOpcode::kLog1p: { // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) @@ -501,15 +507,14 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); auto one = llvm::ConstantFP::get(llvm_ty, 1.0); - auto a_plus_one = ir_builder_->CreateFAdd(a, one); - auto sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(a_plus_one, a_plus_one), - ir_builder_->CreateFMul(b, b)); + auto a_plus_one = b_->CreateFAdd(a, one); + auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a_plus_one, a_plus_one), + b_->CreateFMul(b, b)); TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one)); auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex( - op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); + return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq), + angle); } case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -523,12 +528,11 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( primitive_util::ComplexComponentType(to_type); auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(to_component_type, module_); - return EmitComposeComplex( - op, - ir_builder_->CreateFPCast(EmitExtractReal(operand_value), - to_ir_component_type), - ir_builder_->CreateFPCast(EmitExtractImag(operand_value), - to_ir_component_type)); + return EmitComposeComplex(op, + b_->CreateFPCast(EmitExtractReal(operand_value), + to_ir_component_type), + b_->CreateFPCast(EmitExtractImag(operand_value), + to_ir_component_type)); } case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) @@ -538,8 +542,8 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value))); TF_ASSIGN_OR_RETURN( auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); - return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), - ir_builder_->CreateFMul(exp_a, sin_b)); + return EmitComposeComplex(op, b_->CreateFMul(exp_a, cos_b), + b_->CreateFMul(exp_a, sin_b)); } case HloOpcode::kExpm1: { // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i @@ -550,9 +554,8 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( TF_ASSIGN_OR_RETURN( auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0); - auto real_result = - ir_builder_->CreateFSub(ir_builder_->CreateFMul(exp_a, cos_b), one); - auto imag_result = ir_builder_->CreateFMul(exp_a, sin_b); + auto real_result = b_->CreateFSub(b_->CreateFMul(exp_a, cos_b), one); + auto imag_result = b_->CreateFMul(exp_a, sin_b); return EmitComposeComplex(op, real_result, imag_result); } case HloOpcode::kCos: { @@ -567,18 +570,14 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); auto type = a->getType(); TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); - auto half_exp_b = - ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); auto half_exp_neg_b = - ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); + b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); return EmitComposeComplex( - op, - ir_builder_->CreateFMul( - cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), - ir_builder_->CreateFMul( - sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b))); + op, b_->CreateFMul(cos_a, b_->CreateFAdd(half_exp_neg_b, half_exp_b)), + b_->CreateFMul(sin_a, b_->CreateFSub(half_exp_neg_b, half_exp_b))); } case HloOpcode::kSin: { // sin(z) = .5i(e^(-iz) - e^(iz)) @@ -594,18 +593,14 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); auto type = a->getType(); TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); - auto half_exp_b = - ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); auto half_exp_neg_b = - ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); + b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); return EmitComposeComplex( - op, - ir_builder_->CreateFMul( - sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), - ir_builder_->CreateFMul( - cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); + op, b_->CreateFMul(sin_a, b_->CreateFAdd(half_exp_b, half_exp_neg_b)), + b_->CreateFMul(cos_a, b_->CreateFSub(half_exp_b, half_exp_neg_b))); } case HloOpcode::kTanh: { /* @@ -633,64 +628,61 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a)); TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b)); TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b)); - auto exp_neg_a = ir_builder_->CreateFDiv( - llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); - auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(exp_a, exp_a), - ir_builder_->CreateFMul(exp_neg_a, exp_neg_a)); - auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b); - auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b); - auto real_num = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), - ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); - auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b); - auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a); + auto exp_neg_a = + b_->CreateFDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); + auto exp_2a_minus_exp_neg_2a = b_->CreateFSub( + b_->CreateFMul(exp_a, exp_a), b_->CreateFMul(exp_neg_a, exp_neg_a)); + auto cos_b_sq = b_->CreateFMul(cos_b, cos_b); + auto sin_b_sq = b_->CreateFMul(sin_b, sin_b); + auto real_num = + b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), + b_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); + auto cos_b_sin_b = b_->CreateFMul(cos_b, sin_b); + auto exp_a_plus_exp_neg_a = b_->CreateFAdd(exp_a, exp_neg_a); auto exp_a_plus_exp_neg_a_sq = - ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); - auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a); + b_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); + auto exp_a_minus_exp_neg_a = b_->CreateFSub(exp_a, exp_neg_a); auto exp_a_minus_exp_neg_a_sq = - ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); - auto imag_num = ir_builder_->CreateFMul( - cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq, - exp_a_minus_exp_neg_a_sq)); - auto denom = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), - ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); - return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom), - ir_builder_->CreateFDiv(imag_num, denom)); + b_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); + auto imag_num = b_->CreateFMul( + cos_b_sin_b, + b_->CreateFSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq)); + auto denom = + b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), + b_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); + return EmitComposeComplex(op, b_->CreateFDiv(real_num, denom), + b_->CreateFDiv(imag_num, denom)); } case HloOpcode::kAbs: { - auto sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(EmitExtractReal(operand_value), - EmitExtractReal(operand_value)), - ir_builder_->CreateFMul(EmitExtractImag(operand_value), - EmitExtractImag(operand_value))); + auto sum_sq = + b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value), + EmitExtractReal(operand_value)), + b_->CreateFMul(EmitExtractImag(operand_value), + EmitExtractImag(operand_value))); return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq}, - {sum_sq->getType()}, ir_builder_); + {sum_sq->getType()}, b_); } case HloOpcode::kSign: { // Sign(c) = c / |c| - auto sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(EmitExtractReal(operand_value), - EmitExtractReal(operand_value)), - ir_builder_->CreateFMul(EmitExtractImag(operand_value), - EmitExtractImag(operand_value))); + auto sum_sq = + b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value), + EmitExtractReal(operand_value)), + b_->CreateFMul(EmitExtractImag(operand_value), + EmitExtractImag(operand_value))); auto cplx_abs = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_); + llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, b_); auto type = cplx_abs->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = ir_builder_->CreateFCmpOEQ(cplx_abs, zero); - return ir_builder_->CreateSelect( + auto oeq = b_->CreateFCmpOEQ(cplx_abs, zero); + return b_->CreateSelect( oeq, EmitComposeComplex(op, zero, zero), EmitComposeComplex( - op, - ir_builder_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs), - ir_builder_->CreateFDiv(EmitExtractImag(operand_value), - cplx_abs))); + op, b_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs), + b_->CreateFDiv(EmitExtractImag(operand_value), cplx_abs))); } case HloOpcode::kNegate: - return EmitComposeComplex( - op, ir_builder_->CreateFNeg(EmitExtractReal(operand_value)), - ir_builder_->CreateFNeg(EmitExtractImag(operand_value))); + return EmitComposeComplex(op, + b_->CreateFNeg(EmitExtractReal(operand_value)), + b_->CreateFNeg(EmitExtractImag(operand_value))); case HloOpcode::kReal: return EmitExtractReal(operand_value); case HloOpcode::kImag: @@ -724,15 +716,15 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( case HloOpcode::kComplex: return EmitComposeComplex(op, lhs_value, rhs_value); case HloOpcode::kAdd: - return ir_builder_->CreateFAdd(lhs_value, rhs_value); + return b_->CreateFAdd(lhs_value, rhs_value); case HloOpcode::kSubtract: - return ir_builder_->CreateFSub(lhs_value, rhs_value); + return b_->CreateFSub(lhs_value, rhs_value); case HloOpcode::kMultiply: - return ir_builder_->CreateFMul(lhs_value, rhs_value); + return b_->CreateFMul(lhs_value, rhs_value); case HloOpcode::kDivide: - return ir_builder_->CreateFDiv(lhs_value, rhs_value); + return b_->CreateFDiv(lhs_value, rhs_value); case HloOpcode::kRemainder: - return ir_builder_->CreateFRem(lhs_value, rhs_value); + return b_->CreateFRem(lhs_value, rhs_value); // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered // comparisons always return false when one of the operands is NaN, whereas // unordered comparisons return true. @@ -742,22 +734,22 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( // matches C++'s semantics. case HloOpcode::kEq: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value, - rhs_value, ir_builder_); + rhs_value, b_); case HloOpcode::kNe: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value, - rhs_value, ir_builder_); + rhs_value, b_); case HloOpcode::kLt: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value, - rhs_value, ir_builder_); + rhs_value, b_); case HloOpcode::kGt: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value, - rhs_value, ir_builder_); + rhs_value, b_); case HloOpcode::kLe: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value, - rhs_value, ir_builder_); + rhs_value, b_); case HloOpcode::kGe: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value, - rhs_value, ir_builder_); + rhs_value, b_); case HloOpcode::kMaximum: return EmitFloatMax(lhs_value, rhs_value); @@ -778,64 +770,56 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( llvm::Value* rhs_value) const { switch (op->opcode()) { case HloOpcode::kAdd: - return EmitComposeComplex( - op, - ir_builder_->CreateFAdd(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - ir_builder_->CreateFAdd(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); + return EmitComposeComplex(op, + b_->CreateFAdd(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + b_->CreateFAdd(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); case HloOpcode::kSubtract: - return EmitComposeComplex( - op, - ir_builder_->CreateFSub(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - ir_builder_->CreateFSub(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); + return EmitComposeComplex(op, + b_->CreateFSub(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + b_->CreateFSub(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); case HloOpcode::kMultiply: return EmitComposeComplex( op, - ir_builder_->CreateFSub( - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))), - ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value)))); + b_->CreateFSub(b_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + b_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))), + b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value)), + b_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value)))); case HloOpcode::kDivide: { // (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di)) // = ((ac + bd) + (bc - ad)i) / (c^2 + d^2) - auto rhs_sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(EmitExtractReal(rhs_value), - EmitExtractReal(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(rhs_value), - EmitExtractImag(rhs_value))); + auto rhs_sum_sq = + b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(rhs_value), + EmitExtractReal(rhs_value)), + b_->CreateFMul(EmitExtractImag(rhs_value), + EmitExtractImag(rhs_value))); auto type = rhs_sum_sq->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = ir_builder_->CreateFCmpOEQ(rhs_sum_sq, zero); - auto real_inf_or_nan = - ir_builder_->CreateFDiv(EmitExtractReal(lhs_value), zero); - auto imag_inf_or_nan = - ir_builder_->CreateFDiv(EmitExtractImag(lhs_value), zero); - return ir_builder_->CreateSelect( + auto oeq = b_->CreateFCmpOEQ(rhs_sum_sq, zero); + auto real_inf_or_nan = b_->CreateFDiv(EmitExtractReal(lhs_value), zero); + auto imag_inf_or_nan = b_->CreateFDiv(EmitExtractImag(lhs_value), zero); + return b_->CreateSelect( oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan), EmitComposeComplex( op, - ir_builder_->CreateFDiv( - ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))), + b_->CreateFDiv( + b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + b_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))), rhs_sum_sq), - ir_builder_->CreateFDiv( - ir_builder_->CreateFSub( - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value)), - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value))), + b_->CreateFDiv( + b_->CreateFSub(b_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value)), + b_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value))), rhs_sum_sq))); } // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered @@ -846,21 +830,21 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( // unordered comparison. This makes x != y equivalent to !(x == y), and // matches C++'s semantics. case HloOpcode::kEq: - return ir_builder_->CreateAnd( + return b_->CreateAnd( llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), ir_builder_), + EmitExtractReal(rhs_value), b_), llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), ir_builder_)); + EmitExtractImag(rhs_value), b_)); case HloOpcode::kNe: - return ir_builder_->CreateOr( + return b_->CreateOr( llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), ir_builder_), + EmitExtractReal(rhs_value), b_), llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), ir_builder_)); + EmitExtractImag(rhs_value), b_)); case HloOpcode::kPower: { // (a+bi)^(c+di) = @@ -872,29 +856,26 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( auto b = EmitExtractImag(lhs_value); auto c = EmitExtractReal(rhs_value); auto d = EmitExtractImag(rhs_value); - auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), - ir_builder_->CreateFMul(b, b)); + auto aa_p_bb = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b)); auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); - auto half_c = ir_builder_->CreateFMul(one_half, c); + auto half_c = b_->CreateFMul(one_half, c); TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, EmitPow(component_type, aa_p_bb, half_c)); - auto neg_d = ir_builder_->CreateFNeg(d); + auto neg_d = b_->CreateFNeg(d); TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a)); - auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs); + auto neg_d_arg_lhs = b_->CreateFMul(neg_d, arg_lhs); TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, EmitExp(component_type, neg_d_arg_lhs)); - auto coeff = - ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); + auto coeff = b_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb)); - auto half_d = ir_builder_->CreateFMul(one_half, d); - auto q = - ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs), - ir_builder_->CreateFMul(half_d, ln_aa_p_bb)); + auto half_d = b_->CreateFMul(one_half, d); + auto q = b_->CreateFAdd(b_->CreateFMul(c, arg_lhs), + b_->CreateFMul(half_d, ln_aa_p_bb)); TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); - return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q), - ir_builder_->CreateFMul(coeff, sin_q)); + return EmitComposeComplex(op, b_->CreateFMul(coeff, cos_q), + b_->CreateFMul(coeff, sin_q)); } default: return Unimplemented("binary complex op '%s'", @@ -904,12 +885,12 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value) const { - return llvm_ir::EmitFloatMax(lhs_value, rhs_value, ir_builder_); + return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_); } llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value) const { - return llvm_ir::EmitFloatMin(lhs_value, rhs_value, ir_builder_); + return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_); } StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, @@ -921,15 +902,14 @@ StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, "type F32."); } auto getFloat = [&](const float f) { - return llvm::ConstantFP::get(ir_builder_->getFloatTy(), f); + return llvm::ConstantFP::get(b_->getFloatTy(), f); }; auto multiply_add = [&](tensorflow::gtl::ArraySlice coefficients, llvm::Value* w) { llvm::Value* p = getFloat(coefficients.front()); coefficients.pop_front(); for (float coefficient : coefficients) { - p = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(p, w), - getFloat(coefficient)); + p = b_->CreateFAdd(b_->CreateFMul(p, w), getFloat(coefficient)); } return p; }; @@ -947,50 +927,48 @@ StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, // } // return p*x llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration( - module_, llvm::Intrinsic::log, {ir_builder_->getFloatTy()}); + module_, llvm::Intrinsic::log, {b_->getFloatTy()}); - llvm::Value* w = ir_builder_->CreateFNeg(ir_builder_->CreateCall( - logf_fn, - {ir_builder_->CreateFMul(ir_builder_->CreateFSub(getFloat(1.0f), x), - ir_builder_->CreateFAdd(getFloat(1.0f), x))})); + llvm::Value* w = b_->CreateFNeg(b_->CreateCall( + logf_fn, {b_->CreateFMul(b_->CreateFSub(getFloat(1.0f), x), + b_->CreateFAdd(getFloat(1.0f), x))})); - llvm::Value* p_addr = llvm_ir::EmitAllocaAtFunctionEntry( - ir_builder_->getFloatTy(), "p.addr", ir_builder_); + llvm::Value* p_addr = + llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_); - llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(ir_builder_->CreateFCmpOLT(w, getFloat(5.0f)), - "w_less_than_five", ir_builder_); + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + b_->CreateFCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_); // Handle true BB. - SetToFirstInsertPoint(if_data.true_block, ir_builder_); + SetToFirstInsertPoint(if_data.true_block, b_); { - llvm::Value* lw = ir_builder_->CreateFSub(w, getFloat(2.5f)); + llvm::Value* lw = b_->CreateFSub(w, getFloat(2.5f)); tensorflow::gtl::ArraySlice lq{ 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, -4.39150654e-06f, 0.00021858087f, -0.00125372503f, -0.00417768164f, 0.246640727f, 1.50140941f}; llvm::Value* p = multiply_add(lq, lw); - ir_builder_->CreateStore(p, p_addr); + b_->CreateStore(p, p_addr); } // Handle false BB. - SetToFirstInsertPoint(if_data.false_block, ir_builder_); + SetToFirstInsertPoint(if_data.false_block, b_); { llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( - module_, llvm::Intrinsic::sqrt, {ir_builder_->getFloatTy()}); + module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()}); - llvm::Value* gw = ir_builder_->CreateFSub( - ir_builder_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f)); + llvm::Value* gw = + b_->CreateFSub(b_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f)); tensorflow::gtl::ArraySlice gq{ -0.000200214257f, 0.000100950558f, 0.00134934322f, -0.00367342844f, 0.00573950773f, -0.0076224613f, 0.00943887047f, 1.00167406f, 2.83297682f}; llvm::Value* p = multiply_add(gq, gw); - ir_builder_->CreateStore(p, p_addr); + b_->CreateStore(p, p_addr); } - SetToFirstInsertPoint(if_data.after_block, ir_builder_); - llvm::Value* p = ir_builder_->CreateLoad(p_addr); - return ir_builder_->CreateFMul(p, x); + SetToFirstInsertPoint(if_data.after_block, b_); + llvm::Value* p = b_->CreateLoad(p_addr); + return b_->CreateFMul(p, x); } StatusOr ElementalIrEmitter::EmitErfcInv( @@ -998,13 +976,13 @@ StatusOr ElementalIrEmitter::EmitErfcInv( // Compute erfcinv(value) by calculating erfinv(1.0 - value). auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); - return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value)); + return EmitErfInv(prim_type, b_->CreateFSub(one, value)); } StatusOr ElementalIrEmitter::EmitLog(PrimitiveType prim_type, llvm::Value* value) const { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value}, - {value->getType()}, ir_builder_); + {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, @@ -1016,35 +994,34 @@ StatusOr ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, // When x is large, the naive evaluation of ln(x + 1) is more // accurate than the Taylor series. TF_ASSIGN_OR_RETURN(auto for_large_x, - EmitLog(prim_type, ir_builder_->CreateFAdd(x, one))); + EmitLog(prim_type, b_->CreateFAdd(x, one))); // The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + …. - auto for_small_x = ir_builder_->CreateFMul( - ir_builder_->CreateFAdd(ir_builder_->CreateFMul(negative_half, x), one), - x); + auto for_small_x = + b_->CreateFMul(b_->CreateFAdd(b_->CreateFMul(negative_half, x), one), x); const auto kAntilogarithmIsSmallThreshold = 1e-4; - auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, - {type}, ir_builder_); - auto x_is_small = ir_builder_->CreateFCmpOLT( + auto abs_x = + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); + auto x_is_small = b_->CreateFCmpOLT( abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold)); - return ir_builder_->CreateSelect(x_is_small, for_small_x, for_large_x); + return b_->CreateSelect(x_is_small, for_small_x, for_large_x); } StatusOr ElementalIrEmitter::EmitSin(PrimitiveType prim_type, llvm::Value* value) const { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, - {value->getType()}, ir_builder_); + {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitCos(PrimitiveType prim_type, llvm::Value* value) const { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value}, - {value->getType()}, ir_builder_); + {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitExp(PrimitiveType prim_type, llvm::Value* value) const { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value}, - {value->getType()}, ir_builder_); + {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, @@ -1056,25 +1033,25 @@ StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, // When the exponent is large, the naive evaluation of e^(x) - 1 is more // accurate than the Taylor series. TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value)); - auto for_large_x = ir_builder_->CreateFSub(exp_x, one); + auto for_large_x = b_->CreateFSub(exp_x, one); // The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + …. // We want exp(x)-1 which is x + x^2/2 + x^3/6 + …. - auto x_squared = ir_builder_->CreateFAdd(x, x); - auto x_squared_over_two = ir_builder_->CreateFMul(x_squared, half); - auto for_small_x = ir_builder_->CreateFAdd(x, x_squared_over_two); + auto x_squared = b_->CreateFAdd(x, x); + auto x_squared_over_two = b_->CreateFMul(x_squared, half); + auto for_small_x = b_->CreateFAdd(x, x_squared_over_two); const auto kExponentIsSmallThreshold = 1e-5; - auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, - {type}, ir_builder_); - auto x_is_small = ir_builder_->CreateFCmpOLT( + auto abs_x = + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); + auto x_is_small = b_->CreateFCmpOLT( abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold)); - return ir_builder_->CreateSelect(x_is_small, for_small_x, for_large_x); + return b_->CreateSelect(x_is_small, for_small_x, for_large_x); } StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs}, - {lhs->getType()}, ir_builder_); + {lhs->getType()}, b_); } StatusOr ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, @@ -1089,11 +1066,10 @@ StatusOr ElementalIrEmitter::EmitReducePrecision( return Unimplemented("reduce-precision only implemented for F32"); } return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(), - /*mantissa_bits=*/hlo->mantissa_bits(), - ir_builder_); + /*mantissa_bits=*/hlo->mantissa_bits(), b_); } -static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* ir_builder, +static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* shift_result, bool saturate_to_sign_bit) { @@ -1106,15 +1082,14 @@ static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* ir_builder, llvm::ConstantInt* minus_one = llvm::ConstantInt::get(integer_type, -1); llvm::Value* saturated_value; if (saturate_to_sign_bit) { - saturated_value = ir_builder->CreateSelect( - ir_builder->CreateICmpSLT(lhs, zero), minus_one, zero); + saturated_value = + b->CreateSelect(b->CreateICmpSLT(lhs, zero), minus_one, zero); } else { saturated_value = zero; } llvm::Value* shift_amt_in_range = - ir_builder->CreateICmpULT(rhs, integer_bitsize_constant, "shft.chk"); - return ir_builder->CreateSelect(shift_amt_in_range, shift_result, - saturated_value); + b->CreateICmpULT(rhs, integer_bitsize_constant, "shft.chk"); + return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value); } StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( @@ -1123,49 +1098,49 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( switch (op->opcode()) { // TODO(jingyue): add the "nsw" attribute for signed types. case HloOpcode::kAdd: - return ir_builder_->CreateAdd(lhs_value, rhs_value); + return b_->CreateAdd(lhs_value, rhs_value); case HloOpcode::kSubtract: - return ir_builder_->CreateSub(lhs_value, rhs_value); + return b_->CreateSub(lhs_value, rhs_value); case HloOpcode::kMultiply: - return ir_builder_->CreateMul(lhs_value, rhs_value); + return b_->CreateMul(lhs_value, rhs_value); case HloOpcode::kDivide: - return is_signed ? ir_builder_->CreateSDiv(lhs_value, rhs_value) - : ir_builder_->CreateUDiv(lhs_value, rhs_value); + return is_signed ? b_->CreateSDiv(lhs_value, rhs_value) + : b_->CreateUDiv(lhs_value, rhs_value); case HloOpcode::kRemainder: - return is_signed ? ir_builder_->CreateSRem(lhs_value, rhs_value) - : ir_builder_->CreateURem(lhs_value, rhs_value); + return is_signed ? b_->CreateSRem(lhs_value, rhs_value) + : b_->CreateURem(lhs_value, rhs_value); case HloOpcode::kEq: return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value, - rhs_value, ir_builder_); + rhs_value, b_); case HloOpcode::kNe: return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value, - rhs_value, ir_builder_); + rhs_value, b_); case HloOpcode::kLt: return llvm_ir::EmitComparison( is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT, - lhs_value, rhs_value, ir_builder_); + lhs_value, rhs_value, b_); case HloOpcode::kGt: return llvm_ir::EmitComparison( is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT, - lhs_value, rhs_value, ir_builder_); + lhs_value, rhs_value, b_); case HloOpcode::kLe: return llvm_ir::EmitComparison( is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE, - lhs_value, rhs_value, ir_builder_); + lhs_value, rhs_value, b_); case HloOpcode::kGe: return llvm_ir::EmitComparison( is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE, - lhs_value, rhs_value, ir_builder_); + lhs_value, rhs_value, b_); case HloOpcode::kMinimum: return EmitIntegralMin(lhs_value, rhs_value, is_signed); case HloOpcode::kMaximum: return EmitIntegralMax(lhs_value, rhs_value, is_signed); case HloOpcode::kAnd: - return ir_builder_->CreateAnd(lhs_value, rhs_value); + return b_->CreateAnd(lhs_value, rhs_value); case HloOpcode::kOr: - return ir_builder_->CreateOr(lhs_value, rhs_value); + return b_->CreateOr(lhs_value, rhs_value); case HloOpcode::kXor: - return ir_builder_->CreateXor(lhs_value, rhs_value); + return b_->CreateXor(lhs_value, rhs_value); // Shifting out bits >= the number of bits in the type being shifted // produces a poison value in LLVM which is basically "deferred undefined @@ -1173,20 +1148,17 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( // UB. We replace the poison value with a constant to avoid this deferred // UB. case HloOpcode::kShiftRightArithmetic: - return SaturateShiftIfNecessary( - ir_builder_, lhs_value, rhs_value, - ir_builder_->CreateAShr(lhs_value, rhs_value), - /*saturate_to_sign_bit=*/true); + return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, + b_->CreateAShr(lhs_value, rhs_value), + /*saturate_to_sign_bit=*/true); case HloOpcode::kShiftLeft: - return SaturateShiftIfNecessary( - ir_builder_, lhs_value, rhs_value, - ir_builder_->CreateShl(lhs_value, rhs_value), - /*saturate_to_sign_bit=*/false); + return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, + b_->CreateShl(lhs_value, rhs_value), + /*saturate_to_sign_bit=*/false); case HloOpcode::kShiftRightLogical: - return SaturateShiftIfNecessary( - ir_builder_, lhs_value, rhs_value, - ir_builder_->CreateLShr(lhs_value, rhs_value), - /*saturate_to_sign_bit=*/false); + return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, + b_->CreateLShr(lhs_value, rhs_value), + /*saturate_to_sign_bit=*/false); default: return Unimplemented("binary integer op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -1196,21 +1168,19 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, bool is_signed) const { - return ir_builder_->CreateSelect( - ir_builder_->CreateICmp( - is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE, - lhs_value, rhs_value), - lhs_value, rhs_value); + return b_->CreateSelect(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE + : llvm::ICmpInst::ICMP_UGE, + lhs_value, rhs_value), + lhs_value, rhs_value); } llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, bool is_signed) const { - return ir_builder_->CreateSelect( - ir_builder_->CreateICmp( - is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE, - lhs_value, rhs_value), - lhs_value, rhs_value); + return b_->CreateSelect(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE + : llvm::ICmpInst::ICMP_ULE, + lhs_value, rhs_value), + lhs_value, rhs_value); } llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( @@ -1227,7 +1197,14 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( // If no implicit broadcast is needed for this operand, returns the target // index as the source index. - if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape())) { + // + // `IrArray::Index` may contain a physical linear which we can propagate to + // our operand only if our layouts match. "only if" is a bit strong since + // e.g. we can still forward the linear index if the operand shape is + // [5,1,1,5]{3,2,1,0} and the HLO shape is[5,1,1,5]{3,1,2,0}, but those cases + // are probably not worth handling here for now. + if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape()) && + LayoutUtil::Equal(operand_shape.layout(), hlo.shape().layout())) { return target_index; } @@ -1246,180 +1223,265 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( return source_index; } -llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( +StatusOr ElementalIrEmitter::ConvertValueForDistribution( const HloInstruction* hlo, - const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) - const { - PrimitiveType param_prim_type = hlo->operand(0)->shape().element_type(); - llvm::Type* param_ir_type = - llvm_ir::PrimitiveTypeToIrType(param_prim_type, module_); - - // Same values as PCG library - // https://github.com/imneme/pcg-c/blob/master/include/pcg_variants.h - llvm::Value* multiplier = ir_builder_->getInt( - llvm::APInt(128, {0x4385DF649FCCF645, 0x2360ED051FC65DA4})); - llvm::Value* increment = ir_builder_->getInt( - llvm::APInt(128, {0x14057B7EF767814F, 0x5851F42D4C957F2D})); - - auto random_value_from_hlo = [hlo]() { - const HloModule* module = - hlo->IsFused() ? hlo->parent()->FusionInstruction()->parent()->parent() - : hlo->parent()->parent(); - return module->RandomNew64(); - }; + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const { + TF_ASSIGN_OR_RETURN(llvm::Value * a_or_mean, + operand_to_generator.at(hlo->operand(0))(index)); + TF_ASSIGN_OR_RETURN(llvm::Value * b_or_sigma, + operand_to_generator.at(hlo->operand(1))(index)); + PrimitiveType elem_prim_ty = hlo->shape().element_type(); + llvm::Type* elem_ir_ty = + llvm_ir::PrimitiveTypeToIrType(elem_prim_ty, module_); + llvm::Type* raw_value_ty = raw_value->getType(); + + // Convert raw integer to float in range [0, 1) if the element is a float. + llvm::Value* elem_value = raw_value; + if (elem_ir_ty->isFloatingPointTy()) { + unsigned raw_value_size_in_bits = raw_value_ty->getPrimitiveSizeInBits(); + CHECK(raw_value_size_in_bits == 32 || raw_value_size_in_bits == 64); + // Perform the division using the float type with the same number of bits + // as the raw value to avoid overflow. + if (raw_value_size_in_bits == 32) { + elem_value = b_->CreateUIToFP(elem_value, b_->getFloatTy()); + elem_value = b_->CreateFDiv( + elem_value, llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32))); + } else { + elem_value = b_->CreateUIToFP(elem_value, b_->getDoubleTy()); + elem_value = b_->CreateFDiv( + elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64))); + } - // Seed each RNG emitter with a new 64-bit seed from the HloModule. If the - // compilation order is deterministic (i.e., RandomNew64 invocation order is - // deterministic), then the order of RNG is deterministic for a given seed and - // hence tests will be deterministic. - // If the user provides a global seed instruction then we only use 64-bits of - // the host's random number generator to seed the 128 bit value with the other - // 64-bits is due to a user specified global seed instruction. - // Create a GlobalVariable to maintain state between invocations. There is a - // bug in NVPTX with GlobalVariable and 128 bit values, so using 2 64-bit + if (elem_ir_ty != elem_value->getType()) { + elem_value = b_->CreateFPTrunc(elem_value, elem_ir_ty); + } + } + + // Convert the value for the requested distribution. + switch (hlo->random_distribution()) { + case RNG_UNIFORM: { + if (elem_ir_ty->isFloatingPointTy()) { + return b_->CreateFAdd( + b_->CreateFMul(b_->CreateFSub(b_or_sigma, a_or_mean), elem_value), + a_or_mean); + } else { + // To generate a uniform random value in [a, b) from a raw random sample + // in range [0, 2^N), we let range = b - a and return + // (a + raw_value % range). If range is not a power of 2, raw values + // larger than (2^N - 2^N % range) are biased toward results in + // [a, a + (limit % range)). An unbiased algorithm would need to drop + // raw values and re-sample, but we don't do this because re-sampling in + // an efficient way is complex, and it's not clear that users need it. + // In particular, if one thread in a GPU warp needs to re-sample, we pay + // the same cost as if the whole warp were to re-sample. So an + // efficient re-sampling implementation on GPU would need to do + // nontrivial work to share entropy between threads in the warp. + auto range = b_->CreateSub(b_or_sigma, a_or_mean); + return b_->CreateAdd(a_or_mean, b_->CreateURem(elem_value, range)); + } + } + case RNG_NORMAL: { + TF_ASSIGN_OR_RETURN( + llvm::Value * r, + EmitErfcInv(elem_prim_ty, + b_->CreateFMul(llvm::ConstantFP::get(elem_ir_ty, 2.0), + elem_value))); + return b_->CreateFAdd(b_->CreateFMul(r, b_or_sigma), a_or_mean); + } + default: + return InvalidArgument( + "unhandled distribution %s", + RandomDistribution_Name(hlo->random_distribution()).c_str()); + } +} + +namespace { + +// Checks that the primitive type is supported by the elemental IR emitter for +// Philox RNG and returns the number of elements in each 128 bit sample of the +// Philox RNG algorithm. +int32 GetNumberOfElementsPerPhiloxRngSample(PrimitiveType elem_prim_ty) { + // Calculate the number of elements, that is the number of random numbers, in + // a 128 bit sample. + switch (elem_prim_ty) { + case U32: + case S32: + case F32: + // The algorithm uses 32 bits to generate values for F16. + case F16: + return 4; + case U64: + case S64: + case F64: + return 2; + default: + // BF16 is converted to F16 by the hlo pass HloElementTypeConverter. + // Other data types are not supported by XLA random operation. + LOG(FATAL) << "Unrecognized primitive type for RNG " << elem_prim_ty; + } + return 0; +} + +// Calculates the four uint32 values for the 128-bit Philox sample. +std::array CalculateSampleValues( + llvm::Value* sample_idx, llvm::Value* hlo_random_value, + llvm::Value* global_random_number, llvm::Value* rng_state, + llvm::IRBuilder<>* b) { + llvm::Type* index_ty = sample_idx->getType(); + + std::array counter_values; + + // Use the sample index to initialize counter[0] and counter[1]. + unsigned index_ty_size_in_bits = index_ty->getPrimitiveSizeInBits(); + CHECK(index_ty_size_in_bits == 32 || index_ty_size_in_bits == 64); + if (index_ty_size_in_bits == 32) { + counter_values[0] = sample_idx; + counter_values[1] = b->getInt32(0); + } else { + std::tie(counter_values[0], counter_values[1]) = + llvm_ir::SplitInt64ToInt32s(b, sample_idx); + } + + // Xor the global state variable with the global random number seed and use + // the result to initialize counter[2] and counter[3]. + std::tie(counter_values[2], counter_values[3]) = llvm_ir::SplitInt64ToInt32s( + b, b->CreateXor(rng_state, global_random_number)); + + // The algorithm uses a 64 bit key, which is also interpreted as two uint32 // values. - llvm::GlobalVariable* state_ptr0 = new llvm::GlobalVariable( - /*M=*/*module_, - /*Ty=*/ir_builder_->getInt64Ty(), - /*isConstant=*/false, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/ir_builder_->getInt64(random_value_from_hlo()), - /*Name=*/"state_ptr0"); - - // When the module config seed is 0, the expected result of a prng is a random - // value. Instead of using the random_value_from_hlo, we need a global random - // value as the graph seed. This is because if we use random_value_from_hlo - // here, then for a newly built hlo graph, it always gives the same number. - uint64 graph_seed = hlo_module_config_.seed() != 0 ? hlo_module_config_.seed() - : GlobalRandomValue(); - llvm::GlobalVariable* state_ptr1 = new llvm::GlobalVariable( - /*M=*/*module_, - /*Ty=*/ir_builder_->getInt64Ty(), - /*isConstant=*/false, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/ir_builder_->getInt64(graph_seed), - /*Name=*/"state_ptr1"); - - // We want each thread to use its own stream, so we modify the increment per - // thread. We want the increment to remain odd, so we shift the thread id left - // 1 and add it to the increment. - increment = ir_builder_->CreateAdd(increment, - ir_builder_->CreateShl(EmitThreadId(), 1)); - - // PCG-XSL-RR algorithm - // http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf - // state = multiplier * state + increment - // return uint64_t(state ^ (state >> 64))) >>> (state >> 122) - // where ">>>" is bitwise rotation - auto get_next_i64 = [=]() { - llvm::Value* state0 = ir_builder_->CreateZExtOrTrunc( - ir_builder_->CreateLoad(state_ptr0, "state0"), - ir_builder_->getInt128Ty()); - llvm::Value* state1 = ir_builder_->CreateShl( - ir_builder_->CreateZExtOrTrunc( - ir_builder_->CreateLoad(state_ptr1, "state1"), - ir_builder_->getInt128Ty()), - 64); - llvm::Value* state = ir_builder_->CreateOr(state0, state1); - llvm::Value* updated = ir_builder_->CreateAdd( - ir_builder_->CreateMul(state, multiplier), increment); - ir_builder_->CreateStore( - ir_builder_->CreateTrunc(updated, ir_builder_->getInt64Ty()), - state_ptr0); - ir_builder_->CreateStore( - ir_builder_->CreateTrunc(ir_builder_->CreateLShr(updated, 64), - ir_builder_->getInt64Ty()), - state_ptr1); - - return llvm_ir::CreateRor( - ir_builder_->CreateTrunc( - ir_builder_->CreateXor(state, ir_builder_->CreateLShr(state, 64)), - ir_builder_->getInt64Ty()), - ir_builder_->CreateTrunc(ir_builder_->CreateLShr(state, 122), - ir_builder_->getInt64Ty()), - ir_builder_); - }; + llvm::Value* key_values[2]; + + // Use a module random number to initialize the key. + std::tie(key_values[0], key_values[1]) = + llvm_ir::SplitInt64ToInt32s(b, hlo_random_value); + + // Prepare the constants used in the Philox RNG Algorithm. + llvm::Value* philoxW32A = b->getInt32(0x9E3779B9); + llvm::Value* philoxW32B = b->getInt32(0xBB67AE85); + llvm::Value* philoxM4xW32A = b->getInt32(0xD2511F53); + llvm::Value* philoxM4xW32B = b->getInt32(0xCD9E8D57); + + // Compute the 128 bit value for the current sample by repeating the + // single round computation and key raising computation for ten times. + for (int round = 0; round < 10; ++round) { + // A single round of computation of the counter values is as follows: + // MultiplyHighLow(kPhiloxM4x32A, counter[0], &lo0, &hi0); + // MultiplyHighLow(kPhiloxM4x32B, counter[2], &lo1, &hi1); + // counter[0] = hi1 ^ counter[1] ^ key[0]; + // counter[1] = lo1; + // counter[2] = hi0 ^ counter[3] ^ key[1]; + // counter[3] = lo0; + llvm::Value* lo0; + llvm::Value* hi0; + std::tie(lo0, hi0) = + llvm_ir::UMulLowHigh32(b, philoxM4xW32A, counter_values[0]); + llvm::Value* lo1; + llvm::Value* hi1; + std::tie(lo1, hi1) = + llvm_ir::UMulLowHigh32(b, philoxM4xW32B, counter_values[2]); + counter_values[0] = + b->CreateXor(hi1, b->CreateXor(counter_values[1], key_values[0])); + counter_values[1] = lo1; + counter_values[2] = + b->CreateXor(hi0, b->CreateXor(counter_values[3], key_values[1])); + counter_values[3] = lo0; + key_values[0] = b->CreateAdd(key_values[0], philoxW32A); + key_values[1] = b->CreateAdd(key_values[1], philoxW32B); + } - auto get_next_uniform_float = [=]() { - return ir_builder_->CreateFDiv( - ir_builder_->CreateUIToFP(get_next_i64(), param_ir_type), - llvm::ConstantFP::get(param_ir_type, 0x1p64)); - }; + return counter_values; +} +} // namespace + +// Implements the Philox algorithm to generate random numbers in parallel. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +// +// The paper presents a few variants of the Philox algorithm, we picked the +// 4x32_10 version of the algorithm for the following reasons: +// . 4x32 uses 32-bit multiplication which is fast on GPUs. +// . The authors recommend the 10-round variant, and TensorFlow also uses it. +// +// Precondition: the RNG instruction is not fused. +llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) + const { + VLOG(3) << "Using philox RNG algorithm"; + CHECK(!hlo->IsFused()); + // A random number generated by the per module random number generator. + // This ensures that each RNG HLO generates a different random sequence. + llvm::Value* hlo_random_value = b_->getInt64(hlo->GetModule()->RandomNew64()); + // A value specified by the configuration or generated by a global random + // number generator. + llvm::Value* global_random_number = + b_->getInt64(hlo_module_config_.seed() != 0 ? hlo_module_config_.seed() + : GlobalRandomValue()); + + int elems_per_sample = + GetNumberOfElementsPerPhiloxRngSample(hlo->shape().element_type()); + + // Allocate stack storage for the 128 bit sample as four int32. + llvm::Type* int32_ty = b_->getInt32Ty(); + llvm::Value* sample_address = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + int32_ty, /*element_count=*/b_->getInt32(4), "sample", b_); + + // Load the global state variable for the Philox RNG algorithm. + llvm::GlobalVariable* rng_state_ptr = + llvm_ir::GetOrCreateVariableForPhiloxRngState(module_, b_); + llvm::Value* rng_state = b_->CreateLoad(rng_state_ptr, "rng_state_value"); + + // Build and return the elemental IR generator to generate a random value for + // the element corresponding to the current thread. + // + // This elemental IR generator computes one sample with multiple random + // numbers but only returns one random number. As a result, neighboring + // threads may calculate the same sample unnecessarily. However, if the + // kernel containing the RNG hlo is unrolled, LLVM is able to optimize away + // the duplicated computation of the same sample. In particular, if the unroll + // factor is a multiplier of elems_per_sample, LLVM is able to completely + // remove such duplicated computation. If the unroll factor is a non-trivial + // factor of elems_per_sample, LLVM can only partially remove such duplicated + // computation. return [=](const llvm_ir::IrArray::Index& index) -> StatusOr { - switch (hlo->random_distribution()) { - case RNG_UNIFORM: { - TF_ASSIGN_OR_RETURN(llvm::Value * p, - operand_to_generator.at(hlo->operand(0))(index)); - TF_ASSIGN_OR_RETURN(llvm::Value * q, - operand_to_generator.at(hlo->operand(1))(index)); - if (primitive_util::IsFloatingPointType(param_prim_type)) { - return ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(ir_builder_->CreateFSub(q, p), - get_next_uniform_float()), - p); - } else { - auto r = ir_builder_->CreateSub(q, p); - auto leading_zeros = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(true)}, - {param_ir_type}, ir_builder_); - auto in_block = ir_builder_->GetInsertBlock(); - - // A terminator should be present iff we're emitting code - // into the middle (as opposed to the end) of a basic block. - CHECK_EQ(ir_builder_->GetInsertPoint() == in_block->end(), - in_block->getTerminator() == nullptr); - - llvm::BasicBlock* body_block; - llvm::BasicBlock* out_block; - - if (ir_builder_->GetInsertPoint() == in_block->end()) { - body_block = llvm_ir::CreateBasicBlock( - nullptr, IrName(hlo, "rng_body"), ir_builder_); - out_block = llvm_ir::CreateBasicBlock( - nullptr, IrName(hlo, "rng_out"), ir_builder_); - llvm::BranchInst::Create(body_block, in_block); - } else { - body_block = in_block->splitBasicBlock( - ir_builder_->GetInsertPoint(), "rng_body"); - out_block = body_block->splitBasicBlock( - ir_builder_->GetInsertPoint(), "rng_out"); - body_block->getTerminator()->eraseFromParent(); - } - - SetToFirstInsertPoint(body_block, ir_builder_); - auto random = ir_builder_->CreateAnd( - ir_builder_->CreateZExtOrTrunc(get_next_i64(), param_ir_type), - ir_builder_->CreateLShr(llvm::ConstantInt::get(param_ir_type, ~0), - leading_zeros)); - llvm::BranchInst::Create(out_block, body_block, - ir_builder_->CreateICmpULT(random, r), - body_block); - SetToFirstInsertPoint(out_block, ir_builder_); - return ir_builder_->CreateAdd( - p, ir_builder_->CreateSelect( - ir_builder_->CreateICmpEQ(p, q), - llvm::ConstantInt::get(param_ir_type, 0), random)); - } - } - case RNG_NORMAL: { - TF_ASSIGN_OR_RETURN(llvm::Value * m, - operand_to_generator.at(hlo->operand(0))(index)); - TF_ASSIGN_OR_RETURN(llvm::Value * s, - operand_to_generator.at(hlo->operand(1))(index)); - TF_ASSIGN_OR_RETURN( - llvm::Value * r, - EmitErfcInv(param_prim_type, - ir_builder_->CreateFMul( - llvm::ConstantFP::get(param_ir_type, 2.0), - get_next_uniform_float()))); - return ir_builder_->CreateFAdd(ir_builder_->CreateFMul(r, s), m); - } - default: - return InvalidArgument( - "unhandled distribution %s", - RandomDistribution_Name(hlo->random_distribution()).c_str()); + llvm::Type* index_ty = index.GetType(); + // Calculate the linear element index. + llvm::Value* elem_idx = index.linear(); + if (elem_idx == nullptr) { + elem_idx = index.Linearize(AsInt64Slice(hlo->shape().dimensions()), b_); + } + + // Calculate the index for the 128 bit sample and the offset of the current + // element within the sample. + llvm::Value* elems_per_sample_value = + llvm::ConstantInt::get(index_ty, elems_per_sample); + llvm::Value* sample_idx = b_->CreateUDiv(elem_idx, elems_per_sample_value); + llvm::Value* elem_offset = b_->CreateURem(elem_idx, elems_per_sample_value); + + std::array counter_values = CalculateSampleValues( + sample_idx, hlo_random_value, global_random_number, rng_state, b_); + + // Store the four counter_values into the sample_address alloca so we can + // load the elem_offset'th one below. + for (int idx = 0; idx < 4; ++idx) { + b_->CreateStore(counter_values[idx], + b_->CreateInBoundsGEP(sample_address, b_->getInt32(idx))); } + + llvm::Type* int64_ty = b_->getInt64Ty(); + CHECK(elems_per_sample == 2 || elems_per_sample == 4); + llvm::Type* raw_value_ty = elems_per_sample == 2 ? int64_ty : int32_ty; + // Retrieve the raw value for the current element from the current sample. + llvm::Value* raw_elem_value = b_->CreateLoad( + b_->CreateInBoundsGEP( + b_->CreatePointerCast(sample_address, raw_value_ty->getPointerTo()), + elem_offset), + "raw_elem_value"); + + return ConvertValueForDistribution(hlo, operand_to_generator, index, + raw_elem_value); }; } @@ -1436,9 +1498,8 @@ StatusOr ElementalIrEmitter::EmitElementalSelect( TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, operand_to_generator.at(hlo->operand(2))( ElementwiseSourceIndex(index, *hlo, 2))); - return ir_builder_->CreateSelect( - ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()), - on_true_value, on_false_value); + return b_->CreateSelect(b_->CreateTrunc(pred_value, b_->getInt1Ty()), + on_true_value, on_false_value); } StatusOr ElementalIrEmitter::EmitElementalClamp( @@ -1474,64 +1535,62 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( const int64 concat_dim = hlo->dimensions(0); auto source_index = target_index; - llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock(); + llvm::BasicBlock* init_block = b_->GetInsertBlock(); // A terminator should be present iff we're emitting code // into the middle (as opposed to the end) of a basic block. - CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(), + CHECK_EQ(b_->GetInsertPoint() == init_block->end(), init_block->getTerminator() == nullptr); llvm::BasicBlock* exit_block; - if (ir_builder_->GetInsertPoint() == init_block->end()) { + if (b_->GetInsertPoint() == init_block->end()) { exit_block = llvm_ir::CreateBasicBlock( - /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_); + /*insert_before=*/nullptr, IrName(hlo, "merge"), b_); } else { - exit_block = init_block->splitBasicBlock(ir_builder_->GetInsertPoint(), + exit_block = init_block->splitBasicBlock(b_->GetInsertPoint(), AsStringRef(IrName(hlo, "merge"))); init_block->getTerminator()->eraseFromParent(); } - llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_); - llvm::PHINode* output = ir_builder_->CreatePHI( + llvm_ir::SetToFirstInsertPoint(exit_block, b_); + llvm::PHINode* output = b_->CreatePHI( llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), hlo->operands().size()); - auto prior_insert_point = ir_builder_->GetInsertPoint(); + auto prior_insert_point = b_->GetInsertPoint(); - ir_builder_->SetInsertPoint(init_block); + b_->SetInsertPoint(init_block); for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); ++operand_idx) { const HloInstruction* operand = hlo->operand(operand_idx); auto true_block = llvm_ir::CreateBasicBlock( - exit_block, StrCat("concat_index_from_operand", operand_idx), - ir_builder_); + exit_block, StrCat("concat_index_from_operand", operand_idx), b_); auto false_block = llvm_ir::CreateBasicBlock( - exit_block, StrCat("concat_index_not_from_operand", operand_idx), - ir_builder_); + exit_block, StrCat("concat_index_not_from_operand", operand_idx), b_); auto concat_dim_size = llvm::ConstantInt::get(source_index[concat_dim]->getType(), operand->shape().dimensions(concat_dim)); - ir_builder_->CreateCondBr( - ir_builder_->CreateICmpULT(source_index[concat_dim], concat_dim_size), + b_->CreateCondBr( + b_->CreateICmpULT(source_index[concat_dim], concat_dim_size), true_block, false_block); // Create the terminator of the true block before calling operand // generators, because they require non-degenerate basic blocks. - ir_builder_->SetInsertPoint( + b_->SetInsertPoint( llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block)); TF_ASSIGN_OR_RETURN(llvm::Value * value, operand_to_generator.at(operand)(source_index)); - output->addIncoming(value, ir_builder_->GetInsertBlock()); + output->addIncoming(value, b_->GetInsertBlock()); // Subtract the size of the concat dimension of the current operand // from the source index. - ir_builder_->SetInsertPoint(false_block); + b_->SetInsertPoint(false_block); source_index[concat_dim] = - ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size); + b_->CreateSub(source_index[concat_dim], concat_dim_size); } - ir_builder_->CreateUnreachable(); - ir_builder_->SetInsertPoint(exit_block, prior_insert_point); + b_->CreateUnreachable(); + b_->SetInsertPoint(exit_block, prior_insert_point); return output; } @@ -1555,22 +1614,16 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( // Clamp the start index so that the sliced portion fits in the operand: // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) + start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type); + int64 largest_valid_start_index = + input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i); + CHECK_GE(largest_valid_start_index, 0); - // TODO(b/74360564): This is implementation defined behavior, but is - // currently respected by all implementations. Change this if we ever decide - // to oficially document different behavior. - start_index_value = - ir_builder_->CreateSExtOrTrunc(start_index_value, index_type); - llvm::Value* operand_dim_size = - index_typed_const(input_hlo->shape().dimensions(i)); - llvm::Value* output_dim_size = - index_typed_const(hlo->shape().dimensions(i)); - + bool is_signed = ShapeUtil::ElementIsSigned(hlo->operand(1)->shape()); start_index_value = EmitIntegralMin( - ir_builder_->CreateSub(operand_dim_size, output_dim_size), - EmitIntegralMax(index_typed_const(0), start_index_value, - /*is_signed=*/true), - /*is_signed=*/true); + index_typed_const(largest_valid_start_index), + EmitIntegralMax(index_typed_const(0), start_index_value, is_signed), + is_signed); start_index_value->setName( AsStringRef(IrName(hlo, StrCat("start_idx", i)))); @@ -1581,7 +1634,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( for (int64 i = 0; i < rank; ++i) { // Emit IR which computes: // input_index = start_index + offset_index - input_index[i] = ir_builder_->CreateAdd(slice_start_index[i], index[i]); + input_index[i] = b_->CreateAdd(slice_start_index[i], index[i]); } return operand_to_generator.at(input_hlo)(input_index); } @@ -1603,19 +1656,22 @@ StatusOr ElementalIrEmitter::EmitElementalGather( llvm::Type* index_type = index.GetType(); // This is the index into `operand` that holds the element we want to - // generate. This index "unsafe" as in the components in here may be - // out of bounds. - IrArray::Index unsafe_operand_index(index_type); - - // First copy in the window indices to unsafe_operand_index. - for (int64 i = 0, e = operand_shape.dimensions_size(), - unsafe_operand_index_dim = 0; + // generate. + IrArray::Index operand_index(index_type); + + // First copy in the window indices to operand_index. Also collect a mapping + // from operand dimension to output window dimension. Elided window dimensions + // map to -1. + std::vector operand_to_output_dim(operand_shape.dimensions_size(), -1); + for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0; i < e; i++) { if (c_binary_search(dim_numbers.elided_window_dims(), i)) { - unsafe_operand_index.push_back(index.GetConstantWithIndexType(0)); + operand_index.push_back(index.GetConstantWithIndexType(0)); } else { - unsafe_operand_index.push_back( - index[dim_numbers.output_window_dims(unsafe_operand_index_dim++)]); + int64 output_window_dim = + dim_numbers.output_window_dims(operand_index_dim++); + operand_to_output_dim[i] = output_window_dim; + operand_index.push_back(index[output_window_dim]); } } @@ -1634,20 +1690,40 @@ StatusOr ElementalIrEmitter::EmitElementalGather( } } - auto add_to_unsafe_operand_index = [&](llvm::Value* index_component, - int64 dim) { + auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) { llvm::Value* gather_dim_component_extended = - ir_builder_->CreateSExtOrTrunc(index_component, index_type); - unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] = - ir_builder_->CreateAdd( - unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)], - gather_dim_component_extended); + b_->CreateSExtOrTrunc(index_component, index_type); + int64 operand_dim = dim_numbers.gather_dims_to_operand_dims(dim); + int64 output_dim = operand_to_output_dim[operand_dim]; + // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim. + // This means we set the iteration index to 0, so for the purpose of the + // following calculations we can consider the output dimension size to be 1. + int64 output_dim_size = + output_dim == -1 ? 1 : output_shape.dimensions(output_dim); + int64 largest_valid_start_index = + operand_shape.dimensions(operand_dim) - output_dim_size; + CHECK_GE(largest_valid_start_index, 0); + + // Clamp the gather index so that the gather region fits in the operand. + // gather_dim_component_extended_inbound = + // clamp(gather_dim_component_extended, 0, largest_valid_start_index); + + // TODO(b/111078873): This is implementation defined behavior. + bool is_signed = ShapeUtil::ElementIsSigned(indices_shape); + auto gather_dim_component_extended_inbound = EmitIntegralMin( + index.GetConstantWithIndexType(largest_valid_start_index), + EmitIntegralMax(index.GetConstantWithIndexType(0), + gather_dim_component_extended, is_signed), + is_signed); + + operand_index[operand_dim] = b_->CreateAdd( + operand_index[operand_dim], gather_dim_component_extended_inbound); }; if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, indices_generator(gather_index_index)); - add_to_unsafe_operand_index(gather_dim_component, 0); + add_to_operand_index(gather_dim_component, 0); } else { int64 index_vector_size = indices_shape.dimensions(dim_numbers.index_vector_dim()); @@ -1656,18 +1732,10 @@ StatusOr ElementalIrEmitter::EmitElementalGather( index.GetConstantWithIndexType(i); TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, indices_generator(gather_index_index)); - add_to_unsafe_operand_index(gather_dim_component, i); + add_to_operand_index(gather_dim_component, i); } } - - IrArray::Index safe_operand_index(index_type); - for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) { - safe_operand_index.push_back(ir_builder_->CreateURem( - unsafe_operand_index[i], - index.GetConstantWithIndexType(operand_shape.dimensions(i)))); - } - - return operand_generator(safe_operand_index); + return operand_generator(operand_index); } StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( @@ -1683,7 +1751,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( llvm_ir::IrArray::Index slice_limit_index(index.GetType(), rank); // Slice intersection gathers (ANDs) conditions on all ranks for which // 'input' is set to 'update' - llvm::Value* slice_intersection = ir_builder_->getTrue(); + llvm::Value* slice_intersection = b_->getTrue(); for (int64 i = 0; i < rank; ++i) { llvm::Type* index_type = index[0]->getType(); @@ -1696,36 +1764,29 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) - - // TODO(b/74360564): This is implementation defined behavior, but is - // currently respected by all implementations. Change this if we ever decide - // to oficially document different behavior. - start_index_value = - ir_builder_->CreateSExtOrTrunc(start_index_value, index_type); - llvm::Value* input_dim_size = - index_typed_const(input_hlo->shape().dimensions(i)); + start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type); llvm::Value* update_dim_size = index_typed_const(update_hlo->shape().dimensions(i)); + int64 largest_valid_start_index = + input_hlo->shape().dimensions(i) - update_hlo->shape().dimensions(i); + CHECK_GE(largest_valid_start_index, 0); - start_index_value = - EmitIntegralMin(ir_builder_->CreateSub(input_dim_size, update_dim_size), - EmitIntegralMax(index_typed_const(0), start_index_value, - /*is_signed=*/true), - /*is_signed=*/true); + bool is_signed = ShapeUtil::ElementIsSigned(start_hlo->shape()); + start_index_value = EmitIntegralMin( + index_typed_const(largest_valid_start_index), + EmitIntegralMax(index_typed_const(0), start_index_value, is_signed), + is_signed); start_index_value->setName( AsStringRef(IrName(hlo, StrCat("start_idx", i)))); slice_start_index[i] = start_index_value; - slice_limit_index[i] = - ir_builder_->CreateAdd(slice_start_index[i], update_dim_size); + slice_limit_index[i] = b_->CreateAdd(slice_start_index[i], update_dim_size); - slice_intersection = ir_builder_->CreateAnd( - slice_intersection, - ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), + slice_intersection = b_->CreateAnd( + slice_intersection, b_->CreateICmpSGE(index[i], slice_start_index[i]), "slice_intersection"); - slice_intersection = ir_builder_->CreateAnd( - slice_intersection, - ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]), + slice_intersection = b_->CreateAnd( + slice_intersection, b_->CreateICmpSLT(index[i], slice_limit_index[i]), "slice_intersection"); } @@ -1734,29 +1795,29 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // else -> return data from 'input'. llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), - "ret_value_addr", ir_builder_); - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - slice_intersection, "slice_intersection", ir_builder_); + "ret_value_addr", b_); + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(slice_intersection, "slice_intersection", b_); // Handle true BB (return data from 'update') - SetToFirstInsertPoint(if_data.true_block, ir_builder_); + SetToFirstInsertPoint(if_data.true_block, b_); // Compute update index for intersection case. llvm_ir::IrArray::Index update_index(index.GetType(), rank); for (int64 i = 0; i < rank; ++i) { - update_index[i] = ir_builder_->CreateSub(index[i], slice_start_index[i]); + update_index[i] = b_->CreateSub(index[i], slice_start_index[i]); } TF_ASSIGN_OR_RETURN(llvm::Value * true_value, operand_to_generator.at(update_hlo)(update_index)); - ir_builder_->CreateStore(true_value, ret_value_addr); + b_->CreateStore(true_value, ret_value_addr); // Handle false BB (return data from 'input') - SetToFirstInsertPoint(if_data.false_block, ir_builder_); + SetToFirstInsertPoint(if_data.false_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * false_value, operand_to_generator.at(input_hlo)(index)); - ir_builder_->CreateStore(false_value, ret_value_addr); + b_->CreateStore(false_value, ret_value_addr); - SetToFirstInsertPoint(if_data.after_block, ir_builder_); - return ir_builder_->CreateLoad(ret_value_addr); + SetToFirstInsertPoint(if_data.after_block, b_); + return b_->CreateLoad(ret_value_addr); } StatusOr ElementalIrEmitter::EmitElementalPad( @@ -1764,29 +1825,29 @@ StatusOr ElementalIrEmitter::EmitElementalPad( const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& padded_index) const { auto index = padded_index; - llvm::Value* in_bounds = ir_builder_->getTrue(); + llvm::Value* in_bounds = b_->getTrue(); for (size_t i = 0; i < index.size(); ++i) { auto index_typed_const = [=](int64 n) { return llvm::ConstantInt::get(index[i]->getType(), n); }; const auto& pad_dim = hlo->padding_config().dimensions(i); - index[i] = ir_builder_->CreateSub( - index[i], index_typed_const(pad_dim.edge_padding_low())); - in_bounds = ir_builder_->CreateAnd( - in_bounds, ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)), - "in_bounds"); - in_bounds = ir_builder_->CreateAnd( + index[i] = + b_->CreateSub(index[i], index_typed_const(pad_dim.edge_padding_low())); + in_bounds = b_->CreateAnd(in_bounds, + b_->CreateICmpSGE(index[i], index_typed_const(0)), + "in_bounds"); + in_bounds = b_->CreateAnd( in_bounds, - ir_builder_->CreateICmpEQ( + b_->CreateICmpEQ( index_typed_const(0), - ir_builder_->CreateURem( - index[i], index_typed_const(pad_dim.interior_padding() + 1))), + b_->CreateURem(index[i], + index_typed_const(pad_dim.interior_padding() + 1))), "in_bounds"); - index[i] = ir_builder_->CreateSDiv( + index[i] = b_->CreateSDiv( index[i], index_typed_const(pad_dim.interior_padding() + 1)); - in_bounds = ir_builder_->CreateAnd( + in_bounds = b_->CreateAnd( in_bounds, - ir_builder_->CreateICmpSLT( + b_->CreateICmpSLT( index[i], index_typed_const(hlo->operand(0)->shape().dimensions(i))), "in_bounds"); @@ -1799,26 +1860,26 @@ StatusOr ElementalIrEmitter::EmitElementalPad( // } llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), - "pad_result_addr", ir_builder_); + "pad_result_addr", b_); llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); - SetToFirstInsertPoint(if_data.true_block, ir_builder_); + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_); + SetToFirstInsertPoint(if_data.true_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, operand_to_generator.at(hlo->operand(0))(index)); - ir_builder_->CreateStore(operand_value, ret_value_addr); + b_->CreateStore(operand_value, ret_value_addr); - SetToFirstInsertPoint(if_data.false_block, ir_builder_); + SetToFirstInsertPoint(if_data.false_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, operand_to_generator.at(hlo->operand(1))( IrArray::Index(index.GetType()))); - ir_builder_->CreateStore(padding_value, ret_value_addr); + b_->CreateStore(padding_value, ret_value_addr); - SetToFirstInsertPoint(if_data.after_block, ir_builder_); + SetToFirstInsertPoint(if_data.after_block, b_); // Don't create phi(operand_value, padding_value) here, because invoking // operand_to_generator may create new basic blocks, making the parent // of operand_value or padding_value no longer a predecessor of // if_data.after_block. - return ir_builder_->CreateLoad(ret_value_addr); + return b_->CreateLoad(ret_value_addr); } StatusOr ElementalIrEmitter::EmitElementalDot( @@ -1842,21 +1903,20 @@ StatusOr ElementalIrEmitter::EmitElementalDot( return llvm::ConstantInt::get(index_type, c); }; - std::unique_ptr inner_loop = - llvm_ir::ForLoop::EmitForLoop(IrName(hlo, "inner"), index_typed_const(0), - index_typed_const(contracted_dim_size), - index_typed_const(1), ir_builder_); + std::unique_ptr inner_loop = llvm_ir::ForLoop::EmitForLoop( + IrName(hlo, "inner"), index_typed_const(0), + index_typed_const(contracted_dim_size), index_typed_const(1), b_); - SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), ir_builder_); + SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), b_); PrimitiveType primitive_type = hlo->shape().element_type(); llvm::Type* primitive_type_llvm = llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); - llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry( - primitive_type_llvm, "dot_acc", ir_builder_); - ir_builder_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm), - accumulator_alloca); + llvm::Value* accumulator_alloca = + llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_); + b_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm), + accumulator_alloca); - SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_); + SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_); // This is the inner reduction loop for a dot operation that produces // one element in the output. If the operands to the dot operation have @@ -1876,43 +1936,36 @@ StatusOr ElementalIrEmitter::EmitElementalDot( } rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue()); - llvm::Value* current_accumulator = - ir_builder_->CreateLoad(accumulator_alloca); + llvm::Value* current_accumulator = b_->CreateLoad(accumulator_alloca); TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); llvm::Value* next_accumulator; if (primitive_util::IsComplexType(primitive_type)) { - llvm::Value* product_real = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); - llvm::Value* product_imag = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value))); - next_accumulator = ir_builder_->CreateInsertValue( + llvm::Value* product_real = b_->CreateFSub( + b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); + llvm::Value* product_imag = b_->CreateFAdd( + b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)), + b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))); + next_accumulator = b_->CreateInsertValue( current_accumulator, - ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator), - product_real), + b_->CreateFAdd(EmitExtractReal(current_accumulator), product_real), {0}); - next_accumulator = ir_builder_->CreateInsertValue( + next_accumulator = b_->CreateInsertValue( next_accumulator, - ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator), - product_imag), + b_->CreateFAdd(EmitExtractImag(current_accumulator), product_imag), {1}); } else if (primitive_util::IsFloatingPointType(primitive_type)) { - next_accumulator = ir_builder_->CreateFAdd( - current_accumulator, ir_builder_->CreateFMul(lhs_value, rhs_value)); + next_accumulator = b_->CreateFAdd(current_accumulator, + b_->CreateFMul(lhs_value, rhs_value)); } else { - next_accumulator = ir_builder_->CreateAdd( - current_accumulator, ir_builder_->CreateMul(lhs_value, rhs_value)); + next_accumulator = + b_->CreateAdd(current_accumulator, b_->CreateMul(lhs_value, rhs_value)); } - ir_builder_->CreateStore(next_accumulator, accumulator_alloca); + b_->CreateStore(next_accumulator, accumulator_alloca); - SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_); - return ir_builder_->CreateLoad(accumulator_alloca); + SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_); + return b_->CreateLoad(accumulator_alloca); } llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( @@ -2012,7 +2065,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); auto source_index = target_index; for (int64 dim : hlo->dimensions()) { - source_index[dim] = ir_builder_->CreateSub( + source_index[dim] = b_->CreateSub( llvm::ConstantInt::get(target_index[dim]->getType(), hlo->shape().dimensions(dim) - 1), target_index[dim]); @@ -2025,16 +2078,16 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); // The `dimensions` member of the broadcast instruction maps from // input dimensions to output dimensions. - return operand_to_generator.at( - operand)(target_index.SourceIndexOfBroadcast( - hlo->shape(), operand->shape(), hlo->dimensions(), ir_builder_)); + return operand_to_generator.at(operand)( + target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(), + hlo->dimensions(), b_)); }; case HloOpcode::kSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { IrArray::Index sliced_index = index.SourceIndexOfSlice( /*shape=*/hlo->shape(), /*starts=*/hlo->slice_starts(), - /*strides=*/hlo->slice_strides(), /*builder=*/ir_builder_); + /*strides=*/hlo->slice_strides(), /*builder=*/b_); return operand_to_generator.at(hlo->operand(0))(sliced_index); }; case HloOpcode::kDynamicSlice: @@ -2059,27 +2112,26 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( ShapeUtil::ElementsIn(hlo->operand(0)->shape())); return [this, hlo, &operand_to_generator](const IrArray::Index& index) { const HloInstruction* operand = hlo->operand(0); - return operand_to_generator.at(operand)(index.SourceIndexOfBitcast( - hlo->shape(), operand->shape(), ir_builder_)); + return operand_to_generator.at(operand)( + index.SourceIndexOfBitcast(hlo->shape(), operand->shape(), b_)); }; case HloOpcode::kReshape: CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()), ShapeUtil::ElementsIn(hlo->operand(0)->shape())); return [this, hlo, &operand_to_generator](const IrArray::Index& index) { const HloInstruction* operand = hlo->operand(0); - return operand_to_generator.at(operand)(index.SourceIndexOfReshape( - hlo->shape(), operand->shape(), ir_builder_)); + return operand_to_generator.at(operand)( + index.SourceIndexOfReshape(hlo->shape(), operand->shape(), b_)); }; case HloOpcode::kTranspose: return [this, hlo, &operand_to_generator](const IrArray::Index& target_index) { return operand_to_generator.at(hlo->operand(0))( target_index.SourceIndexOfTranspose( - hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions(), - ir_builder_)); + hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions(), b_)); }; case HloOpcode::kRng: - return MakeRngElementGenerator(hlo, operand_to_generator); + return MakePhiloxRngElementGenerator(hlo, operand_to_generator); case HloOpcode::kPad: return [this, hlo, &operand_to_generator]( const IrArray::Index& padded_index) -> StatusOr { @@ -2093,7 +2145,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return EmitElementalDot(hlo, operand_to_generator, dot_result_index); }; default: - return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + return [hlo](const IrArray::Index& index) { return Unimplemented("Unhandled opcode for elemental IR emission: %s", HloOpcodeString(hlo->opcode()).c_str()); }; @@ -2101,11 +2153,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( } llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const { - return ir_builder_->CreateExtractValue(value, {0}); + return b_->CreateExtractValue(value, {0}); } llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const { - return ir_builder_->CreateExtractValue(value, {1}); + return b_->CreateExtractValue(value, {1}); } llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, @@ -2113,10 +2165,10 @@ llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, llvm::Value* imag) const { auto cplx_type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); - auto complex = ir_builder_->CreateInsertValue( + auto complex = b_->CreateInsertValue( llvm::ConstantAggregateZero::get(cplx_type), real, {0}); if (imag != nullptr) { - complex = ir_builder_->CreateInsertValue(complex, imag, {1}); + complex = b_->CreateInsertValue(complex, imag, {1}); } return complex; } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index d199473374ad394913413a7d3fe805f8782936f7..fcb34557a52d35ef30a5dee643171e17407d05c2 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -34,10 +34,8 @@ class ElementalIrEmitter { std::unordered_map; ElementalIrEmitter(const HloModuleConfig& hlo_module_config, - llvm::Module* module, llvm::IRBuilder<>* ir_builder) - : ir_builder_(ir_builder), - module_(module), - hlo_module_config_(hlo_module_config) {} + llvm::Module* module, llvm::IRBuilder<>* b) + : b_(b), module_(module), hlo_module_config_(hlo_module_config) {} virtual ~ElementalIrEmitter() = default; @@ -54,7 +52,7 @@ class ElementalIrEmitter { const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator) const; - llvm::IRBuilder<>* ir_builder() const { return ir_builder_; } + llvm::IRBuilder<>* b() const { return b_; } llvm::Module* module() const { return module_; } protected: @@ -144,9 +142,7 @@ class ElementalIrEmitter { int64 operand_no) const; // Identifier of the thread unique among all threads on the device - virtual llvm::Value* EmitThreadId() const { - return ir_builder_->getIntN(128, 0); - } + virtual llvm::Value* EmitThreadId() const { return b_->getIntN(128, 0); } StatusOr EmitElementalSelect( const HloInstruction* hlo, @@ -188,7 +184,7 @@ class ElementalIrEmitter { const HloToElementGeneratorMap& operand_to_generator, const llvm_ir::IrArray::Index& dot_result_index) const; - llvm::IRBuilder<>* const ir_builder_; + llvm::IRBuilder<>* const b_; llvm::Module* module_; @@ -197,10 +193,17 @@ class ElementalIrEmitter { const HloModuleConfig& hlo_module_config_; private: - // Returns a ElementGenerator for a RNG HloInstruction. - llvm_ir::ElementGenerator MakeRngElementGenerator( + // Returns a ElementGenerator for an RNG HloInstruction using the Philox + // random number generation algorithm. + llvm_ir::ElementGenerator MakePhiloxRngElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator) const; + // Converts the raw value generated by a random number generation algorithm + // to the distribution requested by the RNG HloInstruction. + StatusOr ConvertValueForDistribution( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc index 8980d4303353a132ada2b3c685b4f2856c33c6a1..addb016b0481b744ff42ba827104099b6cdc3bb9 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -57,8 +57,8 @@ ENTRY main { } )"; - std::unique_ptr lhs = Literal::CreateR3({{{1}, {2}}}); - std::unique_ptr rhs = Literal::CreateR3({{{3}, {4}}}); + std::unique_ptr lhs = LiteralUtil::CreateR3({{{1}, {2}}}); + std::unique_ptr rhs = LiteralUtil::CreateR3({{{3}, {4}}}); RunTest(hlo_text, {lhs.get(), rhs.get()}); } } // namespace diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index 6794cfe297b0fb9a15eb9b7e6906d225f9597d07..228c3fac95c3114484637bd93ec51c60b44403cc 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -25,7 +25,7 @@ limitations under the License. namespace xla { AsyncExecution::AsyncExecution(Backend* backend, - std::vector streams, + std::vector streams, const ExecutionProfile& profile, GlobalDataHandle result) : backend_(CHECK_NOTNULL(backend)), @@ -46,9 +46,10 @@ Status AsyncExecution::BlockUntilDone() const { ExecutionTracker::ExecutionTracker() : next_handle_(1) {} -ExecutionHandle ExecutionTracker::Register( - Backend* backend, std::vector streams, - const ExecutionProfile& profile, GlobalDataHandle result) { +ExecutionHandle ExecutionTracker::Register(Backend* backend, + std::vector streams, + const ExecutionProfile& profile, + GlobalDataHandle result) { tensorflow::mutex_lock lock(execution_mutex_); int64 handle = next_handle_++; auto inserted = handle_to_execution_.emplace( diff --git a/tensorflow/compiler/xla/service/execution_tracker.h b/tensorflow/compiler/xla/service/execution_tracker.h index 4458152dd9a98890fc3a3e7f324245ec68821467..4e9b9f883e26f5564a9c63a40d2b4b9348908214 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.h +++ b/tensorflow/compiler/xla/service/execution_tracker.h @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/backend.h" -#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -40,7 +40,7 @@ namespace xla { // the stream when destructed. class AsyncExecution { public: - AsyncExecution(Backend* backend, std::vector streams, + AsyncExecution(Backend* backend, std::vector streams, const ExecutionProfile& profile, GlobalDataHandle result); Status BlockUntilDone() const; @@ -54,7 +54,7 @@ class AsyncExecution { Backend* backend_; // Stream on which the execution is launched. - std::vector streams_; + std::vector streams_; // Profile object of the execution to be returned to the user. ExecutionProfile profile_; @@ -72,7 +72,7 @@ class ExecutionTracker { // Registers an execution with its backend, streams, and data handle to the // execution result. Returns a handle for the registered execution. ExecutionHandle Register(Backend* backend, - std::vector stream, + std::vector stream, const ExecutionProfile& profile, GlobalDataHandle data); diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index d3854b40de3572a60df1ad99d8a4589f59ad7194..8f6608241ed02bbb7e9fde9b6d767c002435e777 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/flatten_call_graph.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -80,7 +80,7 @@ class FlattenCallGraphTest : public HloTestBase { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); return builder.Build(); @@ -157,7 +157,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(PRED, {}), "param0")); HloInstruction* false_constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); builder.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, param0, false_constant)); @@ -168,7 +168,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { { HloComputation::Builder builder(TestName() + ".entry"); HloInstruction* false_constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); builder.AddInstruction(HloInstruction::CreateWhile( ShapeUtil::MakeShape(PRED, {}), cond_computation, cond_computation, false_constant)); @@ -232,11 +232,11 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { // computation in the true and false branch. HloComputation::Builder builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(56.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(12.0f))); builder.AddInstruction(HloInstruction::CreateConditional( kScalarShape, pred, constant1, sub_computation, constant2, sub_computation)); diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 7cd2c9c136acac46e8e6c548c9e58b9bc8e6e0d2..e3a42d0d06be9e4c9ef96ed2e6ff5daa8eebaf3e 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -113,7 +114,7 @@ static StatusOr ExpandIndexVectorIntoOperandSpace( const Shape& index_shape = index_vector->shape(); HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateFromDimensions(index_shape.element_type(), {1}))); + LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); // We extract out individual components from the smaller index and concatenate // them (interspersing zeros as needed) into the larger index. diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 85e28a0dfe38415974e435106a2d0b75863f2df5..0ce2db907b643f3beabd127388370dbe601179e1 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -20,11 +20,10 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -60,17 +59,19 @@ Status GenericTransferManager::WriteSingleTupleIndexTable( void GenericTransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, - std::function>)> done) { + MutableBorrowingLiteral literal, std::function done) { Status status = stream->BlockHostUntilDone(); if (!status.ok()) { return done(status); } - done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer)); + + done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer, + literal)); } -StatusOr> -GenericTransferManager::TransferLiteralFromDeviceInternal( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { +Status GenericTransferManager::TransferLiteralFromDeviceInternal( + se::StreamExecutor* executor, const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal) { VLOG(2) << "transferring literal from device ordinal " << executor->device_ordinal() << "; device buffer: " << device_buffer; TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); @@ -80,9 +81,6 @@ GenericTransferManager::TransferLiteralFromDeviceInternal( TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), device_buffer.on_host_shape())); - std::unique_ptr literal = - Literal::CreateFromShape(device_buffer.on_host_shape()); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_host_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { @@ -91,12 +89,12 @@ GenericTransferManager::TransferLiteralFromDeviceInternal( /*source=*/device_buffer.buffer(index), /*size=*/GetByteSizeRequirement(subshape), /*destination=*/ - literal->untyped_data(index))); + literal.untyped_data(index))); } return Status::OK(); })); - return std::move(literal); + return Status::OK(); } Status GenericTransferManager::TransferLiteralToDeviceAsync( @@ -158,16 +156,10 @@ Status GenericTransferManager::TransferLiteralToInfeed( return Unimplemented("Generic transfer to Infeed"); } -Status GenericTransferManager::TransferBufferToInfeed( - se::StreamExecutor* executor, int64 size, const void* source) { - return Unimplemented("Generic transfer to Infeed"); -} - Status GenericTransferManager::TransferLiteralFromOutfeed( se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) { - return Unimplemented( - "Outfeed is not supported on this platform (b/30467474)"); + MutableBorrowingLiteral literal) { + return Unimplemented("Generic transfer from Outfeed"); } Status GenericTransferManager::ResetDevices( diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index d216fe7d29e8f2e84ab4f558ee5caec32d07a70a..6c1a21587a7ef5199afb93715dc57be5139fbc22 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/transfer_manager.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -41,9 +40,10 @@ class GenericTransferManager : public TransferManager { se::Platform::Id PlatformId() const override; - void TransferLiteralFromDevice( - se::Stream* stream, const ShapedBuffer& device_buffer, - std::function>)> done) override; + void TransferLiteralFromDevice(se::Stream* stream, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, + std::function done) override; Status TransferLiteralToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, @@ -53,7 +53,7 @@ class GenericTransferManager : public TransferManager { const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + MutableBorrowingLiteral literal) override; Status ResetDevices( tensorflow::gtl::ArraySlice executors) override; @@ -61,17 +61,15 @@ class GenericTransferManager : public TransferManager { int64 GetByteSizeRequirement(const Shape& shape) const override; protected: - Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, - const void* source) override; - Status WriteSingleTupleIndexTable( se::Stream* stream, tensorflow::gtl::ArraySlice elements, const Shape& shape, se::DeviceMemoryBase* region) override; private: - StatusOr> TransferLiteralFromDeviceInternal( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer); + Status TransferLiteralFromDeviceInternal(se::StreamExecutor* executor, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal); // The platform this transfer manager targets. const se::Platform::Id platform_id_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 88f994786a50b2516df845602af796eb12baf579..a3f6e8d9893528642e05354994c1d826949c6063 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -36,6 +36,7 @@ cc_library( hdrs = ["gpu_constants.h"], deps = [ "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework", ], ) @@ -113,11 +114,13 @@ cc_library( srcs = ["hlo_to_ir_bindings.cc"], hdrs = ["hlo_to_ir_bindings.h"], deps = [ + ":buffer_allocations", ":ir_emission_utils", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", @@ -141,6 +144,7 @@ cc_library( ], deps = [ ":backend_configs", + ":buffer_allocations", ":cudnn_convolution_runner", ":elemental_ir_emitter", ":gpu_constants", @@ -149,8 +153,7 @@ cc_library( ":ir_emission_utils", ":parallel_loop_emitter", ":partition_assignment", - ":while_transformer", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -162,13 +165,17 @@ cc_library( "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:name_uniquer", + "//tensorflow/compiler/xla/service:while_loop_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", + "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_tiling", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", - "//tensorflow/compiler/xla/service/llvm_ir:ops", + "//tensorflow/compiler/xla/service/llvm_ir:sort_util", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -199,7 +206,7 @@ cc_library( srcs = ["elemental_ir_emitter.cc"], hdrs = ["elemental_ir_emitter.h"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -214,6 +221,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", + "//tensorflow/compiler/xla/service/llvm_ir:math_ops", "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", @@ -244,8 +252,9 @@ cc_library( deps = [ "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_execution_profile", - "//tensorflow/compiler/xla/service:pool", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", + "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -264,7 +273,9 @@ cc_library( "infeed_thunk.cc", "kernel_thunk.cc", "memset_thunk.cc", + "outfeed_thunk.cc", "sequential_thunk.cc", + "thunk.cc", "thunk_schedule.cc", "tuple_thunk.cc", "while_thunk.cc", @@ -281,6 +292,7 @@ cc_library( "infeed_thunk.h", "kernel_thunk.h", "memset_thunk.h", + "outfeed_thunk.h", "sequential_thunk.h", "thunk.h", "thunk_schedule.h", @@ -288,15 +300,16 @@ cc_library( "while_thunk.h", ], deps = [ - ":backend_configs", ":buffer_allocations", ":cudnn_convolution_runner", ":hlo_execution_profiler", ":infeed_manager", ":ir_emission_utils", + ":outfeed_manager", ":partition_assignment", ":stream_assignment", "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -314,6 +327,7 @@ cc_library( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service:tuple_points_to_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", @@ -350,6 +364,7 @@ cc_library( ":cudnn_convolution_runner", ":gpu_executable", ":ir_emission_utils", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", @@ -381,7 +396,7 @@ cc_library( hdrs = ["cudnn_convolution_rewriter.h"], deps = [ ":ir_emission_utils", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -442,6 +457,7 @@ cc_library( srcs = ["multi_output_fusion.cc"], hdrs = ["multi_output_fusion.h"], deps = [ + ":instruction_fusion", ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/service:hlo", @@ -454,6 +470,7 @@ tf_cc_test( name = "multi_output_fusion_test", srcs = ["multi_output_fusion_test.cc"], deps = [ + ":instruction_fusion", ":multi_output_fusion", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", @@ -516,6 +533,24 @@ cc_library( hdrs = ["pad_insertion.h"], deps = [ ":ir_emission_utils", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_creation_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:shape_inference", + ], +) + +cc_library( + name = "pad_for_tensor_cores", + srcs = ["pad_for_tensor_cores.cc"], + hdrs = ["pad_for_tensor_cores.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", @@ -526,13 +561,31 @@ cc_library( ], ) +tf_cc_test( + name = "pad_for_tensor_cores_test", + srcs = ["pad_for_tensor_cores_test.cc"], + deps = [ + ":ir_emission_utils", + ":pad_for_tensor_cores", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep + ], +) + cc_library( name = "gpu_transfer_manager", srcs = ["gpu_transfer_manager.cc"], hdrs = ["gpu_transfer_manager.h"], deps = [ ":gpu_compiler", + ":outfeed_manager", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -551,8 +604,8 @@ cc_library( cc_library( name = "gpu_compiler", - srcs = ["gpu_compiler.cc"], - hdrs = ["gpu_compiler.h"], + srcs = ["nvptx_compiler.cc"], + hdrs = ["nvptx_compiler.h"], deps = [ ":cudnn_convolution_algorithm_picker", ":cudnn_convolution_rewriter", @@ -567,9 +620,11 @@ cc_library( ":ir_emission_utils", ":ir_emitter", ":multi_output_fusion", + ":pad_for_tensor_cores", ":pad_insertion", ":partition_assignment", ":stream_assignment", + ":stream_executor_util", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -581,7 +636,6 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", - "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -601,7 +655,6 @@ cc_library( "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", - "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter", @@ -623,24 +676,46 @@ cc_library( hdrs = ["cudnn_batchnorm_rewriter.h"], deps = [ ":ir_emission_utils", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", ], ) +cc_library( + name = "xfeed_queue", + hdrs = ["xfeed_queue.h"], + deps = ["//tensorflow/core:lib"], +) + cc_library( name = "infeed_manager", srcs = ["infeed_manager.cc"], hdrs = ["infeed_manager.h"], deps = [ + ":xfeed_queue", + "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], ) +cc_library( + name = "outfeed_manager", + srcs = ["outfeed_manager.cc"], + hdrs = ["outfeed_manager.h"], + deps = [ + ":xfeed_queue", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + cc_library( name = "gpu_layout_assignment", srcs = ["gpu_layout_assignment.cc"], @@ -672,6 +747,8 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep ], @@ -710,32 +787,17 @@ tf_cc_test( ], ) -cc_library( - name = "while_transformer", - srcs = ["while_transformer.cc"], - hdrs = ["while_transformer.h"], - deps = [ - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:lib", - ], -) - tf_cc_test( name = "while_transformer_test", srcs = ["while_transformer_test.cc"], deps = [ ":instruction_fusion", - ":while_transformer", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:while_loop_analysis", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -771,6 +833,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:stream_executor_no_cuda", ], diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index ab5149dcdb09290cd0c0b2233029d0988a95f036..537295292b6ced72c4b2c456557b3c06e0aa5254 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -44,17 +44,27 @@ StatusOr> BufferAllocations::Builder::Build( num_buffers, device_ordinal, memory_allocator, buffer_assignment)); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { + const BufferAllocation& allocation = buffer_assignment->GetAllocation(i); + const int64 expected_alignment = [&] { + if (allocation.is_entry_computation_parameter()) { + return kEntryParameterAlignBytes; + } else if (allocation.is_constant()) { + return kConstantBufferAlignBytes; + } else { + return kXlaAllocatedBufferAlignBytes; + } + }(); + // If buffer #i's address is already registered (e.g. external arguments or // result buffers), use that registered buffer. if (registered_buffers_.count(i)) { se::DeviceMemoryBase address = FindOrDie(registered_buffers_, i); - if (reinterpret_cast(address.opaque()) % - kCudaMallocAlignBytes != + if (reinterpret_cast(address.opaque()) % expected_alignment != 0) { return InternalError( "Address of registered buffer %lld must be a multiple of %llx, but " "was %p", - i, kCudaMallocAlignBytes, address.opaque()); + i, kEntryParameterAlignBytes, address.opaque()); } buffer_allocations->SetBuffer(i, FindOrDie(registered_buffers_, i)); continue; @@ -62,7 +72,6 @@ StatusOr> BufferAllocations::Builder::Build( // Allocate each allocation that might escape, or is the temp buffer. bool seen_temp_buffer = false; - const BufferAllocation& allocation = buffer_assignment->GetAllocation(i); if (allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer()) { const int64 buffer_size = allocation.size(); se::DeviceMemoryBase buffer_address; @@ -70,13 +79,12 @@ StatusOr> BufferAllocations::Builder::Build( OwningDeviceMemory buffer; TF_ASSIGN_OR_RETURN( buffer, memory_allocator->Allocate(device_ordinal, buffer_size)); - if (reinterpret_cast(buffer.opaque()) % - kCudaMallocAlignBytes != + if (reinterpret_cast(buffer.opaque()) % expected_alignment != 0) { return InternalError( "Address returned by memory_allocator->Allocate must be a " "multiple of %llx, but was %p", - kCudaMallocAlignBytes, buffer.opaque()); + kXlaAllocatedBufferAlignBytes, buffer.opaque()); } // We do manual memory management within BufferAllocations. Be sure not // to do a TF_RETURN_IF_ERROR between this line and the @@ -165,5 +173,10 @@ void BufferAllocations::SetBuffer(BufferAllocation::Index buffer_index, buffers_[buffer_index] = buffer; } +bool ShouldEmitLiteralInLlvmIr(const Literal& literal) { + // LLVM can sometimes do interesting optimizations using scalar constants. + return ShapeUtil::IsScalar(literal.shape()); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index 636623502597b3a66523938ba430e9d5a82f796c..f13eab0dd787a2bfa687c991f9d808568360fd24 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -107,6 +107,12 @@ class BufferAllocations { bool torn_down_ = false; }; +// LLVM and PTXAS don't deal well with large constants, so we only emit very +// small constants directly in LLVM IR. Larger constants are emitted with zero +// initializers in LLVM IR and are later overwritten when the PTX/CUBIN is +// loaded. +bool ShouldEmitLiteralInLlvmIr(const Literal& literal); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index 77a48965e031349b045a956fd3f28c58607328e5..5780e0af40699bb6ac2c190c09cd02023fb44db7 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -32,8 +33,11 @@ ConditionalThunk::ConditionalThunk( predicate_buffer_index_(predicate_buffer_index), true_operand_buffer_index_(true_operand_buffer_index), false_operand_buffer_index_(false_operand_buffer_index), - true_thunk_(std::move(true_thunk_sequence), hlo), - false_thunk_(std::move(false_thunk_sequence), hlo) {} + // Pass nullptr as the HloInstruction* to the true_thunk_ and false_thunk_ + // constructors because these SequentialThunks are logically "part of" + // this ConditionalThunk, and shouldn't be profiled separately from it. + true_thunk_(std::move(true_thunk_sequence), nullptr), + false_thunk_(std::move(false_thunk_sequence), nullptr) {} Status ConditionalThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { @@ -43,7 +47,9 @@ Status ConditionalThunk::Initialize(const GpuExecutable& executable, } Status ConditionalThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); // Copy the predicate value from device. bool predicate; se::DeviceMemoryBase predicate_address = @@ -59,10 +65,15 @@ Status ConditionalThunk::ExecuteOnStream( // Execute the true or the false computation depending on the value of the // predicate. if (predicate) { - TF_RETURN_IF_ERROR(true_thunk_.ExecuteOnStream(buffer_allocations, stream)); + profiler->StartHloComputation(); + TF_RETURN_IF_ERROR( + true_thunk_.ExecuteOnStream(buffer_allocations, stream, profiler)); + profiler->FinishHloComputation(hlo_instruction()->true_computation()); } else { + profiler->StartHloComputation(); TF_RETURN_IF_ERROR( - false_thunk_.ExecuteOnStream(buffer_allocations, stream)); + false_thunk_.ExecuteOnStream(buffer_allocations, stream, profiler)); + profiler->FinishHloComputation(hlo_instruction()->false_computation()); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h index ee03865d174469285a9e98b8a30fea90d997df37..aef24342c9fe182eb54b1c2beff840a76e7b8115 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_ #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -50,7 +51,8 @@ class ConditionalThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: BufferAllocation::Slice predicate_buffer_index_; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index f0881124128c9b043392ffc4fa3aee2cd5b754c7..7833a4077e6c6ee4960665f37fb01a35530fd302 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -55,7 +56,8 @@ ConvolutionThunk::ConvolutionThunk( tensor_ops_enabled_(tensor_ops_enabled) {} Status ConvolutionThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { se::DeviceMemoryBase input_data = buffer_allocations.GetDeviceAddress(input_buffer_); se::DeviceMemoryBase filter_data = @@ -68,6 +70,7 @@ Status ConvolutionThunk::ExecuteOnStream( se::dnn::AlgorithmConfig algorithm_config( se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_)); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); TF_RETURN_IF_ERROR(RunCudnnConvolution( convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data, filter_data, output_data, scratch, window_, dim_nums_, algorithm_config, diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 6d845025b1aef2b0a5f147401b6db0598ba94d6d..d76ca6698dcf462c3c4961ce6a9784822af3a81f 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" @@ -66,7 +67,8 @@ class ConvolutionThunk : public Thunk { // Does the convolution for the thunk on "stream". Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: class ScratchAllocator; diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc index ee38c0318a878c7bcdc02afdcd146bfb4498d9a2..92e03f94c11f68082f0a8caa64f82e8533557194 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -30,9 +31,11 @@ HostToDeviceCopyThunk::HostToDeviceCopyThunk( mem_size_(mem_size) {} Status HostToDeviceCopyThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenMemcpy(&destination_data, source_address_, mem_size_); return Status::OK(); } @@ -47,11 +50,13 @@ DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( mem_size_(mem_size) {} Status DeviceToDeviceCopyThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); se::DeviceMemoryBase source_data = buffer_allocations.GetDeviceAddress(source_buffer_); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenMemcpy(&destination_data, source_data, mem_size_); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h index 8b128386f61636de9ac41e856a2b00c578e05735..91564b520acae1839e0a466cf580db00bdf57e46 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -40,7 +41,8 @@ class HostToDeviceCopyThunk : public Thunk { HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const void* source_address_; @@ -63,7 +65,8 @@ class DeviceToDeviceCopyThunk : public Thunk { DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const BufferAllocation::Slice source_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc index c77e3c81c9d38af7857ad1389d20221514bf38f1..60289506524759580dbb9b82147c78c4ce1cb25e 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -66,11 +67,12 @@ Status Visitor::HandleBatchNormInference(HloInstruction* batch_norm) { return Status::OK(); } - HloInstruction* epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction* epsilon = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(batch_norm->epsilon()))); HloInstruction* feature_index = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(batch_norm->feature_index()))); + LiteralUtil::CreateR0(batch_norm->feature_index()))); std::vector operands(batch_norm->operands().begin(), batch_norm->operands().end()); @@ -101,11 +103,12 @@ Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) { return Status::OK(); } - HloInstruction* epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction* epsilon = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(batch_norm->epsilon()))); HloInstruction* feature_index = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(batch_norm->feature_index()))); + LiteralUtil::CreateR0(batch_norm->feature_index()))); std::vector operands(batch_norm->operands().begin(), batch_norm->operands().end()); @@ -128,8 +131,8 @@ Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) { inverse_stddev->shape(), HloOpcode::kPower, inverse_stddev, computation_->AddInstruction(HloInstruction::CreateBroadcast( inverse_stddev->shape(), - computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-2))), + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(-2))), {})))); HloInstruction* variance = computation_->AddInstruction(HloInstruction::CreateBinary( @@ -169,11 +172,12 @@ Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) { return Status::OK(); } - HloInstruction* epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction* epsilon = + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(batch_norm->epsilon()))); HloInstruction* feature_index = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(batch_norm->feature_index()))); + LiteralUtil::CreateR0(batch_norm->feature_index()))); // The cudnn libcall expects its input to be rsqrt(variance + epsilon), but // the batchnorm HLO takes plain variance as input. Fix it up. @@ -189,7 +193,7 @@ Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) { computation_->AddInstruction(HloInstruction::CreateBroadcast( var_plus_epsilon->shape(), computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(-.5))), + LiteralUtil::CreateR0(-.5))), {})))); std::vector operands(batch_norm->operands().begin(), diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index 68099fd63847ef9993f9bc7ac0e28b2939631b35..7b172812c36bb141787ef3a9285d6f7ce13e343b 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -99,13 +100,15 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk( } Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { dnn::BatchDescriptor operand_desc; dnn::BatchDescriptor scale_offset_desc; std::tie(operand_desc, scale_offset_desc) = MakeDescriptors(hlo_instruction()->shape(), feature_index_); se::DeviceMemory output(buffer_allocations.GetDeviceAddress(output_)); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenBatchNormalizationForward( se::DeviceMemory(buffer_allocations.GetDeviceAddress(operand_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), @@ -123,6 +126,7 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( /*is_training=*/false, // /*var_to_inv_var=*/nullptr, // /*inv_var_to_var=*/nullptr); + if (!stream->ok()) { return InternalError("BatchNormalizationForward call failed."); } @@ -158,7 +162,8 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( } Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { dnn::BatchDescriptor operand_desc; dnn::BatchDescriptor scale_offset_desc; // The BatchNormTraining HLO outputs a tuple of three elements: output data, @@ -175,6 +180,7 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( buffer_allocations.GetDeviceAddress(output_inv_stddev_)); se::DeviceMemory null_device_ptr(nullptr); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenBatchNormalizationForward( se::DeviceMemory(buffer_allocations.GetDeviceAddress(operand_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), @@ -240,7 +246,8 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( } Status CudnnBatchNormBackwardThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { dnn::BatchDescriptor operand_desc; dnn::BatchDescriptor scale_offset_desc; @@ -257,6 +264,7 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream( se::DeviceMemory output_grad_offset( buffer_allocations.GetDeviceAddress(output_grad_offset_)); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenBatchNormalizationBackward( se::DeviceMemory( buffer_allocations.GetDeviceAddress(grad_output_)), diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h index 874f85a863092ee05ae5df1f92d732318c5a0554..d2143b3952984722d136757255aa0aa60e9cab7e 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" @@ -60,7 +61,8 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk { const CudnnBatchNormForwardInferenceThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: BufferAllocation::Slice operand_; @@ -90,7 +92,8 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { const CudnnBatchNormForwardTrainingThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: BufferAllocation::Slice operand_; @@ -123,7 +126,8 @@ class CudnnBatchNormBackwardThunk : public Thunk { delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: BufferAllocation::Slice operand_; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 3dc98c4c93ea2b9b68dd3ee27794a39847f8756c..7348307ec8a7286dfb733d6b9685862b20f11ac9 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -14,12 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/mutex.h" namespace xla { namespace gpu { @@ -80,8 +82,7 @@ bool ShouldIncludeWinogradNonfusedAlgo(const Shape& input_shape, const ConvolutionDimensionNumbers& dnums, se::StreamExecutor* stream_exec) { // Skip this check for cudnn7 and newer. - auto version = - stream_exec->AsDnn()->GetVersion(); + auto version = stream_exec->AsDnn()->GetVersion(); if (version.ok() && version.ValueOrDie().major_version() >= 7) { return true; } @@ -137,6 +138,28 @@ string NumBytesToString(int64 bytes) { tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)"); } +// Acquires a process-global lock on the device pointed to by the given +// StreamExecutor. +// +// This is used to prevent other XLA instances from trying to autotune on this +// device while we're using it. +tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + // se::Platform*s are global singletons guaranteed to live forever. + static auto* mutexes = + new std::map, + tensorflow::mutex>(); + + tensorflow::mutex_lock global_lock(mu); + auto it = mutexes + ->emplace(std::piecewise_construct, + std::make_tuple(stream_exec->platform(), + stream_exec->device_ordinal()), + std::make_tuple()) + .first; + return tensorflow::mutex_lock{it->second}; +} + } // anonymous namespace // We could have caching here so that we don't redo this work for two identical @@ -155,6 +178,13 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) { + // Don't run this function concurrently on the same GPU. + // + // This is a bit of a hack and doesn't protect us against arbitrary concurrent + // use of a GPU, but it's sufficient to let us compile two HLO modules + // concurrently and then run them sequentially. + tensorflow::mutex_lock lock = LockGpu(stream_exec_); + // Create a stream for us to do our work on. se::Stream stream{stream_exec_}; stream.Init(); @@ -338,8 +368,8 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( computation->AddInstruction(HloInstruction::CreateTuple( {computation->AddInstruction(HloInstruction::CreateGetTupleElement( new_call_shape.tuple_shapes(0), new_call, 0)), - computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({})))})); + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({})))})); TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple)); return true; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index f9dccd287d955502858f6c24ccd4de80256fc148..905b5ee8767d0fa0514c7f1abf83bc089cd08045 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 27d2c3e491bfc2108cbd168d1a5e1575c2eed11f..cc38db27e2680e950f74e104cef8829585c7b81c 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -29,12 +29,13 @@ limitations under the License. #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -67,8 +68,8 @@ bool IsFPLiteralWithValue(const HloInstruction* operand, float value) { GpuElementalIrEmitter::GpuElementalIrEmitter( const HloModuleConfig& hlo_module_config, llvm::Module* module, - llvm::IRBuilder<>* ir_builder, NestedComputer compute_nested) - : ElementalIrEmitter(hlo_module_config, module, ir_builder), + llvm::IRBuilder<>* b, NestedComputer compute_nested) + : ElementalIrEmitter(hlo_module_config, module, b), hlo_module_config_(hlo_module_config), compute_nested_(std::move(compute_nested)) {} @@ -92,8 +93,8 @@ StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( cast_result_to_fp16 = true; for (int64 i = 0; i < operands.size(); ++i) { if (input_types[i] == F16) { - converted_operands[i] = ir_builder_->CreateFPCast( - converted_operands[i], ir_builder_->getFloatTy()); + converted_operands[i] = + b_->CreateFPCast(converted_operands[i], b_->getFloatTy()); converted_input_types[i] = F32; } } @@ -112,7 +113,7 @@ StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( converted_input_types, output_type) .ValueOrDie(); if (cast_result_to_fp16) { - result = ir_builder_->CreateFPCast(result, ir_builder_->getHalfTy()); + result = b_->CreateFPCast(result, b_->getHalfTy()); } return result; } @@ -215,7 +216,7 @@ StatusOr GpuElementalIrEmitter::EmitPowerOp( // LLVM's NVPTX backend knows how to transform 1/sqrt(A) into the NVPTX // rsqrt.approx instruction. TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt()); - return ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); + return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); } VLOG(10) << "emitting pow as regular call to pow(): " << op->ToString(); @@ -277,6 +278,16 @@ StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( PrimitiveType output_type = op->shape().element_type(); switch (op->opcode()) { case HloOpcode::kTanh: + // If we don't care much about precision, emit a fast approximation of + // tanh. + if (hlo_module_config_.debug_options().xla_enable_fast_math()) { + // Upcast F16 to F32 if necessary. + llvm::Type* type = + input_type == F16 ? b_->getFloatTy() : operand_value->getType(); + llvm::Value* input = b_->CreateFPCast(operand_value, type); + llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); + return b_->CreateFPCast(fast_tanh, operand_value->getType()); + } return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type}, output_type); default: @@ -302,32 +313,31 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( // Declares the callee if it is not declared already. llvm::Function* callee = llvm::cast( - ir_builder_->GetInsertBlock()->getModule()->getOrInsertFunction( + b_->GetInsertBlock()->getModule()->getOrInsertFunction( llvm_ir::AsStringRef(callee_name), callee_type)); for (auto attribute : attributes) { callee->addFnAttr(attribute); } - return ir_builder_->CreateCall(callee, llvm_ir::AsArrayRef(operands)); + return b_->CreateCall(callee, llvm_ir::AsArrayRef(operands)); } llvm::Value* GpuElementalIrEmitter::EmitThreadId() const { - llvm::Value* block_id = ir_builder_->CreateIntCast( + llvm::Value* block_id = b_->CreateIntCast( llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, - {}, {}, ir_builder_), - ir_builder_->getIntNTy(128), /*isSigned=*/true, "block.id"); - llvm::Value* thread_id_in_block = ir_builder_->CreateIntCast( + {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "block.id"); + llvm::Value* thread_id_in_block = b_->CreateIntCast( llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, - {}, {}, ir_builder_), - ir_builder_->getIntNTy(128), /*isSigned=*/true, "thread.id"); - llvm::Value* threads_per_block = ir_builder_->CreateIntCast( + {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "thread.id"); + llvm::Value* threads_per_block = b_->CreateIntCast( llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, - {}, {}, ir_builder_), - ir_builder_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); - return ir_builder_->CreateNSWAdd( - ir_builder_->CreateNSWMul(block_id, threads_per_block), - thread_id_in_block); + {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); + return b_->CreateNSWAdd(b_->CreateNSWMul(block_id, threads_per_block), + thread_id_in_block); } llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( @@ -373,12 +383,12 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( PrimitiveType operand_element_type = operand->shape().element_type(); llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), - "reduce_window_accum_ptr", ir_builder_); + "reduce_window_accum_ptr", b_); { TF_ASSIGN_OR_RETURN(llvm::Value * init_value, operand_to_generator.at(hlo->operand(1))( IrArray::Index(index.GetType()))); - ir_builder_->CreateStore(init_value, accum_ptr); + b_->CreateStore(init_value, accum_ptr); } llvm::Type* index_type = index.GetType(); @@ -386,7 +396,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( return index.GetConstantWithIndexType(c); }; - llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_, index_type); + llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type); std::vector window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); @@ -395,15 +405,15 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( ShapeUtil::MakeShape(operand_element_type, window_size), "window"); CHECK_EQ(window_index.size(), index.size()); - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder_); + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_); IrArray::Index input_index(index_type, index.size()); - llvm::Value* in_bounds = ir_builder_->getInt1(true); + llvm::Value* in_bounds = b_->getInt1(true); for (size_t i = 0; i < index.size(); ++i) { - llvm::Value* stridden_index = ir_builder_->CreateNSWMul( + llvm::Value* stridden_index = b_->CreateNSWMul( index[i], index_typed_const(window.dimensions(i).stride())); - input_index[i] = ir_builder_->CreateNSWSub( - ir_builder_->CreateNSWAdd(stridden_index, window_index[i]), + input_index[i] = b_->CreateNSWSub( + b_->CreateNSWAdd(stridden_index, window_index[i]), index_typed_const(window.dimensions(i).padding_low())); // We must check whether 0 ≤ input_index[i] < bound, as otherwise @@ -411,16 +421,16 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( // comparison is equivalent to the unsigned comparison // input_index[i] < bound, as a negative value wraps to a large // positive value. - in_bounds = ir_builder_->CreateAnd( + in_bounds = b_->CreateAnd( in_bounds, - ir_builder_->CreateICmpULT( + b_->CreateICmpULT( input_index[i], index_typed_const(operand->shape().dimensions(i)))); } llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); - SetToFirstInsertPoint(if_data.true_block, ir_builder_); + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_); + SetToFirstInsertPoint(if_data.true_block, b_); // We are not in pad, so do the computation. TF_ASSIGN_OR_RETURN(llvm::Value * input_value, @@ -428,26 +438,26 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( TF_ASSIGN_OR_RETURN( llvm::Value * accum_value, compute_nested_(*hlo->to_apply(), - {ir_builder_->CreateLoad(accum_ptr), input_value})); - ir_builder_->CreateStore(accum_value, accum_ptr); + {b_->CreateLoad(accum_ptr), input_value})); + b_->CreateStore(accum_value, accum_ptr); - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), ir_builder_); - return ir_builder_->CreateLoad(accum_ptr); + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_); + return b_->CreateLoad(accum_ptr); }; case HloOpcode::kReduce: return [=, &operand_to_generator]( const IrArray::Index& output_index) -> StatusOr { const HloInstruction* operand = hlo->operand(0); llvm::Value* accum_ptr = - ir_builder()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType( + b()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType( hlo->shape().element_type(), module_)); llvm::Type* index_type = output_index.GetType(); TF_ASSIGN_OR_RETURN(llvm::Value * init_value, operand_to_generator.at(hlo->operand(1))( IrArray::Index(index_type))); - ir_builder()->CreateStore(init_value, accum_ptr); + b()->CreateStore(init_value, accum_ptr); - llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_, index_type); + llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type); IrArray::Index input_index = loops.AddLoopsForShapeOnDimensions( operand->shape(), hlo->dimensions(), "reduction_dim"); if (!ShapeUtil::IsScalar(hlo->shape())) { @@ -462,18 +472,17 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( CHECK_EQ(output_index.size(), j); } - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder()); + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b()); TF_ASSIGN_OR_RETURN( llvm::Value * input_value, operand_to_generator.at(hlo->operand(0))(input_index)); TF_ASSIGN_OR_RETURN( llvm::Value * accum_value, - compute_nested_( - *hlo->to_apply(), - {ir_builder()->CreateLoad(accum_ptr), input_value})); - ir_builder()->CreateStore(accum_value, accum_ptr); - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), ir_builder()); - return ir_builder()->CreateLoad(accum_ptr); + compute_nested_(*hlo->to_apply(), + {b()->CreateLoad(accum_ptr), input_value})); + b()->CreateStore(accum_value, accum_ptr); + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b()); + return b()->CreateLoad(accum_ptr); }; default: return ElementalIrEmitter::MakeElementGenerator(hlo, diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 91f4d960aa62fff3e0699ece37a8c74d7dcf2f59..e3eacef133cb8b615a645ca2f11dd6dedf9f0176 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -43,7 +43,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { const HloComputation&, tensorflow::gtl::ArraySlice)>; GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config, - llvm::Module* module, llvm::IRBuilder<>* ir_builder, + llvm::Module* module, llvm::IRBuilder<>* b, NestedComputer compute_nested); llvm_ir::ElementGenerator MakeElementGenerator( diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index e14ee6918bf148861ecccac99355fccf7ae93103..0cdddf8bcfd4e849b311bf810eda471d79dbf106 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -107,7 +108,8 @@ FftThunk::FftThunk(FftType fft_type, output_shape_(output_shape) {} Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { + se::Stream* stream, + HloExecutionProfiler* profiler) { VLOG(3) << "FFT type: " << FftTypeToString(fft_type_); VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_); VLOG(3) << "Output shape: " @@ -116,6 +118,7 @@ Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, FftScratchAllocator scratch_allocator(buffer_allocations.device_ordinal(), buffer_allocations.memory_allocator()); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); if (fft_plan_ == nullptr) { const int64 fft_rank = fft_length_.size(); CHECK_LE(fft_rank, 3); diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index b0a22564f3a09bb67a3c01723f6e37c604656d45..8c53be5077b0c5a88d303c729457139c6cb800f1 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" @@ -72,7 +73,8 @@ class FftThunk : public Thunk { // Does the FFT for the thunk on "stream". Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const se::fft::Type fft_type_; diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index b36539e0cb8d0a2f4758dd90acbdd8fc7181b8ca..2fd2206324e5f763490780a54880825a772b7ea2 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -27,8 +28,11 @@ ForThunk::ForThunk(const int64 loop_limit, const HloInstruction* hlo) : Thunk(Kind::kWhile, hlo), loop_limit_(loop_limit), - body_thunk_sequence_( - MakeUnique(std::move(*body_thunk_sequence), hlo)) {} + body_thunk_sequence_(MakeUnique( + // Pass nullptr as the HloInstruction* to the body_thunk_sequence_ + // constructor because this SequentialThunk is logically "part of" + // this ForThunk, and shouldn't be profiled separately from it. + std::move(*body_thunk_sequence), nullptr)) {} Status ForThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { @@ -37,11 +41,17 @@ Status ForThunk::Initialize(const GpuExecutable& executable, } Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { + se::Stream* stream, + HloExecutionProfiler* profiler) { + VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for " + << (hlo_instruction() ? hlo_instruction()->ToString() : ""); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); for (int64 i = 0; i < loop_limit_; ++i) { + profiler->StartHloComputation(); // Invoke loop body thunk sequence. - TF_RETURN_IF_ERROR( - body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); + TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations, + stream, profiler)); + profiler->FinishHloComputation(hlo_instruction()->while_body()); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index 41ddfe0ceb1d0516c1c64feca53212a925632209..c2d39071b292c6704e9b5857a68bd8b3f3b9a914 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -39,7 +40,8 @@ class ForThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const int64 loop_limit_; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 79fca43d022816645b8a07b9e806fe9cc3745e7c..74282c568c09921dbeec2e9cce79b6c73b6ea592 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -31,16 +32,19 @@ namespace { // dimensions. struct MatrixDescriptor { MatrixDescriptor(se::DeviceMemoryBase matrix_data, bool needs_transpose, - int64 matrix_num_rows, int64 matrix_num_cols) + int64 matrix_num_rows, int64 matrix_num_cols, + int64 matrix_batch_size) : data(matrix_data), transpose(needs_transpose), num_rows(matrix_num_rows), - num_cols(matrix_num_cols) {} + num_cols(matrix_num_cols), + batch_size(matrix_batch_size) {} se::DeviceMemoryBase data; bool transpose; // Whether this matrix needs to be transposed. int64 num_rows; int64 num_cols; + int64 batch_size; }; // Performs a gemm call without an explicit algorithm on lhs_matrix and @@ -50,6 +54,9 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, MatrixDescriptor output_matrix, double alpha, se::Stream* stream) { DCHECK(!output_matrix.transpose); + const int64 batch_size = lhs_matrix.batch_size; + CHECK_EQ(batch_size, rhs_matrix.batch_size); + CHECK_EQ(batch_size, output_matrix.batch_size); se::DeviceMemory lhs_data(lhs_matrix.data); se::DeviceMemory rhs_data(rhs_matrix.data); se::DeviceMemory output_data(output_matrix.data); @@ -60,13 +67,30 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, : se::blas::Transpose::kNoTranspose; auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols; + if (batch_size == 1) { + return stream + ->ThenBlasGemm( + lhs_transpose, rhs_transpose, output_matrix.num_rows, + output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha, + lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, + &output_data, /*leading dim of output=*/output_matrix.num_rows) + .ok(); + } + + int64 lhs_stride = lhs_matrix.num_rows * lhs_matrix.num_cols; + int64 rhs_stride = rhs_matrix.num_rows * rhs_matrix.num_cols; + int64 output_stride = output_matrix.num_rows * output_matrix.num_cols; return stream - ->ThenBlasGemm( + ->ThenBlasGemmStridedBatched( lhs_transpose, rhs_transpose, output_matrix.num_rows, - output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha, - lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, - /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, - &output_data, /*leading dim of output=*/output_matrix.num_rows) + output_matrix.num_cols, /*size of reduce dim=*/k, + /*alpha=*/alpha, lhs_data, + /*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride, + /*beta=*/0.0, &output_data, + /*leading dim of output=*/output_matrix.num_rows, output_stride, + batch_size) .ok(); } @@ -93,6 +117,10 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, se::blas::ProfileResult* output_profile_result) { DCHECK(!output_matrix.transpose); + CHECK_EQ(1, lhs_matrix.batch_size); + CHECK_EQ(1, rhs_matrix.batch_size); + CHECK_EQ(1, output_matrix.batch_size); + se::DeviceMemory lhs_data(lhs_matrix.data); se::DeviceMemory rhs_data(rhs_matrix.data); se::DeviceMemory output_data(output_matrix.data); @@ -141,9 +169,15 @@ StatusOr DoGemmAutotune( alpha, computation_type, algorithm, stream, &profile_result)); - if (profile_result.is_valid() && profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; + if (profile_result.is_valid()) { + VLOG(3) << "cublas gemm algorithm " << algorithm << " took " + << profile_result.elapsed_time_in_ms() << "ms"; + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + } else { + VLOG(4) << "cublas gemm algorithm " << algorithm << " failed."; } } @@ -167,6 +201,8 @@ auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm) { return &DoGemm; case F64: return &DoGemm; + case C64: + return &DoGemm>; default: LOG(FATAL) << "Unsupported type."; } @@ -180,6 +216,8 @@ auto GetGemmWithAlgorithmFn(PrimitiveType type) return &DoGemmWithAlgorithm; case F64: return &DoGemmWithAlgorithm; + case C64: + return &DoGemmWithAlgorithm>; default: LOG(FATAL) << "Unsupported type."; } @@ -192,6 +230,8 @@ auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune) { return &DoGemmAutotune; case F64: return &DoGemmAutotune; + case C64: + return &DoGemmAutotune>; default: LOG(FATAL) << "Unsupported type."; } @@ -210,6 +250,8 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { return se::blas::ComputationType::kF32; case F64: return se::blas::ComputationType::kF64; + case C64: + return se::blas::ComputationType::kComplexF32; default: LOG(FATAL) << "Unsupported type."; } @@ -252,7 +294,8 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, alpha_(alpha) {} Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { + se::Stream* stream, + HloExecutionProfiler* profiler) { VLOG(2) << "Executing a GemmThunk"; se::DeviceMemoryBase lhs_data = @@ -262,12 +305,37 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::DeviceMemoryBase output_data = buffer_allocations.GetDeviceAddress(output_buffer_); + DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), + dim_nums.rhs_batch_dimensions_size()); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, + ShapeUtil::Rank(output_shape_)); + + int64 row_dim = dim_nums.lhs_batch_dimensions_size(); + int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1; + int64 batch_size = std::accumulate(output_shape_.dimensions().begin(), + output_shape_.dimensions().end() - 2, 1, + std::multiplies()); + + // Check that the batch dims don't cover the last two dims. + for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) { + CHECK_NE(row_dim, batch_dim); + CHECK_NE(col_dim, batch_dim); + } + + // Verify that the non-batch dimensions are minor-most. This is required for + // efficient access. + for (const auto* shape : {&lhs_shape_, &rhs_shape_, &output_shape_}) { + CHECK_LT(shape->layout().minor_to_major(row_dim), 2); + CHECK_LT(shape->layout().minor_to_major(col_dim), 2); + } + // BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between // matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of // their layout. Therefore, we should treat dimension 0 as row and dimension 1 // as column when mapping a matrix Dot to BLAS gemm. - int64 output_num_rows = output_shape_.dimensions(0); - int64 output_num_cols = output_shape_.dimensions(1); + int64 output_num_rows = output_shape_.dimensions(row_dim); + int64 output_num_cols = output_shape_.dimensions(col_dim); // BLAS gemm expects the inputs and the output are in column-major order. // Therefore, we need to convert dot between row-major matrices to that @@ -290,34 +358,46 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, // the leading dimension of the LHS matrix of gemm is the number of rows in // B^T and thus the number of columns in B. - auto make_descriptor = [this](se::DeviceMemoryBase data, const Shape& shape, - bool transpose) -> MatrixDescriptor { - bool is_row_major = LayoutUtil::Minor(shape.layout(), 0) != 0; - bool layout_mismatch = LayoutUtil::Minor(shape.layout(), 0) != - LayoutUtil::Minor(output_shape_.layout(), 0); - return MatrixDescriptor(data, transpose ^ layout_mismatch, - shape.dimensions(is_row_major), - shape.dimensions(!is_row_major)); + auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape& shape, + bool transpose) -> MatrixDescriptor { + bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0; + bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) != + LayoutUtil::Minor(output_shape_.layout(), row_dim); + return MatrixDescriptor( + data, transpose ^ layout_mismatch, + shape.dimensions(row_dim + static_cast(is_row_major)), + shape.dimensions(row_dim + static_cast(!is_row_major)), + batch_size); }; - DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); - const MatrixDescriptor lhs_descriptor = make_descriptor( - lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0); + lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == row_dim); const MatrixDescriptor rhs_descriptor = make_descriptor( - rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == 1); + rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == col_dim); // Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to // autotune this gemm to figure out the best algorithm. - auto launch = [this](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, se::Stream* stream) { + auto launch = [&](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, + MatrixDescriptor output_matrix, se::Stream* stream) { PrimitiveType element_type = output_shape_.element_type(); se::blas::ComputationType computation_type = GetBlasComputationType(element_type); + // TODO(b/112111608): Implement auto tune for batched gemm. + if (batch_size != 1) { + return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, + alpha_, stream); + } + + auto thunk_name = [&] { + return hlo_instruction() != nullptr ? hlo_instruction()->ToString() + : ""; + }; + const string& device_name = stream->parent()->GetDeviceDescription().name(); auto autotune_it = autotune_results_.find(device_name); if (autotune_it == autotune_results_.end()) { + VLOG(3) << "Starting autotune of GemmThunk " << thunk_name(); StatusOr best_algorithm = GetGemmAutotuneFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, alpha_, computation_type, stream); @@ -325,11 +405,11 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, autotune_results_.insert({device_name, best_algorithm}).first; if (autotune_it->second.ok()) { - VLOG(2) << "Autotune on GemmThunk " << this + VLOG(2) << "Autotune on GemmThunk " << thunk_name() << " successful; best algorithm is " << best_algorithm.ValueOrDie(); } else { - VLOG(2) << "Autotune on GemmThunk " << this + VLOG(2) << "Autotune on GemmThunk " << thunk_name() << " unsuccessful. Will use generic gemm."; } } @@ -339,7 +419,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, if (best_algorithm.ok()) { auto algorithm = best_algorithm.ValueOrDie(); VLOG(2) << "Using algorithm " << algorithm - << " chosen by autotuning on GemmThunk " << this; + << " chosen by autotuning on GemmThunk " << thunk_name(); return GetGemmWithAlgorithmFn(element_type)( lhs_matrix, rhs_matrix, output_matrix, alpha_, computation_type, algorithm, stream, @@ -352,17 +432,18 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, alpha_, stream); }; + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); bool launch_ok; - if (LayoutUtil::Minor(output_shape_.layout(), 0) == 0) { - launch_ok = launch( - lhs_descriptor, rhs_descriptor, - MatrixDescriptor(output_data, false, output_num_rows, output_num_cols), - stream); + if (LayoutUtil::Minor(output_shape_.layout(), row_dim) == 0) { + launch_ok = launch(lhs_descriptor, rhs_descriptor, + MatrixDescriptor(output_data, false, output_num_rows, + output_num_cols, batch_size), + stream); } else { - launch_ok = launch( - rhs_descriptor, lhs_descriptor, - MatrixDescriptor(output_data, false, output_num_cols, output_num_rows), - stream); + launch_ok = launch(rhs_descriptor, lhs_descriptor, + MatrixDescriptor(output_data, false, output_num_cols, + output_num_rows, batch_size), + stream); } if (!launch_ok) { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 7a4830d64e7caef5a1170cbdbf8ab373fdaf16e2..939c7f85e35b4fcb943a25aa6346d72798432920 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -48,7 +49,8 @@ class GemmThunk : public Thunk { // Does the gemm operation for the thunk on "stream", which must be non-null. Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; // Returns true if we'll perform autotuning if run on the given stream. If // so, we want the GPU to be quiescent during autotuning, so as not to diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.cc b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc index aa360c7f73de2f0f9cf59c22b552b8e60ddb3a87..7f0b030fece8f25578bd90a538279d455350278a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_constants.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc @@ -14,12 +14,23 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" +#include "tensorflow/core/framework/allocator.h" namespace xla { namespace gpu { -// http://docs.nvidia.com/cuda/cuda-c-programming-guide/#device-memory-accesses -const int64 kCudaMallocAlignBytes = 256; +// kEntryParameterAlignBytes is equal to EIGEN_MAX_ALIGN_BYTES, though including +// Eigen headers here to get that symbol may not be a good idea. +// EIGEN_MAX_ALIGN_BYTES may differ between CUDA-enabled builds vs CUDA-disabled +// builds and we don't want the IR generated by XLA:GPU to depend on that. +// +// TODO(b/111767313): Consider raising EIGEN_MAX_ALIGN_BYTES if it helps. +const int64 kEntryParameterAlignBytes = 16; + +const int64 kXlaAllocatedBufferAlignBytes = + tensorflow::Allocator::kAllocatorAlignment; + +const int64 kConstantBufferAlignBytes = kXlaAllocatedBufferAlignBytes; } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.h b/tensorflow/compiler/xla/service/gpu/gpu_constants.h index eb1ca4c6c95a23d2a08f5f9c3cbc85e7d47d4f89..6f5f1fa09c57dfd246d702c0adc92c7e2e76805a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_constants.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.h @@ -21,9 +21,15 @@ limitations under the License. namespace xla { namespace gpu { -// Minimum alignment of cudaMalloc. We require that buffers created by our -// DeviceMemoryAllocator, and all input/output buffers, have this alignment. -extern const int64 kCudaMallocAlignBytes; +// Minimum alignment for buffers passed as incoming arguments by TensorFlow. +extern const int64 kEntryParameterAlignBytes; + +// Minimum alignment for buffers allocated by XLA: the temp buffers and the live +// out (result) buffers. +extern const int64 kXlaAllocatedBufferAlignBytes; + +// Minimum alignment for constant buffers. +extern const int64 kConstantBufferAlignBytes; } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index fbc1303085b579e898d2f503a341754109768567..75f414e47fe3edcc1b10b392ed5cc5038be6c190 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -48,80 +48,17 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module)); - TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, - HloDataflowAnalysis::Run(*module)); - - // Make sure all operands of a library call are in memory instead of constants - // in IR. Also, init values of while and conditional nodes cannot be - // constants. Insert copies for any constants found at the operands of these - // nodes. - tensorflow::gtl::FlatSet inserted_copies; + // Check the assumption that the epsilon and feature_index constants of the + // CUDNN batchnorm op are not shared with other ops where we would replace + // them with a copy. These custom op calls are generated with the + // CudnnBatchNormRewriter, so this would only happen if HloCSE merges them. for (HloComputation* computation : module->computations()) { for (HloInstruction* hlo : computation->instructions()) { - // Inserts a copy of hlo->operand(n) if it's a constant. - auto copy_operand_if_constant = [&](int64 n) -> Status { - HloInstruction* operand = hlo->mutable_operand(n); - // Skip the operands that have already been replaced with a copy in a - // previous iteration (which is possible when a constant is used as an - // operand in multiple places). - if (ContainsKey(inserted_copies, operand)) { - return Status::OK(); - } - for (auto& pair : dataflow->GetInstructionValueSet(operand)) { - const HloValueSet& value_set = pair.second; - for (const HloValue* value : value_set.values()) { - if (value->defining_instruction()->IsConstant() && - !ContainsKey(hlo_to_copy_map_, value->defining_instruction())) { - HloInstruction* constant = value->defining_instruction(); - TF_ASSIGN_OR_RETURN(HloInstruction * copy, - FindOrInsertCopy(constant)); - TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy)); - inserted_copies.insert(copy); - changed = true; - } - } - } - return Status::OK(); - }; - - if (IsCustomCallToDnnBatchNorm(*hlo)) { - // The epsilon and feature_index operands to a CUDNN batchnorm op don't - // need to be materialized in memory -- in fact, they must be constants. - // These are the last two operands of all three batchnorm ops. - for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } - } else if (ImplementedAsLibraryCall(*hlo) || - hlo->opcode() == HloOpcode::kCrossReplicaSum || - hlo->opcode() == HloOpcode::kWhile || - hlo->opcode() == HloOpcode::kConditional) { - // For all other library calls, cross-replica-sum, while and conditional - // ops materialize all the operands into memory. (Cross-replica-sum - // gets its constant args materialized even if it's not implemented as a - // libcall to simplify the implementation. It's slower, but we can - // constant fold away constant args *anyway*, so we just need to make it - // work.) - for (int64 i = 0; i < hlo->operand_count(); ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } + if (!IsCustomCallToDnnBatchNorm(*hlo)) { + continue; } - } - } - - if (changed) { - // Check the assumption that the epsilon and feature_index constants of the - // CUDNN batchnorm op are not shared with other ops where we would replace - // them with a copy. These custom op calls are generated with the - // CudnnBatchNormRewriter, so this would only happen if HloCSE merges them. - for (HloComputation* computation : module->computations()) { - for (HloInstruction* hlo : computation->instructions()) { - if (!IsCustomCallToDnnBatchNorm(*hlo)) { - continue; - } - for (int64 i = hlo->operand_count() - 2; i < hlo->operand_count(); - ++i) { - CHECK_EQ(hlo->operand(i)->opcode(), HloOpcode::kConstant); - } + for (int64 i = hlo->operand_count() - 2; i < hlo->operand_count(); ++i) { + CHECK_EQ(hlo->operand(i)->opcode(), HloOpcode::kConstant); } } } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index f20a828bc1a31ad15298a1d77cd79599aa12faf4..bb7736efa65c49766ea88a43fdab2b7102100c11 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -84,7 +85,7 @@ Status GpuExecutable::ExecuteThunks( } // Stream 0 indicates `main_stream` and substreams start from stream 1. - std::vector::SmartPtr> sub_streams; + std::vector sub_streams; sub_streams.reserve(thunk_schedule_->StreamCount() - 1); while (sub_streams.size() + 1 < thunk_schedule_->StreamCount()) { sub_streams.emplace_back(); @@ -136,18 +137,17 @@ Status GpuExecutable::ExecuteThunks( TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); } - profiler.StartOperation(); VLOG(2) << "Executing the thunk for " << thunk->hlo_instruction()->ToString() << " on stream " << stream_no; - TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); + TF_RETURN_IF_ERROR( + thunk->ExecuteOnStream(buffer_allocations, stream, &profiler)); if (thunk_schedule_->Depended(thunk)) { auto finish_event = MakeUnique(main_stream->parent()); finish_event->Init(); stream->ThenRecordEvent(finish_event.get()); thunk_to_finish_event[thunk] = std::move(finish_event); } - profiler.FinishOperation(thunk->hlo_instruction()); } main_stream->ThenWaitFor(&sub_streams); @@ -182,6 +182,55 @@ Status GpuExecutable::ExecuteThunks( return Status::OK(); } +StatusOr +GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { + tensorflow::mutex_lock lock(module_handle_mutex_); + auto it = module_globals_.find(executor); + if (it != module_globals_.end()) { + return &it->second; + } + + se::MultiModuleLoaderSpec module_spec; + if (!cubin().empty()) { + module_spec.AddCudaCubinInMemory(cubin()); + } + module_spec.AddCudaPtxInMemory(ptx().c_str()); + + tensorflow::gtl::FlatMap globals; + se::ModuleHandle module_handle; + executor->LoadModule(module_spec, &module_handle); + + for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); + ++i) { + const BufferAllocation& allocation = assignment_->GetAllocation(i); + if (allocation.is_constant()) { + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase global, + executor->GetUntypedSymbol( + llvm_ir::ConstantBufferAllocationToGlobalName(allocation), + module_handle)); + VLOG(3) << "Resolved global " + << llvm_ir::ConstantBufferAllocationToGlobalName(allocation) + << " to " << global.opaque(); + InsertOrDie(&globals, i, global); + + const Literal& literal = + llvm_ir::LiteralForConstantAllocation(allocation); + CHECK(ShapeUtil::IsArray(literal.shape())); + if (!ShouldEmitLiteralInLlvmIr(literal)) { + VLOG(3) << "H2D memcpy for constant with shape " + << ShapeUtil::HumanString(literal.shape()); + TF_RETURN_IF_ERROR(executor->SynchronousMemcpyH2D( + literal.untyped_data(), allocation.size(), &global)); + } + } + } + + module_handles_.emplace(executor, + se::ScopedModuleHandle(executor, module_handle)); + return &module_globals_.emplace(executor, std::move(globals)).first->second; +} + StatusOr GpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -193,6 +242,10 @@ StatusOr GpuExecutable::ExecuteOnStream( } BufferAllocations::Builder buffer_allocations_builder; + se::StreamExecutor* executor = run_options->stream()->parent(); + + TF_ASSIGN_OR_RETURN(auto* const globals, ResolveConstantGlobals(executor)); + for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); ++i) { const BufferAllocation& allocation = assignment_->GetAllocation(i); @@ -214,8 +267,12 @@ StatusOr GpuExecutable::ExecuteOnStream( buffer_allocations_builder.RegisterBuffer(i, buffer); } + + if (allocation.is_constant()) { + buffer_allocations_builder.RegisterBuffer(i, FindOrDie(*globals, i)); + } } - se::StreamExecutor* executor = run_options->stream()->parent(); + TF_ASSIGN_OR_RETURN( auto buffer_allocations, buffer_allocations_builder.Build( @@ -236,7 +293,7 @@ StatusOr GpuExecutable::ExecuteOnStream( // the respective location in ShapedBuffer. std::set buffers_in_result; TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachMutableElementWithStatus( - [&buffer_allocations, &buffers_in_result, &shaped_buffer, this]( + [&buffer_allocations, &buffers_in_result, this]( const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { const auto& sources = this->GetRootPointsToSet().element(index); // The points-to set is unambiguous so the set should be a diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 80ec38c3ac114fe4ad9d56784330c1144d913db1..c7ce6d0acbbbe594040271c0d45c71c016e36514 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -34,6 +34,8 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -66,7 +68,7 @@ class GpuExecutable : public Executable { } // Returns the compiled PTX for the computation. - tensorflow::StringPiece ptx() const { return ptx_; } + const string& ptx() const { return ptx_; } // Returns the cubin (compiled PTX) stored in this GpuExecutable. May be // empty, in which case compilation is left up to the GPU driver. @@ -98,6 +100,15 @@ class GpuExecutable : public Executable { // computation. Uses points-to analysis from buffer assignment. const PointsToSet& GetRootPointsToSet() const; + using BufferAllocToDeviceMemoryMap = + tensorflow::gtl::FlatMap; + + // Loads the PTX or CUBIN for this executable into `executor` and resolves the + // globals corresponding to constant buffers. Returns a map mapping buffer + // allocation indices to GPU pointers. + StatusOr ResolveConstantGlobals( + stream_executor::StreamExecutor* executor); + // The LLVM IR, in string format, of the unoptimized module generated for this // GpuExecutable. We save a string instead of an llvm::Module* because leaving // llvm::Module* in a singleton can cause the heap checker to emit false @@ -126,6 +137,14 @@ class GpuExecutable : public Executable { // memory for every output/temp buffers. const std::unique_ptr assignment_; + // Cache of module handles and constant buffer allocation maps used by + // `ResolveConstantGlobals`. + tensorflow::mutex module_handle_mutex_; + std::map + module_handles_ GUARDED_BY(module_handle_mutex_); + std::map + module_globals_ GUARDED_BY(module_handle_mutex_); + TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable); }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 8bf62dde8b9948375fc493fd1a524cfa7b062502..d033faee8d25ed81a1483f8314652ef999ab36c5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -31,52 +31,58 @@ limitations under the License. namespace xla { namespace gpu { -using stream_executor::dnn::DataLayout; -using stream_executor::dnn::FilterLayout; - -static bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) { - int major, minor; - CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major, - &minor)); - return major >= 7; -} +using se::dnn::DataLayout; +using se::dnn::FilterLayout; // Returns (input, filter, output) layouts. static std::tuple HeuristicLayoutAssignment(const HloInstruction* instr, - stream_executor::StreamExecutor* stream_executor) { + se::StreamExecutor* stream_executor) { // DataLayout and FilterLayout uses weird enum names. Translations: // N <=> Batch or Output // C <=> Depth or Input // H <=> Y // W <=> X // - // Therefore kOutputInputYX means NHWC; kBatchDepthYX means NCHW. + // Therefore kOutputInputYX and kBatchDepthYX mean NCHW. + // + // If you have trouble keeping these straight, consider that all that matters + // is the location of the channel dim: Is it major (NCHW), or minor (NHWC)? + + constexpr auto kAllNCHW = + std::make_tuple(DataLayout::kBatchDepthYX, FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX); + constexpr auto kAllNHWC = + std::make_tuple(DataLayout::kBatchYXDepth, FilterLayout::kOutputYXInput, + DataLayout::kBatchYXDepth); - // As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x - // fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version - // changes, as well as the hardware updates. + // If we're not Volta or not fp16, the decision is easy: Use NCHW. if (!(instr->operand(0)->shape().element_type() == xla::PrimitiveType::F16 && IsVoltaOrLater(*stream_executor))) { - return std::make_tuple(DataLayout::kBatchDepthYX, - FilterLayout::kOutputInputYX, - DataLayout::kBatchDepthYX); + return kAllNCHW; } + VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString(); - // For BackwardInput that has stride, full NHWC layouts run significantly - // slower than (NHWC, NCHW, NCHW) or (NHWC, NCHW, NHWC). + + // Empirically we've found with Volta and cudnn 7 that backward-input convs + // with stride are significantly faster with NCHW layouts. // - // TODO(timshen): more closely compare (NHWC, NCHW, NCHW) and (NHWC, NCHW, - // NHWC). + // We could have used a mixed layout combination, e.g. (NHWC, NCHW, NCHW), + // which on paper gives good performance. However, there are two observations: + // * a mixed layout combination is more cuDNN-bug prone, based on empirical + // envidence. + // * we've also observed that for mixed layouts, cuDNN transposes data back + // and forth from a different layout combination. If we end up with + // transposes anyway, we prefer to have them in XLA, as they can be fused. + // TODO(timshen): Figure out the exact condition. This may be achieved by + // auto-tuning layouts offline. if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && window_util::HasStride(instr->window())) { - return std::make_tuple(DataLayout::kBatchYXDepth, - FilterLayout::kOutputInputYX, - DataLayout::kBatchDepthYX); + return kAllNCHW; } - return std::make_tuple(DataLayout::kBatchYXDepth, - FilterLayout::kOutputYXInput, - DataLayout::kBatchYXDepth); + + // For other Volta f16 convolutions, use NHWC. + return kAllNHWC; } // Adds layout constraints on the cudnn custom-call instruction. The layout @@ -170,6 +176,38 @@ Status GpuLayoutAssignment::AddBackendConstraints( TF_RETURN_IF_ERROR( AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); } + + // For batched dot we require the default layout. + // TODO(b/112111608): This is overly conservative, the only real restriction + // is that batch dimensions must be major. + if (instruction->opcode() == HloOpcode::kDot && + ImplementedAsGemm(*instruction) && + instruction->dot_dimension_numbers().lhs_batch_dimensions_size() > 0) { + // Verify that the batch dims come before the row and col dims. + const DotDimensionNumbers& dim_nums = + instruction->dot_dimension_numbers(); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), + dim_nums.rhs_batch_dimensions_size()); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, + ShapeUtil::Rank(instruction->shape())); + for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) { + CHECK_LT(batch_dim, ShapeUtil::Rank(instruction->shape()) - 2); + } + + // Set both inputs and the output to default layout. + Shape op0_shape = instruction->operand(0)->shape(); + LayoutUtil::SetToDefaultLayout(&op0_shape); + Shape op1_shape = instruction->operand(1)->shape(); + LayoutUtil::SetToDefaultLayout(&op1_shape); + Shape output_shape = instruction->shape(); + LayoutUtil::SetToDefaultLayout(&output_shape); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(op0_shape, instruction, 0)); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(op1_shape, instruction, 1)); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(output_shape, instruction)); + } } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index e48165c1426ea04839c245bc20b851a0f1710246..286547ebae2f1a4b8d783a06d13b4dd96052b952 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -31,6 +33,8 @@ namespace xla { namespace gpu { namespace { +namespace op = xla::testing::opcode_matchers; + using LayoutAssignmentTest = HloTestBase; TEST_F(LayoutAssignmentTest, Elementwise) { @@ -132,10 +136,10 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { HloInstruction::CreateParameter(4, aux_shape, "variance")); auto* epsilon = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto* feature_index = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(kFeatureIndex))); + LiteralUtil::CreateR0(kFeatureIndex))); auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall( shape, @@ -201,10 +205,10 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { HloInstruction::CreateParameter(2, offset_scale_shape, "offset")); auto* epsilon = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto* feature_index = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(kFeatureIndex))); + LiteralUtil::CreateR0(kFeatureIndex))); auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall( batchnorm_shape, {operand, scale, offset, epsilon, feature_index}, @@ -278,10 +282,10 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { HloInstruction::CreateParameter(4, shape, "var")); auto* epsilon = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto* feature_index = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(kFeatureIndex))); + LiteralUtil::CreateR0(kFeatureIndex))); auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall( @@ -327,6 +331,33 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { } } +TEST_F(LayoutAssignmentTest, DotLayout) { + const char* hlo_text = R"( + HloModule DotLayout + ENTRY dot { + p0 = f32[8,8,256,64]{3,1,2,0} parameter(0) + p1 = f32[8,8,256,64]{3,1,2,0} parameter(1) + ROOT dot.1330.10585 = f32[8,8,256,256]{3,2,1,0} dot(p0, p1), + lhs_batch_dims={0,1}, lhs_contracting_dims={3}, + rhs_batch_dims={0,1}, rhs_contracting_dims={3} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape()); + GpuLayoutAssignment layout_assignment(&computation_layout, + backend().default_stream_executor()); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); + + Shape expected_shape = + ShapeUtil::MakeShapeWithLayout(F32, {8, 8, 256, 64}, {3, 2, 1, 0}); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::ShapeWithLayout(expected_shape), + op::ShapeWithLayout(expected_shape))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 7bb8df6581b49b1bf8c84a972f715e8dc119d8de..a2f53f844613da9fe8166489dc9959e8d30c6332 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include "llvm/IR/DataLayout.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" +#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h" +#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,15 +36,14 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { +namespace gpu { // TODO(b/30467474) Once GPU infeed implementation settles, consider // folding back the cpu and gpu infeed implementations into a generic // one if possible. -GpuTransferManager::GpuTransferManager() - : GenericTransferManager( - se::cuda::kCudaPlatformId, - /*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout) - .getPointerSize(0 /* default address space */)) {} +GpuTransferManager::GpuTransferManager(se::Platform::Id id, + unsigned pointer_size) + : GenericTransferManager(id, pointer_size) {} Status GpuTransferManager::TransferLiteralToInfeed( se::StreamExecutor* executor, const LiteralSlice& literal) { @@ -50,53 +51,28 @@ Status GpuTransferManager::TransferLiteralToInfeed( VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); - if (!ShapeUtil::IsTuple(shape)) { - int64 size = GetByteSizeRequirement(shape); - return TransferBufferToInfeed(executor, size, literal.untyped_data()); - } - - if (ShapeUtil::IsNestedTuple(shape)) { - return Unimplemented( - "Infeed with a nested tuple shape is not supported: %s", - ShapeUtil::HumanString(literal.shape()).c_str()); - } - // For a tuple, we transfer each of its elements to the device and // enqueue the resulting destination device addresses with the // infeed manager. - std::vector buffers; - buffers.reserve(ShapeUtil::TupleElementCount(shape)); - auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() { - for (gpu::InfeedBuffer* b : buffers) { - b->Done(); - } - }); - - for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - const Shape& tuple_element_shape = - ShapeUtil::GetTupleElementShape(shape, i); - int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape); - TF_ASSIGN_OR_RETURN( - gpu::InfeedBuffer * buffer, - TransferBufferToInfeedInternal(executor, tuple_element_size, - literal.untyped_data({i}))); - buffers.push_back(buffer); - } - - cleanup.release(); - return EnqueueBuffersToInfeed(executor, buffers); -} - -Status GpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor, - int64 size, - const void* source) { - TF_ASSIGN_OR_RETURN(gpu::InfeedBuffer * buffer, - TransferBufferToInfeedInternal(executor, size, source)); - return EnqueueBuffersToInfeed(executor, {buffer}); + ShapeTree buffer_tree(shape); + + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + shape, [&](const Shape& literal_subshape, const ShapeIndex& index) { + if (ShapeUtil::IsArray(literal_subshape)) { + int64 tuple_element_size = GetByteSizeRequirement(literal_subshape); + TF_ASSIGN_OR_RETURN( + *buffer_tree.mutable_element(index), + TransferBufferToInfeedInternal(executor, tuple_element_size, + literal.untyped_data(index))); + } + return Status::OK(); + })); + + return EnqueueBuffersToInfeed(executor, std::move(buffer_tree)); } Status GpuTransferManager::EnqueueBuffersToInfeed( - se::StreamExecutor* executor, std::vector buffers) { + se::StreamExecutor* executor, ShapeTree buffers) { gpu::InfeedManager* infeed_manager = gpu::GetOrCreateInfeedManager(); se::Stream* stream = infeed_manager->GetStream(executor); @@ -106,21 +82,18 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( // possible. Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { - for (gpu::InfeedBuffer* b : buffers) { - b->Done(); - } return InternalError("Failed to complete data transfer on stream %p: %s", stream, block_status.error_message().c_str()); } - infeed_manager->EnqueueBuffers(buffers); + infeed_manager->EnqueueDestination(std::move(buffers)); VLOG(2) << "Infeed data transferred"; return Status::OK(); } -StatusOr GpuTransferManager::TransferBufferToInfeedInternal( +StatusOr GpuTransferManager::TransferBufferToInfeedInternal( se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits::max()) { return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); @@ -136,23 +109,85 @@ StatusOr GpuTransferManager::TransferBufferToInfeedInternal( return InternalError("Failed to obtain a stream"); } - gpu::InfeedBuffer* buffer = new gpu::InfeedBuffer(executor, size); - stream->ThenMemcpy(buffer->device_memory(), source, size); + InfeedBuffer buffer(executor, size); + stream->ThenMemcpy(buffer.device_memory(), source, size); VLOG(2) << "Queued infeed data on stream " << stream; - return buffer; + return std::move(buffer); +} + +static void ShapeTreeToLiteral( + ShapeTree>* shape_tree) { + // This is a struct instead of a lambda for std::function-free recursion. + struct Helper { + static void helper( + ShapeTree>* shape_tree, + ShapeIndex* index) { + const Shape& shape = ShapeUtil::GetSubshape(shape_tree->shape(), *index); + if (ShapeUtil::IsArray(shape)) { + (*shape_tree->mutable_element(*index))->WaitUntilAvailable(); + return; + } + + CHECK(ShapeUtil::IsTuple(shape)) + << ShapeUtil::HumanStringWithLayout(shape); + const int64 tuple_element_count = ShapeUtil::TupleElementCount(shape); + index->push_back(0); + for (int64 i = 0; i < tuple_element_count; ++i) { + index->back() = i; + helper(shape_tree, index); + } + index->pop_back(); + } + }; + ShapeIndex index; + Helper::helper(shape_tree, &index); +} + +Status GpuTransferManager::TransferLiteralFromOutfeed( + se::StreamExecutor* /*executor*/, const Shape& literal_shape, + MutableBorrowingLiteral literal) { + ShapeTree> outfeed_buffers( + &literal_shape); + + // First create a tree of literal buffers that the device can write to. + outfeed_buffers.ForEachMutableElement( + [&](const ShapeIndex& index, + std::unique_ptr* buffer) { + const Shape& shape = ShapeUtil::GetSubshape(literal_shape, index); + // Do not transfer tuple index buffers. + if (ShapeUtil::IsTuple(shape)) { + return; + } + *buffer = MakeUnique(GetByteSizeRequirement(shape)); + (*buffer)->set_destination( + MakeUnique(literal, index)); + }); + + // Give the tree of buffers to the outfeed mananger. The device will fill it + // while we're waiting for it below. + gpu::OutfeedManager* outfeed_manager = gpu::GetOrCreateOutfeedManager(); + outfeed_manager->EnqueueDestination(&outfeed_buffers); + + // Now wait for the tree of buffers are written. + ShapeTreeToLiteral(&outfeed_buffers); + return Status::OK(); } +} // namespace gpu } // namespace xla -static std::unique_ptr CreateGpuTransferManager() { - return xla::MakeUnique(); +static std::unique_ptr CreateNVPTXTransferManager() { + return xla::MakeUnique( + /*id=*/stream_executor::cuda::kCudaPlatformId, + /*pointer_size=*/llvm::DataLayout(xla::gpu::NVPTXCompiler::kDataLayout) + .getPointerSize(0 /* default address space */)); } static bool InitModule() { xla::TransferManager::RegisterTransferManager( - stream_executor::cuda::kCudaPlatformId, &CreateGpuTransferManager); + stream_executor::cuda::kCudaPlatformId, &CreateNVPTXTransferManager); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h index 09f8227f508a3159f3def285898e15bfad544552..7929042869763dfeab2fe8f87093b7ea758337d0 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" @@ -28,33 +29,36 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { +namespace gpu { // An implementation of the XLA GenericTransferManager that // handles GPU-specific infeed. class GpuTransferManager : public GenericTransferManager { public: - GpuTransferManager(); + GpuTransferManager(se::Platform::Id id, unsigned pointer_size); ~GpuTransferManager() override {} Status TransferLiteralToInfeed(se::StreamExecutor* executor, const LiteralSlice& literal) override; - Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, - const void* source) override; + Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, + const Shape& literal_shape, + MutableBorrowingLiteral literal) override; private: // Initiates the infeed data transfers. InfeedBuffer->Done() must be // called to clean up the memory allocated for InfeedBuffer. - StatusOr TransferBufferToInfeedInternal( + StatusOr TransferBufferToInfeedInternal( se::StreamExecutor* executor, int64 size, const void* source); // Enqueues infeed data buffers with the infeed manager after their // transfer completes. Status EnqueueBuffersToInfeed(se::StreamExecutor* executor, - std::vector buffers); + ShapeTree buffers); TF_DISALLOW_COPY_AND_ASSIGN(GpuTransferManager); }; +} // namespace gpu } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc index daddd3738e4bb54f3695a96f6f9ffb9accabe97c..17226769302eef0dd01550b0bc5404e889ad78f8 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc @@ -16,21 +16,43 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include +#include +#include #include #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/util/ptr_util.h" namespace xla { namespace gpu { +namespace { +void InitAndStartTimer(std::stack>* timers, + se::Stream* stream) { + timers->push(MakeUnique(stream->parent())); + stream->InitTimer(timers->top().get()).ThenStartTimer(timers->top().get()); +} + +uint64 GetCyclesTaken(std::stack>* timers, + const std::vector& sub_streams, + se::Stream* stream, double clock_rate_ghz) { + CHECK_GT(timers->size(), 0); + stream->ThenWaitFor(&sub_streams); + stream->ThenStopTimer(timers->top().get()); + stream->BlockHostUntilDone().IgnoreError(); + double nanoseconds = timers->top()->Nanoseconds(); + timers->pop(); + return static_cast(nanoseconds * clock_rate_ghz); +} +} // namespace HloExecutionProfiler::HloExecutionProfiler( bool do_profile, HloExecutionProfile* profile, se::Stream* stream, - const std::vector::SmartPtr>& sub_streams, + const std::vector& sub_streams, const HloComputation* computation) : do_profile_(do_profile), profile_(profile), @@ -39,11 +61,7 @@ HloExecutionProfiler::HloExecutionProfiler( computation_(computation) { if (do_profile_) { clock_rate_ghz_ = stream->parent()->GetDeviceDescription().clock_rate_ghz(); - execution_timer_.reset(new se::Timer(stream->parent())); - per_op_timer_.reset(new se::Timer(stream->parent())); - stream->InitTimer(execution_timer_.get()) - .ThenStartTimer(execution_timer_.get()); - stream->InitTimer(per_op_timer_.get()); + InitAndStartTimer(&timers_, stream); } } @@ -51,31 +69,53 @@ void HloExecutionProfiler::FinishExecution() { CHECK(!finished_execution_) << "Call FinishExecution only once!"; finished_execution_ = true; if (do_profile_) { - stream_->ThenWaitFor(&sub_streams_); - stream_->ThenStopTimer(execution_timer_.get()); - stream_->BlockHostUntilDone().IgnoreError(); profile_->set_total_cycles_executed( *computation_, - static_cast(execution_timer_->Nanoseconds() * clock_rate_ghz_)); + GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_)); + } +} + +void HloExecutionProfiler::StartHloComputation() { + if (do_profile_) { + InitAndStartTimer(&timers_, stream_); + } +} + +void HloExecutionProfiler::FinishHloComputation( + const HloComputation* computation) { + if (do_profile_) { + profile_->set_total_cycles_executed( + *computation, + GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_)); } } -void HloExecutionProfiler::StartOperation() { +void HloExecutionProfiler::StartHloInstruction() { if (do_profile_) { - stream_->ThenStartTimer(per_op_timer_.get()); + InitAndStartTimer(&timers_, stream_); } } -void HloExecutionProfiler::FinishOperation( +void HloExecutionProfiler::FinishHloInstruction( const HloInstruction* hlo_instruction) { if (do_profile_) { - stream_->ThenWaitFor(&sub_streams_); - stream_->ThenStopTimer(per_op_timer_.get()); - stream_->BlockHostUntilDone().IgnoreError(); + hlo_instructions_.erase(hlo_instruction); profile_->SetCyclesTakenBy( hlo_instruction, - static_cast(per_op_timer_->Nanoseconds() * clock_rate_ghz_)); + GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_)); + } +} + +std::unique_ptr +HloExecutionProfiler::MakeScopedInstructionProfiler( + const HloInstruction* hlo_instruction) { + if (do_profile_ && hlo_instruction != nullptr) { + // Make sure that we are not already measuring the time for the same + // 'hlo_instruction'. + CHECK(hlo_instructions_.insert(hlo_instruction).second) + << hlo_instruction->name(); } + return MakeUnique(this, hlo_instruction); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h index c9b882ff805c45a57f15df4fe79dc34100c0ceff..80cde75f2bbb555f514fffea58ad92edf92fd0d1 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h @@ -17,51 +17,93 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_ #include +#include +#include #include #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace gpu { +class ScopedInstructionProfiler; + // A helper class for profiling HLO in the course of GPU program execution. // All of the profiling is guarded internally, to avoid the caller needing to // have lots of conditionals sprinkled around. class HloExecutionProfiler { public: // If profiling is enabled, start an execution timer running. - explicit HloExecutionProfiler( - bool do_profile, HloExecutionProfile* profile, se::Stream* stream, - const std::vector::SmartPtr>& sub_streams, - const HloComputation* computation); + explicit HloExecutionProfiler(bool do_profile, HloExecutionProfile* profile, + se::Stream* stream, + const std::vector& sub_streams, + const HloComputation* computation); // If profiling is enabled, sets the total cycle count on the profile from the // execution timer. void FinishExecution(); - // If profiling is enabled, starts the per-operation timer. - void StartOperation(); + // If profiling is enabled, starts a timer for a (sub)computation. + void StartHloComputation(); + + // If profiling is enabled stops the timer for a (sub)computation and records + // the time that the computation took to execute in the profile. + void FinishHloComputation(const HloComputation* computation); + + // If profiling is enabled, starts a per-operation timer. + void StartHloInstruction(); // If profiling is enabled, stops the per-operation timer and records the time // that the hlo_instruction took to execute in the profile. - void FinishOperation(const HloInstruction* hlo_instruction); + void FinishHloInstruction(const HloInstruction* hlo_instruction); + + // Returns a ScopedInstructionProfiler and triggers a call to + // StartHloInstruction(). Once the returned ScopedInstructionProfiler goes + // out of scope, it triggers a call to FinishHloInstruction(). + std::unique_ptr MakeScopedInstructionProfiler( + const HloInstruction* hlo_instruction); private: const bool do_profile_; double clock_rate_ghz_; HloExecutionProfile* profile_; se::Stream* stream_; - const std::vector::SmartPtr>& sub_streams_; + const std::vector& sub_streams_; const HloComputation* computation_; - std::unique_ptr execution_timer_; - std::unique_ptr per_op_timer_; + std::stack> timers_; + // Contains the HLO instructions for which we are currently measuring the + // time. + std::unordered_set hlo_instructions_; bool finished_execution_ = false; }; +// This class can be used within the ExecuteOnStream() implementations of +// Thunks. It ensures that we always have a pair of matching +// StartHloInstruction() and FinishHloInstruction() calls to the profiler. +class ScopedInstructionProfiler { + public: + ScopedInstructionProfiler(HloExecutionProfiler* profiler, + const HloInstruction* hlo_instruction) + : profiler_(profiler), hlo_instruction_(hlo_instruction) { + if (hlo_instruction != nullptr) { + profiler->StartHloInstruction(); + } + } + ~ScopedInstructionProfiler() { + if (hlo_instruction_ != nullptr) { + profiler_->FinishHloInstruction(hlo_instruction_); + } + } + + private: + HloExecutionProfiler* profiler_; + const HloInstruction* hlo_instruction_; +}; + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index 375709150e08996ea6a40f5e9e66a8f8d9287008..19de37b0fbed15455e8c6a9bfe427ba3d9f0a9dc 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -100,7 +100,7 @@ GpuHloOrdering::GpuHloOrdering( if (last_instruction_per_stream[stream_no] != nullptr) { immediate_preds.push_back(last_instruction_per_stream[stream_no]); } - predecessor_map->SetReachabilityToUnion(immediate_preds, hlo); + predecessor_map->FastSetReachabilityToUnion(immediate_preds, hlo); last_instruction_per_stream[stream_no] = hlo; } else { // Only parameters and constants don't have an assigned stream, since they diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index d420863b8569771b16a03591b6a0ddd0591f7e2e..8c11cd05419289d82b033c936bb60884f45cb636 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -18,8 +18,10 @@ limitations under the License. #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -39,7 +41,7 @@ void HloToIrBindings::EmitBasePointersForHlos( // I/O HLOs are bound to the arguments of the current IR function. I.e., // // void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) { - llvm::Function* function = ir_builder_->GetInsertBlock()->getParent(); + llvm::Function* function = b_->GetInsertBlock()->getParent(); CHECK_EQ(io_hlos.size() + 1, function->arg_size()); // An HLO can have duplicated operands. This data structure remembers which @@ -79,8 +81,8 @@ void HloToIrBindings::EmitBasePointersForHlos( const int64 offset = slice.offset(); CHECK_NE(nullptr, temp_buffer_base_); // Emit IR for GetTupleElement instruction and bind to emitted value. - llvm::Value* base_ptr = ir_builder_->CreateInBoundsGEP( - temp_buffer_base_, ir_builder_->getInt64(offset)); + llvm::Value* base_ptr = + b_->CreateInBoundsGEP(temp_buffer_base_, b_->getInt64(offset)); BindHloToIrValue(*non_io_hlo, EmitGetTupleElement(non_io_hlo, base_ptr)); } @@ -108,15 +110,20 @@ void HloToIrBindings::EmitBasePointersForHlos( if (slice.allocation()->is_thread_local()) { llvm::Type* pointee_type = llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_); - BindHloToIrValue(*non_io_hlo, - ir_builder_->CreateAlloca(pointee_type), index); + BindHloToIrValue(*non_io_hlo, b_->CreateAlloca(pointee_type), + index); + } else if (slice.allocation()->is_constant()) { + llvm::Value* global_for_constant = + module_->getGlobalVariable(llvm_ir::AsStringRef( + llvm_ir::ConstantBufferAllocationToGlobalName( + *slice.allocation()))); + BindHloToIrValue(*non_io_hlo, global_for_constant); } else { const int64 offset = slice.offset(); CHECK_NE(nullptr, temp_buffer_base_); BindHloToIrValue( *non_io_hlo, - ir_builder_->CreateInBoundsGEP(temp_buffer_base_, - ir_builder_->getInt64(offset)), + b_->CreateInBoundsGEP(temp_buffer_base_, b_->getInt64(offset)), index); } }); @@ -129,11 +136,19 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) { return llvm_ir::EmitGetTupleElement( gte->shape(), gte->tuple_index(), /*alignment=*/1, - GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_, module_); + GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_, module_); } return llvm_ir::EmitGetTupleElement( gte->shape(), gte->tuple_index(), /*alignment=*/1, - EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_, module_); + EmitGetTupleElement(gte->operand(0), base_ptr), b_, module_); +} + +// Returns true if `value` has a name that should not be changed. +static bool HasMeaningfulName(llvm::Value* value) { + if (auto* global = llvm::dyn_cast(value)) { + return global->getLinkage() != llvm::GlobalValue::PrivateLinkage; + } + return false; } llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, @@ -145,14 +160,18 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, llvm::Value* typed_ir_value; if (llvm::isa(ir_value)) { - typed_ir_value = llvm::ConstantExpr::getBitCast( + typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast( llvm::cast(ir_value), dest_type); } else { - typed_ir_value = - ir_builder_->CreateBitCast(ir_value, pointee_type->getPointerTo()); + typed_ir_value = b_->CreateBitCast(ir_value, pointee_type->getPointerTo()); + } + if (!HasMeaningfulName(ir_value)) { + ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw"))); + } + if (!HasMeaningfulName(typed_ir_value)) { + typed_ir_value->setName( + llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed"))); } - ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "raw"))); - typed_ir_value->setName(llvm_ir::AsStringRef(llvm_ir::IrName(&hlo, "typed"))); return typed_ir_value; } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index a86e6e78c693ac53bb2c70d88b999a4e1273ecad..eee40b0e91fc03013a6978ae3cfe42b87633eed7 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -36,14 +36,13 @@ class HloToIrBindings { public: HloToIrBindings(const HloModule& module, const BufferAssignment* buffer_assignment, - llvm::IRBuilder<>* ir_builder, llvm::Module* llvm_module, + llvm::IRBuilder<>* b, llvm::Module* llvm_module, bool is_nested) : buffer_assignment_(buffer_assignment), is_nested_(is_nested), - ir_builder_(ir_builder), + b_(b), module_(llvm_module), - alias_analysis_(module, *buffer_assignment_, - &ir_builder_->getContext()) {} + alias_analysis_(module, *buffer_assignment_, &b_->getContext()) {} void EmitBasePointersForHlos( tensorflow::gtl::ArraySlice io_hlos, @@ -104,7 +103,7 @@ class HloToIrBindings { const bool is_nested_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; llvm::Module* module_; // Stores the underlying llvm::IrArray for each HloInstruction. diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc index ae310beefad0c81c17fd4140b441b3a19a002e2c..c5f0cdf6cd5d3e076bffa875fbba991bf0681ee8 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc @@ -15,76 +15,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" -#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/core/platform/logging.h" namespace xla { namespace gpu { -InfeedManager::InfeedManager() : host_to_device_executor_(nullptr) {} - -void InfeedManager::Reset() { - tensorflow::mutex_lock l(mu_); - CHECK(dequeued_buffer_.empty()); - for (auto buffer : enqueued_buffer_) { - buffer->Done(); - } - enqueued_buffer_.clear(); -} - -void InfeedManager::EnqueueBuffers(const std::vector& buffers) { - tensorflow::mutex_lock l(mu_); - bool was_empty = enqueued_buffer_.empty(); - for (gpu::InfeedBuffer* b : buffers) { - enqueued_buffer_.push_back(b); - } - if (was_empty) { - // This has the potential to suffer from the notified thread - // immediately trying and failing to acquire mu_, but seems - // preferable to the alternative of notifying outside the lock - // on every enqueue. - cv_.notify_one(); - } -} - -InfeedBuffer* InfeedManager::BlockingDequeueBuffer() { - bool became_empty = false; - InfeedBuffer* current_buffer; - { - tensorflow::mutex_lock l(mu_); - while (enqueued_buffer_.empty()) { - cv_.wait(l); - } - current_buffer = enqueued_buffer_.front(); - enqueued_buffer_.pop_front(); - dequeued_buffer_.insert(current_buffer); - if (enqueued_buffer_.empty()) { - became_empty = true; - } - } - if (became_empty) { - for (const auto& callback : on_empty_callbacks_) { - callback(); - } - } - return current_buffer; -} - -void InfeedManager::ReleaseBuffers(const std::vector& buffers) { - { - tensorflow::mutex_lock l(mu_); - for (gpu::InfeedBuffer* b : buffers) { - CHECK(ContainsKey(dequeued_buffer_, b)); - dequeued_buffer_.erase(b); - } - } - for (gpu::InfeedBuffer* b : buffers) { - b->Done(); - } -} - se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { + tensorflow::mutex_lock l(host_to_device_stream_mu_); if (host_to_device_executor_ == nullptr) { host_to_device_executor_ = executor; host_to_device_stream_ = MakeUnique(executor); @@ -100,10 +37,6 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { return host_to_device_stream_.get(); } -void InfeedManager::RegisterOnEmptyCallback(std::function callback) { - on_empty_callbacks_.push_back(std::move(callback)); -} - InfeedManager* GetOrCreateInfeedManager() { static InfeedManager* manager = new InfeedManager; return manager; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.h b/tensorflow/compiler/xla/service/gpu/infeed_manager.h index a3fc15cfe36a490f38daabca9ff36fbb1012aead..7e418882e051a77e10bd12000bbc9769980f5f14 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.h @@ -20,12 +20,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_ -#include -#include - +#include "tensorflow/compiler/xla/service/gpu/xfeed_queue.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -47,90 +44,41 @@ namespace gpu { // the client. The client manages the memory of the buffer. class InfeedBuffer { public: + InfeedBuffer() = default; InfeedBuffer(se::StreamExecutor* executor, int64 length) - : executor_(executor), length_(length) { - device_memory_ = executor_->AllocateArray(length); - CHECK(!device_memory_.is_null()); + : device_memory_(executor, executor->AllocateArray(length)), + length_(length) { + CHECK(!device_memory_->is_null()); } - ~InfeedBuffer() { executor_->Deallocate(&device_memory_); } - int64 length() const { return length_; } - // Callback to signal that this buffer is consumed. This helps the - // client to manage memory for the infeed buffers. - void Done() { delete this; } - - se::DeviceMemoryBase* device_memory() { return &device_memory_; } + se::DeviceMemoryBase* device_memory() { return device_memory_.ptr(); } private: - se::StreamExecutor* executor_; // Not owned. - const int64 length_; - se::DeviceMemoryBase device_memory_; + se::ScopedDeviceMemory device_memory_; + int64 length_; }; // Client-side class used to enqueue infeed buffers. -class InfeedManager { +class InfeedManager : public XfeedQueue> { public: - InfeedManager(); - - // Calls the completion callback for any enqueued buffers that have - // not been dequeued by the runtime, and empties the infeed - // queue. Reset may not be called while a runtime computation is - // processing a dequeued buffer. The only safe way to ensure this - // condition is to call Reset when no computation is taking place. - void Reset(); - - // Adds a set of buffers to the infeed queue atomically. buffer->Done - // will be called when the buffer will no longer be accessed by the - // InfeedManager, either as a result of a call to Reset or because the - // runtime has dequeued and used the buffer. - void EnqueueBuffers(const std::vector& buffers); - - // Blocks until the infeed queue is non-empty, then returns the - // buffer at the head of the queue. Adds the current buffer to the - // to-be released set. - InfeedBuffer* BlockingDequeueBuffer(); - - // Releases a set of buffers from the to-be released set. - void ReleaseBuffers(const std::vector& buffers); - // Returns a cached stream associated with an executor. Allocates a // new stream on the first invocation. On subsequent invocations, if // the cached executor is not the same as the requested executor, // returns null. se::Stream* GetStream(se::StreamExecutor* executor); - // Registers a callback that will be called when 'enqueued_buffer_' becomes - // empty. - void RegisterOnEmptyCallback(std::function callback); - private: - // TODO(b/30467474): Revisit if this mutex becomes a point of - // contention. - tensorflow::mutex mu_; - - // Condition variable that is signaled every time a buffer is - // enqueued to an empty queue. - tensorflow::condition_variable cv_; - - // InfeedBuffer* queue contents are not owned, but buffer->Done must - // be called when the buffer is no longer needed by the runtime. - std::deque enqueued_buffer_; - - // Buffers that are dequeued and currently being processed by the - // runtime. Not owned. - tensorflow::gtl::FlatSet dequeued_buffer_; + // Mutex for serializing the creation of host_to_device_stream_. + tensorflow::mutex host_to_device_stream_mu_; // Cached host to device stream for queuing infeed data. - std::unique_ptr host_to_device_stream_; + std::unique_ptr host_to_device_stream_ + GUARDED_BY(host_to_device_stream_mu_); // Executor that the host_to_device_stream belongs to. Not owned. - se::StreamExecutor* host_to_device_executor_; - - // List of callbacks which will be called when 'enqueued_buffer_' becomes - // empty. - std::vector> on_empty_callbacks_; + se::StreamExecutor* host_to_device_executor_ = nullptr; }; // Singleton creator-or-accessor: Returns the GPU infeed manager. diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index 2b63d8727cb11f4369b17adb87bcba18ed2b8b65..fee6d2af3bfd4976f5845edf592e8310b55a3feb 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -27,51 +28,70 @@ InfeedThunk::InfeedThunk( : Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {} Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { - VLOG(2) << "Infeeding to GPU "; + se::Stream* stream, + HloExecutionProfiler* profiler) { + VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString(); + + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); + ShapeTree infeed_buffers = + GetOrCreateInfeedManager()->BlockingGetNextDestination(); + + { + // The infeed buffer has an extra outer tuple with a token. Adjust the index + // accordingly. + ShapeIndex index = {0}; + std::function*)> copy_tuple_contents = + [&](std::vector* tuple_element_addresses) { + const Shape& shape = ShapeUtil::GetSubshape(infeed_buffers.shape(), + ShapeIndexView(index, 1)); + // For the leaf buffers of the tuple copy the elements directly. + if (ShapeUtil::IsArray(shape)) { + const BufferAllocation::Slice& tuple_element_buffer = + infeed_slices_.element(index); + se::DeviceMemoryBase tuple_element_address = + buffer_allocations.GetDeviceAddress(tuple_element_buffer); + + InfeedBuffer* buffer = + infeed_buffers.mutable_element(ShapeIndexView(index, 1)); + stream->ThenMemcpy(&tuple_element_address, + *(buffer->device_memory()), buffer->length()); + tuple_element_addresses->push_back(tuple_element_address.opaque()); + return; + } + + const int64 tuple_element_count = ShapeUtil::TupleElementCount(shape); + index.push_back(0); + std::vector inner_tuple_element_addresses; + for (int64 i = 0; i < tuple_element_count; ++i) { + index.back() = i; + copy_tuple_contents(&inner_tuple_element_addresses); + } + index.pop_back(); + + // Create a buffer of pointers for non-leaf buffers. + CHECK_EQ(tuple_element_count, inner_tuple_element_addresses.size()); + auto host_size = inner_tuple_element_addresses.size() * sizeof(void*); + se::DeviceMemoryBase tuple_address = + buffer_allocations.GetDeviceAddress( + infeed_slices_.element(index)); + stream->ThenMemcpy(&tuple_address, + inner_tuple_element_addresses.data(), host_size); + tuple_element_addresses->push_back(tuple_address.opaque()); + }; - // First copy the infeed data which is element 0 of the infeed instruction's - // two-tuple output (the other element is a token). - se::DeviceMemoryBase data_address = - buffer_allocations.GetDeviceAddress(infeed_slices_.element({0})); - InfeedManager* infeed_manager = GetOrCreateInfeedManager(); - std::vector infeed_buffers; - const Shape& data_shape = - ShapeUtil::GetTupleElementShape(hlo_instruction()->shape(), 0); - if (ShapeUtil::IsTuple(data_shape)) { - CHECK(!ShapeUtil::IsNestedTuple(data_shape)); - // Transfer the tuple elements first. std::vector tuple_element_addresses; - for (int i = 0; i < ShapeUtil::TupleElementCount(data_shape); ++i) { - const BufferAllocation::Slice& tuple_element_buffer = - infeed_slices_.element({0, i}); - se::DeviceMemoryBase tuple_element_address = - buffer_allocations.GetDeviceAddress(tuple_element_buffer); - - InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); - infeed_buffers.push_back(buffer); - stream->ThenMemcpy(&tuple_element_address, *(buffer->device_memory()), - buffer->length()); - tuple_element_addresses.push_back(tuple_element_address.opaque()); - } - // Transfer the tuple outer buffer. - auto host_size = tuple_element_addresses.size() * sizeof(void*); - stream->ThenMemcpy(&data_address, tuple_element_addresses.data(), - host_size); - } else { - InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); - infeed_buffers.push_back(buffer); - stream->ThenMemcpy(&data_address, *(buffer->device_memory()), - buffer->length()); + copy_tuple_contents(&tuple_element_addresses); + CHECK_EQ(1, tuple_element_addresses.size()); } // Construct top-level tuple of infeed containing the data and the token. Use // a nullptr for the token, it should never be dereferenced. - std::vector infeed_addresses = {data_address.opaque(), nullptr}; + se::DeviceMemoryBase data_address = + buffer_allocations.GetDeviceAddress(infeed_slices_.element({0})); + void* infeed_addresses[] = {data_address.opaque(), nullptr}; se::DeviceMemoryBase top_level_address = buffer_allocations.GetDeviceAddress(infeed_slices_.element({})); - stream->ThenMemcpy(&top_level_address, infeed_addresses.data(), - 2 * sizeof(void*)); + stream->ThenMemcpy(&top_level_address, infeed_addresses, 2 * sizeof(void*)); Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { @@ -79,8 +99,6 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, stream, block_status.error_message().c_str()); } - infeed_manager->ReleaseBuffers(infeed_buffers); - VLOG(2) << "Infeeding to GPU complete"; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index cb9a6232f3bcdcbf37bc195069bac449a7217401..59487e245b78e66c45409fe712e86d3392e50580 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -40,7 +41,8 @@ class InfeedThunk : public Thunk { InfeedThunk& operator=(const InfeedThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const ShapeTree infeed_slices_; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 64ed3d748febd8281a8e602194b31c937a4a682a..0f2c83aeb2633a007559d8caac78ea2d233539ed 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -73,6 +73,67 @@ bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { } } +// This function limits the maximum number of operands to a fusion. +// +// There's a cap on how many parameters we can pass to a CUDA kernel, but +// exactly what that limit is is hazy, as it depends on (among other things) how +// much GPU constant memory is in use for other purposes. +// +// Moreover, we don't even know at the point that we're running fusion how many +// arguments the CUDA kernel for a fusion node will have: It depends on buffer +// assignment, where we will decide which of the fusion's operands live in XLA's +// big temp buffer versus in other allocations. +// +// As a heuristic, we simply cap the number of fusion operands plus outputs at +// kMaxOperandsAndOutputsPerFusion. This puts an upper bound on the number of +// parameters to the kernel, working around the correctness problem. +// +// This limit is also often good for performance. In a fusion with many +// operands, each GPU thread likely has to do a lot of work, and so possibly +// uses a lot of registers, thus limiting occupancy. +/*static*/ bool GpuInstructionFusion::FusionWouldBeTooLarge( + const HloInstruction* a, const HloInstruction* b) { + // Compute the number of outputs of the (possibly multi-output) fusion node + // we're considering creating. + // + // This isn't precise; we may be off by one if + // - We're creating a multi-output fusion out of two non-MOFs. Creating a + // MOF adds a new buffer, namely, the tuple buffer. + // - We're merging two MOFs. In this case, we should count the tuple buffer + // only once. + // - WLOG there's an edge from `a` to `b` and `b` is the only consumer of + // `a`. In this case the result of `a` is not part of the output of the + // fusion. + // + // But because this is a heuristic and our limit + // kMaxOperandsAndOutputsPerFusion is a large value (so +/- 1 doesn't make a + // big difference), we ignore this small inaccuracy in favor of simplicity. + int64 num_output_buffers = ShapeUtil::SubshapeCount(a->shape()) + + ShapeUtil::SubshapeCount(b->shape()); + + // The new fusion will have no more operands and outputs than + // producer_operands + consumer_operands - 1 + num_output_buffers + // (minus one because we may be fusing a producer->consumer edge between `a` + // and `b`). + // + // This fact may be enough to let us avoid having to compute the true total + // number of operands, which can be expensive. + if (a->operand_count() + b->operand_count() - 1 + num_output_buffers <= + kMaxOperandsAndOutputsPerFusion) { + return false; + } + + // Compute the precise number of operands to the new fusion. + tensorflow::gtl::FlatSet operands( + a->operands().begin(), a->operands().end()); + operands.insert(b->operands().begin(), b->operands().end()); + // If there's an edge between `a` and `b`, don't count it: We're fusing that + // producer -> consumer relationship. + operands.erase(a); + operands.erase(b); + return operands.size() + num_output_buffers > kMaxOperandsAndOutputsPerFusion; +} + bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); @@ -141,6 +202,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, IsIEEEFloatingPointScalarConstant(producer->operand(0)) && fused_parameter_users[0]->opcode() == HloOpcode::kMultiply; } + return false; } // Other output fusions are not currently supported on GPUs. @@ -183,8 +245,13 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return true; } - return IsFusile(*producer) && IsFusile(*consumer) && - InstructionFusion::ShouldFuse(consumer, operand_index); + if (!IsFusile(*producer) || !IsFusile(*consumer) || + !InstructionFusion::ShouldFuse(consumer, operand_index)) { + return false; + } + + // We put this check last because it's potentially expensive. + return !FusionWouldBeTooLarge(consumer, producer); } bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h index f629d9ff2c7165b652369612c30979150f93bd24..c91f6343a69268ca687004dbe0ffbb863271a95c 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -27,6 +27,19 @@ class GpuInstructionFusion : public InstructionFusion { explicit GpuInstructionFusion(bool may_duplicate) : InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate) {} + // Maximum number of operands plus outputs allowed on a single fusion node. + // Exposed publicly mainly for tests. + static constexpr int64 kMaxOperandsAndOutputsPerFusion = 64; + + // Determines whether the combination of `a` and `b` into a (possibly + // multi-output) fusion would be "too large" -- i.e., have more operands and + // outputs than is allowed. + // + // `ShouldFuse` and `ShouldFuseIntoMultiOutput` call this; it's public so that + // other fusion passes (e.g. GPU multi-output fusion) can also call this. + static bool FusionWouldBeTooLarge(const HloInstruction* a, + const HloInstruction* b); + static bool IsExpensive(const HloInstruction& instruction); bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 1963d9eef72d41fa0a275bea98f959671fa7e737..8d0522bd8fd6659e64d18c52807df8dc7fc2f3b8 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -33,7 +33,7 @@ TEST_F(InstructionFusionTest, CostlyProducerAndOperandElementReusingConsumerNotFused) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); HloInstruction* broadcast2 = @@ -53,7 +53,7 @@ TEST_F(InstructionFusionTest, NonCostlyProducerAndOperandElementReusingConsumerFused) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); HloInstruction* negate1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, const0)); HloInstruction* broadcast2 = @@ -73,7 +73,7 @@ TEST_F(InstructionFusionTest, CostlyProducerAndNonOperandElementReusingConsumerFused_Reshape) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); HloInstruction* reshape2 = builder.AddInstruction( @@ -92,7 +92,7 @@ TEST_F(InstructionFusionTest, CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); HloInstruction* transpose2 = builder.AddInstruction( @@ -606,5 +606,35 @@ TEST_F(InstructionFusionTest, FuseScalarConstant) { op::Parameter())); } +// Check that we limit the number of operands to fusions we create. +TEST_F(InstructionFusionTest, AvoidsLargeFusion) { + constexpr int64 kNumParams = 200; + ASSERT_GT(kNumParams, GpuInstructionFusion::kMaxOperandsAndOutputsPerFusion); + + // Compute p0 + p1 + ... + pN. + HloComputation::Builder b(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {10, 100}); + auto param0 = + b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p")); + auto sum = param0; + for (int64 i = 1; i < kNumParams; ++i) { + auto param = + b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p")); + sum = b.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sum, param)); + } + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(b.Build()); + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + SCOPED_TRACE(module->ToString()); + for (const HloInstruction* instr : computation->instructions()) { + EXPECT_LE(instr->operand_count(), + GpuInstructionFusion::kMaxOperandsAndOutputsPerFusion) + << instr->ToString(); + } +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 388aa35d7dceeef92dbdb6c8a3bb7fb3796a0b61..c349063c71f000435a05306101ad724505f2d197 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -38,24 +38,27 @@ namespace gpu { namespace { // Return whether the given shape is a matrix with no padding. -bool IsRank2WithNoPadding(const Shape& shape) { - return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); +bool IsRank2WithNoPadding(const Shape& shape, int64 batch_dimensions_size) { + return ShapeUtil::Rank(shape) == batch_dimensions_size + 2 && + !LayoutUtil::IsPadded(shape); } // In a gemm operation where output = lhs * rhs, check whether the given shapes // are valid for the operation. bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape) { + const Shape& output_shape, + int64 batch_dimensions_size) { // The inputs and the output must // 1) be matrices with no padding and a non-zero number of elements, // 2) have an allowed element type. PrimitiveType output_primitive_type = output_shape.element_type(); bool type_is_allowed = (output_primitive_type == F16 || output_primitive_type == F32 || - output_primitive_type == F64); - return type_is_allowed && IsRank2WithNoPadding(lhs_shape) && - IsRank2WithNoPadding(rhs_shape) && - IsRank2WithNoPadding(output_shape) && + output_primitive_type == F64 || output_primitive_type == C64); + return type_is_allowed && + IsRank2WithNoPadding(lhs_shape, batch_dimensions_size) && + IsRank2WithNoPadding(rhs_shape, batch_dimensions_size) && + IsRank2WithNoPadding(output_shape, batch_dimensions_size) && !ShapeUtil::IsZeroElementArray(lhs_shape) && !ShapeUtil::IsZeroElementArray(rhs_shape); } @@ -64,14 +67,15 @@ bool DotImplementedAsGemm(const HloInstruction& dot) { CHECK_EQ(dot.opcode(), HloOpcode::kDot); const Shape& lhs_shape = dot.operand(0)->shape(); const Shape& rhs_shape = dot.operand(1)->shape(); + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); // If gemm can accept the operand shapes, use it rather than a custom // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) { + if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape(), + dim_numbers.lhs_batch_dimensions_size())) { // The size of the reduction dimension should match. The shape inference // guarantees this invariant, so the check here is for programming // errors. - const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); return true; @@ -81,11 +85,6 @@ bool DotImplementedAsGemm(const HloInstruction& dot) { } // namespace bool ImplementedAsGemm(const HloInstruction& hlo) { - // We can only do this if the HLO is unnested. - if (hlo.parent() != hlo.GetModule()->entry_computation()) { - return false; - } - // For certain types of Dot, we can call pre-canned BLAS gemm. if (hlo.opcode() == HloOpcode::kDot) { return DotImplementedAsGemm(hlo); @@ -242,15 +241,17 @@ llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, arguments_ptr}); } -llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset, - llvm::IRBuilder<>* builder) { +llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, + llvm::IRBuilder<>* builder) { int bit_width = value->getType()->getPrimitiveSizeInBits(); + llvm::Value* all_warps_mask = builder->getInt32(-1); // Special case for efficiency if (value->getType()->isFloatTy() && bit_width == 32) { return llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_shfl_down_f32, - {value, offset, builder->getInt32(kWarpSize - 1)}, {}, builder); + llvm::Intrinsic::nvvm_shfl_sync_down_f32, + {all_warps_mask, value, offset, builder->getInt32(kWarpSize - 1)}, {}, + builder); } // We must split values wider than 32 bits as the "shfl" instruction operates @@ -264,10 +265,11 @@ llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset, for (int i = 0; i < num_segments; ++i) { x = builder->CreateInsertElement( x, - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_shfl_down_i32, - {builder->CreateExtractElement(x, i), - offset, builder->getInt32(kWarpSize - 1)}, - {}, builder), + llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_shfl_sync_down_i32, + {all_warps_mask, builder->CreateExtractElement(x, i), offset, + builder->getInt32(kWarpSize - 1)}, + {}, builder), i); } return builder->CreateBitCast( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 59455f389e733fee2d6cace7486f919a0c5e834e..5d23a3d01842c7b4ff405171cd49c96a19f7e5b0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -31,6 +31,12 @@ namespace gpu { constexpr int64 kWarpSize = 32; // Returns true if `hlo` will be implemented as a call to BLAS gemm. +// +// Precondition: `hlo` is in an "unnested context", meaning, it lives within the +// entry computation, within the either of a while loop's subcomputations, +// within any of a conditional's subcomputations, etc., but *does not* live +// within a reduce subcomputation, a map subcomputation, a fusion +// subcomputation, etc. It's OK if `hlo` *is* a fusion. bool ImplementedAsGemm(const HloInstruction& hlo); // A call to cuDNN for batch normalization is represented as CustomCall HLO with @@ -125,13 +131,17 @@ llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, llvm::IRBuilder<>* builder); // Emits code to shuffle data between threads of a warp. This has the same -// semantics as the PTX "shfl.down" instruction [0] but works for values of any -// size. The last operand of the emitted "shfl" is `kWarpSize - 1`. +// semantics as the PTX "shfl.sync.down" instruction but works for values that +// aren't 32 bits in size. The last operand of the emitted "shfl" is +// `kWarpSize - 1`. +// +// This function emits a "full-warp" shuffle, which all threads of a warp +// participate in. *Do not use this function from a divergent context:* You +// can't correctly do so on both Volta and earlier GPUs. // -// [0] -// http://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl -llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset, - llvm::IRBuilder<>* builder); +// https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync +llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, + llvm::IRBuilder<>* builder); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index d5e07c3afb7dcb7e7a848b8c02e413c21d8ea155..541cacf6970453033c09a153a2dd320d4ebf3d4a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -57,12 +57,12 @@ IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config, IrEmitterContext* ir_emitter_context, bool is_nested) : ir_emitter_context_(ir_emitter_context), module_(ir_emitter_context->llvm_module()), - ir_builder_(module_->getContext()), + b_(module_->getContext()), bindings_(ir_emitter_context->hlo_module(), - &ir_emitter_context->buffer_assignment(), &ir_builder_, module_, + &ir_emitter_context->buffer_assignment(), &b_, module_, is_nested), hlo_module_config_(hlo_module_config) { - ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( + b_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config.debug_options() .xla_enable_fast_math())); } @@ -71,30 +71,16 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : hlo->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*operand, *hlo) - .EmitReadArrayElement(index, &ir_builder_); + return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_); }; } return EmitTargetElementLoop( - *hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_, + *hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &b_, GetNestedComputer()) .MakeElementGenerator(hlo, operand_to_generator)); } Status IrEmitter::HandleConstant(HloInstruction* constant) { - const Literal& literal = constant->literal(); - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, module_); - llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( - *module_, initializer->getType(), - /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, - /*Name=*/""); - VLOG(2) << "HandleConstant: " << constant->ToString() << std::endl - << " emitted_value: " << llvm_ir::DumpToString(*global_for_const) - << std::endl - << " its type: " - << llvm_ir::DumpToString(*global_for_const->getType()); - bindings_.BindHloToIrValue(*constant, global_for_const); return Status::OK(); } @@ -119,15 +105,10 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { get_tuple_element->shape(), get_tuple_element->tuple_index(), // TODO(b/26344050): tighten the alignment here // based on the real element type. - /*alignment=*/1, GetBasePointer(*operand), &ir_builder_, module_)); + /*alignment=*/1, GetBasePointer(*operand), &b_, module_)); return Status::OK(); } -Status IrEmitter::HandleSort(HloInstruction*) { - // TODO(b/26783907): Implement sort on GPU. - return Unimplemented("sort"); -} - Status IrEmitter::HandleSend(HloInstruction*) { return Unimplemented("Send is not implemented on GPU"); } @@ -144,13 +125,16 @@ Status IrEmitter::HandleRecvDone(HloInstruction*) { return Unimplemented("Recv-done is not implemented on GPU"); } +Status IrEmitter::HandleScatter(HloInstruction*) { + return Unimplemented("Scatter is not implemented on GPUs."); +} + Status IrEmitter::HandleTuple(HloInstruction* tuple) { std::vector base_ptrs; for (const HloInstruction* operand : tuple->operands()) { base_ptrs.push_back(GetBasePointer(*operand)); } - llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &ir_builder_, - module_); + llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &b_, module_); return Status::OK(); } @@ -171,7 +155,7 @@ Status IrEmitter::EmitCallToNestedComputation( std::vector arguments(operands.begin(), operands.end()); arguments.push_back(output); arguments.push_back(bindings_.GetTempBufferBase()); - ir_builder_.CreateCall(emitted_function, arguments); + b_.CreateCall(emitted_function, arguments); return Status::OK(); } @@ -193,21 +177,20 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( computation.root_instruction()->shape().element_type(); bool is_atomic_integral = element_type == S32 || element_type == U32 || element_type == S64 || element_type == U64; - llvm::Value* source = ir_builder_.CreateLoad(source_address, "source"); + llvm::Value* source = b_.CreateLoad(source_address, "source"); if (root_opcode == HloOpcode::kAdd) { // NVPTX supports atomicAdd on F32 and integer types. if (element_type == F32) { // F32 + F32 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_atomic_load_add_f32, {output_address, source}, - {output_address->getType()}, &ir_builder_); + {output_address->getType()}, &b_); return true; } if (is_atomic_integral) { // integral + integral - ir_builder_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address, - source, - llvm::AtomicOrdering::SequentiallyConsistent); + b_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } } @@ -218,8 +201,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Max : llvm::AtomicRMWInst::UMax; - ir_builder_.CreateAtomicRMW(opcode, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + b_.CreateAtomicRMW(opcode, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } @@ -228,8 +211,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Min : llvm::AtomicRMWInst::UMin; - ir_builder_.CreateAtomicRMW(opcode, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + b_.CreateAtomicRMW(opcode, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } @@ -301,20 +284,20 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, llvm::Type* element_address_type = element_type->getPointerTo(); int atomic_size = (element_size < 32) ? 32 : element_size; - llvm::Type* atomic_type = ir_builder_.getIntNTy(atomic_size); + llvm::Type* atomic_type = b_.getIntNTy(atomic_size); llvm::Type* atomic_address_type = atomic_type->getPointerTo(output_address_type->getPointerAddressSpace()); // cas_old_output_address and cas_new_output_address point to the scratch // memory where we store the old and new values for the repeated atomicCAS // operations. - llvm::Value* cas_old_output_address = ir_builder_.CreateAlloca( + llvm::Value* cas_old_output_address = b_.CreateAlloca( atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); - llvm::Value* cas_new_output_address = ir_builder_.CreateAlloca( + llvm::Value* cas_new_output_address = b_.CreateAlloca( atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); // Emit preparation code to the preheader. - llvm::BasicBlock* loop_preheader_bb = ir_builder_.GetInsertBlock(); + llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock(); llvm::Value* atomic_memory_address; // binop_output_address points to the scratch memory that stores the @@ -325,77 +308,71 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, CHECK_EQ((element_size % sizeof(char)), 0); llvm::Type* address_int_type = module_->getDataLayout().getIntPtrType(output_address_type); - atomic_memory_address = - ir_builder_.CreatePtrToInt(output_address, address_int_type); + atomic_memory_address = b_.CreatePtrToInt(output_address, address_int_type); llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3); - llvm::Value* offset = ir_builder_.CreateAnd(atomic_memory_address, mask); + llvm::Value* offset = b_.CreateAnd(atomic_memory_address, mask); mask = llvm::ConstantInt::get(address_int_type, -4); - atomic_memory_address = ir_builder_.CreateAnd(atomic_memory_address, mask); + atomic_memory_address = b_.CreateAnd(atomic_memory_address, mask); atomic_memory_address = - ir_builder_.CreateIntToPtr(atomic_memory_address, atomic_address_type); - binop_output_address = ir_builder_.CreateAdd( - ir_builder_.CreatePtrToInt(cas_new_output_address, address_int_type), - offset); + b_.CreateIntToPtr(atomic_memory_address, atomic_address_type); + binop_output_address = b_.CreateAdd( + b_.CreatePtrToInt(cas_new_output_address, address_int_type), offset); binop_output_address = - ir_builder_.CreateIntToPtr(binop_output_address, element_address_type); + b_.CreateIntToPtr(binop_output_address, element_address_type); } else { atomic_memory_address = - ir_builder_.CreateBitCast(output_address, atomic_address_type); + b_.CreateBitCast(output_address, atomic_address_type); binop_output_address = - ir_builder_.CreateBitCast(cas_new_output_address, element_address_type); + b_.CreateBitCast(cas_new_output_address, element_address_type); } // Use the value from the memory that atomicCAS operates on to initialize // cas_old_output. llvm::Value* cas_old_output = - ir_builder_.CreateLoad(atomic_memory_address, "cas_old_output"); - ir_builder_.CreateStore(cas_old_output, cas_old_output_address); + b_.CreateLoad(atomic_memory_address, "cas_old_output"); + b_.CreateStore(cas_old_output, cas_old_output_address); llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock( - ir_builder_.GetInsertPoint(), "atomic_op_loop_exit"); - llvm::BasicBlock* loop_body_bb = - llvm::BasicBlock::Create(ir_builder_.getContext(), "atomic_op_loop_body", - ir_builder_.GetInsertBlock()->getParent()); - ir_builder_.SetInsertPoint(loop_body_bb); + b_.GetInsertPoint(), "atomic_op_loop_exit"); + llvm::BasicBlock* loop_body_bb = llvm::BasicBlock::Create( + b_.getContext(), "atomic_op_loop_body", b_.GetInsertBlock()->getParent()); + b_.SetInsertPoint(loop_body_bb); // Change preheader's successor from loop_exit_bb to loop_body_bb. loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb); // Emit the body of the loop that repeatedly invokes atomicCAS. // // Use cas_old_output to initialize cas_new_output. - cas_old_output = - ir_builder_.CreateLoad(cas_old_output_address, "cas_old_output"); - ir_builder_.CreateStore(cas_old_output, cas_new_output_address); + cas_old_output = b_.CreateLoad(cas_old_output_address, "cas_old_output"); + b_.CreateStore(cas_old_output, cas_new_output_address); // Emits code to calculate new_output = operation(old_output, source); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( computation, {binop_output_address, source_address}, binop_output_address)); llvm::Value* cas_new_output = - ir_builder_.CreateLoad(cas_new_output_address, "cas_new_output"); + b_.CreateLoad(cas_new_output_address, "cas_new_output"); // Emit code to perform the atomicCAS operation // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output, // cas_new_output); - llvm::Value* ret_value = ir_builder_.CreateAtomicCmpXchg( + llvm::Value* ret_value = b_.CreateAtomicCmpXchg( atomic_memory_address, cas_old_output, cas_new_output, llvm::AtomicOrdering::SequentiallyConsistent, llvm::AtomicOrdering::SequentiallyConsistent); // Extract the memory value returned from atomicCAS and store it as // cas_old_output. - ir_builder_.CreateStore( - ir_builder_.CreateExtractValue(ret_value, 0, "cas_old_output"), - cas_old_output_address); + b_.CreateStore(b_.CreateExtractValue(ret_value, 0, "cas_old_output"), + cas_old_output_address); // Extract the success bit returned from atomicCAS and generate a // conditional branch on the success bit. - ir_builder_.CreateCondBr( - ir_builder_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb, - loop_body_bb); + b_.CreateCondBr(b_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb, + loop_body_bb); // Set the insertion point to the exit basic block so that the caller of // this method can continue emitting code to the right place. - SetToFirstInsertPoint(loop_exit_bb, &ir_builder_); + SetToFirstInsertPoint(loop_exit_bb, &b_); return Status::OK(); } @@ -421,46 +398,49 @@ Status IrEmitter::EmitAtomicOperationForNestedComputation( Status IrEmitter::HandleSelect(HloInstruction* select) { auto pred = select->operand(0); - auto on_true = select->operand(1); - auto on_false = select->operand(2); TF_RET_CHECK(pred->shape().element_type() == PRED); - - if (ShapeUtil::IsTuple(select->shape())) { - llvm_ir::EmitTupleSelect(GetIrArray(*select, *select), - GetIrArray(*pred, *select), - GetBasePointer(*on_true), - GetBasePointer(*on_false), &ir_builder_, module_); - return Status::OK(); - } - // We must not call the subclass `DefaultAction` method, lest its // `HandleSelect` call `IrEmitter::HandleSelect` and its `DefaultAction` // assume no handler has already been called. return IrEmitter::DefaultAction(select); } +Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) { + auto pred = tuple_select->operand(0); + auto on_true = tuple_select->operand(1); + auto on_false = tuple_select->operand(2); + TF_RET_CHECK(pred->shape().element_type() == PRED); + TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape())); + TF_RET_CHECK(ShapeUtil::IsTuple(tuple_select->shape())); + llvm_ir::EmitTupleSelect(GetIrArray(*tuple_select, *tuple_select), + GetIrArray(*pred, *tuple_select), + GetBasePointer(*on_true), GetBasePointer(*on_false), + &b_, module_); + return Status::OK(); +} + namespace { -llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* ir_builder) { - return ir_builder->CreateExtractValue(x, {0}); -} - -llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* ir_builder) { - return ir_builder->CreateExtractValue(x, {1}); -} - -std::pair MultiplyComplex( - llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* ir_builder) { - llvm::Value* lhs_real = Real(lhs_value, ir_builder); - llvm::Value* lhs_imag = Imag(lhs_value, ir_builder); - llvm::Value* rhs_real = Real(rhs_value, ir_builder); - llvm::Value* rhs_imag = Imag(rhs_value, ir_builder); - llvm::Value* real_result1 = ir_builder->CreateFMul(lhs_real, rhs_real); - llvm::Value* real_result2 = ir_builder->CreateFMul(lhs_imag, rhs_imag); - llvm::Value* real_result = ir_builder->CreateFSub(real_result1, real_result2); - llvm::Value* imag_result1 = ir_builder->CreateFMul(lhs_real, rhs_imag); - llvm::Value* imag_result2 = ir_builder->CreateFMul(lhs_imag, rhs_real); - llvm::Value* imag_result = ir_builder->CreateFAdd(imag_result1, imag_result2); +llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* b) { + return b->CreateExtractValue(x, {0}); +} + +llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* b) { + return b->CreateExtractValue(x, {1}); +} + +std::pair MultiplyComplex(llvm::Value* lhs_value, + llvm::Value* rhs_value, + llvm::IRBuilder<>* b) { + llvm::Value* lhs_real = Real(lhs_value, b); + llvm::Value* lhs_imag = Imag(lhs_value, b); + llvm::Value* rhs_real = Real(rhs_value, b); + llvm::Value* rhs_imag = Imag(rhs_value, b); + llvm::Value* real_result1 = b->CreateFMul(lhs_real, rhs_real); + llvm::Value* real_result2 = b->CreateFMul(lhs_imag, rhs_imag); + llvm::Value* real_result = b->CreateFSub(real_result1, real_result2); + llvm::Value* imag_result1 = b->CreateFMul(lhs_real, rhs_imag); + llvm::Value* imag_result2 = b->CreateFMul(lhs_imag, rhs_real); + llvm::Value* imag_result = b->CreateFAdd(imag_result1, imag_result2); return {real_result, imag_result}; } } // namespace @@ -474,27 +454,29 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { const Shape& lhs_shape = lhs_instruction->shape(); const Shape& rhs_shape = rhs_instruction->shape(); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + CHECK_EQ(dnums.lhs_batch_dimensions_size(), + dnums.rhs_batch_dimensions_size()); // TODO(b/110211620): Convert to use i32 index_type when it is possible. - llvm::Type* index_type = ir_builder_.getInt64Ty(); + llvm::Type* index_type = b_.getInt64Ty(); llvm_ir::IrArray::Index element_index(index_type); if (ShapeUtil::IsScalar(lhs_shape) && ShapeUtil::IsScalar(rhs_shape)) { // If the operands are scalar, don't emit any loops. llvm::Value* lhs_value = - lhs_array.EmitReadArrayElement(/*index=*/element_index, &ir_builder_); + lhs_array.EmitReadArrayElement(/*index=*/element_index, &b_); llvm::Value* rhs_value = - rhs_array.EmitReadArrayElement(/*index=*/element_index, &ir_builder_); + rhs_array.EmitReadArrayElement(/*index=*/element_index, &b_); llvm::Value* result; if (ShapeUtil::ElementIsComplex(lhs_shape)) { - auto value = MultiplyComplex(lhs_value, rhs_value, &ir_builder_); + auto value = MultiplyComplex(lhs_value, rhs_value, &b_); result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType()); - result = ir_builder_.CreateInsertValue(result, value.first, {0}); - result = ir_builder_.CreateInsertValue(result, value.second, {1}); + result = b_.CreateInsertValue(result, value.first, {0}); + result = b_.CreateInsertValue(result, value.second, {1}); } else { - result = ir_builder_.CreateFMul(lhs_value, rhs_value); + result = b_.CreateFMul(lhs_value, rhs_value); } - target_array.EmitWriteArrayElement(/*index=*/element_index, result, - &ir_builder_); + target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_); return Status::OK(); } @@ -510,9 +492,15 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { const int64 lhs_reduction_dimension = ShapeUtil::GetDimensionNumber(lhs_shape, -1); const int64 rhs_reduction_dimension = - ShapeUtil::Rank(rhs_shape) >= 2 + ShapeUtil::Rank(rhs_shape) >= 2 + dnums.lhs_batch_dimensions_size() ? ShapeUtil::GetDimensionNumber(rhs_shape, -2) - : 0; + : dnums.lhs_batch_dimensions_size(); + + // Check that the batch dims don't cover the last two dims. + for (int64 batch_dim : dnums.lhs_batch_dimensions()) { + CHECK_NE(lhs_reduction_dimension, batch_dim); + CHECK_NE(rhs_reduction_dimension, batch_dim); + } // Verify the reduction dimension in the two operands are the same size. TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == @@ -521,11 +509,18 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Create loop nests which loop through the LHS operand dimensions and the RHS // operand dimensions. The reduction dimension of the LHS and RHS are handled // in a separate innermost loop which performs the sum of products. - llvm_ir::ForLoopNest loop_nest(IrName(dot), &ir_builder_); - llvm_ir::IrArray::Index lhs_index = EmitOperandArrayLoopNest( - lhs_array, lhs_reduction_dimension, "lhs", &loop_nest); - llvm_ir::IrArray::Index rhs_index = EmitOperandArrayLoopNest( - rhs_array, rhs_reduction_dimension, "rhs", &loop_nest); + llvm_ir::ForLoopNest loop_nest(IrName(dot), &b_); + llvm_ir::IrArray::Index lhs_index = loop_nest.EmitOperandArrayLoopNest( + lhs_array, /*dimension_to_skip=*/lhs_reduction_dimension, "lhs"); + llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest( + rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs"); + + // We don't have to iterate over the batch dimensions in both arrays, simplify + // the loop nest of the rhs. + for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) { + DCHECK(c_linear_search(dnums.lhs_batch_dimensions(), i)); + rhs_index[i] = lhs_index[i]; + } // Create the reduction loop which does the sum of products reduction. std::unique_ptr reduction_loop = loop_nest.AddLoop( @@ -545,7 +540,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { llvm::Value* accum_address = llvm_ir::EmitAllocaAtFunctionEntry( accum_type, // The pointee type of the alloca instruction. "accum_address", // The name of the alloca instruction. - &ir_builder_); + &b_); // Initialize the accumulator in the preheader to zero. new llvm::StoreInst( @@ -559,27 +554,25 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // updated_accum = accum + lhs_element * rhs_element // *accum_address = updated_accum TF_RET_CHECK(!reduction_loop->GetBodyBasicBlock()->empty()); - ir_builder_.SetInsertPoint( + b_.SetInsertPoint( &*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt()); - llvm::Value* lhs_element = - lhs_array.EmitReadArrayElement(lhs_index, &ir_builder_); - llvm::Value* rhs_element = - rhs_array.EmitReadArrayElement(rhs_index, &ir_builder_); - llvm::Value* accum = ir_builder_.CreateLoad(accum_address); + llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_); + llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_); + llvm::Value* accum = b_.CreateLoad(accum_address); llvm::Value* updated_accum; if (ShapeUtil::ElementIsComplex(lhs_shape)) { - auto value = MultiplyComplex(lhs_element, rhs_element, &ir_builder_); - llvm::Value* accum_real = Real(accum, &ir_builder_); - llvm::Value* real_sum = ir_builder_.CreateFAdd(accum_real, value.first); - updated_accum = ir_builder_.CreateInsertValue(accum, real_sum, {0}); - llvm::Value* accum_imag = Imag(accum, &ir_builder_); - llvm::Value* imag_sum = ir_builder_.CreateFAdd(accum_imag, value.second); - updated_accum = ir_builder_.CreateInsertValue(updated_accum, imag_sum, {1}); + auto value = MultiplyComplex(lhs_element, rhs_element, &b_); + llvm::Value* accum_real = Real(accum, &b_); + llvm::Value* real_sum = b_.CreateFAdd(accum_real, value.first); + updated_accum = b_.CreateInsertValue(accum, real_sum, {0}); + llvm::Value* accum_imag = Imag(accum, &b_); + llvm::Value* imag_sum = b_.CreateFAdd(accum_imag, value.second); + updated_accum = b_.CreateInsertValue(updated_accum, imag_sum, {1}); } else { - llvm::Value* product = ir_builder_.CreateFMul(lhs_element, rhs_element); - updated_accum = ir_builder_.CreateFAdd(accum, product); + llvm::Value* product = b_.CreateFMul(lhs_element, rhs_element); + updated_accum = b_.CreateFAdd(accum, product); } - ir_builder_.CreateStore(updated_accum, accum_address); + b_.CreateStore(updated_accum, accum_address); // After the reduction loop exits, store the accumulator into the target // address. The index into the target address is the concatenation of the rhs @@ -591,21 +584,22 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { target_index.push_back(lhs_index[dimension]); } } - for (size_t dimension = 0; dimension < rhs_index.size(); ++dimension) { + // Skip over the batch dimensions to not have them in the index twice. + for (size_t dimension = dnums.lhs_batch_dimensions_size(); + dimension < rhs_index.size(); ++dimension) { if (dimension != rhs_reduction_dimension) { target_index.push_back(rhs_index[dimension]); } } - SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &ir_builder_); + SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_); target_array.EmitWriteArrayElement( target_index, - ir_builder_.CreateLoad( - accum_address), // The value written to the target array. - &ir_builder_); + b_.CreateLoad(accum_address), // The value written to the target array. + &b_); // Set the IR builder insert point to the exit basic block of the outer most // loop. This ensures later instructions are inserted after this loop nest. - ir_builder_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); + b_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); return Status::OK(); } @@ -647,11 +641,10 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { [=](const llvm_ir::IrArray::Index& index) -> StatusOr { // Initialize an accumulator with init_value. llvm::AllocaInst* accumulator_addr = - ir_builder_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType( + b_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType( reduce->shape().element_type(), module_)); - ir_builder_.CreateStore( - ir_builder_.CreateLoad(GetBasePointer(*init_value)), - accumulator_addr); + b_.CreateStore(b_.CreateLoad(GetBasePointer(*init_value)), + accumulator_addr); // The enclosing loops go over all the target elements. Now we have to // compute the actual target element. For this, we build a new loop nest @@ -659,12 +652,12 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { // AddLoopsForShapeOnDimensions will return an Index where induction // Value*s are placed for each dimension in dimensions, and all the rest // are nullptrs. - llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &ir_builder_); + llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &b_); const llvm_ir::IrArray::Index reduced_dims_index = loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions, "reduction_dim"); - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &b_); // Build a full index for the input argument, using reduced_dims_index // as the base. In reduced_dims_index only the reduction dimensions are @@ -683,13 +676,12 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { // Apply the reduction function to the loaded value. llvm::Value* input_address = - GetIrArray(*arg, *reduce) - .EmitArrayElementAddress(input_index, &ir_builder_); + GetIrArray(*arg, *reduce).EmitArrayElementAddress(input_index, &b_); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *function, {accumulator_addr, input_address}, accumulator_addr)); - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); - return ir_builder_.CreateLoad(accumulator_addr); + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); + return b_.CreateLoad(accumulator_addr); }); } @@ -702,8 +694,8 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { for (HloInstruction* operand : fusion->operands()) { parameter_arrays.push_back(GetIrArray(*operand, *fusion)); } - GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, - &ir_builder_, GetNestedComputer()); + GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_, + GetNestedComputer()); FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); @@ -733,24 +725,6 @@ Status IrEmitter::HandleOutfeed(HloInstruction*) { return Unimplemented("Outfeed is not supported on GPU."); } -Status IrEmitter::HandleRng(HloInstruction* random) { - ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; - for (const HloInstruction* operand : random->operands()) { - operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*operand, *random) - .EmitReadArrayElement(index, &ir_builder_); - }; - } - // Emits a single-threaded loop because the loop body generated by the element - // generator for Rng can't be parallelized (b/32333178). - return llvm_ir::LoopEmitter( - GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_, - GetNestedComputer()) - .MakeElementGenerator(random, operand_to_generator), - GetIrArray(*random, *random), &ir_builder_) - .EmitLoop(IrName(random)); -} - Status IrEmitter::HandleBatchNormInference(HloInstruction*) { return Unimplemented( "The GPU backend does not implement BatchNormInference directly. It " @@ -774,34 +748,9 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) { "to a cudnn CustomCall using CudnnBatchNormRewriter."); } -llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest( - const llvm_ir::IrArray& operand_array, int64 reduction_dimension, - tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest) { - // Prepares the dimension list we will use to emit the loop nest. Outermost - // loops are added first. Add loops in major-to-minor order, and skip the - // reduction dimension. - std::vector dimensions; - const Shape& shape = operand_array.GetShape(); - for (int i = 0; i < LayoutUtil::MinorToMajor(shape).size(); ++i) { - int64 dimension = LayoutUtil::Major(shape.layout(), i); - if (dimension != reduction_dimension) { - dimensions.push_back(dimension); - } - } - - // Create loop nest with one for-loop for each dimension of the - // output. - llvm_ir::IrArray::Index index = - loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix); - // Verify every dimension except the reduction dimension was set in the index. - for (size_t dimension = 0; dimension < index.size(); ++dimension) { - if (dimension == reduction_dimension) { - DCHECK_EQ(nullptr, index[dimension]); - } else { - DCHECK_NE(nullptr, index[dimension]); - } - } - return index; +Status IrEmitter::HandleIota(HloInstruction*) { + // TODO(b/64798317): implement iota on GPU. + return Unimplemented("Iota is not implemented on GPU."); } StatusOr IrEmitter::ComputeNestedElement( @@ -810,16 +759,16 @@ StatusOr IrEmitter::ComputeNestedElement( llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType( computation.root_instruction()->shape().element_type(), module_), - "return_buffer", &ir_builder_); + "return_buffer", &b_); std::vector parameter_buffers; for (llvm::Value* parameter_element : parameter_elements) { parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry( - parameter_element->getType(), "parameter_buffer", &ir_builder_)); - ir_builder_.CreateStore(parameter_element, parameter_buffers.back()); + parameter_element->getType(), "parameter_buffer", &b_)); + b_.CreateStore(parameter_element, parameter_buffers.back()); } TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers, return_buffer)); - return ir_builder_.CreateLoad(return_buffer); + return b_.CreateLoad(return_buffer); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index e55dfc6dae844ceb1d28ad389d133c80823bad9a..561c6838798aa92ce2c96b3c45d5ba42fe6edef3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -79,7 +79,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; - Status HandleSort(HloInstruction* sort) override; Status HandleSend(HloInstruction* send) override; Status HandleSendDone(HloInstruction* send_done) override; Status HandleRecv(HloInstruction* recv) override; @@ -87,14 +86,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleParameter(HloInstruction* parameter) override; Status HandleReduce(HloInstruction* reduce) override; Status HandleTuple(HloInstruction* tuple) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call) override; - Status HandleRng(HloInstruction* random) override; Status HandleBatchNormInference(HloInstruction* batch_norm) override; Status HandleBatchNormTraining(HloInstruction* batch_norm) override; Status HandleBatchNormGrad(HloInstruction* batch_norm) override; + Status HandleIota(HloInstruction* iota) override; Status FinishVisit(HloInstruction* root) override { return Status::OK(); } @@ -161,7 +162,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { // The following fields track the IR emission state. According to LLVM memory // management rules, their memory is owned by the module. - llvm::IRBuilder<> ir_builder_; + llvm::IRBuilder<> b_; // Mapping from HLO to its underlying LLVM value. HloToIrBindings bindings_; @@ -170,17 +171,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { const HloModuleConfig& hlo_module_config_; private: - // Emits a series of nested loops for iterating over an operand array in the - // dot operation. Loops are constructed in major to minor dimension layout - // order. No loop is emitted for the given reduction_dimension. The function - // returns an IrArray index for the given operand_array containing the indvars - // of the loops. All dimensions of the index are filled except for the - // reduction dimension. name_suffix is the string to append to the names of - // LLVM constructs (eg, basic blocks) constructed by this method. - llvm_ir::IrArray::Index EmitOperandArrayLoopNest( - const llvm_ir::IrArray& operand_array, int64 reduction_dimension, - tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest); - // A helper method for EmitAtomicOperationForNestedComputation. Certain // computations, such as floating-point addition and integer maximization, can // be simply implemented using an LLVM atomic instruction. If "computation" is @@ -197,6 +187,13 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* output_address, llvm::Value* source_address); + // A helper method for HandleSort(). It adds the inner comparison loop where + // we compare elements pointed to by 'keys_index' and 'compare_keys_index'. + void EmitCompareLoop(int64 dimension_to_sort, + const llvm_ir::IrArray::Index& keys_index, + const llvm_ir::IrArray::Index& compare_keys_index, + const llvm_ir::IrArray& keys_array); + StatusOr ComputeNestedElement( const HloComputation& computation, tensorflow::gtl::ArraySlice parameter_elements); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index c9574c87a3be208915b3d6a32679553eb425d2f0..5c827e5f9cf3e1c04af444dae338a2ec411ce372 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -70,10 +70,10 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( argument_dereferenceable_bytes.push_back(root_size); } // The base pointer of the memory block for all pre-allocated temp buffers. - argument_types.push_back(ir_builder_.getInt8PtrTy()); + argument_types.push_back(b_.getInt8PtrTy()); llvm::FunctionType* function_type = - llvm::FunctionType::get(ir_builder_.getVoidTy(), argument_types, false); + llvm::FunctionType::get(b_.getVoidTy(), argument_types, false); llvm::Function* function = llvm::Function::Create( function_type, // The function type. llvm::GlobalValue::InternalLinkage, // The linkage type. @@ -96,8 +96,7 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( llvm::BasicBlock::Create(function->getContext(), "entry", function); // Emit a "return void" at entry_bb's end, and sets the insert point before // that return instruction. - ir_builder_.SetInsertPoint( - llvm::ReturnInst::Create(function->getContext(), entry_bb)); + b_.SetInsertPoint(llvm::ReturnInst::Create(function->getContext(), entry_bb)); std::vector non_io_hlos; for (const auto* hlo : nested_computation.instructions()) { @@ -127,20 +126,17 @@ Status IrEmitterNested::EmitTargetElementLoop( target_arrays.push_back(GetIrArray(hlo, hlo, {i})); } TF_RETURN_IF_ERROR( - llvm_ir::LoopEmitter(element_generator, target_arrays, &ir_builder_) - .EmitLoop()); + llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop()); std::vector tuple_operand_ptrs; tuple_operand_ptrs.reserve(num_elems); for (const llvm_ir::IrArray& array : target_arrays) { tuple_operand_ptrs.push_back(array.GetBasePointer()); } - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_, - module_); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_); return Status::OK(); } - return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), - &ir_builder_) + return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_) .EmitLoop(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index bdb9e77da4d4fda23cad128fc6400a1205e7d54b..a093ffc7c1293d5dc9e44de97896add12a3ae510 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -28,11 +28,12 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" @@ -48,30 +49,35 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -79,6 +85,7 @@ namespace gpu { namespace { +using llvm_ir::IrArray; using llvm_ir::IrName; using tensorflow::gtl::ArraySlice; using tensorflow::gtl::InlinedVector; @@ -164,40 +171,6 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { return DfsHloVisitor::Postprocess(hlo); } -namespace { -bool ImplementedAsHostToDeviceMemcpy(const BufferAssignment& buffer_assignment, - const HloInstruction& hlo) { - // `hlo` needs to satisfy the following conditions to be implemented as a - // host-to-device cuMemcpy. - // - // 1. `hlo` is a kCopy instruction. - // 2. `hlo`'s only operand is a kConstant instruction. - // 3. `hlo` and its operand have the same shape (thus the same layout too). - // 4. The address of `hlo`'s buffer is known at runtime (without dereferencing - // pointers in a tuple). - return hlo.opcode() == HloOpcode::kCopy && - hlo.operand(0)->opcode() == HloOpcode::kConstant && - ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok(); -} - -bool ImplementedAsDeviceToDeviceMemcpy( - const BufferAssignment& buffer_assignment, const HloInstruction& hlo) { - // `hlo` needs to satisfy three conditions to be implemented as a - // device-to-device cuMemcpy. - // - // 1. `hlo` is a kCopy instruction. - // 2. `hlo` and its operand have the same shape (thus the same layout too). - // 3. `hlo` and its operand have a statically-known buffer assignment - // (constants do not, for instance), which means the source buffer also - // resides on the device. - return hlo.opcode() == HloOpcode::kCopy && - ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok() && - buffer_assignment.GetUniqueTopLevelSlice(hlo.operand(0)).ok(); -} -} // namespace - llvm::Function* IrEmitterUnnested::BuildKernelPrototype( const HloInstruction& inst, tensorflow::gtl::ArraySlice args) { @@ -211,7 +184,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( llvm::LLVMContext& context = module->getContext(); llvm::FunctionType* kernel_type = llvm::FunctionType::get( /*Result=*/llvm::Type::getVoidTy(context), - std::vector(args.size(), ir_builder_.getInt8PtrTy()), + std::vector(args.size(), b_.getInt8PtrTy()), /*isVarArg=*/false); llvm::Function* kernel = llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage, @@ -226,9 +199,20 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( ++arg_it; kernel->addDereferenceableAttr(arg_no + 1, alloc->size()); + + const int64 alignment = [&] { + if (alloc->is_entry_computation_parameter()) { + return kEntryParameterAlignBytes; + } else if (alloc->is_constant()) { + return kConstantBufferAlignBytes; + } else { + return kXlaAllocatedBufferAlignBytes; + } + }(); + kernel->addParamAttr( - arg_no, llvm::Attribute::get(context, llvm::Attribute::Alignment, - kCudaMallocAlignBytes)); + arg_no, + llvm::Attribute::get(context, llvm::Attribute::Alignment, alignment)); if (alloc->IsPreallocatedTempBuffer()) { fn_arg->setName("temp_buf"); @@ -247,7 +231,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( nvvm_annotations_node->addOperand(llvm::MDNode::get( context, {llvm::ConstantAsMetadata::get(kernel), llvm::MDString::get(context, "kernel"), - llvm::ConstantAsMetadata::get(ir_builder_.getInt32(1))})); + llvm::ConstantAsMetadata::get(b_.getInt32(1))})); // Update the insert point to the entry basic block. llvm::BasicBlock* entry_bb = @@ -255,7 +239,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( // Emit a "return void" at entry_bb's end, and set the insert point before // that return instruction. - ir_builder_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb)); + b_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb)); return kernel; } @@ -293,7 +277,7 @@ int ComputeMaxUnrollFactor(const HloInstruction* hlo) { // range of i32. // Otherwise, the return type is i64. llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, - llvm::IRBuilder<>* ir_builder) { + llvm::IRBuilder<>* b) { // Find the unnested hlo instructon for which the kernel is generated for. const HloInstruction* unnested_hlo = hlo; const HloComputation* computation = hlo->parent(); @@ -314,7 +298,7 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, return in_range; }; - llvm::Type* i64_ty = ir_builder->getInt64Ty(); + llvm::Type* i64_ty = b->getInt64Ty(); // Check launch dimension if (!IsInt32(launch_size)) { return i64_ty; @@ -343,7 +327,7 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, } } - return ir_builder->getInt32Ty(); + return b->getInt32Ty(); } } // namespace @@ -355,21 +339,18 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { unroll_factor = ComputeMaxUnrollFactor(hlo); } - thunk_sequence_->emplace_back(BuildKernelThunk(hlo, unroll_factor)); + thunk_sequence_->emplace_back(BuildKernelThunk( + hlo, /*implements_whole_instruction=*/true, unroll_factor)); return IrEmitter::DefaultAction(hlo); } Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { - const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); - if (dnums.lhs_batch_dimensions_size() > 0 || - dnums.rhs_batch_dimensions_size() > 0) { - return Unimplemented("Dot with batch dimensions not implemented."); - } if (ImplementedAsGemm(*dot)) { thunk_sequence_->emplace_back(BuildGemmThunk(dot)); return Status::OK(); } - thunk_sequence_->emplace_back(BuildKernelThunk(dot)); + thunk_sequence_->emplace_back( + BuildKernelThunk(dot, /*implements_whole_instruction=*/true)); return IrEmitter::HandleDot(dot); } @@ -379,7 +360,8 @@ Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) { } Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) { - thunk_sequence_->emplace_back(BuildKernelThunk(convolution)); + thunk_sequence_->emplace_back( + BuildKernelThunk(convolution, /*implements_whole_instruction=*/true)); return IrEmitter::HandleConvolution(convolution); } @@ -586,16 +568,17 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { } } CHECK(first_reduce != nullptr); - thunks.push_back(BuildKernelThunk(fusion)); + thunks.push_back( + BuildKernelThunk(fusion, /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), fusion)); - std::vector parameter_arrays; + std::vector parameter_arrays; for (HloInstruction* operand : fusion->operands()) { parameter_arrays.push_back(GetIrArray(*operand, *fusion)); } GpuElementalIrEmitter elemental_emitter( - hlo_module_config_, ir_emitter_context_->llvm_module(), - &ir_builder_, GetNestedComputer()); + hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, + GetNestedComputer()); FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); @@ -660,21 +643,22 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // touching the un-updated elements. // Set up kernel thunk and fused ir emitter. - thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); - std::vector operand_arrays; + thunk_sequence_->emplace_back( + BuildKernelThunk(fusion, /*implements_whole_instruction=*/true)); + std::vector operand_arrays; for (HloInstruction* operand : fusion->operands()) { operand_arrays.push_back(GetIrArray(*operand, *fusion)); } GpuElementalIrEmitter elemental_emitter(hlo_module_config_, ir_emitter_context_->llvm_module(), - &ir_builder_, GetNestedComputer()); + &b_, GetNestedComputer()); // Shape of the dynamic-update-slice's "update" operand. Shape update_shape = root->operand(1)->shape(); // Array to write into. Because this is an in-place operation, this is the // same as operand 0's array. - llvm_ir::IrArray output_array = GetIrArray(*fusion, *fusion); + IrArray output_array = GetIrArray(*fusion, *fusion); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( update_shape, ir_emitter_context_->device_description()); @@ -685,346 +669,38 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace( fusion, operand_arrays, output_array, &elemental_emitter, - launch_dimensions, &ir_builder_); + launch_dimensions, &b_); } + if (ImplementedAsGemm(*fusion)) { thunk_sequence_->emplace_back(BuildGemmThunk(fusion)); return Status::OK(); } - CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); - int unroll_factor = ComputeMaxUnrollFactor(fusion); + CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop); - thunk_sequence_->emplace_back(BuildKernelThunk(fusion, unroll_factor)); - return IrEmitter::HandleFusion(fusion); -} - -namespace { - -// Returns the indices of the first elements of all consecutive subarrays of the -// given array. For example: -// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4} -std::vector ConsecutiveSegments(tensorflow::gtl::ArraySlice xs) { - std::vector is = {0}; - for (size_t i = 1; i < xs.size(); ++i) { - if (1 != xs[i] - xs[i - 1]) { - is.push_back(i); - } - } - return is; -} - -// Merges the sequences of dimensions of the given shape which start at the -// given indices `segs`. -Shape MergeDimensions(tensorflow::gtl::ArraySlice segs, - const Shape& shape) { - std::vector dimensions; - for (size_t i = 1; i <= segs.size(); ++i) { - dimensions.push_back(std::accumulate( - shape.dimensions().begin() + segs[i - 1], - shape.dimensions().begin() + - (segs.size() == i ? shape.dimensions().size() : segs[i]), - 1, std::multiplies())); - } - return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), - dimensions); -} - -// Returns whether the given shapes and permutation are a 0-2-1 transpose, and -// if so, the normalized and rank-reduced shapes. The shapes must have the same -// dimensions, so this considers layout only. -// -// This function recognizes higher-rank transposes which are elementwise -// equivalent to a 0-2-1 transpose. -std::tuple IsTranspose021(const Shape& a, const Shape& b) { - CHECK(ShapeUtil::Compatible(a, b)); - std::vector perm(a.dimensions().size()); - { - auto layout_a_orig = LayoutUtil::MinorToMajor(a); - std::vector layout_a(layout_a_orig.rbegin(), layout_a_orig.rend()); - auto layout_b_orig = LayoutUtil::MinorToMajor(b); - std::vector layout_b(layout_b_orig.rbegin(), layout_b_orig.rend()); - for (size_t i = 0; i < perm.size(); ++i) { - perm[i] = PositionInContainer(layout_b, layout_a[i]); - } + if (CheckAndEmitHloWithTile021(fusion)) { + return Status::OK(); } - auto segs = ConsecutiveSegments(perm); - Shape norm_a = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a); - Shape norm_b = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(b); - if (3 == segs.size() && 0 == perm[0]) { - Shape reduced_a = MergeDimensions(segs, norm_a); - Shape reduced_b = ShapeUtil::MakeShapeWithDescendingLayout( - b.element_type(), - Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions()))); - return std::make_tuple(true, reduced_a, reduced_b); - } - return std::make_tuple(false, ShapeUtil::MakeNil(), ShapeUtil::MakeNil()); -} - -// Returns whether the given shapes are potentially of a 0-2-1 transpose. -// As 0-2-1 is a self-inverse permutation, which shape is input or output is -// arbitrary. -bool AreShapesForTranspose021(const Shape& a, const Shape& b) { - return 3 == b.dimensions().size() && - ShapeUtil::Compatible( - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a), - ShapeUtil::PermuteDimensions( - {0, 2, 1}, - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - b))); -} - -// Emits a tiled 0-2-1 transpose, assuming both input and output lain out from -// major to minor. The x- and y- dimensions are tiled in square tiles of edge -// length `tile_size`. Each thread block of `tile_size` x `num_rows` threads -// transposes one tile: each thread copies a row from the input to a shared -// memory tile, then copies a column from the shared memory tile to the output. -// -// `tile_size` should usually be same as warp size. -// -// Returns (number of tiles = number of thread blocks needed). -// -// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient -// to launch fewer blocks so each transposes many tiles, and -// in any case, the number of blocks we can launch is limited. -// -// This is the same algorithm in CUDA: -// https://github.com/tensorflow/tensorflow/blob/d2693c8a70567cc78b2e8a9ac8020d321620ca83/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L189 -int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, - const int64 tile_size, const int64 num_rows, - llvm::IRBuilder<>* builder) { - // Adds `addend` to the given `dim` of `index`. - auto offset_dim = [builder](llvm_ir::IrArray::Index index, - llvm::Value* addend, int64 dim) { - index[dim] = builder->CreateAdd(index[dim], addend); - return index; - }; - CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape())); - - Shape input_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input.GetShape()); - Shape output_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - output.GetShape()); - input = input.CastToShape(input_shape, builder); - output = output.CastToShape(output_shape, builder); - - llvm::Type* tile_type = llvm::ArrayType::get( - llvm::ArrayType::get(input.GetElementLlvmType(), tile_size), - // One extra here to avoid share memory bank conflict - tile_size + 1); - auto* tile = new llvm::GlobalVariable( - *builder->GetInsertBlock()->getParent()->getParent(), tile_type, - /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, - llvm::UndefValue::get(tile_type), "tile", nullptr, - llvm::GlobalValue::NotThreadLocal, - /*AddressSpace=*/3 /* GPU shared memory */); - - // let x = threadIdx.x - llvm::Value* x = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, num_rows * tile_size, - static_cast(x)); - x = builder->CreateIntCast(x, builder->getInt64Ty(), /*isSigned=*/true, - "thread.id.x"); - - // computing logical thread ids - // logical_x = x % tile_size - auto logical_x = builder->CreateURem(x, builder->getInt64(tile_size)); - - // logical_y = x / tile_size - auto logical_y = builder->CreateUDiv(x, builder->getInt64(tile_size)); - - // `emit_cp` emits equivalent to following pseudocode: - // if (tile_size == tile_width && tile_size == tile_height) { - // unroll for (i in range(0, tile_size, num_rows)) { - // emit_cp_element(index + {0, i, 0}, y + logical_y); - // } - // } else if (x < tile_width) { - // tile_height_upperbound = ceil(tile_height / num_rows) * num_rows; - // for (i in range(0, tile_height_upperbound, num_rows)) { - // y_loc = i + logical_y; - // if (y_loc < tile_height) - // emit_cp_element(index + {0, i, 0}, y_loc); - // } - // } - // - // We use this to emit both the copy from input to tile and the copy from tile - // to output. - // - // `index` is the origin of the row or column in the input or output array. - // - // `emit_cp_element(index, y)` emits code to copy a single element between the - // tile and the input or output array, where `y` is the `y`-position in the - // tile, whether which is row or column is a function of whether we're copying - // from input or to output, and `index` is the index into the input or output - // array. - auto emit_cp_tile = [builder, tile_size, &offset_dim, num_rows, logical_x, - logical_y]( - std::function - emit_cp_element, - llvm::Value* tile_width, llvm::Value* tile_height, - const llvm_ir::IrArray::Index& index, - const string& loop_name) { - llvm_ir::LlvmIfData if_not_last_row = llvm_ir::EmitIfThenElse( - builder->CreateAnd( - builder->CreateICmpEQ(builder->getInt64(tile_size), tile_width), - builder->CreateICmpEQ(builder->getInt64(tile_size), tile_height)), - "not_last_row", builder); - builder->SetInsertPoint(if_not_last_row.true_block->getTerminator()); - for (int64 i = 0; i < tile_size; i += num_rows) { - auto source_idx = offset_dim(index, builder->getInt64(i), /*dim=*/1); - auto y_loc = builder->CreateAdd(builder->getInt64(i), logical_y); - emit_cp_element(source_idx, y_loc); - } - builder->SetInsertPoint(if_not_last_row.false_block->getTerminator()); - llvm_ir::LlvmIfData if_in_tile = llvm_ir::EmitIfThenElse( - builder->CreateICmpULT(logical_x, tile_width), "x_in_tile", builder); - builder->SetInsertPoint(if_in_tile.true_block->getTerminator()); - - // tile_height_upper_bound = ceil(tile_height / num_rows) * num_rows - auto tile_height_upper_bound = builder->CreateMul( - builder->CreateUDiv( - builder->CreateAdd(tile_height, builder->getInt64(num_rows - 1)), - builder->getInt64(num_rows)), - builder->getInt64(num_rows)); - - auto loop = llvm_ir::ForLoop::EmitForLoop( - loop_name, builder->getInt64(0), tile_height_upper_bound, - builder->getInt64(num_rows), builder); - llvm_ir::SetToFirstInsertPoint(loop->GetHeaderBasicBlock(), builder); - builder->SetInsertPoint(loop->GetBodyBasicBlock()->getTerminator()); - - auto y_loc = builder->CreateAdd(loop->GetIndVarValue(), logical_y); - auto if_y_in_tile = llvm_ir::EmitIfThenElse( - builder->CreateICmpULT(y_loc, tile_height), "y_in_tile", builder); - builder->SetInsertPoint(if_y_in_tile.true_block->getTerminator()); - - emit_cp_element(offset_dim(index, loop->GetIndVarValue(), /*dim=*/1), - y_loc); - builder->SetInsertPoint(if_not_last_row.after_block->getTerminator()); - }; - - auto input_dims_in_tiles = input_shape.dimensions(); - // Unpermuted dimensions are untiled. - for (int i = 1; i < 3; ++i) { - input_dims_in_tiles[i] = - CeilOfRatio(input_dims_in_tiles[i], tile_size); - } - int64 num_tiles = - std::accumulate(input_dims_in_tiles.begin(), input_dims_in_tiles.end(), 1, - std::multiplies()); - const llvm_ir::IrArray::Index input_tile_index( - /*linear=*/builder->CreateIntCast( - llvm_ir::AddRangeMetadata( - 0, num_tiles, - static_cast(llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, - builder))), - builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"), - ShapeUtil::MakeShapeWithDescendingLayout( - PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)), - builder); - const llvm_ir::IrArray::Index input_tile_origin = ({ - llvm_ir::IrArray::Index index = input_tile_index; - for (int i = 1; i < 3; ++i) { - index[i] = builder->CreateMul(index[i], builder->getInt64(tile_size), - "tile_origin." + std::to_string(i)); - } - index; - }); - const llvm_ir::IrArray::Index input_index = - offset_dim(offset_dim(input_tile_origin, logical_x, /*dim=*/2), logical_y, - /*dim=*/1); - std::vector tile_dims(input_shape.dimensions().size()); - // Only last row or column may not have full size. - for (int i = 1; i < 3; ++i) { - tile_dims[i] = builder->CreateSelect( - builder->CreateICmpEQ(input_tile_index[i], - builder->getInt64(input_dims_in_tiles[i] - 1)), - builder->getInt64(input_shape.dimensions(i) - - (input_dims_in_tiles[i] - 1) * tile_size), - builder->getInt64(tile_size), "tile_size"); - } - - // Load data from input memory to shared memory tile. - emit_cp_tile( - // tile[y, x] = input_array[index] - [builder, tile, &input, logical_x](const llvm_ir::IrArray::Index& index, - llvm::Value* y) { - builder->CreateStore( - input.EmitReadArrayElement(index, builder, "input_element"), - builder->CreateGEP(tile, {builder->getInt64(0), y, logical_x})); - }, - tile_dims[2], tile_dims[1], input_index, "input"); + int unroll_factor = ComputeMaxUnrollFactor(fusion); - // Wait for all threads to reach this point, lest we copy a value from tile to - // output before the other thread copies it from input to tile. - // This is `__syncthreads` in CUDA. - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, builder); - - const llvm_ir::IrArray::Index output_tile_index( - Permute({0, 2, 1}, input_tile_index.multidim())); - const llvm_ir::IrArray::Index output_tile_origin( - Permute({0, 2, 1}, input_tile_origin.multidim())); - const llvm_ir::IrArray::Index output_index = - offset_dim(offset_dim(output_tile_origin, logical_x, /*dim=*/2), - logical_y, /*dim=*/1); - - // Store data from shared memory tile to output memory. - emit_cp_tile( - // output_array[index] = tile[x, y] - [builder, tile, &output, logical_x](const llvm_ir::IrArray::Index& index, - llvm::Value* y) { - output.EmitWriteArrayElement( - index, - builder->CreateLoad( - builder->CreateGEP(tile, {builder->getInt64(0), logical_x, y}), - "output_element"), - builder); - }, - tile_dims[1], tile_dims[2], output_index, "output"); - - return num_tiles; + thunk_sequence_->emplace_back(BuildKernelThunk( + fusion, /*implements_whole_instruction=*/true, unroll_factor)); + return IrEmitter::HandleFusion(fusion); } -} // namespace - Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { - if (ImplementedAsHostToDeviceMemcpy(ir_emitter_context_->buffer_assignment(), - *copy)) { - thunk_sequence_->emplace_back(BuildHostToDeviceCopyThunk(copy)); - return Status::OK(); - } - if (ImplementedAsDeviceToDeviceMemcpy( - ir_emitter_context_->buffer_assignment(), *copy)) { + CHECK(ShapeUtil::Compatible(copy->operand(0)->shape(), copy->shape())); + const BufferAssignment& buffer_assignment = + ir_emitter_context_->buffer_assignment(); + if (LayoutUtil::Equal(copy->operand(0)->shape().layout(), + copy->shape().layout()) && + buffer_assignment.GetUniqueTopLevelSlice(copy->operand(0)).ok()) { thunk_sequence_->emplace_back(BuildDeviceToDeviceCopyThunk(copy)); return Status::OK(); } - bool is_transpose_021; - Shape reduced_input_shape, reduced_output_shape; - std::tie(is_transpose_021, reduced_input_shape, reduced_output_shape) = - IsTranspose021(copy->operand(0)->shape(), copy->shape()); - if (is_transpose_021 && - reduced_input_shape.dimensions(1) >= kMinDimensionToTransposeTiled && - reduced_input_shape.dimensions(2) >= kMinDimensionToTransposeTiled) { - thunk_sequence_->emplace_back(BuildKernelThunk(copy)); - VLOG(3) << "Emitting tiled 0-2-1 transposition"; - constexpr int64 tile_size = 32; - constexpr int64 num_rows = 8; - int64 num_tiles = EmitTranspose021Tiled( - GetIrArray(*copy->operand(0), *copy) - .CastToShape(reduced_input_shape, &ir_builder_), - GetIrArray(*copy, *copy) - .CastToShape(reduced_output_shape, &ir_builder_), - tile_size, num_rows, &ir_builder_); - UpdateLaunchDimensions(LaunchDimensions(num_tiles, num_rows * tile_size), - LastThunk(), ir_emitter_context_->llvm_module()); + if (CheckAndEmitHloWithTile021(copy)) { return Status::OK(); } @@ -1032,7 +708,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { } Status IrEmitterUnnested::EmitExtraOutputsForReduce( - const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, + const HloInstruction* reduce, const IrArray::Index& index, tensorflow::gtl::ArraySlice< std::pair> extra_output_gens) { @@ -1040,11 +716,11 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( const HloInstruction* output = reduce->parent()->FusionInstruction(); llvm::Value* extra_output_address = GetIrArray(*output, *output, extra_output_gens[i].second) - .EmitArrayElementAddress(index, &ir_builder_, + .EmitArrayElementAddress(index, &b_, "extra_output_element_address"); TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, extra_output_gens[i].first(index)); - ir_builder_.CreateStore(extra_output_ir_value, extra_output_address); + b_.CreateStore(extra_output_ir_value, extra_output_address); } return Status::OK(); } @@ -1074,12 +750,10 @@ Status IrEmitterUnnested::EmitReductionToScalar( LaunchDimensions launch_dimensions = CalculateLaunchDimensions( tiled_input_shape, ir_emitter_context_->device_description()); - llvm::Type* index_ty = GetIndexTypeForKernel( - reduce, - launch_dimensions.block_count() * launch_dimensions.threads_per_block(), - &ir_builder_); + llvm::Type* index_ty = + GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_); - auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; @@ -1121,59 +795,57 @@ Status IrEmitterUnnested::EmitReductionToScalar( // // and threads_per_block is a multiple of warpSize. // reduce_kernel<<>>(); // - auto loop_body_emitter = - [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { const int num_reduces = reducers.size(); llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); std::vector partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN( - llvm::Value* const init_ir_value, - init_value_gens[i](llvm_ir::IrArray::Index(index_ty))); - ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + llvm::Value* partial_reduction_result_address = + b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](IrArray::Index(index_ty))); + b_.CreateStore(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } llvm::Value* x_in_tiles = tile_index[0]; - x_in_tiles = ir_builder_.CreateZExtOrTrunc(x_in_tiles, index_ty); + x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty); // Emit an inner for-loop that reduces the elements in the tile. auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { std::unique_ptr tile_element_loop = llvm_ir::ForLoop::EmitForLoop( - "element_id_in_tile", index_typed_const(0), - index_typed_const(kTileSize), index_typed_const(1), &ir_builder_); + "element_id_in_tile", index_typed_constant(0), + index_typed_constant(kTileSize), index_typed_constant(1), &b_); // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &ir_builder_); - llvm::Value* x = ir_builder_.CreateNSWAdd( - ir_builder_.CreateNSWMul(x_in_tiles, index_typed_const(kTileSize)), + &b_); + llvm::Value* x = b_.CreateNSWAdd( + b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)), tile_element_loop->GetIndVarValue()); // Unless we know the tile is entirely in bounds, we have to emit a // x-in-bounds check before reading from the input. if (!tile_in_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(x, index_typed_const(num_elems)), - "x_in_bounds", &ir_builder_); + b_.CreateICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", + &b_); // Emit code that reads the input element and accumulates it to // the partial reduction result. - llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); } - llvm_ir::IrArray::Index input_index( - /*linear=*/x, input_shape, &ir_builder_); - llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); + IrArray::Index input_index( + /*linear=*/x, input_shape, &b_); + llvm::Value* input_address = b_.CreateAlloca(element_ir_type); for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); + b_.CreateStore(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], input_address}, @@ -1184,49 +856,48 @@ Status IrEmitterUnnested::EmitReductionToScalar( // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's // immediately beyond the tile. - llvm::Value* x_end = ir_builder_.CreateNSWAdd( - index_typed_const(kTileSize), - ir_builder_.CreateNSWMul(x_in_tiles, index_typed_const(kTileSize))); + llvm::Value* x_end = b_.CreateNSWAdd( + index_typed_constant(kTileSize), + b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize))); // The tile is entirely in bound if all_threads_in_bounds or // x_end <= num_elems. - llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.CreateICmpULE(x_end, index_typed_const(num_elems)), - ir_builder_.getInt1(all_threads_in_bounds)); + llvm::Value* tile_in_bounds = + b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(num_elems)), + b_.getInt1(all_threads_in_bounds)); llvm_ir::LlvmIfData if_tile_in_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, - &ir_builder_); + llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_); TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, - &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, &b_); TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); // After the if-then-else statement on tile_in_bounds, emit calls to // shfl_down that accumulate the partial reduction results of all threads // from the warp. - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, - &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, &b_); int bit_width = llvm_ir::GetSizeInBits(element_ir_type); // bitcast cannot be applied to aggregate types (even packed ones), so we // instead bitcast addresses of load/store to intN* of the same bit-width. llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() - ? ir_builder_.getIntNTy(bit_width) + ? b_.getIntNTy(bit_width) : element_ir_type; for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; shuffle_distance /= 2) { - llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( - element_ir_type, nullptr, "result_from_other_lane"); + llvm::Value* result_from_other_lane = + b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane"); for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), + llvm::Value* partial_reduction_result = b_.CreateLoad( + b_.CreateBitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), "partial_reduction_result"); - ir_builder_.CreateStore( - EmitShuffleDown(partial_reduction_result, - ir_builder_.getInt32(shuffle_distance), - &ir_builder_), - ir_builder_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); + CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) + << "Requires block size a multiple of the warp size, otherwise we " + "will read undefined elements."; + b_.CreateStore( + EmitFullWarpShuffleDown(partial_reduction_result, + b_.getInt32(shuffle_distance), &b_), + b_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], result_from_other_lane}, @@ -1240,24 +911,23 @@ Status IrEmitterUnnested::EmitReductionToScalar( // Emit an atomic operation that accumulates the partial reduction result of // lane 0 (which holds the partially accumulated result for its warp) to the // output element. - llvm::Value* lane_id = ir_builder_.CreateURem( - x_in_tiles, index_typed_const(kWarpSize), "lane_id"); + llvm::Value* lane_id = + b_.CreateURem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id"); llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpEQ(lane_id, index_typed_const(0)), - "lane_id_is_zero", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, - &ir_builder_); + b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", + &b_); + llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); for (int i = 0; i != num_reduces; ++i) { llvm::Value* output_address = GetIrArray(*output, *output, reduce_output_shapes[i]) .EmitArrayElementAddress( - llvm_ir::IrArray::Index( - /*linear=*/ir_builder_.getInt64(0), + IrArray::Index( + /*linear=*/b_.getInt64(0), ShapeUtil::GetSubshape(output->shape(), reduce_output_shapes[i]), - &ir_builder_), - &ir_builder_, "output_element_address"); + &b_), + &b_, "output_element_address"); TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( *reducers[i], output_address, partial_reduction_result_addresses[i])); } @@ -1271,7 +941,7 @@ Status IrEmitterUnnested::EmitReductionToScalar( static_cast(LastThunk())->thunks().back().get(), ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &ir_builder_) + launch_dimensions, &b_) .EmitLoop(IrName(reduce), index_ty); } @@ -1284,8 +954,8 @@ Status IrEmitterUnnested::EmitColumnReduction( tensorflow::gtl::ArraySlice< std::pair> extra_output_gens) { - // Divide the input matrix into tiles of size Kx1. For example, when the - // input matrix is 4x4 and K=2, the tiled matrix looks like + // Divide the input matrix into tiles of size KxL. For example, when the + // input matrix is 4x4, K=2, and L=1 the tiled matrix looks like // // 0123 // 0123 @@ -1297,100 +967,131 @@ Status IrEmitterUnnested::EmitColumnReduction( // // We choose 128 as the tile size based on empirical evidence. It's big enough // to reduce the amount of atomic adds in the end, maximizing the memory - // bandwidth. - constexpr int64 kTileSize = 128; + // bandwidth. A tile width of 2 allows for high memory bandwidth utilization + // on 16b input data. + constexpr int64 kTileHeight = 128; + constexpr int64 kTileWidth = 2; - // If the height is not a multiple of the tile size, we pad the bottom of the + // If the height is not a multiple of kTileHeight, we pad the bottom of the // input matrix. - const int64 height_in_tiles = CeilOfRatio(height, kTileSize); - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {height_in_tiles, width}, {1, 0}); + const int64 height_in_tiles = CeilOfRatio(height, kTileHeight); + // If width is not a multiple of kTileWidth the rightmost thread will process + // fewer input elements. + const int64 width_in_tiles = CeilOfRatio(width, kTileWidth); + Shape tiled_input_shape = + ShapeUtil::MakeShapeWithLayout(reduce->shape().element_type(), + {height_in_tiles, width_in_tiles}, {1, 0}); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( tiled_input_shape, ir_emitter_context_->device_description()); // TODO(b/110211620): Convert to use i32 index_type when it is possible. - llvm::Type* index_ty = ir_builder_.getInt64Ty(); + llvm::Type* index_ty = b_.getInt64Ty(); - auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; - // linear_index < height_in_tiles * width; + // linear_index < height_in_tiles * width_in_tiles; // linear_index += blockDim.x * gridDim.x) { - // y_in_tiles = linear_index / width; - // x = linear_index % width; + // y_in_tiles = linear_index / width_in_tiles; + // x_in_tiles = linear_index % width_in_tiles; // - // partial_result = init_value; - // if (height % kTileSize == 0 || - // y_in_tiles * kTileSize + kTileSize <= height) { - // for (element_id_in_tile : range(kTileSize)) { - // y = y_in_tiles * kTileSize + element_id_in_tile; - // partial_result = Reducer(partial_result, input[y][x]); + // partial_results[kTileWidth] = init_values; + // tile_in_y_bounds = height % kTileHeight == 0 || + // y_in_tiles * kTileHeight + kTileHeight <= height; + // tile_in_x_bounds = width % kTileWidth == 0 || + // x_in_tiles * kTileWidth + kTileWidth <= width; + // // The implementation handles y and x bound checks separately. + // if (tile_in_y_bounds && tile_in_x_bounds) { + // for (y_offset : range(kTileHeight)) { + // y = y_in_tiles * kTileHeight + y_offset; + // for (x_offset : range(kTileWidth)) { + // x = x_in_tiles * kTileWidth + x_offset; + // partial_result = Reducer(partial_result[x_offset], input[y][x]); + // } // } // } else { - // for (element_id_in_tile : range(kTileSize)) { - // y = y_in_tiles * kTileSize + element_id_in_tile; - // if (y < height) { - // partial_result = Reducer(partial_result, input[y][x]); + // for (y_offset : range(kTileHeight)) { + // y = y_in_tiles * kTileHeight + y_offset; + // for (y_offset : range(kTileHeight)) { + // x = x_in_tiles * kTileWidth + x_offset; + // if (y < height && x < width) { + // partial_result = Reducer(partial_result, input[y][x]); + // } // } // } // } - // AtomicReducer(&output[x], partial_result); + // for (x_offset : range(kTileWidth)) { + // AtomicReducer(&output[x + x_offset], partial_result[x_offset]); + // } // } - auto loop_body_emitter = - [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { const int num_reduces = reducers.size(); // Emit the loop body that reduces one tile. llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); std::vector partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN( - llvm::Value* const init_ir_value, - init_value_gens[i](llvm_ir::IrArray::Index(index_ty))); - ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); - partial_reduction_result_addresses.push_back( - partial_reduction_result_address); + for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { + llvm::Value* partial_reduction_result_address = + b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + + llvm::Twine(i * kTileWidth + x_offset)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](IrArray::Index(index_ty))); + b_.CreateStore(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); + } } // Emit an inner for-loop that partially reduces the elements in the given // tile. llvm::Value* y_in_tiles = tile_index[0]; - llvm::Value* x = tile_index[1]; + llvm::Value* x_in_tiles = tile_index[1]; - y_in_tiles = ir_builder_.CreateZExtOrTrunc(y_in_tiles, index_ty); - x = ir_builder_.CreateZExtOrTrunc(x, index_ty); + y_in_tiles = b_.CreateZExtOrTrunc(y_in_tiles, index_ty); + x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty); - auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { + auto emit_tile_element_loop = [=](bool tile_in_y_bounds, + bool tile_in_x_bounds) -> Status { std::unique_ptr tile_element_loop = llvm_ir::ForLoop::EmitForLoop( - "element_id_in_tile", index_typed_const(0), - index_typed_const(kTileSize), index_typed_const(1), &ir_builder_); + "element_id_in_tile", index_typed_constant(0), + index_typed_constant(kTileHeight), index_typed_constant(1), &b_); // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &ir_builder_); - llvm::Value* y = ir_builder_.CreateNSWAdd( - ir_builder_.CreateNSWMul(y_in_tiles, index_typed_const(kTileSize)), + &b_); + llvm::Value* y = b_.CreateNSWAdd( + b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)), tile_element_loop->GetIndVarValue()); - // Unless we know the tile is entirely in bounds, we have to emit a - // y-in-bounds check before reading from the input. - if (!tile_in_bounds) { + // Unless we know that y is in bounds, we have to emit a check before + // reading from the input. + if (!tile_in_y_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(y, index_typed_const(height)), - "y_in_bounds", &ir_builder_); + b_.CreateICmpULT(y, index_typed_constant(height)), "y_in_bounds", + &b_); // Emit code that reads the input element and accumulates it to // the partial reduction result. - llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); } - llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); - { + for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { + llvm::Value* x = b_.CreateNSWAdd( + b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); + // Unless we know that x is in bounds, we have to emit a check before + // reading from the input. + if (!tile_in_x_bounds) { + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + b_.CreateICmpULT(x, index_typed_constant(width)), "x_in_bounds", + &b_); + llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); + } + llvm::Value* input_address = b_.CreateAlloca(element_ir_type); // {y,x} is an index to input_matrix_shape [height,width]. We need to // convert that to an index to input_shape (the shape of the operand of // "reduce"). This conversion is composed of a transposition from @@ -1406,67 +1107,95 @@ Status IrEmitterUnnested::EmitColumnReduction( const Shape input_matrix_shape = ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), {height, width}); - const llvm_ir::IrArray::Index input_matrix_index( - {y, x}, input_matrix_shape, &ir_builder_); - const llvm_ir::IrArray::Index input_index = + const IrArray::Index input_matrix_index({y, x}, input_matrix_shape, + &b_); + const IrArray::Index input_index = input_matrix_index .SourceIndexOfReshape(input_matrix_shape, - normalized_input_shape, &ir_builder_) + normalized_input_shape, &b_) .SourceIndexOfTranspose(normalized_input_shape, input_shape, - transpose_dimension_mapping, - &ir_builder_); + transpose_dimension_mapping, &b_); for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); + b_.CreateStore(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], - {partial_reduction_result_addresses[i], input_address}, - partial_reduction_result_addresses[i])); + {partial_reduction_result_addresses[i * kTileWidth + x_offset], + input_address}, + partial_reduction_result_addresses[i * kTileWidth + x_offset])); + TF_RETURN_IF_ERROR(EmitExtraOutputsForReduce(reduce, input_index, + extra_output_gens)); } - return EmitExtraOutputsForReduce(reduce, input_index, - extra_output_gens); } + return Status::OK(); }; - // y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's - // immediately beyond the tile. - llvm::Value* y_end = ir_builder_.CreateNSWAdd( - index_typed_const(kTileSize), - ir_builder_.CreateNSWMul(y_in_tiles, index_typed_const(kTileSize))); - llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.CreateICmpULE(y_end, index_typed_const(height)), - ir_builder_.getInt1(height % kTileSize == 0)); - // The tile is entirely in bound if "height" is a multiple of kTileSize or + // y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location + // that's immediately beyond the tile. + llvm::Value* y_end = b_.CreateNSWAdd( + index_typed_constant(kTileHeight), + b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight))); + // x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location + // that's immediately beyond the tile. + llvm::Value* x_end = b_.CreateNSWAdd( + index_typed_constant(kTileWidth), + b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth))); + llvm::Value* tile_in_y_bounds = + b_.CreateOr(b_.CreateICmpULE(y_end, index_typed_constant(height)), + b_.getInt1(height % kTileHeight == 0)); + llvm::Value* tile_in_x_bounds = + b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(width)), + b_.getInt1(width % kTileWidth == 0)); + // The tile is in y bounds if "height" is a multiple of kTileHeight or // y_end <= height. - llvm_ir::LlvmIfData if_tile_in_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, - &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, - &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); - - // After the if-then-else statement on tile_in_bounds, emit atomic - // operations to accumulate the partial reduction result to the output - // element. - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, - &ir_builder_); + llvm_ir::LlvmIfData if_tile_in_y_bounds_data = + llvm_ir::EmitIfThenElse(tile_in_y_bounds, "tile_in_y_bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.true_block, &b_); + // The tile is in x bounds if "width" is a multiple of kTileWidth or + // x_end <= width. + llvm_ir::LlvmIfData if_tile_in_x_bounds_data = + llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true, + /*tile_in_x_bounds=*/true)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true, + /*tile_in_x_bounds=*/false)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.false_block, &b_); + if_tile_in_x_bounds_data = + llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false, + /*tile_in_x_bounds=*/true)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false, + /*tile_in_x_bounds=*/false)); + + // After the nested if-then-else statement on tile_in_y_bounds and + // tile_in_x_bounds, emit atomic operations to accumulate the partial + // reduction result to the output element. + llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.after_block, &b_); const HloInstruction* output = reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; for (int i = 0; i != num_reduces; ++i) { - llvm::Value* output_address = - GetIrArray(*output, *output, reduce_output_shapes[i]) - .EmitArrayElementAddress( - llvm_ir::IrArray::Index( - x, - ShapeUtil::GetSubshape(output->shape(), - reduce_output_shapes[i]), - &ir_builder_), - &ir_builder_, "output_element_address"); - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, partial_reduction_result_addresses[i])); + for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { + llvm::Value* x = b_.CreateNSWAdd( + b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + IrArray::Index( + x, + ShapeUtil::GetSubshape(output->shape(), + reduce_output_shapes[i]), + &b_), + &b_, "output_element_address"); + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, + partial_reduction_result_addresses[i * kTileWidth + x_offset])); + } } return Status::OK(); }; @@ -1478,7 +1207,7 @@ Status IrEmitterUnnested::EmitColumnReduction( static_cast(LastThunk())->thunks().back().get(), ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &ir_builder_) + launch_dimensions, &b_) .EmitLoop(IrName(reduce), index_ty); } @@ -1628,28 +1357,25 @@ Status IrEmitterUnnested::EmitRowReduction( {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0}); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( tiled_input_shape, ir_emitter_context_->device_description()); - llvm::Type* index_ty = GetIndexTypeForKernel( - reduce, - launch_dimensions.block_count() * launch_dimensions.threads_per_block(), - &ir_builder_); + llvm::Type* index_ty = + GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_); - auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; - auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) { + auto loop_body_emitter = [=](const IrArray::Index& tile_index) { const int num_reduces = reducers.size(); llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( input_shape.element_type(), ir_emitter_context_->llvm_module()); std::vector partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN( - llvm::Value* const init_ir_value, - init_value_gens[i](llvm_ir::IrArray::Index(index_ty))); - ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + llvm::Value* partial_reduction_result_address = + b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](IrArray::Index(index_ty))); + b_.CreateStore(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } @@ -1658,25 +1384,25 @@ Status IrEmitterUnnested::EmitRowReduction( llvm::Value* y = tile_index[1]; llvm::Value* x_tile = tile_index[2]; - x_tile = ir_builder_.CreateZExtOrTrunc(x_tile, index_ty); + x_tile = b_.CreateZExtOrTrunc(x_tile, index_ty); llvm::Value* warp_id = - ir_builder_.CreateUDiv(x_tile, index_typed_const(kWarpSize), "warp_id"); + b_.CreateUDiv(x_tile, index_typed_constant(kWarpSize), "warp_id"); llvm::Value* lane_id = - ir_builder_.CreateURem(x_tile, index_typed_const(kWarpSize), "lane_id"); + b_.CreateURem(x_tile, index_typed_constant(kWarpSize), "lane_id"); // The x-location of the last element in this z-x-tile. // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size); - llvm::Value* last_x = ir_builder_.CreateNSWAdd( - lane_id, ir_builder_.CreateNSWMul( - index_typed_const(kWarpSize), - ir_builder_.CreateNSWAdd( - index_typed_const(x_tile_size - 1), - ir_builder_.CreateNSWMul( - warp_id, index_typed_const(x_tile_size))))); + llvm::Value* last_x = b_.CreateNSWAdd( + lane_id, + b_.CreateNSWMul( + index_typed_constant(kWarpSize), + b_.CreateNSWAdd( + index_typed_constant(x_tile_size - 1), + b_.CreateNSWMul(warp_id, index_typed_constant(x_tile_size))))); KernelSupportLibrary ksl( - &ir_builder_, + &b_, /*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll, /*prevent_vectorization=*/false); @@ -1685,22 +1411,22 @@ Status IrEmitterUnnested::EmitRowReduction( auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, int64 x_tile_loop_bound) -> Status { auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { - llvm::Value* z = ir_builder_.CreateNSWAdd( + llvm::Value* z = b_.CreateNSWAdd( z_indvar, - ir_builder_.CreateNSWMul(index_typed_const(z_tile_size), z_tile)); + b_.CreateNSWMul(index_typed_constant(z_tile_size), z_tile)); TF_RETURN_IF_ERROR(ksl.For( "x_tile", - /*start=*/index_typed_const(0), - /*end=*/index_typed_const(x_tile_loop_bound), + /*start=*/index_typed_constant(0), + /*end=*/index_typed_constant(x_tile_loop_bound), /*step=*/1, [&](llvm::Value* x_indvar) -> Status { // x = lane_id + // warpSize * (element_id_in_x_tile + warp_id * x_tile_size); - llvm::Value* x = ir_builder_.CreateNSWAdd( + llvm::Value* x = b_.CreateNSWAdd( lane_id, - ir_builder_.CreateNSWMul( - index_typed_const(kWarpSize), - ir_builder_.CreateNSWAdd( - x_indvar, ir_builder_.CreateNSWMul( + b_.CreateNSWMul( + index_typed_constant(kWarpSize), + b_.CreateNSWAdd( + x_indvar, b_.CreateNSWMul( warp_id, llvm::ConstantInt::get( index_ty, x_tile_size))))); @@ -1709,17 +1435,16 @@ Status IrEmitterUnnested::EmitRowReduction( if (!x_tile_in_bounds) { llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(x, index_typed_const(width)), - "x_in_bounds", &ir_builder_); - // Points ir_builder_ to the then-block. + b_.CreateICmpULT(x, index_typed_constant(width)), + "x_in_bounds", &b_); + // Points b_ to the then-block. llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, - &ir_builder_); + &b_); } // Emit code that reads the input element and accumulates it // to the partial reduction result. - llvm::Value* input_address = - ir_builder_.CreateAlloca(element_ir_type); + llvm::Value* input_address = b_.CreateAlloca(element_ir_type); { // {z,y,x} is an index to input_3d_tensor_shape // [depth,height,width]. We need to convert that to an index @@ -1737,21 +1462,20 @@ Status IrEmitterUnnested::EmitRowReduction( const Shape input_3d_tensor_shape = ShapeUtil::MakeShapeWithDescendingLayout( input_shape.element_type(), {depth, height, width}); - const llvm_ir::IrArray::Index input_3d_tensor_index( - {z, y, x}, input_3d_tensor_shape, &ir_builder_); - const llvm_ir::IrArray::Index input_index = + const IrArray::Index input_3d_tensor_index( + {z, y, x}, input_3d_tensor_shape, &b_); + const IrArray::Index input_index = input_3d_tensor_index .SourceIndexOfReshape(input_3d_tensor_shape, - normalized_input_shape, - &ir_builder_) + normalized_input_shape, &b_) .SourceIndexOfTranspose( normalized_input_shape, input_shape, - transpose_dimension_mapping, &ir_builder_); + transpose_dimension_mapping, &b_); for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); + b_.CreateStore(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], input_address}, @@ -1765,14 +1489,14 @@ Status IrEmitterUnnested::EmitRowReduction( }; return ksl.For("z_tile", - /*start=*/index_typed_const(0), - /*end=*/index_typed_const(z_tile_size), + /*start=*/index_typed_constant(0), + /*end=*/index_typed_constant(z_tile_size), /*step=*/1, emit_z_tile_element_loop); }; - llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.getInt1(width % (x_tile_size * kWarpSize) == 0), - ir_builder_.CreateICmpULT(last_x, index_typed_const(width))); + llvm::Value* tile_in_bounds = + b_.CreateOr(b_.getInt1(width % (x_tile_size * kWarpSize) == 0), + b_.CreateICmpULT(last_x, index_typed_constant(width))); TF_RETURN_IF_ERROR( ksl.If(tile_in_bounds, @@ -1795,23 +1519,25 @@ Status IrEmitterUnnested::EmitRowReduction( // bitcast cannot be applied to aggregate types (even packed ones), so we // instead bitcast addresses of load/store to intN* of the same bit-width. llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() - ? ir_builder_.getIntNTy(bit_width) + ? b_.getIntNTy(bit_width) : element_ir_type; for (int shuffle_distance = 16; shuffle_distance >= 1; shuffle_distance /= 2) { - llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( - element_ir_type, nullptr, "result_from_other_lane"); + llvm::Value* result_from_other_lane = + b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane"); for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), + llvm::Value* partial_reduction_result = b_.CreateLoad( + b_.CreateBitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), "partial_reduction_result"); - ir_builder_.CreateStore( - EmitShuffleDown(partial_reduction_result, - ir_builder_.getInt32(shuffle_distance), - &ir_builder_), - ir_builder_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); + CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) + << "Requires block size a multiple of the warp size, otherwise we " + "will read undefined elements."; + b_.CreateStore( + EmitFullWarpShuffleDown(partial_reduction_result, + b_.getInt32(shuffle_distance), &b_), + b_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], result_from_other_lane}, @@ -1826,20 +1552,18 @@ Status IrEmitterUnnested::EmitRowReduction( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpEQ(lane_id, index_typed_const(0)), - "lane_id_is_zero", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, - &ir_builder_); + b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", + &b_); + llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); for (int i = 0; i != num_reduces; ++i) { llvm::Value* output_address = GetIrArray(*output, *output, reduce_output_shapes[i]) .EmitArrayElementAddress( - llvm_ir::IrArray::Index( - y, - ShapeUtil::GetSubshape(output->shape(), - reduce_output_shapes[i]), - &ir_builder_), - &ir_builder_, "output_element_address"); + IrArray::Index(y, + ShapeUtil::GetSubshape( + output->shape(), reduce_output_shapes[i]), + &b_), + &b_, "output_element_address"); // We don't need to emit atomic operations if there is only one tile of // results. 'depth' is the z dimension, 'width' is the x dimension. if (z_tile_size >= depth && x_tile_size >= width) { @@ -1863,7 +1587,7 @@ Status IrEmitterUnnested::EmitRowReduction( static_cast(LastThunk())->thunks().back().get(), ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, - launch_dimensions, &ir_builder_) + launch_dimensions, &b_) .EmitLoop(IrName(reduce), index_ty); } @@ -1982,32 +1706,36 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { BuildInitializerThunk(reduce)); std::vector> thunks; thunks.push_back(std::move(initializer_thunk)); - thunks.push_back(BuildKernelThunk(reduce)); + thunks.push_back( + BuildKernelThunk(reduce, /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), reduce)); return EmitReductionToVector( - reduce, input->shape(), {[&](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*input, *reduce) - .EmitReadArrayElement(index, &ir_builder_); + reduce, input->shape(), {[&](const IrArray::Index& index) { + return GetIrArray(*input, *reduce).EmitReadArrayElement(index, &b_); }}, - {[&](const llvm_ir::IrArray::Index& index) { + {[&](const IrArray::Index& index) { return GetIrArray(*init_value, *reduce) - .EmitReadArrayElement(index, &ir_builder_); + .EmitReadArrayElement(index, &b_); }}, dimensions_to_reduce, {reducer}, {{}}, {}); } - thunk_sequence_->emplace_back(BuildKernelThunk(reduce)); + thunk_sequence_->emplace_back( + BuildKernelThunk(reduce, /*implements_whole_instruction=*/true)); return IrEmitter::HandleReduce(reduce); } Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { bool all_tuple_elements_have_buffer = c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { - return ir_emitter_context_->buffer_assignment().HasTopLevelAllocation( - tuple_element); + return ir_emitter_context_->buffer_assignment() + .GetUniqueTopLevelSlice(tuple_element) + .ok(); }); + // TODO(b/111689850): This logic isn't quite correct. + // // Tuples (especially tuples that are the final result of a computation) can // be so huge that if we were to emit a kernel that took each tuple element as // a parameter, we would exceed the max allowable number of parameters to a @@ -2015,9 +1743,9 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { // buffer, we collect their buffer addresses in a host array, and then copy // that array to the tuple's buffer. // - // Some tuple elements (e.g. const or bitcast of const) might not have a - // buffer -- their contents are stored in code. In that case, we fall back to - // emitting kernels which have access to their buffer addresses in code. + // Some tuple elements might not have an unambiguous buffer (like the result + // of a select-tuple). In that case, we fall back to emitting kernels which + // have access to their buffer addresses in code. if (all_tuple_elements_have_buffer) { std::vector tuple_element_buffers; for (const HloInstruction* tuple_element : tuple->operands()) { @@ -2027,7 +1755,8 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { tuple_element_buffers, GetAllocationSlice(*tuple), tuple)); return Status::OK(); } - thunk_sequence_->emplace_back(BuildKernelThunk(tuple)); + thunk_sequence_->emplace_back( + BuildKernelThunk(tuple, /*implements_whole_instruction=*/true)); return IrEmitter::HandleTuple(tuple); } @@ -2052,7 +1781,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( BuildInitializerThunk(select_and_scatter)); std::vector> thunks; thunks.push_back(std::move(initializer_thunk)); - thunks.push_back(BuildKernelThunk(select_and_scatter)); + thunks.push_back(BuildKernelThunk(select_and_scatter, + /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), select_and_scatter)); @@ -2065,8 +1795,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( LaunchDimensions launch_dimensions = CalculateLaunchDimensions( source->shape(), ir_emitter_context_->device_description()); llvm::Type* index_type = GetIndexTypeForKernel( - select_and_scatter, launch_dimensions.launch_bound(), &ir_builder_); - auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + select_and_scatter, launch_dimensions.launch_bound(), &b_); + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_type, c); }; @@ -2089,114 +1819,106 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // selected_index = I // initialized_flag = true // output(selected_index) = scatter(output(selected_index), source(S)) - auto loop_body_emitter = - [=](const llvm_ir::IrArray::Index& source_index) -> Status { + auto loop_body_emitter = [=](const IrArray::Index& source_index) -> Status { // Allocate space to keep the currently selected value, its index, and a // boolean flag if the value is initialized. The initialized_flag is set // false. llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(operand_element_type, ir_emitter_context_->llvm_module()), - "selected_value_address", &ir_builder_); + "selected_value_address", &b_); llvm::Value* selected_index_address = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - index_type, index_typed_const(rank), "selected_index_address", - &ir_builder_); + index_type, index_typed_constant(rank), "selected_index_address", + &b_); llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( - ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_); - ir_builder_.CreateStore(ir_builder_.getInt1(false), - initialized_flag_address); + b_.getInt1Ty(), "initialized_flag_address", &b_); + b_.CreateStore(b_.getInt1(false), initialized_flag_address); // Create the inner loop to iterate over the window. - llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), - &ir_builder_, index_type); + llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_, + index_type); std::vector window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); CHECK_GT(dim.size(), 0); } - const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape( + const IrArray::Index window_index = window_loops.AddLoopsForShape( ShapeUtil::MakeShape(operand_element_type, window_size), "window"); llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(), - &ir_builder_); + &b_); // Compute the operand index to visit and evaluate the condition whether the // operand index is within the bounds. The unsigned comparison includes // checking whether the operand index >= 0. - llvm_ir::IrArray::Index operand_index(index_type, source_index.size()); - llvm::Value* in_bounds_condition = ir_builder_.getInt1(true); + IrArray::Index operand_index(index_type, source_index.size()); + llvm::Value* in_bounds_condition = b_.getInt1(true); for (int64 i = 0; i < rank; ++i) { - llvm::Value* strided_index = ir_builder_.CreateNSWMul( - source_index[i], index_typed_const(window.dimensions(i).stride())); - operand_index[i] = ir_builder_.CreateNSWSub( - ir_builder_.CreateNSWAdd(strided_index, window_index[i]), - index_typed_const(window.dimensions(i).padding_low())); - llvm::Value* index_condition = ir_builder_.CreateICmpULT( + llvm::Value* strided_index = b_.CreateNSWMul( + source_index[i], index_typed_constant(window.dimensions(i).stride())); + operand_index[i] = b_.CreateNSWSub( + b_.CreateNSWAdd(strided_index, window_index[i]), + index_typed_constant(window.dimensions(i).padding_low())); + llvm::Value* index_condition = b_.CreateICmpULT( operand_index[i], - index_typed_const(ShapeUtil::GetDimension(operand->shape(), i))); - in_bounds_condition = - ir_builder_.CreateAnd(in_bounds_condition, index_condition); + index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i))); + in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); } CHECK(in_bounds_condition != nullptr); // Only need to do something if the operand index is within the bounds. // First check if the initialized_flag is set. llvm_ir::LlvmIfData if_in_bounds = - llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &ir_builder_); + llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); + llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_); llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( - ir_builder_.CreateLoad(initialized_flag_address), "initialized", - &ir_builder_); + b_.CreateLoad(initialized_flag_address), "initialized", &b_); // If the initialized_flag is false, initialize the selected value and index // with the currently visiting operand. - llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_); - const auto save_operand_index = [&]( - const llvm_ir::IrArray::Index& operand_index) { + llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &b_); + const auto save_operand_index = [&](const IrArray::Index& operand_index) { for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - ir_builder_.CreateInBoundsGEP(selected_index_address, - {ir_builder_.getInt32(i)}); - ir_builder_.CreateStore(operand_index[i], selected_index_address_slot); + b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); + b_.CreateStore(operand_index[i], selected_index_address_slot); } }; - llvm_ir::IrArray operand_array = GetIrArray(*operand, *select_and_scatter); + IrArray operand_array = GetIrArray(*operand, *select_and_scatter); llvm::Value* operand_data = - operand_array.EmitReadArrayElement(operand_index, &ir_builder_); - ir_builder_.CreateStore(operand_data, selected_value_address); + operand_array.EmitReadArrayElement(operand_index, &b_); + b_.CreateStore(operand_data, selected_value_address); save_operand_index(operand_index); - ir_builder_.CreateStore(ir_builder_.getInt1(true), - initialized_flag_address); + b_.CreateStore(b_.getInt1(true), initialized_flag_address); // If the initialized_flag is true, call the `select` function to // potentially update the selected value and index with the currently // visiting operand. - llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_); const Shape output_shape = ShapeUtil::MakeShape(PRED, {}); llvm::Value* operand_address = - operand_array.EmitArrayElementAddress(operand_index, &ir_builder_); + operand_array.EmitArrayElementAddress(operand_index, &b_); llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(PRED, ir_emitter_context_->llvm_module()), - "select_return_buffer", &ir_builder_); + "select_return_buffer", &b_); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *select_and_scatter->select(), {selected_value_address, operand_address}, select_return_buffer)); - llvm::Value* result = ir_builder_.CreateLoad(select_return_buffer); + llvm::Value* result = b_.CreateLoad(select_return_buffer); // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. - llvm::Value* cond = ir_builder_.CreateICmpNE( + llvm::Value* cond = b_.CreateICmpNE( result, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType( PRED, ir_emitter_context_->llvm_module()), 0), "boolean_predicate"); llvm_ir::LlvmIfData if_select_lhs = - llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &ir_builder_); - ir_builder_.CreateStore(ir_builder_.CreateLoad(operand_address), - selected_value_address); + llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_); + llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_); + b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address); save_operand_index(operand_index); // After iterating over the window elements, scatter the source element to @@ -2204,20 +1926,19 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // location is computed by calling the `scatter` function with the source // value and the current output value. llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), - &ir_builder_); - llvm_ir::IrArray::Index selected_index(operand_index.GetType()); + &b_); + IrArray::Index selected_index(operand_index.GetType()); for (int64 i = 0; i < rank; ++i) { - llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( - selected_index_address, {ir_builder_.getInt32(i)}); - selected_index.push_back( - ir_builder_.CreateLoad(selected_index_address_slot)); + llvm::Value* selected_index_address_slot = + b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); + selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); } llvm::Value* source_value_address = GetIrArray(*source, *select_and_scatter) - .EmitArrayElementAddress(source_index, &ir_builder_); + .EmitArrayElementAddress(source_index, &b_); llvm::Value* output_value_address = GetIrArray(*select_and_scatter, *select_and_scatter) - .EmitArrayElementAddress(selected_index, &ir_builder_); + .EmitArrayElementAddress(selected_index, &b_); return EmitAtomicOperationForNestedComputation( *select_and_scatter->scatter(), output_value_address, source_value_address); @@ -2232,7 +1953,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( static_cast(LastThunk())->thunks().back().get(), ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, source->shape(), - launch_dimensions, &ir_builder_) + launch_dimensions, &b_) .EmitLoop(IrName(select_and_scatter), index_type); } @@ -2242,33 +1963,147 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { condition->root_instruction()->shape().element_type() == PRED) << "While condition computation must return bool"; // Build ForThunk for conformant while loops, otherwise build WhileThunk. - auto result = CanTransformWhileToFor(xla_while); - if (result.ok()) { - auto tuple = result.ConsumeValueOrDie(); - // loop_trip_count = (limit - start + increment - 1) / increment - const int64 loop_trip_count = - (std::get<1>(tuple) - std::get<0>(tuple) + std::get<2>(tuple) - 1) / - std::get<2>(tuple); - thunk_sequence_->emplace_back(BuildForThunk(xla_while, loop_trip_count)); + // TODO(b/112163966): Move trip count computation earlier in the pipeline. + if (auto loop_trip_count = ComputeWhileLoopTripCount(xla_while)) { + thunk_sequence_->emplace_back(BuildForThunk(xla_while, *loop_trip_count)); VLOG(3) << "Built ForThunk for while: " << xla_while->name(); } else { thunk_sequence_->emplace_back(BuildWhileThunk(xla_while)); - VLOG(3) << "Built WhileThunk for while: " << xla_while->name() - << " while-to-for transform status: " << result.status(); + VLOG(3) << "Built WhileThunk for while: " << xla_while->name(); } return Status::OK(); } -Status IrEmitterUnnested::HandleRng(HloInstruction* random) { - thunk_sequence_->push_back(BuildKernelThunk(random)); - return IrEmitter::HandleRng(random); +Status IrEmitterUnnested::HandleRng(HloInstruction* rng) { + // Build the kernel to generate the random numbers. + // + // Unroll the kernel so that the duplicated computation that calculates the + // 128 bit sample can be optimized away by LLVM. + thunk_sequence_->emplace_back( + BuildKernelThunk(rng, /*implements_whole_instruction=*/false, + ComputeMaxUnrollFactor(rng))); + ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; + for (const HloInstruction* operand : rng->operands()) { + operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*operand, *rng).EmitReadArrayElement(index, &b_); + }; + } + TF_RETURN_IF_ERROR(EmitTargetElementLoop( + *rng, GpuElementalIrEmitter(hlo_module_config_, module_, &b_, + GetNestedComputer()) + .MakeElementGenerator(rng, operand_to_generator))); + std::unique_ptr rng_thunk = std::move(thunk_sequence_->back()); + thunk_sequence_->pop_back(); + + // Emit a kernel to increment the global state for Philox RNG algorithm. + thunk_sequence_->emplace_back( + BuildKernelThunk(rng, /*implements_whole_instruction=*/false)); + llvm_ir::IncrementVariableForPhiloxRngState(1, module_, &b_); + std::unique_ptr increment_seed_thunk = + std::move(thunk_sequence_->back()); + thunk_sequence_->pop_back(); + + // Build the SequentialThunk for the RNG hlo. + std::vector> thunks; + thunks.reserve(2); + thunks.push_back(std::move(rng_thunk)); + thunks.push_back(std::move(increment_seed_thunk)); + thunk_sequence_->emplace_back( + MakeUnique(std::move(thunks), rng)); + + return Status::OK(); } Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { - thunk_sequence_->push_back(BuildKernelThunk(select)); + thunk_sequence_->push_back( + BuildKernelThunk(select, /*implements_whole_instruction=*/true)); return IrEmitter::HandleSelect(select); } +Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { + std::vector> thunks; + auto keys = sort->operand(0); + auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; + ShapeIndex keys_shape_index({}); + ShapeIndex values_shape_index({}); + if (values != nullptr) { + keys_shape_index = ShapeIndex({0}); + values_shape_index = ShapeIndex({1}); + } + auto keys_destination = GetAllocationSlice(*sort, keys_shape_index); + auto values_destination = GetAllocationSlice(*sort, values_shape_index); + + if (keys_destination != GetAllocationSlice(*keys)) { + thunks.push_back(MakeUnique( + /*source_address=*/GetAllocationSlice(*keys), + /*destination_buffer=*/keys_destination, + /*mem_size=*/ShapeUtil::ByteSizeOf(keys->shape()), nullptr)); + } + if (values != nullptr && values_destination != GetAllocationSlice(*values)) { + // TODO(b/26783907): Figure out why we never seem to share buffers for + // key/value sort. + thunks.push_back(MakeUnique( + /*source_address=*/GetAllocationSlice(*values), + /*destination_buffer=*/values_destination, + /*mem_size=*/ShapeUtil::ByteSizeOf(values->shape()), nullptr)); + } + + int64 dimension_to_sort = sort->dimensions(0); + int64 dimension_to_sort_bound = keys->shape().dimensions(dimension_to_sort); + int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); + auto index_type = b_.getInt64Ty(); + + // Naive C++ code for the outer loops: + // + // for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound); + // ++stage) { + // int64 first_xor_mask = (1LL << (stage + 1)) - 1; + // SortInPlace(first_xor_mask); + // for (int64 mask = stage - 1; mask >= 0; --mask) { + // int64 later_xor_mask = 1LL << mask; + // SortInPlace(later_xor_mask); + // } + // } + // + // This follows the algorithm described on Wikipedia: + // https://en.wikipedia.org/wiki/Bitonic_sorter + + for (int64 stage = 0; stage < num_stages; ++stage) { + for (int64 mask = stage; mask >= 0; --mask) { + thunks.push_back( + BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + keys->shape(), ir_emitter_context_->device_description()); + UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), + ir_emitter_context_->llvm_module()); + + llvm::Value* xor_mask; + if (mask == stage) { + xor_mask = llvm::ConstantInt::get(index_type, (1LL << (stage + 1)) - 1); + } else { + xor_mask = llvm::ConstantInt::get(index_type, 1LL << mask); + } + + TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace( + dimension_to_sort, GetIrArray(*sort, *sort, keys_shape_index), + values != nullptr ? tensorflow::gtl::make_optional( + GetIrArray(*sort, *sort, values_shape_index)) + : tensorflow::gtl::nullopt, + IrName(sort), xor_mask, &b_, &launch_dimensions)); + } + } + + thunk_sequence_->emplace_back( + MakeUnique(std::move(thunks), sort)); + return Status::OK(); +} + +Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) { + thunk_sequence_->push_back( + BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true)); + return IrEmitter::HandleTupleSelect(tuple_select); +} + Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { if (hlo_module_config_.replica_count() != 1) { // TODO(b/33011107): Support nontrivial cross replica sum on GPU. @@ -2304,12 +2139,12 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { thunks.push_back(MakeUnique( /*source_address=*/GetAllocationSlice(*crs->operand(i)), /*destination_buffer=*/tuple_element_buffers.back(), - /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), crs)); + /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr)); } // Output a tuple of the buffers above. thunks.push_back(MakeUnique(tuple_element_buffers, - GetAllocationSlice(*crs), crs)); + GetAllocationSlice(*crs), nullptr)); thunk_sequence_->push_back( MakeUnique(std::move(thunks), crs)); return Status::OK(); @@ -2324,6 +2159,11 @@ Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { return Status::OK(); } +Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) { + thunk_sequence_->emplace_back(BuildOutfeedThunk(outfeed)); + return Status::OK(); +} + // Figures out how to access the buffers for all subshapes of hlo's operands and // for hlo itself (i.e. all the buffers produced by HLO). // @@ -2411,11 +2251,6 @@ GetHloBufferSlices(const HloInstruction* hlo, // Adds entries for all subshapes of instr to `slices`. auto add_slices_for = [&](const HloInstruction* instr) { - // GPU constants don't have buffers; don't bother looking for one. - if (instr->IsConstant()) { - return; - } - ShapeUtil::ForEachSubshape( instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) { if (slices.count({instr, index})) { @@ -2443,7 +2278,8 @@ GetHloBufferSlices(const HloInstruction* hlo, } std::unique_ptr IrEmitterUnnested::BuildKernelThunk( - const HloInstruction* inst, int unroll_factor) { + const HloInstruction* inst, bool implements_whole_instruction, + int unroll_factor) { const BufferAssignment& buffer_assn = ir_emitter_context_->buffer_assignment(); @@ -2476,21 +2312,25 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( // We'll pass a pointer to each of the elements of `buffers` to our kernel, in // this order. - std::vector buffers(buffers_needed.begin(), - buffers_needed.end()); - std::sort(buffers.begin(), buffers.end(), + std::vector non_constant_buffers; + c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers), + [](const BufferAllocation* allocation) { + return !allocation->is_constant(); + }); + + std::sort(non_constant_buffers.begin(), non_constant_buffers.end(), [](const BufferAllocation* a, const BufferAllocation* b) { return a->index() < b->index(); }); - llvm::Function* kernel = BuildKernelPrototype(*inst, buffers); + llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers); // Build a map from a BufferAllocation to the corresponding argument in our // kernel. std::unordered_map kernel_args; { auto arg_it = kernel->arg_begin(); - auto buffers_it = buffers.begin(); + auto buffers_it = non_constant_buffers.begin(); for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) { kernel_args[*buffers_it] = arg_it; } @@ -2508,18 +2348,24 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( << " is found in slice " << slice.ToString() << " at GTE index " << gte_index.ToString(); - llvm::Value* loc = - ir_builder_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), - {ir_builder_.getInt64(slice.offset())}); + llvm::Value* loc; + if (slice.allocation()->is_constant()) { + loc = ir_emitter_context_->llvm_module()->getGlobalVariable( + llvm_ir::AsStringRef(llvm_ir::ConstantBufferAllocationToGlobalName( + *slice.allocation()))); + CHECK_NE(loc, nullptr); + } else { + loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), + {b_.getInt64(slice.offset())}); + } // If gte_index is nonempty, we have to dereference `loc` to get to the // value we're ultimately interested in. llvm::Type* int8_double_pointer = - llvm::PointerType::get(ir_builder_.getInt8PtrTy(), /*AddressSpace=*/0); + llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0); for (int64 idx : gte_index) { - loc = ir_builder_.CreateBitCast(loc, int8_double_pointer); - loc = ir_builder_.CreateLoad( - ir_builder_.CreateInBoundsGEP(loc, {ir_builder_.getInt64(idx)})); + loc = b_.CreateBitCast(loc, int8_double_pointer); + loc = b_.CreateLoad(b_.CreateInBoundsGEP(loc, {b_.getInt64(idx)})); } bindings_.BindHloToIrValue(*instr, loc, index); @@ -2531,11 +2377,12 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( bindings_.SetTempBufferBase(kernel_args.at(*temp_buffer)); } else { bindings_.SetTempBufferBase( - llvm::ConstantPointerNull::get(ir_builder_.getInt8PtrTy())); + llvm::ConstantPointerNull::get(b_.getInt8PtrTy())); } - return MakeUnique(buffers, llvm_ir::AsString(kernel->getName()), - inst, unroll_factor); + return MakeUnique( + non_constant_buffers, llvm_ir::AsString(kernel->getName()), + implements_whole_instruction ? inst : nullptr, unroll_factor); } std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( @@ -2569,7 +2416,7 @@ std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( ShapeTree slices(inst->shape()); slices.ForEachMutableElement( - [this, inst](const ShapeIndex& index, BufferAllocation::Slice* slice) { + [&](const ShapeIndex& index, BufferAllocation::Slice* slice) { *slice = ir_emitter_context_->buffer_assignment() .GetUniqueSlice(inst, index) .ConsumeValueOrDie(); @@ -2577,6 +2424,23 @@ std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( return MakeUnique(slices, inst); } +std::unique_ptr IrEmitterUnnested::BuildOutfeedThunk( + const HloInstruction* inst) { + CHECK_EQ(HloOpcode::kOutfeed, inst->opcode()); + + ShapeTree slices(inst->operand(0)->shape()); + slices.ForEachMutableElement( + [&](const ShapeIndex& index, BufferAllocation::Slice* slice) { + auto status_or_slice = + ir_emitter_context_->buffer_assignment().GetUniqueSlice( + inst->operand(0), index); + if (status_or_slice.ok()) { + *slice = status_or_slice.ConsumeValueOrDie(); + } + }); + return MakeUnique(std::move(slices), inst); +} + namespace { double GetScalarConstantAsDouble(const Literal& literal) { switch (literal.shape().element_type()) { @@ -2692,6 +2556,11 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( init_value = hlo->operand(init_value->parameter_number()); } + // Initializer thunks don't implement a whole instruction, and we want to + // profile the whole instruction instead of the individual thunks it consists + // of. Therefore we pass nullptr as the HloInstruction* to the thunks we + // generate below. + // // In the common case, the initializer is a constant. In this case, emit a // device-memset call if we can. Currently StreamExecutor only supports // zeroing and 32-bit memsets. @@ -2705,7 +2574,8 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( ArraySlice literal_bytes( reinterpret_cast(literal.untyped_data()), num_bytes); if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { - return {MakeUnique(GetAllocationSlice(*hlo, index), hlo)}; + return { + MakeUnique(GetAllocationSlice(*hlo, index), nullptr)}; } // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by @@ -2723,7 +2593,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( } uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); return {MakeUnique( - pattern32, GetAllocationSlice(*hlo, index), hlo)}; + pattern32, GetAllocationSlice(*hlo, index), nullptr)}; } // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit @@ -2734,12 +2604,13 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( uint32 word; memcpy(&word, literal_bytes.data(), sizeof(word)); return {MakeUnique( - word, GetAllocationSlice(*hlo, index), hlo)}; + word, GetAllocationSlice(*hlo, index), nullptr)}; } } // Otherwise fall back to our slow initializer code. - std::unique_ptr kernel_thunk = BuildKernelThunk(hlo); + std::unique_ptr kernel_thunk = + BuildKernelThunk(hlo, /*implements_whole_instruction=*/false); LaunchDimensions launch_dimensions = CalculateLaunchDimensions(ShapeUtil::GetSubshape(hlo->shape(), index), ir_emitter_context_->device_description()); @@ -2748,15 +2619,24 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( // If the init_value was fused into this reduce we have to generate it first. if (fused && init_value_operand->opcode() != HloOpcode::kParameter) { CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode()); - TF_RETURN_IF_ERROR(HandleConstant(const_cast(init_value))); + + const Literal& literal = init_value_operand->literal(); + llvm::Constant* initializer = + llvm_ir::ConvertLiteralToIrConstant(literal, module_); + + llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( + *module_, initializer->getType(), + /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, + /*Name=*/""); + global_for_const->setAlignment(kConstantBufferAlignBytes); + bindings_.BindHloToIrValue(*init_value_operand, global_for_const); } TF_RETURN_IF_ERROR(ParallelLoopEmitter( - [=](const llvm_ir::IrArray::Index& index) { + [=](const IrArray::Index& index) { return GetIrArray(*init_value, *hlo) - .EmitReadArrayElement(index, &ir_builder_); + .EmitReadArrayElement(index, &b_); }, - GetIrArray(*hlo, *hlo, index), launch_dimensions, - &ir_builder_) + GetIrArray(*hlo, *hlo, index), launch_dimensions, &b_) .EmitLoop(IrName(hlo))); // Clean up state left behind by emitting the loop above. (This is normally @@ -2867,13 +2747,13 @@ std::unique_ptr IrEmitterUnnested::BuildWhileThunk( HloComputation* condition = hlo->while_condition(); IrEmitterUnnested ir_emitter_condition(hlo_module_config_, condition, ir_emitter_context_); - TF_CHECK_OK(condition->root_instruction()->Accept(&ir_emitter_condition)); + TF_CHECK_OK(condition->Accept(&ir_emitter_condition)); // Generate thunk sequence for while 'body'. HloComputation* body = hlo->while_body(); IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, ir_emitter_context_); - TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body)); + TF_CHECK_OK(body->Accept(&ir_emitter_body)); return MakeUnique( GetAllocationSlice(*condition->root_instruction()), // cond result @@ -2891,7 +2771,7 @@ std::unique_ptr IrEmitterUnnested::BuildForThunk( HloComputation* body = hlo->while_body(); IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, ir_emitter_context_); - TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body)); + TF_CHECK_OK(body->Accept(&ir_emitter_body)); return MakeUnique(loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo); @@ -2907,12 +2787,12 @@ std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( HloComputation* true_computation = hlo->true_computation(); IrEmitterUnnested ir_emitter_true(hlo_module_config_, true_computation, ir_emitter_context_); - TF_CHECK_OK(true_computation->root_instruction()->Accept(&ir_emitter_true)); + TF_CHECK_OK(true_computation->Accept(&ir_emitter_true)); HloComputation* false_computation = hlo->false_computation(); IrEmitterUnnested ir_emitter_false(hlo_module_config_, false_computation, ir_emitter_context_); - TF_CHECK_OK(false_computation->root_instruction()->Accept(&ir_emitter_false)); + TF_CHECK_OK(false_computation->Accept(&ir_emitter_false)); return MakeUnique( GetAllocationSlice(*hlo->operand(0)), @@ -2940,41 +2820,588 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( ir_emitter_context_->llvm_module()); if (!hlo.IsMultiOutputFusion()) { return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo), - launch_dimensions, &ir_builder_, unroll_factor) - .EmitLoop(IrName(&hlo), - GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), - &ir_builder_)); + launch_dimensions, &b_, unroll_factor) + .EmitLoop( + IrName(&hlo), + GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), &b_)); } - // For multiple outputs fusion, we need to emit each operand and the root. - std::vector output_arrays; + // For multioutput fusion, we need to emit each operand and the root. + std::vector output_arrays; for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { output_arrays.push_back(GetIrArray(hlo, hlo, {i})); } TF_RETURN_IF_ERROR( ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, - &ir_builder_, unroll_factor) + &b_, unroll_factor) .EmitLoop(IrName(&hlo), GetIndexTypeForKernel( - &hlo, launch_dimensions.launch_bound(), &ir_builder_))); + &hlo, launch_dimensions.launch_bound(), &b_))); std::vector tuple_operand_ptrs; for (int64 i = 0; i < output_arrays.size(); ++i) { tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } - ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_, - module_); + b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_); return Status::OK(); } Status IrEmitterUnnested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) { - CHECK(Thunk::Kind::kKernel == LastThunk()->kind()); + CHECK_EQ(Thunk::Kind::kKernel, LastThunk()->kind()); return EmitTargetElementLoopInThunk(hlo, element_generator, static_cast(LastThunk())); } +int IrEmitterUnnested::ConstructIrArrayForOutputs( + const HloInstruction& hlo, std::vector* output_arrays) { + int64 num_outputs = 1; + if (hlo.IsMultiOutputFusion()) { + num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); + output_arrays->reserve(num_outputs); + for (int64 i = 0; i < num_outputs; ++i) { + output_arrays->push_back(GetIrArray(hlo, hlo, {i})); + } + } else { + output_arrays->push_back(GetIrArray(hlo, hlo)); + } + return num_outputs; +} + +int IrEmitterUnnested::ConstructIrArrayForInputs( + const HloInstruction& hlo, std::vector* param_arrays) { + int64 num_params = hlo.operands().size(); + param_arrays->reserve(num_params); + for (const HloInstruction* param : hlo.operands()) { + param_arrays->push_back(GetIrArray(*param, hlo)); + } + return num_params; +} + +int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( + const HloInstruction& hlo, const std::vector& output_arrays, + tensorflow::gtl::ArraySlice reduced_output_dims, + std::vector* output_reduced_shapes, + std::vector* output_in_reduced_shape_arrays) { + int64 num_outputs = 1; + if (hlo.IsMultiOutputFusion()) { + num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); + output_in_reduced_shape_arrays->reserve(num_outputs); + output_reduced_shapes->reserve(num_outputs); + for (int64 i = 0; i < num_outputs; ++i) { + output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( + ShapeUtil::GetSubshape(hlo.shape(), {i}).element_type(), + reduced_output_dims)); + output_in_reduced_shape_arrays->push_back( + output_arrays[i].CastToShape((*output_reduced_shapes)[i], &b_)); + } + } else { + output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( + hlo.shape().element_type(), reduced_output_dims)); + output_in_reduced_shape_arrays->push_back( + output_arrays[0].CastToShape((*output_reduced_shapes)[0], &b_)); + } + return num_outputs; +} + +int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( + const HloInstruction& hlo, const std::vector& param_arrays, + const std::vector& param_buffers, + tensorflow::gtl::ArraySlice reduced_output_dims, + std::vector* param_reduced_shapes, + std::vector* param_in_reduced_shape_arrays) { + int64 num_params = hlo.operands().size(); + param_in_reduced_shape_arrays->reserve(num_params); + param_reduced_shapes->reserve(num_params); + for (int64 id = 0; id < num_params; ++id) { + if (param_buffers[id] == nullptr) { + param_reduced_shapes->push_back(Shape()); + param_in_reduced_shape_arrays->push_back(IrArray()); + continue; + } + const HloInstruction* param = hlo.operand(id); + param_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( + param->shape().element_type(), + Permute({0, 2, 1}, reduced_output_dims))); + param_in_reduced_shape_arrays->push_back( + param_arrays[id].CastToShape((*param_reduced_shapes)[id], &b_)); + } + return num_params; +} + +namespace { + +// Reads thread_idx.x and converts it to a (y,x) coordinate, assuming that the +// thread lives within a square tile of size tile_size (so thread blocks are of +// size tile_size * tile_size). +std::tuple CalculateYXCoordinateWithinTile( + llvm::IRBuilder<>* builder, llvm::Value* tile_size, + int64 threads_per_tile) { + // Calculate the starting element coordinate within a tile for the current + // thread, (y, x) from thread_id. + llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder); + llvm_ir::AddRangeMetadata(0, threads_per_tile, + llvm::cast(thread_id)); + thread_id = builder->CreateIntCast(thread_id, tile_size->getType(), + /*isSigned=*/true, "thread.id.x"); + auto x = builder->CreateURem(thread_id, tile_size); + auto y = builder->CreateUDiv(thread_id, tile_size); + return std::make_tuple(y, x); +} + +// Reads block_idx.x, casts it to type index_ty, and adds the assumption that +// it's in the range [0, num_blocks]. +llvm::Value* GetBlockIdx(llvm::IRBuilder<>* builder, llvm::Type* index_ty, + int64 num_blocks) { + llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, builder); + llvm_ir::AddRangeMetadata(0, num_blocks, + llvm::cast(block_id)); + return builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true, + "block.id.x"); +} + +// Emits code to process up to (tile_size/num_rows) elements in a tile, given +// `emit_elem_function` is the function to emit code to process one element, `y` +// and `x` are the coordinates for the first element to process, and `index` is +// the index for the origin of the tile. Emits bounds check to ensure that each +// processed element is within the boundary defined by `tile_width` and +// `tile_height`. +void EmitTiledElementalCodeWithBoundsCheck( + int64 tile_size, int64 num_rows, const IrArray::Index& index, + const string& loop_name, KernelSupportLibrary* ksl, + llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, + llvm::Value* tile_width, llvm::Value* tile_height, + const std::function& + emit_elem_function) { + llvm::Type* index_ty = tile_width->getType(); + // Emits a constant value with index type. + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + // Adds `addend` to the given `dim` of `index`. + auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { + index[dim] = builder->CreateAdd(index[dim], addend); + return index; + }; + + auto emit_full_tile = [&] { + for (int64 i = 0; i < tile_size; i += num_rows) { + auto source_idx = offset_dim(index, index_typed_constant(i), /*dim=*/1); + auto y_loc = builder->CreateAdd(index_typed_constant(i), y); + emit_elem_function(source_idx, y_loc); + } + }; + + auto emit_last_row = [&] { + ksl->IfReturnVoid("x_in_tile", builder->CreateICmpULT(x, tile_width), [&] { + // tile_height_upper_bound = + // ceil(tile_height / num_rows) * num_rows + auto tile_height_upper_bound = builder->CreateMul( + builder->CreateUDiv( + builder->CreateAdd(tile_height, + index_typed_constant(num_rows - 1)), + index_typed_constant(num_rows)), + index_typed_constant(num_rows)); + ksl->ForReturnVoid( + loop_name, /*start=*/index_typed_constant(0), + /*end=*/tile_height_upper_bound, + /*step=*/index_typed_constant(num_rows), [&](llvm::Value* y_indvar) { + auto y_loc = builder->CreateAdd(y_indvar, y); + ksl->IfReturnVoid( + "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), [&] { + emit_elem_function(offset_dim(index, y_indvar, /*dim=*/1), + y_loc); + }); + }); + }); + }; + ksl->IfReturnVoid( + "full_tile", + builder->CreateAnd( + builder->CreateICmpEQ(index_typed_constant(tile_size), tile_width), + builder->CreateICmpEQ(index_typed_constant(tile_size), tile_height)), + emit_full_tile, emit_last_row); +} +} // namespace + +// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose +// algorithm to improve the memory access patterns for the input parameters +// which have a shape that is a 0-2-1 transpose of the output tensors. +// +// For the purpose of tiling, the output tensors have a logical shape of three +// components 0-2-1 while the relevant input parameters have a logical shape of +// three components 0-1-2 in the order major to minor. The x- and y- dimensions +// of the tensors are tiled in square tiles of edge length `kTileSize`. Each +// thread block of `kTileSize` x `kNumRows` threads transposes one tile: each +// thread copies kTileSize/kNumRows elements from the input to a shared memory +// tile, then the otherwise "regular hlo kernel" reads from the shared memory +// instead of the original input. +// +// This is similar to the following CUDA algorithm in TensorFlow: +// https://goo.gl/MStRV6. +// +// `kTileSize` should usually be same as warp size. We currently choose 32 for +// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`. +// +// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient +// to launch fewer blocks so each transposes many tiles. +LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( + HloInstruction* hlo, tensorflow::gtl::ArraySlice reduced_output_dims, + tensorflow::gtl::ArraySlice tiled_param_ids) { + // Parameters for the tiling algorithm. + constexpr int64 kTileSize = 32; + constexpr int64 kNumRows = 4; + constexpr int64 kThreadsPerTile = kTileSize * kNumRows; + + // Construct IrArrays for the inputs and outputs. + std::vector output_arrays; + int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays); + std::vector param_arrays; + int64 num_params = ConstructIrArrayForInputs(*hlo, ¶m_arrays); + + // Allocate shared memory buffers to store the tiled inputs. + std::vector param_shmem_buffers(num_params, nullptr); + for (int64 id : tiled_param_ids) { + const HloInstruction* param = hlo->operand(id); + // Add 1 to the minor dimension to reduce shared memory bank conflicts. + llvm::Type* tile_type = llvm::ArrayType::get( + llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType( + param->shape().element_type(), module_), + kTileSize + 1), + kTileSize); + const int kNVPTXSharedMemoryAddrSpace = 3; + auto* tile_base_ptr = new llvm::GlobalVariable( + *b_.GetInsertBlock()->getParent()->getParent(), tile_type, + /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, + llvm::UndefValue::get(tile_type), + llvm_ir::AsStringRef(IrName(hlo, StrCat("tile", id))), nullptr, + llvm::GlobalValue::NotThreadLocal, kNVPTXSharedMemoryAddrSpace); + param_shmem_buffers[id] = tile_base_ptr; + VLOG(3) << "Added shmem buffer for parameter " << id << ": " + << llvm_ir::DumpToString(*tile_base_ptr); + } + + // The 0-2-1 shape of the tiling scheme is the reduced shape of the HLO result + // for the purpose of tiling. Calculate the logical output dimensions in the + // tile from the reduced output dimensions. + std::vector output_dims_in_tiles = std::vector( + reduced_output_dims.begin(), reduced_output_dims.end()); + CHECK_EQ(output_dims_in_tiles.size(), 3); + for (int i = 1; i < 3; ++i) { + output_dims_in_tiles[i] = + CeilOfRatio(output_dims_in_tiles[i], kTileSize); + } + const int64 num_tiles = + c_accumulate(output_dims_in_tiles, 1, std::multiplies()); + LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile); + + llvm::Type* index_ty = + GetIndexTypeForKernel(hlo, launch_dimensions.launch_bound(), &b_); + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + + // Cast each output IrArray to its corresponding reduced shape and keep the + // reduced shape live during IR emission. + std::vector output_in_reduced_shape_arrays; + std::vector output_reduced_shapes; + CHECK_EQ(ConstructOutputReducedShapeAndCastOutputIrArrayToShape( + *hlo, output_arrays, reduced_output_dims, &output_reduced_shapes, + &output_in_reduced_shape_arrays), + num_outputs); + + // For each tiled parameter, cast its input IrArray to the corresponding + // reduced shape and keep the reduced shape live during IR emission. + std::vector param_in_reduced_shape_arrays; + std::vector param_reduced_shapes; + CHECK_EQ(ConstructInputReducedShapeAndCastInputIrArrayToShape( + *hlo, param_arrays, param_shmem_buffers, reduced_output_dims, + ¶m_reduced_shapes, ¶m_in_reduced_shape_arrays), + num_params); + + // Calculate the starting element coordinate within a tile for the current + // thread, (y, x) from thread_id. + llvm::Value* x; + llvm::Value* y; + std::tie(y, x) = CalculateYXCoordinateWithinTile( + &b_, index_typed_constant(kTileSize), kThreadsPerTile); + + // Calculate the index for the current output tile from block_id. + const IrArray::Index output_tile_index( + GetBlockIdx(&b_, index_ty, num_tiles), + ShapeUtil::MakeShapeWithDescendingLayout(PRED /*arbitrary*/, + output_dims_in_tiles), + &b_); + + // Output tile origin is the index for the first element of the current output + // tile. + const IrArray::Index output_tile_origin = [&] { + IrArray::Index index = output_tile_index; + for (int i = 1; i < 3; ++i) { + index[i] = + b_.CreateMul(output_tile_index[i], index_typed_constant(kTileSize), + "tile_origin." + std::to_string(i)); + } + return index; + }(); + + // Calculate the input tile origin from the output tile origin. + const IrArray::Index input_tile_origin( + Permute({0, 2, 1}, output_tile_origin.multidim())); + + // Calculate the current output tile bounds in each of the logical dimensions. + std::vector output_tile_bounds(3); + for (int i = 1; i < 3; ++i) { + // Only last row or column may not have full size. + output_tile_bounds[i] = b_.CreateSelect( + b_.CreateICmpEQ(output_tile_index[i], + index_typed_constant(output_dims_in_tiles[i] - 1)), + index_typed_constant(reduced_output_dims[i] - + (output_dims_in_tiles[i] - 1) * kTileSize), + index_typed_constant(kTileSize), "kTileSize"); + } + + KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); + + // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck. + auto emit_tiled_elemental_code_with_bounds_check = + [&](const IrArray::Index& index, const string& loop_name, + llvm::Value* tile_width, llvm::Value* tile_height, + const std::function& + emit_elem_function) { + EmitTiledElementalCodeWithBoundsCheck( + kTileSize, kNumRows, index, loop_name, &ksl, &b_, y, x, tile_width, + tile_height, emit_elem_function); + }; + + // Adds `addend` to the given `dim` of `index`. + auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { + index[dim] = b_.CreateAdd(index[dim], addend); + return index; + }; + const IrArray::Index input_index = + offset_dim(offset_dim(input_tile_origin, x, /*dim=*/2), y, /*dim=*/1); + + // Copy input parameter values to shared memory buffers: + // tile[y, x] = input[index] + emit_tiled_elemental_code_with_bounds_check( + input_index, "input", output_tile_bounds[1], output_tile_bounds[2], + [&](const IrArray::Index& index, llvm::Value* y_loc) { + for (int64 id : tiled_param_ids) { + IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id]; + llvm::Value* shmem_buffer = param_shmem_buffers[id]; + // TODO(jlebar): Add AA metadata to this store. Tile buffers are + // global variables, so LLVM can't infer much about it. + b_.CreateStore( + input_in_logical_shape.EmitReadArrayElement(index, &b_, + "input_element"), + b_.CreateGEP(shmem_buffer, {index_typed_constant(0), y_loc, x})); + } + }); + + // Wait for all threads to reach this point, lest we copy a value from tile to + // output before the other thread copies it from input to tile. + // This is `__syncthreads` in CUDA. + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); + + llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x); + + const IrArray::Index output_index = + offset_dim(offset_dim(output_tile_origin, x, /*dim=*/2), y, /*dim=*/1); + + // Write to output[index] by emitting code like normal, except that values for + // the tiled parameters are read from the shmem buffers. + if (hlo->opcode() == HloOpcode::kCopy) { + emit_tiled_elemental_code_with_bounds_check( + output_index, "output", output_tile_bounds[2], output_tile_bounds[1], + [&](const IrArray::Index& index, llvm::Value* y_loc) { + // TODO(jlebar): Add AA metadata to this load. + llvm::Instruction* load_from_shmem_buffer = b_.CreateLoad( + b_.CreateGEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}), + "output_element"); + output_in_reduced_shape_arrays[0].EmitWriteArrayElement( + index, load_from_shmem_buffer, &b_); + }); + } else { + CHECK_EQ(hlo->opcode(), HloOpcode::kFusion); + emit_tiled_elemental_code_with_bounds_check( + output_index, "output", output_tile_bounds[2], output_tile_bounds[1], + [&](const IrArray::Index& index, llvm::Value* y_loc) { + GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, + GetNestedComputer()); + FusedIrEmitter fused_emitter(param_arrays, &elem_emitter); + tiled_param_info.set_y(y_loc); + fused_emitter.SetTiledParameterInfo(&tiled_param_info); + TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter)); + IrArray::Index untiled_index = llvm_ir::GetUnreducedOutputIndex( + index, output_reduced_shapes[0], output_arrays[0].GetShape(), + &b_); + const llvm_ir::ElementGenerator& output_generator = + fused_emitter.GetRootGenerator(); + llvm::Value* output_value = + output_generator(untiled_index).ValueOrDie(); + if (hlo->IsMultiOutputFusion()) { + CHECK(output_value->getType()->isStructTy()); + CHECK_EQ(output_value->getType()->getStructNumElements(), + output_in_reduced_shape_arrays.size()); + for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) { + output_in_reduced_shape_arrays[i].EmitWriteArrayElement( + index, b_.CreateExtractValue(output_value, i), &b_); + } + } else { + output_in_reduced_shape_arrays[0].EmitWriteArrayElement( + index, output_value, &b_); + } + }); + } + + // For multioutput fusion, emit a tuple with all the individual outputs. + if (hlo->IsMultiOutputFusion()) { + std::vector tuple_operand_ptrs; + for (int64 i = 0; i < output_arrays.size(); ++i) { + tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); + } + llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &b_, + module_); + } + + return launch_dimensions; +} + +bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { + HloOpcode opcode = hlo->opcode(); + CHECK(opcode == HloOpcode::kFusion || opcode == HloOpcode::kCopy); + CHECK(opcode != HloOpcode::kFusion || + hlo->fusion_kind() == HloInstruction::FusionKind::kLoop) + << "Only loop fusions are supported."; + + const Shape& output_shape = hlo->IsMultiOutputFusion() + ? ShapeUtil::GetSubshape(hlo->shape(), {0}) + : hlo->shape(); + + // If the output_shape is reduced to 021 shape, find all the parameters of the + // hlo that are in the corresponding 012 shape. + std::vector params_012; + optional> reduced_dims_021; + for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); + ++operand_idx) { + HloInstruction* operand = hlo->mutable_operand(operand_idx); + auto find_transpose_result = + llvm_ir::FindTranspose021(operand->shape(), output_shape); + if (!find_transpose_result.has_value()) { + continue; + } + const std::vector& curr_reduced_dims_021 = *find_transpose_result; + if (!reduced_dims_021.has_value()) { + reduced_dims_021 = curr_reduced_dims_021; + } + if (!ContainersEqual(*reduced_dims_021, curr_reduced_dims_021)) { + // There is more than one possible transpose. Instead of picking one + // transpose, we simply give up here. + return false; + } + params_012.push_back(operand_idx); + } + + if (!reduced_dims_021.has_value()) { + return false; + } + + if ((*reduced_dims_021)[1] < kMinDimensionToTransposeTiled || + (*reduced_dims_021)[2] < kMinDimensionToTransposeTiled) { + return false; + } + + // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the + // elements are of size 4 bytes), and CUDA has an architectural limit of 48kb + // shared memory per SM. (This is increased to 96kb in Volta, but we don't + // use this, in part because it eats into our L1 cache space.) + // + // For correctness we need to ensure that we don't make more than 48kb worth + // of shmem tiles per block. And for performance, we'd probably like to use + // significantly less, so that we can fit more than one block at a time on a + // gpu core. + // + // We say without benchmarks that we want at least 3 threads/block, + // corresponding to 3 shmem tiles if the elements are 32 bits wide. We choose + // which params get the shmem transpose treatment arbitrarily; it's not clear + // if there's a Right Choice. + // + // This is only sound if tiled transposes are the only place where we use + // shared memory in fusions. If in the future other fusile ops use shared + // memory, we'll have to adjust this heuristic. + constexpr int kMinBlocksPerCore = 3; + constexpr int64 kShmemPerCore = 48 * 1024; + int64 shmem_used = 0; + for (int64 i = 0; i < params_012.size(); ++i) { + const HloInstruction* operand = hlo->operand(params_012[i]); + shmem_used += + 32 * 33 * + ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type()); + + if (kMinBlocksPerCore * shmem_used > kShmemPerCore) { + // Erase this element and everything after it from params_012. + params_012.resize(i); + break; + } + } + + VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString(); + thunk_sequence_->emplace_back( + BuildKernelThunk(hlo, /*implements_whole_instruction=*/true)); + const LaunchDimensions launch_dimensions = + EmitHlo021Tile(hlo, *reduced_dims_021, params_012); + UpdateLaunchDimensions(launch_dimensions, LastThunk(), + ir_emitter_context_->llvm_module()); + + return true; +} + +Status IrEmitterUnnested::EmitConstantGlobals() { + for (const BufferAllocation& allocation : + ir_emitter_context_->buffer_assignment().Allocations()) { + if (!allocation.is_constant()) { + continue; + } + + const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation); + const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal); + llvm::ArrayType* global_type = + llvm::ArrayType::get(b_.getInt8Ty(), allocation.size()); + llvm::Constant* initializer = + should_emit_initializer + ? llvm_ir::ConvertLiteralToIrConstant(literal, module_) + : llvm::ConstantAggregateZero::get(global_type); + if (should_emit_initializer) { + VLOG(3) << "Emitted initializer for constant with shape " + << ShapeUtil::HumanString(literal.shape()); + } + + // These globals will be looked up by name by GpuExecutable so we need to + // give them an external linkage. Not all of their uses are visible in the + // LLVM IR (e.g. TupleThunk) so we can't give then a linkage that merely + // preserves their names (like available_externally), we also need to ensure + // that they stick around even if they're "unused". + // + // We may have to be more more clever here in the future if we notice that + // we're keeping around too many globals because of their linkage. + llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( + global_type, /*isConstant=*/should_emit_initializer, + llvm::GlobalValue::ExternalLinkage, + /*Initializer=*/initializer, + llvm_ir::AsStringRef( + llvm_ir::ConstantBufferAllocationToGlobalName(allocation))); + global_for_const->setAlignment(kConstantBufferAlignBytes); + ir_emitter_context_->llvm_module()->getGlobalList().push_back( + global_for_const); + } + + return Status::OK(); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 819060061a9b8bcf0db4f782852b0a7c6530143c..525441990795e160ba0e8facb910d5cc9796c4bb 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h" namespace xla { namespace gpu { @@ -73,8 +74,11 @@ class IrEmitterUnnested : public IrEmitter { Status HandleTuple(HloInstruction* tuple) override; Status HandleWhile(HloInstruction* xla_while) override; Status HandleInfeed(HloInstruction* xla_infeed) override; + Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleRng(HloInstruction* random) override; Status HandleSelect(HloInstruction* select) override; + Status HandleSort(HloInstruction* sort) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleAfterAll(HloInstruction* gen_token) override; @@ -88,6 +92,9 @@ class IrEmitterUnnested : public IrEmitter { const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, KernelThunk* thunk); + // Emits LLVM global variables corresponding to constant instructions. + Status EmitConstantGlobals(); + private: // Builds the appropriate thunk for the instruction hlo and returns the owning // pointer to it. The caller needs to make sure `inst` outlives the lifetime @@ -115,7 +122,7 @@ class IrEmitterUnnested : public IrEmitter { // Emits code that reduces a matrix of shape [height x width] to a vector of // [width]. Other parameters have the same meaning as those of // `EmitReductionToVector`. Note that input shape might not be - // [height x width], but can be bitcast to [height x weight] with "height" + // [height x width], but can be bitcast to [height x width] with "height" // being the major dimension. Status EmitColumnReduction( int64 height, int64 width, HloInstruction* reduce, @@ -131,7 +138,7 @@ class IrEmitterUnnested : public IrEmitter { // Emits code that reduces a 3D tensor of shape [depth x height x width] to a // vector of shape [height]. Other parameters have the same meaning as those // of `EmitReductionToVector`. Note that input shape might not be - // [depth x height x width], but can be bitcast to [depth x height x weight] + // [depth x height x width], but can be bitcast to [depth x height x width] // with "depth" being the most major dimension. Status EmitRowReduction( int64 depth, int64 height, int64 width, HloInstruction* reduce, @@ -182,12 +189,56 @@ class IrEmitterUnnested : public IrEmitter { std::pair> extra_output_gens); + // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel + // for the hlo instruction. + bool CheckAndEmitHloWithTile021(HloInstruction* hlo); + // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and + // returns the launch dimensions for the kernel. This is a helper to support + // the implementation of CheckAndEmitHloWithTile021. + LaunchDimensions EmitHlo021Tile( + HloInstruction* hlo, + tensorflow::gtl::ArraySlice reduced_output_dims, + tensorflow::gtl::ArraySlice tiled_param_ids); + // Generates the IrArray for each output of hlo and returns the number of + // outputs. + int ConstructIrArrayForOutputs(const HloInstruction& hlo, + std::vector* output_arrays); + // Generates the IrArray for each input of hlo and returns the number of + // inputs. + int ConstructIrArrayForInputs(const HloInstruction& hlo, + std::vector* param_arrays); + // For each output of the `hlo` instruction, constructs the reduced shape for + // the output with the given `reduced_output_dims` and cast the original + // output IrArray element in `output_arrays` to the reduced shape. Returns + // the number of outputs. + int ConstructOutputReducedShapeAndCastOutputIrArrayToShape( + const HloInstruction& hlo, + const std::vector& output_arrays, + tensorflow::gtl::ArraySlice reduced_output_dims, + std::vector* output_reduced_shapes, + std::vector* output_in_reduced_shape_arrays); + // For each input of the `hlo` instruction, checks its value in + // `param_buffers` to find out whether the input has a reduced shape. If the + // input has a reduced shape, constructs the reduced shape for the input and + // casts the original input IrArray in `param_arrays` to the reduced shape. + // Return the total number of inputs. + int ConstructInputReducedShapeAndCastInputIrArrayToShape( + const HloInstruction& hlo, + const std::vector& param_arrays, + const std::vector& param_buffers, + tensorflow::gtl::ArraySlice reduced_output_dims, + std::vector* param_reduced_shapes, + std::vector* param_in_reduced_shape_arrays); + // Returns a KernelThunk that invokes the kernel emitted for `inst`. The // caller needs to make sure `inst` outlives the lifetime of the returned // Thunk object. The kernel implementation will be unrolled if unroll_factor - // is greater than one. - std::unique_ptr BuildKernelThunk(const HloInstruction* inst, - int unroll_factor = 1); + // is greater than one. 'implements_whole_instruction' specifies whether this + // KernelThunk implements the whole 'inst' HloInstruction. In some cases + // 'inst' will be implemented by a sequence of Thunks. + std::unique_ptr BuildKernelThunk( + const HloInstruction* inst, bool implements_whole_instruction, + int unroll_factor = 1); // Returns a FftThunk that calls cuFFT to implement `inst`. std::unique_ptr BuildFftThunk(const HloInstruction* inst); @@ -208,10 +259,14 @@ class IrEmitterUnnested : public IrEmitter { std::unique_ptr BuildDeviceToDeviceCopyThunk( const HloInstruction* inst); - // Returns an InfeedThunk that performs device-to-device memcpy to implement + // Returns an InfeedThunk that performs a host-to-device memcpy to implement // `inst`. std::unique_ptr BuildInfeedThunk(const HloInstruction* inst); + // Returns an OutfeedThunk that performs a device-to-host memcpy to implement + // `inst`. + std::unique_ptr BuildOutfeedThunk(const HloInstruction* inst); + // Returns a WhileThunk that invokes thunk sequences for 'condition' and // 'body' sub-computations of while instruction 'hlo'. std::unique_ptr BuildWhileThunk(const HloInstruction* hlo); diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index f56c1ce69f11ed79c8be76834269f29de93a9645..e76823ad103dfa5ba61a0d3ba81b2c028dfeb33e 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -75,7 +76,8 @@ void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) { } Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { + se::Stream* stream, + HloExecutionProfiler* profiler) { // Load the kernel. se::StreamExecutor* executor = stream->parent(); LaunchDimensions launch_dimensions; @@ -100,6 +102,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, VLOG(3) << " Arg: alloc #" << arg->index() << ": " << buf.opaque() << " (" << buf.size() << "B)"; } + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); if (!stream->parent()->Launch( stream, se::ThreadDim(launch_dimensions.threads_per_block()), se::BlockDim(launch_dimensions.block_count()), *kernel, diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index 7def27e189b66747569344a3dbe5c0c446f903be..d751de50ad6671b3bf88cd4de49a8feb448e13ba 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -62,7 +63,8 @@ class KernelThunk : public Thunk { // Executes the kernel for the thunk on "stream", which must be non-null. Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: // Buffers passed to the kernel as arguments. diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index 7de8f9e1ee922bdbf65fd1299702482e1843f17e..eb93efc560efbb4c14065ec98b980a1ca78605c6 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -17,12 +17,12 @@ cc_library( name = "llvm_gpu_backend", srcs = [ "dump_ir_pass.cc", - "gpu_backend_lib.cc", + "nvptx_backend_lib.cc", "utils.cc", ], hdrs = [ "dump_ir_pass.h", - "gpu_backend_lib.h", + "nvptx_backend_lib.h", "utils.h", ], deps = [ @@ -34,6 +34,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@llvm//:amdgpu_code_gen", "@llvm//:analysis", "@llvm//:bit_reader", "@llvm//:bit_writer", diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc similarity index 93% rename from tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc rename to tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index a4e4e85bf3d2c197cfc691b7fca0920aa6571729..cf44458a2ed6c069c1469bb975c62565264451c1 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h" #include #include @@ -114,20 +114,20 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, // Gets the GPU name as it's known to LLVM for a given compute capability. If // we see an unrecognized compute capability, we return "sm_30". static string GetSmName(std::pair compute_capability) { - static auto* m = new std::map, int>({{{2, 0}, 20}, - {{2, 1}, 21}, - {{3, 0}, 30}, - {{3, 2}, 32}, - {{3, 5}, 35}, - {{3, 7}, 37}, - {{5, 0}, 50}, - {{5, 2}, 52}, - {{5, 3}, 53}, - {{6, 0}, 60}, - {{6, 1}, 61}, - {{6, 2}, 62}, - // TODO: Change this to 70 once LLVM NVPTX supports it - {{7, 0}, 60}}); + static auto* m = new std::map, int>({ + {{3, 0}, 30}, + {{3, 2}, 32}, + {{3, 5}, 35}, + {{3, 7}, 37}, + {{5, 0}, 50}, + {{5, 2}, 52}, + {{5, 3}, 53}, + {{6, 0}, 60}, + {{6, 1}, 61}, + {{6, 2}, 62}, + {{7, 0}, 70}, + {{7, 2}, 72}, + }); int sm_version = 30; auto it = m->find(compute_capability); if (it != m->end()) { @@ -206,7 +206,7 @@ std::unique_ptr GetTargetMachine( codegen_opt_level = CodeGenOpt::None; } return WrapUnique(target->createTargetMachine( - triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx42", target_options, + triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options, Optional(RelocModel), Optional(CMModel), codegen_opt_level)); } @@ -319,8 +319,8 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module, llvm::Linker linker(*module); string libdevice_path = tensorflow::io::JoinPath( - libdevice_dir_path, GetLibdeviceFilename(libdevice_dir_path, - compute_capability)); + libdevice_dir_path, + GetLibdeviceFilename(libdevice_dir_path, compute_capability)); TF_RETURN_IF_ERROR(tensorflow::Env::Default()->FileExists(libdevice_path)); VLOG(1) << "Linking with libdevice from: " << libdevice_path; std::unique_ptr libdevice_module = @@ -328,7 +328,7 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module, if (linker.linkInModule( std::move(libdevice_module), llvm::Linker::Flags::LinkOnlyNeeded, [](Module& M, const StringSet<>& GVS) { - internalizeModule(M, [&M, &GVS](const GlobalValue& GV) { + internalizeModule(M, [&GVS](const GlobalValue& GV) { return !GV.hasName() || (GVS.count(GV.getName()) == 0); }); })) { diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h similarity index 90% rename from tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h rename to tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h index 0a345191d34e6f40db043c559a67a44a6748321c..54e0e140dea1c3a8b21ffde2950c4bc9b703b71c 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // LLVM-based compiler backend. -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_BACKEND_LIB_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_BACKEND_LIB_H_ #include #include @@ -44,4 +44,4 @@ StatusOr CompileToPtx(llvm::Module* module, } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_BACKEND_LIB_H_ diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc index d4100a898b5bb9eec382c34932c2db104c9e985b..9fd6cf7157ecd659e7eb1d2c5228eca931ff6a01 100644 --- a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc @@ -14,21 +14,27 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" + +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/stream_executor/stream_executor.h" namespace xla { namespace gpu { Status MemzeroThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenMemZero(&dest_data, dest_data.size()); return Status::OK(); } Status Memset32BitValueThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); stream->ThenMemset32(&dest_data, value_, dest_data.size()); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.h b/tensorflow/compiler/xla/service/gpu/memset_thunk.h index 51c332d287d139335b356fc66411b5ffaa448b5a..d1fec0bd76b8a80f4a1e1c2e818f248997da7a75 100644 --- a/tensorflow/compiler/xla/service/gpu/memset_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MEMSET_THUNK_H_ #include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/status.h" @@ -36,7 +37,8 @@ class MemzeroThunk : public Thunk { : Thunk(Kind::kMemzero, hlo), dest_(dest) {} Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const BufferAllocation::Slice dest_; @@ -52,7 +54,8 @@ class Memset32BitValueThunk : public Thunk { : Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {} Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: uint32 value_; diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 652b5c7687118678a81bcdcd36e74f68decf4b88..c62bae0628f7b2fbfe822104fbe5f3528e0e09c3 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -23,6 +23,8 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -71,7 +73,6 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, // In that case, the operand of the reduce needs to have the same shape // as the other tuple operands, but also we need to compare the output // shapes of the reduces. - // TODO(tjoerg): Allow differences in fp precision. auto* element_instr_1 = get_element_instr(instr1); auto* element_instr_2 = get_element_instr(instr2); if (element_instr_1->opcode() == HloOpcode::kReduce && @@ -80,8 +81,8 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, return false; } // The elementwise output shapes must be the same (including layout). - return ShapeUtil::Equal(get_element_shape(element_instr_1), - get_element_shape(element_instr_2)); + return ShapeUtil::EqualIgnoringFpPrecision( + get_element_shape(element_instr_1), get_element_shape(element_instr_2)); } namespace { @@ -107,16 +108,42 @@ bool IsInputFusibleReduction(HloInstruction* instr) { return IsReductionToVector(*instr); } } + +// The code emitted for reduction suffers from poor data locality if the layouts +// of input parameters differ. In such situtations it is beneficial not to fuse. +// We consider input params with maximum rank only. Params with smaller ranks +// will be broadcasted and have not been observed to cause data locality issues. +// TODO(b/111977086): Improve reduce emitters to remove this limitation. +bool ReduceFriendlyInputLayouts(HloInstruction* instr) { + std::vector params; + if (instr->opcode() == HloOpcode::kFusion) { + params = instr->fused_parameters(); + } else { + for (HloInstruction* operand : instr->operands()) { + params.push_back(operand); + } + } + int64 max_rank = 0; + const Layout* max_rank_layout; + for (HloInstruction* param : params) { + if (ShapeUtil::Rank(param->shape()) > max_rank) { + max_rank = ShapeUtil::Rank(param->shape()); + max_rank_layout = ¶m->shape().layout(); + } + } + return c_all_of(params, [&](HloInstruction* param) { + return (ShapeUtil::Rank(param->shape()) < max_rank) || + (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); + }); +} + } // namespace bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { // We can fuse reduces and loop fusions. return IsInputFusibleReduction(instr) || (instr->opcode() == HloOpcode::kFusion && - instr->fusion_kind() == HloInstruction::FusionKind::kLoop && - // TODO(b/110202584): bitcasts make nested fusions, GPU has no support - // for nested fusions. - instr->fused_expression_root()->opcode() != HloOpcode::kBitcast); + instr->fusion_kind() == HloInstruction::FusionKind::kLoop); } int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, @@ -145,16 +172,22 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, if (!MultiOutputFusion::LegalToFuse(instr1, instr2)) { return false; } + // If we're fusing fusions only do it if the fusion kind matches. Loop fusions // merge into bigger loop fusions and input (reduce) fusions become fusions // with multiple reduce outputs. We could fuse reduce and loop fusions // together too (the result being an input fusion) if we find cases where this // improves things. CHECK(instr1->opcode() == HloOpcode::kFusion); - if (instr2->opcode() == HloOpcode::kFusion) { - return instr1->fusion_kind() == instr2->fusion_kind(); + if ((instr2->opcode() == HloOpcode::kFusion && + instr1->fusion_kind() != instr2->fusion_kind()) || + (instr2->opcode() != HloOpcode::kFusion && + instr1->fusion_kind() == HloInstruction::FusionKind::kLoop)) { + return false; } - return instr1->fusion_kind() != HloInstruction::FusionKind::kLoop; + + // Do this check last, as it may be expensive. + return !GpuInstructionFusion::FusionWouldBeTooLarge(instr1, instr2); } bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { @@ -176,29 +209,41 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { // fusions operands. for (HloInstruction* consumer : computation()->MakeInstructionPostOrder()) { if (consumer->user_count() == 0) { + VLOG(3) << consumer->name() << " has no users."; continue; } if (!IsInputFusibleReduction(consumer)) { + VLOG(3) << consumer->name() << " is not an input-fusable reduction."; continue; } + VLOG(3) << consumer->name() + << " is a fusion candidate. Looking for fuseable operands."; auto consumer_operands = consumer->operands(); for (size_t i = 0; i < consumer_operands.size(); ++i) { HloInstruction* producer = consumer_operands[i]; if (!producer->IsFusable()) { + VLOG(3) << producer->name() << " is not fusable."; continue; } const bool is_loop_fusion = producer->opcode() == HloOpcode::kFusion && producer->fusion_kind() == HloInstruction::FusionKind::kLoop; - if (!is_loop_fusion) { + if (!producer->IsElementwise() && !is_loop_fusion) { + VLOG(3) << producer->name() << " is not a loop fusion."; continue; } if (!ShapesCompatibleForFusion(producer, consumer)) { + VLOG(3) << producer->name() << " has an incompatible shape."; + continue; + } + if (!ReduceFriendlyInputLayouts(producer)) { + VLOG(3) << producer->name() << " has inputs with mixed layouts."; continue; } // If we have already decided to fuse this producer, skip it. if (ContainsKey(to_fuse, producer)) { + VLOG(3) << producer->name() << " will be fused with another consumer."; continue; } // Do not fuse a producer if the other operands of the fusion are @@ -207,6 +252,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { return producer != operand && reachability()->IsReachable(producer, operand); })) { + VLOG(3) << producer->name() << " would introduce a cycle when fused."; break; } to_fuse.insert(producer); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 979ea79243818c398b1b130254a41c95ced51830..14f157a5e518a0ec82c664c123629d04bd385bbf 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -27,7 +28,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace gpu { -using InstructionFusionTest = HloTestBase; +using MultiOutputFusionTest = HloTestBase; const char kModulePrefix[] = R"( HloModule test_module @@ -40,10 +41,10 @@ const char kModulePrefix[] = R"( scalar_mul_computation { scalar_lhs.1 = f32[] parameter(0) scalar_rhs.1 = f32[] parameter(1) - ROOT mul.1 = f32[] add(scalar_lhs.1, scalar_rhs.1) + ROOT mul.1 = f32[] multiply(scalar_lhs.1, scalar_rhs.1) })"; -TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { +TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { // Fusion with reduce instruction root and a sibling reduce instruction // sharing the same input param. auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( @@ -72,7 +73,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { op::Tuple(op::Reduce(), op::Reduce())); } -TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) { +TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) { auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[6400]{0} parameter(1) @@ -99,7 +100,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) { ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); } -TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { +TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[10,10]{1,0} parameter(1) @@ -126,7 +127,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); } -TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) { +TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceFusions) { // Two sibling fusions with reduce instruction roots sharing the same input // param. auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( @@ -160,7 +161,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) { op::Tuple(op::Reduce(), op::Reduce())); } -TEST_F(InstructionFusionTest, +TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) { // Multi-output fusion with two reduce instructions root and a sibling reduce // instruction sharing the same input param. @@ -193,7 +194,7 @@ TEST_F(InstructionFusionTest, op::Tuple(op::Reduce(), op::Reduce(), op::Reduce())); } -TEST_F(InstructionFusionTest, +TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) { // Verify that if we already have a multi-output fusion that we prefer to pick // a reduce op from its operands for checking shape compatibility. @@ -226,7 +227,7 @@ TEST_F(InstructionFusionTest, ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); } -TEST_F(InstructionFusionTest, MultiOutputFusionTwoLoops) { +TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) { auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[6400]{0} parameter(0) @@ -255,7 +256,27 @@ TEST_F(InstructionFusionTest, MultiOutputFusionTwoLoops) { op::Tuple(op::Multiply(), op::Divide())); } -TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { +TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + exp = f32[2,2,2]{2,1,0} exponential(p0) + reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add_computation + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, exp) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement())); + const HloInstruction* fusion = root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Exp())); +} + +TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( fused_add { p0.1 = f32[2,2,2]{2,1,0} parameter(0) @@ -282,7 +303,7 @@ TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { op::Tuple(op::Reduce(), op::Add())); } -TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { +TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( fused_select { p1.1 = f32[2,2,2]{2,1,0} parameter(1) @@ -323,7 +344,7 @@ TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { op::Tuple(op::Reduce(), op::Reduce(), op::Select())); } -TEST_F(InstructionFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) { +TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) { auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( fused_element_wise { p0.1 = f32[2,2,2]{2,1,0} parameter(0) @@ -349,5 +370,128 @@ TEST_F(InstructionFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) { ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); } +TEST_F(MultiOutputFusionTest, + ProducerConsumerFusionFp16LoopFusionAndReduceFusion) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_select { + p1.1 = f16[2,2,2]{2,1,0} parameter(1) + c0 = f16[] constant(0) + broadcast = f16[2,2,2]{2,1,0} broadcast(f16[] c0), dimensions={} + greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast) + p0.1 = f16[2,2,2]{2,1,0} parameter(0) + ROOT select = f16[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f16[2,2,2]{2,1,0} p0.1, f16[2,2,2]{2,1,0} broadcast) + } + fused_reduce { + p0.2 = f16[2,2,2]{2,1,0} parameter(0) + convert = f32[2,2,2]{2,1,0} convert(p0.2) + c1 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(convert, c1), dimensions={2}, to_apply=scalar_add_computation + mul = f32[2,2,2]{2,1,0} multiply(convert, convert) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + ENTRY reduce { + p0 = f16[2,2,2]{2,1,0} parameter(0) + p1 = f16[2,2,2]{2,1,0} parameter(1) + select = f16[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select + fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce + gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0 + gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1 + ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) tuple(gte1, gte1, select) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement(), + op::GetTupleElement())); + const HloInstruction* fusion = root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce(), op::Select())); +} + +TEST_F(MultiOutputFusionTest, + ProducerConsumerFusionReduceUnfriendlyLoopFusion) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + mixed_input_layouts_computation { + p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) + copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1) + c0 = f16[] constant(0) + broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={} + greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast) + ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast) + } + fused_reduce { + p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2) + c0.2 = f32[] constant(0) + ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add_computation + } + ENTRY reduce { + p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0) + p1 = f16[128,1024,32,32]{1,3,2,0} parameter(1) + loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation + reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce + ROOT root = (f32[1024]{0}, f16[128,1024,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +// Check that we limit the number of operands to fusions we create. +TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) { + constexpr int64 kNumParams = 200; + ASSERT_GT(kNumParams, GpuInstructionFusion::kMaxOperandsAndOutputsPerFusion); + + // Compute + // p0 * p1, + // p0 * p1 + p1 * p2 + // p0 * p1 + p1 * p2 + p2 * p3 + // ... + // where each of the (pi * pj)'s is represented as a fusion node so that + // multi-output fusion will pay attention to it. + auto module = CreateNewModule(); + HloComputation::Builder b(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {10, 100}); + + std::vector params; + for (int64 i = 0; i < kNumParams; ++i) { + params.push_back( + b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p"))); + } + + // Creates a fusion node that calculates x*y. + auto make_fusion = [&](HloInstruction* x, HloInstruction* y) { + HloComputation::Builder sub_builder("subcomp"); + auto* p0 = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "p")); + auto* p1 = sub_builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "p")); + sub_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1)); + HloComputation* subcomp = + module->AddEmbeddedComputation(sub_builder.Build()); + return HloInstruction::CreateFusion( + shape, HloInstruction::FusionKind::kLoop, {x, y}, subcomp); + }; + + auto* sum = b.AddInstruction(make_fusion(params[0], params[1])); + for (int64 i = 2; i < kNumParams; ++i) { + sum = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, sum, + b.AddInstruction(make_fusion(params[i - 1], params[i])))); + } + auto computation = module->AddEntryComputation(b.Build()); + EXPECT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + for (const HloInstruction* instr : computation->instructions()) { + EXPECT_LE(instr->operand_count() + ShapeUtil::SubshapeCount(instr->shape()), + GpuInstructionFusion::kMaxOperandsAndOutputsPerFusion) + << instr->ToString(); + } +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc similarity index 88% rename from tensorflow/compiler/xla/service/gpu/gpu_compiler.cc rename to tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index decfc40dafafe875fa02bab6695f5c54e522f267..76c9b6ab33befa98f03821fac84071bd978ae24d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" +#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h" #include #include @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" -#include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" @@ -50,11 +49,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" -#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -74,7 +75,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" -#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -96,8 +96,8 @@ limitations under the License. namespace xla { namespace gpu { -/* static */ const char* GpuCompiler::kTargetTriple = "nvptx64-nvidia-cuda"; -/* static */ const char* GpuCompiler::kDataLayout = +/* static */ const char* NVPTXCompiler::kTargetTriple = "nvptx64-nvidia-cuda"; +/* static */ const char* NVPTXCompiler::kDataLayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"; namespace { @@ -146,7 +146,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // support BF16 operations without directly implementing a BF16 lowering for // most ops. pipeline.AddPass(BF16, F32); - pipeline.AddPass(); { auto& pass = @@ -199,6 +198,12 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddInvariantChecker(); pipeline.AddPass(); pipeline.AddPass(); + if (IsVoltaOrLater(*stream_exec)) { + pipeline.AddPass(); + // PadForTensorCores leaves behind unnecessary tuple/get-tuple-element + // pairs that TupleSimplifier fixes. + pipeline.AddPass(); + } TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } @@ -275,14 +280,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, } } - { - // Do an aggressive LICM pass over while loops. In particular, this hoists - // constants that were sunk by WhileLoopConstantSinking. Leaving them in - // the while loop may result in unnecessary copies. - HloPassPipeline pipeline("while-loop-licm"); - pipeline.AddPass(true); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - } return Status::OK(); } @@ -354,16 +351,30 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) { return; } + // We need ptxas >= 9.0 as a hard requirement, because we compile targeting + // PTX 6.0. An older ptxas will just fail to compile any of our code. + // // ptxas 9.0 before 9.0.276 and ptxas 9.1 before 9.1.121 miscompile some // address calculations with large offsets (e.g. "load ptr + large_constant"), // b/70245379. - if ((vmaj == 9 && vmin == 0 && vdot < 276) || - (vmaj == 9 && vmin == 1 && vdot < 121)) { - LOG(WARNING) << "*** WARNING *** You are using ptxas " << vmaj << "." - << vmin << "." << vdot - << ", which is in range [9.0.0, 9.0.276) + [9.1.0, 9.1.121). " - "These versions are known to miscompile XLA code, leading " - "to incorrect results or invalid-address errors."; + // + // ptxas 9.1.121 miscompiles some large multioutput fusions, again in a way + // that appears related to address calculations, b/111107644. ptxas 9.2.88 + // appears to work, as far as we can tell. + if (vmaj < 9) { + LOG(ERROR) + << "You are using ptxas 8.x, but XLA requires ptxas 9.x (and strongly " + "prefers >= 9.2.88). Compilation of XLA kernels below will likely " + "fail.\n\nYou do not need to update CUDA; cherry-picking the ptxas " + "binary is sufficient."; + } else if ((vmaj < 9 || vmin < 2 || vdot < 88)) { + LOG(WARNING) + << "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "." + << vdot + << ", which older than 9.2.88. ptxas 9.x before 9.2.88 is known to " + "miscompile XLA code, leading to incorrect results or " + "invalid-address errors.\n\nYou do not need to update to CUDA " + "9.2.88; cherry-picking the ptxas binary is sufficient."; } } @@ -391,17 +402,18 @@ void WarnIfBadDriverJITVersion() { // - 384.x before 384.108 // - 387.x before 387.40 // - 390.x before 390.10. - auto vmaj = std::get<0>(version); - auto vmin = std::get<1>(version); - if ((vmaj == 384 && vmin < 108) || // - (vmaj == 387 && vmin < 40) || // - (vmaj == 390 && vmin < 10)) { + // + // In addition, only >= 396.20 contains ptxas >= 9.2.88, which contains the + // fix for the "large multioutput fusions" miscompile, b/111107644. + if (version < std::make_tuple(396, 20, 0)) { LOG(WARNING) << "*** WARNING *** Invoking the PTX->SASS JIT from driver version " << se::cuda::DriverVersionToString(version) - << ", which is in range [384.0.0, 384.108.0) + [387.0.0, 387.40.0) + " - "[390.0.0, 390.10.0). These versions are known to miscompile XLA " - "code, leading to incorrect results or invalid-address errors."; + << ", which is older than 396.20.0. These versions are known to " + "miscompile XLA code, leading to incorrect results or " + "invalid-address errors.\nXLA only uses the driver JIT if it " + "cannot find ptxas; you don't need to update your driver if " + "you can point XLA to ptxas 9.2.88 or newer."; } }); } @@ -473,14 +485,14 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, } // namespace -GpuCompiler::GpuCompiler() +NVPTXCompiler::NVPTXCompiler() : pointer_size_(llvm::DataLayout(kDataLayout) .getPointerSize(0 /* default address space */)) {} -StatusOr> GpuCompiler::RunHloPasses( +StatusOr> NVPTXCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) { - XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses"); + XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunHloPasses"); tracing::ScopedActivity activity("HLO Transforms", module->name(), /*is_expensive=*/true); TF_RETURN_IF_ERROR( @@ -488,10 +500,10 @@ StatusOr> GpuCompiler::RunHloPasses( return std::move(module); } -StatusOr> GpuCompiler::RunBackend( +StatusOr> NVPTXCompiler::RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) { - XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend"); + XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend"); TF_RET_CHECK(stream_exec != nullptr); @@ -525,11 +537,13 @@ StatusOr> GpuCompiler::RunBackend( // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr buffer_assignment, - BufferAssigner::Run(module.get(), hlo_schedule->ConsumeHloOrdering(), - BufferSizeBytesFunction(), - /*color_alignment=*/[](LogicalBuffer::Color) { - return kCudaMallocAlignBytes; - })); + BufferAssigner::Run( + module.get(), hlo_schedule->ConsumeHloOrdering(), + BufferSizeBytesFunction(), + /*color_alignment=*/ + [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::Stats::ToString() and BufferAssignment::ToString() // include headers, so no need for us to print them ourselves. XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString()); @@ -550,10 +564,12 @@ StatusOr> GpuCompiler::RunBackend( HloComputation* entry_computation = module->entry_computation(); IrEmitterUnnested ir_emitter(module->config(), entry_computation, &ir_emitter_context); + + TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + { - XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission"); - TF_RETURN_IF_ERROR( - entry_computation->root_instruction()->Accept(&ir_emitter)); + XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend - IR emission"); + TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter)); } if (user_pre_optimization_hook_) { @@ -579,7 +595,8 @@ StatusOr> GpuCompiler::RunBackend( } { - XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier"); + XLA_SCOPED_LOGGING_TIMER( + "NVPTXCompiler::RunBackend - Running LLVM verifier"); std::string err; llvm::raw_string_ostream err_stream(err); @@ -619,7 +636,7 @@ StatusOr> GpuCompiler::RunBackend( string ptx; { - XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - CompileToPtx"); + XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend - CompileToPtx"); TF_ASSIGN_OR_RETURN(ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor}, module->config(), libdevice_dir)); } @@ -688,10 +705,10 @@ StatusOr> GpuCompiler::RunBackend( return std::unique_ptr(gpu_executable); } -std::vector GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, - int cc_major, - int cc_minor) { - XLA_SCOPED_LOGGING_TIMER("GpuCompiler::CompilePtxOrGetCachedResult"); +std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, + int cc_major, + int cc_minor) { + XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompilePtxOrGetCachedResult"); tracing::ScopedActivity activity("PTX->CUBIN", /*is_expensive=*/true); bool inserted; decltype(compilation_cache_.begin()) iter; @@ -764,12 +781,14 @@ std::vector GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, } StatusOr>> -GpuCompiler::CompileAheadOfTime(std::vector> module, - const AotCompilationOptions& options) { - return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime"); +NVPTXCompiler::CompileAheadOfTime( + std::vector> module, + const AotCompilationOptions& options) { + return Unimplemented( + "not yet implemented: NVPTXCompiler::CompileAheadOfTime"); } -se::Platform::Id GpuCompiler::PlatformId() const { +se::Platform::Id NVPTXCompiler::PlatformId() const { return se::cuda::kCudaPlatformId; } @@ -779,7 +798,7 @@ se::Platform::Id GpuCompiler::PlatformId() const { static bool InitModule() { xla::Compiler::RegisterCompilerFactory( stream_executor::cuda::kCudaPlatformId, - []() { return xla::MakeUnique(); }); + []() { return xla::MakeUnique(); }); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h similarity index 93% rename from tensorflow/compiler/xla/service/gpu/gpu_compiler.h rename to tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index f3b02ae5d8867bdf1d970e809bff95a15d9f54d2..d4d2909f1b2dc57c3ae0f9d67067e533574369dd 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NVPTX_COMPILER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NVPTX_COMPILER_H_ #include #include @@ -37,10 +37,10 @@ namespace xla { namespace gpu { // The GPU compiler generates efficient GPU executables. -class GpuCompiler : public LLVMCompiler { +class NVPTXCompiler : public LLVMCompiler { public: - GpuCompiler(); - ~GpuCompiler() override {} + NVPTXCompiler(); + ~NVPTXCompiler() override {} // Bring in // StatusOr>> Compile( @@ -64,7 +64,7 @@ class GpuCompiler : public LLVMCompiler { se::Platform::Id PlatformId() const override; HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { - // Capture just the pointer size, not the entire GpuCompiler object. + // Capture just the pointer size, not the entire NVPTXCompiler object. int64 pointer_size = pointer_size_; return [pointer_size](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape, pointer_size); @@ -146,10 +146,10 @@ class GpuCompiler : public LLVMCompiler { CompilationCacheHash, CompilationCacheEq> compilation_cache_ GUARDED_BY(mutex_); - TF_DISALLOW_COPY_AND_ASSIGN(GpuCompiler); + TF_DISALLOW_COPY_AND_ASSIGN(NVPTXCompiler); }; } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NVPTX_COMPILER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc new file mode 100644 index 0000000000000000000000000000000000000000..4aaf0c9e142106a0e74f319d71dad4c4c96d3f08 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h" + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace gpu { + +OutfeedManager* GetOrCreateOutfeedManager() { + static auto* manager = new OutfeedManager; + return manager; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.h b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..160ba4b691f818ff01b41b8603c11853ea12c253 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_ + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/gpu/xfeed_queue.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/notification.h" + +namespace xla { +namespace gpu { + +// TODO(b/30467474) Once GPU outfeed implementation settles, consider +// folding back the cpu and gpu outfeed implementations into a generic +// one if possible. + +// Defines a buffer holding the destination for an outfeed in host memory and a +// notification when that triggers when the transfer is done. +class OutfeedBuffer { + public: + OutfeedBuffer(int64 length) : length_(length) {} + + // Waits for the device transfer to be finished. + void WaitUntilAvailable() { done_.WaitForNotification(); } + + int64 length() const { return length_; } + void set_destination(std::unique_ptr destination) { + destination_ = std::move(destination); + } + MutableBorrowingLiteral* destination() { return destination_.get(); } + + // Callback to signal that this buffer is consumed. + void Done() { done_.Notify(); } + + private: + std::unique_ptr destination_; + const int64 length_; + tensorflow::Notification done_; +}; + +// Manages a thread-safe queue of buffers. The buffers are supposed to be +// produced by the transfer manager and consumed by the device. +using OutfeedManager = XfeedQueue>*>; + +// Singleton creator-or-accessor: Returns the GPU outfeed manager. +OutfeedManager* GetOrCreateOutfeedManager(); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..b99d998c4d7df514c024b1f8d643d08c72059d0e --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -0,0 +1,107 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +OutfeedThunk::OutfeedThunk(ShapeTree outfeed_slices, + const HloInstruction* hlo_instruction) + : Thunk(Kind::kOutfeed, hlo_instruction), + outfeed_slices_(std::move(outfeed_slices)) {} + +Status OutfeedThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { + VLOG(2) << "Outfeeding from GPU: " << hlo_instruction()->ToString(); + + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); + OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager(); + ShapeTree>* outfeed_buffers = + outfeed_manager->BlockingGetNextDestination(); + + // Nothing to be done for empty tuples. + if (ShapeUtil::IsEmptyTuple(hlo_instruction()->operand(0)->shape())) { + return Status::OK(); + } + CHECK(ShapeUtil::Compatible(hlo_instruction()->operand(0)->shape(), + outfeed_buffers->shape())); + + TF_RETURN_IF_ERROR(outfeed_buffers->ForEachMutableElementWithStatus( + [&](const ShapeIndex& index, std::unique_ptr* buffer) { + if (!*buffer) { // Tuple pointers. + return Status::OK(); + } + + BufferAllocation::Slice slice = outfeed_slices_.element(index); + se::DeviceMemoryBase data_address; + if (slice.allocation()) { + // If we have a static allocation, read it from there. This avoids + // synchronizing the host and device just to read a pointer. + data_address = buffer_allocations.GetDeviceAddress(slice); + } else { + // Otherwise we have to read the tuple pointer first. + CHECK(!index.empty()); + // Copy the parent buffer to the host. + BufferAllocation::Slice tuple_slice = + outfeed_slices_.element(ShapeIndexView(index).ConsumeFront()); + if (!tuple_slice.allocation()) { + return Unimplemented( + "Nested dynamic tuples are not supported on GPU"); + } + se::DeviceMemoryBase tuple_address = + buffer_allocations.GetDeviceAddress(tuple_slice); + CHECK(tuple_slice.size() % sizeof(void*) == 0) + << "Tuple size must be a multiple of pointer size"; + std::vector tuple_element_buffer_addresses(tuple_slice.size() / + sizeof(void*)); + stream->ThenMemcpy(tuple_element_buffer_addresses.data(), + tuple_address, tuple_slice.size()); + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + // The data address is specified by the element of the tuple pointer + // buffer. + data_address = + se::DeviceMemoryBase(tuple_element_buffer_addresses[index.back()], + (*buffer)->length()); + } + + // TODO(b/111309141): Run this on a separate stream so it doesn't block + // the GPU from doing work during the transfer. This could be handled by + // making StreamAssignment do something intelligent with outfeed thunks. + stream + ->ThenMemcpy((*buffer)->destination()->untyped_data(), data_address, + (*buffer)->length()) + .ThenDoHostCallback([buffer]() { (*buffer)->Done(); }); + return Status::OK(); + })); + + Status block_status = stream->BlockHostUntilDone(); + if (!block_status.ok()) { + return InternalError("Failed to complete data transfer on stream %p: %s", + stream, block_status.error_message().c_str()); + } + + VLOG(2) << "Outfeeding from GPU complete"; + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h new file mode 100644 index 0000000000000000000000000000000000000000..8ed89f05f0c5bb2e3893e695d413bac3b231112d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_THUNK_H_ + +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// A thunk that outfeeds data. Data must be already resident on the host. This +// thunk performs a host to device copy from the buffer allocated for the +// outfeed op to the host location. +class OutfeedThunk : public Thunk { + public: + // Constructs a OutfeedThunk that copies data to the host-side + // outfeed queue from the buffers in the given shape tree. + OutfeedThunk(ShapeTree outfeed_slices, + const HloInstruction* hlo_instruction); + + OutfeedThunk(const OutfeedThunk&) = delete; + OutfeedThunk& operator=(const OutfeedThunk&) = delete; + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream, + HloExecutionProfiler* profiler) override; + + private: + const ShapeTree outfeed_slices_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc new file mode 100644 index 0000000000000000000000000000000000000000..79f7d31816baf0b95b967771b956a9c06ac81e91 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc @@ -0,0 +1,233 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" + +namespace xla { +namespace gpu { + +using tensorflow::gtl::ArraySlice; + +// We want the input/output feature counts of an f16 conv to be factors of 8, +// because without this cudnn can't use tensor cores on the conv. +static constexpr int64 kDesiredNumFeaturesFactor = 8; + +// We won't pad a conv if doing so increases the total number of bytes in the +// lhs, rhs, or result by more than this amount. +// +// TODO(jlebar): This number was tuned experimentally. It represents a +// compromise on our current benchmarks; it speeds some up significantly, and +// doesn't slow any down. But we can observe by changing this value that +// there's additional room for speedups. Achieving those speedups without also +// slowing other things down will likely require a more sophisticated heuristic, +// possibly some form of auto-tuning. +static constexpr double kMaxBytesTouchedIncrease = 1.2; + +// Pads the given dimensions in the given shape up to a multiple of +// kDesiredNumFeaturesFactor. +static Shape PadShape(Shape s, ArraySlice dims) { + for (int64 dim : dims) { + int64 dim_to_pad_size = s.dimensions(dim); + int64 new_dim_to_pad_size = + RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor); + s.set_dimensions(dim, new_dim_to_pad_size); + } + return s; +} + +// Creates and returns an HLO that zero-pads one or more dimensions in the given +// instruction so that its shape is equal to the given shape. +// +// Padding is added to the end of each relevant dimension. +// +// If the instruction already has the given shape, simply returns it without an +// intervening pad. +static HloInstruction* PadInstruction(HloInstruction* instr, + const Shape& new_shape) { + HloComputation* comp = instr->parent(); + + const Shape& shape = instr->shape(); + auto* zero = comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(shape.element_type()).CloneToUnique())); + + PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape)); + + bool added_padding = false; + for (int64 dim = 0; dim < ShapeUtil::Rank(shape); ++dim) { + if (shape.dimensions(dim) == new_shape.dimensions(dim)) { + continue; + } + CHECK_GT(new_shape.dimensions(dim), shape.dimensions(dim)); + pad_config.mutable_dimensions(dim)->set_edge_padding_high( + new_shape.dimensions(dim) - shape.dimensions(dim)); + added_padding = true; + } + + if (!added_padding) { + return instr; + } + return comp->AddInstruction( + HloInstruction::CreatePad(new_shape, instr, zero, pad_config)); +} + +// Pads the input/output feature dimensions of the given cudnn convolution +// custom-call to be multiples of kDesiredNumFeaturesFactor. +static StatusOr PadFeaturesDims(HloInstruction* conv) { + CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0)) + << "conv must use 0 scratch bytes, i.e. this pass must be run " + "before CudnnConvolutionAlgorithmPicker."; + + const auto& target = conv->custom_call_target(); + const auto& dnums = conv->convolution_dimension_numbers(); + auto* lhs = conv->mutable_operand(0); + auto* rhs = conv->mutable_operand(1); + const Shape& result_shape = conv->shape().tuple_shapes(0); + + Shape new_lhs_shape = [&] { + if (target == kCudnnConvForwardCallTarget || + target == kCudnnConvBackwardFilterCallTarget) { + // LHS is "input". + return PadShape(lhs->shape(), {dnums.input_feature_dimension()}); + } + CHECK_EQ(target, kCudnnConvBackwardInputCallTarget); + // LHS is "output". + return PadShape(lhs->shape(), {dnums.output_feature_dimension()}); + }(); + + Shape new_rhs_shape = [&] { + if (target == kCudnnConvForwardCallTarget || + target == kCudnnConvBackwardInputCallTarget) { + // RHS is "filter". + return PadShape(rhs->shape(), {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}); + } + CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); + // RHS is "output". + return PadShape(rhs->shape(), {dnums.output_feature_dimension()}); + }(); + + if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) && + ShapeUtil::Equal(rhs->shape(), new_rhs_shape)) { + VLOG(3) << "No need to pad features of " << conv->ToString(); + return false; + } + + Shape new_result_shape = [&] { + if (target == kCudnnConvForwardCallTarget) { + // Result is "output". + return PadShape(result_shape, {dnums.output_feature_dimension()}); + } + if (target == kCudnnConvBackwardInputCallTarget) { + // Result is "input". + return PadShape(result_shape, {dnums.input_feature_dimension()}); + } + CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); + // Result is "filter". + return PadShape(result_shape, {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}); + }(); + + // Check that padding wouldn't increase the total bytes read/written by this + // operation too much. + auto check_size_increase = [&](const Shape& old_shape, + const Shape& new_shape) { + int64 old_bytes = ShapeUtil::ByteSizeOf(old_shape); + int64 new_bytes = ShapeUtil::ByteSizeOf(new_shape); + if (new_bytes <= old_bytes * kMaxBytesTouchedIncrease) { + return true; + } + VLOG(3) << "Not padding convolution; doing so would change input / result " + "shape from " + << ShapeUtil::HumanString(old_shape) << " to " + << ShapeUtil::HumanString(new_shape) << ", a size increase of " + << new_bytes / static_cast(old_bytes) << "x > " + << kMaxBytesTouchedIncrease << "x: " << conv->ToString(); + return false; + }; + if (!check_size_increase(lhs->shape(), new_lhs_shape) || + !check_size_increase(rhs->shape(), new_rhs_shape) || + !check_size_increase(result_shape, new_result_shape)) { + return false; + } + + // OK, let's do the transformation! + + auto* new_lhs = PadInstruction(lhs, new_lhs_shape); + auto* new_rhs = PadInstruction(rhs, new_rhs_shape); + CHECK(new_lhs != lhs || new_rhs != rhs) + << "We should have had to pad either LHS or RHS."; + + auto add = [&](std::unique_ptr new_instr) { + return conv->parent()->AddInstruction(std::move(new_instr)); + }; + + Shape new_conv_shape = ShapeUtil::MakeTupleShape( + {new_result_shape, ShapeUtil::MakeShape(U8, {0})}); + auto* new_conv = + add(conv->CloneWithNewOperands(new_conv_shape, {new_lhs, new_rhs})); + + // Slice the new conv result if necessary, keeping in mind that new_conv has + // tuple shape (new_result_shape, u8[0]). + if (!ShapeUtil::Equal(result_shape, new_result_shape)) { + std::vector start_indices(result_shape.dimensions_size(), 0); + std::vector end_indices(result_shape.dimensions().begin(), + result_shape.dimensions().end()); + std::vector strides(result_shape.dimensions_size(), 1); + + auto* new_conv_result = add( + HloInstruction::CreateGetTupleElement(new_result_shape, new_conv, 0)); + auto* empty_temp_buffer = + add(HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + auto* sliced_result = add(HloInstruction::CreateSlice( + result_shape, new_conv_result, start_indices, end_indices, strides)); + new_conv = + add(HloInstruction::CreateTuple({sliced_result, empty_temp_buffer})); + } + + VLOG(2) << "Padded features of " << conv->ToString() << ", replaced with " + << new_conv->ToString(); + TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(conv, new_conv)); + return true; +} + +static std::vector GetRelevantConvs(HloComputation* comp) { + std::vector convs; + for (HloInstruction* instr : comp->instructions()) { + if (IsCustomCallToDnnConvolution(*instr) && + instr->operand(0)->shape().element_type() == F16) { + convs.push_back(instr); + } + } + return convs; +} + +StatusOr PadForTensorCores::Run(HloModule* module) { + bool changed = false; + for (HloComputation* comp : module->MakeNonfusionComputations()) { + for (HloInstruction* conv : GetRelevantConvs(comp)) { + TF_ASSIGN_OR_RETURN(bool result, PadFeaturesDims(conv)); + changed |= result; + } + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h new file mode 100644 index 0000000000000000000000000000000000000000..192359f026bfb2f1d5436713e4a30725fa0ad6ba --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Ensures that f16 cudnn convolutions have input/output channel dimensions that +// are multiples of 8, inserting pads/slices as necessary. +// +// This is useful primarily for Volta and newer GPUs, where tensor cores can +// only be used if the channel dims are multiples of 8. It's probably the +// opposite of useful on other GPUs, so you should check what GPU you're +// targeting before running this pass. +// +// TODO(jlebar): Also pad dots. +class PadForTensorCores : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { + return "pad for tensor cores"; + } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..99e7580b826fc5cd6d98a037a5eb064552952e18 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc @@ -0,0 +1,164 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" + +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace gpu { +namespace { + +namespace op = xla::testing::opcode_matchers; +using ::testing::_; + +using PadForTensorCoresTest = HloVerifiedTestBase; + +TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,41] parameter(0) + filter = f16[2,2,41,40] parameter(1) + ROOT result = (f16[10,20,30,40], u8[0]) custom-call(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convForward" + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + + SCOPED_TRACE(module().ToString()); + EXPECT_THAT(root, op::CustomCall(kCudnnConvForwardCallTarget, + op::Pad(op::Parameter(0), _), + op::Pad(op::Parameter(1), _))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(), + ShapeUtil::MakeShape(F16, {10, 20, 30, 48}))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(), + ShapeUtil::MakeShape(F16, {2, 2, 48, 40}))); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + output = f16[10,20,30,41] parameter(0) + filter = f16[2,2,40,41] parameter(1) + ROOT result = (f16[10,20,30,40], u8[0]) custom-call(output, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardInput" + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::CustomCall(kCudnnConvBackwardInputCallTarget, + op::Pad(op::Parameter(0), _), + op::Pad(op::Parameter(1), _))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(), + ShapeUtil::MakeShape(F16, {10, 20, 30, 48}))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(), + ShapeUtil::MakeShape(F16, {2, 2, 40, 48}))); +} + +TEST_F(PadForTensorCoresTest, PadF16ForwardConvOutputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,40] parameter(0) + filter = f16[2,2,40,41] parameter(1) + ROOT result = (f16[10,20,30,41], u8[0]) custom-call(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convForward" + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvForwardCallTarget, op::Parameter(0), + op::Pad(op::Parameter(1), _)))), + _)); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvInputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + output = f16[10,20,30,40] parameter(0) + filter = f16[2,2,41,40] parameter(1) + result = (f16[10,20,30,41], u8[0]) custom-call(output, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardInput" + ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0 + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBackwardInputCallTarget, op::Parameter(0), + op::Pad(op::Parameter(1), _)))), + _))); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,41] parameter(0) + output = f16[10,20,30,40] parameter(1) + result = (f16[2,2,41,40], u8[0]) custom-call(input, output), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardFilter" + ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0 + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBackwardFilterCallTarget, + op::Pad(op::Parameter(0), _), op::Parameter(1)))), + _))); +} + +TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) { + ParseAndVerifyModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,40] parameter(0) + output = f16[10,20,30,41] parameter(1) + result = (f16[2,2,40,41], u8[0]) custom-call(input, output), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardFilter" + ROOT gte = f16[2,2,40,41] get-tuple-element(result), index=0 + })"); + EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); + auto* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBackwardFilterCallTarget, + op::Parameter(0), op::Pad(op::Parameter(1), _)))), + _))); +} + +} // anonymous namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index c8f0d4185c63c5bafca6f30acab31cbe8e987277..b22040eee167e784bed58dbc0d0ad2ae042037f3 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" @@ -68,7 +69,7 @@ HloInstruction* MaybePaddedAndSlicedInput( PrimitiveType element_type = input->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(Literal::Zero(element_type)))); + MakeUnique(LiteralUtil::Zero(element_type)))); input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } @@ -125,7 +126,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, PrimitiveType element_type = kernel->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(Literal::Zero(element_type)))); + MakeUnique(LiteralUtil::Zero(element_type)))); return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); } } // namespace @@ -234,9 +235,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // Create a new backward convolution replacing the old one. HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(1); - HloInstruction* padding = - computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(Literal::Zero(input->shape().element_type())))); + HloInstruction* padding = computation->AddInstruction( + HloInstruction::CreateConstant(MakeUnique( + LiteralUtil::Zero(input->shape().element_type())))); HloInstruction* padded_input = MakePadHlo(input, padding, input_padding_config).ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index cd833ec7bd858aabee84ac306d198e80eb112506..3838fee674566196e10ddd98462c1a1aa7835e1a 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -32,27 +32,27 @@ namespace gpu { ParallelLoopEmitter::ParallelLoopEmitter( BodyEmitter body_emitter, const Shape& shape, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, int unroll_factor) - : LoopEmitter(body_emitter, shape, ir_builder), + : LoopEmitter(body_emitter, shape, b), launch_dimensions_(launch_dimensions), unroll_factor_(unroll_factor) {} ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, int unroll_factor) - : LoopEmitter(target_element_generator, target_arrays, ir_builder), + : LoopEmitter(target_element_generator, target_arrays, b), launch_dimensions_(launch_dimensions), unroll_factor_(unroll_factor) {} ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, const llvm_ir::IrArray& target_array, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, int unroll_factor) - : LoopEmitter(target_element_generator, target_array, ir_builder), + : LoopEmitter(target_element_generator, target_array, b), launch_dimensions_(launch_dimensions), unroll_factor_(unroll_factor) {} @@ -74,29 +74,27 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( CHECK_NE(index_type, nullptr); std::vector array_indices; llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, ir_builder_); + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(), static_cast(block_id)); - block_id = ir_builder_->CreateZExtOrTrunc(block_id, index_type, "block_id"); + block_id = b_->CreateZExtOrTrunc(block_id, index_type, "block_id"); // Per the PTX documentation: // "It is guaranteed that [...] 0 <= %tid.x < %ntid.x" // // %ntid.x is currently specified as 1024. llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, ir_builder_); + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.threads_per_block(), static_cast(thread_id)); - thread_id = - ir_builder_->CreateZExtOrTrunc(thread_id, index_type, "thread_id"); - - llvm::Value* linear_index_base = ir_builder_->CreateAdd( - ir_builder_->CreateMul( - block_id, - llvm::ConstantInt::get(index_type, - launch_dimensions_.threads_per_block()), - "", - /*HasNUW=*/true, /*HasNSW=*/true), + thread_id = b_->CreateZExtOrTrunc(thread_id, index_type, "thread_id"); + + llvm::Value* linear_index_base = b_->CreateAdd( + b_->CreateMul(block_id, + llvm::ConstantInt::get( + index_type, launch_dimensions_.threads_per_block()), + "", + /*HasNUW=*/true, /*HasNSW=*/true), thread_id, "linear_index", /*HasNUW=*/true, /*HasNSW=*/true); // Add an @llvm.assume(linear_index < threads_per_block * num_blocks). @@ -109,41 +107,41 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( // conditions in the same basic block as their operands. llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::assume, - {ir_builder_->CreateICmpULT( + {b_->CreateICmpULT( linear_index_base, llvm::ConstantInt::get(index_type, launch_dimensions_.threads_per_block() * launch_dimensions_.block_count()), "linear_index_in_range")}, - {}, ir_builder_); + {}, b_); if (unroll_factor_ > 1) { - linear_index_base = ir_builder_->CreateMul( + linear_index_base = b_->CreateMul( linear_index_base, llvm::ConstantInt::get(index_type, unroll_factor_), "linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true); } - array_indices.emplace_back(linear_index_base, shape_, ir_builder_); + array_indices.emplace_back(linear_index_base, shape_, b_); for (int i = 1; i < unroll_factor_; ++i) { - llvm::Value* linear_index = ir_builder_->CreateAdd( - linear_index_base, llvm::ConstantInt::get(index_type, i), - "linear_index", - /*HasNUW=*/true, /*HasNSW=*/true); - array_indices.emplace_back(linear_index, shape_, ir_builder_); + llvm::Value* linear_index = + b_->CreateAdd(linear_index_base, llvm::ConstantInt::get(index_type, i), + "linear_index", + /*HasNUW=*/true, /*HasNSW=*/true); + array_indices.emplace_back(linear_index, shape_, b_); } auto if_in_bounds = llvm_ir::EmitIfThenElse( - ir_builder_->CreateICmpULT( + b_->CreateICmpULT( linear_index_base, llvm::ConstantInt::get(index_type, ShapeUtil::ElementsIn(shape_))), - llvm_ir::IrName(loop_name, "in_bounds"), ir_builder_, false); + llvm_ir::IrName(loop_name, "in_bounds"), b_, false); // Set exit_bb_ to the exit block of the if structure. exit_bb_ = if_in_bounds.after_block; CHECK_NE(nullptr, exit_bb_); // Set IR builder insertion point to the body of the if structure. - llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, b_); return array_indices; } diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index 302e1bf1bc8e90f2eebd838f156a1552e86185ac..b82a23419df08cafdc69b6d2f14528484b95dc73 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -34,13 +34,13 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { // The meanings of other parameters are the same as LoopEmitter. ParallelLoopEmitter(BodyEmitter body_emitter, const Shape& shape, const LaunchDimensions& launch_dimensions, - llvm::IRBuilder<>* ir_builder, int unroll_factor = 1); + llvm::IRBuilder<>* b, int unroll_factor = 1); // Constructs a ParallelLoopEmitter from an element generator that generates // each element of the given target array. ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, const llvm_ir::IrArray& target_array, const LaunchDimensions& launch_dimensions, - llvm::IRBuilder<>* ir_builder, int unroll_factor = 1); + llvm::IRBuilder<>* b, int unroll_factor = 1); // Constructs a loop emitter for a loop that generates on element of each of N // arrays on each iteration. @@ -50,7 +50,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, int unroll_factor = 1); ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc index 88cb10883e97ae663dc492ad088e6daf9133d7f5..84285be70a4ba94101040a639c39b3eaecbb5bb3 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -33,9 +34,12 @@ Status SequentialThunk::Initialize(const GpuExecutable& executable, } Status SequentialThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream, + HloExecutionProfiler* profiler) { + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); for (const auto& thunk : thunks_) { - TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); + TF_RETURN_IF_ERROR( + thunk->ExecuteOnStream(buffer_allocations, stream, profiler)); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h index 135f79e413dfaa27f2f2264e0daa3beb3c305e0f..3c4de1d1a6c912ba31f56c29b10ca004d1e56da6 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -41,7 +42,8 @@ class SequentialThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: // The list of sub-thunks. diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index e4cfc6999f2da04dd7e7a34d854fdb3d75b8bfc6..0806dd51614f4d2da12f3fbbc9fb98df5273d5c8 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -33,13 +33,13 @@ int StreamAssignment::StreamNumberForHlo(const HloInstruction& hlo) const { } void StreamAssignment::AssignStreamToHlo(const HloInstruction* hlo, - int stream_no) { - CHECK_GE(stream_no, 0); - if (stream_no >= stream_count_) { - stream_count_ = stream_no + 1; + int stream_num) { + CHECK_GE(stream_num, 0); + if (stream_num >= stream_count_) { + stream_count_ = stream_num + 1; } - InsertOrDie(&hlo_to_stream_number_, hlo, stream_no); - VLOG(2) << "Assign stream #" << stream_no << " to " << hlo->ToString(); + InsertOrDie(&hlo_to_stream_number_, hlo, stream_num); + VLOG(2) << "Assign stream #" << stream_num << " to " << hlo->ToString(); } namespace { @@ -51,6 +51,12 @@ bool CanRunConcurrently(const HloInstruction& a, const HloInstruction& b, return !reachability.IsConnected(&a, &b); } +constexpr int kInvalidStreamNum = -1; +// Returns true iff `stream_num` is an invalid stream number. +inline bool IsStreamNumValid(int stream_num) { + return stream_num != kInvalidStreamNum; +} + // Returns which existing stream to assign to `hlo`, or -1 if a stream is not // needed. `stream_assignment` is the existing stream assignment for all // instructions topologically before `hlo`. `seen_gemms` contains all GEMMs that @@ -62,7 +68,7 @@ int ComputeStreamToAssign( if (hlo.opcode() == HloOpcode::kParameter || hlo.opcode() == HloOpcode::kConstant) { // kParameter and kConstant do not need a thunk. - return -1; + return kInvalidStreamNum; } if (hlo.GetModule() @@ -75,17 +81,17 @@ int ComputeStreamToAssign( if (!ImplementedAsGemm(hlo)) { // If `hlo` is not implemented as a GEMM, keep it close to its operands to // avoid excessive synchronization. - int stream_no = -1; + int stream_num = -1; for (const auto* operand : hlo.operands()) { if (stream_assignment.HasStreamAssigned(*operand)) { - stream_no = - std::max(stream_no, stream_assignment.StreamNumberForHlo(*operand)); + stream_num = std::max(stream_num, + stream_assignment.StreamNumberForHlo(*operand)); } } - if (stream_no == -1) { - stream_no = 0; + if (!IsStreamNumValid(stream_num)) { + stream_num = 0; } - return stream_no; + return stream_num; } // Assign different streams to concurrent GEMMs. The code below uses a @@ -94,17 +100,17 @@ int ComputeStreamToAssign( // `hlo` a different stream. std::set forbidden_stream_numbers; for (const auto* seen_gemm : seen_gemms) { - int stream_no = stream_assignment.StreamNumberForHlo(*seen_gemm); - if (!forbidden_stream_numbers.count(stream_no) && + int stream_num = stream_assignment.StreamNumberForHlo(*seen_gemm); + if (!forbidden_stream_numbers.count(stream_num) && CanRunConcurrently(*seen_gemm, hlo, reachability)) { - forbidden_stream_numbers.insert(stream_no); + forbidden_stream_numbers.insert(stream_num); } } - for (int stream_no = 0; stream_no < stream_assignment.StreamCount(); - ++stream_no) { - if (!forbidden_stream_numbers.count(stream_no)) { - return stream_no; + for (int stream_num = 0; stream_num < stream_assignment.StreamCount(); + ++stream_num) { + if (!forbidden_stream_numbers.count(stream_num)) { + return stream_num; } } return stream_assignment.StreamCount(); @@ -118,11 +124,27 @@ std::unique_ptr AssignStreams(const HloModule& module) { std::unique_ptr reachability = computation.ComputeReachability(); std::vector seen_gemms; + // The execution of different RNG Hlo instructions in the same module updates + // a common global variable. To avoid a race condition, we simply assign all + // RNG kernels to the same stream to make them run sequentially. + // + // TODO(b/111791052): If we remove such a common variable, we will need to + // clean up the code here. + int stream_num_for_rng = kInvalidStreamNum; for (const auto* hlo : computation.MakeInstructionPostOrder()) { - int stream_no = ComputeStreamToAssign(*hlo, *stream_assignment, - *reachability, seen_gemms); - if (stream_no != -1) { - stream_assignment->AssignStreamToHlo(hlo, stream_no); + // If we ever enable fusion of RNG instructions, we will need to extend this + // code to look inside a fused instruction. + int stream_num = (hlo->opcode() == HloOpcode::kRng && + IsStreamNumValid(stream_num_for_rng)) + ? stream_num_for_rng + : ComputeStreamToAssign(*hlo, *stream_assignment, + *reachability, seen_gemms); + if (IsStreamNumValid(stream_num)) { + stream_assignment->AssignStreamToHlo(hlo, stream_num); + if (hlo->opcode() == HloOpcode::kRng && + !IsStreamNumValid(stream_num_for_rng)) { + stream_num_for_rng = stream_num; + } } if (ImplementedAsGemm(*hlo)) { seen_gemms.push_back(hlo); diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index a50ddf6ac63c7fa7ccace94bc7f40f438aedccf8..05b305ea4cdfdbaeb42544b626a6b9990bb42f57 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -20,10 +20,17 @@ limitations under the License. namespace xla { namespace gpu { -using stream_executor::dnn::DataLayout; -using stream_executor::dnn::DataLayoutString; -using stream_executor::dnn::FilterLayout; -using stream_executor::dnn::FilterLayoutString; +using se::dnn::DataLayout; +using se::dnn::DataLayoutString; +using se::dnn::FilterLayout; +using se::dnn::FilterLayoutString; + +bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) { + int major, minor; + CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major, + &minor)); + return major >= 7; +} StatusOr> StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 39a6a38d001f502b2abb8de6efe2ce623b478c71..1fc46bafa10e7ba6c896f081d5c836bd400886c9 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -25,18 +26,20 @@ limitations under the License. namespace xla { namespace gpu { +// Returns true if the given StreamExecutor is for a Volta or newer nvidia GPU. +bool IsVoltaOrLater(const se::StreamExecutor& stream_exec); + // Returns (input, filter, output) XLA Layout protos given the StreamExecutor // layouts. StatusOr> StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, - stream_executor::dnn::DataLayout input, - stream_executor::dnn::FilterLayout filter, - stream_executor::dnn::DataLayout output); + se::dnn::DataLayout input, + se::dnn::FilterLayout filter, + se::dnn::DataLayout output); // Returns (input, filter, output) StreamExecutor layouts given the XLA layouts. -StatusOr> +StatusOr< + std::tuple> XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, const Layout& input, const Layout& filter, const Layout& output); diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..4fad3f46cf953945e4f395e751e5ba76db97ecc4 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -0,0 +1,223 @@ +# Description: GPU-specific XLA tests. For example, codegen tests that +# verify the IR emitted. +# +# TODO(jlebar): None of these tests actually use the GPU, so they should not +# need to run on machines with GPUs present. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = [":friends"]) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "gpu_codegen_test", + testonly = True, + srcs = ["gpu_codegen_test.cc"], + hdrs = ["gpu_codegen_test.h"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service/gpu:gpu_executable", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "gpu_copy_test", + srcs = ["gpu_copy_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_ftz_test", + srcs = ["gpu_ftz_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_index_test", + srcs = ["gpu_index_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_infeed_test", + srcs = ["infeed_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_kernel_tiling_test", + srcs = ["gpu_kernel_tiling_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_ldg_test", + srcs = ["gpu_ldg_test.cc"], + tags = ["requires-gpu-sm35"], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_noalias_test", + srcs = ["gpu_noalias_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_fusion_test", + srcs = ["gpu_fusion_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_unrolling_test", + srcs = ["gpu_unrolling_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "gpu_alignment_test", + testonly = True, + srcs = ["gpu_alignment_test.cc"], + tags = [ + "requires-gpu-sm35", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", + "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..672c68e59b59dff19f0c5575db26dea455c45053 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuAlignmentTest : public GpuCodegenTest {}; + +TEST_F(GpuAlignmentTest, Test) { + const char* hlo_string = R"( +HloModule GpuAlignmentTest + +ENTRY main { + zero = f32[] constant(0) + tok = token[] after-all() + a = f32[100] parameter(0) + b_tup = (f32[200], token[]) infeed(tok) + b = f32[200] get-tuple-element(b_tup), index=0 + a_padded = f32[150] pad(a, zero), padding=0_50 + b_sliced = f32[150] slice(b), slice={[0:150]} + ROOT c = f32[150] add(a_padded, b_sliced) +} +)"; + + CompileAndVerifyIr(hlo_string, R"( +CHECK: @fusion(i8* align 64 dereferenceable(600) %alloc0, i8* align 16 dereferenceable(400) %alloc1, i8* align 64 dereferenceable(864) %temp_buf) +)"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b8415fe9106137e588f345a3492f93e46aeb5b6 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace gpu { + +std::unique_ptr GpuCodegenTest::CreateNewModuleWithFTZ(bool ftz) { + HloModuleConfig config; + auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); + debug_options.set_xla_gpu_ftz(ftz); + debug_options.set_xla_gpu_max_kernel_unroll_factor(1); + // TODO(b/38354253): Change tests to use Parameters instead of Constants. + debug_options.add_xla_disable_hlo_passes("constant_folding"); + config.set_debug_options(debug_options); + + return MakeUnique(TestName(), config); +} + +void GpuCodegenTest::CompileAndVerifyPtx(std::unique_ptr hlo_module, + const string& pattern) { + std::unique_ptr executable = + std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie()); + string ptx_str = + std::string(static_cast(executable.get())->ptx()); + StatusOr filecheck_result = RunFileCheck(ptx_str, pattern); + ASSERT_TRUE(filecheck_result.ok()); + EXPECT_TRUE(filecheck_result.ValueOrDie()); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h new file mode 100644 index 0000000000000000000000000000000000000000..e4a3573babb7ed746504c1466f85b582aa4d044f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TESTS_GPU_CODEGEN_TEST_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TESTS_GPU_CODEGEN_TEST_H_ + +#include + +#include "tensorflow/compiler/xla/tests/llvm_irgen_test_base.h" + +namespace xla { +namespace gpu { + +// Tests that verify IR or PTX emitted by the GPU backend is as expected. +class GpuCodegenTest : public LlvmIrGenTestBase { + protected: + // Like HloTestBase::CreateNewModule(), with a flag for configuring the ftz + // option. + std::unique_ptr CreateNewModuleWithFTZ(bool ftz); + + // Compiles the given HLO module to PTX and verifies the PTX matches the given + // FileCheck pattern. (See http://llvm.org/docs/CommandGuide/FileCheck.html). + void CompileAndVerifyPtx(std::unique_ptr hlo_module, + const string& pattern); +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TESTS_GPU_CODEGEN_TEST_H_ diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ce69e058e64aab1f3c292b2ad7c7b529d4666b35 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { + +class GpuCopyTest : public GpuCodegenTest {}; + +// The GPU backend should not emit a copy kernel for the kCopy instruction in +// this test. Instead, it should generate a CopyThunk which invokes cuMemcpy at +// runtime. +TEST_F(GpuCopyTest, UseMemcpy) { + HloComputation::Builder builder(TestName()); + + std::unique_ptr literal = + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kCopy, constant)); + + std::unique_ptr computation = builder.Build(); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + // There should not be any kernel prefixed "copy". + CompileAndVerifyIr(std::move(hlo_module), "; CHECK-NOT: define void @_copy", + /*match_optimized_ir=*/false); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..177b94934c7f519172508b5cc6e088f908401193 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc @@ -0,0 +1,119 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" + +// Check that the ftz (flush denormals to zero) flag is reflected in PTX as +// expected. + +namespace xla { +namespace gpu { +namespace { + +class GpuFtzTest : public GpuCodegenTest { + public: + explicit GpuFtzTest(bool ftz) : ftz_(ftz) {} + + // Creates an HLO module that performs the given binary operation on some + // data. + std::unique_ptr CreateBinaryOpModule(HloOpcode op) { + HloComputation::Builder builder(TestName()); + + Shape param_shape = ShapeUtil::MakeShapeWithLayout( + F32, /*dimensions=*/{100, 100}, /*minor_to_major=*/{1, 0}); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /* parameter_number=*/0, param_shape, "x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /* parameter_number=*/1, param_shape, "y")); + builder.AddInstruction(HloInstruction::CreateBinary(param_shape, op, x, y)); + + auto hlo_module = CreateNewModuleWithFTZ(ftz_); + hlo_module->AddEntryComputation(builder.Build()); + return hlo_module; + } + + // Creates an HLO module that performs the given unary operation on some data. + std::unique_ptr CreateUnaryOpModule(HloOpcode op) { + HloComputation::Builder builder(TestName()); + + Shape param_shape = ShapeUtil::MakeShapeWithLayout( + F32, /*dimensions=*/{100, 100}, /*minor_to_major=*/{1, 0}); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /* parameter_number=*/0, param_shape, "x")); + builder.AddInstruction(HloInstruction::CreateUnary(param_shape, op, x)); + + auto hlo_module = CreateNewModuleWithFTZ(ftz_); + hlo_module->AddEntryComputation(builder.Build()); + return hlo_module; + } + + bool ftz_; +}; + +class GpuFtzEnabledTest : public GpuFtzTest { + public: + GpuFtzEnabledTest() : GpuFtzTest(/*ftz=*/true) {} +}; + +class GpuFtzDisabledTest : public GpuFtzTest { + public: + GpuFtzDisabledTest() : GpuFtzTest(/*ftz=*/false) {} +}; + +// Check that we emit mul.ftz.f32 when in ftz mode, and plain mul.f32 otherwise. +TEST_F(GpuFtzEnabledTest, MultiplyFtz) { + CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"( + CHECK-NOT: mul.f32 + CHECK: mul.ftz.f32 + CHECK-NOT: mul.f32 + )"); +} +TEST_F(GpuFtzDisabledTest, MultiplyFtz) { + CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"( + CHECK-NOT: mul.ftz.f32 + CHECK: mul.f32 + CHECK-NOT: mul.ftz.f32 + )"); +} + +// In NVPTX, exp(float) is implemented in libdevice, and consults __nvvm_reflect +// to determine whether or not ftz is enabled. The implementation uses two +// calls to ex2.approx. When ftz is on, we get two calls to the ftz version; +// when ftz is off, we get one call to the ftz version and one call to the +// regular version. +TEST_F(GpuFtzEnabledTest, ExpFtz) { + CompileAndVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"( + CHECK-NOT: ex2.approx.f32 + CHECK: ex2.approx.ftz.f32 + CHECK-NOT: ex2.approx.f32 + CHECK: ex2.approx.ftz.f32 + CHECK-NOT: ex2.approx.f32 + CHECK-NOT: ex2.approx.ftz.f32 + )"); +} + +TEST_F(GpuFtzDisabledTest, ExpFtz) { + CompileAndVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"( + CHECK-NOT: ex2.approx.f32 + CHECK-DAG: ex2.approx.ftz.f32 + CHECK-DAG: ex2.approx.f32 + CHECK-NOT: ex2.approx.f32 + CHECK-NOT: ex2.approx.ftz.f32 + )"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..674b436a8e3135a5dfe3731647897696bf1321cd --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_fusion_test.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuFusionTest : public GpuCodegenTest {}; + +TEST_F(GpuFusionTest, FusedReshape) { + const char* hlo_text = R"( + HloModule test_module + + fused_computation { + p0.param_0 = f32[4,1,1]{2,1,0} parameter(0) + p1.param_1 = f32[4,1]{1,0} parameter(1) + reshape = f32[4,1]{1,0} reshape(p0.param_0) + ROOT add = f32[4,1] add(reshape, p1.param_1) + } + + ENTRY BroadcastIntoAdd { + p0 = f32[4,1,1]{2,1,0} parameter(0) + p1 = f32[4,1]{1,0} parameter(1) + ROOT fusion = f32[4,1]{1,0} fusion(p0, p1), kind=kLoop, + calls=fused_computation + } +)"; + + CompileAndVerifyIr(hlo_text, + R"( +; CHECK-LABEL: @fusion +; CHECK: fadd +; CHECK: } + )"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e5958165eff21d82faf821213e50fe30a11059a4 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -0,0 +1,147 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { + +// This file tests the index expressions used to reference source tensors. When +// the destination tensor and source tensor have compatible shapes, the linear +// index is used to access the source tensor. Otherwise, dimensional indices +// computed from the linear index are used to access the source tensor. + +class GpuIndexTest : public GpuCodegenTest {}; + +TEST_F(GpuIndexTest, CompatibleUseLinearIndex) { + HloComputation::Builder builder(TestName()); + + auto param_shape = ShapeUtil::MakeShape(F32, {5, 7, 2}); + HloInstruction* param_x = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "x")); + HloInstruction* param_y = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "y")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {5, 7, 2}), HloOpcode::kGe, param_x, param_y)); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(builder.Build()); + + // Check the optimized IR as the unoptimized IR contains dead udiv and urem. + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-NOT: udiv +; CHECK-NOT: urem + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshape) { + HloModuleConfig config; + config.set_debug_options(HloTestBase::GetDebugOptionsForTest()); + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY CompatibleUseLinearIndexWithReshape { + x = f32[5,7,2]{2,1,0} parameter(0) + y = f32[5,14]{1,0} parameter(1) + reshape = f32[5,7,2]{2,1,0} reshape(y) + ROOT gte = pred[5,7,2]{2,1,0} greater-than-or-equal-to(x, reshape) + })", + config) + .ValueOrDie(); + + // Check the optimized IR as the unoptimized IR contains dead udiv and urem. + CompileAndVerifyIr(std::move(module), + R"( +; CHECK-NOT: udiv +; CHECK-NOT: urem + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithReshapeAndBroadcast) { + HloModuleConfig config; + config.set_debug_options(HloTestBase::GetDebugOptionsForTest()); + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY CompatibleUseLinearIndexWithReshape { + x = f32[5,7,2]{2,1,0} parameter(0) + y = f32[14]{0} parameter(1) + reshape = f32[7,2]{1,0} reshape(y) + broadcast = f32[5,7,2]{2,1,0} broadcast(reshape), dimensions={1,2} + ROOT gte = pred[5,7,2]{2,1,0} greater-than-or-equal-to(x, broadcast) + })", + config) + .ValueOrDie(); + + // Check the optimized IR reuses the linear index by calculating modulo 14. + CompileAndVerifyIr(std::move(module), + R"( +; CHECK: %[[urem1:.*]] = urem i{{[0-9]*}} %[[linear_index:.*]], 14 +; CHECK: %[[bitcast:.*]] = bitcast i8 addrspace(1)* %[[alloc:.*]] to float addrspace(1)* +; CHECK: %[[idx1:.*]] = zext i{{[0-9]*}} %[[urem1]] to i64 +; CHECK: getelementptr inbounds float, float addrspace(1)* %[[bitcast]], i64 %[[idx1]] + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuIndexTest, CompatibleUseLinearIndexWithSizeOneDimensions) { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_max_kernel_unroll_factor(1); + config.set_debug_options(debug_options); + + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY CompatibleUseLinearIndexWithSizeOneDimensions { + x = f32[1,1024,1,256]{3,2,1,0} parameter(0) + ROOT y = f16[1,1024,1,256]{2,3,1,0} convert(x) + })", + config) + .ValueOrDie(); + + // Check that the unoptimized IR reuses the linear index. + CompileAndVerifyIr(std::move(module), + R"( +; CHECK-LABEL: @fusion +; CHECK: udiv i32 %[[linear_index:.*]], 262144 +; CHECK: %[[ld_addr:.*]] = getelementptr inbounds float, float* {{.*}}, i32 %[[linear_index]] +; CHECK: load float, float* %[[ld_addr]] +; CHECK: %[[st_addr:.*]] = getelementptr inbounds half, half* {{.*}}, i32 %[[linear_index]] +; CHECK: store half {{.*}}, half* %[[st_addr]] + )", + /*match_optimized_ir=*/false); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cca35316f0c472d2a17c466f8cd1af7f22575a8b --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -0,0 +1,177 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuKernelTilingTest : public GpuCodegenTest { + protected: + GpuKernelTilingTest() { + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + config_.set_debug_options(debug_options); + // Disable layout_assignment to use the preassigned layouts. + debug_options.add_xla_disable_hlo_passes("layout_assignment"); + } + HloModuleConfig config_; +}; + +TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { + const char *const kHloString = R"( + HloModule unnested_transpose_1 + + ENTRY unnested_transpose_1 { + para0 = f16[32,3,64]{2,1,0} parameter(0) + ROOT copy1 = f16[32,3,64]{1,0,2} copy(para0) + })"; + + // Check that a call to llvm.nvvm.barrier0 is generated. + auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @copy +; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) { + const char *const kHloString = R"( + HloModule unnested_transpose_2 + + ENTRY unnested_transpose_2 { + para0 = f16[2,3,64]{2,1,0} parameter(0) + ROOT copy1 = f16[2,3,64]{1,0,2} copy(para0) + })"; + + // Check that a call to llvm.nvvm.barrier0 is not generated. + auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @copy +; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) { + const char *const kHloString = R"( + HloModule multiple_output_fusion_1 + fused_computation.1 { + param0 = f32[4,5,6,7,8]{4,3,2,1,0} parameter(0) + copy = f32[4,5,6,7,8]{2,1,4,3,0} copy(param0) + ROOT convert = f16[4,5,6,7,8]{2,1,4,3,0} convert(copy) + } + + ENTRY copy_in_fusion_run_without_hlo_passes { + para0 = f32[4,5,6,7,8]{4,3,2,1,0} parameter(0) + ROOT fusion.1 = f16[4,5,6,7,8]{2,1,4,3,0} fusion(para0), kind=kLoop, + calls=fused_computation.1 + })"; + + // Check that a call to llvm.nvvm.barrier0 is generated. + auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) { + const char *const kHloString = R"( + HloModule multiple_output_fusion_1 + fused_computation.1 { + param0 = f16[8,31,31,65]{3,2,1,0} parameter(0) + param1 = f16[8,31,31,65]{3,2,1,0} parameter(1) + copy0 = f16[8,31,31,65]{2,1,3,0} copy(param0) + copy1 = f16[8,31,31,65]{2,1,3,0} copy(param1) + ROOT tuple1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0}) + tuple(copy0, copy1) + } + + ENTRY multiple_output_fusion_1 { + para0 = f16[8,31,31,65]{3,2,1,0} parameter(0) + para1 = f16[8,31,31,65]{3,2,1,0} parameter(1) + ROOT fusion.1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0}) + fusion(para0,para1), kind=kLoop, calls=fused_computation.1 + })"; + + // Check that a call to llvm.nvvm.barrier0 is generated. + auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); + + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + +TEST_F(GpuKernelTilingTest, + MultipleOutputFusionWithTwoPossibleTransposesNotTiled) { + const char *const kHloString = R"( + HloModule multiple_output_fusion_2 + fused_computation.1 { + param0 = f16[8,31,31,65]{3,2,1,0} parameter(0) + param1 = f16[8,31,31,65]{1,3,2,0} parameter(1) + copy2 = f16[8,31,31,65]{2,1,3,0} copy(param0) + copy3 = f16[8,31,31,65]{2,1,3,0} copy(param1) + ROOT tuple1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0}) + tuple(copy2, copy3) + } + + ENTRY multiple_output_fusion_2 { + para0 = f16[8,31,31,65]{3,2,1,0} parameter(0) + para1 = f16[8,31,31,65]{1,3,2,0} parameter(1) + ROOT fusion1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0}) + fusion(para0,para1), kind=kLoop, calls=fused_computation.1 + })"; + + // Check that a call to llvm.nvvm.barrier0 is not generated. + auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c9ae7bada5e7545b558b6fcb872ece60850cbe9 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc @@ -0,0 +1,141 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests that we emit ld.global.nc (the PTX instruction corresponding to CUDA's +// __ldg builtin) for reads of buffers that don't change during a kernel's +// execution. + +#include +#include + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { + +class GpuLdgTest : public GpuCodegenTest {}; + +// Parameters are never overwritten, so parameter reads should get ld.global.nc +// reads. +TEST_F(GpuLdgTest, LdgForParamRead) { + HloComputation::Builder builder(TestName()); + + auto shape = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + std::unique_ptr computation = builder.Build(); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + CompileAndVerifyPtx(std::move(hlo_module), R"( + CHECK-NOT: ld.global.f32 + CHECK: ld.global.nc.f32 + )"); +} + +// Check that reading a buffer produced by a non-parameter HLO also results in +// ld.global.nc, if that buffer isn't modified within the instruction that reads +// it. +TEST_F(GpuLdgTest, LdgForNonParamRead) { + HloComputation::Builder builder(TestName()); + + auto shape = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + HloInstruction* square = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, add)); + builder.AddInstruction(HloInstruction::CreateTuple({add, square})); + std::unique_ptr computation = builder.Build(); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + CompileAndVerifyPtx(std::move(hlo_module), R"( + CHECK: { + CHECK-NOT: ld.global.f32 + CHECK: ld.global.nc.f32 + CHECK: } + )"); +} + +// Check that reading a buffer that's modified in-place does not produce +// ld.global.nc. +// +// We do this by creating a reduce that feeds into a sin. We don't currently +// fuse sin into reduce, and the sin is elementwise, so it reuses its input +// buffer as its output. +// +// It seems like a fair bet that we won't start fusing sin into the output of +// reduce in the foreseeable future. But if that turns out to be wrong, I give +// you, future reader, permission to delete this test. +TEST_F(GpuLdgTest, NoLdgWhenSharingBuffer) { + auto hlo_module = CreateNewModule(); + HloComputation::Builder builder(TestName()); + + HloComputation* reduce_computation; + { + auto embedded_builder = HloComputation::Builder("add"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + reduce_computation = + hlo_module->AddEmbeddedComputation(embedded_builder.Build()); + } + + auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {2}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "x")); + HloInstruction* reduce = builder.AddInstruction(HloInstruction::CreateReduce( + reduce_shape, + builder.AddInstruction(HloInstruction::CreateBinary( + param_shape, HloOpcode::kAdd, param, param)), + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), + {0}, reduce_computation)); + builder.AddInstruction( + HloInstruction::CreateUnary(reduce_shape, HloOpcode::kSin, reduce)); + + std::unique_ptr computation = builder.Build(); + hlo_module->AddEntryComputation(std::move(computation)); + + CompileAndVerifyPtx(std::move(hlo_module), R"( + CHECK-LABEL: .entry sin + CHECK: { + CHECK-NOT: ld.global.nc.f32 + CHECK: ld.global.f32 + CHECK: } + )"); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c42e5704a4d2e611a203293e60a86ba4104bca46 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { + +class GpuNoAliasTest : public GpuCodegenTest {}; + +TEST_F(GpuNoAliasTest, Concat) { + HloComputation::Builder builder(TestName()); + + auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* param_x = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "x")); + HloInstruction* param_y = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "y")); + HloInstruction* concat = + builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(F32, {2, 4}), {param_x, param_y}, 1)); + builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(F32, {2, 6}), {concat, param_x}, 1)); + + std::unique_ptr computation = builder.Build(); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK: %[[x_gep:.*]] = getelementptr inbounds [2 x [2 x float]], [2 x [2 x float]]* %x{{.*}}, i32 0 +; CHECK: load float, float* %[[x_gep]], {{.*}}, !noalias ![[param_noalias:.*]] +; CHECK: %[[y_gep:.*]] = getelementptr inbounds [2 x [2 x float]], [2 x [2 x float]]* %y{{.*}}, i32 0 +; CHECK: load float, float* %[[y_gep]], {{.*}}, !noalias ![[param_noalias]] +; CHECK: %[[result_ptr:.*]] = bitcast [2 x [6 x float]]* %fusion{{.*}} to float* +; CHECK: %[[result_gep:.*]] = getelementptr inbounds float, float* %[[result_ptr]] +; CHECK: store float {{.*}}, float* %[[result_gep]], !alias.scope ![[param_noalias]] +; CHECK: ![[param_noalias]] = !{![[retval_buffer:.*]]} + )", + /*match_optimized_ir=*/false); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..962293630683fcbbce3941f622061a2ff0f02dda --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -0,0 +1,185 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuUnrollingTest : public GpuCodegenTest {}; + +const char *const kAddModule = R"( + HloModule test_module + + fused_computation { + p0.param_0 = f32[2,2]{1,0} parameter(0) + p1.param_1 = f32[2,2]{1,0} parameter(1) + ROOT add = f32[2,2] add(p0.param_0, p1.param_1) + } + + ENTRY BroadcastIntoAdd { + p0 = f32[2,2]{1,0} parameter(0) + p1 = f32[2,2]{1,0} parameter(1) + ROOT fusion = f32[2,2]{1,0} fusion(p0, p1), kind=kLoop, + calls=fused_computation + })"; + +TEST_F(GpuUnrollingTest, DoNotUnroll) { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_max_kernel_unroll_factor(1); + config.set_debug_options(debug_options); + auto hlo_module = ParseHloString(kAddModule, config).ValueOrDie(); + + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: @fusion +; CHECK: fadd +; CHECK-NOT: fadd +; CHECK: } + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuUnrollingTest, UnrollFourTimes) { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + // We request a factor of 8, but the computation works on 4 elements, limiting + // the maximum unroll factor. + debug_options.set_xla_gpu_max_kernel_unroll_factor(8); + config.set_debug_options(debug_options); + auto hlo_module = ParseHloString(kAddModule, config).ValueOrDie(); + + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: @fusion +; CHECK: fadd +; CHECK: fadd +; CHECK: fadd +; CHECK: fadd +; CHECK-NOT: fadd +; CHECK: } + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuUnrollingTest, UnrollDefaultTimes) { + // The default unrolling factor is 4. + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + auto hlo_module = ParseHloString(kAddModule, config).ValueOrDie(); + + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: @fusion +; CHECK: load <4 x float> +; CHECK: fadd +; CHECK: fadd +; CHECK: fadd +; CHECK: fadd +; CHECK-NOT: fadd +; CHECK: store <4 x float> +; CHECK: } + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuUnrollingTest, UnrollUnfusedAdd) { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_max_kernel_unroll_factor(4); + config.set_debug_options(debug_options); + + const char *const kUnfusedAddModule = R"( + HloModule test_module + + ENTRY AddFunc { + p0 = f32[2,2]{1,0} parameter(0) + p1 = f32[2,2]{1,0} parameter(1) + ROOT add = f32[2,2]{1,0} add(p0, p1) + })"; + auto hlo_module = ParseHloString(kUnfusedAddModule, config).ValueOrDie(); + + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: @add +; CHECK: load <4 x float> +; CHECK: fadd +; CHECK: fadd +; CHECK: fadd +; CHECK: fadd +; CHECK-NOT: fadd +; CHECK: store <4 x float> +; CHECK: } + )", + /*match_optimized_ir=*/true); +} + +TEST_F(GpuUnrollingTest, UnrollMultiOutputFusion) { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_max_kernel_unroll_factor(2); + config.set_debug_options(debug_options); + + const char *const kMultiOutputFusionModule = R"( + HloModule test_module + + fused_computation { + p0.param_0 = f32[2,2]{1,0} parameter(0) + p1.param_1 = f32[2,2]{1,0} parameter(1) + add = f32[2,2]{1,0} add(p0.param_0, p1.param_1) + mul = f32[2,2]{1,0} multiply(p0.param_0, p1.param_1) + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(add, mul) + } + + ENTRY BroadcastIntoAdd { + p0 = f32[2,2]{1,0} parameter(0) + p1 = f32[2,2]{1,0} parameter(1) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p0, p1), kind=kLoop, + calls=fused_computation + })"; + auto hlo_module = + ParseHloString(kMultiOutputFusionModule, config).ValueOrDie(); + + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: @fusion +; CHECK: load <2 x float> +; CHECK: load <2 x float> +; CHECK-NOT: load <2 x float> +; CHECK: fadd +; CHECK: fmul +; CHECK: fadd +; CHECK: fmul +; CHECK: store <2 x float> +; CHECK: store <2 x float> +; CHECK-NOT: store <2 x float> +; CHECK-NOT: fadd +; CHECK-NOT: fmul +; CHECK: } + )", + /*match_optimized_ir=*/true); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9072b30317d253fd6d50e9d98949cad4eaebfe7b --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc @@ -0,0 +1,121 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class InfeedTest : public ClientLibraryTestBase { + protected: + // Transfers the given literal to the infeed interface of the device, and + // check if the returned data from Infeed HLO is same as the literal. + void TestInfeedRoundTrip(const Literal& literal) { + // TODO(b/30481585) Explicitly reset the Infeed state so that the + // test is not affected by the state from the previous tests. + ASSERT_IS_OK(client_->TransferToInfeed(literal)); + XlaBuilder builder(TestName()); + Infeed(&builder, literal.shape()); + if (ShapeUtil::IsTuple(literal.shape())) { + // TODO(b/30609564): Use ComputeAndCompareLiteral instead. + ComputeAndCompareTuple(&builder, literal, {}); + } else { + ComputeAndCompareLiteral(&builder, literal, {}); + } + } +}; + +TEST_F(InfeedTest, SingleInfeedR0Bool) { + TestInfeedRoundTrip(*LiteralUtil::CreateR0(true)); +} + +TEST_F(InfeedTest, SingleInfeedR1U32) { + TestInfeedRoundTrip(*LiteralUtil::CreateR1({1, 2, 3})); +} + +TEST_F(InfeedTest, SingleInfeedR2F32) { + TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); +} + +TEST_F(InfeedTest, SingleInfeedR3F32) { + TestInfeedRoundTrip( + *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); +} + +TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { + const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); + const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); + + TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0minor)); + + TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0major)); +} + +TEST_F(InfeedTest, SingleInfeedR4S32) { + TestInfeedRoundTrip(*LiteralUtil::CreateR4( + {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, + {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); +} + +// Tests that a large infeed can be handled. +TEST_F(InfeedTest, LargeInfeed) { + Array4D array(80, 100, 8, 128); + array.FillIota(1.0f); + TestInfeedRoundTrip(*LiteralUtil::CreateR4FromArray4D(array)); +} + +TEST_F(InfeedTest, SingleInfeedTuple) { + TestInfeedRoundTrip( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), + LiteralUtil::CreateR0(false).get()})); +} + +TEST_F(InfeedTest, SingleInfeedEmptyTuple) { + TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); +} + +// Tests that a large tuple infeed can be handled. +TEST_F(InfeedTest, SingleInfeedLargeTuple) { + Array4D array(40, 100, 8, 128); + array.FillIota(1.0f); + TestInfeedRoundTrip(*LiteralUtil::MakeTuple( + {LiteralUtil::CreateR4FromArray4D(array).get(), + LiteralUtil::CreateR0(5).get()})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/thunk.cc b/tensorflow/compiler/xla/service/gpu/thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..c78605cebbc671272b8df9faf0e0cc54be2f5b1c --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/thunk.cc @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/thunk.h" + +namespace xla { +namespace gpu { + +std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) { + switch (kind) { + case Thunk::kConditional: + return os << "kConditional"; + case Thunk::kConvolution: + return os << "kConvolution"; + case Thunk::kCopy: + return os << "kCopy"; + case Thunk::kCudnnBatchNormBackward: + return os << "kCudnnBatchNormBackward"; + case Thunk::kCudnnBatchNormForwardInference: + return os << "kCudnnBatchNormForwardInference"; + case Thunk::kCudnnBatchNormForwardTraining: + return os << "kCudnnBatchNormForwardTraining"; + case Thunk::kFft: + return os << "kFft"; + case Thunk::kGemm: + return os << "kGemm"; + case Thunk::kInfeed: + return os << "kInfeed"; + case Thunk::kKernel: + return os << "kKernel"; + case Thunk::kMemset32BitValue: + return os << "kMemset32BitValue"; + case Thunk::kMemzero: + return os << "kMemzero"; + case Thunk::kOutfeed: + return os << "kOutfeed"; + case Thunk::kSequential: + return os << "kSequential"; + case Thunk::kTuple: + return os << "kTuple"; + case Thunk::kWhile: + return os << "kWhile"; + } +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 931c0bffab850362dbd2df975657dd47d9cbd3ae..4df0bb005b623e5ac79a4dfcb7c5a8a7a400940c 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -40,7 +41,7 @@ class GpuExecutable; // This is thread-compatible. class Thunk { public: - enum class Kind { + enum Kind { kConditional, kConvolution, kCopy, @@ -53,6 +54,7 @@ class Thunk { kKernel, kMemset32BitValue, kMemzero, + kOutfeed, kSequential, kTuple, kWhile, @@ -94,11 +96,12 @@ class Thunk { // Execute the kernel for the thunk on the given stream. This method must be // called after Initialize and can be called multiple times over Thunk's - // lifetime. Stream argument must be non-null. + // lifetime. 'stream' and 'profiler' must be non-null. // // Precondition: Initialize(stream->parent()) has been called. virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) = 0; + se::Stream* stream, + HloExecutionProfiler* profiler) = 0; private: Kind kind_; @@ -108,6 +111,8 @@ class Thunk { // A sequence of thunks. using ThunkSequence = std::vector>; +std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index 97cb04c38fbf18e516857f5269c984696ca204c3..8579b1545fd24f80621ac0f53b997e33586cbabe 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -15,30 +15,41 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace gpu { Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { - std::vector tuple_element_buffer_addresses; - for (BufferAllocation::Slice tuple_element_buffer : tuple_element_buffers_) { - tuple_element_buffer_addresses.push_back( - buffer_allocations.GetDeviceAddress(tuple_element_buffer).opaque()); + se::Stream* stream, + HloExecutionProfiler* profiler) { + auto size = tuple_element_buffers_.size(); + auto tuple_element_buffer_addresses = MakeUnique(size); + for (int i = 0; i != size; ++i) { + tuple_element_buffer_addresses[i] = + buffer_allocations.GetDeviceAddress(tuple_element_buffers_[i]).opaque(); } se::DeviceMemory dest_buffer_address( buffer_allocations.GetDeviceAddress(dest_buffer_)); - auto host_size = tuple_element_buffer_addresses.size() * sizeof(void*); + auto host_size = size * sizeof(void*); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); if (!stream ->ThenMemcpy(&dest_buffer_address, - tuple_element_buffer_addresses.data(), host_size) + tuple_element_buffer_addresses.get(), host_size) .ok()) { return InternalError( "Unable to launch MemcpyH2D from %p to %p with size %lu", - tuple_element_buffer_addresses.data(), dest_buffer_address.opaque(), - sizeof(void*) * tuple_element_buffer_addresses.size()); + tuple_element_buffer_addresses.get(), dest_buffer_address.opaque(), + host_size); + } + // Free the tuple address buffer when memcpy is done. + auto* buffers_raw = tuple_element_buffer_addresses.release(); + if (!stream->ThenDoHostCallback([buffers_raw] { delete[] buffers_raw; }) + .ok()) { + delete[] buffers_raw; + return InternalError("Unable to enqueue host callback!"); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h index 951f809b51937c97a6e7de0345ec58a8b66a4242..2d5735d6c40ccd26f0e527f1a02403910db4c812 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -46,7 +47,8 @@ class TupleThunk : public Thunk { TupleThunk& operator=(const TupleThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const std::vector tuple_element_buffers_; diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 30b9640c4c75dae61e9a90da5fb10e9d4a90cd26..d81d87e7dc54cd752000b85f3ec173d66d7195e4 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -29,10 +30,14 @@ WhileThunk::WhileThunk( const HloInstruction* hlo) : Thunk(Kind::kWhile, hlo), condition_result_buffer_index_(condition_result_buffer_index), + // Pass nullptr as the HloInstruction* to the condition_thunk_sequence_ + // and body_thunk_sequence_ constructors because these SequentialThunks + // are logically "part of" this WhileThunk, and shouldn't be profiled + // separately from it. condition_thunk_sequence_(MakeUnique( - std::move(*condition_thunk_sequence), hlo)), - body_thunk_sequence_( - MakeUnique(std::move(*body_thunk_sequence), hlo)) {} + std::move(*condition_thunk_sequence), nullptr)), + body_thunk_sequence_(MakeUnique( + std::move(*body_thunk_sequence), nullptr)) {} Status WhileThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { @@ -43,18 +48,24 @@ Status WhileThunk::Initialize(const GpuExecutable& executable, } Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) { + se::Stream* stream, + HloExecutionProfiler* profiler) { se::DeviceMemoryBase condition_result_data = buffer_allocations.GetDeviceAddress(condition_result_buffer_index_); + auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); while (true) { // Invoke thunk sequence for while 'condition' computation. - TF_RETURN_IF_ERROR( - condition_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); + profiler->StartHloComputation(); + VLOG(3) << "Executing condition computation"; + TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream( + buffer_allocations, stream, profiler)); + profiler->FinishHloComputation(hlo_instruction()->while_condition()); // Copy the result of condition computation and break the loop if 'false'. bool condition_result; stream->ThenMemcpy(&condition_result, condition_result_data, sizeof(bool)); + VLOG(3) << "condition_result = " << condition_result; Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError( @@ -66,9 +77,15 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, break; } - // Invoke thunk sequence for while 'body' computation. - TF_RETURN_IF_ERROR( - body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); + // We measure the time of one execution of the while body computation. The + // while body may be executed more than once, the last measurement "wins". + profiler->StartHloComputation(); + VLOG(3) << "Executing body computation"; + // Invoke thunk sequence for while 'body' computation, and pass on + // 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'. + TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations, + stream, profiler)); + profiler->FinishHloComputation(hlo_instruction()->while_body()); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index 22176685a92df9c95b10f755b209309843c0fa3a..9270f95ee67cf0bd3ab8082452a9d8703cb4304e 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -48,7 +49,8 @@ class WhileThunk : public Thunk { Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - se::Stream* stream) override; + se::Stream* stream, + HloExecutionProfiler* profiler) override; private: const BufferAllocation::Slice condition_result_buffer_index_; diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc deleted file mode 100644 index 7749201cbceece216a2db2569936949eb7de5125..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ /dev/null @@ -1,521 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" - -#include -#include - -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace xla { -namespace gpu { - -namespace { - -// TODO(b/33483676) Use an expression tree to specify computations to pattern -// match for while transformations. - -// ExprTree is a simple recursive data structure used to express computation -// patterns to match. -// -// Each ExprTree node is comprised of an HloOpcode, and a set of operands (each -// of type ExprTree). Operands can be added by specifying the index and -// HloOpcode of the operand. -// -// For example, the following computation: -// -// Parameter -// | -// Const GetTupleElement -// \ / -// Add (root) -// -// Can be matched with the following expression tree: -// -// ExprTree add(HloOpcode::kAdd, -// ExprTree(HloOpcode::kConstant), -// ExprTree(HloOpcode::kGetTupleElement, -// tuple_index, ExprTree(HloOpcode::kParameter))); -// -// Match the ExprTree root against an Hlo graph: -// -// ExprTree::TaggedInstructionMap tagged_instructions; -// TF_RETURN_IF_ERROR(add.Match(computation_->root_instruction(), -// &tagged_instructions)); -// -// Instructions that are "tagged" with a context-specific string will -// be returned in 'tagged_instructions' for further processing (i.e. parsing -// constants or recording the tuple_index). -// -class ExprTree { - public: - explicit ExprTree(HloOpcode opcode) : opcode_(opcode) {} - ExprTree(HloOpcode opcode, const string& tag) : opcode_(opcode), tag_(tag) {} - ExprTree(HloOpcode opcode, const ExprTree& operand0) : opcode_(opcode) { - SetOperand(0, operand0); - } - ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0) - : opcode_(opcode) { - SetOperand(index0, operand0); - } - ExprTree(HloOpcode opcode, int64 index0, const ExprTree& operand0, - int64 index1, const ExprTree& operand1) - : opcode_(opcode) { - SetOperand(index0, operand0); - SetOperand(index1, operand1); - } - ExprTree(HloOpcode opcode, const string& tag, const ExprTree& operand0) - : opcode_(opcode), tag_(tag) { - SetOperand(0, operand0); - } - ExprTree(HloOpcode opcode, const ExprTree& operand0, const ExprTree& operand1) - : opcode_(opcode) { - SetOperand(0, operand0); - SetOperand(1, operand1); - } - - ExprTree(const ExprTree& to_copy) { - opcode_ = to_copy.opcode_; - tag_ = to_copy.tag_; - if (to_copy.fused_root_tree_ != nullptr) { - fused_root_tree_.reset(new ExprTree(*to_copy.fused_root_tree_)); - } - for (auto& pair : to_copy.operands_) { - CHECK(operands_.find(pair.first) == operands_.end()); - operands_.insert(std::make_pair( - pair.first, std::unique_ptr(new ExprTree(*pair.second)))); - } - } - - void SetFusedRoot(const ExprTree& fused_root) { - fused_root_tree_.reset(new ExprTree(fused_root)); - } - - typedef std::unordered_map - TaggedInstructionMap; - - // Matches 'instruction' HloOpcode against 'opcode_'. - // Recursively matches each operand in 'operands_'. - // Recursively matches fused instructions starting at 'fused_root_tree_' - // if 'opcode_ == kFusion'. - // Returns OK status, and instructions in 'tagged_instructions' for each - // matched ExprTree node with a non-empty 'tag_'. - // Returns error message on failure. - Status Match(const HloInstruction* instruction, - TaggedInstructionMap* tagged_instructions) const { - if (opcode_ != instruction->opcode()) { - return InvalidArgument("got opcode %s, want %s", - HloOpcodeString(instruction->opcode()).c_str(), - HloOpcodeString(opcode_).c_str()); - } - - VLOG(2) << "Matched " << HloOpcodeString(opcode_) << ": " << tag_; - if (!tag_.empty()) { - tagged_instructions->insert({tag_, instruction}); - } - - if (instruction->opcode() == HloOpcode::kFusion) { - CHECK(fused_root_tree_ != nullptr); - // Match fused instructions for this node starting a 'fused_root_tree'. - TF_RETURN_IF_ERROR(fused_root_tree_->Match( - instruction->fused_expression_root(), tagged_instructions)); - } - - // Match each operand in 'operands_'. - for (auto& pair : operands_) { - TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first), - tagged_instructions)); - } - return Status::OK(); - } - - private: - void SetOperand(int64 index, const ExprTree& operand) { - CHECK_EQ(0, operands_.count(index)); - operands_.insert(std::make_pair(index, MakeUnique(operand))); - } - - HloOpcode opcode_; - std::unordered_map> operands_; - std::unique_ptr fused_root_tree_; - string tag_; -}; - -// MatcherBase is a base class that provides common functionality for -// sub-classes which match specific target sub-computations (i.e. loop -// induction variable initialization, comparison and update). -class MatcherBase { - public: - MatcherBase() {} - virtual ~MatcherBase() {} - - // Attempts to match each ExprTree in 'expr_trees_'. - // Returns OK on the first successful match, error status otherwise. - virtual Status Run() { - Status status; - for (const ExprTree& expr_tree : expr_trees_) { - status = MatchExprTree(expr_tree); - if (status.ok()) { - return status; - } - } - return status; - } - - virtual Status MatchExprTree(const ExprTree& expr_tree) = 0; - - // Returns the constant value parsed form kConstant 'instruction'. - // Returns error status otherwise. - Status ParseConstInteger(const HloInstruction* instruction, - int64* const_value) const { - CHECK_EQ(HloOpcode::kConstant, instruction->opcode()); - PrimitiveType element_type = instruction->shape().element_type(); - if (element_type != S32 && element_type != S64) { - return InvalidArgument("Expected constant of integral type."); - } - const Literal& literal = instruction->literal(); - PrimitiveType type = literal.shape().element_type(); - if (type != S32 && type != S64) { - return InvalidArgument("Must use S32 or S64 integral types."); - } - if (type == S32) { - *const_value = static_cast(literal.GetFirstElement()); - } else if (type == S64) { - *const_value = literal.GetFirstElement(); - } - return Status::OK(); - } - - StatusOr GetTaggedInstruction( - const string& tag, - const ExprTree::TaggedInstructionMap& tagged_instructions) { - auto it = tagged_instructions.find(tag); - if (it == tagged_instructions.end()) { - return InvalidArgument("Cound not find instruction for tag: %s", - tag.c_str()); - } - return it->second; - } - - protected: - std::vector expr_trees_; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(MatcherBase); -}; - -// WhileConditionComputationMatcher attempts to match a target computation -// pattern in the while condition sub-computation. -// If the target pattern is matched, two pieces of information are extracted -// from 'tagged' instructions returned by the matcher: -// -// *) 'tuple_index': -// *) The loop induction variable tuple_index from the GetTupleElement -// instruction of the matched computation. -// *) Used in subsequent matching passes of while init operand and body -// computations to select loop induction variable tuple element. -// -// *) 'loop_limit': -// *) The integral value from Constant root operand in matched computation. -// *) Used as the constant for the loop limit. -// -class WhileConditionComputationMatcher : public MatcherBase { - public: - explicit WhileConditionComputationMatcher(const HloComputation* computation) - : computation_(computation) { - expr_trees_.emplace_back(BuildCondExprTree()); - } - - int64 loop_limit() const { return loop_limit_; } - int64 tuple_index() const { return tuple_index_; } - - private: - // Builds expression tree for the following condition computation: - // - // Const Parameter - // \ / - // Fusion ------------> FusionParam FusionParam - // \ / - // GTE / - // \ / - // LessThan (fused root) - // - ExprTree BuildCondExprTree() { - // Build ExprTree for fused instructions. - ExprTree fused_root( - HloOpcode::kLt, - ExprTree(HloOpcode::kGetTupleElement, "gte", - ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")), - ExprTree(HloOpcode::kParameter)); - - // Build top-level computation. - ExprTree root(HloOpcode::kFusion, - ExprTree(HloOpcode::kConstant, "loop_limit"), - ExprTree(HloOpcode::kParameter, "param0")); - - root.SetFusedRoot(fused_root); - return root; - } - - Status MatchExprTree(const ExprTree& expr_tree) override { - VLOG(2) << "MATCHING while condition"; - ExprTree::TaggedInstructionMap tagged_instructions; - TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), - &tagged_instructions)); - - // Get tagged GTE instruction and set 'tuple_index_'. - TF_ASSIGN_OR_RETURN(const HloInstruction* gte, - GetTaggedInstruction("gte", tagged_instructions)); - tuple_index_ = gte->tuple_index(); - - // Get tagged Constant instruction and parse 'loop_limit_'. - TF_ASSIGN_OR_RETURN( - const HloInstruction* const_hlo, - GetTaggedInstruction("loop_limit", tagged_instructions)); - TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_limit_)); - - // Get tagged "param0" instruction, and check that it matches - // 'computation_' parameter 0. - TF_ASSIGN_OR_RETURN(const HloInstruction* param0, - GetTaggedInstruction("param0", tagged_instructions)); - if (param0 != computation_->parameter_instruction(0)) { - return InvalidArgument("Unexpected Parameter0 instruction : %s", - param0->name().c_str()); - } - - // Get tagged 'gte.fusion_param.param0', find its associated fusion operand, - // and compare it to 'computation_' parameter0. - TF_ASSIGN_OR_RETURN( - const HloInstruction* gte_fusion_param0, - GetTaggedInstruction("gte.fusion_param.param0", tagged_instructions)); - CHECK_EQ(HloOpcode::kParameter, gte_fusion_param0->opcode()); - CHECK(gte_fusion_param0->IsFused()); - if (gte_fusion_param0->parent()->FusionInstruction()->operand( - gte_fusion_param0->parameter_number()) != - computation_->parameter_instruction(0)) { - return InvalidArgument("Could not match fusion param: %s", - gte_fusion_param0->name().c_str()); - } - - return Status::OK(); - } - - const HloComputation* computation_; - - int64 loop_limit_ = -1; - int64 tuple_index_ = -1; - - TF_DISALLOW_COPY_AND_ASSIGN(WhileConditionComputationMatcher); -}; - -// WhileInitOperandMatcher matches a target computation pattern of the -// while instructions 'init' operand, indexing the tuple at 'tuple_index'. -// On success, parses constant 'loop_start' which represents the loop induction -// variable start values, then returns OK. -// Returns error status otherwise. -class WhileInitOperandMatcher : public MatcherBase { - public: - WhileInitOperandMatcher(const HloInstruction* while_hlo, - const int64 tuple_index) - : while_hlo_(while_hlo), tuple_index_(tuple_index) { - expr_trees_.emplace_back(BuildInitExprTree()); - } - - int64 loop_start() const { return loop_start_; } - - private: - // Builds expression tree for the following while init operand subcomputation: - // - // Const - // | - // Copy - // | - // Tuple0 - // | - // While - // - ExprTree BuildInitExprTree() { - return ExprTree( - HloOpcode::kWhile, "while", - ExprTree(HloOpcode::kTuple, tuple_index_, - ExprTree(HloOpcode::kCopy, - ExprTree(HloOpcode::kConstant, "loop_start")))); - } - - Status MatchExprTree(const ExprTree& expr_tree) override { - VLOG(2) << "MATCHING while init"; - ExprTree::TaggedInstructionMap tagged_instructions; - TF_RETURN_IF_ERROR(expr_tree.Match(while_hlo_, &tagged_instructions)); - - // Get tagged while instruction check against 'while_hlo_'. - TF_ASSIGN_OR_RETURN(const HloInstruction* while_hlo, - GetTaggedInstruction("while", tagged_instructions)); - if (while_hlo != while_hlo_) { - return InvalidArgument("Expected While for instruction : %s", - while_hlo->name().c_str()); - } - - // Get tagged Constant instruction and parse 'loop_start_'. - TF_ASSIGN_OR_RETURN( - const HloInstruction* const_hlo, - GetTaggedInstruction("loop_start", tagged_instructions)); - TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_)); - - return Status::OK(); - } - - const HloInstruction* while_hlo_; - const int64 tuple_index_; - - int64 loop_start_ = -1; - - TF_DISALLOW_COPY_AND_ASSIGN(WhileInitOperandMatcher); -}; - -// WhileBodyComputationMatcher matches a target computation pattern for -// the loop induction variable update. Matching proceeds from the while body -// computation root[tuple_index] to param[tuple_index], where 'tuple_index' -// If the target pattern is matched, parses a constant which represents the -// loop induction variable increment value, then returns status OK. -// Returns error status otherwise. -class WhileBodyComputationMatcher : public MatcherBase { - public: - WhileBodyComputationMatcher(const HloComputation* computation, - const int64 tuple_index) - : computation_(computation), tuple_index_(tuple_index) { - expr_trees_.emplace_back(BuildBodyExprTree(0, 1)); - expr_trees_.emplace_back(BuildBodyExprTree(1, 0)); - } - - int64 loop_increment() const { return loop_increment_; } - - private: - // Builds expression tree for the following while body computation: - // - // - // FusionParam FusionParam - // \ / - // Const Param \ GTE1 - // \ / \ / - // Fusion -----------> Add - // | - // Copy - // | - // Tuple0 - // - ExprTree BuildBodyExprTree(const int64 const_index, const int64 gte_index) { - // Build ExprTree for fused instructions. - ExprTree gte1 = - ExprTree(HloOpcode::kGetTupleElement, "gte", - ExprTree(HloOpcode::kParameter, "gte.fusion_param.param0")); - ExprTree fused_root(HloOpcode::kAdd, const_index, - ExprTree(HloOpcode::kParameter), gte_index, gte1); - - // Build fusion instruction (and set fused root). - ExprTree fusion(HloOpcode::kFusion, 0, - ExprTree(HloOpcode::kConstant, "loop_increment"), 1, - ExprTree(HloOpcode::kParameter, "param0")); - fusion.SetFusedRoot(fused_root); - - // Build top-level computation. - ExprTree tuple0(HloOpcode::kTuple, tuple_index_, - ExprTree(HloOpcode::kCopy, fusion)); - return tuple0; - } - - Status MatchExprTree(const ExprTree& expr_tree) override { - VLOG(2) << "MATCHING while body"; - ExprTree::TaggedInstructionMap tagged_instructions; - TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), - &tagged_instructions)); - - for (const auto& pair : tagged_instructions) { - const auto& tag = pair.first; - const auto& inst = pair.second; - - if (tag == "gte" && inst->tuple_index() != tuple_index_) { - // Check that the matched GTE instruction is at the 'tuple_index' we - // matched in the while condition computation. - return InvalidArgument("Unexpected tuple index instruction : %s", - inst->name().c_str()); - } else if (tag == "loop_increment") { - // ParseHloString the constant which represents the loop induction - // variable increment value. - TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_)); - } else if (tag == "param0" && - inst != computation_->parameter_instruction(0)) { - // Check that the matched parameter == parameter 0 from 'computation_'. - return InvalidArgument("Unexpected Parameter0 instruction : %s", - inst->name().c_str()); - } else if (tag == "gte.fusion_param.param0") { - // Fusion parameter: lookup and compare with associated fusion operand. - CHECK_EQ(HloOpcode::kParameter, inst->opcode()); - CHECK(inst->IsFused()); - if (inst->parent()->FusionInstruction()->operand( - inst->parameter_number()) != - computation_->parameter_instruction(0)) { - return InvalidArgument("Could not match fusion param: %s", - inst->name().c_str()); - } - } - } - return Status::OK(); - } - - const HloComputation* computation_; - const int64 tuple_index_; - - int64 loop_increment_ = -1; - - TF_DISALLOW_COPY_AND_ASSIGN(WhileBodyComputationMatcher); -}; - -} // namespace - -StatusOr> CanTransformWhileToFor( - const HloInstruction* while_hlo) { - if (while_hlo->opcode() != HloOpcode::kWhile) { - return InvalidArgument("Expected While instruction."); - } - - WhileConditionComputationMatcher cond_matcher(while_hlo->while_condition()); - TF_RETURN_IF_ERROR(cond_matcher.Run()); - - WhileInitOperandMatcher init_matcher(while_hlo, cond_matcher.tuple_index()); - TF_RETURN_IF_ERROR(init_matcher.Run()); - - WhileBodyComputationMatcher body_matcher(while_hlo->while_body(), - cond_matcher.tuple_index()); - TF_RETURN_IF_ERROR(body_matcher.Run()); - - // Check for valid For loop parameters. - if (init_matcher.loop_start() >= cond_matcher.loop_limit()) { - return InvalidArgument("Loop start must be less than loop limit."); - } - if (body_matcher.loop_increment() <= 0) { - return InvalidArgument("Loop increment must greater than zero."); - } - return std::make_tuple(init_matcher.loop_start(), cond_matcher.loop_limit(), - body_matcher.loop_increment()); -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.h b/tensorflow/compiler/xla/service/gpu/while_transformer.h deleted file mode 100644 index fe3a954e1828ee4a323872eea81f64c7e780ad24..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ - -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/statusor.h" - -namespace xla { -namespace gpu { - -// Runs an analysis of the while loop instruction 'while_hlo' (and its -// associated sub-computations) to determine if it can be transformed into an -// equivalent "for" loop with the following "for" loop parameters: -// -// *) 'loop_start': loop induction variable starting value. -// *) 'loop_limit': loop induction variable limit value. -// *) 'loop_increment': loop induction variable per-iteration increment value. -// -// Returns an std::tuple = (loop_start, loop_limit, loop_increment) on success. -// The values in the returned tuple are values extracted from the 'while_hlo' -// operand (and its sub-computations) during analysis. -// Returns an error status on failure. -StatusOr> CanTransformWhileToFor( - const HloInstruction* while_hlo); - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index 2f290f61bd527e9827472a78256f015e066e44be..c5f3906356d821e059d2b1213c9083c4408a4d1c 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -13,11 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/while_transformer.h" - #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -42,7 +41,7 @@ class WhileTransformerTest : public HloTestBase { const int64 tuple_index, const int64 limit) { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(limit))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(limit))); auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter( 0, GetLoopStateShape(tuple_index), "loop_state")); auto induction_variable = @@ -65,8 +64,8 @@ class WhileTransformerTest : public HloTestBase { auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, ind_var_tuple_index)); - auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(increment))); + auto inc = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(increment))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); // Update data GTE(data_tuple_index). @@ -89,10 +88,12 @@ class WhileTransformerTest : public HloTestBase { const int64 ind_var_tuple_index, const int64 ind_var_init) { auto builder = HloComputation::Builder(TestName() + ".While"); - auto induction_var_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(ind_var_init))); - auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto induction_var_init = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(ind_var_init))); + auto data_init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1( + {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); auto loop_state_init = ind_var_tuple_index == 0 ? builder.AddInstruction( @@ -108,12 +109,12 @@ class WhileTransformerTest : public HloTestBase { void RunFusionPasses() { // Run standard fusion passes. - EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .ValueOrDie()); - EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .ValueOrDie()); + TF_ASSERT_OK(gpu::GpuInstructionFusion(/*may_duplicate=*/false) + .Run(module_.get()) + .status()); + TF_ASSERT_OK(gpu::GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module_.get()) + .status()); } void RunCopyInsertionPass() { @@ -139,10 +140,7 @@ class WhileTransformerTest : public HloTestBase { Shape condition_result_shape_; }; -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) { +TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) { // Build computation with induction variable at tuple element 0. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); @@ -151,18 +149,13 @@ TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - TF_ASSERT_OK(result.status()); - // Check results. - EXPECT_THAT(result.ConsumeValueOrDie(), - Eq(std::tuple(0, 10, 1))); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_TRUE(result); + EXPECT_EQ(10, *result); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) { +TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) { // Build computation with induction variable at tuple element 1. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(1, 10)); @@ -171,19 +164,14 @@ TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - TF_ASSERT_OK(result.status()); - // Check results. - EXPECT_THAT(result.ConsumeValueOrDie(), - Eq(std::tuple(0, 10, 1))); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_TRUE(result); + EXPECT_EQ(10, *result); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) { - // Build computation with invalid loop limit. +TEST_F(WhileTransformerTest, ImpossibleLoopLimit) { + // Build computation with an impossible loop limit. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 5)); auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1)); @@ -191,17 +179,13 @@ TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - ASSERT_FALSE(result.ok()); - EXPECT_THAT(result.status().error_message(), - HasSubstr("Loop start must be less than loop limit.")); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_TRUE(result); + EXPECT_EQ(0, *result); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) { +TEST_F(WhileTransformerTest, InvalidLoopIncrement) { // Build computation with invalid loop increment. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); @@ -210,11 +194,9 @@ TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) { // Run HLO Optimization passes. RunFusionPasses(); RunCopyInsertionPass(); - // Run WhileTransformer. - auto result = gpu::CanTransformWhileToFor(while_hlo); - ASSERT_FALSE(result.ok()); - EXPECT_THAT(result.status().error_message(), - HasSubstr("Loop increment must greater than zero.")); + + auto result = ComputeWhileLoopTripCount(while_hlo); + ASSERT_FALSE(result); } } // namespace diff --git a/tensorflow/compiler/xla/service/gpu/xfeed_queue.h b/tensorflow/compiler/xla/service/gpu/xfeed_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..dd46ff433ba0ad6bfa3999b96845fdaebe148aca --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/xfeed_queue.h @@ -0,0 +1,90 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_ + +#include +#include +#include + +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace xla { +namespace gpu { + +// TODO(b/30467474) Once GPU outfeed implementation settles, consider +// folding back the cpu and gpu outfeed implementations into a generic +// one if possible. + +// Manages a thread-safe queue of buffers. +template +class XfeedQueue { + public: + // Adds a tree of buffers to the queue. The individual buffers correspond to + // the elements of a tuple and may be nullptr if the buffer is a tuple index + // buffer. + void EnqueueDestination(BufferType buffers) { + tensorflow::mutex_lock l(mu_); + enqueued_buffers_.push_back(std::move(buffers)); + cv_.notify_one(); + } + + // Blocks until the queue is non-empty, then returns the buffer at the head of + // the queue. + BufferType BlockingGetNextDestination() { + bool became_empty; + BufferType current_buffer; + { + tensorflow::mutex_lock l(mu_); + while (enqueued_buffers_.empty()) { + cv_.wait(l); + } + current_buffer = std::move(enqueued_buffers_.front()); + enqueued_buffers_.pop_front(); + became_empty = enqueued_buffers_.empty(); + } + if (became_empty) { + for (const auto& callback : on_empty_callbacks_) { + callback(); + } + } + return current_buffer; + } + + void RegisterOnEmptyCallback(std::function callback) { + on_empty_callbacks_.push_back(std::move(callback)); + } + + private: + tensorflow::mutex mu_; + + // Condition variable that is signaled every time a buffer is enqueued. + tensorflow::condition_variable cv_; + + // The queue of trees of buffers. Buffer* queue contents are not owned. + std::deque enqueued_buffers_ GUARDED_BY(mu_); + + // List of callbacks which will be called when 'enqueued_buffers_' becomes + // empty. + std::vector> on_empty_callbacks_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_ diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index acf661148699dab18916e3065ee647d37fda6208..aa89567ee86e59e197045c0b51eed3b9aa59fef7 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -47,7 +48,7 @@ HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) { auto x_value = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "x_value")); auto half = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.5))); builder.AddInstruction(HloInstruction::CreateBinary( half->shape(), HloOpcode::kAdd, x_value, half)); return module->AddEmbeddedComputation(builder.Build()); @@ -122,7 +123,7 @@ std::unique_ptr MakeBigGraph() { auto rng = builder.AddInstruction( HloInstruction::CreateRng(vshape, RNG_UNIFORM, {param_m, param_m})); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add_computation = ScalarSumComputation(module.get()); builder.AddInstruction( HloInstruction::CreateReduce(vshape, rng, one, {1}, add_computation)); diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 3849b565e3136924b2d2b1929353885f85b1a043..b41dc66fe9f5e869a114be96b7cc01fc1a3d59da 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -239,7 +239,7 @@ class HeapSimulatorTest : public HloTestBase { TEST_F(HeapSimulatorTest, ScalarConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); // Constants aren't assigned. See b/32248867 HeapSimulatorTracker tracker(TestName(), builder.Build(), {const0}); @@ -674,7 +674,7 @@ class HeapAlgorithmTestBase : public ::testing::Test { const BufferValue* DummyBufferValue() { const BufferValue::Id id = buffers_.size(); auto const0 = builder_.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); buffers_.emplace_back(MakeUnique(id, const0, ShapeIndex{})); return buffers_.back().get(); } diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index d2417910606fdd13223076d33ff1bda1dd291d98..be9098f555e78f3cabfe55481356f8b6841a3a2b 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -151,10 +151,20 @@ message HloInstructionProto { // Backend configuration for the instruction. Has backend-specific meaning. string backend_config = 43; - // Cross Replica Sum fields. + // Cross replica op fields. + // TODO(b/112107579): remove replica_group_ids field and always use + // replica_groups. repeated int64 replica_group_ids = 44; + repeated ReplicaGroup replica_groups = 49; int64 all_reduce_id = 45; string cross_replica_sum_barrier = 46; + + // Whether this Send/Recv instruction transfers data to/from the host. Only + // present for Send and Recv instructions and their SendDone and RecvDone + // partners. + bool is_host_transfer = 47; + + xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; } // Serialization of HloComputation. @@ -239,8 +249,9 @@ message BufferAllocationProto { int64 index = 1; int64 size = 2; bool is_thread_local = 3; - bool is_reusable = 4; + bool is_tuple = 11; bool is_entry_computation_parameter = 5; + bool is_constant = 12; int64 parameter_number = 6; repeated int64 parameter_shape_index = 10; bool maybe_live_out = 7; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h index afb0c20f0cdf3eb92f72ab8bc368b4b8d723459e..1fea544730c27efdaa260f55ea81c163165f7ed5 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h @@ -42,7 +42,7 @@ class HloAliasAnalysis { static StatusOr> Run( HloModule* module, const HloDataflowAnalysis::FusionCanShareBufferFunction& - fusion_can_share_buffer = nullptr); + fusion_can_share_buffer); string ToString() const; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 8f18d50f6e033fab1c01f42017b951c224c22799..da94ab5346e5628b4a603b3ac2d84071904d1e65 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -47,7 +47,9 @@ class HloAliasAnalysisTest : public HloTestBase { // reference to the generated analysis stored in analysis_. HloAliasAnalysis& RunAnalysis() { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); - analysis_ = HloAliasAnalysis::Run(module_.get()).ConsumeValueOrDie(); + analysis_ = HloAliasAnalysis::Run(module_.get(), + /*fusion_can_share_buffer=*/nullptr) + .ConsumeValueOrDie(); return *analysis_; } @@ -116,9 +118,9 @@ TEST_F(HloAliasAnalysisTest, BinaryOperation) { // Test the analysis on a single binary operation (Add). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, constant1, constant2)); module_->AddEntryComputation(builder.Build()); @@ -228,9 +230,9 @@ TEST_F(HloAliasAnalysisTest, SingleCall) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); module_->AddEntryComputation(builder.Build()); @@ -267,9 +269,9 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto call1 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto call2 = builder.AddInstruction(HloInstruction::CreateCall( @@ -346,15 +348,15 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) { auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( @@ -439,15 +441,15 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while0 = builder.AddInstruction( @@ -498,7 +500,7 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); return cond_builder.Build(); }; // Build separate condition computations so the call graph is flat. The @@ -543,9 +545,9 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto entry_while = builder.AddInstruction( @@ -608,17 +610,17 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2, constant3})); auto xla_while = builder.AddInstruction( @@ -654,19 +656,18 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { } TEST_F(HloAliasAnalysisTest, TupleSelect) { - // Test a kSelect of a tuple value. Non-top-level element flow through the - // instruction. + // Test a kTupleSelect. Non-top-level element flow through the instruction. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -677,13 +678,13 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { builder.AddInstruction(HloInstruction::CreateTuple({constant4})); const Shape tuple_shape = tuple1->shape(); auto select11 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple1)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple1)); auto select12 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2)); auto select34 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple3, tuple4)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple3, tuple4)); auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, select12, select34)); + tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34)); module_->AddEntryComputation(builder.Build()); @@ -718,7 +719,7 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { } TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { - // Test a tuple-shaped kSelect feeding a kWhile instruction. HLO: + // Test a tuple-shaped kTupleSelect feeding a kWhile instruction. HLO: // // body((F32[], F32[]) %tuple_param): // %negate = Negate(%tuple_param{0}) @@ -754,22 +755,22 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({constant2})); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2)); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, select)); @@ -806,7 +807,7 @@ TEST_F(HloAliasAnalysisTest, Bitcast) { // Bitcasting a value should not produce a new buffer. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kBitcast, constant)); @@ -825,7 +826,7 @@ TEST_F(HloAliasAnalysisTest, BitcastInterference) { // interference. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kBitcast, constant)); builder.AddInstruction(HloInstruction::CreateTuple({constant, bitcast})); @@ -844,13 +845,13 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) { // the other use of the init. auto builder = HloComputation::Builder(TestName()); auto init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto cond_builder = HloComputation::Builder("condition"); auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, init->shape(), "param")); auto cond_root = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 34b18b0e21fbf6ce5d406cae9dbd64b9744f5a83..441288da1a6859a3f393a298ee02eb4b435e42e0 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -284,9 +284,8 @@ void HloComputation::set_root_instruction( if (!IsFusionComputation()) { CHECK(ShapeUtil::Compatible(new_root_instruction->shape(), root_instruction_->shape())) - << new_root_instruction->shape().ShortDebugString() - << " is incompatible with " - << root_instruction_->shape().ShortDebugString(); + << new_root_instruction->shape() << " is incompatible with " + << root_instruction_->shape(); } bool root_found = false; for (auto& instruction : instructions_) { @@ -529,8 +528,10 @@ HloInstruction* HloComputation::CreateFusionInstruction( } StatusOr HloComputation::DeepCopyHelper( - HloInstruction* instruction, const ShapeTree* indices_to_copy, - ShapeTree* copies_added, ShapeIndex* index) { + HloInstruction* instruction, ShapeIndex* index, + const std::function< + HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* computation)>& copy_leaf) { if (ShapeUtil::IsTuple(instruction->shape())) { std::vector elements; for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); @@ -541,9 +542,8 @@ StatusOr HloComputation::DeepCopyHelper( instruction, i)); index->push_back(i); - TF_ASSIGN_OR_RETURN( - HloInstruction * element, - DeepCopyHelper(gte, indices_to_copy, copies_added, index)); + TF_ASSIGN_OR_RETURN(HloInstruction * element, + DeepCopyHelper(gte, index, copy_leaf)); elements.push_back(element); index->pop_back(); } @@ -557,19 +557,7 @@ StatusOr HloComputation::DeepCopyHelper( // Array shape. TF_RET_CHECK(ShapeUtil::IsArray(instruction->shape())); - if (indices_to_copy == nullptr || indices_to_copy->element(*index)) { - // Use kCopy to copy array elements - HloInstruction* copy = AddInstruction(HloInstruction::CreateUnary( - instruction->shape(), HloOpcode::kCopy, instruction)); - if (copies_added != nullptr) { - *copies_added->mutable_element(*index) = copy; - } - return copy; - } else { - // Elements which are not to be copied are passed through - // transparently. - return instruction; - } + return copy_leaf(instruction, *index, this); } StatusOr HloComputation::DeepCopyInstruction( @@ -591,7 +579,36 @@ StatusOr HloComputation::DeepCopyInstruction( } ShapeIndex index; - return DeepCopyHelper(instruction, indices_to_copy, copies_added, &index); + auto copy_leaf = [indices_to_copy, copies_added]( + HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* computation) { + if (indices_to_copy == nullptr || indices_to_copy->element(leaf_index)) { + HloInstruction* copy = computation->AddInstruction( + HloInstruction::CreateUnary(leaf->shape(), HloOpcode::kCopy, leaf)); + if (copies_added != nullptr) { + *copies_added->mutable_element(leaf_index) = copy; + } + return copy; + } + // Elements which are not to be copied are passed through + // transparently. + return leaf; + }; + return DeepCopyHelper(instruction, &index, copy_leaf); +} + +StatusOr HloComputation::DeepCopyInstructionWithCustomCopier( + HloInstruction* instruction, + const std::function< + HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* computation)>& copy_leaf) { + if (instruction->parent() != this) { + return FailedPrecondition( + "Can't deep copy instruction %s: instruction is not in computation %s", + instruction->name().c_str(), name().c_str()); + } + ShapeIndex index; + return DeepCopyHelper(instruction, &index, copy_leaf); } ProgramShape HloComputation::ComputeProgramShape() const { @@ -664,7 +681,7 @@ std::unique_ptr HloComputation::ComputeReachability() inputs.assign(hlo->operands().begin(), hlo->operands().end()); inputs.insert(inputs.end(), hlo->control_predecessors().begin(), hlo->control_predecessors().end()); - result->SetReachabilityToUnion(inputs, hlo); + result->FastSetReachabilityToUnion(inputs, hlo); } return result; } @@ -881,4 +898,13 @@ void HloComputation::UniquifyName(NameUniquer* name_uniquer) { name_ = name_uniquer->GetUniqueName(name_); } +HloInstruction* HloComputation::GetInstructionWithName( + tensorflow::StringPiece name) { + auto instructions_in_computation = instructions(); + auto it = c_find_if(instructions_in_computation, [&](HloInstruction* instr) { + return instr->name() == name; + }); + return it == instructions_in_computation.end() ? nullptr : *it; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index c1c3e79ebc789eff0873515c5fffd11089b92043..49ed65910f519810740b89760ad815f287e59a91 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ +#include #include #include #include @@ -254,6 +255,14 @@ class HloComputation { const ShapeTree* indices_to_copy = nullptr, ShapeTree* copies_added = nullptr); + // As above, but uses a custom function to copy the leaf nodes, which could + // create alternative HLOs other than kCopy, or even pass-throughs. + StatusOr DeepCopyInstructionWithCustomCopier( + HloInstruction* instruction, + const std::function< + HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* computation)>& copy_leaf); + // Computes and returns the ProgramShape of this computation (shape of // parameters and result with layout). ProgramShape ComputeProgramShape() const; @@ -356,6 +365,10 @@ class HloComputation { unique_id_ = id; } + // Returns the instruction in this computation that has name `name`. Returns + // null if there is no such computation. + HloInstruction* GetInstructionWithName(tensorflow::StringPiece name); + int64 unique_id() const { return unique_id_; } private: @@ -378,8 +391,10 @@ class HloComputation { // Internal helper for recursive copying of an instruction. Creates and // returns a deep copy of the given instruction. StatusOr DeepCopyHelper( - HloInstruction* instruction, const ShapeTree* indices_to_copy, - ShapeTree* copies_added, ShapeIndex* index); + HloInstruction* instruction, ShapeIndex* index, + const std::function< + HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* computation)>& copy_leaf); // Internal helper to collect unreachable roots. std::vector CollectUnreachableRoots() const; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index a8f3f0e9c2dca8fb97ebc8f8c9dd80fcf7f4de4a..e4c547033139185d5dd4ef37db2d22a6431c1102 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -118,7 +118,7 @@ TEST_F(HloComputationTest, PostOrderSingleton) { // Test GetInstructionPostOrder for a computation with one instruction. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant)); @@ -129,7 +129,7 @@ TEST_F(HloComputationTest, PostOrderSimple) { // instructions. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto negate1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto negate2 = builder.AddInstruction( @@ -144,7 +144,7 @@ TEST_F(HloComputationTest, PostOrderTrace) { // Test GetInstructionPostOrder for a computation with a trace instruction. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto negate1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto trace = @@ -163,13 +163,13 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { // which are not connected. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), @@ -181,11 +181,11 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { // which are not connected. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -205,11 +205,11 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { // computation has multiple roots (dead code). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); // Add three disconnected add expressions. builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, constant1, constant2)); @@ -256,7 +256,7 @@ TEST_F(HloComputationTest, DeepCopyArray) { // Test that DeepCopyInstruction properly copies an array. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); @@ -268,9 +268,9 @@ TEST_F(HloComputationTest, DeepCopyTuple) { // Test that DeepCopyInstruction properly copies a tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -289,7 +289,7 @@ TEST_F(HloComputationTest, DeepCopyArrayAtIndices) { // copy are specified. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto computation = builder.Build(); { @@ -314,9 +314,9 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { // specified by the given indices. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto computation = builder.Build(); @@ -375,7 +375,7 @@ TEST_F(HloComputationTest, DeepCopyToken) { // Test that DeepCopyInstruction properly handles tokens which should not be // copied. auto builder = HloComputation::Builder(TestName()); - auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(token).ValueOrDie(); @@ -388,9 +388,9 @@ TEST_F(HloComputationTest, DeepCopyTokenTuple) { // Test that DeepCopyInstruction properly handles tokens which should not be // copied. auto builder = HloComputation::Builder(TestName()); - auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({token, constant})); auto module = CreateNewModule(); @@ -407,7 +407,7 @@ TEST_F(HloComputationTest, CycleDetection) { // Test whether the visitor can detect cycles in the graph. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto add = builder.AddInstruction( @@ -433,7 +433,7 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { // twice. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto dead_negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary( @@ -456,9 +456,9 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { TEST_F(HloComputationTest, CloneWithControlDependency) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); @@ -502,9 +502,9 @@ TEST_F(HloComputationTest, Reachability) { // There is a control dependency from 'add' to 'exp'. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); auto negate = builder.AddInstruction( @@ -607,13 +607,14 @@ TEST_F(HloComputationTest, Stringification) { auto* computation = module->AddEntryComputation(builder.Build()); auto options = HloPrintOptions().set_print_metadata(false); - EXPECT_EQ(computation->ToString(options), - R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + const string expected_computation = + R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { %x = f32[5,10]{1,0} parameter(0) %y = f32[20,10]{1,0} parameter(1) %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"); +})"; + EXPECT_EQ(computation->ToString(options), expected_computation); } TEST_F(HloComputationTest, StringificationIndent) { @@ -639,13 +640,14 @@ TEST_F(HloComputationTest, StringificationIndent) { auto options = HloPrintOptions().set_print_metadata(false).set_indent_amount(2); - EXPECT_EQ(computation->ToString(options), - R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + const string expected_computation = + R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { %x = f32[5,10]{1,0} parameter(0) %y = f32[20,10]{1,0} parameter(1) %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} - })"); + })"; + EXPECT_EQ(computation->ToString(options), expected_computation); } TEST_F(HloComputationTest, StringificationCanonical) { @@ -670,21 +672,23 @@ TEST_F(HloComputationTest, StringificationCanonical) { auto* computation = module->AddEntryComputation(builder.Build()); auto options = HloPrintOptions().set_print_metadata(false); - EXPECT_EQ(computation->ToString(options), - R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + const string expected_computation1 = + R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { %x = f32[5,10]{1,0} parameter(0) %y = f32[20,10]{1,0} parameter(1) %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"); +})"; + EXPECT_EQ(computation->ToString(options), expected_computation1); options = HloPrintOptions().Canonical(); - EXPECT_EQ(computation->ToString(options), R"(TransposeDot { + const string expected_computation2 = R"(TransposeDot { tmp_0 = f32[5,10]{1,0} parameter(0) tmp_1 = f32[20,10]{1,0} parameter(1) tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"); +})"; + EXPECT_EQ(computation->ToString(options), expected_computation2); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 35ecd4428d0dfde2de445ea34472d2c78148c6c9..7229031c0c7f8bd374cfb495c7d8c11e9ca8b95e 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" @@ -51,14 +51,18 @@ StatusOr HloConstantFolding::Run(HloModule* module) { computation->root_instruction() != instruction) { continue; } - // Skip Constant, Parameter, Reduce operation. + // Skip Constant, Parameter, Reduce, and AfterAll operation. // TODO(b/35975797): Enable Reduce operation once arbitrary computation // are supported by the evaluator. // TODO(b/64407269): Enable Tuple once the timeout issue is resolved. + // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one + // operand in which case constant folding will be impossible and this + // special case is not necessary. if (instruction->opcode() == HloOpcode::kParameter || instruction->opcode() == HloOpcode::kConstant || instruction->opcode() == HloOpcode::kTuple || - instruction->opcode() == HloOpcode::kReduce) { + instruction->opcode() == HloOpcode::kReduce || + instruction->opcode() == HloOpcode::kAfterAll) { continue; } // Skip instructions with non-constant operands. diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 5d05ccfc0b223d8749a2577ba1bf96b1ab3e761b..64a42c1efc0c788ae8e66fb72b2d9aecec179082 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -41,7 +41,7 @@ using HloConstantFoldingTest = HloTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); @@ -62,7 +62,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); @@ -82,8 +82,8 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { HloComputation::Builder builder(TestName()); - HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({42.0f, 19.0f}))); + HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({42.0f, 19.0f}))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); @@ -120,7 +120,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) { for (auto csize : test_config.concat_sizes) { dimensions[test_config.concat_dimension] = csize; concat_size += csize; - auto literal = Literal::CreateFromDimensions(F32, dimensions); + auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions); HloInstruction* insn = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); operands.push_back(insn); @@ -149,7 +149,7 @@ TEST_F(HloConstantFoldingTest, Slice) { const int64 slice_limits[] = {10, 8, 6, 5, 9}; const int64 slice_strides[] = {1, 1, 1, 1, 1}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - Literal::CreateRandomLiteral( + LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -172,7 +172,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { HloComputation::Builder builder(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - Literal::CreateRandomLiteral( + LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->Literal::CloneToUnique(); HloInstruction* literal_instruction = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 8955e26d5cd1bf30f965395750f5078d070a6906..1bbb0ff08e26f626f4c3992a5f20ec4990f7db2d 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -49,9 +49,9 @@ Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) { // The default number of bytes accessed for an instruction is the sum of the // sizes of the inputs and outputs. The default ShapeUtil::ByteSizeOf does not // handle opaque types. - float bytes_accessed = shape_size_(hlo->shape()); + float bytes_accessed = GetShapeSize(hlo->shape()); for (const HloInstruction* operand : hlo->operands()) { - bytes_accessed += shape_size_(operand->shape()); + bytes_accessed += GetShapeSize(operand->shape()); } current_properties_[kBytesAccessedKey] = bytes_accessed; @@ -121,6 +121,13 @@ Status HloCostAnalysis::HandleElementwiseOp( } } +int64 HloCostAnalysis::GetShapeSize(const Shape& shape) const { + if (!LayoutUtil::HasLayout(shape)) { + return 0; + } + return shape_size_(shape); +} + Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) { return HandleElementwiseOp(hlo); } @@ -155,6 +162,10 @@ Status HloCostAnalysis::HandleConstant(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleIota(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) { // GetTupleElement forwards a pointer and does not touch each element in the // output. @@ -164,7 +175,11 @@ Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleSelect(const HloInstruction*) { +Status HloCostAnalysis::HandleSelect(const HloInstruction* hlo) { + return HandleElementwiseOp(hlo); +} + +Status HloCostAnalysis::HandleTupleSelect(const HloInstruction*) { return Status::OK(); } @@ -173,21 +188,21 @@ Status HloCostAnalysis::HandleReverse(const HloInstruction*) { } Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) { - current_properties_[kBytesAccessedKey] = shape_size_(slice->shape()) * 2; + current_properties_[kBytesAccessedKey] = GetShapeSize(slice->shape()) * 2; return Status::OK(); } Status HloCostAnalysis::HandleDynamicSlice( const HloInstruction* dynamic_slice) { current_properties_[kBytesAccessedKey] = - shape_size_(dynamic_slice->shape()) * 2; + GetShapeSize(dynamic_slice->shape()) * 2; return Status::OK(); } Status HloCostAnalysis::HandleDynamicUpdateSlice( const HloInstruction* dynamic_update_slice) { current_properties_[kBytesAccessedKey] = - shape_size_(dynamic_update_slice->operand(1)->shape()) * 2; + GetShapeSize(dynamic_update_slice->operand(1)->shape()) * 2; return Status::OK(); } @@ -196,7 +211,7 @@ Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) { // through them). The memory touched is then only the size of the output // index table of the tuple. - current_properties_[kBytesAccessedKey] = shape_size_(tuple->shape()); + current_properties_[kBytesAccessedKey] = GetShapeSize(tuple->shape()); return Status::OK(); } @@ -518,12 +533,25 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { // TODO(b/33004697): Compute correct cost here, taking the actual number of // replicas into account. double flops = 0.0; - ShapeUtil::ForEachSubshape( - crs->shape(), [&, this](const Shape& subshape, const ShapeIndex&) { - if (ShapeUtil::IsArray(subshape)) { - flops += ShapeUtil::ElementsIn(subshape); - } - }); + ShapeUtil::ForEachSubshape(crs->shape(), + [&](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsArray(subshape)) { + flops += ShapeUtil::ElementsIn(subshape); + } + }); + current_properties_[kFlopsKey] = flops; + return Status::OK(); +} + +Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) { + // TODO(b/110096724): Compute correct cost here. + double flops = 0.0; + ShapeUtil::ForEachSubshape(hlo->shape(), + [&](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsArray(subshape)) { + flops += ShapeUtil::ElementsIn(subshape); + } + }); current_properties_[kFlopsKey] = flops; return Status::OK(); } @@ -538,15 +566,9 @@ Status HloCostAnalysis::HandleRng(const HloInstruction* random) { } Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { - // Compute the properties of the fused expression and attribute them to the - // fusion node. Use a dummy shape_size to avoid any errors from trying to - // calculate the size of a shape that does not have a layout, since nodes - // inside fusion nodes do not necessarily have a layout assigned. - ShapeSizeFunction shape_size = [](const Shape& shape) { return 0; }; TF_ASSIGN_OR_RETURN( current_properties_, - ProcessSubcomputation(fusion->fused_instructions_computation(), - &shape_size)); + ProcessSubcomputation(fusion->fused_instructions_computation())); // Fusion nodes that produce a tuple also produce the entries in the tuple. // Ignore the memory accessed inside fused ops, since fusion is supposed to @@ -555,11 +577,11 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { ShapeUtil::ForEachSubshape( fusion->shape(), [this](const Shape& subshape, const ShapeIndex& /*shape_index*/) { - current_properties_[kBytesAccessedKey] += shape_size_(subshape); + current_properties_[kBytesAccessedKey] += GetShapeSize(subshape); }); for (const HloInstruction* operand : fusion->operands()) { - current_properties_[kBytesAccessedKey] += shape_size_(operand->shape()); + current_properties_[kBytesAccessedKey] += GetShapeSize(operand->shape()); } return Status::OK(); @@ -640,6 +662,11 @@ Status HloCostAnalysis::HandleGather(const HloInstruction* gather) { return Status::OK(); } +Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) { + // TODO(b/32945756): Compute the properties of the sub-computation. + return Status::OK(); +} + Status HloCostAnalysis::FinishVisit(const HloInstruction*) { return Status::OK(); } @@ -677,11 +704,8 @@ float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const { } StatusOr HloCostAnalysis::ProcessSubcomputation( - HloComputation* computation, const ShapeSizeFunction* shape_size) { - if (shape_size == nullptr) { - shape_size = &shape_size_; - } - HloCostAnalysis visitor(*shape_size, per_second_rates_); + HloComputation* computation) { + HloCostAnalysis visitor(shape_size_, per_second_rates_); TF_RETURN_IF_ERROR(computation->Accept(&visitor)); return visitor.properties(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 44e5df587c4bf0b3004c8d624c45d42d258c3661..193a04bea0831de2b3aca19b17a445ad73e02e49 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -52,9 +52,11 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleElementwiseUnary(const HloInstruction* hlo) override; Status HandleElementwiseBinary(const HloInstruction* hlo) override; Status HandleConstant(const HloInstruction* constant) override; + Status HandleIota(const HloInstruction* iota) override; Status HandleGetTupleElement( const HloInstruction* get_tuple_element) override; - Status HandleSelect(const HloInstruction* select) override; + Status HandleSelect(const HloInstruction* hlo) override; + Status HandleTupleSelect(const HloInstruction* hlo) override; Status HandleCompare(const HloInstruction* compare) override; Status HandleClamp(const HloInstruction* clamp) override; Status HandleReducePrecision(const HloInstruction* hlo) override; @@ -69,6 +71,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleConvolution(const HloInstruction* convolution) override; Status HandleFft(const HloInstruction* fft) override; Status HandleCrossReplicaSum(const HloInstruction* crs) override; + Status HandleAllToAll(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; Status HandleHostCompute(const HloInstruction* host_compute) override; @@ -102,6 +105,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; Status HandleGather(const HloInstruction* gather) override; + Status HandleScatter(const HloInstruction* scatter) override; Status FinishVisit(const HloInstruction* root) override; Status Preprocess(const HloInstruction* hlo) override; @@ -147,11 +151,8 @@ class HloCostAnalysis : public ConstDfsHloVisitor { const Properties& per_second_rates); // Returns the properties computed from visiting the computation rooted at the - // given hlo. Uses shape_size_ to calculate shape sizes if shape_size is null, - // otherwise uses shape_size_. - StatusOr ProcessSubcomputation( - HloComputation* computation, - const ShapeSizeFunction* shape_size = nullptr); + // given hlo. + StatusOr ProcessSubcomputation(HloComputation* computation); // Utility function to handle all element-wise operations. Status HandleElementwiseOp(const HloInstruction* hlo_instruction); @@ -168,6 +169,10 @@ class HloCostAnalysis : public ConstDfsHloVisitor { static float GetPropertyForHlo(const HloInstruction& hlo, const string& key, const HloToProperties& hlo_to_properties); + // Decorates shape_size_ by returning 0 immediately if the shape does not have + // a layout. + int64 GetShapeSize(const Shape& shape) const; + // Function which computes the size of the top-level of a given shape (not // including nested elements, if any). If null then bytes_accessed methods // return an error. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 9fc4c48226fa5307f5e030a612f3957756827e37..2c854eea18642eb7cb081b4fdfe3bc83627e41ae 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/service.h" @@ -338,13 +338,13 @@ TEST_F(FusionCostAnalysis, LoopFusion) { // tuple = Tuple({sub, sub, mul, C1}) HloComputation::Builder builder(TestName()); auto c1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2))); auto c2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2))); auto c3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2)); @@ -391,9 +391,9 @@ TEST_F(FusionCostAnalysis, NoLayout) { HloComputation::Builder builder(TestName()); auto c1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR4FromArray4D(Array4D(2, 3, 4, 5)))); + LiteralUtil::CreateR4FromArray4D(Array4D(2, 3, 4, 5)))); auto c2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(shape_without_layout, c2, {1})); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 0fb65c845a6d4407c81171f6c1569fee98b1d16d..90d2be118d94d52135820e5b8138fcb06389c684 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -261,9 +262,9 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, padding_config_dim.set_edge_padding_high(zeros_to_append); *padding_config.add_dimensions() = padding_config_dim; - HloInstruction* zero = - computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(Literal::Zero(operand->shape().element_type())))); + HloInstruction* zero = computation->AddInstruction( + HloInstruction::CreateConstant(MakeUnique( + LiteralUtil::Zero(operand->shape().element_type())))); return MakePadHlo(operand, zero, padding_config); } @@ -272,7 +273,7 @@ StatusOr BroadcastZeros( ArraySlice broadcast_dimensions) { HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(Literal::Zero(element_type)))); + MakeUnique(LiteralUtil::Zero(element_type)))); return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, /*result_shape_bounds=*/broadcast_dimensions); } diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index 7e7c4f95fed737f40064224717f409b934e4ff27..60d3e71757d5ce31e025c744e089ff56091d9a43 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -60,8 +60,8 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, evaluator.Evaluate>( - *module, {Literal::CreateR1({3, 4})})); - CHECK_EQ(*result_literal, *Literal::CreateR1({3, 4})); + *module, {LiteralUtil::CreateR1({3, 4})})); + CHECK_EQ(*result_literal, *LiteralUtil::CreateR1({3, 4})); } TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { @@ -82,10 +82,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { std::unique_ptr result_literal, evaluator.Evaluate>( *module, - {Literal::CreateR3( + {LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})})); CHECK_EQ(*result_literal, - *Literal::CreateR2( + *LiteralUtil::CreateR2( {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}})); } @@ -103,10 +103,11 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { entry_computation->set_root_instruction(with_1_degenerate_dim_prepended); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {Literal::CreateR1({9, 10})})); - CHECK_EQ(*result_literal, *Literal::CreateR2({{9, 10}})); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {LiteralUtil::CreateR1({9, 10})})); + CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{9, 10}})); } TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { @@ -123,10 +124,11 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { entry_computation->set_root_instruction(with_2_degenerate_dims_prepended); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {Literal::CreateR1({9, 10})})); - CHECK_EQ(*result_literal, *Literal::CreateR3({{{9, 10}}})); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {LiteralUtil::CreateR1({9, 10})})); + CHECK_EQ(*result_literal, *LiteralUtil::CreateR3({{{9, 10}}})); } TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { @@ -145,8 +147,8 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, evaluator.Evaluate>( - *module, {Literal::CreateR0(9)})); - CHECK_EQ(*result_literal, *Literal::CreateR2({{9}})); + *module, {LiteralUtil::CreateR0(9)})); + CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{9}})); } TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { @@ -166,9 +168,9 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr result_literal, evaluator.Evaluate>( - *module, {Literal::CreateR1({1, 2, 3, 4, 5, 6})})); + *module, {LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6})})); CHECK_EQ(*result_literal, - *Literal::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); + *LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); } TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { @@ -188,8 +190,8 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, evaluator.Evaluate>( - *module, {Literal::CreateR1({3, 4})})); - CHECK_EQ(*result_literal, *Literal::CreateR1({0, 0, 0, 3, 4, 0})); + *module, {LiteralUtil::CreateR1({3, 4})})); + CHECK_EQ(*result_literal, *LiteralUtil::CreateR1({0, 0, 0, 3, 4, 0})); } TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { @@ -209,8 +211,8 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, evaluator.Evaluate>( - *module, {Literal::CreateR0(0)})); - CHECK_EQ(*result_literal, *Literal::CreateR2({{0, 0}, {0, 0}})); + *module, {LiteralUtil::CreateR0(0)})); + CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{0, 0}, {0, 0}})); } TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { @@ -230,9 +232,9 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, evaluator.Evaluate>( - *module, {Literal::CreateR0(0.0f)})); + *module, {LiteralUtil::CreateR0(0.0f)})); CHECK_EQ(*result_literal, - *Literal::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); + *LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index a0ee8896230d6dcacb5a8eb607fc00ae5226cfa5..06484f4012fc091f70df7bc8ec231ce3fcf89669 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -24,7 +24,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_domain_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -143,10 +143,8 @@ StatusOr HloCSE::Run(HloModule* module) { if (instruction->operand_count() == 0) { continue; } - // Skip instructions which have side effects or are a domain (which must - // not be CSE-ed). - if (instruction->HasSideEffect() || - instruction->opcode() == HloOpcode::kDomain) { + // Skip instructions which have side effects. + if (instruction->HasSideEffect()) { continue; } diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 16db374566c727f1f3efe2a6d419f1f3caf0aaf1..90fbaa37c5a70a78a9a818b4a8968f3406c671b1 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -53,9 +53,9 @@ TEST_F(HloCseTest, CombineTwoConstants) { // Test that two identical constants are commoned. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -72,7 +72,7 @@ TEST_F(HloCseTest, CombineTwoConstants) { EXPECT_EQ(42.0f, constant->literal().Get({})); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = Literal::CreateR0(84.0); + auto expected = LiteralUtil::CreateR0(84.0); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } @@ -81,10 +81,10 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { // the pass is not layout sensitive. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -104,7 +104,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { EXPECT_THAT(add, op::Add(first_operand, first_operand)); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); + auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } @@ -113,10 +113,10 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { // if the pass is layout sensitive. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -134,7 +134,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { EXPECT_THAT(add, op::Add(constant1, constant2)); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); + auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } @@ -144,20 +144,20 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { auto builder = HloComputation::Builder(TestName()); std::vector constants; constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0)))); constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)))); // Duplicate the float constant to verify something happens. constants.push_back(builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f)))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)))); const Shape shape_r0 = ShapeUtil::MakeShape(F32, {}); for (int64 i = 0; i < constants.size(); ++i) { @@ -188,13 +188,13 @@ TEST_F(HloCseTest, NonscalarConstants) { // Test that identical nonscalar constants are merged. auto builder = HloComputation::Builder(TestName()); auto common_constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto common_constant2 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); // Create a constant which has the same shape but a different value. auto uncommon_constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}))); + LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}))); // Tie the constants together with a tuple. This makes it easier to refer to // the constant instructions via their use. @@ -223,7 +223,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { // Test that three identical instructions are commoned. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -239,7 +239,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { EXPECT_EQ(5, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3)); - HloCSE cse(/*is_layout_sensitive=*/false); + HloCSE cse(/*is_layout_sensitive=*/true); EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); @@ -248,12 +248,189 @@ TEST_F(HloCseTest, IdenticalInstructions) { EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand, first_operand)); } +// Test two identical while loops with same inputs +TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) { + auto module = ParseHloString(R"( + HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput + + %body (param: (f32[], f32[])) -> (f32[], f32[]) { + %param = (f32[], f32[]) parameter(0) + %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param), +index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param), +index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1) + ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add) + } + + %condition (param.1: (f32[], f32[])) -> pred[] { + %param.1 = (f32[], f32[]) parameter(0) + ROOT %constant = pred[] constant(false) + } + + %condition.1 (param.2: (f32[], f32[])) -> pred[] { + %param.2 = (f32[], f32[]) parameter(0) + ROOT %constant.1 = pred[] constant(false) + } + + ENTRY %WhileLoopsIdenticalConditionsAndBodiesSameInput () -> (f32[], f32[]) +{ %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) %tuple.1 = +(f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3) %while = (f32[], +f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT +%while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), +condition=%condition.1, body=%body + } + )") + .ValueOrDie(); + + auto computation = module->entry_computation(); + + EXPECT_EQ(5, computation->instruction_count()); + HloCSE cse(true); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_EQ(4, computation->instruction_count()); +} + +// Test two while loops with same conditions, same inputs, but different +// bodies +TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) { + auto module = ParseHloString(R"( + HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies + + %body (param: (f32[], f32[])) -> (f32[], f32[]) { + %param = (f32[], f32[]) parameter(0) + %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param), +index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param), +index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1) + ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add) + } + + %body2 (param.1: (f32[], f32[])) -> (f32[], f32[]) { + %param.1 = (f32[], f32[]) parameter(0) + %get-tuple-element.2 = f32[] get-tuple-element((f32[], f32[]) %param.1), +index=0 %get-tuple-element.3 = f32[] get-tuple-element((f32[], f32[]) %param.1), +index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[] +%get-tuple-element.3) ROOT %tuple.2 = (f32[], f32[]) tuple(f32[] +%get-tuple-element.2, f32[] %sub) + } + + %condition (param.2: (f32[], f32[])) -> pred[] { + %param.2 = (f32[], f32[]) parameter(0) + ROOT %constant = pred[] constant(false) + } + + %condition.1 (param.3: (f32[], f32[])) -> pred[] { + %param.3 = (f32[], f32[]) parameter(0) + ROOT %constant.1 = pred[] constant(false) + } + + ENTRY %WhileLoopsIdenticalConditionsSameInputAndDifferentBodies () -> +(f32[], f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) + %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3) + %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), +condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], +f32[]) %tuple.1), condition=%condition.1, body=%body2 + } + )") + .ValueOrDie(); + + auto computation = module->entry_computation(); + + EXPECT_EQ(5, computation->instruction_count()); + HloCSE cse(true); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_EQ(5, computation->instruction_count()); +} + +// Test two identical while loops with different inputs +TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) { + auto module = ParseHloString(R"( + HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput + + %body (param: (f32[], f32[])) -> (f32[], f32[]) { + %param = (f32[], f32[]) parameter(0) + %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param), +index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param), +index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1) + ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add) + } + + %condition (param.1: (f32[], f32[])) -> pred[] { + %param.1 = (f32[], f32[]) parameter(0) + ROOT %constant = pred[] constant(false) + } + + %condition.1 (param.2: (f32[], f32[])) -> pred[] { + %param.2 = (f32[], f32[]) parameter(0) + ROOT %constant.1 = pred[] constant(false) + } + + ENTRY %WhileLoopsIdenticalConditionsAndBodiesDifferentInput () -> (f32[], +f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) + %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3) + %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), +condition=%condition, body=%body %constant.4 = f32[] constant(1) %constant.5 = +f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[] +%constant.5) ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.2), +condition=%condition.1, body=%body + } + + )") + .ValueOrDie(); + + auto computation = module->entry_computation(); + + EXPECT_EQ(8, computation->instruction_count()); + HloCSE cse(true); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_EQ(8, computation->instruction_count()); +} + +// Test two while loops with identical bodies and same inputs, but different +// conditions +TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) { + auto module = ParseHloString(R"( + HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions + + %body (param: (f32[], f32[])) -> (f32[], f32[]) { + %param = (f32[], f32[]) parameter(0) + %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param), +index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param), +index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1) + ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add) + } + + %condition (param.1: (f32[], f32[])) -> pred[] { + %param.1 = (f32[], f32[]) parameter(0) + ROOT %constant = pred[] constant(false) + } + + %condition.1 (param.2: (f32[], f32[])) -> pred[] { + %param.2 = (f32[], f32[]) parameter(0) + ROOT %constant.1 = pred[] constant(true) + } + + ENTRY %WhileLoopsIdenticalBodiesAndInputDifferntConditions () -> (f32[], +f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) + %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3) + %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), +condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], +f32[]) %tuple.1), condition=%condition.1, body=%body + })") + .ValueOrDie(); + + auto computation = module->entry_computation(); + + EXPECT_EQ(5, computation->instruction_count()); + HloCSE cse(true); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_EQ(5, computation->instruction_count()); +} + TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { // Test that two identical instructions with different layouts are *not* // commoned if the pass is layout sensitive. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); @@ -284,7 +461,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { // the pass is layout insensitive. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); @@ -362,7 +539,7 @@ TEST_F(HloCseTest, IdenticalExpressions) { // The *1 instructions should be merged with the *2 instructions. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto negate1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kNegate, constant)); @@ -400,9 +577,9 @@ TEST_F(HloCseTest, DoNotCombineRng) { // Test that two RNG ops are not commoned. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); auto rng1 = builder.AddInstruction(HloInstruction::CreateRng( ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM, {constant1, constant2})); @@ -442,9 +619,9 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); auto builder = HloComputation::Builder(TestName() + "_rng_fun"); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); auto rng = builder.AddInstruction(HloInstruction::CreateRng( scalar_shape, RandomDistribution::RNG_UNIFORM, {constant1, constant2})); auto param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -459,7 +636,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({5.0f}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({5.0f}))); auto rng1 = builder.AddInstruction( HloInstruction::CreateMap(constant->shape(), {constant}, rng_function)); auto rng2 = builder.AddInstruction( @@ -521,9 +698,9 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { // in this case) are not collapsed. auto builder = HloComputation::Builder(TestName()); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -536,5 +713,40 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { EXPECT_EQ(2, computation->instruction_count()); } +TEST_F(HloCseTest, Domain) { + auto module = ParseHloString(R"( +HloModule module +ENTRY %entry { + %param = f32[] parameter(0), sharding={maximal device=0} + %domain.0 = f32[] domain(%param), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + %domain.1 = f32[] domain(%param), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + %domain.2 = f32[] domain(%param), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=2}} + %negate.0 = f32[] negate(%domain.0) + %negate.1 = f32[] negate(%domain.1) + %negate.2 = f32[] negate(%domain.2) + %domain.3 = f32[] domain(%negate.0), + domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} + %domain.4 = f32[] domain(%negate.1), + domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} + %domain.5 = f32[] domain(%negate.2), + domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}} + %add = f32[] add(%domain.3, %domain.4) + ROOT %sub = f32[] subtract(%add, %domain.5) +})") + .ValueOrDie(); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + LOG(INFO) << "AAAAA " << module->ToString(); + const HloInstruction* sub = module->entry_computation()->root_instruction(); + const HloInstruction* add = sub->operand(0); + EXPECT_EQ(add->operand(0), add->operand(1)); + EXPECT_NE(add->operand(0), sub->operand(1)); + EXPECT_NE(add->operand(1), sub->operand(1)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 8a4a9b59868eb436842c9a819ffa8d6ec2054eee..bbfb0c253f583b633c4b2c34b2f068b563d3d9e0 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -398,18 +398,17 @@ bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) { bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) { CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone); bool changed = false; - // RecvDone forwards the operand value at {0} to the output. + // RecvDone forwards the operand value at {0} to element {0} of its output. for (auto& pair : GetInstructionValueSet(recv_done)) { ShapeIndex& index = pair.first; HloValueSet& value_set = pair.second; - ShapeIndex operand_index = {0}; - for (int64 i : index) { - operand_index.push_back(i); + if (index.empty() || index[0] != 0) { + continue; } const HloValueSet& operand_value_set = - GetValueSet(recv_done->operand(0), operand_index); + GetValueSet(recv_done->operand(0), index); if (value_set != operand_value_set) { value_set = operand_value_set; changed = true; @@ -578,17 +577,17 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { } } -bool HloDataflowAnalysis::UpdateSelectValueSet(HloInstruction* select) { - CHECK_EQ(select->opcode(), HloOpcode::kSelect); - // A phi value is not defined at a kSelect instruction because kSelect does - // not create a new value. Rather it forwards a value from its operands. This - // contrasts with kWhile instruction (which does define a phi value) which has - // in-place update semantics. +bool HloDataflowAnalysis::UpdateTupleSelectValueSet(HloInstruction* select) { + CHECK_EQ(select->opcode(), HloOpcode::kTupleSelect); + // A phi value is not defined at a kTupleSelect instruction because + // kTupleSelect does not create a new value. Rather it forwards a value from + // its operands. This contrasts with kWhile instruction (which does define a + // phi value) which has in-place update semantics. bool changed = false; for (auto& pair : GetInstructionValueSet(select)) { const ShapeIndex& index = pair.first; if (index.empty()) { - // kSelect copies (not forwards) the top-level value. + // kTupleSelect copies (not forwards) the top-level value. continue; } HloValueSet& value_set = pair.second; @@ -650,8 +649,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( return UpdateCopyValueSet(instruction); case HloOpcode::kGetTupleElement: return UpdateGetTupleElementValueSet(instruction); - case HloOpcode::kSelect: - return UpdateSelectValueSet(instruction); + case HloOpcode::kTupleSelect: + return UpdateTupleSelectValueSet(instruction); case HloOpcode::kTuple: return UpdateTupleValueSet(instruction); case HloOpcode::kParameter: @@ -850,21 +849,25 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { } break; case HloOpcode::kCopy: - case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: case HloOpcode::kTuple: // These instructions only define their top-level values. Any other // values flow from their operands. define_top_level_only(); break; case HloOpcode::kRecvDone: - // RecvDone aliases its input tuple element {0}, therefore does not - // define any values. + // RecvDone produces a two-element tuple. Element zero aliases its + // input tuple element {0}; element one is a token. + define_value_at(/*index=*/{}); + define_value_at(/*index=*/{1}); break; case HloOpcode::kSend: - // Send produces a tuple of {aliased operand, U32 context}, therefore - // only defines the top-level tuple and the tuple element at {1}. + // Send produces a tuple of {aliased operand, U32 context, token}, + // therefore only defines the top-level tuple and the tuple elements + // at {1} and {2}. define_value_at(/*index=*/{}); define_value_at(/*index=*/{1}); + define_value_at(/*index=*/{2}); break; default: define_all_values(); @@ -1014,19 +1017,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( } if (user->opcode() == HloOpcode::kFusion) { + if (fusion_can_share_buffer_ != nullptr) { + return fusion_can_share_buffer_(user, operand); + } // Get the parameter associated with 'operand'; HloInstruction* fusion_param = user->fused_parameter(user->operand_index(operand)); const HloValue& value = GetValueDefinedAt(fusion_param, operand_index); - if (value.uses().size() != 1) { - if (MultiDynamicSliceUseShareSameIndices(value.uses())) { - return true; - } - return false; + if (MultiDynamicSliceUseShareSameIndices(value.uses())) { + return true; } - const HloUse& use = value.uses()[0]; - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop || user->fusion_kind() == HloInstruction::FusionKind::kInput) { if (user->fused_expression_root()->opcode() == @@ -1036,13 +1037,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // Returns true iff there is exactly one use of 'operand' at shape index // 'operand_index', and this singleton use is the fused root at operand // index 0. - return use.instruction == user->fused_expression_root() && - use.operand_number == 0; - } else { - return AreTransitiveUsesElementwiseOrTuple(fusion_param); + if (value.uses().size() == 1) { + const HloUse& use = value.uses()[0]; + return use.instruction == user->fused_expression_root() && + use.operand_number == 0; + } + return false; } - } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && - user->fused_expression_root()->opcode() == HloOpcode::kAdd) { + return AreTransitiveUsesElementwiseOrTuple(fusion_param); + } + if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. // Check if one operand of kAdd fused root is kDot or kConvolution. @@ -1063,11 +1068,12 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // Returns true iff there is exactly one use of 'operand' at shape index // 'operand_index', and this singleton use is the fused root (at operand // index 'other_add_operand_index'). - return use.instruction == user->fused_expression_root() && - use.operand_number == other_add_operand_index; - } else if (fusion_can_share_buffer_ != nullptr && - fusion_can_share_buffer_(user, operand)) { - return true; + if (value.uses().size() == 1) { + const HloUse& use = value.uses()[0]; + return use.instruction == user->fused_expression_root() && + use.operand_number == other_add_operand_index; + } + return false; } } @@ -1078,6 +1084,21 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( std::vector operand_indices = user->OperandIndices(operand); return operand_indices.size() == 1 && operand_indices[0] == 0; } + if (user->opcode() == HloOpcode::kSort) { + // Only valid if there are no other users. + if (operand->users().size() != 1) { + return false; + } + // If we only sort keys, the output of sort is not a tuple, so we can always + // share the buffer. + if (user->operand_count() == 1) { + return true; + } + CHECK(!user_index.empty()); + // Only share with the right tuple element buffer. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && user_index[0] == operand_indices[0]; + } if (user->opcode() == HloOpcode::kCall) { // Get all uses of value defined by 'operand' at 'operand_index'. const auto& uses = GetValueDefinedAt(operand, operand_index).uses(); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 9fea218af0c4ac8a512bea5c187564a8219d041f..f4abc7a7c7dcfb223067fe946bec0c5ef32f206b 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -189,7 +189,7 @@ class HloDataflowAnalysis { bool UpdateGetTupleElementValueSet(HloInstruction* gte); bool UpdateParameterValueSet(HloInstruction* parameter); bool UpdateRecvDoneValueSet(HloInstruction* recv_done); - bool UpdateSelectValueSet(HloInstruction* select); + bool UpdateTupleSelectValueSet(HloInstruction* select); bool UpdateSendValueSet(HloInstruction* send); bool UpdateTupleValueSet(HloInstruction* tuple); bool UpdateWhileValueSet(HloInstruction* xla_while); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 0ea8bdcab680a40fd9301f2dcd5e0e176ac73d15..4755c4a0cf8d268b1c47e596a14605eb2c60b36c 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -101,9 +101,9 @@ TEST_P(HloDataflowAnalysisTest, BinaryOperation) { // Test the dataflow for a simple binary operation (Add). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, constant1, constant2)); module_->AddEntryComputation(builder.Build()); @@ -198,9 +198,9 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) { // Verify the dataflow through a nested tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto nested_tuple = builder.AddInstruction( @@ -259,9 +259,9 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); module_->AddEntryComputation(builder.Build()); @@ -308,9 +308,9 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto call1 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto call2 = builder.AddInstruction(HloInstruction::CreateCall( @@ -362,9 +362,9 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto call1 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto call2 = builder.AddInstruction(HloInstruction::CreateCall( @@ -426,9 +426,9 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, outer_computation)); module_->AddEntryComputation(builder.Build()); @@ -493,15 +493,15 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( @@ -594,15 +594,15 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while0 = builder.AddInstruction( @@ -653,7 +653,7 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); @@ -691,9 +691,9 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto entry_while = builder.AddInstruction( @@ -780,15 +780,15 @@ TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) { auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( @@ -840,11 +840,11 @@ TEST_P(HloDataflowAnalysisTest, ArraySelect) { // Test a kSelect of an array value. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2)); @@ -860,19 +860,18 @@ TEST_P(HloDataflowAnalysisTest, ArraySelect) { } TEST_P(HloDataflowAnalysisTest, TupleSelect) { - // Test a kSelect of a tuple value. Non-top-level element flow through the - // instruction. + // Test a kTupleSelect. Non-top-level element flow through the instruction. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -883,20 +882,20 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { builder.AddInstruction(HloInstruction::CreateTuple({constant4})); const Shape tuple_shape = tuple1->shape(); auto select11 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple1)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple1)); auto select12 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2)); auto select34 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple3, tuple4)); + tuple_shape, HloOpcode::kTupleSelect, pred, tuple3, tuple4)); auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, select12, select34)); + tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34)); module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - // Top-level value is always defined by a kSelect. + // Top-level value is always defined by a kTupleSelect. EXPECT_TRUE(analysis.ValueIsDefinedAt(select11)); EXPECT_TRUE(analysis.ValueIsDefinedAt(select12)); EXPECT_TRUE(analysis.ValueIsDefinedAt(select34)); @@ -937,20 +936,20 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { } TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { - // Test kSelect of a nested tuple. + // Test kTupleSelect of a nested tuple. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); auto constant5 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0))); auto inner_tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant2, constant3})); auto tuple1 = builder.AddInstruction( @@ -960,7 +959,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { auto tuple2 = builder.AddInstruction( HloInstruction::CreateTuple({constant4, inner_tuple2})); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); module_->AddEntryComputation(builder.Build()); @@ -983,7 +982,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { } TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) { - // Test a tuple-shaped kSelect feeding a kWhile instruction. HLO: + // Test a tuple-shaped kTupleSelect feeding a kWhile instruction. HLO: // // body((F32[], F32[]) %tuple_param): // %add = Add(%tuple_param{0}, %tuple_param{1}) @@ -1026,24 +1025,24 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({constant2})); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); auto gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape_, select, 0)); auto tuple = @@ -1089,7 +1088,7 @@ TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) { // Test the bitcast_defines_value flag to the dataflow analysis. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kBitcast, constant)); @@ -1158,44 +1157,50 @@ TEST_P(HloDataflowAnalysisTest, SendAndSendDone) { auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); auto send = builder.AddInstruction( - HloInstruction::CreateSend(param, /*channel_id=*/0)); + HloInstruction::CreateSend(param, token, /*channel_id=*/0)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - EXPECT_EQ(analysis.values().size(), 4); + EXPECT_EQ(analysis.values().size(), 6); EXPECT_TRUE(analysis.ValueIsDefinedAt(param)); EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{})); EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0})); EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{2})); EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done)); EXPECT_THAT(HloValuesAt(send, /*index=*/{0}), UnorderedElementsAre(analysis.GetValueDefinedAt(param))); } TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) { - // Test that a RecvDone forwards its operand tuple element at {0} to the - // output. + // Test that a RecvDone forwards its operand tuple element at {0} to element + // {0} of the output. auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); auto recv = builder.AddInstruction( - HloInstruction::CreateRecv(scalar_shape_, /*channel_id=*/0)); + HloInstruction::CreateRecv(scalar_shape_, token, /*channel_id=*/0)); auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - EXPECT_EQ(analysis.values().size(), 3); + EXPECT_EQ(analysis.values().size(), 7); EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{})); EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0})); EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1})); - EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done)); - EXPECT_THAT(HloValuesAt(recv_done), + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{2})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{0})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{1})); + EXPECT_THAT(HloValuesAt(recv_done, /*index=*/{0}), UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0}))); EXPECT_TRUE( analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module()); @@ -1304,13 +1309,13 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { auto body_param = body_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "body_param")); auto constant = body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto exp = body_builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant)); auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, exp, body_param)); auto dead_constant = body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto dead_negate = body_builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kNegate, dead_constant)); HloComputation* body = module_->AddEmbeddedComputation( @@ -1320,7 +1325,7 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "cond_param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); HloComputation* condition = module_->AddEmbeddedComputation(cond_builder.Build()); @@ -1571,11 +1576,11 @@ TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) { auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(56.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(12.0f))); auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( scalar_shape_, pred, constant1, true_computation, constant2, false_computation)); @@ -1662,11 +1667,11 @@ TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(56.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(12.0f))); auto tuple_operand = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( @@ -1792,15 +1797,15 @@ TEST_P(HloDataflowAnalysisTest, NestedConditionals) { // Build entry computation. auto builder = HloComputation::Builder(TestName()); auto pred1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto pred2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.2f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.2f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.3f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.3f))); auto tuple_operand = builder.AddInstruction( HloInstruction::CreateTuple({pred2, constant1, constant2})); auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( @@ -1938,9 +1943,9 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -2043,7 +2048,7 @@ TEST_F(CanShareOperandBufferWithUserTest, Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -2071,7 +2076,7 @@ TEST_F(CanShareOperandBufferWithUserTest, auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "param0")); auto index = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 0}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 0}))); auto ds = builder.AddInstruction( HloInstruction::CreateDynamicSlice(slice_shape, param, index, {1, 2, 2})); @@ -2139,9 +2144,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -2179,9 +2184,9 @@ TEST_F(CanShareOperandBufferWithUserTest, // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape_bf16, convert1, update, starts)); @@ -2227,14 +2232,56 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); } +TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto sort = + builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); + Shape values_shape = ShapeUtil::MakeShape(F32, {8}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto values = builder.AddInstruction( + HloInstruction::CreateParameter(1, values_shape, "values")); + auto sort = builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The buffer for the keys can be shared with the first tuple entry. + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0})); + // The buffer for the values can be shared with the second tuple entry. + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {1})); + // Verify that the buffers are not shared with the "wrong" tuple entry. + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1})); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {0})); +} + TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { auto builder = HloComputation::Builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -2243,7 +2290,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add_operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -2265,7 +2312,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -2273,7 +2320,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { HloInstruction::CreateReverse(data_shape, operand, {0, 1})); auto two = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); @@ -2293,13 +2340,13 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( data_shape, HloOpcode::kMultiply, operand, operand)); auto two = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, mul, two)); @@ -2318,7 +2365,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) { TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - auto make_cond = [this, &data_shape]() { + auto make_cond = [&data_shape]() { auto builder = HloComputation::Builder(TestName() + ".Cond"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); @@ -2327,7 +2374,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { return builder.Build(); }; - auto make_body = [this, &data_shape]() { + auto make_body = [&data_shape]() { auto builder = HloComputation::Builder(TestName() + ".Body"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); @@ -2365,7 +2412,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { auto sub_param = sub_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "sub_param")); auto one = sub_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto ones = sub_builder.AddInstruction( HloInstruction::CreateBroadcast(shape, one, {1})); auto add = sub_builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 2822ecd788f624ff4e289f4b2d32fb83caf8bd77..26e3736e01270dbc6ca67647e814843aba2d1e3d 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -53,9 +53,9 @@ TEST_F(HloDceTest, NoDeadCode) { // Verify that no dead code is removed from a computation with no dead code. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -74,20 +74,21 @@ TEST_F(HloDceTest, InstructionsWithSideEffect) { // Verify that side-effect instructions (Send in this test) are not removed. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); builder.AddInstruction( - HloInstruction::CreateSend(constant, /*channel_id=*/0)); + HloInstruction::CreateSend(constant, token, /*channel_id=*/0)); builder.AddInstruction(HloInstruction::CreateTuple({})); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(3, computation->instruction_count()); + EXPECT_EQ(4, computation->instruction_count()); HloDCE dce; EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); - EXPECT_EQ(3, computation->instruction_count()); + EXPECT_EQ(4, computation->instruction_count()); } TEST_F(HloDceTest, DeadParameters) { @@ -126,9 +127,9 @@ TEST_F(HloDceTest, ControlDependencies) { // Verify that instructions with control dependencies are not removed. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); // Create two dead instructions: a negate and an add. auto dead_negate = builder.AddInstruction(HloInstruction::CreateUnary( @@ -223,7 +224,7 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) { auto param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "cond_param")); auto constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); cond_builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, constant)); } @@ -234,8 +235,7 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) { { auto param = body_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); - auto token = - body_builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto token = body_builder.AddInstruction(HloInstruction::CreateToken()); auto infeed = body_builder.AddInstruction( HloInstruction::CreateInfeed(shape, token, "")); body_builder.AddInstruction( @@ -279,8 +279,8 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { { auto param = nested_callee_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); - auto token = nested_callee_builder.AddInstruction( - HloInstruction::CreateAfterAll({})); + auto token = + nested_callee_builder.AddInstruction(HloInstruction::CreateToken()); nested_callee_builder.AddInstruction( HloInstruction::CreateOutfeed(shape, param, token, "")); } @@ -345,12 +345,12 @@ TEST_F(HloDceTest, RemoveDeadSubcomputation) { builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")), builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), /*dimensions_to_reduce=*/{0}, reduce_subcomp)); // Add another instruction as the root of the computation. builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); module->AddEntryComputation(builder.Build()); EXPECT_EQ(module->MakeComputationPostOrder().size(), 2); @@ -386,7 +386,7 @@ TEST_F(HloDceTest, KeepUsedSubcomputation) { builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")), builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), /*dimensions_to_reduce=*/{0}, reduce_subcomp)); // Add another instruction as the root of the computation that also uses @@ -396,7 +396,7 @@ TEST_F(HloDceTest, KeepUsedSubcomputation) { builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {100}), "param1")), builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), /*dimensions_to_reduce=*/{0}, reduce_subcomp)); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index ebd5adb5d573ce4b556046f85eb26a6ad59efcb9..9e096320db5048457435199627a1ef1fe1572177 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -41,11 +41,15 @@ namespace xla { bool HloDomainMap::InSameDomain(HloInstruction* instruction1, HloInstruction* instruction2) const { - int64 domain_id1 = FindOrDefault(instruction_to_domain_, instruction1, -1); - int64 domain_id2 = FindOrDefault(instruction_to_domain_, instruction2, -1); + int64 domain_id1 = GetDomainId(instruction1); + int64 domain_id2 = GetDomainId(instruction2); return domain_id1 >= 0 && domain_id1 == domain_id2; } +int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const { + return FindOrDefault(instruction_to_domain_, instruction, -1); +} + Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain); // We only check operands, so we are sure to not process the empty domain from @@ -58,6 +62,11 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } } + if (instruction == instruction->parent()->root_instruction()) { + auto domain = MakeUnique(); + domain->enter_domains.insert(instruction); + TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h index e62ef763fb3881ab6030b1f6a66266ac80a3d84d..1ca71597253eecfb45ae8f384240033a57045277 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.h +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -65,6 +65,10 @@ class HloDomainMap { // currently processing. bool IsDomainInstruction(HloInstruction* instruction) const; + // Retrieves the domain identifier of the instruction, or -1 in case + // instruction is not found within any domain. + int64 GetDomainId(HloInstruction* instruction) const; + private: HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {} diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h index aa0308100a21f109579de75788fce7d242d6a6b0..f855f2a1fc944fcc11c9afed278bef4af87813da 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -71,12 +71,6 @@ class DomainMetadata { // Returns a string representation of the metadata. virtual string ToString() const = 0; - - // Given a reachable set (the set of instructions which are reachable from - // each other via user/operand pathways, without crossing a kDomain - // instruciton), makes sure that all of them have metadata attributes which - // are coherent with this metadata object. - virtual Status NormalizeInstructions(const Domain& domain) const = 0; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.cc b/tensorflow/compiler/xla/service/hlo_domain_remover.cc index 1d06040b0e7c92b03f4cb5481bdee73a0f74f939..67fad0769f5eb5ceca64ebd2aa78c6469f2c813d 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_remover.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_domain_remover.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" #include "tensorflow/compiler/xla/service/hlo_domain_map.h" +#include "tensorflow/compiler/xla/service/hlo_domain_verifier.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -43,54 +43,16 @@ class HloDomainRemover::RunContext { Status HloDomainRemover::RunContext::VerifyAndNormalizeDomain( const DomainMetadata::Domain& domain) { - // Verify that the whole kDomain frontier bounding the instruction reach set, - // has matching metadata. - // A kDomain instruction has two sides of metadata, a user facing and an - // operand facing. - // A reachable instruction set can make contact with a kDomain instruction on - // a user facing side (the kDomain is operand of the instruction), or on a - // operand facing side (the kDomain is user of the instruction). - // And depending on the contact side, the proper metadata object - // (user_side_metadata() vs. operand_side_metadata()) needs to be used for - // consistency checks. - const DomainMetadata* ref_metadata = nullptr; - VLOG(4) << "Reach set:"; - for (HloInstruction* instruction : domain.instructions) { - VLOG(4) << " " << instruction->name(); - } - VLOG(4) << " Domains:"; - for (HloInstruction* instruction : domain.enter_domains) { - const DomainMetadata& meta = instruction->user_side_metadata(); - VLOG(4) << " User side: " << instruction->name(); - VLOG(4) << " " << meta.ToString(); - if (ref_metadata == nullptr) { - ref_metadata = &meta; - } else { - TF_RET_CHECK(meta.Matches(*ref_metadata)) - << "Metadata mismatch at instruction " << instruction->name() << " : " - << meta.ToString() << " vs " << ref_metadata->ToString(); - } - } - for (HloInstruction* instruction : domain.exit_domains) { - const DomainMetadata& meta = instruction->operand_side_metadata(); - VLOG(4) << " Operand side: " << instruction->name(); - VLOG(4) << " " << meta.ToString(); - if (ref_metadata == nullptr) { - ref_metadata = &meta; - } else { - TF_RET_CHECK(meta.Matches(*ref_metadata)) - << "Metadata mismatch at instruction " << instruction->name() << " : " - << meta.ToString() << " vs " << ref_metadata->ToString(); - } - } + TF_ASSIGN_OR_RETURN(const DomainMetadata* ref_metadata, + HloDomainVerifier::VerifyDomain(domain)); if (ref_metadata != nullptr) { VLOG(4) << "Applying domain normalization: " << ref_metadata->ToString(); - TF_RETURN_IF_ERROR(ref_metadata->NormalizeInstructions(domain)); + TF_RETURN_IF_ERROR(remover_->normalizer_(domain, ref_metadata)); } else { // No kDomain instruction was present within this domain, so call the // generic normalization functions and have them apply their heuristic. VLOG(2) << "Applying domain-less normalization"; - TF_RETURN_IF_ERROR(remover_->normalizer_(domain)); + TF_RETURN_IF_ERROR(remover_->normalizer_(domain, nullptr)); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h index 0c71dd34fd4d2944037dc965a2c9ad2c592d6e3e..c859e05f02e54d601804b641094ecdd11bbe1aed 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_remover.h +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h @@ -35,9 +35,10 @@ class HloDomainRemover : public HloPassInterface { // instructions in it with the same attributes (ie, sharding), a normalizer // function is tasked at applying attribute normalization on the instructions // within such domain. - HloDomainRemover( - tensorflow::StringPiece kind, - std::function normalizer) + HloDomainRemover(tensorflow::StringPiece kind, + std::function + normalizer) : kind_(kind.ToString()), normalizer_(std::move(normalizer)) {} tensorflow::StringPiece name() const override { return "domain_remover"; } @@ -48,7 +49,9 @@ class HloDomainRemover : public HloPassInterface { class RunContext; string kind_; - std::function normalizer_; + std::function + normalizer_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index abc5b1c8effe03e39a2683eb2876ad0a27293921..ffc18a0f886df86d87944d9c284a6faf8afe4c60 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -97,12 +97,6 @@ class OpNameMetadata : public DomainMetadata { string ToString() const override { return opname_; } - Status NormalizeInstructions( - const DomainMetadata::Domain& domain) const override { - // For the purposes of this test, nothing to do. - return Status::OK(); - } - static tensorflow::StringPiece KindName() { return "opname"; } private: @@ -124,7 +118,8 @@ std::unique_ptr OpNameDomainCreator(HloInstruction* instruction, std::move(user_side_metadata)); } -Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain) { +Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain, + const DomainMetadata* metadata) { // Nothing to do for the particular use this test make of the OpName domains. return Status::OK(); } @@ -159,7 +154,7 @@ ENTRY entry { EXPECT_FALSE(HasDomainEdge(module, "e", "d")); HloDomainRemover remover(ShardingMetadata::KindName(), - NormalizeShardingDomain); + ShardingMetadata::NormalizeShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); EXPECT_TRUE(remover_changed); @@ -201,12 +196,14 @@ HloModule Module ENTRY entry { p0 = (f32[4]) parameter(0) a = f32[4] get-tuple-element(p0), index=0 - b = (f32[4], u32[]) send(a), channel_id=1, sharding={maximal device=0} - c = () send-done(b), channel_id=1, sharding={maximal device=0} - d = (f32[4], u32[]) recv(), channel_id=2, sharding={maximal device=0} - e = f32[4] recv-done(d), channel_id=2, sharding={maximal device=0} - f = f32[4] add(a, e) - g = f32[4] subtract(a, e) + token = token[] after-all() + b = (f32[4], u32[], token[]) send(a, token), channel_id=1, sharding={maximal device=0} + c = token[] send-done(b), channel_id=1, sharding={maximal device=0} + d = (f32[4], u32[], token[]) recv(token), channel_id=2, sharding={maximal device=0} + e = (f32[4], token[]) recv-done(d), channel_id=2, sharding={maximal device=0} + e_element = f32[4] get-tuple-element(e), index=0, sharding={maximal device=0} + f = f32[4] add(a, e_element) + g = f32[4] subtract(a, e_element) ROOT h = (f32[4], f32[4]) tuple(f, g) } )"; @@ -219,18 +216,18 @@ ENTRY entry { EXPECT_TRUE(isolator_changed); EXPECT_TRUE(HasDomainEdge(module, "b", "a")); - EXPECT_TRUE(HasDomainEdge(module, "f", "e")); + EXPECT_TRUE(HasDomainEdge(module, "f", "e_element")); EXPECT_FALSE(HasDomainEdge(module, "a", "p0")); EXPECT_FALSE(HasDomainEdge(module, "c", "b")); EXPECT_FALSE(HasDomainEdge(module, "e", "d")); HloDomainRemover remover(ShardingMetadata::KindName(), - NormalizeShardingDomain); + ShardingMetadata::NormalizeShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); EXPECT_TRUE(remover_changed); EXPECT_FALSE(HasDomainEdge(module, "b", "a")); - EXPECT_FALSE(HasDomainEdge(module, "f", "e")); + EXPECT_FALSE(HasDomainEdge(module, "f", "e_element")); } TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { @@ -238,11 +235,13 @@ TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { HloModule Module ENTRY entry { - a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=-1} - b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=-1} - c = f32[4] add(b, b), sharding={maximal device=-1} - d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=-1} - ROOT e = () send-done(d), channel_id=2, sharding={maximal device=-1} + token = token[] after-all(), sharding={maximal device=-1} + a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=-1} + b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=-1} + b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=-1} + c = f32[4] add(b_element, b_element), sharding={maximal device=-1} + d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=-1} + ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=-1} } )"; @@ -259,11 +258,13 @@ TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) { HloModule Module ENTRY entry { - a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=0} - b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=0} - c = f32[4] add(b, b) - d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=0} - ROOT e = () send-done(d), channel_id=2, sharding={maximal device=0} + token = token[] after-all(), sharding={maximal device=0} + a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=0} + b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=0} + b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=0} + c = f32[4] add(b_element, b_element) + d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=0} + ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=0} } )"; @@ -271,7 +272,7 @@ ENTRY entry { LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainRemover remover(ShardingMetadata::KindName(), - NormalizeShardingDomain); + ShardingMetadata::NormalizeShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); EXPECT_FALSE(remover_changed); @@ -318,7 +319,7 @@ ENTRY entry { EXPECT_FALSE(HasDomainEdge(module, "e", "d")); HloDomainRemover sharding_remover(ShardingMetadata::KindName(), - NormalizeShardingDomain); + ShardingMetadata::NormalizeShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, sharding_remover.Run(module)); EXPECT_TRUE(sharding_remover_changed); @@ -405,7 +406,7 @@ ENTRY entry { } HloDomainRemover remover(ShardingMetadata::KindName(), - NormalizeShardingDomain); + ShardingMetadata::NormalizeShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); EXPECT_TRUE(remover_changed); @@ -430,6 +431,44 @@ ENTRY entry { HloSharding::AssignDevice(0)})); } +TEST_F(HloDomainTest, EmptyRootDomain) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + %param = f32[1] parameter(0), sharding={maximal device=0} + %tuple = (f32[1]) tuple(%param), + sharding={maximal device=1} + ROOT %gte = f32[1] get-tuple-element(%tuple), index=0, + sharding={maximal device=1} +})"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module, "tuple", "param")); + EXPECT_FALSE(HasDomainEdge(module, "gte", "tuple")); + + // Remove %tuple and %gte (tuple simplification) + HloInstruction* gte = FindInstruction(module, "gte"); + HloInstruction* tuple = FindInstruction(module, "tuple"); + module->entry_computation()->set_root_instruction(tuple->mutable_operand(0)); + TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(gte)); + TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(tuple)); + + HloDomainRemover remover(ShardingMetadata::KindName(), + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + EXPECT_TRUE(remover_changed); + + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_TRUE(root->has_sharding()); + EXPECT_EQ(root->sharding(), HloSharding::AssignDevice(1)); +} + // Tests that text dumps of domain instructions can be parsed back, in the // specific case of null shardings. TEST_F(HloDomainTest, DumpParseNullSharding) { diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc new file mode 100644 index 0000000000000000000000000000000000000000..751fc677e2d955fd3d9f8970f7c0370a22c054bf --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc @@ -0,0 +1,124 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_verifier.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +class HloDomainVerifier::RunContext { + public: + RunContext(HloModule* module, HloDomainVerifier* verifier) + : module_(module), verifier_(verifier) {} + + Status Run(); + + private: + // If the verifier caller passed an empty vector for kinds, we collect all the + // avalable domain types. + Status PopulateDomainKinds(); + + HloModule* module_; + HloDomainVerifier* verifier_; +}; + +Status HloDomainVerifier::RunContext::PopulateDomainKinds() { + if (verifier_->kinds_.empty()) { + // The caller specified no domain kinds, collect all the ones available. + std::set kinds; + for (HloComputation* computation : module_->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kDomain) { + TF_RET_CHECK(instruction->user_side_metadata().Kind() == + instruction->operand_side_metadata().Kind()) + << instruction->ToString(); + kinds.insert(instruction->user_side_metadata().Kind().ToString()); + } + } + } + verifier_->kinds_.insert(verifier_->kinds_.end(), kinds.begin(), + kinds.end()); + } + return Status::OK(); +} + +Status HloDomainVerifier::RunContext::Run() { + VLOG(4) << "Running HLO Domain Verifier"; + TF_RETURN_IF_ERROR(PopulateDomainKinds()); + for (HloComputation* computation : module_->computations()) { + for (auto& kind : verifier_->kinds_) { + // First create the domain instruciton sets. A domain instruction set is + // the set of instructions whose edges never cross a kDomain instruction. + TF_ASSIGN_OR_RETURN(std::unique_ptr domain_map, + HloDomainMap::Create(computation, kind)); + // Verify every domain populated within the map. + for (auto& domain : domain_map->GetDomains()) { + TF_RETURN_IF_ERROR(VerifyDomain(*domain).status()); + } + } + } + return Status::OK(); +} + +StatusOr HloDomainVerifier::Run(HloModule* module) { + RunContext run_context(module, this); + TF_RETURN_IF_ERROR(run_context.Run()); + return false; +} + +StatusOr HloDomainVerifier::VerifyDomain( + const DomainMetadata::Domain& domain) { + const DomainMetadata* ref_metadata = nullptr; + VLOG(4) << "Reach set:"; + for (HloInstruction* instruction : domain.instructions) { + VLOG(4) << " " << instruction->name(); + } + VLOG(4) << " Domains:"; + for (HloInstruction* instruction : domain.enter_domains) { + const DomainMetadata& meta = instruction->user_side_metadata(); + VLOG(4) << " User side: " << instruction->name(); + VLOG(4) << " " << meta.ToString(); + if (ref_metadata == nullptr) { + ref_metadata = &meta; + } else { + TF_RET_CHECK(meta.Matches(*ref_metadata)) + << "Metadata mismatch at instruction " << instruction->name() << " : " + << meta.ToString() << " vs " << ref_metadata->ToString(); + } + } + for (HloInstruction* instruction : domain.exit_domains) { + const DomainMetadata& meta = instruction->operand_side_metadata(); + VLOG(4) << " Operand side: " << instruction->name(); + VLOG(4) << " " << meta.ToString(); + if (ref_metadata == nullptr) { + ref_metadata = &meta; + } else { + TF_RET_CHECK(meta.Matches(*ref_metadata)) + << "Metadata mismatch at instruction " << instruction->name() << " : " + << meta.ToString() << " vs " << ref_metadata->ToString(); + } + } + return ref_metadata; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h new file mode 100644 index 0000000000000000000000000000000000000000..8e53cf97f8ba9a88140a909ad20c1a938aec8c1f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_VERIFIER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_VERIFIER_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { + +// Verifies that the domain instructions are consistent, and the each domain is +// surrounded by the same metadata. +class HloDomainVerifier : public HloPassInterface { + public: + HloDomainVerifier(std::vector kinds) : kinds_(std::move(kinds)) {} + + tensorflow::StringPiece name() const override { return "domain_verifier"; } + + StatusOr Run(HloModule* module) override; + + // Verify that the whole kDomain frontier bounding the instruction reach set, + // has matching metadata. + // A kDomain instruction has two sides of metadata, a user facing and an + // operand facing. + // A reachable instruction set can make contact with a kDomain instruction on + // a user facing side (the kDomain is operand of the instruction), or on a + // operand facing side (the kDomain is user of the instruction). + // And depending on the contact side, the proper metadata object + // (user_side_metadata() vs. operand_side_metadata()) needs to be used for + // consistency checks. + // Returns the DomainMetadata pointer which surrounds the domain, and + // represents the common metadata within such domain. If the returned + // DomainMetadata pointer is nullptr, the input domain had no kDomain + // boundary. + static StatusOr VerifyDomain( + const DomainMetadata::Domain& domain); + + private: + class RunContext; + + std::vector kinds_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_VERIFIER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index 4ed1508d7067684a15d0fb7d86e69b055bc1333b..b9244b8e9e5f34e7ac5113c8eacb6f8243eea314 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" @@ -144,6 +144,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { opcode == HloOpcode::kCrossReplicaSum || opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || + opcode == HloOpcode::kScatter || opcode == HloOpcode::kSelectAndScatter || opcode == HloOpcode::kConditional) { continue; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index deb7f28d849ba3df5f54289ddca7357b8af416fe..51353eea6e72d5a131897f3c3ae312046051103e 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -135,7 +136,6 @@ StatusOr> Compare( } // namespace - HloEvaluator::HloEvaluator(int64 max_loop_iterations) : max_loop_iterations_(max_loop_iterations) { typed_visitors_[PRED] = MakeUnique>(this); @@ -330,6 +330,24 @@ StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( return result; } +StatusOr> HloEvaluator::EvaluateDotOp( + const DotDimensionNumbers& dim_numbers, const Literal& lhs, + const Literal& rhs) { + std::unique_ptr lhs_instr = + HloInstruction::CreateConstant(lhs.CloneToUnique()); + std::unique_ptr rhs_instr = + HloInstruction::CreateConstant(rhs.CloneToUnique()); + + TF_ASSIGN_OR_RETURN( + Shape dot_shape, + ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers)); + + std::unique_ptr cloned_instruction = + HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(), + dim_numbers); + return Evaluate(cloned_instruction.get()); +} + Status HloEvaluator::HandleParameter(HloInstruction* parameter) { CHECK_LT(parameter->parameter_number(), arg_literals_.size()); const Literal* input_literal = arg_literals_[parameter->parameter_number()]; @@ -382,7 +400,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { ShapeUtil::GetDimension(operand_shape, concat_dim); } - auto result_literal = Literal::CreateFromDimensions( + auto result_literal = LiteralUtil::CreateFromDimensions( reference_shape.element_type(), concat_dimensions); DimensionVector source_indices(rank, 0); DimensionVector dest_indices(concat_dimensions.size(), 0); @@ -533,7 +551,7 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) { operand_literals.push_back(&GetEvaluatedLiteralFor(operand)); } - evaluated_[tuple] = Literal::MakeTuple(operand_literals); + evaluated_[tuple] = LiteralUtil::MakeTuple(operand_literals); return Status::OK(); } @@ -757,6 +775,12 @@ class OutputWindowIndexToInputIndex { return ArraySlice(input_index_); } + // Returns for a given 'input_dim' the corresponding output dimension index, + // or -1 if 'input_dim' is an elided window dimension. + int64 input_dim_value_to_output_index(int64 input_dim) { + return input_dim_value_to_output_index_[input_dim]; + } + private: // Propagates window dimensions from the output index to input_index_ by // mutating input_index_ in place. @@ -774,7 +798,7 @@ class OutputWindowIndexToInputIndex { // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of // the input index from the output index. See - // PropagateOutputIndexToInputIndex. + // PropagateOutputIndexWindowDimsToInputIndex. std::vector input_dim_value_to_output_index_; // The result computed by this functor. operator() returns an ArraySlice into @@ -827,6 +851,8 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { // corresponding index in the input shape. std::vector input_index(operand.shape().dimensions_size()); std::vector output_index(gather->shape().dimensions_size()); + std::vector input_gather_index_clamped( + operand.shape().dimensions_size()); OutputGatherIndexToInputIndex output_gather_index_to_input_index( &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), @@ -848,14 +874,26 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { output_index[i] = output_gather_index[i] + output_window_index[i]; DCHECK_LT(output_index[i], shape.dimensions(i)); } + for (int i = 0, e = input_gather_index.size(); i < e; i++) { + int64 output_dim = + output_window_index_to_input_index.input_dim_value_to_output_index(i); + // If 'output_dim' is -1, it means 'i' is an elided window dim. This means + // we set the iteration index to 0, so for the purpose of the following + // calculations we can consider the output dimension size to be 1. + int64 output_dim_size = + output_dim == -1 ? 1 : shape.dimensions(output_dim); + // Clamp the gather index so that the gather region fits in the operand. + // input_gather_index_clamped[i] = clamp(input_gather_index[i], 0, + // operand_shape.dimensions(i) - + // output_dim_size); + input_gather_index_clamped[i] = + std::min(operand_shape.dimensions(i) - output_dim_size, + std::max(0LL, input_gather_index[i])); + } for (int i = 0, e = input_index.size(); i < e; i++) { - // TODO(b/74360564): We should implement whatever out of bounds behavior - // we decide for dynamic-slice here as well. - input_index[i] = (input_gather_index[i] + input_window_index[i]) % - operand_shape.dimensions(i); - if (input_index[i] < 0) { - input_index[i] += operand_shape.dimensions(i); - } + input_index[i] = input_gather_index_clamped[i] + input_window_index[i]; + DCHECK_GE(input_index[i], 0); + DCHECK_LT(input_index[i], operand_shape.dimensions(i)); } TF_RETURN_IF_ERROR( result->CopyElementFrom(operand, input_index, output_index)); @@ -903,7 +941,7 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { } Status HloEvaluator::HandleAfterAll(HloInstruction* token) { - evaluated_[token] = Literal::CreateToken(); + evaluated_[token] = LiteralUtil::CreateToken(); return Status::OK(); } @@ -1024,8 +1062,6 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) { const auto& on_false = GetEvaluatedLiteralFor(select->operand(2)); // If predicate is of scalar type, no element-wise selection would be needed. - // This would also handle output array of tuple types as the DefaultAction - // would go through the HloEvaluatorTypedVisitor which doesn't handle tuples. if (ShapeUtil::IsScalar(pred.shape())) { if (pred.Get({})) { evaluated_[select] = on_true.CloneToUnique(); @@ -1038,6 +1074,19 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) { return DefaultAction(select); } +Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) { + const auto& pred = GetEvaluatedLiteralFor(tuple_select->operand(0)); + const auto& on_true = GetEvaluatedLiteralFor(tuple_select->operand(1)); + const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2)); + + if (pred.Get({})) { + evaluated_[tuple_select] = on_true.CloneToUnique(); + } else { + evaluated_[tuple_select] = on_false.CloneToUnique(); + } + return Status::OK(); +} + Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { HloComputation* cond_comp = while_hlo->while_condition(); HloComputation* body_comp = while_hlo->while_body(); @@ -1068,6 +1117,161 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { return Status::OK(); } +// Key-value sort is a special snowflake: it's templated on two different +// element types, one for the keys, and one for the values. Jump through some +// hoops to make this work. +namespace { +template +StatusOr> EvaluateSortInternal( + HloInstruction* sort, const Literal& keys_literal, + const Literal& values_literal) { + auto rank = ShapeUtil::Rank(keys_literal.shape()); + TF_RET_CHECK( + ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape())) + << "Sort keys and values must have the same dimensions"; + TF_RET_CHECK(rank > 0 && rank <= 2) + << "Sort is only supported for rank-1 and rank-2 shapes, rank is: " + << rank; + TF_RET_CHECK(sort->operand_count() == 2) << "Expected key-value sort"; + // We need to sort and array of keys and an array of values, where the + // sorted order of the values is determined by the keys. The simplest(?) + // way to do this is to go to an array-of-pairs representation, sort the + // array using the keys, and then go back to pair-of-arrays. + VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); + VLOG(3) << "HandleSort values_literal: " << values_literal.ToString(); + + auto sort_r1 = [](const Literal& keys_literal, + const Literal& values_literal) { + const auto& keys_data = keys_literal.data(); + const auto& values_data = values_literal.data(); + + using kv_pair = std::pair; + std::vector key_value_vector; + CHECK_EQ(keys_data.size(), values_data.size()); + key_value_vector.reserve(keys_data.size()); + for (int i = 0; i < keys_data.size(); ++i) { + key_value_vector.push_back(std::make_pair(keys_data[i], values_data[i])); + } + std::sort(key_value_vector.begin(), key_value_vector.end(), + [](const kv_pair& a, const kv_pair& b) { + return SafeLess(a.first, b.first); + }); + std::vector result_keys; + std::vector result_values; + for (const auto& key_value : key_value_vector) { + result_keys.push_back(key_value.first); + result_values.push_back(key_value.second); + } + auto result_keys_literal = MakeUnique(keys_literal.shape()); + result_keys_literal->PopulateR1( + tensorflow::gtl::ArraySlice(result_keys)); + auto result_values_literal = MakeUnique(values_literal.shape()); + result_values_literal->PopulateR1( + tensorflow::gtl::ArraySlice(result_values)); + return std::make_pair(std::move(result_keys_literal), + std::move(result_values_literal)); + }; + + std::unique_ptr result_tuple; + if (rank == 1) { + auto result_pair = sort_r1(keys_literal, values_literal); + result_tuple = LiteralUtil::MakeTuple( + {result_pair.first.get(), result_pair.second.get()}); + } else { + // For R2 sort, the desired semantics are to sort each matrix row + // independently. + auto keys_result_literal = MakeUnique(keys_literal.shape()); + auto values_result_literal = MakeUnique(values_literal.shape()); + int64 r1_length = keys_literal.shape().dimensions(1); + for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) { + TF_ASSIGN_OR_RETURN(auto keys_r1_slice, + keys_literal.Slice({row, 0}, {row + 1, r1_length}) + ->Reshape({r1_length})); + TF_ASSIGN_OR_RETURN(auto values_r1_slice, + values_literal.Slice({row, 0}, {row + 1, r1_length}) + ->Reshape({r1_length})); + auto r1_result_pair = sort_r1(*keys_r1_slice, *values_r1_slice); + TF_ASSIGN_OR_RETURN(auto sorted_keys, + r1_result_pair.first->Reshape({1, r1_length})); + TF_ASSIGN_OR_RETURN(auto sorted_values, + r1_result_pair.second->Reshape({1, r1_length})); + TF_RETURN_IF_ERROR(keys_result_literal->CopySliceFrom( + *sorted_keys, {0, 0}, {row, 0}, {1, r1_length})); + TF_RETURN_IF_ERROR(values_result_literal->CopySliceFrom( + *sorted_values, {0, 0}, {row, 0}, {1, r1_length})); + } + result_tuple = LiteralUtil::MakeTuple( + {keys_result_literal.get(), values_result_literal.get()}); + } + + VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString(); + return std::move(result_tuple); +} + +template +StatusOr> EvaluateSortCurried( + HloInstruction* sort, const Literal& keys_literal, + const Literal& values_literal) { + switch (sort->operand(1)->shape().element_type()) { + case F32: + return EvaluateSortInternal(sort, keys_literal, + values_literal); + case U32: + return EvaluateSortInternal(sort, keys_literal, + values_literal); + case S32: + return EvaluateSortInternal(sort, keys_literal, + values_literal); + case BF16: + return EvaluateSortInternal(sort, keys_literal, + values_literal); + default: + return InvalidArgument("Unsupported type for Sort"); + } +} + +StatusOr> EvaluateSort(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { + switch (sort->operand(0)->shape().element_type()) { + case F32: + return EvaluateSortCurried(sort, keys_literal, values_literal); + case U32: + return EvaluateSortCurried(sort, keys_literal, values_literal); + case S32: + return EvaluateSortCurried(sort, keys_literal, values_literal); + case BF16: + return EvaluateSortCurried(sort, keys_literal, values_literal); + default: + return InvalidArgument("Unsupported type for Sort"); + } +} +} // namespace + +Status HloEvaluator::HandleSort(HloInstruction* sort) { + const int64 sort_dim = sort->dimensions(0); + const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape()); + if (sort_dim != rank - 1) { + return Unimplemented( + "Trying to support along dimension %lld, which is not the last " + "dimension", + sort_dim); + } + + if (!ShapeUtil::IsTuple(sort->shape())) { + return DefaultAction(sort); + } else { + auto result = EvaluateSort(sort, GetEvaluatedLiteralFor(sort->operand(0)), + GetEvaluatedLiteralFor(sort->operand(1))); + if (result.ok()) { + evaluated_[sort] = std::move(result.ValueOrDie()); + return Status::OK(); + } else { + return result.status(); + } + } +} + Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 2ad56080d862deb357dec128ca43c989f0d14d14..a4c37ef32827892194da070ee05ec6dc4f4c306f 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -115,6 +116,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault { StatusOr> EvaluateElementwiseUnaryOp( HloOpcode opcode, const Literal& operand); + StatusOr> EvaluateDotOp( + const DotDimensionNumbers& dim_numbers, const Literal& lhs, + const Literal& rhs); + protected: // Make HloEvaluatorTypedVisitor a friend because it is logically part of this // class. @@ -172,10 +177,14 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; + Status HandleBroadcast(HloInstruction* broadcast) override; Status HandleAfterAll(HloInstruction* token) override; + Status HandleSort(HloInstruction* sort) override; + // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 42770d848a83b2e27b87bc963d259e2b7af664a4..251109c89f1ae971cec95057d604c396d4f99522 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" @@ -112,9 +112,9 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp // with 3 operands. TEST_P(HloEvaluatorTest, DoesClamp) { - auto low = Literal::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); - auto value = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); - auto high = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); + auto low = LiteralUtil::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); + auto value = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + auto high = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); Shape shape = low->shape(); HloComputation::Builder b(TestName()); @@ -127,15 +127,15 @@ TEST_P(HloEvaluatorTest, DoesClamp) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({{0, 4}, {2, 4}}); + auto expected = LiteralUtil::CreateR2({{0, 4}, {2, 4}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { - auto low = Literal::CreateR0(0.f); - auto value = Literal::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); - auto high = Literal::CreateR0(1.f); + auto low = LiteralUtil::CreateR0(0.f); + auto value = LiteralUtil::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); + auto high = LiteralUtil::CreateR0(1.f); Shape shape = value->shape(); HloComputation::Builder b(TestName()); @@ -148,7 +148,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({{0, 0}, {1, 1}}); + auto expected = LiteralUtil::CreateR2({{0, 0}, {1, 1}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -156,9 +156,9 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { // Verifies that HloEvaluator evaluates a HLO instruction that performs select // with 3 operands. TEST_P(HloEvaluatorTest, DoesSelect) { - auto pred = Literal::CreateR2({{true, false}, {false, true}}); - auto on_true = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); - auto on_false = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + auto pred = LiteralUtil::CreateR2({{true, false}, {false, true}}); + auto on_true = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); + auto on_false = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); Shape shape = on_true->shape(); HloComputation::Builder b(TestName()); @@ -173,7 +173,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { std::unique_ptr result = Evaluate({}); - auto expected = Literal::CreateR2({{2, 5}, {0, 4}}); + auto expected = LiteralUtil::CreateR2({{2, 5}, {0, 4}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -181,46 +181,46 @@ TEST_P(HloEvaluatorTest, DoesSelect) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise addition with 2 operands. TEST_P(HloEvaluatorTest, DoesAdd) { - auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - auto expected = Literal::CreateR2({{3, 4}, {-96, 8}}); + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto expected = LiteralUtil::CreateR2({{3, 4}, {-96, 8}}); TestBinaryOp(HloOpcode::kAdd, std::move(expected), std::move(lhs), std::move(rhs)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise and with 2 operands. TEST_P(HloEvaluatorTest, DoesAnd) { - auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - auto expected = Literal::CreateR2({{0, 0}, {4, 4}}); + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto expected = LiteralUtil::CreateR2({{0, 0}, {4, 4}}); TestBinaryOp(HloOpcode::kAnd, std::move(expected), std::move(lhs), std::move(rhs)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise or with 2 operands. TEST_P(HloEvaluatorTest, DoesOr) { - auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - auto expected = Literal::CreateR2({{3, 4}, {-100, 4}}); + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto expected = LiteralUtil::CreateR2({{3, 4}, {-100, 4}}); TestBinaryOp(HloOpcode::kOr, std::move(expected), std::move(lhs), std::move(rhs)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise or with 2 operands. TEST_P(HloEvaluatorTest, DoesXor) { - auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - auto expected = Literal::CreateR2({{3, 4}, {-104, 0}}); + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto expected = LiteralUtil::CreateR2({{3, 4}, {-104, 0}}); TestBinaryOp(HloOpcode::kXor, std::move(expected), std::move(lhs), std::move(rhs)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise multiply with 2 operands. TEST_P(HloEvaluatorTest, DoesMultiply) { - auto lhs = Literal::CreateR2({{-1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2( + auto lhs = LiteralUtil::CreateR2({{-1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2( {{std::numeric_limits::min(), 4}, {4, 4}}); - auto expected = Literal::CreateR2( + auto expected = LiteralUtil::CreateR2( {{std::numeric_limits::min(), 0}, {-400, 16}}); TestBinaryOp(HloOpcode::kMultiply, std::move(expected), std::move(lhs), std::move(rhs)); @@ -228,17 +228,17 @@ TEST_P(HloEvaluatorTest, DoesMultiply) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise divide with 2 operands. TEST_P(HloEvaluatorTest, DoesDivideInt64) { - auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - auto expected = Literal::CreateR2({{0, 0}, {-25, 1}}); + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto expected = LiteralUtil::CreateR2({{0, 0}, {-25, 1}}); TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), std::move(rhs)); } TEST_P(HloEvaluatorTest, DoesDivideDouble) { - auto lhs = Literal::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); - auto rhs = Literal::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); + auto lhs = LiteralUtil::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); + auto rhs = LiteralUtil::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); auto expected = - Literal::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); + LiteralUtil::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), std::move(rhs)); } @@ -246,54 +246,54 @@ TEST_P(HloEvaluatorTest, DoesDivideDouble) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise abs op with 1 operand. TEST_P(HloEvaluatorTest, DoesAbsR2) { - auto operand = Literal::CreateR2({{1, -20}, {-100, 4}}); - auto expected = Literal::CreateR2({{1, 20}, {100, 4}}); + auto operand = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); + auto expected = LiteralUtil::CreateR2({{1, 20}, {100, 4}}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } TEST_P(HloEvaluatorTest, DoesAbsR0) { - auto operand = Literal::CreateR0(-1.0f); - auto expected = Literal::CreateR0(1.0f); + auto operand = LiteralUtil::CreateR0(-1.0f); + auto expected = LiteralUtil::CreateR0(1.0f); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } TEST_P(HloEvaluatorTest, DoesAbsR1WithZeroSize) { - auto operand = Literal::CreateR1({}); - auto expected = Literal::CreateR1({}); + auto operand = LiteralUtil::CreateR1({}); + auto expected = LiteralUtil::CreateR1({}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } TEST_P(HloEvaluatorTest, DoesNegateR2) { - auto operand = Literal::CreateR2( + auto operand = LiteralUtil::CreateR2( {{0, std::numeric_limits::min()}, {-1, 4}}); - auto expected = - Literal::CreateR2({{0, std::numeric_limits::min()}, {1, -4}}); + auto expected = LiteralUtil::CreateR2( + {{0, std::numeric_limits::min()}, {1, -4}}); TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand)); } TEST_P(HloEvaluatorTest, DoesCosR2) { - auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); - auto expected = Literal::CreateR2({{1, -1}, {-1, 1}}); + auto operand = LiteralUtil::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); + auto expected = LiteralUtil::CreateR2({{1, -1}, {-1, 1}}); TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand), use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } TEST_P(HloEvaluatorTest, DoesSinR2) { - auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); - auto expected = Literal::CreateR2({{0, 0}, {0, 0}}); + auto operand = LiteralUtil::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); + auto expected = LiteralUtil::CreateR2({{0, 0}, {0, 0}}); TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand), use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } TEST_P(HloEvaluatorTest, DoesNotR2) { auto operand = - Literal::CreateR2({{0, std::numeric_limits::min()}, - {-1, std::numeric_limits::max()}}); + LiteralUtil::CreateR2({{0, std::numeric_limits::min()}, + {-1, std::numeric_limits::max()}}); auto expected = - Literal::CreateR2({{-1, std::numeric_limits::max()}, - {0, std::numeric_limits::min()}}); + LiteralUtil::CreateR2({{-1, std::numeric_limits::max()}, + {0, std::numeric_limits::min()}}); TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand)); } // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { - auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - auto rhs2 = Literal::CreateR2({{1, -20}, {-100, 4}}); + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto rhs2 = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); std::vector args = {lhs.get(), rhs.get(), rhs2.get()}; Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); @@ -314,7 +314,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { std::unique_ptr result = Evaluate(args); - auto expected = Literal::CreateR2({{4, -16}, {-196, 12}}); + auto expected = LiteralUtil::CreateR2({{4, -16}, {-196, 12}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -324,7 +324,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { HloComputation::Builder b(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - Literal::CreateRandomLiteral( + LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->CloneToUnique(); HloInstruction* literal_instruction = @@ -349,8 +349,8 @@ TEST_P(HloEvaluatorTest, DoesReshape) { // Verifies Broadcast operation is correctly evaluated. TEST_P(HloEvaluatorTest, DoesBroadcast) { HloComputation::Builder b(TestName()); - auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); - auto output_literal = Literal::CreateR3( + auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + auto output_literal = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{1, 2}, {3, 4}, {5, 6}}}); HloInstruction* literal_instruction = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -365,8 +365,8 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { HloComputation::Builder b(TestName()); - auto input_literal = Literal::CreateR0(111); - auto output_literal = Literal::CreateR2( + auto input_literal = LiteralUtil::CreateR0(111); + auto output_literal = LiteralUtil::CreateR2( {{111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}}); HloInstruction* literal_instruction = b.AddInstruction( @@ -386,9 +386,9 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{-1, -2}, {100, 200}}))); + LiteralUtil::CreateR2({{-1, -2}, {100, 200}}))); HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{-2, -3}, {-100, -200}}))); + LiteralUtil::CreateR2({{-2, -3}, {-100, -200}}))); std::vector operands = {operand1, operand2}; @@ -399,8 +399,8 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { std::unique_ptr result = Evaluate(); - auto expected = - Literal::CreateR2({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); + auto expected = LiteralUtil::CreateR2( + {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -408,9 +408,9 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({100, 200}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({100, 200}))); HloInstruction* operand2 = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); std::vector operands = {operand1, operand2}; @@ -421,16 +421,16 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR1({100, 200}); + auto expected = LiteralUtil::CreateR1({100, 200}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { HloComputation::Builder b(TestName()); - auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); auto expected = - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), expected->shape())); @@ -447,9 +447,9 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { HloComputation::Builder b(TestName()); - auto input_literal = Literal::CreateR2WithLayout( + auto input_literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1})); - auto expected = Literal::CreateR2WithLayout( + auto expected = LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0})); ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), expected->shape())); @@ -478,13 +478,13 @@ PaddingConfig CreatePaddingConfig( } TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { - auto operand = Literal::CreateR2({{}, {}}); + auto operand = LiteralUtil::CreateR2({{}, {}}); HloComputation::Builder b(TestName()); auto operand_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); constexpr int32 kPadValue = 10; - auto pad_value = Literal::CreateR0(kPadValue); + auto pad_value = LiteralUtil::CreateR0(kPadValue); auto padding_value_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value))); @@ -496,7 +496,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2( + auto expected = LiteralUtil::CreateR2( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); @@ -506,11 +506,11 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { HloComputation::Builder b(TestName()); Array4D input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); - auto input = Literal::CreateR4FromArray4D(input_array); + auto input = LiteralUtil::CreateR4FromArray4D(input_array); HloInstruction* input_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); constexpr float kPadValue = 1.5; - auto pad_value = Literal::CreateR0(kPadValue); + auto pad_value = LiteralUtil::CreateR0(kPadValue); HloInstruction* pad_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value))); @@ -532,7 +532,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { (*expected_array)(7, 0, 0, 0) = 5.0f; (*expected_array)(7, 2, 0, 0) = 6.0f; - auto expected = Literal::CreateR4FromArray4D(*expected_array); + auto expected = LiteralUtil::CreateR4FromArray4D(*expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -549,12 +549,12 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { // } auto input_array = MakeUnique>(4, 3); input_array->FillUnique(1.0f); - auto input = Literal::CreateR2FromArray2D(*input_array); + auto input = LiteralUtil::CreateR2FromArray2D(*input_array); HloInstruction* input_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); auto pad_value_instruction = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.718f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.718f))); auto r2_padding_on_dim0_dim1 = CreatePaddingConfig({{{-1, -2, 0}}, {{-2, 4, 0}}}); @@ -574,7 +574,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 2) = 2.718f; (*expected_array)(0, 3) = 2.718f; (*expected_array)(0, 4) = 2.718f; - auto expected = Literal::CreateR2FromArray2D(*expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250))); } @@ -590,12 +590,12 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { // } auto input_array = MakeUnique>(4, 3); input_array->FillUnique(1.0f); - auto input = Literal::CreateR2FromArray2D(*input_array); + auto input = LiteralUtil::CreateR2FromArray2D(*input_array); HloInstruction* input_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); auto pad_value_instruction = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.718f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.718f))); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -613,7 +613,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { std::unique_ptr result = Evaluate(); auto expected_array = MakeUnique>(0, 9); - auto expected = Literal::CreateR2FromArray2D(*expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -630,13 +630,13 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // } auto lhs_array = MakeUnique>(4, 1); lhs_array->FillUnique(1.0f); - auto lhs_literal = Literal::CreateR2FromArray2D(*lhs_array); + auto lhs_literal = LiteralUtil::CreateR2FromArray2D(*lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); // rhs: // f32[2] { 1, 2 }, - auto rhs_literal = Literal::CreateR2({{1, 2}}); + auto rhs_literal = LiteralUtil::CreateR2({{1, 2}}); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -658,7 +658,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { {4.f, 8.f}, }); // clang-format on - auto expected = Literal::CreateR2FromArray2D(expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -669,7 +669,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { // lhs: // f32[3] // { 1, 2, 3 }, - auto lhs_literal = Literal::CreateR1({1, 2, 3}); + auto lhs_literal = LiteralUtil::CreateR1({1, 2, 3}); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -681,7 +681,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { // } auto rhs_array = MakeUnique>(3, 2); rhs_array->FillUnique(1.0f); - auto rhs_literal = Literal::CreateR2FromArray2D(*rhs_array); + auto rhs_literal = LiteralUtil::CreateR2FromArray2D(*rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -695,7 +695,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR1({22.f, 28.f}); + auto expected = LiteralUtil::CreateR1({22.f, 28.f}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -712,7 +712,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { // } auto lhs_array = MakeUnique>(4, 3); lhs_array->FillUnique(1.0f); - auto lhs_literal = Literal::CreateR2FromArray2D(*lhs_array); + auto lhs_literal = LiteralUtil::CreateR2FromArray2D(*lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -724,7 +724,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { // } auto rhs_array = MakeUnique>(3, 2); rhs_array->FillUnique(1.0f); - auto rhs_literal = Literal::CreateR2FromArray2D(*rhs_array); + auto rhs_literal = LiteralUtil::CreateR2FromArray2D(*rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -744,7 +744,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { {94.f, 124.f}, {130.f, 172.f}, }); - auto expected = Literal::CreateR2FromArray2D(expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -753,12 +753,12 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { HloComputation::Builder b(TestName()); Array3D lhs_array = {{{1, 2, 3}}}; - auto lhs_literal = Literal::CreateR3FromArray3D(lhs_array); + auto lhs_literal = LiteralUtil::CreateR3FromArray3D(lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); Array3D rhs_array = {{{3.f, 4.f}}}; - auto rhs_literal = Literal::CreateR3FromArray3D(rhs_array); + auto rhs_literal = LiteralUtil::CreateR3FromArray3D(rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -792,7 +792,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { std::unique_ptr result = Evaluate(); Array3D expected_array = {{{11.f, 18.f, 9.f}}}; - auto expected = Literal::CreateR3FromArray3D(expected_array); + auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -809,7 +809,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { {13, 14, 15, 16}, })); // clang-format on - auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -820,7 +820,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { {7, 8}, })); // clang-format on - auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -854,7 +854,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { {149, 160, 171, 80}, })); // clang-format on - auto expected = Literal::CreateR4FromArray4D(expected_array); + auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -884,11 +884,11 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { }}); // clang-format on - auto lhs_literal = Literal::CreateR4FromArray4D(input); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(input); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); - auto rhs_literal = Literal::CreateR4FromArray4D(weight); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(weight); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); rhs_instruction = b.AddInstruction(HloInstruction::CreateReverse( @@ -933,7 +933,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { Array4D expected_array({{{{2514, 2685}}}}); Array4D expected_array_bf16({{{{2512, 2672}}}}); // clang-format on - auto expected = Literal::CreateR4FromArray4D( + auto expected = LiteralUtil::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); @@ -964,11 +964,11 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { }}); // clang-format on - auto lhs_literal = Literal::CreateR4FromArray4D(input); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(input); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); - auto rhs_literal = Literal::CreateR4FromArray4D(weight); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(weight); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -1010,7 +1010,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { Array4D expected_array({{{{2514, 2685}}}}); Array4D expected_array_bf16({{{{2512, 2672}}}}); // clang-format on - auto expected = Literal::CreateR4FromArray4D( + auto expected = LiteralUtil::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); @@ -1028,7 +1028,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { {13, 14, 15, 16}, })); // clang-format on - auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -1039,7 +1039,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { {7, 8}, })); // clang-format on - auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -1074,7 +1074,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { {91, 112, 98, 120, 105, 128, 112}, {65, 84, 70, 90, 75, 96, 80}, })); - auto expected = Literal::CreateR4FromArray4D(expected_array); + auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -1091,7 +1091,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { {13, 14, 15, 16}, })); // clang-format on - auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -1102,7 +1102,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { {7, 8}, })); // clang-format on - auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -1138,7 +1138,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { {104, 91, 112, 98, 120, 105, 128, 112}, {78, 65, 84, 70, 90, 75, 96, 80}, })); - auto expected = Literal::CreateR4FromArray4D(expected_array); + auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -1156,7 +1156,7 @@ TEST_P(HloEvaluatorTest, {13, 14, 15, 16}, })); // clang-format on - auto lhs_literal = Literal::CreateR4FromArray4D(lhs_array); + auto lhs_literal = LiteralUtil::CreateR4FromArray4D(lhs_array); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); @@ -1167,7 +1167,7 @@ TEST_P(HloEvaluatorTest, {8, 9, 10}, })); // clang-format on - auto rhs_literal = Literal::CreateR4FromArray4D(rhs_array); + auto rhs_literal = LiteralUtil::CreateR4FromArray4D(rhs_array); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); @@ -1210,7 +1210,7 @@ TEST_P(HloEvaluatorTest, {0, 0, 0}, {91, 98, 105}, })); - auto expected = Literal::CreateR4FromArray4D(expected_array); + auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -1225,9 +1225,9 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 std::vector v(kNumElements, 1.0f); HloInstruction* arg_instruction = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1(v))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1(v))); HloInstruction* init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); HloComputation::Builder add_computation("add"); Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1262,9 +1262,9 @@ void BM_ReducePrecisely(int num_iters) { constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 std::vector v(kNumElements, 1.0f); HloInstruction* arg_instruction = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1(v))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1(v))); auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); HloComputation::Builder add_computation("add"); Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1299,13 +1299,13 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { // } auto arg_array = MakeUnique>(2, 3); arg_array->FillUnique(1.0f); - auto arg_literal = Literal::CreateR2FromArray2D(*arg_array); + auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); HloInstruction* arg_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); HloComputation::Builder add_computation("add"); Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1326,7 +1326,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR1({6, 18}); + auto expected = LiteralUtil::CreateR1({6, 18}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -1341,13 +1341,13 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { // } auto arg_array = MakeUnique>(2, 3); arg_array->FillUnique(1.0f); - auto arg_literal = Literal::CreateR2FromArray2D(*arg_array); + auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); HloInstruction* arg_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); HloComputation::Builder max_computation("max"); Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1378,7 +1378,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({{6, 7}}); + auto expected = LiteralUtil::CreateR2({{6, 7}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -1392,13 +1392,13 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { // } auto arg_array = MakeUnique>(2, 3); arg_array->FillUnique(1.0f); - auto arg_literal = Literal::CreateR2FromArray2D(*arg_array); + auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); HloInstruction* arg_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); HloComputation::Builder add_computation("add"); Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1435,7 +1435,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({{1, 3, 5}, {5, 11, 13}}); + auto expected = LiteralUtil::CreateR2({{1, 3, 5}, {5, 11, 13}}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } @@ -1445,13 +1445,13 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. std::vector input_dims(6, 4); std::unique_ptr arg_literal = - Literal::CreateFullWithDescendingLayout(input_dims, 1.0f); + LiteralUtil::CreateFullWithDescendingLayout(input_dims, 1.0f); HloInstruction* arg_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); HloComputation::Builder add_computation("add"); Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1498,7 +1498,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { std::vector output_dims = {4, 3, 3, 3, 4, 4}; std::unique_ptr result_literal = - Literal::CreateFullWithDescendingLayout(output_dims, 8.0f); + LiteralUtil::CreateFullWithDescendingLayout(output_dims, 8.0f); EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result)); } @@ -1513,7 +1513,8 @@ TEST_P(HloEvaluatorTest, StridedSlice) { // } auto operand_array = MakeUnique>(3, 5); operand_array->FillUnique(1.0f); - auto operand_literal = Literal::CreateR2FromArray2D(*operand_array); + auto operand_literal = + LiteralUtil::CreateR2FromArray2D(*operand_array); HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); @@ -1527,7 +1528,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({ + auto expected = LiteralUtil::CreateR2({ {3}, {19}, }); @@ -1545,13 +1546,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { // } auto operand_array = MakeUnique>(2, 4); operand_array->FillUnique(1.0f); - auto operand_literal = Literal::CreateR2FromArray2D(*operand_array); + auto operand_literal = + LiteralUtil::CreateR2FromArray2D(*operand_array); HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, @@ -1560,7 +1562,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({ + auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, {6, 7, 8}, }); @@ -1580,13 +1582,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { // } auto operand_array = MakeUnique>(2, 4); operand_array->FillUnique(1.0f); - auto operand_literal = Literal::CreateR2FromArray2D(*operand_array); + auto operand_literal = + LiteralUtil::CreateR2FromArray2D(*operand_array); HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2, 1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2, 1}))); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, @@ -1595,7 +1598,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({ + auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, {6, 7, 8}, }); @@ -1613,16 +1616,17 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { // } auto operand_array = MakeUnique>(2, 3); operand_array->FillUnique(1.0); - auto operand_literal = Literal::CreateR2FromArray2D(*operand_array); + auto operand_literal = + LiteralUtil::CreateR2FromArray2D(*operand_array); HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); auto start_indices = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); auto update = b.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{-2.0, -3.0}, {-6.0, -7.0}}))); + LiteralUtil::CreateR2({{-2.0, -3.0}, {-6.0, -7.0}}))); Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( @@ -1631,7 +1635,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({ + auto expected = LiteralUtil::CreateR2({ {1, -2, -3}, {5, -6, -7}, }); @@ -1649,12 +1653,13 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { // } auto operand_array = MakeUnique>(2, 3); operand_array->FillUnique(1.0); - auto operand_literal2 = Literal::CreateR2FromArray2D(*operand_array); + auto operand_literal2 = + LiteralUtil::CreateR2FromArray2D(*operand_array); HloInstruction* operand2 = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal2))); HloInstruction* operand1 = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); auto tuple = b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2})); @@ -1666,7 +1671,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { std::unique_ptr result = Evaluate(); - auto expected = Literal::CreateR2({ + auto expected = LiteralUtil::CreateR2({ {1, 2, 3}, {5, 6, 7}, }); @@ -1686,9 +1691,9 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { operand_array->FillUnique(1.0); HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2FromArray2D(*operand_array))); + LiteralUtil::CreateR2FromArray2D(*operand_array))); HloInstruction* operand1 = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 1}))); auto tuple1 = b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2})); @@ -1706,8 +1711,8 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { std::unique_ptr result = Evaluate(); auto result_inner_literal = - Literal::CreateR2FromArray2D(*operand_array); - auto expected = Literal::MakeTuple({ + LiteralUtil::CreateR2FromArray2D(*operand_array); + auto expected = LiteralUtil::MakeTuple({ result_inner_literal.get(), result_inner_literal.get(), }); @@ -1735,7 +1740,7 @@ TEST_P(HloEvaluatorTest, Reverse) { {{23.0f}, {24.0f}}}, }); // clang-format on - auto operand_literal = Literal::CreateR4FromArray4D(input); + auto operand_literal = LiteralUtil::CreateR4FromArray4D(input); HloInstruction* operand = b.AddInstruction( HloInstruction::CreateConstant(std::move(operand_literal))); @@ -1746,7 +1751,7 @@ TEST_P(HloEvaluatorTest, Reverse) { std::unique_ptr result = Evaluate(); // clang-format off - auto expected = Literal::CreateR4FromArray4D({ + auto expected = LiteralUtil::CreateR4FromArray4D({ {{{23.0f}, {24.0f}}, {{21.0f}, {22.0f}}, {{19.0f}, {20.0f}}}, @@ -1782,11 +1787,11 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { // Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}. HloEvaluator evaluator; auto result = evaluator.EvaluateWithSubstitutions( - add, {{param0, Literal::CreateR1({1, 2, 3, 4}).get()}, - {square, Literal::CreateR1({10, 20, 30, 40}).get()}}); + add, {{param0, LiteralUtil::CreateR1({1, 2, 3, 4}).get()}, + {square, LiteralUtil::CreateR1({10, 20, 30, 40}).get()}}); TF_ASSERT_OK(result.status()); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); + *LiteralUtil::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); } // Check that EvaluateWithSubstitutions works if one of the operands to the op @@ -1799,18 +1804,18 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0")); HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kMultiply, param0, param0)); - HloInstruction* constant = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + HloInstruction* constant = b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); HloInstruction* add = b.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, constant, square)); // Evaluate add with square = {10, 20, 30, 40}. HloEvaluator evaluator; auto result = evaluator.EvaluateWithSubstitutions( - add, {{square, Literal::CreateR1({10, 20, 30, 40}).get()}}); + add, {{square, LiteralUtil::CreateR1({10, 20, 30, 40}).get()}}); TF_ASSERT_OK(result.status()); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); + *LiteralUtil::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { @@ -1830,11 +1835,12 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); - EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}), - *Evaluate({operand.get(), gather_indices.get()}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { @@ -1854,10 +1860,11 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR2({{1, 3}, {4, 6}, {7, 9}}), + *LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -1878,11 +1885,11 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 2}, {2, 1}}); + LiteralUtil::CreateR2({{0, 2}, {2, 1}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR3( + *LiteralUtil::CreateR3( {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -1904,13 +1911,13 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{-1, 1}, {-4, 4}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -1932,13 +1939,13 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{-2, 2}, {-1, 1}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -1959,10 +1966,11 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({1, 1}); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{5}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{5}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -1983,11 +1991,11 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR2({{2, 1}, {1, 1}}); + LiteralUtil::CreateR2({{2, 1}, {1, 1}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR3({{{8}}, {{5}}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR3({{{8}}, {{5}}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -2007,10 +2015,11 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = Literal::CreateR2({{}, {}, {}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{}, {}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{}, {}}), *Evaluate({operand.get(), gather_indices.get()}))); } @@ -2031,26 +2040,479 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = Literal::CreateR1({0, 1, 2}); + std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); std::unique_ptr gather_indices = - Literal::CreateR3({{{0}, {1}}, {{2}, {1}}}); + LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{0, 1}, {2, 1}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{0, 1}, {2, 1}}), *Evaluate({operand.get(), gather_indices.get()}))); } +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV2 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=mul_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_f32 (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(f32[] lhs, f32[] rhs) +} + +ENTRY main { + operand = f32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = f32[2,3] parameter(2) + ROOT scatter = f32[3,3] scatter(operand, indices, updates), + to_apply=add_f32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = LiteralUtil::CreateR2( + {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({2, 1}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); + EXPECT_TRUE(LiteralTestUtil::Near( + *LiteralUtil::CreateR2( + {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}), + ErrorSpec{0.1, 0.01})); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({1, 1}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { + const char* hlo_text = R"( +HloModule TensorFlowScatterMultipleBatchDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=2 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + std::unique_ptr updates = LiteralUtil::CreateR3( + {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}), + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNd + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); + std::unique_ptr expected = + LiteralUtil::CreateR3({{{-10, 10}, {-2, 2}, {-3, 3}}, // + {{-40, 40}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, + EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNdNonDefaultIndexVectorDim + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); + std::unique_ptr expected = + LiteralUtil::CreateR3({{{-20, 20}, {-10, 10}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule DynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[1,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0,1}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({1, 1}); + std::unique_ptr updates = LiteralUtil::CreateR2({{10}}); + std::unique_ptr expected = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule BatchDynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,1,1] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + std::unique_ptr updates = + LiteralUtil::CreateR3({{{10}}, {{20}}}); + std::unique_ptr expected = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { + const char* hlo_text = R"( +HloModule TensorFlowScatter_ZeroDimBounds + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,0] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,0] parameter(2) + ROOT scatter = s32[3,0] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR1({0, 2}); + std::unique_ptr updates = LiteralUtil::CreateR2({{}, {}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *operand, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { + const string hlo_text = R"( +HloModule Scatter_NoUpdateWindowDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[2,2,1] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2 +} +)"; + ParseAndVerifyModule(hlo_text); + + std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); + std::unique_ptr scatter_indices = + LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); + std::unique_ptr updates = + LiteralUtil::CreateR2({{10, 20}, {30, 40}}); + std::unique_ptr expected = + LiteralUtil::CreateR1({10, 61, 32}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *expected, + *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); +} + // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise comparison with 2 bfloat16 operands. TEST_P(HloEvaluatorTest, DoesCompareBF16) { // lhs >= rhs - auto lhs = Literal::CreateR2( + auto lhs = LiteralUtil::CreateR2( {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)}, {bfloat16(-0.25), bfloat16(-0.35), bfloat16(-0.125)}}); - auto rhs = Literal::CreateR2( + auto rhs = LiteralUtil::CreateR2( {{bfloat16(0.5), bfloat16(0.125), bfloat16(0.125)}, {bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}}); auto expected = - Literal::CreateR2({{false, true, true}, {false, true, true}}); + LiteralUtil::CreateR2({{false, true, true}, {false, true, true}}); TestBinaryOp(HloOpcode::kGe, std::move(expected), std::move(lhs), std::move(rhs)); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 8b08756c641fa8de6c7739fb4dd94ceceeb53311..4dc03fd06d6917ed74371aa9926a1059e502826d 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/core/lib/core/casts.h" @@ -34,6 +35,37 @@ using is_complex_t = std::is_same; template using is_complex64_t = std::is_same; +// It's UB to use std::sort with std::less, because of NaNs. Define +// "safe" less functions which are actually strict weak orders. +template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> +bool SafeLess(const NativeT& a, const NativeT& b) { + return a < b; +} + +template ::value || + std::is_same::value>::type* = nullptr> +bool SafeLess(const NativeT& a, const NativeT& b) { + if (std::isnan(b)) { + return !std::isnan(a); + } else { + return a < b; + } +} + +template ::value>::type* = nullptr> +bool SafeLess(const NativeT& a, const NativeT& b) { + if (Eigen::half_impl::isnan(b)) { + return !Eigen::half_impl::isnan(a); + } else { + return a < b; + } +} + // Templated DfsHloVisitor for use by HloEvaluator. // // Typically ReturnT here indicates the resulting literal type of each evaluated @@ -269,6 +301,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleFloor(floor); } + Status HandleImag(HloInstruction* imag) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[imag], + ElementWiseUnaryOp(imag, [](ElementwiseT elem_operand) { + return std::imag(elem_operand); + })); + return Status::OK(); + } + Status HandleLog(HloInstruction* log) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { @@ -572,6 +612,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleReal(HloInstruction* real) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[real], + ElementWiseUnaryOp(real, [](ElementwiseT elem_operand) { + return std::real(elem_operand); + })); + return Status::OK(); + } + template < typename NativeT, typename std::enable_if::value>::type* = nullptr> @@ -1025,83 +1073,47 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { CHECK_EQ(dnums.lhs_batch_dimensions_size(), dnums.rhs_batch_dimensions_size()); - std::vector lhs_non_contracting_dims; + DimensionVector lhs_index(lhs_rank); + DimensionVector rhs_index(rhs_rank); + + // result_index_locations[i] contains one or two pointers to the locations + // in lhs_index or rhs_index where the i'th result index should go. + tensorflow::gtl::InlinedVector, kInlineRank> + result_index_locations; + result_index_locations.reserve(lhs_rank + rhs_rank - 2); + + // The first components in the output shape are the LHS and RHS batch + // dimensions: + for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); i++) { + result_index_locations.push_back( + {&lhs_index[dnums.lhs_batch_dimensions(i)], + &rhs_index[dnums.rhs_batch_dimensions(i)]}); + } + + // Then we have the LHS and RHS non-contracting dimensions, if any: for (int64 i = 0; i < lhs_rank; i++) { - if (i != lhs_contracting_dimension) { - lhs_non_contracting_dims.push_back(i); + if (i != lhs_contracting_dimension && + !ArrayContains(AsInt64Slice(dnums.lhs_batch_dimensions()), i)) { + result_index_locations.push_back({&lhs_index[i], nullptr}); } } - - std::vector rhs_non_batch_non_contracting_dims; - tensorflow::gtl::FlatSet batch_dims_set( - dnums.rhs_batch_dimensions().begin(), - dnums.rhs_batch_dimensions().end()); for (int64 i = 0; i < rhs_rank; i++) { - if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { - rhs_non_batch_non_contracting_dims.push_back(i); + if (i != rhs_contracting_dimension && + !ArrayContains(AsInt64Slice(dnums.rhs_batch_dimensions()), i)) { + result_index_locations.push_back({&rhs_index[i], nullptr}); } } - const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); - const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); - - DimensionVector lhs_index(lhs_rank); - DimensionVector rhs_index(rhs_rank); auto result = MakeUnique(dot->shape()); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice result_index) { ElementwiseT result_val = static_cast(0); - // Find the corresponding non-contracting indices for lhs and rhs. - // - // For `result_index`, its batch dimension, if exists, will be at the - // same dimension as the batch dimension of lhs and rhs. More - // specifically: - // - For lhs, the non-contracting dimensions, including the batch - // dimension have the same index as the `result_index`. - // - For rhs, the batch dimension is set seperately from other - // non-contracting dimensions, since these other non-contracting - // dimensions in rhs follow the non-contracting dimensions of lhs in - // the resulting index. - // - // As an example, for a resulting index: - // result_index [result_batch, result_x, result_y] - // the effecting lhs and rhs indices are: - // lhs [result_batch, lhs_non_contracting_dim, contracting_dim - // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] - // `result_x` is only affected by the lhs_non_contracting_dim and - // likewise `result_y` only depends on rhs_non_contracting_dim. - // - // so we can look up the lhs and rhs indices by: - // - // lhs: - // batch index is the same as `result_batch`. - // non-contracting dimension is the same as - // result_index[lhs_non_contracting_dim] - // rhs: - // batch index: the same as `result_batch`. - // non-contracting dimension index: *not* the same as - // result_index[rhs_non_contractng_dim], since the - // non-contracting dimensions of lhs are included in the - // result_index first. Instead, the non_contracting_dim of rhs must - // be calculated as following: - // lhs_non_contracting_dimensions_size + - // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 - // - // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is - // the index offset to the result_index that only depends on - // the non_batch and non-contracting dimensions of rhs. -1 at the - // end translates size to index. - for (auto i : lhs_non_contracting_dims) { - lhs_index[i] = result_index[i]; - } - for (auto i : dnums.rhs_batch_dimensions()) { - rhs_index[i] = result_index[i]; - } - for (auto i : rhs_non_batch_non_contracting_dims) { - const int64 rhs_non_batch_non_contracting_dim = - lhs_non_contracting_size + (i - batch_dim_size) - 1; - rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; + for (int64 i = 0; i < result_index.size(); i++) { + *result_index_locations[i].first = result_index[i]; + if (result_index_locations[i].second) { + *result_index_locations[i].second = result_index[i]; + } } // Accumulates resulting product along the contracted dimension. @@ -1321,7 +1333,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { parent_->GetEvaluatedLiteralFor(operand); auto curr_val = arg_literal.Get(multi_index); - auto curr_val_literal = Literal::CreateR0(curr_val); + auto curr_val_literal = LiteralUtil::CreateR0(curr_val); arg_literals.push_back(std::move(curr_val_literal)); } @@ -1402,24 +1414,49 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { !is_complex_t::value && !std::is_same::value>::type* = nullptr> Status HandleSort(HloInstruction* sort) { - TF_RET_CHECK(ShapeUtil::Rank(sort->shape()) == 1) - << "Sort is only supported for R1 shapes"; - - auto arg = sort->operand(0); - const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); - VLOG(3) << "HandleSort arg_literal: " << arg_literal.ToString(); - const auto& arg_data = arg_literal.data(); + auto keys = sort->operand(0); + auto rank = ShapeUtil::Rank(keys->shape()); + TF_RET_CHECK(rank > 0 && rank <= 2) + << "Sort is only supported for R1 and R2 shapes"; + TF_RET_CHECK(sort->operand_count() == 1) + << "Typed visitor does not support key-value sort"; + + const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys); + + auto sort_r1 = [this](const Literal& keys_literal) { + VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); + const auto& keys_data = keys_literal.data(); + + std::vector result_data(keys_data.begin(), keys_data.end()); + std::sort(result_data.begin(), result_data.end(), + [](const ReturnT& a, const ReturnT& b) { + return SafeLess(a, b); + }); + auto result_literal = MakeUnique(keys_literal.shape()); + result_literal->PopulateR1( + tensorflow::gtl::ArraySlice(result_data)); + VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); + return result_literal; + }; - std::vector return_data(arg_data.begin(), arg_data.end()); - std::sort(return_data.begin(), return_data.end(), - [](const ReturnT& a, const ReturnT& b) { - return SafeLess(a, b); - }); - auto result_literal = MakeUnique(sort->shape()); - result_literal->PopulateR1( - tensorflow::gtl::ArraySlice(return_data)); - VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); - parent_->evaluated_[sort] = std::move(result_literal); + if (rank == 1) { + parent_->evaluated_[sort] = std::move(sort_r1(keys_literal)); + } else { + // For R2 sort, the desired semantics are to sort each matrix row + // independently. + auto result_literal = MakeUnique(keys_literal.shape()); + int64 r1_length = keys->shape().dimensions(1); + for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { + TF_ASSIGN_OR_RETURN(auto r1_slice, + keys_literal.Slice({row, 0}, {row + 1, r1_length}) + ->Reshape({r1_length})); + auto r1_result = sort_r1(*r1_slice); + TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length})); + TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( + *r1_result, {0, 0}, {row, 0}, {1, r1_length})); + } + parent_->evaluated_[sort] = std::move(result_literal); + } return Status::OK(); } @@ -1444,8 +1481,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { ShapeUtil::Rank(arg->shape()) - dimensions.size()); TF_ASSIGN_OR_RETURN(auto inferred_return_shape, ShapeInference::InferReduceShape( - /*arg=*/arg->shape(), - /*init_value=*/init_value->shape(), + {&arg->shape(), &init_value->shape()}, /*dimensions_to_reduce=*/dimensions, /*to_apply=*/function->ComputeProgramShape())); TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) @@ -1507,8 +1543,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto curr_val = arg_literal.Get(input_index); // Evaluate computation with specified literal operands. - auto curr_val_literal = Literal::CreateR0(curr_val); - auto result_val_literal = Literal::CreateR0(result_val); + auto curr_val_literal = LiteralUtil::CreateR0(curr_val); + auto result_val_literal = + LiteralUtil::CreateR0(result_val); std::unique_ptr computed_result = embedded_evaluator @@ -1586,10 +1623,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Used in the dual IterateThroughWindow lambdas below. Hoisted to avoid // dynamic memory allocations. - auto curr_val_literal = Literal::CreateR0(ReturnT()); - auto selected_val_literal = Literal::CreateR0(ReturnT()); - auto source_literal_scatter = Literal::CreateR0(ReturnT()); - auto scattered_literal = Literal::CreateR0(ReturnT()); + auto curr_val_literal = LiteralUtil::CreateR0(ReturnT()); + auto selected_val_literal = LiteralUtil::CreateR0(ReturnT()); + auto source_literal_scatter = LiteralUtil::CreateR0(ReturnT()); + auto scattered_literal = LiteralUtil::CreateR0(ReturnT()); do { // For each element in `source`, we place a window in `operand`. For each // window placement, we iterate inside the window twice: @@ -1710,9 +1747,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Evaluate computation with specified literal operands. const auto curr_val_literal = - Literal::CreateR0(curr_val); + LiteralUtil::CreateR0(curr_val); const auto result_val_literal = - Literal::CreateR0(result_val); + LiteralUtil::CreateR0(result_val); std::unique_ptr computed_result = embedded_evaluator .Evaluate( @@ -1734,6 +1771,388 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + // Reshapes the scatter indices input to have a trailing degenerate `1` + // dimension if necessary. Hands over the ownership of the newly created + // literal (if there is one) to `reshaped_indices`. + StatusOr> ReshapedScatterIndices( + int64 index_vector_dim, const Literal& indices, + std::unique_ptr* reshaped_indices) { + if (indices.shape().dimensions_size() != index_vector_dim) { + return std::cref(indices); + } + + std::vector new_shape(indices.shape().dimensions().begin(), + indices.shape().dimensions().end()); + new_shape.push_back(1); + TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape)); + return std::cref(**reshaped_indices); + } + + // Returns an ShapeUtil::IndexIterationSpace that iterates over the update + // scatter dimensions while keeping the rest of the update dimensions clamped + // to 0. + ShapeUtil::IndexIterationSpace IterationSpaceForUpdateScatterIndices( + const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { + int64 updates_rank = updates_shape.dimensions_size(); + std::vector index_base(updates_rank, 0); + std::vector index_count(updates_rank, 1); + for (int64 i = 0; i < updates_rank; i++) { + bool is_update_scatter_dim = + !c_binary_search(dim_numbers.update_window_dims(), i); + if (is_update_scatter_dim) { + index_count[i] = updates_shape.dimensions(i); + } + } + return {std::move(index_base), std::move(index_count), + std::vector(updates_rank, 1)}; + } + + // Return an ShapeUtil::IndexIterationSpace that iterates over the update + // window dimensions while keeping the rest of the update dimensions clamped + // to 0. + ShapeUtil::IndexIterationSpace IterationSpaceForUpdateWindowIndices( + const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { + int64 updates_rank = updates_shape.dimensions_size(); + std::vector index_base(updates_rank, 0); + std::vector index_count(updates_rank, 1); + for (int64 i = 0; i < updates_rank; i++) { + bool is_update_window_dim = + c_binary_search(dim_numbers.update_window_dims(), i); + if (is_update_window_dim) { + index_count[i] = updates_shape.dimensions(i); + } + } + return {std::move(index_base), std::move(index_count), + std::vector(updates_rank, 1)}; + } + + // This functor computes the contribution of scatter_indices to an input index + // corresponding to an update index. That is, given an update index I, it + // picks out the scatter indices in I and uses them to look up a scatter + // index, S, from the scatter indices tensor, and expands S into the input + // space according to scatter_dims_to_operand_dims. + // + // This is similar to the class HloEvaluator::OutputGatherIndexToInputIndex + // that does the corresponding function for Gather. + class UpdateScatterIndexToInputIndex { + public: + // The constructor does some setup work that is amortized across all + // iterations. + explicit UpdateScatterIndexToInputIndex( + const ScatterDimensionNumbers* dim_numbers, const Shape& input_shape, + const Shape& updates_shape, const Literal* scatter_indices) + : dim_numbers_(*dim_numbers), scatter_indices_(*scatter_indices) { + for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { + update_dim_is_scatter_dims_.push_back( + !c_binary_search(dim_numbers_.update_window_dims(), i)); + } + + for (int64 i = 0; i < input_shape.dimensions_size(); i++) { + int64 index_of_input_dim_in_index_vector = + FindIndex(dim_numbers_.scatter_dims_to_operand_dims(), i); + if (index_of_input_dim_in_index_vector == + dim_numbers_.scatter_dims_to_operand_dims_size()) { + input_dim_value_to_index_vector_.push_back(-1); + } else { + input_dim_value_to_index_vector_.push_back( + index_of_input_dim_in_index_vector); + } + } + + index_vector_index_.resize(scatter_indices_.shape().dimensions_size()); + input_index_.resize(input_shape.dimensions_size()); + int64 index_vector_size = + scatter_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); + index_vector_.resize(index_vector_size); + } + + // Returns the contribution of scatter_indices to the input index + // corresponding to update_index. See scatter_inner_loop_body. + // + // This is conceptually a stateless transformation from update_index to the + // scatter input index, but: + // + // - Instead of allocating memory to represent the scatter input index on + // every invocation we reuse the same storage for the result + // (input_index_), mutating it in place. + // - Instead of allocating buffers for temporary values like + // index_vector_index_ and index_vector on every invocation, we reuse the + // same storage for all invocations. + // + // This returns an arrayslice into memory owned by the class. + StatusOr> operator()( + tensorflow::gtl::ArraySlice update_index) { + PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index); + TF_RETURN_IF_ERROR(FetchIndexVector()); + PropagateIndexVectorToInputIndex(); + return tensorflow::gtl::ArraySlice(input_index_); + } + + private: + // Propagates the scatter index dimensions from the update index into + // index_vector_index_ by mutating index_vector_index_ in place. Does not + // update the dim_numbers.index_vector_dim() dimension -- that's the + // dimension we iterate over in FetchIndexVector. + void PropagateUpdateIndexScatterDimsToIndexVectorIndex( + tensorflow::gtl::ArraySlice update_index) { + int64 index_vector_index_i = 0; + for (int64 i = 0, e = update_index.size(); i < e; i++) { + if (!update_dim_is_scatter_dims_[i]) { + continue; + } + + if (index_vector_index_i == dim_numbers_.index_vector_dim()) { + index_vector_index_i++; + } + + index_vector_index_[index_vector_index_i++] = update_index[i]; + } + } + + // Populates index_vector_ by iterating over scatter_indices_ according to + // index_vector_index_. + Status FetchIndexVector() { + int64 index_vector_dim = dim_numbers_.index_vector_dim(); + for (int64 i = 0, e = index_vector_.size(); i < e; i++) { + index_vector_index_[index_vector_dim] = i; + TF_ASSIGN_OR_RETURN(index_vector_[i], scatter_indices_.GetIntegralAsS64( + index_vector_index_)); + } + return Status::OK(); + } + + // Populates input_index_. + void PropagateIndexVectorToInputIndex() { + for (int64 i = 0, e = input_index_.size(); i < e; i++) { + if (input_dim_value_to_index_vector_[i] != -1) { + input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]]; + } + + // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] + // remains 0, as set by the constructor. + } + } + + // input_dim_value_to_index_vector_[i] tells us how to compute dimension i + // of the input index from the index vector. See + // PropagateIndexVectorToInputIndex. + std::vector input_dim_value_to_index_vector_; + + // update_dim_is_scatter_dims_[i] is true iff the update index i is a + // scatter dimension. + std::vector update_dim_is_scatter_dims_; + + // The buffer into which we construct an index into scatter_indices_ to + // fetch the index vector. + std::vector index_vector_index_; + + // The index vector fetched from scatter_indices_. + std::vector index_vector_; + + // The result computed by this functor. operator() returns an ArraySlice + // into this vector. + std::vector input_index_; + + const ScatterDimensionNumbers& dim_numbers_; + const Literal& scatter_indices_; + }; + + // This functor computes the contribution of the window indices in an update + // index to an input index. That is, given an update index I it picks out the + // update window indices in I and expands it into a window index into the + // input shape. + // + // This is similar to the class HloEvaluator::OutputWindowIndexToInputIndex + // that does the corresponding function for Gather. + class UpdateWindowIndexToInputIndex { + public: + // The constructor does some setup work that is amortized across all + // iterations. + explicit UpdateWindowIndexToInputIndex( + const ScatterDimensionNumbers& dim_numbers, const Shape& input_shape, + const Shape& updates_shape) { + std::vector window_index_to_update_index; + int64 update_index_count = 0; + for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { + if (c_binary_search(dim_numbers.update_window_dims(), i)) { + window_index_to_update_index.push_back(update_index_count++); + } else { + update_index_count++; + } + } + + int64 window_dim_count = 0; + for (int64 i = 0; i < input_shape.dimensions_size(); i++) { + if (c_binary_search(dim_numbers.inserted_window_dims(), i)) { + input_dim_value_to_update_index_.push_back(-1); + } else { + input_dim_value_to_update_index_.push_back( + window_index_to_update_index[window_dim_count++]); + } + } + + input_index_.resize(input_shape.dimensions_size()); + } + + // Returns the contribution of the window indices to the input index + // corresponding to update_index. See scatter_inner_loop_body. + // + // This is conceptually a stateless transformation from update_index to the + // window input index, but instead of allocating memory to represent the + // scatter input index on every invocation we reuse the same storage for the + // result (input_index_), mutating it in place. + // + // This returns an arrayslice into memory owned by the class. + StatusOr> operator()( + tensorflow::gtl::ArraySlice update_index) { + PropagateUpdateIndexWindowDimsToInputIndex(update_index); + return tensorflow::gtl::ArraySlice(input_index_); + } + + // Returns for a given 'input_dim' the corresponding update dimension index, + // or -1 if 'input_dim' is an elided window dimension. + int64 input_dim_value_to_update_index(int64 input_dim) { + return input_dim_value_to_update_index_[input_dim]; + } + + private: + // Propagates window dimensions from the update index to input_index_ by + // mutating input_index_ in place. + void PropagateUpdateIndexWindowDimsToInputIndex( + tensorflow::gtl::ArraySlice update_index) { + for (int64 i = 0, e = input_index_.size(); i < e; i++) { + if (input_dim_value_to_update_index_[i] != -1) { + input_index_[i] = update_index[input_dim_value_to_update_index_[i]]; + } + + // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] + // remains 0, as set by the constructor. + } + } + + // input_dim_value_to_index_vector_[i] tells us how to compute dimension i + // of the input index from the update index. See + // PropagateUpdateIndexWindowDimsToInputIndex. + std::vector input_dim_value_to_update_index_; + + // The result computed by this functor. operator() returns an ArraySlice + // into this vector. + std::vector input_index_; + }; + + Status HandleScatter(HloInstruction* scatter) override { + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + const Literal& operand = + parent_->GetEvaluatedLiteralFor(scatter->operand(0)); + std::unique_ptr reshaped_scatter_indices; + TF_ASSIGN_OR_RETURN(const Literal& scatter_indices, + ReshapedScatterIndices(dim_numbers.index_vector_dim(), + parent_->GetEvaluatedLiteralFor( + scatter->operand(1)), + &reshaped_scatter_indices)); + const Literal& updates = + parent_->GetEvaluatedLiteralFor(scatter->operand(2)); + const Shape& updates_shape = updates.shape(); + const Shape& operand_shape = operand.shape(); + + ShapeUtil::IndexIterationSpace scatter_indices_iteration_space = + IterationSpaceForUpdateScatterIndices(updates_shape, dim_numbers); + ShapeUtil::IndexIterationSpace window_indices_iteration_space = + IterationSpaceForUpdateWindowIndices(updates_shape, dim_numbers); + + std::vector input_index(operand_shape.dimensions_size()); + std::vector update_index(updates_shape.dimensions_size()); + std::vector input_scatter_index_clamped( + operand_shape.dimensions_size()); + + UpdateScatterIndexToInputIndex update_scatter_index_to_input_index( + &scatter->scatter_dimension_numbers(), /*input_shape=*/operand_shape, + updates_shape, &scatter_indices); + UpdateWindowIndexToInputIndex update_window_index_to_input_index( + scatter->scatter_dimension_numbers(), /*input_shape=*/operand_shape, + updates_shape); + + // Initialize the result with the operand. This makes it easier to handle + // the updates even when the indices are repeated. + std::unique_ptr result = operand.CloneToUnique(); + HloEvaluator embedded_evaluator; + auto scatter_inner_loop_body = + [&](tensorflow::gtl::ArraySlice update_window_index, + tensorflow::gtl::ArraySlice input_scatter_index, + tensorflow::gtl::ArraySlice update_scatter_index) + -> StatusOr { + TF_ASSIGN_OR_RETURN( + tensorflow::gtl::ArraySlice input_window_index, + update_window_index_to_input_index(update_window_index)); + for (int i = 0, e = update_index.size(); i < e; i++) { + update_index[i] = update_scatter_index[i] + update_window_index[i]; + DCHECK_LT(update_index[i], updates_shape.dimensions(i)); + } + for (int i = 0, e = input_scatter_index.size(); i < e; i++) { + int64 update_dim = + update_window_index_to_input_index.input_dim_value_to_update_index( + i); + // If 'update_dim' is -1, it means 'i' is an elided window dim. This + // means we set the iteration index to 0, so for the purpose of the + // following calculations we can consider the update dimension size to + // be 1. + int64 update_dim_size = + update_dim == -1 ? 1 : updates_shape.dimensions(update_dim); + // Clamp the scatter index so that the scatter region fits in the + // operand. input_scatter_index_clamped[i] = + // clamp(input_scatter_index[i], 0, + // operand_shape.dimensions(i) - + // update_dim_size); + input_scatter_index_clamped[i] = + std::min(operand_shape.dimensions(i) - update_dim_size, + std::max(0LL, input_scatter_index[i])); + } + for (int i = 0, e = input_index.size(); i < e; i++) { + input_index[i] = input_scatter_index_clamped[i] + input_window_index[i]; + DCHECK_GE(input_index[i], 0); + DCHECK_LT(input_index[i], operand_shape.dimensions(i)); + } + + auto result_value_literal = + LiteralUtil::CreateR0(result->Get(input_index)); + auto update_value_literal = + LiteralUtil::CreateR0(updates.Get(update_index)); + std::unique_ptr updated_result = + embedded_evaluator + .Evaluate( + *scatter->to_apply(), + {result_value_literal.get(), update_value_literal.get()}) + .ConsumeValueOrDie(); + // Clear visit states so that the we can use the evaluate again on the + // same computation. + embedded_evaluator.ResetVisitStates(); + result->Set(input_index, updated_result->Get({})); + return true; + }; + + auto scatter_outer_loop_body = + [&](tensorflow::gtl::ArraySlice update_scatter_index) + -> StatusOr { + TF_ASSIGN_OR_RETURN( + tensorflow::gtl::ArraySlice input_scatter_index, + update_scatter_index_to_input_index(update_scatter_index)); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + updates_shape, window_indices_iteration_space, + [&](tensorflow::gtl::ArraySlice update_window_index) { + return scatter_inner_loop_body( + update_window_index, input_scatter_index, update_scatter_index); + })); + return true; + }; + + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + updates_shape, scatter_indices_iteration_space, + scatter_outer_loop_body)); + parent_->evaluated_[scatter] = std::move(result); + return Status::OK(); + } + Status HandleSlice(HloInstruction* slice) override { auto operand = slice->operand(0); const Shape& shape = slice->shape(); @@ -1757,7 +2176,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return operand_literal.Get(operand_index); }; - auto result = Literal::CreateFromDimensions( + auto result = LiteralUtil::CreateFromDimensions( shape.element_type(), AsInt64Slice(shape.dimensions())); TF_RETURN_IF_ERROR(result->Populate(func)); parent_->evaluated_[slice] = std::move(result); @@ -1959,6 +2378,30 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleReducePrecision(reduce_precision); } + template ::value || + std::is_same::value || + std::is_same::value>::type* = nullptr> + Status HandleIota(HloInstruction* iota) { + auto result = MakeUnique(iota->shape()); + auto data = result->data(); + std::iota(data.begin(), data.end(), 0); + parent_->evaluated_[iota] = std::move(result); + return Status::OK(); + } + template ::value || + std::is_same::value || + std::is_same::value)>::type* = nullptr> + Status HandleIota(HloInstruction* iota) { + return InvalidArgument("Unsupported type for iota"); + } + Status HandleIota(HloInstruction* iota) override { + return HandleIota(iota); + } + private: // Creates a vector of multipliers which can be used to create a linear index // into shape. @@ -2016,10 +2459,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { start_indices_typed.end()); // Clamp the start indices so the slice is in-bounds w.r.t the operand. - - // TODO(b/74360564): This is implementation defined behavior, but is - // currently respected by all implementations. Change this if we ever decide - // to officially document different behavior. for (int64 i = 0; i < start.size(); ++i) { start[i] = std::min( std::max(int64{0}, start[i]), @@ -2053,10 +2492,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { start_indices_typed.end()); // Clamp the update start indices so the slice is in-bounds w.r.t the // operand. - - // TODO(b/74360564): This is implementation defined behavior, but is - // currently respected by all implementations. Change this if we ever decide - // to oficially document different behavior. for (int64 i = 0; i < rank; ++i) { start[i] = std::min( std::max(0, start[i]), @@ -2175,38 +2610,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return rhs_unsigned >= lhs_size_unsigned; } - // It's UB to use std::sort with std::less, because of NaNs. Define - // "safe" less functions which are actually strict weak orders. - template ::value>::type* = - nullptr> - static bool SafeLess(const NativeT& a, const NativeT& b) { - return a < b; - } - - template ::value || - std::is_same::value>::type* = nullptr> - static bool SafeLess(const NativeT& a, const NativeT& b) { - if (std::isnan(b)) { - return !std::isnan(a); - } else { - return a < b; - } - } - - template ::value>::type* = nullptr> - static bool SafeLess(const NativeT& a, const NativeT& b) { - if (Eigen::half_impl::isnan(b)) { - return !Eigen::half_impl::isnan(a); - } else { - return a < b; - } - } - HloEvaluator* parent_; }; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 8856723f67cf22c44e5ee482777a6a0908d1725d..1efa6eb5bda7e1cb90874e0466aafd2c788a3fbf 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -27,7 +27,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -844,7 +844,10 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( *elem_count *= dim; } } - if (elem_count.has_value() && *elem_count <= 8) { + // Allow HloDotDumper to print HloInstruction reconstructed from HloProto + // collected from profiling tools. Those constants may not have a valid + // literal. + if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { return Printf("%s (%s)", constant->literal().ToString(), ShapeUtil::HumanString(constant->shape())); } @@ -948,6 +951,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kGe: case HloOpcode::kGt: case HloOpcode::kImag: + case HloOpcode::kIota: case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: @@ -966,6 +970,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kRemainder: case HloOpcode::kRng: case HloOpcode::kRoundNearestAfz: + case HloOpcode::kSelect: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: @@ -1001,7 +1006,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kPad: case HloOpcode::kReshape: case HloOpcode::kReverse: - case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: case HloOpcode::kTranspose: // De-emphasize scalar-shaped data movement ops and all data movement ops // inside fusion nodes, both of which are essentially free. @@ -1017,6 +1022,8 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kWhite; } return kGreen; + case HloOpcode::kScatter: + // Do not de-emphasize Scatter, since it involves significant work. case HloOpcode::kCopy: // Emphasize copy nodes, which are either physical transposes (and thus // significant), or copies of read-only buffers (and thus dead weight). @@ -1041,6 +1048,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kMap: return kGray; case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kRecv: diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 68f41a1cbb4db228f5dcf8b4a6130f05e81262a8..1d7a062c55696de9db4b187efd86bce191279083 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -120,7 +121,7 @@ TEST(HloGraphDumperTest, NestedFusion) { TEST(HloGraphDumperTest, Constant) { HloComputation::Builder b("b"); auto instruction = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(-42))); instruction->SetAndSanitizeName("i_am_a_constant_root_instruction"); HloModuleConfig config; HloModule m(TestName(), config); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 5aaeec802fa64fe5aefc76fb7556c0ad0487d179..8690f2cdaa9b45d126e91b123c6992cbe2f27e1d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -22,7 +22,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" @@ -112,29 +112,30 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kSend: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Send instruction should have 1 operand but sees " + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Send instruction should have 2 operand but sees " << proto.operand_ids_size(); - instruction = CreateSend(operands(0), proto.channel_id()); + instruction = CreateSend(operands(0), operands(1), proto.channel_id(), + proto.is_host_transfer()); break; case HloOpcode::kSendDone: TF_RET_CHECK(proto.operand_ids_size() == 1) << "SendDone instruction should have 1 operand but sees " << proto.operand_ids_size(); - instruction = CreateSendDone(operands(0)); + instruction = CreateSendDone(operands(0), proto.is_host_transfer()); break; case HloOpcode::kRecv: - TF_RET_CHECK(proto.operand_ids_size() == 0) - << "Recv instruction should have 0 operand but sees " + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Recv instruction should have 1 operand but sees " << proto.operand_ids_size(); - instruction = - CreateRecv(proto.shape().tuple_shapes(0), proto.channel_id()); + instruction = CreateRecv(proto.shape().tuple_shapes(0), operands(0), + proto.channel_id(), proto.is_host_transfer()); break; case HloOpcode::kRecvDone: TF_RET_CHECK(proto.operand_ids_size() == 1) << "RecvDone instruction should have 1 operand but sees " << proto.operand_ids_size(); - instruction = CreateRecvDone(operands(0)); + instruction = CreateRecvDone(operands(0), proto.is_host_transfer()); break; case HloOpcode::kReverse: TF_RET_CHECK(proto.operand_ids_size() == 1) @@ -163,6 +164,20 @@ StatusOr> HloInstruction::CreateFromProto( proto.dimensions().end()), computations(0)); break; + case HloOpcode::kSort: { + TF_RET_CHECK(proto.operand_ids_size() == 1 || + proto.operand_ids_size() == 2) + << "Sort instruction should have 1 or 2 operands but has " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.dimensions().size() == 1) + << "Sort instruction should have 1 dimension"; + HloInstruction* keys = operands(0); + HloInstruction* values = + proto.operand_ids_size() == 2 ? operands(1) : nullptr; + instruction = + CreateSort(proto.shape(), proto.dimensions(0), keys, values); + break; + } case HloOpcode::kTranspose: TF_RET_CHECK(proto.operand_ids_size() == 1) << "Transpose instruction should have 1 operand but sees " @@ -271,7 +286,7 @@ StatusOr> HloInstruction::CreateFromProto( // converted to take tokens. instruction = CreateInfeed(data_shape, proto.infeed_config()); } else { - CHECK_EQ(proto.operand_ids_size(), 2); + CHECK_EQ(proto.operand_ids_size(), 1); instruction = CreateInfeed(data_shape, operands(0), proto.infeed_config()); } @@ -305,6 +320,15 @@ StatusOr> HloInstruction::CreateFromProto( /*all_reduce_id=*/all_reduce_id); break; } + case HloOpcode::kAllToAll: { + instruction = CreateAllToAll( + proto.shape(), all_operands(), + /*replica_groups=*/ + std::vector(proto.replica_groups().begin(), + proto.replica_groups().end()), + /*barrier=*/proto.cross_replica_sum_barrier()); + break; + } case HloOpcode::kConvolution: TF_RET_CHECK(proto.operand_ids_size() == 2) << "Convolution instruction should have 2 operands but sees " @@ -372,6 +396,39 @@ StatusOr> HloInstruction::CreateFromProto( slice_sizes); break; } + case HloOpcode::kGather: { + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Gather instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_gather_dimension_numbers()) + << "Gather instruction should have GatherDimensionNumbers set."; + std::unique_ptr gather_dimension_numbers = + MakeUnique(proto.gather_dimension_numbers()); + std::vector gather_window_bounds; + for (int64 bound : proto.gather_window_bounds()) { + gather_window_bounds.push_back(bound); + } + instruction = + CreateGather(proto.shape(), operands(0), operands(1), + *gather_dimension_numbers, gather_window_bounds); + break; + } + case HloOpcode::kScatter: { + TF_RET_CHECK(proto.operand_ids_size() == 3) + << "Scatter instruction should have 3 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_scatter_dimension_numbers()) + << "Scatter instruction should have ScatterDimensionNumbers set."; + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Scatter instruction should have 1 called computation but sees " + << proto.called_computation_ids_size(); + auto scatter_dimension_numbers = MakeUnique( + proto.scatter_dimension_numbers()); + instruction = + CreateScatter(proto.shape(), operands(0), operands(1), operands(2), + computations(0), *scatter_dimension_numbers); + break; + } default: { instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -413,13 +470,6 @@ StatusOr> HloInstruction::CreateFromProto( instruction->set_sharding(sharding); } - if (proto.has_gather_dimension_numbers()) { - instruction->gather_dimension_numbers_ = - MakeUnique(proto.gather_dimension_numbers()); - } - for (int64 bound : proto.gather_window_bounds()) { - instruction->gather_window_bounds_.push_back(bound); - } return std::move(instruction); } @@ -438,6 +488,11 @@ StatusOr> HloInstruction::CreateFromProto( return MakeUnique(std::move(literal)); } +/* static */ std::unique_ptr HloInstruction::CreateIota( + const Shape& shape) { + return WrapUnique(new HloInstruction(HloOpcode::kIota, shape)); +} + /* static */ std::unique_ptr HloInstruction::CreateGetTupleElement(const Shape& shape, HloInstruction* operand, int64 index) { @@ -489,7 +544,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: - case HloOpcode::kSort: case HloOpcode::kTanh: break; default: @@ -542,8 +596,9 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, // Only certain opcodes are supported with CreateTernary: opcodes of ternary // instructions with no auxiliary fields. switch (opcode) { - case (HloOpcode::kClamp): - case (HloOpcode::kSelect): + case HloOpcode::kClamp: + case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: break; default: LOG(FATAL) << "Invalid ternary instruction opcode " @@ -625,6 +680,14 @@ HloInstruction::CreateCrossReplicaSum( all_reduce_id); } +/* static */ std::unique_ptr HloInstruction::CreateAllToAll( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + const std::vector& replica_groups, + tensorflow::StringPiece barrier) { + return MakeUnique(shape, operands, replica_groups, + barrier); +} + /* static */ std::unique_ptr HloInstruction::CreateInfeed( const Shape& infeed_shape, HloInstruction* token_operand, const string& config) { @@ -651,29 +714,33 @@ HloInstruction::CreateCrossReplicaSum( } /* static */ std::unique_ptr HloInstruction::CreateSend( - HloInstruction* operand, int64 channel_id) { - return MakeUnique(operand, channel_id); + HloInstruction* operand, HloInstruction* token, int64 channel_id, + bool is_host_transfer) { + return MakeUnique(operand, token, channel_id, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateSendDone( - HloInstruction* operand) { + HloInstruction* operand, bool is_host_transfer) { auto send_operand = DynCast(operand); CHECK(send_operand != nullptr) << "SendDone must take the context operand from Send"; - return MakeUnique(send_operand); + return MakeUnique(send_operand, is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateRecv( - const Shape& shape, int64 channel_id) { - return MakeUnique(shape, channel_id); + const Shape& shape, HloInstruction* token, int64 channel_id, + bool is_host_transfer) { + return MakeUnique(shape, token, channel_id, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateRecvDone( - HloInstruction* operand) { + HloInstruction* operand, bool is_host_transfer) { auto recv_operand = DynCast(operand); CHECK(recv_operand != nullptr) << "RecvDone must take the context operand from Recv"; - return MakeUnique(recv_operand); + return MakeUnique(recv_operand, is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateReverse( @@ -684,6 +751,7 @@ HloInstruction::CreateCrossReplicaSum( /* static */ std::unique_ptr HloInstruction::CreateAfterAll( tensorflow::gtl::ArraySlice operands) { + CHECK(!operands.empty()); auto instruction = WrapUnique( new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); for (auto operand : operands) { @@ -692,6 +760,11 @@ HloInstruction::CreateCrossReplicaSum( return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateToken() { + return WrapUnique( + new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); +} + /* static */ std::unique_ptr HloInstruction::CreateWhile( const Shape& shape, HloComputation* condition, HloComputation* body, HloInstruction* init) { @@ -772,11 +845,25 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, } /* static */ std::unique_ptr HloInstruction::CreateReduce( - const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + const Shape& shape, HloInstruction* operand, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation) { + auto instruction = WrapUnique(new HloReduceInstruction( + shape, {operand, init_value}, dimensions_to_reduce, reduce_computation)); + return std::move(instruction); +} + +/* static */ std::unique_ptr HloInstruction::CreateReduce( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::gtl::ArraySlice init_values, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation) { - return MakeUnique( - shape, arg, init_value, dimensions_to_reduce, reduce_computation); + std::vector all_args; + all_args.reserve(operands.size() * 2); + all_args.insert(all_args.end(), operands.begin(), operands.end()); + all_args.insert(all_args.end(), init_values.begin(), init_values.end()); + return MakeUnique(shape, all_args, dimensions_to_reduce, + reduce_computation); } /* static */ std::unique_ptr HloInstruction::CreateReduceWindow( @@ -908,6 +995,12 @@ HloInstruction::CreateBroadcastSequence( return MakeUnique(shape, operand, dimensions); } +/* static */ std::unique_ptr HloInstruction::CreateSort( + const Shape& shape, int64 dimension, HloInstruction* keys, + HloInstruction* values) { + return MakeUnique(shape, dimension, keys, values); +} + /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { return MakeUnique(shape, fusion_kind, fused_root); @@ -952,6 +1045,8 @@ bool HloInstruction::HasSideEffectNoRecurse() const { case HloOpcode::kTrace: case HloOpcode::kHostCompute: return true; + case HloOpcode::kCrossReplicaSum: + return all_reduce_id().has_value(); default: return false; } @@ -1010,34 +1105,18 @@ bool HloInstruction::HasSideEffect() const { const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice window_bounds) { - std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kGather, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(gather_indices); - instruction->gather_dimension_numbers_ = - MakeUnique(gather_dim_numbers); - c_copy(window_bounds, std::back_inserter(instruction->gather_window_bounds_)); - return instruction; + return MakeUnique(shape, operand, gather_indices, + gather_dim_numbers, window_bounds); } -/* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice output_window_dims, - tensorflow::gtl::ArraySlice elided_window_dims, - tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, - int64 index_vector_dim) { - GatherDimensionNumbers gather_dim_numbers; - for (int64 output_window_dim : output_window_dims) { - gather_dim_numbers.add_output_window_dims(output_window_dim); - } - for (int64 elided_window_dim : elided_window_dims) { - gather_dim_numbers.add_elided_window_dims(elided_window_dim); - } - for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { - gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); - } - - gather_dim_numbers.set_index_vector_dim(index_vector_dim); - return gather_dim_numbers; +/* static */ std::unique_ptr HloInstruction::CreateScatter( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers) { + return MakeUnique(shape, operand, scatter_indices, + updates, update_computation, + scatter_dim_numbers); } /* static */ std::unique_ptr HloInstruction::CreateDomain( @@ -1091,6 +1170,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kGetTupleElement: case HloOpcode::kReducePrecision: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kConvolution: @@ -1100,6 +1180,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kHostCompute: case HloOpcode::kPad: case HloOpcode::kDynamicSlice: + case HloOpcode::kSort: + case HloOpcode::kGather: + case HloOpcode::kScatter: + case HloOpcode::kIota: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1122,7 +1206,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: - case HloOpcode::kSort: case HloOpcode::kTanh: CHECK_EQ(new_operands.size(), 1); clone = CreateUnary(shape, opcode_, new_operands[0]); @@ -1156,6 +1239,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( // Ternary ops. case HloOpcode::kClamp: case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: CHECK_EQ(new_operands.size(), 3); clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1], new_operands[2]); @@ -1201,11 +1285,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( true_computation(), new_operands[2], false_computation()); break; - case HloOpcode::kGather: - CHECK_EQ(new_operands.size(), 2); - clone = CreateGather(shape, new_operands[0], new_operands[1], - *gather_dimension_numbers_, gather_window_bounds_); - break; case HloOpcode::kDomain: CHECK_EQ(new_operands.size(), 1); clone = @@ -1213,7 +1292,11 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( user_side_metadata_->Clone()); break; case HloOpcode::kAfterAll: - clone = CreateAfterAll(new_operands); + if (new_operands.empty()) { + clone = CreateToken(); + } else { + clone = CreateAfterAll(new_operands); + } break; } SetupDerivedInstruction(clone.get()); @@ -1495,11 +1578,10 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kSubtract: case HloOpcode::kTanh: case HloOpcode::kTuple: + case HloOpcode::kTupleSelect: return true; - // These opcodes have complex or special behavior so just return false. - case HloOpcode::kDomain: - case HloOpcode::kWhile: + // This opcode has complex or special behavior so just return false. case HloOpcode::kAfterAll: return false; @@ -1508,11 +1590,6 @@ bool HloInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals(dot_dimension_numbers(), other.dot_dimension_numbers()); - case HloOpcode::kGather: - return protobuf_util::ProtobufEquals(gather_dimension_numbers(), - other.gather_dimension_numbers()) && - gather_window_bounds() == other.gather_window_bounds(); - // Remaining instructions with special values. case HloOpcode::kCall: return eq_computations(to_apply(), other.to_apply()); @@ -1520,9 +1597,17 @@ bool HloInstruction::IdenticalSlowPath( return eq_computations(true_computation(), other.true_computation()) && eq_computations(false_computation(), other.false_computation()); - // These opcodes are not yet supported. - case HloOpcode::kSort: + case HloOpcode::kWhile: { + if (eq_computations(while_body(), other.while_body()) && + eq_computations(while_condition(), other.while_condition())) { + return true; + } return false; + } + + case HloOpcode::kDomain: + return operand_side_metadata().Matches(other.operand_side_metadata()) && + user_side_metadata().Matches(other.user_side_metadata()); // Ops migrated to subclasses should never come to this line. // TODO(b/80131774): Remove this switch when migration is complete. @@ -1537,11 +1622,13 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kReverse: case HloOpcode::kConcatenate: case HloOpcode::kReduce: + case HloOpcode::kSort: case HloOpcode::kTranspose: case HloOpcode::kBroadcast: case HloOpcode::kMap: case HloOpcode::kSlice: case HloOpcode::kConstant: + case HloOpcode::kIota: case HloOpcode::kTrace: case HloOpcode::kFusion: case HloOpcode::kRng: @@ -1551,6 +1638,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: case HloOpcode::kConvolution: case HloOpcode::kCustomCall: case HloOpcode::kReduceWindow: @@ -1558,9 +1646,12 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kHostCompute: case HloOpcode::kPad: case HloOpcode::kDynamicSlice: + case HloOpcode::kGather: + case HloOpcode::kScatter: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } + return false; } void HloInstruction::RemoveUser(HloInstruction* user) { @@ -1610,8 +1701,8 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(), new_operand->shape())) - << old_operand->shape().ShortDebugString() << " is not compatible with " - << new_operand->shape().ShortDebugString(); + << old_operand->shape() << " is not compatible with " + << new_operand->shape(); operands_[operand_num] = new_operand; VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with " @@ -1663,6 +1754,7 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kReduceWindow: case HloOpcode::kReduce: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kScatter: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; default: @@ -1681,6 +1773,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) { case HloOpcode::kReduceWindow: case HloOpcode::kReduce: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kScatter: CHECK_EQ(called_computations_.size(), 1); called_computations_[0] = computation; break; @@ -1820,7 +1913,6 @@ bool HloInstruction::IsElementwiseImpl( // Ternary elementwise operations. case HloOpcode::kSelect: - return !ShapeUtil::IsTuple(shape_); case HloOpcode::kClamp: return true; @@ -1832,6 +1924,10 @@ bool HloInstruction::IsElementwiseImpl( } } +bool HloInstruction::IsCrossModuleAllReduce() const { + return opcode() == HloOpcode::kCrossReplicaSum && all_reduce_id(); +} + string HloInstruction::ToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { @@ -1924,11 +2020,6 @@ std::vector HloInstruction::ExtraAttributesToString( if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); } - if (gather_dimension_numbers_ != nullptr) { - extra.push_back(GatherDimensionNumbersToString()); - extra.push_back( - StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")); - } if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { @@ -1949,7 +2040,8 @@ std::vector HloInstruction::ExtraAttributesToString( } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || opcode() == HloOpcode::kReduce || - opcode() == HloOpcode::kCrossReplicaSum) { + opcode() == HloOpcode::kCrossReplicaSum || + opcode() == HloOpcode::kScatter) { extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { @@ -1985,6 +2077,7 @@ std::vector HloInstruction::ExtraAttributesToString( case HloOpcode::kReduceWindow: case HloOpcode::kReduce: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kScatter: extra.push_back( StrCat("to_apply=\n", to_apply()->ToString(new_options))); break; @@ -2015,8 +2108,8 @@ std::vector HloInstruction::ExtraAttributesToString( } if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), - "\", entry=", operand_side_metadata_->ToString(), - ", exit=", user_side_metadata_->ToString(), "}")); + "\", entry=", user_side_metadata_->ToString(), + ", exit=", operand_side_metadata_->ToString(), "}")); } return extra; @@ -2058,14 +2151,6 @@ HloInstructionProto HloInstruction::ToProto() const { if (dot_dimension_numbers_ != nullptr) { *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; } - if (gather_dimension_numbers_ != nullptr) { - *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_; - } - if (opcode() == HloOpcode::kGather) { - for (int64 bound : gather_window_bounds()) { - proto.add_gather_window_bounds(bound); - } - } if (has_sharding()) { *proto.mutable_sharding() = sharding().ToProto(); @@ -2191,12 +2276,16 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleRemainder(this); case HloOpcode::kSelect: return visitor->HandleSelect(this); + case HloOpcode::kTupleSelect: + return visitor->HandleTupleSelect(this); case HloOpcode::kConvolution: return visitor->HandleConvolution(this); case HloOpcode::kFft: return visitor->HandleFft(this); case HloOpcode::kCrossReplicaSum: return visitor->HandleCrossReplicaSum(this); + case HloOpcode::kAllToAll: + return visitor->HandleAllToAll(this); case HloOpcode::kTuple: return visitor->HandleTuple(this); case HloOpcode::kMap: @@ -2289,10 +2378,14 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSendDone(this); case HloOpcode::kGather: return visitor->HandleGather(this); + case HloOpcode::kScatter: + return visitor->HandleScatter(this); case HloOpcode::kDomain: return visitor->HandleDomain(this); case HloOpcode::kAfterAll: return visitor->HandleAfterAll(this); + case HloOpcode::kIota: + return visitor->HandleIota(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -2824,26 +2917,6 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } -string HloInstruction::GatherDimensionNumbersToString() const { - CHECK_NE(gather_dimension_numbers_.get(), nullptr); - string output_window_dims = - StrCat("output_window_dims={", - Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); - string elided_window_dims = - StrCat("elided_window_dims={", - Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); - string gather_dims_to_operand_dims = StrCat( - "gather_dims_to_operand_dims={", - Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); - string index_vector_dim = StrCat( - "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); - - return Join>( - {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, - index_vector_dim}, - ", "); -} - bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: @@ -3087,12 +3160,23 @@ const std::vector& HloInstruction::replica_group_ids() const { return Cast(this)->replica_group_ids(); } +const std::vector& HloInstruction::replica_groups() const { + return Cast(this)->replica_groups(); +} + string HloInstruction::cross_replica_sum_barrier() const { - return Cast(this)->cross_replica_sum_barrier(); + if (opcode() == HloOpcode::kCrossReplicaSum) { + return Cast(this)->cross_replica_sum_barrier(); + } + return Cast(this)->cross_replica_sum_barrier(); } void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) { - return Cast(this)->set_cross_replica_sum_barrier( + if (opcode() == HloOpcode::kCrossReplicaSum) { + return Cast(this)->set_cross_replica_sum_barrier( + barrier); + } + return Cast(this)->set_cross_replica_sum_barrier( barrier); } @@ -3157,4 +3241,19 @@ int64 HloInstruction::slice_sizes(int64 dimension) const { const std::vector& HloInstruction::dynamic_slice_sizes() const { return Cast(this)->dynamic_slice_sizes(); } + +const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const { + return Cast(this)->gather_dimension_numbers(); +} + +tensorflow::gtl::ArraySlice HloInstruction::gather_window_bounds() + const { + return Cast(this)->gather_window_bounds(); +} + +const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers() + const { + return Cast(this)->scatter_dimension_numbers(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 59a383218c3ce462b66596208b1fcd6233376b39..3c575ae6ea8e60f48def4debcd9cfbea63e396b2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -33,7 +33,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/iterator_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -346,6 +346,9 @@ class HloInstruction { static std::unique_ptr CreateConstant( std::unique_ptr literal); + // Creates an Iota instruction. + static std::unique_ptr CreateIota(const Shape& shape); + // Creates a get tuple element instruction. static std::unique_ptr CreateGetTupleElement( const Shape& shape, HloInstruction* operand, int64 index); @@ -444,8 +447,27 @@ class HloInstruction { HloComputation* reduce_computation, tensorflow::gtl::ArraySlice replica_group_ids, tensorflow::StringPiece barrier, - const tensorflow::gtl::optional& all_reduce_id = - tensorflow::gtl::nullopt); + const tensorflow::gtl::optional& all_reduce_id); + + // This op handles the communication of an Alltoall operation. On each core, + // the operands are N ops in the same shape, where N is the number of cores + // participating the Alltoall. Then the N operands are scattered to N cores, + // e.g., the ith operand is sent to the ith core. Then each core gathers the + // received data into a tuple. + // + // - `replica_groups`: each ReplicaGroup contains a list of replica id. If + // empty, all replicas belong to one group in the order of 0 - (n-1). Alltoall + // will be applied within subgroups in the specified order. For example, + // replica groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied + // within replica 1, 2, 3, and in the gather phase, the received blocks will + // be concatenated in the order of 1, 2, 3; another Alltoall will be applied + // within replica 4, 5, 0, and the concatenation order is 4, 5, 0. + // + // TODO(b/110096724): This is NOT YET ready to use. + static std::unique_ptr CreateAllToAll( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + const std::vector& replica_groups, + tensorflow::StringPiece barrier); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. @@ -477,7 +499,7 @@ class HloInstruction { const Shape& outfeed_shape, HloInstruction* operand, HloInstruction* token_operand, tensorflow::StringPiece outfeed_config); // Overload which does not require a token. - // TODO(b/80000000): Remove this overload when all uses of infeed are + // TODO(b/80000000): Remove this overload when all uses of outfeed are // converted to take tokens. static std::unique_ptr CreateOutfeed( const Shape& outfeed_shape, HloInstruction* operand, @@ -485,25 +507,30 @@ class HloInstruction { // Creates an asynchronous send instruction with the given channel id, which // initiates sending the operand data to a unique receive instruction in - // another computation that has the same channel id. - static std::unique_ptr CreateSend(HloInstruction* operand, - int64 channel_id); + // another computation that has the same channel id. If is_host_transfer is + // true, then this Send operation transfers data to the host. + static std::unique_ptr CreateSend( + HloInstruction* operand, HloInstruction* token, int64 channel_id, + bool is_host_transfer = false); // Blocks until data transfer for the Send instruction (operand) is complete. // The operand must be kSend. static std::unique_ptr CreateSendDone( - HloInstruction* operand); + HloInstruction* operand, bool is_host_transfer = false); // Creates an asynchronous receive instruction with the given channel id, // which allocates resources to receive data of the given shape from a unique - // send instruction in another computation that has the same channel id. - static std::unique_ptr CreateRecv(const Shape& shape, - int64 channel_id); + // send instruction in another computation that has the same channel id. If + // is_host_transfer is true, then this Send operation transfers data from the + // host. + static std::unique_ptr CreateRecv( + const Shape& shape, HloInstruction* token, int64 channel_id, + bool is_host_transfer = false); // Blocks until data transfer for the Recv instruction (operand) is complete // and returns the receive buffer. The operand must be kRecv. static std::unique_ptr CreateRecvDone( - HloInstruction* operand); + HloInstruction* operand, bool is_host_transfer = false); // Creates a slice instruction, where the operand is sliced by the given // start/limit indices. @@ -534,17 +561,34 @@ class HloInstruction { int64 dimension); // Creates a reduce instruction, where the computation (given by the handle) - // is applied successively to every element in operand. That is, if f is the - // function to apply (which either takes 2 [accumulator, value] or 3 - // [accumulator, index, value] arguments) and init is a reduction operator - // specified initial value (for example, 0 for addition), then this operation - // will compute: - // f(f(init, [index0], value0), [index1], value1), ...) + // is applied successively to every element in operand. For example, let f be + // the function to apply, which takes 2 arguments, an accumulator and the + // current value. Let init be an initial value (which is normally chosen to be + // the identity element for f, e.g. 0 if f is addition). + // Then the reduce HLO will compute: + // f(f(init, value0), value1), ...) static std::unique_ptr CreateReduce( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation); + // A more general, multiple-argument version of the above. + // The function to apply, f, now takes N arguments: + // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ..., + // init_valueN], and returns an N-tuple. The performed computation is (for + // commutative and associative f operators) equivalent to: + // + // f_1 = f(init0, ... initN, input0.value0, ..., inputN.value0) + // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1, + // ..., inputN.value1) + // ... + // TODO(b/112040122): Add support to this in HLO passes and in backends. + static std::unique_ptr CreateReduce( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::gtl::ArraySlice init_values, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation); + // Creates a reduce-window instruction, where the computation (given // by the handle) is applied window-wise at each valid window // position in the operand. @@ -611,6 +655,11 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); + // Creates a sort op, with a keys operand, and an optional values operand. + static std::unique_ptr CreateSort( + const Shape& shape, int64 dimension, HloInstruction* keys, + HloInstruction* values = nullptr); + // Creates a while instruction, given a condition computation, a body // computation, and the initial value for the input of the computations. For // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1 @@ -632,6 +681,12 @@ class HloInstruction { const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice window_bounds); + static std::unique_ptr CreateScatter( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers); + // Creates a kDomain instruction which delimits an HLO domain which have // the provided user and operand side metadata. static std::unique_ptr CreateDomain( @@ -680,17 +735,18 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); - // Creates a token instruction used for joining or creating new values of - // token type which thread through side-effecting operations. + // Creates a Afterall instruction used for joining or creating new values of + // token type which thread through side-effecting operations. Operands must + // all be tokens, and there must be at least one operand. static std::unique_ptr CreateAfterAll( tensorflow::gtl::ArraySlice operands); - // Creates an instance of GatherDimensionNumbers. - static GatherDimensionNumbers MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice output_window_dims, - tensorflow::gtl::ArraySlice elided_window_dims, - tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, - int64 index_vector_dim); + // Creates an AfterAll instruction which creates a token type out of thin air + // (no operands). This is a separate method from CreateAfterAll to facility + // the removal of operand-less AfterAll instructions. + // TODO(b/110532604): Remove this capability of creating a token from nothing + // when we plumb a primordial token from the entry computation. + static std::unique_ptr CreateToken(); // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } @@ -1001,9 +1057,7 @@ class HloInstruction { if (sharding_ == nullptr) { return tensorflow::gtl::optional(); } - auto device = sharding_->UniqueDevice(); - return device.ok() ? device.ValueOrDie() - : tensorflow::gtl::optional(); + return sharding_->UniqueDevice(); } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. @@ -1066,19 +1120,6 @@ class HloInstruction { // Returns the dump string of the dot dimension numbers. string DotDimensionNumbersToString() const; - const GatherDimensionNumbers& gather_dimension_numbers() const { - CHECK(gather_dimension_numbers_ != nullptr); - return *gather_dimension_numbers_; - } - - tensorflow::gtl::ArraySlice gather_window_bounds() const { - CHECK_EQ(opcode(), HloOpcode::kGather); - return gather_window_bounds_; - } - - // Returns the dump string of the gather dimension numbers. - string GatherDimensionNumbersToString() const; - // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of @@ -1133,6 +1174,9 @@ class HloInstruction { // Returns true if this instruction is elementwise on all its operands. bool IsElementwise() const; + // Returns true if this is an cross module all-reduce instrucion. + bool IsCrossModuleAllReduce() const; + // Returns true if this elementwise instruction implicitly broadcasts operand // `operand_idx`. // @@ -1390,6 +1434,9 @@ class HloInstruction { // Delegates to HloAllReduceInstruction::replica_group_ids. const std::vector& replica_group_ids() const; + // Delegates to HloAllToAllInstruction::replica_groups. + const std::vector& replica_groups() const; + // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. string cross_replica_sum_barrier() const; void set_cross_replica_sum_barrier(const string& barrier); @@ -1445,6 +1492,15 @@ class HloInstruction { // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes. const std::vector& dynamic_slice_sizes() const; + + // Delegates to HloGatherInstruction::gather_dimension_numbers. + const GatherDimensionNumbers& gather_dimension_numbers() const; + // Delegates to HloGatherInstruction::gather_window_bounds. + tensorflow::gtl::ArraySlice gather_window_bounds() const; + + // Delegates to HloScatterInstruction::scatter_dimension_numbers(). + const ScatterDimensionNumbers& scatter_dimension_numbers() const; + // Old methods kept for smooth subclassing transition END. protected: @@ -1588,9 +1644,6 @@ class HloInstruction { // Describes the dimension numbers used for a dot. std::unique_ptr dot_dimension_numbers_; - std::unique_ptr gather_dimension_numbers_; - std::vector gather_window_bounds_; - // Used to tag kCopy instructions that are eligible for copy elision. bool copy_elision_allowed_ = true; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index d8ca99dfd12ef95ab5e1ea61093d8bf3ea97a5e2..8a694dde8066ab9a1138b9f7981153d451ddb89e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -20,10 +20,11 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -249,7 +250,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperands) { auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r0f32_, "param1")); auto c0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto addleft = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, c0)); auto addright = builder.AddInstruction( @@ -294,7 +295,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) { auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r0f32_, "param1")); auto c0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto neg1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, c0)); auto addleft = builder.AddInstruction( @@ -334,7 +335,7 @@ TEST_F(HloInstructionTest, TrivialMap) { auto param = embedded_builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "x")); auto value = embedded_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); embedded_builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value)); auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build()); @@ -383,9 +384,9 @@ TEST_F(HloInstructionTest, TrivialReduce) { auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, f32a100x10, "p")); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto reduce = builder.AddInstruction( HloInstruction::CreateReduce(f32v100, param0, const0, /*dimensions_to_reduce=*/{1}, add_f32)); @@ -626,7 +627,7 @@ TEST_F(HloInstructionTest, SingletonFusionOp) { HloComputation::Builder builder(TestName()); // Create a fusion instruction containing a single unary operation. auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); auto module = CreateNewModule(); @@ -642,9 +643,9 @@ TEST_F(HloInstructionTest, BinaryFusionOp) { HloComputation::Builder builder(TestName()); // Create a fusion instruction containing a single binary operation. auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.1f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); auto module = CreateNewModule(); @@ -661,7 +662,7 @@ TEST_F(HloInstructionTest, ChainFusionOp) { HloComputation::Builder builder(TestName()); // Create a chain of fused unary ops. auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto exp1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); auto exp2 = builder.AddInstruction( @@ -682,7 +683,7 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { HloComputation::Builder builder(TestName()); // Create a chain of fused unary ops. auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto exp1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); auto exp2 = builder.AddInstruction( @@ -710,13 +711,13 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { HloComputation::Builder builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({ + HloInstruction::CreateConstant(LiteralUtil::CreateR2({ {1, 2}, {3, 4}, }))); auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}); - auto token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); auto outfeed10 = builder.AddInstruction( HloInstruction::CreateOutfeed(shape10, constant, token, "")); auto outfeed01 = builder.AddInstruction( @@ -732,7 +733,7 @@ TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) { HloComputation::Builder builder(TestName()); auto* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({ + HloInstruction::CreateConstant(LiteralUtil::CreateR2({ {1, 2}, {3, 4}, }))); @@ -763,7 +764,7 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { HloComputation::Builder builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto map_1_x = builder.AddInstruction( HloInstruction::CreateMap(scalar_shape, {constant}, computation_x)); auto map_2_x = builder.AddInstruction( @@ -798,11 +799,11 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { // Notable complexities are repeated operands in the same instruction, // different shapes, use of value in different expressions. auto c1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto c2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.1f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.1f))); auto c3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(9.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(9.0f))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1, c2)); @@ -873,11 +874,11 @@ TEST_F(HloInstructionTest, IdenticalInstructions) { // Create a set of random constant operands to use below. Make them matrices // so dimensions are interesting. auto operand1 = HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); auto operand2 = HloInstruction::CreateConstant( - Literal::CreateR2({{10.0, 20.0}, {30.0, 40.0}})); - auto vector_operand = - HloInstruction::CreateConstant(Literal::CreateR1({42.0, 123.0})); + LiteralUtil::CreateR2({{10.0, 20.0}, {30.0, 40.0}})); + auto vector_operand = HloInstruction::CreateConstant( + LiteralUtil::CreateR1({42.0, 123.0})); Shape shape = operand1->shape(); // Convenient short names for the operands. @@ -1234,9 +1235,9 @@ TEST_F(HloInstructionTest, NestedFusionEquality) { // Build a nested fusion computation. Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto b_t = builder.AddInstruction( HloInstruction::CreateTranspose(data_shape, b, {1, 0})); DotDimensionNumbers dot_dnums; @@ -1245,7 +1246,7 @@ TEST_F(HloInstructionTest, NestedFusionEquality) { auto dot = builder.AddInstruction( HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add_operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); auto add = builder.AddInstruction(HloInstruction::CreateBinary( @@ -1342,7 +1343,7 @@ TEST_F(HloInstructionTest, Stringification) { "condition=%TransposeDot, body=%TransposeDot"); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); HloInstruction* conditional = builder.AddInstruction(HloInstruction::CreateConditional( sout, pred, x, computation, x, computation)); @@ -1369,7 +1370,7 @@ TEST_F(HloInstructionTest, StringifyGather_0) { HloInstruction* gather_instruction = builder.AddInstruction(HloInstruction::CreateGather( gather_result_shape, input, gather_indices, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1405,7 +1406,7 @@ TEST_F(HloInstructionTest, StringifyGather_1) { HloInstruction* gather_instruction = builder.AddInstruction(HloInstruction::CreateGather( gather_result_shape, input, gather_indices, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1424,6 +1425,55 @@ TEST_F(HloInstructionTest, StringifyGather_1) { "index_vector_dim=2, window_bounds={30,29,28,27,26}"); } +TEST_F(HloInstructionTest, StringifyScatter) { + Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); + Shape scatter_indices_tensor_shape = + ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); + Shape scatter_updates_shape = + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}); + + HloComputation::Builder builder("Scatter"); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); + HloInstruction* scatter_indices = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, scatter_indices_tensor_shape, "scatter_indices")); + HloInstruction* scatter_updates = + builder.AddInstruction(HloInstruction::CreateParameter( + 2, scatter_updates_shape, "scatter_updates")); + + HloComputation::Builder update_builder("Scatter.update"); + update_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p1")); + update_builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p2")); + + auto module = CreateNewModule(); + auto* update_computation = + module->AddEmbeddedComputation(update_builder.Build()); + + HloInstruction* scatter_instruction = + builder.AddInstruction(HloInstruction::CreateScatter( + input_tensor_shape, input, scatter_indices, scatter_updates, + update_computation, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2))); + module->AddEntryComputation(builder.Build()); + + EXPECT_EQ( + scatter_instruction->ToString(), + "%scatter = f32[50,49,48,47,46]{4,3,2,1,0} " + "scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " + "s64[10,9,5,7,6]{4,3,2,1,0} %scatter_indices, " + "f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %scatter_updates), " + "update_window_dims={4,5,6,7,8}, inserted_window_dims={}, " + "scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=2, " + "to_apply=%Scatter.update"); +} + TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { // Tests stringification of a simple op, fusion, while, and conditional. const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); @@ -1455,15 +1505,15 @@ TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kLoop); - EXPECT_EQ( - fusion->ToString(options), + const string expected_fusion = R"(f32[5,20]{1,0} fusion(f32[5,10]{1,0}, f32[20,10]{1,0}), kind=kLoop, calls= { tmp_0 = f32[5,10]{1,0} parameter(0) tmp_1 = f32[20,10]{1,0} parameter(1) tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"); +})"; + EXPECT_EQ(fusion->ToString(options), expected_fusion); } TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { @@ -1495,8 +1545,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { HloInstruction::CreateWhile(sout, computation, computation, x)); auto options = HloPrintOptions().Canonical(); - EXPECT_EQ(loop->ToString(options), - R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition= + const string expected_loop = + R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition= { tmp_0 = f32[5,10]{1,0} parameter(0) tmp_1 = f32[20,10]{1,0} parameter(1) @@ -1518,7 +1568,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} } -})"); +})"; + EXPECT_EQ(loop->ToString(options), expected_loop); } TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { @@ -1550,13 +1601,12 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { HloInstruction::CreateWhile(sout, computation, computation, x)); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); HloInstruction* conditional = builder.AddInstruction(HloInstruction::CreateConditional( sout, pred, x, computation, x, computation)); auto options = HloPrintOptions().Canonical(); - EXPECT_EQ( - conditional->ToString(options), + const string expected_conditional = R"(f32[5,20]{1,0} conditional(pred[], f32[5,10]{1,0}, f32[5,10]{1,0}), true_computation= { tmp_0 = f32[5,10]{1,0} parameter(0) @@ -1579,7 +1629,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} } -})"); +})"; + EXPECT_EQ(conditional->ToString(options), expected_conditional); } TEST_F(HloInstructionTest, CheckDeepClone) { diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index e2f43f581091af49a4bdb96c8c42eb52035ce6fd..1de5032670ff47cda5599cf736bbd3529cfcaba9 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -180,8 +181,11 @@ std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, - int64 channel_id) - : HloInstruction(opcode, shape), channel_id_(channel_id) {} + int64 channel_id, + bool is_host_transfer) + : HloInstruction(opcode, shape), + channel_id_(channel_id), + is_host_transfer_(is_host_transfer) {} HloInstructionProto HloSendRecvInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); @@ -191,7 +195,12 @@ HloInstructionProto HloSendRecvInstruction::ToProto() const { std::vector HloSendRecvInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("channel_id=", channel_id_)}; + std::vector attrs; + attrs.push_back(StrCat("channel_id=", channel_id_)); + if (is_host_transfer()) { + attrs.push_back("is_host_transfer=true"); + } + return attrs; } bool HloSendRecvInstruction::IdenticalSlowPath( @@ -204,26 +213,32 @@ bool HloSendRecvInstruction::IdenticalSlowPath( // Send instruction produces a tuple of {aliased operand, U32 context}. HloSendInstruction::HloSendInstruction(HloInstruction* operand, - int64 channel_id) + HloInstruction* token, int64 channel_id, + bool is_host_transfer) : HloSendRecvInstruction( HloOpcode::kSend, - ShapeUtil::MakeTupleShape( - {CHECK_NOTNULL(operand)->shape(), ShapeUtil::MakeShape(U32, {})}), - channel_id) { + ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(), + ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()}), + channel_id, is_host_transfer) { AppendOperand(operand); + AppendOperand(token); } std::unique_ptr HloSendInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 1); - return MakeUnique(new_operands[0], channel_id()); + CHECK_EQ(new_operands.size(), 2); + return MakeUnique(new_operands[0], new_operands[1], + channel_id(), is_host_transfer()); } -HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand) - : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil(), - CHECK_NOTNULL(operand)->channel_id()) { +HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand, + bool is_host_transfer) + : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(), + CHECK_NOTNULL(operand)->channel_id(), + is_host_transfer) { AppendOperand(operand); } @@ -234,30 +249,39 @@ HloSendDoneInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return MakeUnique( - Cast(new_operands[0])); + Cast(new_operands[0]), is_host_transfer()); } // Recv instruction produces a tuple of {receive buffer, U32 context}. -HloRecvInstruction::HloRecvInstruction(const Shape& shape, int64 channel_id) +HloRecvInstruction::HloRecvInstruction(const Shape& shape, + HloInstruction* token, int64 channel_id, + bool is_host_transfer) : HloSendRecvInstruction( HloOpcode::kRecv, - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}), - channel_id) {} + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()}), + channel_id, is_host_transfer) { + AppendOperand(token); +} std::unique_ptr HloRecvInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 0); + CHECK_EQ(new_operands.size(), 1); return MakeUnique( - ShapeUtil::GetTupleElementShape(shape, 0), channel_id()); + ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(), + is_host_transfer()); } -HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand) +HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand, + bool is_host_transfer) : HloSendRecvInstruction( HloOpcode::kRecvDone, - ShapeUtil::GetTupleElementShape(operand->shape(), 0), - CHECK_NOTNULL(operand)->channel_id()) { + ShapeUtil::MakeTupleShape( + {ShapeUtil::GetTupleElementShape(operand->shape(), 0), + ShapeUtil::MakeTokenShape()}), + CHECK_NOTNULL(operand)->channel_id(), is_host_transfer) { AppendOperand(operand); } @@ -268,7 +292,7 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); return MakeUnique( - Cast(new_operands[0])); + Cast(new_operands[0]), is_host_transfer()); } HloAllReduceInstruction::HloAllReduceInstruction( @@ -281,8 +305,6 @@ HloAllReduceInstruction::HloAllReduceInstruction( replica_group_ids_(replica_group_ids.begin(), replica_group_ids.end()), cross_replica_sum_barrier_(barrier.begin(), barrier.end()), all_reduce_id_(all_reduce_id) { - // TODO(b/79737069): Remove the CHECK when supported. - CHECK(!all_reduce_id_); for (auto operand : operands) { AppendOperand(operand); } @@ -337,6 +359,67 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( cross_replica_sum_barrier(), all_reduce_id()); } +HloAllToAllInstruction::HloAllToAllInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + const std::vector& replica_groups, + tensorflow::StringPiece barrier) + : HloInstruction(HloOpcode::kAllToAll, shape), + replica_groups_(replica_groups), + cross_replica_sum_barrier_(barrier.begin(), barrier.end()) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +bool HloAllToAllInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return ContainersEqual(replica_groups(), casted_other.replica_groups(), + [](const ReplicaGroup& a, const ReplicaGroup& b) { + return ContainersEqual(a.replica_ids(), + b.replica_ids()); + }) && + cross_replica_sum_barrier() == + casted_other.cross_replica_sum_barrier(); +} + +std::unique_ptr +HloAllToAllInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* /*context*/) const { + return MakeUnique( + shape, new_operands, replica_groups(), cross_replica_sum_barrier()); +} + +std::vector HloAllToAllInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector result; + std::vector replica_group_str; + for (const ReplicaGroup& group : replica_groups()) { + replica_group_str.push_back( + StrCat("{", Join(group.replica_ids(), ","), "}")); + } + result.push_back( + StrCat("replica_groups={", Join(replica_group_str, ","), "}")); + + if (!cross_replica_sum_barrier().empty()) { + result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); + } + + return result; +} + +HloInstructionProto HloAllToAllInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_replica_groups() = {replica_groups_.begin(), + replica_groups_.end()}; + proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); + return proto; +} + HloReverseInstruction::HloReverseInstruction( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) @@ -416,13 +499,14 @@ HloConcatenateInstruction::CloneWithNewOperandsImpl( } HloReduceInstruction::HloReduceInstruction( - const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + const Shape& shape, tensorflow::gtl::ArraySlice args, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation) : HloInstruction(HloOpcode::kReduce, shape), dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) { - AppendOperand(arg); - AppendOperand(init_value); + for (HloInstruction* arg : args) { + AppendOperand(arg); + } AppendComputation(reduce_computation); } @@ -455,8 +539,48 @@ std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique( - shape, new_operands[0], new_operands[1], dimensions(), to_apply()); + return MakeUnique(shape, new_operands, dimensions(), + to_apply()); +} + +HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension, + HloInstruction* keys, + HloInstruction* values) + : HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) { + AppendOperand(keys); + if (values) { + AppendOperand(values); + } +} + +HloInstructionProto HloSortInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloSortInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloSortInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr HloSortInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + HloInstruction* keys = new_operands[0]; + HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr; + return MakeUnique(shape, dimensions(0), keys, values); } HloTransposeInstruction::HloTransposeInstruction( @@ -757,7 +881,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( HloTraceInstruction::HloTraceInstruction(const string& tag, HloInstruction* operand) : HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()), - literal_(Literal::CreateR1U8(tag)) { + literal_(LiteralUtil::CreateR1U8(tag)) { AppendOperand(operand); operand->set_tracing(this); } @@ -1043,8 +1167,6 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); clone = fused_expression_root(); } else { - clone = fused_instructions_computation()->AddInstruction( - instruction_to_fuse->Clone(/*suffix=*/"")); // When add_output is false, instruction_to_fuse is necessarily an operand // of the fusion instruction. After fusion this will no longer be the // case. Remove the operand from the operand list and remove its @@ -1054,6 +1176,16 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( bool in_operand_list = std::find(operands().begin(), operands().end(), instruction_to_fuse) != operands().end(); CHECK(add_output || in_operand_list); + if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { + // We assume all uses of a kTuple operation are GTE ops, not another + // fusion node. In this case, we don't need to clone + // 'instruction_to_fuse'. + CHECK(!in_operand_list); + clone = instruction_to_fuse; + } else { + clone = fused_instructions_computation()->AddInstruction( + instruction_to_fuse->Clone(/*suffix=*/"")); + } const std::vector& fused_parameters = fused_instructions_computation()->parameter_instructions(); for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { @@ -1150,9 +1282,10 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( } int64 index = tuple_elements.size(); if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { - index -= instruction_to_fuse->operand_count(); + CHECK_EQ(clone, instruction_to_fuse); + index -= clone->operand_count(); std::vector to_be_removed; - for (auto old_gte : instruction_to_fuse->users()) { + for (auto old_gte : clone->users()) { CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement); int64 old_tuple_index = old_gte->tuple_index(); HloInstruction* new_gte = @@ -1164,7 +1297,6 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( for (auto old_gte : to_be_removed) { TF_CHECK_OK(parent()->RemoveInstruction(old_gte)); } - TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone)); } else { HloInstruction* new_gte = parent()->AddInstruction(HloInstruction::CreateGetTupleElement( @@ -1173,7 +1305,9 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal( } } - VLOG(2) << "New clone:\n" << clone->ToString(); + if (clone != instruction_to_fuse) { + VLOG(2) << "New clone:\n" << clone->ToString(); + } return clone; } @@ -1854,4 +1988,180 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl( return MakeUnique( shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); } + +HloGatherInstruction::HloGatherInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice window_bounds) + : HloInstruction(HloOpcode::kGather, shape) { + AppendOperand(operand); + AppendOperand(gather_indices); + gather_dimension_numbers_ = + MakeUnique(gather_dim_numbers); + c_copy(window_bounds, std::back_inserter(gather_window_bounds_)); +} + +string HloGatherInstruction::GatherDimensionNumbersToString() const { + CHECK(gather_dimension_numbers_ != nullptr); + string output_window_dims = + StrCat("output_window_dims={", + Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); + string elided_window_dims = + StrCat("elided_window_dims={", + Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); + string gather_dims_to_operand_dims = StrCat( + "gather_dims_to_operand_dims={", + Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); + string index_vector_dim = StrCat( + "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); + + return Join>( + {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, + index_vector_dim}, + ", "); +} + +/* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers( + tensorflow::gtl::ArraySlice output_window_dims, + tensorflow::gtl::ArraySlice elided_window_dims, + tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, + int64 index_vector_dim) { + GatherDimensionNumbers gather_dim_numbers; + for (int64 output_window_dim : output_window_dims) { + gather_dim_numbers.add_output_window_dims(output_window_dim); + } + for (int64 elided_window_dim : elided_window_dims) { + gather_dim_numbers.add_elided_window_dims(elided_window_dim); + } + for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { + gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); + } + + gather_dim_numbers.set_index_vector_dim(index_vector_dim); + return gather_dim_numbers; +} + +HloInstructionProto HloGatherInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers(); + for (int64 bound : gather_window_bounds()) { + proto.add_gather_window_bounds(bound); + } + return proto; +} + +std::vector HloGatherInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {GatherDimensionNumbersToString(), + StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")}; +} + +bool HloGatherInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return protobuf_util::ProtobufEquals( + gather_dimension_numbers(), + casted_other.gather_dimension_numbers()) && + gather_window_bounds() == casted_other.gather_window_bounds(); +} + +std::unique_ptr HloGatherInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique( + shape, new_operands[0], new_operands[1], gather_dimension_numbers(), + gather_window_bounds()); +} + +HloScatterInstruction::HloScatterInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers) + : HloInstruction(HloOpcode::kScatter, shape) { + AppendOperand(operand); + AppendOperand(scatter_indices); + AppendOperand(updates); + AppendComputation(update_computation); + scatter_dimension_numbers_ = + MakeUnique(scatter_dim_numbers); +} + +string HloScatterInstruction::ScatterDimensionNumbersToString() const { + string update_window_dims = + StrCat("update_window_dims={", + Join(scatter_dimension_numbers().update_window_dims(), ","), "}"); + string inserted_window_dims = StrCat( + "inserted_window_dims={", + Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}"); + string scatter_dims_to_operand_dims = StrCat( + "scatter_dims_to_operand_dims={", + Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","), + "}"); + string index_vector_dim = StrCat( + "index_vector_dim=", scatter_dimension_numbers().index_vector_dim()); + + return Join>( + {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, + index_vector_dim}, + ", "); +} + +/* static */ ScatterDimensionNumbers +HloScatterInstruction::MakeScatterDimNumbers( + tensorflow::gtl::ArraySlice update_window_dims, + tensorflow::gtl::ArraySlice inserted_window_dims, + tensorflow::gtl::ArraySlice scatter_dims_to_operand_dims, + int64 index_vector_dim) { + ScatterDimensionNumbers scatter_dim_numbers; + for (int64 update_window_dim : update_window_dims) { + scatter_dim_numbers.add_update_window_dims(update_window_dim); + } + for (int64 inserted_window_dim : inserted_window_dims) { + scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim); + } + for (int64 scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) { + scatter_dim_numbers.add_scatter_dims_to_operand_dims( + scatter_dim_to_operand_dim); + } + scatter_dim_numbers.set_index_vector_dim(index_vector_dim); + return scatter_dim_numbers; +} + +HloInstructionProto HloScatterInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers(); + return proto; +} + +std::vector HloScatterInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {ScatterDimensionNumbersToString()}; +} + +bool HloScatterInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return protobuf_util::ProtobufEquals( + scatter_dimension_numbers(), + casted_other.scatter_dimension_numbers()) && + eq_computations(to_apply(), casted_other.to_apply()); +} + +std::unique_ptr HloScatterInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 3); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], to_apply(), + scatter_dimension_numbers()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index ec8a42bd3b965f3aad373afd25e76506b2ff3964..9586ad667345111d05015e035c93fe6578e3b665 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -141,12 +141,15 @@ class HloSendRecvInstruction : public HloInstruction { // channel. int64 channel_id() const { return channel_id_; } + // Returns whether this send/recv instruction sends data to/from the host. + bool is_host_transfer() const { return is_host_transfer_; } + // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; protected: explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, - int64 channel_id); + int64 channel_id, bool is_host_transfer); private: std::vector ExtraAttributesToStringImpl( @@ -157,11 +160,15 @@ class HloSendRecvInstruction : public HloInstruction { eq_computations) const override; // Represents a unique identifier for each Send/Recv instruction pair. int64 channel_id_; + + // Whether this send/recv instruction sends data to/from the host. + bool is_host_transfer_; }; class HloSendInstruction : public HloSendRecvInstruction { public: - explicit HloSendInstruction(HloInstruction* operand, int64 channel_id); + explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token, + int64 channel_id, bool is_host_transfer); private: // Implementation for non-common logic of CloneWithNewOperands. @@ -173,7 +180,8 @@ class HloSendInstruction : public HloSendRecvInstruction { class HloSendDoneInstruction : public HloSendRecvInstruction { public: - explicit HloSendDoneInstruction(HloSendInstruction* operand); + explicit HloSendDoneInstruction(HloSendInstruction* operand, + bool is_host_transfer); private: // Implementation for non-common logic of CloneWithNewOperands. @@ -185,7 +193,8 @@ class HloSendDoneInstruction : public HloSendRecvInstruction { class HloRecvInstruction : public HloSendRecvInstruction { public: - explicit HloRecvInstruction(const Shape& shape, int64 channel_id); + explicit HloRecvInstruction(const Shape& shape, HloInstruction* token, + int64 channel_id, bool is_host_transfer); private: // Implementation for non-common logic of CloneWithNewOperands. @@ -197,7 +206,8 @@ class HloRecvInstruction : public HloSendRecvInstruction { class HloRecvDoneInstruction : public HloSendRecvInstruction { public: - explicit HloRecvDoneInstruction(HloRecvInstruction* operand); + explicit HloRecvDoneInstruction(HloRecvInstruction* operand, + bool is_host_transfer); private: // Implementation for non-common logic of CloneWithNewOperands. @@ -214,8 +224,7 @@ class HloAllReduceInstruction : public HloInstruction { HloComputation* reduce_computation, tensorflow::gtl::ArraySlice replica_group_ids, tensorflow::StringPiece barrier, - const tensorflow::gtl::optional& all_reduce_id = - tensorflow::gtl::nullopt); + const tensorflow::gtl::optional& all_reduce_id); // Returns the group ids of each replica for CrossReplicaSum op. const std::vector& replica_group_ids() const { @@ -264,6 +273,47 @@ class HloAllReduceInstruction : public HloInstruction { tensorflow::gtl::optional all_reduce_id_; }; +class HloAllToAllInstruction : public HloInstruction { + public: + explicit HloAllToAllInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operand, + const std::vector& replica_groups, + tensorflow::StringPiece barrier); + + const std::vector& replica_groups() const { + return replica_groups_; + } + + // TODO(b/110096724): rename this. + void set_cross_replica_sum_barrier(string barrier) { + cross_replica_sum_barrier_ = barrier; + } + string cross_replica_sum_barrier() const { + return cross_replica_sum_barrier_; + } + + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector replica_groups_; + + // The string representation of the barrier config. + string cross_replica_sum_barrier_; +}; + class HloReverseInstruction : public HloInstruction { public: explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, @@ -322,7 +372,7 @@ class HloConcatenateInstruction : public HloInstruction { class HloReduceInstruction : public HloInstruction { public: explicit HloReduceInstruction( - const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + const Shape& shape, tensorflow::gtl::ArraySlice args, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation); // Returns the dimension sizes or numbers associated with this instruction. @@ -331,6 +381,47 @@ class HloReduceInstruction : public HloInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns the input tensors to be reduced. + tensorflow::gtl::ArraySlice inputs() const { + return tensorflow::gtl::ArraySlice(operands(), 0, + operand_count() / 2); + } + + // Returns the init values of the reduction. + tensorflow::gtl::ArraySlice init_values() const { + return tensorflow::gtl::ArraySlice( + operands(), operand_count() / 2, operand_count()); + } + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloSortInstruction : public HloInstruction { + public: + explicit HloSortInstruction(const Shape& shape, int64 dimension, + HloInstruction* keys, + HloInstruction* values = nullptr); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns the sort dimension for this instruction + int64 sort_dimension() { return dimensions(0); } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -496,6 +587,8 @@ class HloConstantInstruction : public HloInstruction { explicit HloConstantInstruction(const Shape& shape); // Returns the literal associated with this instruction. const Literal& literal() const { return *literal_; } + // Returns whether there is literal associated with this instruction. + bool HasLiteral() const { return literal_ != nullptr; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1117,6 +1210,88 @@ class HloDynamicSliceInstruction : public HloInstruction { // ('start' is specified dynamically in the second operand of the operation). std::vector dynamic_slice_sizes_; }; + +class HloGatherInstruction : public HloInstruction { + public: + explicit HloGatherInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* gather_indices, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice window_bounds); + const GatherDimensionNumbers& gather_dimension_numbers() const { + CHECK(gather_dimension_numbers_ != nullptr); + return *gather_dimension_numbers_; + } + tensorflow::gtl::ArraySlice gather_window_bounds() const { + return gather_window_bounds_; + } + // Returns the dump string of the gather dimension numbers. + string GatherDimensionNumbersToString() const; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Creates an instance of GatherDimensionNumbers. + static GatherDimensionNumbers MakeGatherDimNumbers( + tensorflow::gtl::ArraySlice output_window_dims, + tensorflow::gtl::ArraySlice elided_window_dims, + tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, + int64 index_vector_dim); + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::unique_ptr gather_dimension_numbers_; + std::vector gather_window_bounds_; +}; + +class HloScatterInstruction : public HloInstruction { + public: + explicit HloScatterInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers); + const ScatterDimensionNumbers& scatter_dimension_numbers() const { + CHECK(scatter_dimension_numbers_ != nullptr); + return *scatter_dimension_numbers_; + } + // Returns the dump string of the scatter dimension numbers. + string ScatterDimensionNumbersToString() const; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Creates an instance of ScatterDimensionNumbers. + static ScatterDimensionNumbers MakeScatterDimNumbers( + tensorflow::gtl::ArraySlice update_window_dims, + tensorflow::gtl::ArraySlice inserted_window_dims, + tensorflow::gtl::ArraySlice scatter_dims_to_operand_dims, + int64 index_vector_dim); + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::unique_ptr scatter_dimension_numbers_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index f0d9fdbc8f86da0bb9d7f9235239df677c9506bc..71b44507cc704344ff6fe5269ea498bb32cfb8a6 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -299,9 +299,12 @@ TokKind HloLexer::LexNumberOrPattern() { static LazyRE2 int_pattern = {R"([-]?\d+)"}; if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); - tensorflow::strings::safe_strto64( - StringPieceFromPointers(token_start_, current_ptr_), &int64_val_); - return TokKind::kInt; + auto slice = StringPieceFromPointers(token_start_, current_ptr_); + if (tensorflow::strings::safe_strto64(slice, &int64_val_)) { + return TokKind::kInt; + } + LOG(ERROR) << "Failed to parse int literal: " << slice; + return TokKind::kError; } static LazyRE2 neg_inf = {"-inf"}; diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index 0275294a1a86cef13e5b267ad578f30cc18858dc..01b625c29ca2823b2a2490b30a9d4d5128b4c22e 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 9a3010cf1ff75e840130d8442bbe26d6041cef25..7961aece541faeb66875885b380158756c503250 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -75,8 +76,10 @@ TEST(HloMatchersTest, Test) { } TEST(HloMatchersTest, CustomCallMatcher) { - auto c1 = HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3})); - auto c2 = HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3})); + auto c1 = + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3})); + auto c2 = + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3})); auto call = HloInstruction::CreateCustomCall( ShapeUtil::MakeShape(F32, {1}), {c1.get(), c2.get()}, "foo_target"); @@ -154,9 +157,8 @@ TEST(HloMatchersTest, ShardingMatcher) { Array assignment({2}); assignment.SetValues({0, 1}); auto sharding = HloSharding::Tuple( - tuple_shape, - {HloSharding::Tile(ShapeUtil::MakeShape(F32, {5}), assignment), - HloSharding::AssignDevice(1), HloSharding::Replicate()}); + tuple_shape, {HloSharding::Tile(assignment), HloSharding::AssignDevice(1), + HloSharding::Replicate()}); p2->set_sharding(sharding); EXPECT_THAT(p0.get(), op::NoSharding()); @@ -169,8 +171,7 @@ TEST(HloMatchersTest, ShardingMatcher) { EXPECT_THAT( p2.get(), - op::Sharding( - "{{f32[5] devices=[2]0,1}, {maximal device=1}, {replicated}}")); + op::Sharding("{{devices=[2]0,1}, {maximal device=1}, {replicated}}")); EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))), "%param.0 = f32[5]{0} parameter(0) has no sharding (expected: " diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 39bc25ba42c2cb6a9f77e2726405311ba13b3edc..55ff073d3faf34aa0f1b8f0886946837e7a49bcc 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -537,10 +537,11 @@ uint64 HloModule::RandomNew64() const { HloComputation* HloModule::GetComputationWithName( tensorflow::StringPiece name) { - auto it = c_find_if(computations(), [&](HloComputation* computation) { + auto computations_in_module = computations(); + auto it = c_find_if(computations_in_module, [&](HloComputation* computation) { return computation->name() == name; }); - return it == computations().end() ? nullptr : *it; + return it == computations_in_module.end() ? nullptr : *it; } /* static */ std::atomic HloModule::next_unique_module_id_(0); diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index bf33640db16638803f4f8e6c66f35d6bb6e2c9fe..10bf9ffd6c1960df5ca2a3555d120b0874407f15 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -75,10 +77,23 @@ Status HloModuleGroupMetadata::Build() { if (tracked == nullptr) { return Status::OK(); } - // Add the parent computation of this channel instruction and its peer - // computation (both must be while computations) as companions. + + std::vector peers; if (IsChannelInstruction(hlo)) { - HloComputation* peer_computation = PeerComputation(hlo); + peers.push_back(PeerComputation(hlo)); + } else if (hlo->IsCrossModuleAllReduce()) { + for (HloInstruction* instr : GetAllReduceGroup(*hlo->all_reduce_id())) { + if (instr == hlo) { + continue; + } + peers.push_back(instr->parent()); + } + } + + // Add the parent computation of this channel (or all-reduce) instruction + // and its peer computation(s) (both must be while computations) as + // companions. + for (HloComputation* peer_computation : peers) { const TrackedInstruction* peer_tracked = GetTrackedInstruction(peer_computation); TF_RET_CHECK(peer_tracked != nullptr) @@ -162,8 +177,12 @@ bool HloModuleGroupMetadata::IsChannelInstruction( case HloOpcode::kSend: case HloOpcode::kRecv: case HloOpcode::kSendDone: - case HloOpcode::kRecvDone: - return true; + case HloOpcode::kRecvDone: { + const HloSendRecvInstruction* send_recv_instr = + DynCast(instruction); + CHECK(send_recv_instr != nullptr); + return !send_recv_instr->is_host_transfer(); + } default: return false; } @@ -175,7 +194,8 @@ bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const { bool HloModuleGroupMetadata::InstructionCommunicates( HloInstruction* hlo) const { - return IsChannelInstruction(hlo) || IsCompanionInstruction(hlo); + return IsChannelInstruction(hlo) || IsCompanionInstruction(hlo) || + hlo->IsCrossModuleAllReduce(); } const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel( @@ -200,6 +220,13 @@ HloComputation* HloModuleGroupMetadata::PeerComputation( } } +const std::vector& HloModuleGroupMetadata::GetAllReduceGroup( + int64 all_reduce_id) const { + auto it = all_reduce_map_.find(all_reduce_id); + CHECK(it != all_reduce_map_.end()); + return it->second; +} + std::vector HloModuleGroupMetadata::GetCompanionsPath(const HloInstruction* hlo) const { std::vector path; @@ -278,10 +305,27 @@ Status HloModuleGroupMetadata::RecordInstructions() { tracked_instructions_[hlo->to_apply()] = TrackedInstruction(hlo, ComputationKind::kCallFunction); } + + // Group cross module all-reduce instructions by the all_reduce id. + if (hlo->IsCrossModuleAllReduce()) { + TF_RET_CHECK(channel_id_map_.find(*hlo->all_reduce_id()) == + channel_id_map_.end()) + << "all_reduce_id " << *hlo->all_reduce_id() + << " is already used by a send/recv instruction"; + all_reduce_map_[*hlo->all_reduce_id()].push_back(hlo); + max_channel_id_ = std::max(max_channel_id_, *hlo->all_reduce_id()); + return Status::OK(); + } + if (!IsChannelInstruction(hlo)) { return Status::OK(); } + TF_RET_CHECK(all_reduce_map_.find(hlo->channel_id()) == + all_reduce_map_.end()) + << "channel id " << hlo->channel_id() + << " is already used by an all-reduce instruction"; + // Add a new channel if needed. if (channel_id_map_.find(hlo->channel_id()) == channel_id_map_.end()) { channels_.emplace_back(); @@ -324,6 +368,7 @@ Status HloModuleGroupMetadata::RecordInstructions() { } } VLOG(2) << "Created " << channels_.size() << " channels"; + VLOG(2) << "Created " << all_reduce_map_.size() << " all-reduce groups"; return Status::OK(); } @@ -382,7 +427,8 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { // Check if the shapes match for each channel. for (const Channel& channel : channels_) { const Shape& send_shape = channel.send->operand(0)->shape(); - const Shape& recv_shape = channel.recv_done->shape(); + const Shape& recv_shape = + ShapeUtil::GetTupleElementShape(channel.recv_done->shape(), 0); if (!ShapeUtil::Compatible(send_shape, recv_shape)) { return FailedPrecondition("send/recv shapes do not match"); } diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index ffde3a332dfc141ca928a44cfdf4686900e9f47b..84f2d3f5fbc1a6ff1df8ba3c0babd122e5701148 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -92,7 +92,7 @@ class HloModuleGroupMetadata { ComputationKind kind_ = ComputationKind::kInvalid; }; - // Represents a channel and the 4 instructions that form the channel. + // Represents a channel and the instructions that form the channel. struct Channel { int64 id = -1; HloInstruction* send = nullptr; @@ -118,13 +118,17 @@ class HloModuleGroupMetadata { // comment above on companion instructions. bool IsCompanionInstruction(HloInstruction* hlo) const; - // Returns true if the instruction is either a channel instruction or a - // companion instruction. + // Returns true if the instruction is either a channel instruction, a + // cross-module all-reduce instruction, or a companion instruction. bool InstructionCommunicates(HloInstruction* hlo) const; // Returns the Channel instance for the given channel id. const Channel& GetChannel(int64 channel_id) const; + // Returns the all-reduce instructions with the same all_reduce_id. + const std::vector& GetAllReduceGroup( + int64 all_reduce_id) const; + // Returns the computation that contains the peer channel instructions for // the given instruction. // @@ -187,13 +191,14 @@ class HloModuleGroupMetadata { // Returns all channels in the module group. const std::vector& channels() const { return channels_; } - // Returns the maximum channel id used in the module group. + // Returns the maximum channel id or all_reduce_id used in the module group. int64 max_channel_id() const { return max_channel_id_; } private: Status Build(); - // Record all channel instructions and While instructions. + // Record all channel instructions, cross-module AllReduce instructions, and + // While/Conditional/Call instructions. Status RecordInstructions(); // Verifies the given HloModules are well-formed and follow the specification, @@ -255,6 +260,9 @@ class HloModuleGroupMetadata { // Map from channel ids to the index in channels_. tensorflow::gtl::FlatMap channel_id_map_; + // Map from all-reduce ids to the all reduce instructions. + tensorflow::gtl::FlatMap> all_reduce_map_; + // The maximum channel id used in the module group. int64 max_channel_id_ = -1; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 21a9b7291acc9e0066a9061facd13ab5acbf0bac..9fd0ade153109c6c809c37aa08257f83a82c44d5 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -56,12 +56,17 @@ std::vector HloModuleGroupUtil::GlobalPredecessors( }; // If the given instruction is a companion instruction, we need to find the - // predecessors of all of its companion instructions. + // predecessors of all of its companion instructions. If the instruction is an + // all-reduce, we need to find the predecessors of all the peer all-reduce + // instructions. std::vector instruction_group; if (metadata_.IsCompanionInstruction(instruction)) { for (HloInstruction* companion : metadata_.Companions(instruction)) { instruction_group.push_back(companion); } + } else if (instruction->IsCrossModuleAllReduce()) { + instruction_group = + metadata_.GetAllReduceGroup(*instruction->all_reduce_id()); } else { instruction_group.push_back(instruction); } @@ -112,12 +117,17 @@ std::vector HloModuleGroupUtil::GlobalSuccessors( }; // If the given instruction is a companion instruction, we need to find the - // successors of all of its companion instructions. + // successors of all of its companion instructions. If the instruction is an + // all-reduce, we need to find the successors of all its peer all-reduce + // instructions. std::vector instruction_group; if (metadata_.IsCompanionInstruction(instruction)) { for (HloInstruction* companion : metadata_.Companions(instruction)) { instruction_group.push_back(companion); } + } else if (instruction->IsCrossModuleAllReduce()) { + instruction_group = + metadata_.GetAllReduceGroup(*instruction->all_reduce_id()); } else { instruction_group.push_back(instruction); } @@ -170,15 +180,17 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( HloInstruction* hlo = stack.top(); // Find the instruction group of the currently visited instruction. The - // instruction group represents all companion instructions of the - // current instruction, and are considered to be a single entity for the - // purpose of the traversal (i.e., they must always be in the same visit - // state). + // instruction group represents all companion instructions of the current + // instruction, or all the all-reduce instructions that belong to the same + // group, or are considered to be a single entity for the purpose of the + // traversal (i.e., they must always be in the same visit state). std::vector instruction_group; if (metadata_.IsCompanionInstruction(hlo)) { for (HloInstruction* companion : metadata_.Companions(hlo)) { instruction_group.push_back(companion); } + } else if (hlo->IsCrossModuleAllReduce()) { + instruction_group = metadata_.GetAllReduceGroup(*hlo->all_reduce_id()); } else { instruction_group.push_back(hlo); } @@ -292,7 +304,7 @@ HloModuleGroupUtil::ComputeReachability( } auto reachability = MakeUnique(post_order); for (HloInstruction* hlo : post_order) { - reachability->SetReachabilityToUnion(GlobalPredecessors(hlo), hlo); + reachability->FastSetReachabilityToUnion(GlobalPredecessors(hlo), hlo); } return std::move(reachability); } diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 7f28a804bfec9c2f1bbb5fa08f7dd4e68be14d35..236f4500860a8673e61cbd2f861a8fc40c7861f7 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -38,7 +38,7 @@ class HloModuleTest : public HloTestBase { std::unique_ptr CreateConstantComputation() { auto builder = HloComputation::Builder("Constant"); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); return builder.Build(); } @@ -122,7 +122,7 @@ TEST_F(HloModuleTest, CloneHasFusion) { { auto b = HloComputation::Builder("Entry"); auto input = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); b.AddInstruction( HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput, /*operands=*/{input}, fused_computation)); @@ -173,7 +173,7 @@ TEST_F(HloModuleTest, LargeConstantToString) { auto builder = HloComputation::Builder("Constant"); std::vector values(16, 42.0); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1(values))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1(values))); module->AddEntryComputation(builder.Build()); EXPECT_EQ( diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 05e47a698f3b1d6345b183fb88b588a413063595..ec279867e595b66a22882703cc06046e3e916c96 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -47,6 +47,7 @@ namespace xla { #define HLO_OPCODE_LIST(V) \ V(kAbs, "abs") \ V(kAdd, "add") \ + V(kAllToAll, "all-to-all") \ V(kAtan2, "atan2") \ V(kBatchNormGrad, "batch-norm-grad") \ V(kBatchNormInference, "batch-norm-inference") \ @@ -87,6 +88,7 @@ namespace xla { V(kHostCompute, "host-compute") \ V(kImag, "imag") \ V(kInfeed, "infeed") \ + V(kIota, "iota") \ V(kIsFinite, "is-finite") \ V(kLe, "less-than-or-equal-to", kHloOpcodeIsComparison) \ V(kLog, "log") \ @@ -117,6 +119,7 @@ namespace xla { V(kReverse, "reverse") \ V(kRng, "rng") \ V(kRoundNearestAfz, "round-nearest-afz") \ + V(kScatter, "scatter") \ V(kSelect, "select") \ V(kSelectAndScatter, "select-and-scatter") \ V(kSend, "send") \ @@ -133,6 +136,7 @@ namespace xla { V(kTrace, "trace") \ V(kTranspose, "transpose") \ V(kTuple, "tuple", kHloOpcodeIsVariadic) \ + V(kTupleSelect, "tuple-select") \ V(kWhile, "while") enum class HloOpcode { diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index cfe5dace05ac03f1573f90b2ce664c94837837b4..126d3a2d9c70bff1d2a022e395652049768d6d21 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -57,7 +57,7 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { auto builder_c = HloComputation::Builder("C"); HloInstruction* c = builder_c.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); HloComputation* computation_c = module->AddEmbeddedComputation(builder_c.Build()); @@ -145,7 +145,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(scalar_shape, condition, body, constant)); module->AddEntryComputation(builder.Build()); @@ -208,7 +208,7 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(scalar_shape, condition, body, constant)); auto add = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 57d17064c1390f60b7c4bd019b21740f95a615e7..2a8c6ecd9248b9bf77153781d9c169306c9a9197 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -15,8 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -117,11 +119,13 @@ class HloParser { // Types of attributes. enum class AttrTy { + kBool, kInt64, kInt32, kFloat, kString, kBracedInt64List, + kBracedInt64ListList, kHloComputation, kFftType, kWindow, @@ -202,6 +206,10 @@ class HloParser { bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result); + // 'parse_and_add_item' is an lambda to parse an element in the list and add + // the parsed element to the result. It's supposed to capture the result. + bool ParseList(const TokKind start, const TokKind end, const TokKind delim, + const std::function& parse_and_add_item); bool ParseParamListToShape(Shape* shape, LocTy* shape_loc); bool ParseParamList(); @@ -489,6 +497,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateConstant(std::move(literal))); break; } + case HloOpcode::kIota: { + if (!ParseOperands(&operands, /*expected_size=*/0) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateIota(shape)); + break; + } // Unary ops. case HloOpcode::kAbs: case HloOpcode::kRoundNearestAfz: @@ -509,7 +525,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: - case HloOpcode::kSort: case HloOpcode::kTanh: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -552,7 +567,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } // Ternary ops. case HloOpcode::kClamp: - case HloOpcode::kSelect: { + case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: { if (!ParseOperands(&operands, /*expected_size=*/3) || !ParseAttributes(attrs)) { return false; @@ -608,6 +624,28 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } break; } + case HloOpcode::kAllToAll: { + optional>> tmp_groups; + optional barrier; + attrs["replica_groups"] = {/*required=*/false, + AttrTy::kBracedInt64ListList, &tmp_groups}; + attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + std::vector replica_groups; + if (tmp_groups) { + c_transform(*tmp_groups, std::back_inserter(replica_groups), + [](const std::vector& ids) { + ReplicaGroup group; + *group.mutable_replica_ids() = {ids.begin(), ids.end()}; + return group; + }); + } + instruction = builder->AddInstruction(HloInstruction::CreateAllToAll( + shape, operands, replica_groups, barrier ? *barrier : "")); + break; + } case HloOpcode::kReshape: { if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -621,8 +659,38 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = - builder->AddInstruction(HloInstruction::CreateAfterAll(operands)); + if (operands.empty()) { + instruction = builder->AddInstruction(HloInstruction::CreateToken()); + } else { + instruction = + builder->AddInstruction(HloInstruction::CreateAfterAll(operands)); + } + break; + } + case HloOpcode::kSort: { + auto loc = lexer_.GetLoc(); + + optional> dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs) || + dimensions->size() != 1) { + return false; + } + switch (operands.size()) { + case 1: + instruction = builder->AddInstruction(HloInstruction::CreateSort( + shape, dimensions->at(0), /*keys=*/operands[0])); + break; + case 2: + instruction = builder->AddInstruction(HloInstruction::CreateSort( + shape, dimensions->at(0), + /*keys=*/operands[0], /*values=*/operands[1])); + break; + default: + return Error(loc, StrCat("expects either 1 or 2 operands, but has ", + operands.size(), " operands")); + } break; } case HloOpcode::kTuple: { @@ -649,18 +717,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kRecv: { optional channel_id; + // If the is_host_transfer attribute is not present then default to false. + optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; - if (!ParseOperands(&operands, /*expected_size=*/0) || + attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, + &is_host_transfer}; + if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateRecv(shape.tuple_shapes(0), *channel_id)); + // If the is_host_transfer attribute is not present then default to false. + instruction = builder->AddInstruction(HloInstruction::CreateRecv( + shape.tuple_shapes(0), operands[0], *channel_id, *is_host_transfer)); break; } case HloOpcode::kRecvDone: { optional channel_id; + // If the is_host_transfer attribute is not present then default to false. + optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, + &is_host_transfer}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { return false; @@ -668,24 +745,32 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (channel_id != operands[0]->channel_id()) { return false; } - instruction = - builder->AddInstruction(HloInstruction::CreateRecvDone(operands[0])); + instruction = builder->AddInstruction( + HloInstruction::CreateRecvDone(operands[0], *is_host_transfer)); break; } case HloOpcode::kSend: { optional channel_id; + // If the is_host_transfer attribute is not present then default to false. + optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; - if (!ParseOperands(&operands, /*expected_size=*/1) || + attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, + &is_host_transfer}; + if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateSend(operands[0], *channel_id)); + instruction = builder->AddInstruction(HloInstruction::CreateSend( + operands[0], operands[1], *channel_id, *is_host_transfer)); break; } case HloOpcode::kSendDone: { optional channel_id; + // If the is_host_transfer attribute is not present then default to false. + optional is_host_transfer = false; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool, + &is_host_transfer}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { return false; @@ -693,8 +778,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (channel_id != operands[0]->channel_id()) { return false; } - instruction = - builder->AddInstruction(HloInstruction::CreateSendDone(operands[0])); + instruction = builder->AddInstruction( + HloInstruction::CreateSendDone(operands[0], *is_host_transfer)); break; } case HloOpcode::kGetTupleElement: { @@ -807,18 +892,28 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReduce: { + auto loc = lexer_.GetLoc(); + optional reduce_computation; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &reduce_computation}; optional> dimensions_to_reduce; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions_to_reduce}; - if (!ParseOperands(&operands, /*expected_size=*/2) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } + if (operands.size() % 2) { + return Error(loc, StrCat("expects an even number of operands, but has ", + operands.size(), " operands")); + } instruction = builder->AddInstruction(HloInstruction::CreateReduce( - shape, /*operand=*/operands[0], /*init_value=*/operands[1], + shape, /*operands=*/ + tensorflow::gtl::ArraySlice(operands, 0, + operands.size() / 2), + /*init_values=*/ + tensorflow::gtl::ArraySlice( + operands, operands.size() / 2, operands.size()), *dimensions_to_reduce, *reduce_computation)); break; } @@ -1074,13 +1169,24 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kCustomCall: { optional custom_call_target; + optional window; + optional dnums; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; + attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; + attrs["dim_labels"] = {/*required=*/false, + AttrTy::kConvolutionDimensionNumbers, &dnums}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( shape, operands, *custom_call_target)); + if (window.has_value()) { + instruction->set_window(*window); + } + if (dnums.has_value()) { + instruction->set_convolution_dimension_numbers(*dnums); + } break; } case HloOpcode::kHostCompute: { @@ -1161,17 +1267,54 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, return false; } - GatherDimensionNumbers dim_numbers = HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/*output_window_dims, - /*elided_window_dims=*/*elided_window_dims, - /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims, - /*index_vector_dim=*/*index_vector_dim); + GatherDimensionNumbers dim_numbers = + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/*output_window_dims, + /*elided_window_dims=*/*elided_window_dims, + /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims, + /*index_vector_dim=*/*index_vector_dim); instruction = builder->AddInstruction(HloInstruction::CreateGather( shape, /*operand=*/operands[0], /*gather_indices=*/operands[1], dim_numbers, *window_bounds)); break; } + case HloOpcode::kScatter: { + optional> update_window_dims; + attrs["update_window_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &update_window_dims}; + optional> inserted_window_dims; + attrs["inserted_window_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims}; + optional> scatter_dims_to_operand_dims; + attrs["scatter_dims_to_operand_dims"] = {/*required=*/true, + AttrTy::kBracedInt64List, + &scatter_dims_to_operand_dims}; + optional index_vector_dim; + attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, + &index_vector_dim}; + + optional update_computation; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &update_computation}; + + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { + return false; + } + + ScatterDimensionNumbers dim_numbers = + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/*update_window_dims, + /*inserted_window_dims=*/*inserted_window_dims, + /*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims, + /*index_vector_dim=*/*index_vector_dim); + + instruction = builder->AddInstruction(HloInstruction::CreateScatter( + shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1], + /*updates=*/operands[2], *update_computation, dim_numbers)); + break; + } case HloOpcode::kDomain: { DomainData domain; attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain}; @@ -1180,8 +1323,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, return false; } instruction = builder->AddInstruction(HloInstruction::CreateDomain( - shape, operands[0], std::move(domain.entry_metadata), - std::move(domain.exit_metadata))); + shape, operands[0], std::move(domain.exit_metadata), + std::move(domain.entry_metadata))); break; } case HloOpcode::kTrace: @@ -1267,7 +1410,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, bool replicated = false; std::vector devices; std::vector tile_assignment_dimensions; - Shape tile_shape; while (lexer_.GetKind() != TokKind::kRbrace) { switch (lexer_.GetKind()) { case TokKind::kw_maximal: @@ -1318,7 +1460,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, break; } case TokKind::kShape: - tile_shape = lexer_.GetShapeVal(); + // TODO(b/112302613): Left here for backward compatibility to ignore the + // removed tile shape data. lexer_.Lex(); break; case TokKind::kRbrace: @@ -1333,19 +1476,12 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return Error(loc, "replicated shardings should not have any devices assigned"); } - if (!ShapeUtil::Equal(tile_shape, Shape())) { - return Error(loc, - "replicated shardings should not have any tile shape set"); - } sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED); } else if (maximal) { if (devices.size() != 1) { return Error(loc, "maximal shardings should have exactly one device assigned"); } - if (!ShapeUtil::Equal(tile_shape, Shape())) { - return Error(loc, "maximal shardings should not have any tile shape set"); - } sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); sharding->add_tile_assignment_devices(devices[0]); } else { @@ -1353,9 +1489,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return Error( loc, "non-maximal shardings must have more than one device assigned"); } - if (ShapeUtil::Equal(tile_shape, Shape())) { - return Error(loc, "non-maximal shardings should have a tile shape set"); - } if (tile_assignment_dimensions.empty()) { return Error( loc, @@ -1363,7 +1496,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, "dimensions"); } sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); - *sharding->mutable_tile_shape() = tile_shape; for (tensorflow::int64 dim : tile_assignment_dimensions) { sharding->add_tile_assignment_dimensions(dim); } @@ -1520,6 +1652,24 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value, "value ", value, " is out of range for literal's primitive type ", PrimitiveType_Name(literal->shape().element_type()))); } + } else if (std::is_unsigned::value) { + CHECK((std::is_same::value || + std::is_same::value)) + << "Unimplemented checking for ParsedElemT"; + + ParsedElemT upper_bound; + if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) { + upper_bound = std::numeric_limits::max(); + } else { + upper_bound = + static_cast(std::numeric_limits::max()); + } + if (value > upper_bound || value < 0) { + // Value is out of range for LiteralNativeT. + return TokenError(StrCat( + "value ", value, " is out of range for literal's primitive type ", + PrimitiveType_Name(literal->shape().element_type()))); + } } else if (value > static_cast( std::numeric_limits::max()) || value < static_cast( @@ -1588,7 +1738,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, } } } - *literal = Literal::MakeTupleOwned(std::move(elements)); + *literal = LiteralUtil::MakeTupleOwned(std::move(elements)); return ParseToken(TokKind::kRparen, StrCat("expects ')' at the end of the tuple with ", ShapeUtil::TupleElementCount(shape), "elements")); @@ -1616,8 +1766,8 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, } // Create a literal with the given shape in default layout. - *literal = Literal::CreateFromDimensions(shape.element_type(), - AsInt64Slice(shape.dimensions())); + *literal = LiteralUtil::CreateFromDimensions( + shape.element_type(), AsInt64Slice(shape.dimensions())); tensorflow::int64 nest_level = 0; tensorflow::int64 linear_index = 0; // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for @@ -2010,6 +2160,14 @@ bool HloParser::ParseAttributeHelper( bool success = [&] { LocTy attr_loc = lexer_.GetLoc(); switch (attr_type) { + case AttrTy::kBool: { + bool result; + if (!ParseBool(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } case AttrTy::kInt64: { tensorflow::int64 result; if (!ParseInt64(&result)) { @@ -2113,6 +2271,26 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kBracedInt64ListList: { + std::vector> result; + auto parse_and_add_item = [&]() { + std::vector item; + if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, + TokKind::kComma, &item)) { + return false; + } + result.push_back(item); + return true; + }; + if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item)) { + return false; + } + static_cast>>*>( + attr_out_ptr) + ->emplace(result); + return true; + } case AttrTy::kSliceRanges: { SliceRanges result; if (!ParseSliceRanges(&result)) { @@ -2455,6 +2633,26 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end, end, StrCat("expects an int64 list to end with ", TokKindToString(end))); } +bool HloParser::ParseList(const TokKind start, const TokKind end, + const TokKind delim, + const std::function& parse_and_add_item) { + if (!ParseToken(start, StrCat("expects a list starting with ", + TokKindToString(start)))) { + return false; + } + if (lexer_.GetKind() == end) { + // empty + } else { + do { + if (!parse_and_add_item()) { + return false; + } + } while (EatIfPresent(delim)); + } + return ParseToken( + end, StrCat("expects a list to end with ", TokKindToString(end))); +} + // param_list_to_shape ::= param_list '->' shape bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) { if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index da1a34ae3c083ebef4ccc2cfe4daf224cf87e64e..4cd21841f4c25071d222cd291ed56aad2d266ca7 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -277,12 +277,28 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { "SendRecv", R"(HloModule TwoSendRecvBothWayRecvFist_module -ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = (f32[], u32[]) recv(), channel_id=15, sharding={maximal device=1} - ROOT %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15, sharding={maximal device=1} +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { + %token = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, sharding={maximal device=1} + ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1} + %constant = f32[] constant(2.1), sharding={maximal device=0} + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0} +} + +)" +}, +{ +"SendRecvWithHostTransfer", +R"(HloModule HostTransferSendRecv_module + +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) { + %token = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, is_host_transfer=true + ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, is_host_transfer=true %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} - %send-done = () send-done((f32[], u32[]) %send), channel_id=16, sharding={maximal device=0} + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, is_host_transfer=true + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, is_host_transfer=true } )" @@ -742,6 +758,46 @@ ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5 ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} } +)" +}, +{ +"scatter", +R"(HloModule StringifyScatter + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] { + %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) + %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2) + ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%add_F32.v3 +} + +)" +}, +{ + "ConstantUnsignedNoUnderflow", + R"(HloModule ConstantUnsignedNoUnderflow_module + +ENTRY %ConstantUnsignedNoUnderflow () -> u64[] { + ROOT %constant = u64[] constant(1) +} + +)" +}, + +{ + "ConstantUnsignedNoOverflow", + R"(HloModule ConstantUnsignedNoOverflow_module + +ENTRY %ConstantUnsignedNoOverflow () -> u64[] { + ROOT %constant = u64[] constant(9223372036854775807) +} + )" }, }); @@ -787,6 +843,32 @@ ENTRY ReduceR3ToR2.v3 { ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 } +)" +}, +// tuple reduce +{ +"TupleReduce", +R"(HloModule TupleReduce + +max_argmax { + value = f32[] parameter(2) + prev_max = f32[] parameter(0) + is_next_larger = pred[] greater-than-or-equal-to(value, prev_max) + max = f32[] select(is_next_larger, value, prev_max) + index = s32[] parameter(3) + prev_argmax = s32[] parameter(1) + argmax = s32[] select(is_next_larger, index, prev_argmax) + ROOT pair = (f32[], s32[]) tuple(max, argmax) +} + +ENTRY reduce_entry { + values = f32[1024]{0} parameter(0) + indices = f32[1024]{0} parameter(1) + init_value = f32[] constant(-inf) + init_index = s32[] constant(-1) + ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax +} + )" }, // infeed/outfeed @@ -830,6 +912,56 @@ ENTRY ReducePrecision { ROOT reduce-precision = f32[1]{0} reduce-precision(constant), exponent_bits=8, mantissa_bits=10 } +)" +}, +// Sort (Key) +{ +"SortKey", +R"(HloModule sort + +ENTRY Sort { + x = f32[1024]{0} parameter(0) + ROOT sorted = f32[1024]{0} sort(x), dimensions={0} +} + +)" +}, +// Sort (Key, Value) +{ +"SortKeyValue", +R"(HloModule sort + +ENTRY Sort { + keys = f32[1024]{0} parameter(0) + values = s32[1024]{0} parameter(1) + ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0} +} + +)" +}, +// R2 Sort (Key) +{ +"SortKeyR2", +R"(HloModule sort + +ENTRY Sort { + x = f32[1024,16]{0,1} parameter(0) + ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0} +} + +)" +}, +// R2 Sort (Key, Value) +{ +"SortKeyValueR2", +R"(HloModule sort + +ENTRY Sort { + keys = f32[1024,16]{0,1} parameter(0) + values = s32[1024,16]{0,1} parameter(1) + ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0} +} + )" }, // Conditional @@ -938,6 +1070,52 @@ ENTRY CrossReplicaSumWithSubgroups { ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_group_ids={0,0,1,1}, barrier="abc", to_apply=add } +)" +}, +// all-to-all +{ +"AllToAll", +R"(HloModule AllToAll + +ENTRY AllToAll { + input = f32[128,32]{0,1} parameter(0) + ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={} +} + +)" +}, +// all-to-all with subgroups +{ +"AllToAllWithSubgroups", +R"(HloModule AllToAllWithSubgroups + +ENTRY AllToAllWithSubgroups { + input = f32[128,32]{0,1} parameter(0) + ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}}, barrier="abc" +} + +)" +}, +// Iota +{ +"Iota", +R"(HloModule iota + +ENTRY Iota { + ROOT iota = f32[100]{0} iota() +} + +)" +}, +// custom-call with window and dim_labels +{ +"CustomCallWithWindowAndDimLabels", +R"(HloModule CustomCallWithWindowAndDimLabels + +ENTRY Computation { + ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="target" +} + )" } }); @@ -1136,6 +1314,40 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { "is out of range for literal's primitive type F16"); } +TEST_F(HloParserTest, ConstantUnsignedUnderflow) { + const string original = R"( + HloModule ConstantUnsignedUnderflow_module + ENTRY %ConstantUnsignedUnderflow () -> u64[] { + ROOT %constant = u64[] constant(-1) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); + ExpectHasSubstr(result.status().error_message(), + "is out of range for literal's primitive type U64"); +} + +TEST_F(HloParserTest, ConstantUnsignedOverflow) { + const string original = R"( + HloModule ConstantUnsignedOverflow_module + ENTRY %ConstantUnsignedOverflow () -> u32[] { + ROOT %constant = u32[] constant(4294967296) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); + ExpectHasSubstr(result.status().error_message(), + "is out of range for literal's primitive type U32"); +} + +TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) { + const string original = R"( + HloModule ConstantUnsignedOverflow_module + ENTRY %ConstantUnsignedOverflow () -> u64[] { + ROOT %constant = u64[] constant(9223372036854775808) + })"; + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); +} + TEST_F(HloParserTest, ConstantWithExp) { const string original = R"(HloModule ConstantWithExp_module @@ -1196,11 +1408,12 @@ TEST_F(HloParserTest, UnexpectedAttribute) { const string original = R"(HloModule unexpected_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = (f32[], u32[]) recv(), channel_id=15 - %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + %token = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, calls=%recv - %send-done = () send-done((f32[], u32[]) %send), channel_id=16 + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, calls=%recv + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } )"; @@ -1212,11 +1425,12 @@ TEST_F(HloParserTest, MissingAttribute) { const string original = R"(HloModule missing_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = (f32[], u32[]) recv(), channel_id=15 - %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + %token = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(-2.1) - %send = (f32[], u32[]) send(f32[] %constant) - %send-done = () send-done((f32[], u32[]) %send), channel_id=16 + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token) + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } )"; @@ -1228,11 +1442,12 @@ TEST_F(HloParserTest, PredecessorUndefined) { const string original = R"(HloModule pre_not_found_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = (f32[], u32[]) recv(), channel_id=15 - %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + %token = token[] after-all() + %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15 + %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15 ROOT %constant = f32[] constant(2.1) - %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, control-predecessors={%done} - %send-done = () send-done((f32[], u32[]) %send), channel_id=16 + %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, control-predecessors={%done} + %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16 } )"; diff --git a/tensorflow/compiler/xla/service/hlo_pass_fix.h b/tensorflow/compiler/xla/service/hlo_pass_fix.h index b3d0a07add39968c6310392ea01daeab8a7dd9af..28194deb0e32252b372a328b006dabaf250fa2c7 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_fix.h +++ b/tensorflow/compiler/xla/service/hlo_pass_fix.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_FIX_H_ +#include + #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,9 +36,19 @@ class HloPassFix : public Pass { StatusOr Run(HloModule* module) override { bool changed = false; bool changed_this_iteration = true; + int64 iteration_count = 0; + int64 limit = + std::max(static_cast(1000), module->instruction_count()); while (changed_this_iteration) { TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module)); changed |= changed_this_iteration; + ++iteration_count; + if (iteration_count == limit) { + LOG(ERROR) + << "Unexpectedly number of iterations in HLO passes (" + << iteration_count + << ")\nIf compilation hangs here, please file a bug with XLA."; + } } return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index 2418c19f3de7b036d7ef52d3a6db11de6316203b..2a07b6fcbc243d955e136ccdf097c8155a115845 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_query.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc index 657a9ee83d29e72b95660325f9139f44159d6508..585c95972b0e01abc14543205af71b4b0c0bdf3c 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc @@ -39,15 +39,15 @@ TEST_F(HloReachabilityTest, Reachability) { */ auto builder = HloComputation::Builder(TestName()); auto a = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto b = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto c = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto d = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto e = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); builder.Build(); HloReachabilityMap reachability({a, b, c, d, e}); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 62c07d7fac93618a83b3b6111aec1e93309a0761..cf0be30c7ad5cbeb7fd3d71c7c649b6b448360b8 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -1203,7 +1203,7 @@ StatusOr HloRematerialization::RematerializeComputation( StatusOr HloRematerialization::Run( HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, int64 memory_limit_bytes, RematerializationSizes* sizes, - bool run_copy_elision) { + CopyInsertion* copy_insertion) { // The sequence is constructed entirely by this method. TF_RET_CHECK(sequence->empty()); @@ -1238,13 +1238,14 @@ StatusOr HloRematerialization::Run( return size_function_(buffer.shape()); }, scheduler_algorithm_)); - if (run_copy_elision) { + if (copy_insertion) { // We run a separate pass of copy elision here because the sequential // ordering from the HLO schedule allows for more copies to be eliminated. // TODO(b/80249101): Instead of a separate copy elision pass, use the // ordering from the HLO schedule directly for copy insertion. SequentialHloOrdering ordering(module, *sequence); - TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, {}, module)); + TF_RETURN_IF_ERROR( + copy_insertion->RemoveUnnecessaryCopies(ordering, module)); } // Compute peak memory usage of all computations in the module called in a @@ -1349,10 +1350,10 @@ StatusOr HloRematerialization::Run( int64 memory_limit_bytes, HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes, bool run_copy_elision) { + RematerializationSizes* sizes, CopyInsertion* copy_insertion) { HloRematerialization remat(scheduler_algorithm, size_function); return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes, - run_copy_elision); + copy_insertion); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 59b4cf5dcc761f70767ce4d7ff0959448f29939a..2ec004350ad88ff31ece90ec419d90a55b965166 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -17,6 +17,7 @@ #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -57,8 +58,9 @@ class HloRematerialization { // sizes: Optional outparam that indicates the peak memory usage of the HLO // module before/after rematerialization. // - // run_copy_elision: Enable copy elision. This pass is used to eliminate - // copies that were inserted before HLO scheduling. + // copy_insertion: If non-null, run copy elision after scheduling. This + // pass is used to eliminate copies that were inserted by copy insertion + // before HLO scheduling. // // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy // insertion is integrated with HLO scheduling. @@ -74,7 +76,7 @@ class HloRematerialization { const ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes, bool run_copy_elision = true); + RematerializationSizes* sizes, CopyInsertion* copy_insertion = nullptr); protected: HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, @@ -90,7 +92,7 @@ class HloRematerialization { StatusOr Run(HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, int64 memory_limit, RematerializationSizes* sizes, - bool run_copy_elision); + CopyInsertion* copy_insertion); // Rematerializes instructions within the given computation. 'order' is the // order in which the computation's instructions will be emitted in the diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 7a46da6efe0df23129d56e16355cf66aceb68ffe..ac8c97d380953764b66135ad1c5fcee0d481c004 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -132,7 +132,7 @@ class HloRematerializationTest : public HloTestBase { builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); return builder.Build(); } @@ -147,7 +147,7 @@ class HloRematerializationTest : public HloTestBase { TF_EXPECT_OK(verifier().Run(module).status()); return HloRematerialization::RematerializeAndSchedule( ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, - sequence, /*sizes=*/nullptr, /*run_copy_elision=*/false); + sequence, /*sizes=*/nullptr); } // Various shapes used in the canned computations. @@ -226,7 +226,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); HloComputation* while_cond = module->AddEmbeddedComputation(cond_builder.Build()); @@ -263,7 +263,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); HloComputation* while_cond = module->AddEmbeddedComputation(cond_builder.Build()); @@ -296,7 +296,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); HloComputation* while_cond = module->AddEmbeddedComputation(cond_builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index c6d3909af6103949daf4b0ab6be9b74724461e30..27cc5361cde2fa021b9489f98217ae5648afc2ad 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -567,6 +567,7 @@ StatusOr ScheduleComputationsInModule( sequence[computation] = std::move(one_computation_sequence); } } + VLOG(1) << "Module schedule:\n" << sequence; return sequence; } diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 73f22f81f4e9cf597db8b184642acff2fdaaf2b0..9ec983c2bc353955cb23d441d200ac8aa36951b1 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -168,8 +168,9 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { auto cond_builder = HloComputation::Builder("WhileCond"); HloInstruction* cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "cond_param")); - HloInstruction* zero_vector = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{0, 0, 0, 0}}))); + HloInstruction* zero_vector = + cond_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{0, 0, 0, 0}}))); cond_builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); @@ -179,16 +180,18 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { auto body_builder = HloComputation::Builder("WhileBody"); HloInstruction* body_param = body_builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "body_param")); - HloInstruction* one_vector = body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + HloInstruction* one_vector = + body_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 1, 1, 1}}))); body_builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kSubtract, body_param, one_vector)); auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); // transpose(matrix) + bcast(while) auto builder = HloComputation::Builder(TestName()); - HloInstruction* while_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + HloInstruction* while_init = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 1, 1, 1}}))); // Creates 16 bytes, ignoring subcomputations HloInstruction* while_loop = builder.AddInstruction(HloInstruction::CreateWhile( @@ -199,7 +202,7 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { HloInstruction::CreateBroadcast(r2f32, while_loop, {0})); HloInstruction* matrix = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2( + HloInstruction::CreateConstant(LiteralUtil::CreateR2( {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); // Creates 32 bytes HloInstruction* transpose = builder.AddInstruction( @@ -257,7 +260,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { // Wrap lit in abs because constants are considered free by // IgnoreInstruction, and it skews the accounting. auto lit = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1, 1, 1, 1, 1, 1}))); + LiteralUtil::CreateR1({1, 1, 1, 1, 1, 1}))); auto abs_const = builder.AddInstruction( HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit)); @@ -279,7 +282,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, ScheduleComputationsInModule(*module, - [&TUPLE_SIZE](const BufferValue& buffer) { + [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf( buffer.shape(), TUPLE_SIZE); }, @@ -300,11 +303,11 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { HloComputation::Builder builder(TestName()); auto c1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1, 1, 1, 1, 1}))); + LiteralUtil::CreateR1({1, 1, 1, 1, 1}))); auto c2 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1, 2, 3, 4, 5}))); + LiteralUtil::CreateR1({1, 2, 3, 4, 5}))); auto c3 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({0, 2, 4, 6, 8}))); + LiteralUtil::CreateR1({0, 2, 4, 6, 8}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2)); @@ -354,8 +357,9 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { auto cond_builder = HloComputation::Builder("WhileCond"); HloInstruction* cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "cond_param")); - HloInstruction* zero_vector = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{0, 0, 0, 0}}))); + HloInstruction* zero_vector = + cond_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{0, 0, 0, 0}}))); cond_builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); @@ -365,15 +369,17 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { auto body_builder = HloComputation::Builder("WhileBody"); HloInstruction* body_param = body_builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "body_param")); - HloInstruction* one_vector = body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + HloInstruction* one_vector = + body_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 1, 1, 1}}))); body_builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kSubtract, body_param, one_vector)); auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); auto builder = HloComputation::Builder(TestName()); - HloInstruction* while_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + HloInstruction* while_init = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1, 1, 1, 1}}))); // Creates 16 bytes, ignoring subcomputations builder.AddInstruction(HloInstruction::CreateWhile( r1f32, cond_computation, body_computation, while_init)); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 268b4727bcbed42ba71526f1d5ef5c887e941930..879fb3bbab2ada0f924282f16b3d9ccb4c2cb203 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -31,12 +31,9 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { CHECK_EQ(1, ShapeUtil::Rank(input_shape)); CHECK_GT(num_tiles, 1); std::vector dimensions(1, num_tiles); - Shape tile_shape = input_shape; - auto& tile_dimension = (*tile_shape.mutable_dimensions())[0]; - tile_dimension = CeilOfRatio(static_cast(tile_dimension), num_tiles); Array assignment(dimensions); std::iota(assignment.begin(), assignment.end(), 0); - return HloSharding(tile_shape, assignment); + return HloSharding(assignment); } HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { @@ -60,6 +57,9 @@ HloSharding HloSharding::Tuple( const Shape& tuple_shape, tensorflow::gtl::ArraySlice shardings) { CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + for (auto& sharding : shardings) { + CHECK(!sharding.IsTuple()) << sharding.ToString(); + } std::vector flattened_list(shardings.begin(), shardings.end()); CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape)) << "Flat list has " << flattened_list.size() << ", required " @@ -67,6 +67,24 @@ HloSharding HloSharding::Tuple( return HloSharding(flattened_list); } +HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, + const HloSharding& sharding) { + CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + CHECK(!sharding.IsTuple()) << sharding.ToString(); + int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape); + std::vector flattened_list; + flattened_list.reserve(leaf_count); + for (int64 i = 0; i < leaf_count; ++i) { + flattened_list.push_back(sharding); + } + return HloSharding(flattened_list); +} + +HloSharding HloSharding::Single(const Shape& shape, + const HloSharding& sharding) { + return ShapeUtil::IsTuple(shape) ? SingleTuple(shape, sharding) : sharding; +} + string HloSharding::ToString() const { if (IsTuple()) { std::vector parts; @@ -83,8 +101,7 @@ string HloSharding::ToString() const { return StrCat( "{maximal device=", static_cast(*tile_assignment_.begin()), "}"); } else { - return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", "devices=[", - Join(tile_assignment_.dimensions(), ","), "]", + return StrCat("{devices=[", Join(tile_assignment_.dimensions(), ","), "]", Join(tile_assignment_, ","), "}"); } } @@ -106,15 +123,15 @@ std::map HloSharding::UsedDevices(int64* count) const { if (IsTuple()) { for (auto& tuple_element_sharding : tuple_elements()) { auto unique_device = tuple_element_sharding.UniqueDevice(); - if (unique_device.ok()) { - device_map[unique_device.ValueOrDie()] += 1; + if (unique_device) { + device_map[*unique_device] += 1; } } element_count = tuple_elements().size(); } else { auto unique_device = UniqueDevice(); - if (unique_device.ok()) { - device_map[unique_device.ValueOrDie()] += 1; + if (unique_device) { + device_map[*unique_device] += 1; } } if (count != nullptr) { @@ -124,7 +141,6 @@ std::map HloSharding::UsedDevices(int64* count) const { } std::vector HloSharding::TileIndexForDevice(int64 device) const { - CHECK(!ShapeUtil::IsTuple(tile_shape_)); CHECK(!maximal_); CHECK(!IsTuple()); std::vector ret_index; @@ -144,32 +160,43 @@ int64 HloSharding::DeviceForTileIndex( if (maximal_) { return *tile_assignment_.begin(); } - CHECK_EQ(ShapeUtil::Rank(tile_shape_), tile_assignment_.dimensions().size()); return tile_assignment_(index); } -std::vector HloSharding::TileOffsetForDevice(int64 device) const { +std::vector HloSharding::TileOffsetForDevice(const Shape& shape, + int64 device) const { CHECK(!IsTuple()); - std::vector index = TileIndexForDevice(device); if (maximal_) { - // Index will always be all zeroes if we're maximal, and tile_shape_ is not - // valid. - return index; + return std::vector(shape.dimensions_size(), 0); } + + CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions()); + std::vector index = TileIndexForDevice(device); for (int64 i = 0; i < index.size(); ++i) { - index[i] *= tile_shape_.dimensions(i); + const int64 shape_dim = shape.dimensions(i); + index[i] = std::min( + index[i] * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), shape_dim); } return index; } -std::vector HloSharding::TileLimitForDevice(int64 device) const { +std::vector HloSharding::TileLimitForDevice(const Shape& shape, + int64 device) const { CHECK(!IsTuple()); - CHECK(!maximal_); // Maximal shardings do not have a valid tile shape. + if (maximal_) { + return std::vector(shape.dimensions().begin(), + shape.dimensions().end()); + } + + CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions()); std::vector index = TileIndexForDevice(device); for (int64 i = 0; i < index.size(); ++i) { - index[i] = (index[i] + 1) * tile_shape_.dimensions(i); + const int64 shape_dim = shape.dimensions(i); + index[i] = std::min( + (index[i] + 1) * CeilOfRatio(shape_dim, tile_assignment_.dim(i)), + shape_dim); } return index; } @@ -217,40 +244,31 @@ StatusOr HloSharding::GetTupleSharding(const Shape& shape) const { return Tuple(ShapeTree(shape, *this)); } -StatusOr HloSharding::UniqueDevice() const { +tensorflow::gtl::optional HloSharding::UniqueDevice() const { if (IsTuple()) { if (tuple_elements_.empty()) { - return tensorflow::errors::InvalidArgument( - "UniqueDevice() called on empty tuple"); + return tensorflow::gtl::nullopt; } - std::vector> results; - std::transform(tuple_elements_.begin(), tuple_elements_.end(), - std::back_inserter(results), - [](const HloSharding& s) { return s.UniqueDevice(); }); - if (std::all_of(results.begin(), results.end(), - [&](const StatusOr& s) { - return s.ok() && results[0].ok() && - s.ValueOrDie() == results[0].ValueOrDie(); - })) { - return results[0]; - } else { - return tensorflow::errors::InvalidArgument( - "Tuple did not contain a unique device"); + tensorflow::gtl::optional unique_device; + for (auto& tuple_sharding : tuple_elements_) { + auto device = tuple_sharding.UniqueDevice(); + if (!device || (unique_device && *device != *unique_device)) { + return tensorflow::gtl::nullopt; + } + unique_device = device; } + return unique_device; } - if (!replicated_ && maximal_ && !IsTuple()) { + if (!replicated_ && maximal_) { return static_cast(*tile_assignment_.begin()); } - return tensorflow::errors::InvalidArgument( - "UniqueDevice() called on sharding that executes on multiple devices"); + return tensorflow::gtl::nullopt; } -bool HloSharding::HasUniqueDevice() const { - if (IsTuple()) { - return UniqueDevice().status().ok(); - } else { - return !IsReplicated() && IsTileMaximal(); - } +int64 HloSharding::GetUniqueDevice() const { + auto device = UniqueDevice(); + CHECK(device) << "Sharding does not have a unique device: " << *this; + return *device; } Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { @@ -324,11 +342,12 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, return Status::OK(); } - // The tile rank must be the same as the input rank. - if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) { + // The tile assignment tensor must have the same rank as the input. + if (ShapeUtil::Rank(shape) != tile_assignment_.num_dimensions()) { return tensorflow::errors::InvalidArgument( - "Tile rank is different to the input rank. sharding=", ToString(), - ", input_shape=", ShapeUtil::HumanString(shape)); + "Number of tile assignment dimensions is different to the input rank. " + "sharding=", + ToString(), ", input_shape=", ShapeUtil::HumanString(shape)); } // The correct constructor have to be used to create tile maximal shardings. @@ -338,20 +357,6 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, "sharding was intended, use HloSharding::Replicated(). If a device " "placement was intended, use HloSharding::AssignDevice()"); } - - // The tile assignment tensor must contain enough element to cover the full - // shape with tiles of the specified size. - for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) { - int64 total_tile_size = tile_assignment_.dim(i) * tile_shape_.dimensions(i); - if (shape.dimensions(i) > total_tile_size) { - return tensorflow::errors::InvalidArgument( - StrCat("Tile assignment tensor has too few element to cover the full " - "shape. Dimension ", - i, ", shape ", shape.dimensions(i), ", total size ", - total_tile_size)); - } - } - return Status::OK(); } @@ -381,7 +386,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, proto.tile_assignment_dimensions().end())); std::copy(proto.tile_assignment_devices().begin(), proto.tile_assignment_devices().end(), tile_assignment.begin()); - return HloSharding(proto.tile_shape(), tile_assignment); + return HloSharding(tile_assignment); } OpSharding HloSharding::ToProto() const { @@ -395,7 +400,6 @@ OpSharding HloSharding::ToProto() const { return result; } - *result.mutable_tile_shape() = tile_shape_; for (int64 dim : tile_assignment_.dimensions()) { result.add_tile_assignment_dimensions(dim); } @@ -412,30 +416,16 @@ OpSharding HloSharding::ToProto() const { return result; } -HloSharding HloSharding::TransformShardedTileShape( - const Shape& new_shape, - const std::function& transform) const { - CHECK(!IsTuple()); +Shape HloSharding::TileShape(const Shape& shape) const { if (IsTileMaximal()) { - return *this; + return shape; } - CHECK_EQ(ShapeUtil::Rank(new_shape), ShapeUtil::Rank(tile_shape())); - Shape new_tile_shape; - new_tile_shape.set_element_type(tile_shape().element_type()); - for (int64 i = 0; i < ShapeUtil::Rank(new_shape); ++i) { - int64 dim; - if (tile_assignment().dim(i) == 1) { - dim = new_shape.dimensions(i); - } else if (transform) { - dim = transform(i, tile_shape().dimensions(i)); - } else { - dim = tile_shape().dimensions(i); - } - new_tile_shape.add_dimensions(dim); + Shape result_shape = shape; + for (int64 i = 0; i < shape.dimensions_size(); ++i) { + (*result_shape.mutable_dimensions())[i] = + CeilOfRatio(shape.dimensions(i), tile_assignment_.dim(i)); } - TF_CHECK_OK( - LayoutUtil::CopyLayoutBetweenShapes(tile_shape_, &new_tile_shape)); - return HloSharding::Tile(new_tile_shape, tile_assignment()); + return result_shape; } HloSharding HloSharding::GetSubSharding(const Shape& shape, @@ -477,9 +467,6 @@ size_t HloSharding::Hash() const { for (uint32 v : tile_assignment_) { h = tensorflow::Hash64Combine(h, std::hash{}(v)); } - for (uint32 v : tile_shape_.dimensions()) { - h = tensorflow::Hash64Combine(h, std::hash{}(v)); - } return h; } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 34324d2058efe804cda486600dabd8a62cb84fda..894783e5d1538fa4e8e91b65827121f32040af83 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -24,7 +24,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -48,22 +48,10 @@ class HloSharding { // the input shape (one tile) assigned to a single device. static HloSharding AssignDevice(int64 device_id); - // Creates a new sharding which splits a shape into tiles each with shape - // `tile_shape`. Each tile is assigned to one device, which is specified by - // `tile_assignment`. Any tensor not a multiple of the tile size in any - // dimension is implicitly padded to the tile size. - // - // e.g. Tile({2, 2}, {0, 1}) on a tensor of shape {3, 2} would look like: - // 2 1 padding - // <------><-> - // +----+----+ - // | 0 | 1 | - // +----+----+ - // - // Split into two tiles, one of which is implicitly padded by one. - static HloSharding Tile(const Shape& tile_shape, - const Array& tile_assignment) { - return HloSharding(tile_shape, tile_assignment); + // Creates a new sharding which splits a shape into tiles amongst the devices + // specified by `tile_assignment`. + static HloSharding Tile(const Array& tile_assignment) { + return HloSharding(tile_assignment); } // Creates a new sharding which splits a one-dimensional input shape into @@ -80,6 +68,15 @@ class HloSharding { static HloSharding Tuple(const Shape& tuple_shape, tensorflow::gtl::ArraySlice shardings); + // Creates a new sharding for a tuple type, with a single input sharding + // repeated on each leaf. + static HloSharding SingleTuple(const Shape& tuple_shape, + const HloSharding& sharding); + + // If shape is an array, returns sharding, otherwise returns the tuple shaped + // sharding with all the leaf nodes having the same input sharding. + static HloSharding Single(const Shape& shape, const HloSharding& sharding); + // Create a new sharding from a protobuf OpSharding. static StatusOr FromProto(const OpSharding& proto); @@ -137,24 +134,30 @@ class HloSharding { // REQUIRES: !IsTuple() int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice index) const; - // Given a device ID, returns the offset within the input space of the + // Given a device ID, returns the offset within the specified shape of the // tile that should be executed on the given core. This returns the lower // extent of the tile in the input space. // REQUIRES: !IsTuple() - std::vector TileOffsetForDevice(int64 device) const; + std::vector TileOffsetForDevice(const Shape& shape, + int64 device) const; - // Given a device ID, returns the limit within the input space of the + // Given a device ID, returns the limit within the specified shape of the // tile that should be executed on the given core. This returns the upper // extent of the tile in the input space. // REQUIRES: !IsTuple() - std::vector TileLimitForDevice(int64 device) const; + std::vector TileLimitForDevice(const Shape& shape, int64 device) const; + + // Returns the single device this op operates on. If the sharding does not + // span a single device, the return value will be empty. + // In order for a sharding to span a single device, every leaf sharding must + // be maximal and not replicated, and the used device must match. + tensorflow::gtl::optional UniqueDevice() const; - // Returns the single device this op operates on. - // REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal() - StatusOr UniqueDevice() const; + // Retrieves the unique device or fails with a CHECK. + int64 GetUniqueDevice() const; // Returns true if this op only uses a single device. - bool HasUniqueDevice() const; + bool HasUniqueDevice() const { return UniqueDevice().has_value(); } // Returns the ShapeTree containing the shardings for each element of this // tuple, if IsTuple, or a ShapeTree with a single element containing this @@ -183,7 +186,6 @@ class HloSharding { bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && - ShapeUtil::Compatible(tile_shape_, other.tile_shape_) && tile_assignment_ == other.tile_assignment_ && tuple_elements_ == other.tuple_elements_; } @@ -197,9 +199,6 @@ class HloSharding { } }; - // Gets the tile shape. - // REQUIRES: !IsTileMaximal() && !IsTuple() - const Shape& tile_shape() const { return tile_shape_; } // Gets the tile assignment tensor. // REQUIRES: !IsReplicated() && !IsTuple() const Array& tile_assignment() const { return tile_assignment_; } @@ -211,25 +210,15 @@ class HloSharding { return tuple_elements_; } - // Return a new sharding that can apply to the given new shape. - // If this sharding is tile-maximal, the returned sharding will be the same as - // this sharding. If this sharding is not tile-maximal, the returned - // sharding's tile size will differ: - // - Non-sharded dimensions will be adapted to be the same as `new_shape`; - // tile_dimension(i) = new_shape.dimensions(i); - // - Sharded dimensions will be kept the same unless `transform` is supplied - // in which case tile_dimension(i) = transform(i, tile_dimension(i)); - // REQUIRES: !IsTuple(). - HloSharding TransformShardedTileShape( - const Shape& new_shape, - const std::function& transform = nullptr) const; + // Gets the tile shape. + // REQUIRES: !IsTuple() + Shape TileShape(const Shape& shape) const; private: HloSharding() : replicated_(true), maximal_(true), tuple_(false), - tile_shape_(), tile_assignment_({0}) {} // device_id values: // -2: magic number to mean unassigned device, used by spatial partitioning @@ -241,15 +230,13 @@ class HloSharding { : replicated_(false), maximal_(true), tuple_(false), - tile_shape_(), tile_assignment_({1}, device_id) {} - HloSharding(const Shape& tile_shape, const Array& tile_assignment) + explicit HloSharding(const Array& tile_assignment) : replicated_(false), maximal_(false), tuple_(false), - tile_shape_(tile_shape), tile_assignment_(tile_assignment) {} - HloSharding(const std::vector& tuple_shardings) + explicit HloSharding(const std::vector& tuple_shardings) : replicated_(false), maximal_(false), tuple_(true), @@ -272,7 +259,6 @@ class HloSharding { bool replicated_; bool maximal_; bool tuple_; - Shape tile_shape_; Array tile_assignment_; // Only non-empty when tuple_ is true, but because empty tuples are allowed // may also be empty even then. This is a flattened list of all the leaf diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 39036e205e76979e7da08246cd030ebd17e52f76..94f5a3b273b2fd7e545472c42f3863f549dd3db1 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -88,6 +88,12 @@ std::vector LocatePassThroughDomainLinks( VLOG(2) << " " << instruction->ToString(); } } + if (instruction == instruction->parent()->root_instruction()) { + pass_through.emplace_back(nullptr, instruction); + VLOG(2) << "Found passthrough domain link:"; + VLOG(2) << " "; + VLOG(2) << " " << instruction->ToString(); + } } return pass_through; } @@ -101,8 +107,12 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, HloInstruction::CreateGetTupleElement(pass_through.operand->shape(), tuple, 0)); gte->set_sharding(sharding); - TF_RETURN_IF_ERROR( - pass_through.operand->ReplaceUseWith(pass_through.user, gte)); + if (pass_through.user != nullptr) { + TF_RETURN_IF_ERROR( + pass_through.operand->ReplaceUseWith(pass_through.user, gte)); + } else { + pass_through.operand->parent()->set_root_instruction(gte); + } } return Status::OK(); } @@ -235,21 +245,6 @@ StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, Status ApplyDomainSharding(const DomainMetadata::Domain& domain, const HloSharding& sharding) { - // Here is the place to call external sharding normalizers, which are - // implemented in other modules (ie, spatial partitioning). - // The signature of the external normalizer function should be something - // like: - // - // StatusOr Normalizer(const DomainMetadata::Domain&, - // const HloSharding& sharding); - // - // The function should return true if it has processed the domain - // normalization, false if domain was not one recognized by it, or an error. - // We will call the functions in order below, and fall back to local code if - // none of the external normalizers acted on the domain. - // External normalizers should not handle the cases that are already handled - // locally. - // None of the external normalizers handled the domain sharding, try to see // whether this is a single sharding first. auto single_sharding = sharding.ExtractSingleSharding(); @@ -380,25 +375,36 @@ string ShardingMetadata::ToString() const { return sharding_ != nullptr ? sharding_->ToString() : "{}"; } -Status ShardingMetadata::NormalizeInstructions( - const DomainMetadata::Domain& domain) const { - if (sharding_ != nullptr) { - VLOG(4) << "Normalizing sharding to " << sharding_->ToString() << ":"; - TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding_)); - TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding_)); +/*static*/ StatusOr +ShardingMetadata::ToShardingMetadata(const DomainMetadata* metadata) { + if (metadata->Kind() != ShardingMetadata::KindName()) { + return Status( + tensorflow::error::INVALID_ARGUMENT, + "ShardingMetadata normalizer called with incorrect domain metadata"); } - return Status::OK(); + return static_cast(metadata); } -Status NormalizeShardingDomain(const DomainMetadata::Domain& domain) { - TF_ASSIGN_OR_RETURN(std::unique_ptr sharding, - ExtractOriginalCommonSharding(domain.instructions)); - if (sharding != nullptr) { - VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString() - << ":"; - TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding)); +Status ShardingMetadata::NormalizeShardingDomain( + const DomainMetadata::Domain& domain, const DomainMetadata* metadata) { + if (metadata != nullptr) { + TF_ASSIGN_OR_RETURN(const auto& sharding_metadata, + ToShardingMetadata(metadata)); + const HloSharding* sharding = sharding_metadata->sharding(); + if (sharding != nullptr) { + VLOG(4) << "Normalizing sharding to " << sharding->ToString() << ":"; + TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding)); + TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding)); + } } else { - VLOG(1) << "Unable to find common sharding"; + TF_ASSIGN_OR_RETURN(std::unique_ptr sharding, + ExtractOriginalCommonSharding(domain.instructions)); + if (sharding != nullptr) { + VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString(); + TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding)); + } else { + VLOG(1) << "Unable to find common sharding"; + } } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h index ec162c34904ee2dfac3daeeee37133282a9c9698..5e01fc0e22ae8f3421c2cb5790adf44b1200a804 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -38,23 +38,26 @@ class ShardingMetadata : public DomainMetadata { string ToString() const override; - Status NormalizeInstructions( - const DomainMetadata::Domain& domain) const override; + const HloSharding* sharding() const { return sharding_.get(); } static tensorflow::StringPiece KindName() { return "sharding"; } + static StatusOr ToShardingMetadata( + const DomainMetadata* metadata); + + // Apply the specified domain metadata onto the specified domain. If no + // metadata is specified then apply sharding heuristics and normalize the + // instructions whose sharding deviates from the one which is inferred as to + // be the original one. Policy wise, HLO passes are allowed to create new + // unassigned instructions, but if they do create assigned ones, they have to + // conform to the ones around. + static Status NormalizeShardingDomain(const DomainMetadata::Domain& domain, + const DomainMetadata* metadata); + private: std::unique_ptr sharding_; }; -// Within a set of instructions which had common sharding attributes before -// entring the HLO passes pipeline, apply sharding heuristics and normalize the -// instructions whose sharding deviates from the one which is inferred as to be -// the original one. -// Policy wise, HLO passes are allowed to create new unassigned instructions, -// but if they do create assigned ones, they have to conform to the ones around. -Status NormalizeShardingDomain(const DomainMetadata::Domain& domain); - // Given an HLO graph edge between instruction and one of its operands, creates // a ShardingMetadata based kDomain instruction if the sharding between // instruction and operand changes. Returns nullptr if there is no need for a diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 54b7402b866361748d9eb35182b0bf486c4c9bdc..45fc300fcaf5a301fe11768da77a7c0907919c39 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -39,7 +39,6 @@ Array MakeArray(tensorflow::gtl::ArraySlice dimensions, class HloShardingTest : public HloTestBase {}; TEST_F(HloShardingTest, Replicate) { - Shape tile_shape = ShapeUtil::MakeShape(U32, {4}); HloSharding sharding = HloSharding::Replicate(); EXPECT_TRUE(sharding.IsReplicated()); EXPECT_TRUE(sharding.IsTileMaximal()); @@ -51,7 +50,7 @@ TEST_F(HloShardingTest, Replicate) { EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}), /*num_devices=*/2)); - EXPECT_IS_NOT_OK(sharding.UniqueDevice()); + EXPECT_FALSE(sharding.HasUniqueDevice()); } TEST_F(HloShardingTest, DevicePlacement) { @@ -60,7 +59,7 @@ TEST_F(HloShardingTest, DevicePlacement) { EXPECT_TRUE(sharding.IsTileMaximal()); EXPECT_FALSE(sharding.UsesDevice(0)); EXPECT_TRUE(sharding.UsesDevice(5)); - EXPECT_EQ(5, sharding.UniqueDevice().ValueOrDie()); + EXPECT_EQ(5, sharding.GetUniqueDevice()); HloSharding other = HloSharding::Replicate(); EXPECT_NE(other, sharding); @@ -79,37 +78,22 @@ TEST_F(HloShardingTest, DevicePlacement) { TEST_F(HloShardingTest, Tile) { { // Test should fail because of a duplicate tile assignment. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 0, 2, 3})); + HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 0, 2, 3})); EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {4, 6}), /*num_devices=*/4)); } { // Test should fail because of more devices used then `num_device`. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); + HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3})); EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4, 6}), /*num_devices=*/2)); } - { - // Test should fail because the total tiled size in dimension 0 is 4 but we - // have 6 elements along that dimensions. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); - EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {6, 3}), - /*num_devices=*/4)); - } - { // Test should pass. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); + Shape shape = ShapeUtil::MakeShape(U32, {4, 5}); + HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {3, 5}), /*num_devices=*/5)); @@ -118,12 +102,16 @@ TEST_F(HloShardingTest, Tile) { EXPECT_EQ(2, sharding.DeviceForTileIndex({1, 0})); EXPECT_EQ(1, sharding.DeviceForTileIndex({1, 1})); - EXPECT_EQ(sharding.TileOffsetForDevice(0), (std::vector{0, 0})); - EXPECT_EQ(sharding.TileOffsetForDevice(3), (std::vector{0, 3})); - EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector{2, 0})); - EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector{2, 3})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 0), + (std::vector{0, 0})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 3), + (std::vector{0, 3})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 2), + (std::vector{2, 0})); + EXPECT_EQ(sharding.TileOffsetForDevice(shape, 1), + (std::vector{2, 3})); - EXPECT_IS_NOT_OK(sharding.UniqueDevice()); + EXPECT_FALSE(sharding.HasUniqueDevice()); } } @@ -135,8 +123,7 @@ TEST_F(HloShardingTest, NestedTuple) { ShapeUtil::MakeShape(F32, {4, 6}), }); - HloSharding tiled_sharding = HloSharding::Tile( - ShapeUtil::MakeShape(F32, {4, 3}), Array({{0, 1}})); + HloSharding tiled_sharding = HloSharding::Tile(Array({{0, 1}})); OpSharding proto; proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE); *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto(); @@ -187,32 +174,11 @@ TEST_F(HloShardingTest, Hash) { } { - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding1 = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); - HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), - MakeArray({2, 2}, {0, 3, 2, 1})); - EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); - } - - { - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding1 = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); - HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), - MakeArray({2, 2}, {0, 3, 2, 1})); + HloSharding sharding1 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); + HloSharding sharding2 = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); } - { - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding1 = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); - HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), - MakeArray({2, 2}, {0, 3, 1, 2})); - EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); - } - HloSharding default_sharding = HloSharding::Replicate(); { ShapeTree shape_tree(ShapeUtil::MakeTupleShape({}), @@ -259,19 +225,6 @@ TEST_F(HloShardingTest, Hash) { } } -TEST_F(HloShardingTest, TransformShardedTileShapeTest) { - HloSharding sharding = - HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), - Array4D({{{{0, 1}, {2, 3}}}})); - HloSharding result = sharding.TransformShardedTileShape( - ShapeUtil::MakeShape(F32, {13, 15, 17, 19}), - [](int dim, int value) { return dim * 111; }); - HloSharding expected = - HloSharding::Tile(ShapeUtil::MakeShape(F32, {13, 15, 222, 333}), - Array4D({{{{0, 1}, {2, 3}}}})); - EXPECT_EQ(result, expected); -} - TEST_F(HloShardingTest, ToStringReplicatedTest) { HloSharding sharding = HloSharding::Replicate(); EXPECT_EQ(sharding.ToString(), "{replicated}"); @@ -284,9 +237,8 @@ TEST_F(HloShardingTest, ToStringAssignDeviceTest) { TEST_F(HloShardingTest, ToStringTiledTest) { HloSharding sharding = - HloSharding::Tile(ShapeUtil::MakeShape(S32, {7, 11, 13}), - Array3D({{{2, 3}}, {{5, 7}}})); - EXPECT_EQ(sharding.ToString(), "{s32[7,11,13] devices=[2,1,2]2,3,5,7}"); + HloSharding::Tile(Array3D({{{2, 3}}, {{5, 7}}})); + EXPECT_EQ(sharding.ToString(), "{devices=[2,1,2]2,3,5,7}"); } TEST_F(HloShardingTest, ToStringTupleTest) { @@ -294,21 +246,18 @@ TEST_F(HloShardingTest, ToStringTupleTest) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}), ShapeUtil::MakeShape(U32, {7, 25}), ShapeUtil::MakeShape(S32, {9, 11})}), - {HloSharding::Replicate(), - HloSharding::Tile(ShapeUtil::MakeShape(U32, {7, 13}), - Array2D({{3, 5}})), + {HloSharding::Replicate(), HloSharding::Tile(Array2D({{3, 5}})), HloSharding::AssignDevice(3)}); EXPECT_EQ(sharding.ToString(), - "{{replicated}, {u32[7,13] devices=[1,2]3,5}, {maximal device=3}}"); + "{{replicated}, {devices=[1,2]3,5}, {maximal device=3}}"); } TEST_F(HloShardingTest, OstreamTest) { HloSharding sharding = - HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), - Array4D({{{{0, 1}, {2, 3}}}})); + HloSharding::Tile(Array4D({{{{0, 1}, {2, 3}}}})); std::ostringstream oss; oss << sharding; - EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}"); + EXPECT_EQ(oss.str(), "{devices=[1,1,2,2]0,1,2,3}"); } TEST_F(HloShardingTest, ParseHloString) { @@ -319,8 +268,7 @@ TEST_F(HloShardingTest, ParseHloString) { }; check(HloSharding::Replicate()); check(HloSharding::AssignDevice(2)); - check(HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), - Array4D({{{{0}, {1}}}}))); + check(HloSharding::Tile(Array4D({{{{0}, {1}}}}))); // Empty tuple. One sharding is required for empty tuples, as we need to be // able to assign sharding to them, even though they have no leaves. check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}), @@ -332,8 +280,7 @@ TEST_F(HloShardingTest, ParseHloString) { ShapeUtil::MakeShape(F32, {3, 5, 7}), ShapeUtil::MakeShape(F32, {3, 7})}); check(HloSharding::Tuple( - tuple_shape, {HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), - Array4D({{{{0}, {1}}}})), + tuple_shape, {HloSharding::Tile(Array4D({{{{0}, {1}}}})), HloSharding::Replicate(), HloSharding::AssignDevice(1)})); } { @@ -343,8 +290,7 @@ TEST_F(HloShardingTest, ParseHloString) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5, 7}), ShapeUtil::MakeShape(F32, {3, 7})})}); std::vector leaf_shardings = { - HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), - Array4D({{{{0}, {1}}}})), + HloSharding::Tile(Array4D({{{{0}, {1}}}})), HloSharding::Replicate(), HloSharding::AssignDevice(1)}; ShapeTree sharding_tree(tuple_shape, HloSharding::Replicate()); // Assign leaf_shardings to sharding_tree leaves. diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc index 7b601f9a9578cfa6b293cf7f002255f7db8b1257..45c684d66752862eec301b8943d350804f070309 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc @@ -75,7 +75,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { module->AddEmbeddedComputation(CreateR0S32IdentityComputation()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); auto x = builder.AddInstruction( HloInstruction::CreateCall(r0s32_, {constant}, callee1)); auto y = builder.AddInstruction( @@ -112,9 +112,9 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { module->AddEmbeddedComputation(CreateR0S32AdditionComputation()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(5))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3))); auto x = builder.AddInstruction( HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee1)); auto y = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 3dc733940fc89952bd5e75a9b28d9cbf356f8000..b78bfa0cdf4db605576fa11e18ce6c654c6a0b6d 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -101,11 +101,11 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction( } }; string node_name; - if (debug_options_.xla_hlo_tfgraph_device_scopes() && - instruction->has_sharding() && - instruction->sharding().HasUniqueDevice()) { - node_name = StrCat( - "dev", instruction->sharding().UniqueDevice().ConsumeValueOrDie()); + if (debug_options_.xla_hlo_tfgraph_device_scopes()) { + auto device = instruction->sharding_unique_device(); + if (device) { + node_name = StrCat("dev", *device); + } } // If an instruction is fused, put it in the subgraph of the fusion; // otherwise, put it in the computation subgraph. @@ -215,10 +215,10 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) { NodeDef* node_def = graph_def_.add_node(); node_def->set_name(GetNodeNameForInstruction(instruction)); node_def->set_op(GetOpDefName(instruction)); - if (instruction->has_sharding() && - instruction->sharding().HasUniqueDevice()) { - TF_ASSIGN_OR_RETURN(int64 device, instruction->sharding().UniqueDevice()); - node_def->set_device(GetDeviceName(device)); + + auto device = instruction->sharding_unique_device(); + if (device) { + node_def->set_device(GetDeviceName(*device)); } SetNodeAttrs(instruction, node_def); if (instruction->opcode() == HloOpcode::kFusion) { diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index be156d765dc10d54eaf301e90883babbc5693e28..1e2b31a1f2bb4865faafc3d14e2b194e3aa171a1 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -90,7 +90,7 @@ TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) { TEST_F(HloTfGraphBuilderTest, CheckScalarValue) { auto builder = HloComputation::Builder("Const"); HloInstruction *instruction = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); OpMetadata metadata; metadata.set_op_name("x"); metadata.set_op_type("y"); diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 7b27dbfec376b8ba16d00285f10e2cc291e07a61..7fd99fc93050b386c5ad24e6dcd2fea1bf652c3f 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -125,7 +125,7 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, // transparently. CHECK_EQ(operand_number, 0); return index.empty(); - case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: // Select does not use any nested elements of its selected-from operands // (operand 1 and 2) CHECK_GE(operand_number, 0); @@ -283,8 +283,7 @@ std::ostream& operator<<(std::ostream& out, string InstructionValueSet::ToString() const { string out = StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n"); - ForEachElement([this, &out](const ShapeIndex& index, - const HloValueSet& value_set) { + ForEachElement([&out](const ShapeIndex& index, const HloValueSet& value_set) { StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n"); }); return out; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index fb39c6f085339a65ffdffc8e2757fe5a259a02b0..3fae61f704ae73bb66b151c12cf7d900ffe42f49 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -41,6 +41,10 @@ Status ShapeVerifier::HandleSelect(HloInstruction* select) { return CheckTernaryShape(select); } +Status ShapeVerifier::HandleTupleSelect(HloInstruction* tuple_select) { + return CheckTernaryShape(tuple_select); +} + Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) { std::vector operand_shapes; for (const HloInstruction* operand : concatenate->operands()) { @@ -101,6 +105,15 @@ Status ShapeVerifier::HandleCrossReplicaSum(HloInstruction* crs) { ShapeInference::InferCrossReplicaSumShape(operand_shapes)); } +Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { + std::vector operand_shapes; + for (const HloInstruction* operand : hlo->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape(hlo, + ShapeInference::InferAllToAllTupleShape(operand_shapes)); +} + Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), @@ -108,17 +121,45 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { reduce_precision->mantissa_bits())); } +namespace { + +Status CheckIsTokenOperand(const HloInstruction* instruction, + int64 operand_no) { + const HloInstruction* token = instruction->operand(operand_no); + if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) { + return InternalError( + "Expected operand %lld to be token-shaped, actual shape is " + "%s:\n%s", + operand_no, ShapeUtil::HumanString(token->shape()).c_str(), + instruction->ToString().c_str()); + } + return Status::OK(); +} + +Status CheckOperandAndParameter(const HloInstruction* instruction, + int64 operand_number, + const HloComputation* computation, + int64 parameter_number) { + const HloInstruction* operand = instruction->operand(operand_number); + const HloInstruction* parameter = + computation->parameter_instruction(parameter_number); + if (!ShapeUtil::Compatible(operand->shape(), parameter->shape())) { + return InternalError("Operand %s shape does not match parameter's %s in %s", + operand->ToString().c_str(), + parameter->ToString().c_str(), + instruction->ToString().c_str()); + } + return Status::OK(); +} + +} // namespace + Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { HloInfeedInstruction* infeed = Cast(instruction); // Infeed has an optional single token operand. // TODO(b/80000000): Update when token is not optional. - if (infeed->operand_count() == 1 && - !ShapeUtil::Equal(infeed->operand(0)->shape(), - ShapeUtil::MakeTokenShape())) { - return InternalError( - "Expected infeed operand to be token-shaped, actual shape is %s:\n%s", - ShapeUtil::HumanString(infeed->operand(0)->shape()).c_str(), - infeed->ToString().c_str()); + if (infeed->operand_count() == 1) { + TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); } // The output of infeed is a tuple containing the data value and a token. @@ -131,13 +172,8 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { HloOutfeedInstruction* outfeed = Cast(instruction); // Outfeed has an optional token operand (operand 1). // TODO(b/80000000): Update when token is not optional. - if (outfeed->operand_count() == 2 && - !ShapeUtil::Equal(outfeed->operand(1)->shape(), - ShapeUtil::MakeTokenShape())) { - return InternalError( - "Expected operand 1 of outfeed to be a token, actual shape is %s:\n%s", - ShapeUtil::HumanString(outfeed->operand(1)->shape()).c_str(), - outfeed->ToString().c_str()); + if (outfeed->operand_count() == 2) { + TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); } // Outfeed has a separate shape field for the value which is outfed to the @@ -167,13 +203,28 @@ Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { } Status ShapeVerifier::HandleSort(HloInstruction* sort) { - return CheckUnaryShape(sort); + if (sort->operand_count() == 2 && + !ShapeUtil::SameDimensions(sort->operand(0)->shape(), + sort->operand(1)->shape())) { + return InternalError( + "Expected sort to have to have the same dimensions for the keys and " + "the values. Keys shape is: %s\n, Values shape is: %s", + ShapeUtil::HumanString(sort->operand(0)->shape()).c_str(), + ShapeUtil::HumanString(sort->operand(1)->shape()).c_str()); + } + return CheckVariadicShape(sort); } Status ShapeVerifier::HandleConstant(HloInstruction* constant) { return CheckShape(constant, constant->literal().shape()); } +Status ShapeVerifier::HandleIota(HloInstruction* iota) { + return ShapeUtil::Rank(iota->shape()) == 1 + ? Status::OK() + : InternalError("Iota only supports arrays of rank 1."); +} + Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { return CheckShape(get_tuple_element, ShapeInference::InferGetTupleElementShape( @@ -182,10 +233,13 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { } Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { + if (!ShapeUtil::IsArray(reduce->shape())) { + return InvalidArgument("Variadic reduce is not supported."); + } return CheckShape( reduce, ShapeInference::InferReduceShape( - reduce->operand(0)->shape(), reduce->operand(1)->shape(), + {&reduce->operand(0)->shape(), &reduce->operand(1)->shape()}, reduce->dimensions(), reduce->to_apply()->ComputeProgramShape())); } @@ -233,8 +287,11 @@ Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleCall(HloInstruction* call) { + for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) { + TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i)); + } // The shape of kCall should match the shape of the computation it calls. - return CheckShape(call, call->to_apply()->ComputeProgramShape().result()); + return CheckShape(call, call->to_apply()->root_instruction()->shape()); } Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); } @@ -303,19 +360,37 @@ Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { } Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { + TF_RETURN_IF_ERROR( + CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0)); + TF_RETURN_IF_ERROR( + CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0)); + const Shape& conditional_shape = + xla_while->while_condition()->root_instruction()->shape(); + if (!ShapeUtil::Compatible(conditional_shape, + ShapeUtil::MakeShape(PRED, {}))) { + return InternalError( + "Conditional computation shape does not lead to a scalar predicate " + "shape: %s", + ShapeUtil::HumanString(conditional_shape).c_str()); + } // The shape of kWhile should match the shape of the body computation it // calls. return CheckShape(xla_while, - xla_while->while_body()->ComputeProgramShape().result()); + xla_while->while_body()->root_instruction()->shape()); } Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { + TF_RETURN_IF_ERROR(CheckOperandAndParameter( + conditional, 1, conditional->true_computation(), 0)); + TF_RETURN_IF_ERROR(CheckOperandAndParameter( + conditional, 2, conditional->false_computation(), 0)); + TF_RETURN_IF_ERROR( + CheckShape(conditional, + conditional->true_computation()->root_instruction()->shape())); TF_RETURN_IF_ERROR(CheckShape( conditional, - conditional->true_computation()->ComputeProgramShape().result())); - return CheckShape( - conditional, - conditional->false_computation()->ComputeProgramShape().result()); + conditional->false_computation()->root_instruction()->shape())); + return Status::OK(); } Status ShapeVerifier::HandlePad(HloInstruction* pad) { @@ -325,39 +400,29 @@ Status ShapeVerifier::HandlePad(HloInstruction* pad) { } Status ShapeVerifier::HandleSend(HloInstruction* send) { - TF_RET_CHECK(send->users().size() == 1); - const HloInstruction* send_done = send->users().front(); - TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); - TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); - return CheckShape( - send, ShapeUtil::MakeTupleShape( - {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})})); + return CheckShape(send, + ShapeUtil::MakeTupleShape({send->operand(0)->shape(), + ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()})); } Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { - TF_RET_CHECK(send_done->operands().size() == 1); - const HloInstruction* send = send_done->operand(0); - TF_RET_CHECK(send->opcode() == HloOpcode::kSend); - TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); - return CheckShape(send_done, ShapeUtil::MakeNil()); + return CheckShape(send_done, ShapeUtil::MakeTokenShape()); } Status ShapeVerifier::HandleRecv(HloInstruction* recv) { - TF_RET_CHECK(recv->users().size() == 1); - const HloInstruction* recv_done = recv->users().front(); - TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); - TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); - return CheckShape(recv, - ShapeUtil::MakeTupleShape( - {recv_done->shape(), ShapeUtil::MakeShape(U32, {})})); + return CheckShape( + recv, ShapeUtil::MakeTupleShape( + {ShapeUtil::GetTupleElementShape(recv->shape(), 0), + ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})); } Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { - TF_RET_CHECK(recv_done->operands().size() == 1); - const HloInstruction* recv = recv_done->operand(0); - TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv); - TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); - return CheckShape(recv_done, recv->shape().tuple_shapes(0)); + return CheckShape( + recv_done, + ShapeUtil::MakeTupleShape( + {ShapeUtil::GetTupleElementShape(recv_done->operand(0)->shape(), 0), + ShapeUtil::MakeTokenShape()})); } Status ShapeVerifier::HandleBatchNormTraining( @@ -416,6 +481,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kRecvDone: case HloOpcode::kReducePrecision: case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kTuple: @@ -456,6 +522,15 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { gather->gather_dimension_numbers(), gather->gather_window_bounds())); } +Status ShapeVerifier::HandleScatter(HloInstruction* scatter) { + return CheckShape( + scatter, ShapeInference::InferScatterShape( + scatter->operand(0)->shape(), scatter->operand(1)->shape(), + scatter->operand(2)->shape(), + scatter->to_apply()->ComputeProgramShape(), + scatter->scatter_dimension_numbers())); +} + Status ShapeVerifier::HandleAfterAll(HloInstruction* token) { std::vector operand_shapes; for (const HloInstruction* operand : token->operands()) { @@ -478,16 +553,10 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, // We treat BF16 and F32 as compatible types if mixed precision is allowed, // but only when the instruction defines the BF16/F32 buffer. switch (instruction->opcode()) { - case HloOpcode::kSelect: - if (ShapeUtil::IsTuple(inferred_shape) || !allow_mixed_precision_) { - // Select only defines the top-level buffer, which in this case is the - // tuple, so we cannot allow mixed precision. - compatible = - ShapeUtil::Compatible(instruction->shape(), inferred_shape); - } else { - compatible = ShapeUtil::CompatibleIgnoringFpPrecision( - instruction->shape(), inferred_shape); - } + case HloOpcode::kTupleSelect: + // TupleSelect only defines the top-level buffer, which in this case is + // the tuple, so we cannot allow mixed precision. + compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); break; case HloOpcode::kGetTupleElement: case HloOpcode::kTuple: @@ -568,19 +637,6 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { instruction->opcode(), instruction->operands())); } -// Checks if the given two instructions shares the same channel id. -Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1, - const HloInstruction* instr2) { - if (instr1->channel_id() != instr2->channel_id()) { - return InternalError( - "Expected to have the same channel id, actual channel ids are: %s " - "(%lld), %s (%lld)", - instr1->ToString().c_str(), instr1->channel_id(), - instr2->ToString().c_str(), instr2->channel_id()); - } - return Status::OK(); -} - string ComputationsToString( tensorflow::gtl::ArraySlice computations) { return tensorflow::str_util::Join( @@ -780,33 +836,23 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { "While loop must have exactly one operand; had %lld : %s", instruction->operand_count(), instruction->ToString().c_str()); } - auto* init = instruction->operand(0); - auto* cond_param = while_cond->parameter_instruction(0); - if (!ShapeUtil::Compatible(init->shape(), cond_param->shape())) { - return FailedPrecondition( - "While condition's parameter must have the same shape as the " - "loop's 'init'. init: %s, param: %s", - init->ToString().c_str(), cond_param->ToString().c_str()); - } - auto* cond_root = while_cond->root_instruction(); - if (!ShapeUtil::Compatible(cond_root->shape(), - ShapeUtil::MakeShape(PRED, {}))) { - return FailedPrecondition("While condition should have shape PRED: %s", - cond_root->ToString().c_str()); - } - auto* body_param = while_body->parameter_instruction(0); - if (!ShapeUtil::Compatible(init->shape(), body_param->shape())) { + return Status::OK(); +} + +Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) { + if (instruction->true_computation()->num_parameters() != 1) { return FailedPrecondition( - "While body's parameter must have the same shape as the loop's" - " 'init'. init: %s, param: %s", - init->ToString().c_str(), body_param->ToString().c_str()); + "True computation %s of %s must have 1 parameter insted of %lld", + instruction->true_computation()->name().c_str(), + instruction->ToString().c_str(), + instruction->true_computation()->num_parameters()); } - auto* body_root = while_body->root_instruction(); - if (!ShapeUtil::Compatible(init->shape(), body_root->shape())) { + if (instruction->false_computation()->num_parameters() != 1) { return FailedPrecondition( - "While body should have same shape as the loop's 'init'." - "init: %s, body: %s", - init->ToString().c_str(), body_root->ToString().c_str()); + "False computation %s of %s must have 1 parameter insted of %lld", + instruction->false_computation()->name().c_str(), + instruction->ToString().c_str(), + instruction->false_computation()->num_parameters()); } return Status::OK(); } @@ -859,10 +905,105 @@ Status VerifyEntryAndExitShapes(const HloModule& module) { return Status::OK(); } +// Checks if the given two instructions share the same channel id. +Status CheckSameChannel(const HloInstruction* instr1, + const HloInstruction* instr2) { + if (instr1->channel_id() != instr2->channel_id()) { + return InternalError( + "Expected to have the same channel id, actual channel ids are: %s " + "(%lld), %s (%lld)", + instr1->ToString().c_str(), instr1->channel_id(), + instr2->ToString().c_str(), instr2->channel_id()); + } + return Status::OK(); +} + +// Checks if the given two instructions have the same is_host_transfer attribute +// value. Intsructions must be send/recv instructions or their 'done' variant. +Status CheckSameIsHostTransfer(const HloInstruction* instr1, + const HloInstruction* instr2) { + const HloSendRecvInstruction* send_recv1 = + DynCast(instr1); + const HloSendRecvInstruction* send_recv2 = + DynCast(instr2); + TF_RET_CHECK(send_recv1 != nullptr); + TF_RET_CHECK(send_recv2 != nullptr); + if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) { + return InternalError( + "Expected instructions to have the same is-host-transfer property: %s, " + "%s ", + instr1->ToString().c_str(), instr2->ToString().c_str()); + } + return Status::OK(); +} + +// Checks various invariants of send and recv instructions. +Status VerifySendsAndRecvs(const HloModule& module) { + tensorflow::gtl::FlatMap host_channels; + // Host send/recv instructions must have their own unique channel. + auto check_unique_host_channel = [&](const HloInstruction* instruction) { + const HloSendRecvInstruction* sendrecv = + DynCast(instruction); + if (sendrecv->is_host_transfer()) { + auto it_inserted = + host_channels.insert({sendrecv->channel_id(), sendrecv}); + if (!it_inserted.second) { + return FailedPrecondition( + "Channel %lld is used for multiple host send/recv instructions: %s " + "and " + "%s", + sendrecv->channel_id(), sendrecv->ToString().c_str(), + it_inserted.first->second->ToString().c_str()); + } + } + + return Status::OK(); + }; + + // Send/Recv instruction must have a single user: the corresponding + // SendDone/RecvDone. with matching channel. + for (const HloComputation* computation : module.computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + switch (instruction->opcode()) { + case HloOpcode::kSend: { + TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); + TF_RET_CHECK(instruction->users().size() == 1); + const HloInstruction* send_done = instruction->users().front(); + TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done)); + break; + } + case HloOpcode::kRecv: { + TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); + TF_RET_CHECK(instruction->users().size() == 1); + const HloInstruction* recv_done = instruction->users().front(); + TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done)); + break; + } + case HloOpcode::kSendDone: + TF_RET_CHECK(instruction->operands().size() == 1); + TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend); + break; + case HloOpcode::kRecvDone: + TF_RET_CHECK(instruction->operands().size() == 1); + TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv); + break; + default: + break; + } + } + } + return Status::OK(); +} + } // namespace StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyHloStructure(module)); + TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module)); tensorflow::gtl::FlatMap instructions; @@ -902,6 +1043,8 @@ StatusOr HloVerifier::Run(HloModule* module) { << " != " << ShapeUtil::Rank(instruction->operand(0)->shape()); } else if (instruction->opcode() == HloOpcode::kWhile) { TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction)); + } else if (instruction->opcode() == HloOpcode::kConditional) { + TF_RETURN_IF_ERROR(CheckConditionalInstruction(instruction)); } else if (instruction->opcode() != HloOpcode::kRng /* Rng operands are always scalar. */ && instruction->IsElementwise()) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index da6b5d222206fe9bfcbf5157dc524ed46edaaac7..5a56a44f355a7b8c43d22433404d80f672024a55 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -35,7 +35,9 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleElementwiseBinary(HloInstruction* hlo) override; Status HandleClamp(HloInstruction* clamp) override; Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleConcatenate(HloInstruction* concatenate) override; + Status HandleIota(HloInstruction* iota) override; Status HandleConvert(HloInstruction* convert) override; Status HandleBitcastConvert(HloInstruction* convert) override; Status HandleCopy(HloInstruction* copy) override; @@ -43,6 +45,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleAllToAll(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; Status HandleInfeed(HloInstruction*) override; Status HandleOutfeed(HloInstruction*) override; @@ -81,6 +84,7 @@ class ShapeVerifier : public DfsHloVisitor { HloInstruction* batch_norm_inference) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleGather(HloInstruction* gather) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* token) override; Status FinishVisit(HloInstruction*) override { return Status::OK(); } @@ -101,10 +105,6 @@ class ShapeVerifier : public DfsHloVisitor { Status CheckTernaryShape(const HloInstruction* instruction); Status CheckVariadicShape(const HloInstruction* instruction); - // Checks if the given two instructions share the same channel id. - Status CheckSameChannel(const HloInstruction* instr1, - const HloInstruction* instr2); - private: // Whether the inputs and output of an instruction can contain both F32s and // BF16s. Tuples that include both F32s and BF16s are allowed regardless of @@ -145,6 +145,8 @@ class HloVerifier : public HloPassInterface { Status CheckWhileInstruction(HloInstruction* instruction); + Status CheckConditionalInstruction(HloInstruction* instruction); + // Checks that the non-scalar operand shapes are compatible to the output // shape, i.e., that there are no implicit broadcasts of size-one dimensions. Status CheckElementwiseInstruction(HloInstruction* instruction); diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index c92db0be14dceb32ea86521dcc99b8f63738e4a5..04c6ba3eeb92bad2b5b69f7f56e73e1f7a8148aa 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -123,5 +124,55 @@ TEST_F(HloVerifierTest, ResetsShapeVerifierState) { EXPECT_FALSE(verifier().Run(module.get()).status().ok()); } +TEST_F(HloVerifierTest, CheckCallOperandParameterShapesMismatch) { + const char* const hlo_string = R"( +HloModule Module + +callme { + ROOT param = (s32[], f32[4]) parameter(0) +} + +ENTRY entry { + p0 = (f32[4], s32[]) parameter(0) + ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("shape does not match parameter")); +} + +TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) { + const char* const hlo_string = R"( +HloModule Module + +true_branch { + tparam = (s32[], f32[4]) parameter(0) + ROOT tgte1 = f32[4] get-tuple-element(tparam), index=1 +} + +false_branch { + fparam = (s32[], f32[4]) parameter(0) + ROOT fgte1 = f32[4] get-tuple-element(fparam), index=1 +} + +ENTRY entry { + p0 = (f32[4], s32[]) parameter(0) + constant = pred[] constant(true) + ROOT conditional = f32[4] conditional(constant, p0, p0), + true_computation=true_branch, false_computation=false_branch +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("shape does not match parameter")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index d7458c338e9f1df9fac90270845aae0b8f779ee2..bb5b40a8a87c5eab5a5b1599581a81bbd064511b 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -36,7 +36,8 @@ string HumanReadableProfileBuilder::ToString() const { computation_name_.c_str(), HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str()); - auto print_op = [&](const OpInfo& op) { + int64 cumulative_cycles = 0; + auto print_op = [&](const OpInfo& op, bool is_total = false) { // Skip ops with 0 optimal seconds and 0 actual cycles. These are ops that // were expected to be free and are actually free -- things like (on most // backends) kParameter or kConstant HLOs. There's no need to clutter the @@ -59,27 +60,44 @@ string HumanReadableProfileBuilder::ToString() const { } } + double cumulative_cycles_percent = 0; double cycles_percent = 0; + if (!is_total) { + cumulative_cycles += op.cycles; + } if (total_cycles_ > 0) { cycles_percent = op.cycles / static_cast(total_cycles_) * 100; + cumulative_cycles_percent = + cumulative_cycles / static_cast(total_cycles_) * 100; + } + + string cycles_percent_str; + if (is_total) { + // Leaving off the two trailing decimal points of "100.%" lets us save two + // columns in the output. + cycles_percent_str = "100.% 100Σ"; + } else { + cycles_percent_str = + Printf("%5.2f%% %2.0fΣ", cycles_percent, cumulative_cycles_percent); } double nsecs = op.cycles / clock_rate_ghz_; - Appendf(&s, - "%15lld cycles (%6.2f%%) :: %12.1f usec %22s :: %18s " - ":: %18s :: %14s :: %16s :: %s\n", - op.cycles, cycles_percent, CyclesToMicroseconds(op.cycles), - op.optimal_seconds < 0 - ? "" - : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(), - op.flop_count <= 0 - ? "" - : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), - op.transcendental_count <= 0 ? "" - : HumanReadableNumTranscendentalOps( - op.transcendental_count, nsecs) - .c_str(), - bytes_per_sec.c_str(), bytes_per_cycle.c_str(), op.name.c_str()); + Appendf( + &s, + "%15lld cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: " + "%16s :: %s\n", + op.cycles, cycles_percent_str.c_str(), CyclesToMicroseconds(op.cycles), + op.optimal_seconds < 0 + ? "" + : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(), + op.flop_count <= 0 + ? "" + : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), + op.transcendental_count <= 0 + ? "" + : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs) + .c_str(), + bytes_per_sec.c_str(), bytes_per_cycle.c_str(), op.name.c_str()); }; float optimal_seconds_sum = 0.0; @@ -98,7 +116,8 @@ string HumanReadableProfileBuilder::ToString() const { VLOG(1) << "Total floating point ops: " << total_flops; print_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops, - total_transcendentals, total_bytes, optimal_seconds_sum}); + total_transcendentals, total_bytes, optimal_seconds_sum}, + /*is_total=*/true); // Sort ops in decreasing order of cycles, and print them. std::vector sorted_ops(op_infos_); diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc index 8c7b38dd1bf73e0be7b669d7215812aaef1cee17..f85d31d5225b8012b68f851b2bfec219d736ba0d 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 1985d20578677ae68b244023c4640454b004bf49..3531b7223fb11df212fa8d30e3adba6aac6c5679 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -160,6 +161,12 @@ StatusOr IndexedArrayAnalysis::ComputeArrayFor( computed_array, ComputeArrayForReshape(instr->shape(), FindOrDie(cache_, instr->operand(0)))); + } else if (instr->opcode() == HloOpcode::kDot) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(), + FindOrDie(cache_, instr->operand(0)), + FindOrDie(cache_, instr->operand(1)))); } else { computed_array = nullptr; } @@ -290,8 +297,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForGather( } if (auto* indexed = dynamic_cast(source)) { - auto it = c_find(indexed->output_dims(), source_dim); - if (it != indexed->output_dims().end()) { + if (c_linear_search(indexed->output_dims(), source_dim)) { return FoldGatherOfGather(indexed, indices, source_dim, output_dims, shape); } @@ -441,7 +447,7 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, int64 indexed_source_subarray_size = std::accumulate(operand_shape.begin() + source_passthrough_dim + 1, - operand_shape.end(), 1, std::multiplies()); + operand_shape.end(), 1LL, std::multiplies()); return FindSuffixWithProduct(result_shape, indexed_source_subarray_size); } @@ -758,7 +764,7 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node, scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim())); - CHECK_EQ(c_accumulate(new_scalar_indexed_source_shape, 1l, + CHECK_EQ(c_accumulate(new_scalar_indexed_source_shape, 1LL, std::multiplies()), ShapeUtil::ElementsIn(scalar_indexed_source_shape)); @@ -956,11 +962,177 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, return Construct( new_source, scalar_indexed_const->indices(), scalar_indexed_const->source_dim(), - std::vector(scalar_indexed_const->output_dims().begin(), - scalar_indexed_const->output_dims().end()), + ArraySliceToVector(scalar_indexed_const->output_dims()), scalar_indexed_const->shape()); } +namespace { + +// Returns the non-contracting non-batch dimension (as per `contracting_dims` +// and `batch_dims`) if there is exactly one, otherwise returns nullopt. +gtl::optional GetOnlyNonContractingNonBatchDim( + int64 rank, ArraySlice contracting_dims, + ArraySlice batch_dims) { + gtl::optional result; + for (int64 dim = 0; dim < rank; dim++) { + if (!ArrayContains(contracting_dims, dim) && + !ArrayContains(batch_dims, dim)) { + if (result.has_value()) { + return gtl::nullopt; + } + result = dim; + } + } + return result; +} + +// Returns true if `indexed_array`, which is either the LHS or the RHS of a Dot +// HLO, can be folded into the dot operation. For now these conditions are both +// necessary and sufficient. +// +// `tag` describes the caller. Used only for logging. +// +// `contracting_dims` and `batch_dims` are the contracting and batch dimensions +// of whatever operand `indexed_array` is to the dot (LHS or RHS). +bool CanFoldDotIntoIndexedArray( + tensorflow::StringPiece tag, + Analysis::ScalarIndexedConstantArray* indexed_array, + ArraySlice contracting_dims, ArraySlice batch_dims) { + gtl::optional non_contracting_non_batch_dim = + GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()), + contracting_dims, batch_dims); + if (!non_contracting_non_batch_dim.has_value()) { + VLOG(3) << tag << ": multiple or no non-contracting non-batch dimensions"; + return false; + } + + if (indexed_array->output_dims().size() != 1 || + indexed_array->output_dims()[0] != *non_contracting_non_batch_dim) { + VLOG(3) << tag << ": output dims != the lhs non-contracting non-batch dim"; + return false; + } + + int64 indexed_array_rank = ShapeUtil::Rank(indexed_array->shape()); + if (indexed_array->source_dim() < (indexed_array_rank - 2)) { + // This restriction can be lifted by inserting reshape nodes. + VLOG(3) << tag + << ": source dim is not in the low two dims, won't be able to form " + "a matmul"; + return false; + } + + return true; +} + +} // namespace + +StatusOr +IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( + const Shape& shape, const DotDimensionNumbers& dim_numbers, + ScalarIndexedConstantArray* lhs, ConstantArray* rhs) { + VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " " + << ToString(rhs); + if (!CanFoldDotIntoIndexedArray( + "ComputeArrayForDotWithIndexedLhs", lhs, /*contracting_dims=*/ + AsInt64Slice(dim_numbers.lhs_contracting_dimensions()), + /*batch_dims=*/AsInt64Slice(dim_numbers.lhs_batch_dimensions()))) { + return nullptr; + } + + int64 lhs_rank = ShapeUtil::Rank(lhs->shape()); + DotDimensionNumbers new_dim_numbers = dim_numbers; + new_dim_numbers.set_lhs_contracting_dimensions( + 0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1)); + + TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateDotOp( + new_dim_numbers, lhs->literal(), *rhs->literal()))); + + // The new source dimension is wherever the non-batch non-contracting LHS + // dimension "went". + int64 new_source_dim = dim_numbers.lhs_batch_dimensions_size() + + dim_numbers.rhs_batch_dimensions_size(); + + ConstantArray* new_source = Construct(literal_for_new_source); + return Construct( + new_source, lhs->indices(), new_source_dim, + ArraySliceToVector(lhs->output_dims()), shape); +} + +StatusOr +IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( + const Shape& shape, const DotDimensionNumbers& dim_numbers, + ConstantArray* lhs, ScalarIndexedConstantArray* rhs) { + VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " " + << ToString(rhs); + if (!CanFoldDotIntoIndexedArray( + "ComputeArrayForDotWithIndexedRhs", rhs, /*contracting_dims=*/ + AsInt64Slice(dim_numbers.rhs_contracting_dimensions()), + /*batch_dims=*/AsInt64Slice(dim_numbers.rhs_batch_dimensions()))) { + return nullptr; + } + + int64 rhs_rank = ShapeUtil::Rank(rhs->shape()); + + DotDimensionNumbers new_dim_numbers = dim_numbers; + new_dim_numbers.set_rhs_contracting_dimensions( + 0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1)); + + TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateDotOp( + new_dim_numbers, *lhs->literal(), rhs->literal()))); + + // The new source dimension is wherever the non-batch non-contracting RHS + // dimension "went". + int64 new_source_dim = dim_numbers.lhs_batch_dimensions_size() + + dim_numbers.rhs_batch_dimensions_size() + 1; + + ConstantArray* new_source = Construct(literal_for_new_source); + return Construct( + new_source, rhs->indices(), new_source_dim, + ArraySliceToVector(rhs->output_dims()), shape); +} + +StatusOr IndexedArrayAnalysis::ComputeArrayForDot( + const Shape& shape, const DotDimensionNumbers& dim_numbers, Array* lhs, + Array* rhs) { + // Intuitively, if + // + // - The LHS of a dot product is a gathered sequence of rows from a constant + // array (i.e. LHS[I,J] = Const[Indices[I],J]) and the RHS is a constant + // + // OR + // + // - If the RHS of a dot product is a gathered sequence of columns from a + // constant array (i.e. RHS[I,J] = Const[I, Indices[J]]) and the LHS is a + // constant + // + // then the result of the dot product itself is a gather from a constant + // array. E.g. Dot(LHS, ConstRhs) where LHS[I,J] = Const[Indices[I],J] can be + // rewritten as Result where Result[I,J] = Dot(Const, ConstRhs)[Indices[I], + // J]. + // + // We do a general version of this rewrite here. + VLOG(3) << "ComputeArrayForDot(" << ToString(lhs) << " " << ToString(rhs); + if (auto* lhs_indexed_array = + dynamic_cast(lhs)) { + if (auto* rhs_constant = dynamic_cast(rhs)) { + return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers, + lhs_indexed_array, rhs_constant); + } + } + + if (auto* rhs_indexed_array = + dynamic_cast(rhs)) { + if (auto* lhs_constant = dynamic_cast(lhs)) { + return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, lhs_constant, + rhs_indexed_array); + } + } + + return nullptr; +} + tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const { return "indexed-array-analysis-printer-pass"; } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index 8684430231c1929f82508e3675f1c275c42b6149..e923dc39f7f464a8d3c400294499a6f5efda3991 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -268,6 +268,18 @@ class IndexedArrayAnalysis { tensorflow::gtl::ArraySlice window_bounds, Array* source, Array* indices); + StatusOr ComputeArrayForDotWithIndexedLhs( + const Shape& shape, const DotDimensionNumbers& dim_numbers, + ScalarIndexedConstantArray* lhs, ConstantArray* rhs); + + StatusOr ComputeArrayForDotWithIndexedRhs( + const Shape& shape, const DotDimensionNumbers& dim_numbers, + ConstantArray* lhs, ScalarIndexedConstantArray* rhs); + + StatusOr ComputeArrayForDot(const Shape& shape, + const DotDimensionNumbers& dim_numbers, + Array* lhs, Array* rhs); + // This tries to fold a ScalarIndexedArray which has another // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a // ScalarIndexedArray as indices. If `source` happened to be a diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index fc2befe05b18651502c42b9892e766145d85f2e8..5f4b42799b1c26ea544f9d4447cc45b5ae9d5a48 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -799,5 +799,170 @@ ENTRY main { AssertArrayForRootExpressionIs(hlo_text, "%add"); } +TEST_F(IndexedArrayAnalysisTest, DotOpBasic_0) { + string hlo_text = R"( +HloModule DotOp + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + indices = s32[5] parameter(0) + dot_lhs = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"( +(scalar-indexed-const + (constant s32[3,3] s32[3,3] { + { 70, 80, 90 }, + { 158, 184, 210 }, + { 246, 288, 330 } }) + %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, DotOpBasic_1) { + string hlo_text = R"( +HloModule DotOp + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}}) + indices = s32[5] parameter(0) + dot_lhs = s32[3,5] gather(gather_operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3,1} + ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"( +(scalar-indexed-const + (constant s32[4,3] s32[4,3] { + { 84, 99, 114 }, + { 96, 114, 132 }, + { 108, 129, 150 }, + { 120, 144, 168 } }) + %indices 0->[1]))"); +} + +TEST_F(IndexedArrayAnalysisTest, DotOpBasic_2) { + string hlo_text = R"( +HloModule DotOp + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + indices = s32[5] parameter(0) + dot_rhs = s32[3,5] gather(gather_operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3,1} + ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"( +(scalar-indexed-const + (constant s32[4,4] s32[4,4] { + { 38, 44, 50, 56 }, + { 83, 98, 113, 128 }, + { 128, 152, 176, 200 }, + { 173, 206, 239, 272 } }) + %indices 1->[1]) +)"); +} + +TEST_F(IndexedArrayAnalysisTest, DotOpBasic_3) { + string hlo_text = R"( +HloModule DotOp + +ENTRY main { + gather_operand = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) + indices = s32[5] parameter(0) + dot_rhs = s32[5,3] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} + ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1} +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"( +(scalar-indexed-const + (constant s32[4,4] s32[4,4] { + { 14, 32, 50, 68 }, + { 32, 77, 122, 167 }, + { 50, 122, 194, 266 }, + { 68, 167, 266, 365 } }) + %indices 1->[0]) +)"); +} + +TEST_F(IndexedArrayAnalysisTest, DotOpWithBatch) { + string hlo_text = R"( +HloModule DotOp + +ENTRY main { + gather_operand = s32[2,3,2] constant(s32[2,3,2]{{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}}) + dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}}) + indices = s32[4] parameter(0) + dot_rhs = s32[2,3,4] gather(gather_operand, indices), + output_window_dims={0,1}, + elided_window_dims={2}, + gather_dims_to_operand_dims={2}, + index_vector_dim=1, + window_bounds={2,3,1} + ROOT dot = s32[2,2,4] dot(dot_lhs_constant, dot_rhs), + lhs_contracting_dims={2}, rhs_contracting_dims={1}, + lhs_batch_dims={0}, rhs_batch_dims={0} +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"( +(scalar-indexed-const + (constant s32[2,2,2] s32[2,2,2] { + { { 22, 28 }, + { 49, 64 } }, + { { 220, 244 }, + { 301, 334 } } }) + %indices 3->[2]) +)"); +} + +TEST_F(IndexedArrayAnalysisTest, DotOpNegative) { + string hlo_text = R"( +HloModule DotOp + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}}) + dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}}) + indices = s32[2] parameter(0) + dot_lhs = s32[3,2] gather(gather_operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3,1} + ROOT dot = s32[3,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, "%dot"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index d2af261008f40ee83e0676cfc7e67c45f8be1844..32937b33b3737482f07d4c7607f7f1c5c183a56b 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -51,10 +51,10 @@ TEST_F(InlinerTest, MapMax) { auto max_f32 = max_builder.Build(); auto builder = HloComputation::Builder("MapMaxFunction"); - auto lhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); - auto rhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({4, 3, 2, 1}))); + auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); + auto rhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({4, 3, 2, 1}))); builder.AddInstruction( HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); @@ -70,7 +70,7 @@ TEST_F(InlinerTest, MapMax) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - auto expected = Literal::CreateR1({4, 3, 3, 4}); + auto expected = LiteralUtil::CreateR1({4, 3, 3, 4}); EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } @@ -83,12 +83,12 @@ TEST_F(InlinerTest, MapConstant) { HloInstruction::CreateParameter(0, r0f32, "x")); (void)param1; const2_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); auto const2_f32 = const2_builder.Build(); auto builder = HloComputation::Builder("MapConstFunction"); auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}}))); + LiteralUtil::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}}))); builder.AddInstruction( HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get())); @@ -104,7 +104,7 @@ TEST_F(InlinerTest, MapConstant) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - auto expected = Literal::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); + auto expected = LiteralUtil::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } @@ -123,10 +123,10 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { auto max_f32 = max_builder.Build(); auto builder = HloComputation::Builder("MapSubFunction"); - auto lhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); - auto rhs = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({4, 3, 2, 1}))); + auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); + auto rhs = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({4, 3, 2, 1}))); builder.AddInstruction( HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); @@ -142,7 +142,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - auto expected = Literal::CreateR1({3, 1, -1, -3}); + auto expected = LiteralUtil::CreateR1({3, 1, -1, -3}); EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 088cc2622695c7724dae2b6cde28fecd40547445..f33942d67907d8f40811bde5041350a2e1e1f1fc 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -73,6 +73,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kGt: case HloOpcode::kImag: case HloOpcode::kInfeed: + case HloOpcode::kIota: case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLt: @@ -100,6 +101,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kAfterAll: case HloOpcode::kTranspose: case HloOpcode::kTuple: + case HloOpcode::kTupleSelect: return false; // Cheap instructions for reals, but expensive for complex. @@ -118,6 +120,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kConditional: case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: case HloOpcode::kCustomCall: case HloOpcode::kDivide: case HloOpcode::kDomain: @@ -139,6 +142,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kReduceWindow: case HloOpcode::kRemainder: case HloOpcode::kRng: + case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: case HloOpcode::kSend: case HloOpcode::kSendDone: diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 21db2338995960bde00ec9c4b325e5562fc3a592..9e7a15f0330d3f06779c850a4b575f84fe0b9505 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -167,7 +167,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1")); HloInstruction* binary1 = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); - builder.AddInstruction(HloInstruction::CreateSend(binary1, 0)); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); + builder.AddInstruction(HloInstruction::CreateSend(binary1, token, 0)); HloInstruction* unary = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); @@ -258,7 +259,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { add = f32[4,3]{1,0} add(p0, p0) abs1 = f32[4,3]{1,0} abs(add) log = f32[4,3]{1,0} log(abs1) - send = f32[4,3]{1,0} send(log), channel_id=0 + token = token[] after-all() + send = f32[4,3]{1,0} send(log, token), channel_id=0 abs2 = f32[4,3]{1,0} abs(log) ROOT root = f32[4,3]{1,0} subtract(abs2, add) })") @@ -288,7 +290,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { p0 = f32[4,3]{1,0} parameter(0) add1 = f32[4,3]{1,0} add(p0, p0) log = f32[4,3]{1,0} log(p0) - send = f32[4,3]{1,0} send(log), channel_id=0 + token = token[] after-all() + send = f32[4,3]{1,0} send(log, token), channel_id=0 add2 = f32[4,3]{1,0} add(log, add1) ROOT root = f32[4,3]{1,0} subtract(add1, add2) })") @@ -321,7 +324,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { add1 = f32[4,3]{1,0} add(p0, p0) add2 = f32[4,3]{1,0} add(add1, add1) log = f32[4,3]{1,0} log(add2) - send = f32[4,3]{1,0} send(log), channel_id=0 + token = token[] after-all() + send = f32[4,3]{1,0} send(log, token), channel_id=0 sub1 = f32[4,3]{1,0} subtract(log, add2) sub2 = f32[4,3]{1,0} subtract(add2, add1) ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2) @@ -352,7 +356,8 @@ TEST_F(InstructionFusionTest, AllowUnaryDuplication) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0")); HloInstruction* unary1 = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0)); - builder.AddInstruction(HloInstruction::CreateSend(unary1, 0)); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); + builder.AddInstruction(HloInstruction::CreateSend(unary1, token, 0)); HloInstruction* unary2 = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1)); @@ -375,7 +380,8 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1")); HloInstruction* binary1 = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); - builder.AddInstruction(HloInstruction::CreateSend(binary1, 0)); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); + builder.AddInstruction(HloInstruction::CreateSend(binary1, token, 0)); HloInstruction* unary = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 524d3234eb4eff9c7d000eca1a0d9f5c4fae90af..8652599dc6d48ff8c2aaa703fead161f891a57d1 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -74,7 +74,7 @@ cc_library( hdrs = ["executable.h"], deps = [ ":executor", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 9816acf6507a0ed5391cf4f1c94ccd0f27f5227a..8d40c08d555a232b7cf3b81cc0f9970804c2f896 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 36fdfa868dfbfaf9fbf353dd6623058d518fec04..b5a9d6e8e7d66ae0c560226a79578d85eaf55644 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -30,10 +30,12 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" @@ -59,7 +61,6 @@ namespace xla { // anonymous namespace, instead of three or four spread all over this file. namespace { - } // namespace std::ostream& operator<<(std::ostream& out, @@ -113,14 +114,18 @@ LayoutConstraints::LayoutConstraints( HloComputation* computation) : points_to_analysis_(points_to_analysis), computation_(computation) { // Gather all array-shaped logical buffers into unconstrained_buffer_ids. - for (LogicalBuffer::Id id = 0; id < points_to_analysis_.num_logical_buffers(); - id++) { - auto& buffer = points_to_analysis_.logical_buffer(id); - // The points to analysis is computed per module, restrict constraints to - // array buffers in this computation. - if (buffer.IsArray() && buffer.instruction()->parent() == computation) { - unconstrained_buffer_ids_.insert(buffer.id()); - } + for (HloInstruction* inst : computation_->instructions()) { + points_to_analysis_.GetPointsToSet(inst).ForEachElement( + [&](const ShapeIndex&, const PointsToSet::BufferList& buffers) { + for (const LogicalBuffer* buffer : buffers) { + // The points to analysis is computed per module, restrict + // constraints to array buffers in this computation. + if (buffer->IsArray() && + buffer->instruction()->parent() == computation) { + unconstrained_buffer_ids_.insert(buffer->id()); + } + } + }); } } @@ -392,6 +397,43 @@ string LayoutConstraints::ToString() const { return output; } +namespace { + +bool IsHostSendRecv(const HloInstruction* instruction) { + const HloSendRecvInstruction* send_recv_instr = + DynCast(instruction); + return send_recv_instr != nullptr && send_recv_instr->is_host_transfer(); +} + +} // namespace + +Status LayoutAssignment::BuildHostChannelConstraints( + HloComputation* computation) { + for (auto* instruction : computation->instructions()) { + const HloSendRecvInstruction* send_recv_instr = + DynCast(instruction); + if (send_recv_instr == nullptr || !send_recv_instr->is_host_transfer()) { + continue; + } + + // For host transfers the Send and Recv instruction carry the layout. + if (instruction->opcode() == HloOpcode::kSend || + instruction->opcode() == HloOpcode::kRecv) { + const Shape& data_shape = + ShapeUtil::GetTupleElementShape(send_recv_instr->shape(), 0); + TF_RET_CHECK(ShapeUtil::IsArray(data_shape)); + TF_RET_CHECK(LayoutUtil::HasLayout(data_shape)); + const Layout* prev_layout = host_channel_constraints_.ConstrainChannel( + send_recv_instr->channel_id(), data_shape.layout()); + TF_RET_CHECK(prev_layout == nullptr) + << "Cannot constrain host transfer layout as it was set to " + << LayoutUtil::HumanString(*prev_layout) << ": " + << send_recv_instr->ToString(); + } + } + return Status::OK(); +} + Status LayoutAssignment::AddMandatoryConstraints( const ComputationLayout* computation_layout, ChannelLayoutConstraints* channel_constraints, HloComputation* computation, @@ -399,6 +441,11 @@ Status LayoutAssignment::AddMandatoryConstraints( VLOG(3) << "Adding mandatory layout constraints to computation " << computation->name(); + auto get_channel_constraints = [&](const HloInstruction* instruction) { + return IsHostSendRecv(instruction) ? &host_channel_constraints_ + : channel_constraints; + }; + // Constrain layouts of instructions which define values with pre-existing // layouts. for (auto* instruction : computation->instructions()) { @@ -435,18 +482,21 @@ Status LayoutAssignment::AddMandatoryConstraints( if (instruction->opcode() == HloOpcode::kSend || instruction->opcode() == HloOpcode::kRecv) { - CHECK(channel_constraints) + CHECK(get_channel_constraints(instruction)) << "Multi-module layout assignment requires ChannelLayoutConstraints"; int64 channel_id = instruction->channel_id(); - if (!channel_constraints->IsChannelConstrained(channel_id)) { + if (!get_channel_constraints(instruction) + ->IsChannelConstrained(channel_id)) { continue; } if (instruction->opcode() == HloOpcode::kSend) { // TODO(b/68493863): Change to use SetOperandLayout(). const Shape send_buffer_shape = instruction->operand(0)->shape(); TF_RET_CHECK(ShapeUtil::IsArray(send_buffer_shape)); - Shape new_buffer_shape = channel_constraints->LayoutShapeForChannel( - send_buffer_shape, instruction->channel_id()); + Shape new_buffer_shape = + get_channel_constraints(instruction) + ->LayoutShapeForChannel(send_buffer_shape, + instruction->channel_id()); TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( new_buffer_shape, instruction->operand(0))); } else { @@ -457,8 +507,9 @@ Status LayoutAssignment::AddMandatoryConstraints( const LogicalBuffer* buffer, constraints->points_to_analysis().GetBufferDefinedAt(instruction, {0})); - Shape new_shape = channel_constraints->LayoutShapeForChannel( - recv_buffer_shape, instruction->channel_id()); + Shape new_shape = get_channel_constraints(instruction) + ->LayoutShapeForChannel( + recv_buffer_shape, instruction->channel_id()); TF_RETURN_IF_ERROR( constraints->SetBufferLayout(new_shape.layout(), *buffer)); } @@ -823,8 +874,8 @@ void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction, // HostCompute module. // Otherwise it is preferable to leave the new instruction without device, // and let the automatic device placer to choose the best location. - if (!sharding.HasUniqueDevice() || - HloSharding::IsReservedDevice(sharding.UniqueDevice().ValueOrDie())) { + auto device = sharding.UniqueDevice(); + if (!device || HloSharding::IsReservedDevice(*device)) { copy->set_sharding(sharding); } } @@ -1177,7 +1228,7 @@ Status LayoutAssignment::PropagateUseConstraintToDefs( const PointsToSet& points_to_set = constraints->points_to_analysis().GetPointsToSet(instruction); return points_to_set.ForEachElementWithStatus( - [this, &shape_layout, constraints]( + [&shape_layout, constraints]( const ShapeIndex& index, const PointsToSet::BufferList& buffers) -> Status { if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) { @@ -1535,6 +1586,10 @@ Status LayoutAssignment::RunOnComputation( ChannelLayoutConstraints* channel_constraints) { VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name() << ")"; + + // Must be run before clearing layouts. + TF_RETURN_IF_ERROR(BuildHostChannelConstraints(computation)); + TF_RETURN_IF_ERROR(ClearComputationLayouts(computation)); if (computation_layout != nullptr) { auto it = computation_layouts_.find(computation); @@ -1624,13 +1679,20 @@ Status LayoutAssignment::RunOnComputation( Status LayoutAssignment::ConstrainChannelLayouts( HloComputation* computation, ChannelLayoutConstraints* channel_constraints) { + auto get_channel_constraints = [&](const HloInstruction* instruction) { + return IsHostSendRecv(instruction) ? &host_channel_constraints_ + : channel_constraints; + }; // We go through the kRecvDone before. These must either impose their layout, - // of find a matching one already existing (ConstrainChannel() returns + // or find a matching one already existing (ConstrainChannel() returns // nullptr). for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kRecvDone) { - const Layout* layout = channel_constraints->ConstrainChannel( - instruction->channel_id(), instruction->shape().layout()); + const Layout* layout = + get_channel_constraints(instruction) + ->ConstrainChannel( + instruction->channel_id(), + ShapeUtil::GetSubshape(instruction->shape(), {0}).layout()); TF_RET_CHECK(layout == nullptr) << instruction->ToString() << " cannot constrain layout as it was set to " @@ -1643,11 +1705,12 @@ Status LayoutAssignment::ConstrainChannelLayouts( for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { if (instruction->opcode() == HloOpcode::kSend) { HloInstruction* operand = instruction->mutable_operand(0); - const Layout* layout = channel_constraints->ConstrainChannel( - instruction->channel_id(), operand->shape().layout()); + const Layout* layout = get_channel_constraints(instruction) + ->ConstrainChannel(instruction->channel_id(), + operand->shape().layout()); if (layout != nullptr) { // We found an already constrained layout which does not match the one - // the kSend wants to impose. Eitehr add a new kCopy, or use the + // the kSend wants to impose. Either add a new kCopy, or use the // existing one to marshal the correct shape. Shape shape = operand->shape(); *shape.mutable_layout() = *layout; diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index b75ecb311a07b996562460fc5d6fbd8e70ac056b..f9e8dbea2f8aa224318adf3cf4b5e493792d3093 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -488,6 +488,9 @@ class LayoutAssignment : public HloPassInterface { } } + // Adds constraints related to host Send/Recv instructions. + Status BuildHostChannelConstraints(HloComputation* computation); + // Map containing the layouts of all computations assigned so // far. Computations are handled in a topological sort where computations are // handled before their caller instructions so the layouts of caller @@ -507,6 +510,10 @@ class LayoutAssignment : public HloPassInterface { // computations/instructions. ChannelLayoutConstraints channel_constraints_; + // Layout constraints for send/recv instructions which communicate with the + // host. + ChannelLayoutConstraints host_channel_constraints_; + // The set of HLO instructions which lacked any layout constraint, thus // receiving propagated default layouts. tensorflow::gtl::FlatSet diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 67e2cf6c777b3ecc86cfa408145b9c3cd0c31df9..a16fa75e3032cfa4257d9b5608dd176fdb4ddbdb 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -141,9 +141,9 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { std::vector> minor_to_majors = {{0, 1}, {1, 0}}; for (auto& minor_to_major : minor_to_majors) { auto builder = HloComputation::Builder(TestName()); - auto constant_literal1 = Literal::CreateR2WithLayout( + auto constant_literal1 = LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major)); - auto constant_literal2 = Literal::CreateR2WithLayout( + auto constant_literal2 = LiteralUtil::CreateR2WithLayout( {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major)); Shape ashape = constant_literal1->shape(); @@ -192,10 +192,10 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { // match their source). auto builder = HloComputation::Builder(TestName()); auto constant0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant0, constant1})); @@ -229,10 +229,10 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { // Verify layouts of a select with tuple operands is assigned properly. auto builder = HloComputation::Builder(TestName()); auto constant0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto tuple0 = builder.AddInstruction( HloInstruction::CreateTuple({constant0, constant1})); @@ -240,7 +240,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { HloInstruction::CreateTuple({constant0, constant1})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); @@ -274,7 +274,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { // tuple and assigning the layouts of the copied arrays as needed. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto inner_tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant})); auto nested_tuple = builder.AddInstruction( @@ -584,7 +584,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { auto builder = HloComputation::Builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7}); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(input_shape, constant, {})); auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -770,8 +770,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { false_builder.AddInstruction( HloInstruction::CreateParameter(0, tshape, "param")); // Using infeed as layout assignment does not mess up with it. - auto token = - false_builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto token = false_builder.AddInstruction(HloInstruction::CreateToken()); auto infeed = false_builder.AddInstruction( HloInstruction::CreateInfeed(xshape, token, "")); auto infeed_data = false_builder.AddInstruction( @@ -803,7 +802,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { auto builder = HloComputation::Builder(TestName()); auto constant0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); builder.AddInstruction(HloInstruction::CreateUnary( constant0->shape(), HloOpcode::kBitcast, constant0)); @@ -829,12 +828,14 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { ENTRY entry_computation { param = (f32[2,2]) parameter(0) gte = f32[2,2] get-tuple-element(param), index=0 - recv = (f32[2,2], u32[]) recv(), channel_id=1, sharding={maximal device=1} - ROOT recv-done = f32[2,2] recv-done(recv), channel_id=1, + token = token[] after-all() + recv = (f32[2,2], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=1} + recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1, sharding={maximal device=1} - send = (f32[2,2], u32[]) send(gte), channel_id=1, + ROOT root = f32[2,2] get-tuple-element(recv-done), index=0 + send = (f32[2,2], u32[], token[]) send(gte, token), channel_id=1, sharding={maximal device=0} - send-done = () send-done(send), channel_id=1, sharding={maximal device=0} + send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0} } )"; @@ -853,7 +854,7 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { AssignLayouts(module.get(), &computation_layout, &channel_constraints); EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "recv-done"), ElementsAre(1, 0)); + EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0)); EXPECT_TRUE( ShapeUtil::Equal(ShapeUtil::GetSubshape( FindInstruction(module.get(), "send")->shape(), {0}), diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index f1e7fc29532ce7e6841010a5258f4000a7c70383..cdd3daf73b8ac1a4d1ec3c81224c2c0bfe8e5811 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -21,6 +21,11 @@ filegroup( ]), ) +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + cc_library( name = "alias_analysis", srcs = ["alias_analysis.cc"], @@ -37,12 +42,25 @@ cc_library( ], ) +tf_cc_test( + name = "alias_analysis_test", + srcs = ["alias_analysis_test.cc"], + deps = [ + ":alias_analysis", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/core:test", + ], +) + cc_library( name = "llvm_util", srcs = ["llvm_util.cc"], hdrs = ["llvm_util.h"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -106,12 +124,31 @@ cc_library( ], ) +cc_library( + name = "kernel_tiling", + srcs = ["kernel_tiling.cc"], + hdrs = ["kernel_tiling.h"], + deps = [ + ":ir_array", + ":llvm_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + cc_library( name = "fused_ir_emitter", srcs = ["fused_ir_emitter.cc"], hdrs = ["fused_ir_emitter.h"], deps = [ ":ir_array", + ":kernel_tiling", ":llvm_util", ":loop_emitter", ":tuple_ops", @@ -127,9 +164,9 @@ cc_library( ) cc_library( - name = "ops", - srcs = ["ops.cc"], - hdrs = ["ops.h"], + name = "dynamic_update_slice_util", + srcs = ["dynamic_update_slice_util.cc"], + hdrs = ["dynamic_update_slice_util.h"], deps = [ ":fused_ir_emitter", ":ir_array", @@ -143,6 +180,23 @@ cc_library( ], ) +cc_library( + name = "sort_util", + srcs = ["sort_util.cc"], + hdrs = ["sort_util.h"], + deps = [ + ":ir_array", + ":llvm_loop", + ":llvm_util", + ":loop_emitter", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", + "//tensorflow/compiler/xla/service/gpu:partition_assignment", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + cc_library( name = "tuple_ops", srcs = ["tuple_ops.cc"], @@ -169,3 +223,22 @@ cc_library( "@llvm//:core", ], ) + +cc_library( + name = "buffer_assignment_util", + srcs = ["buffer_assignment_util.cc"], + hdrs = ["buffer_assignment_util.h"], + deps = [ + "//tensorflow/compiler/xla/service:buffer_assignment", + ], +) + +cc_library( + name = "math_ops", + srcs = ["math_ops.cc"], + hdrs = ["math_ops.h"], + deps = [ + ":llvm_util", + "@llvm//:core", + ], +) diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index f200a08a3cd7e33351ec4607d67d40e7ab28f3b9..e5370eca56f2e3a891523ba2b72961d66ec809aa 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -28,16 +28,16 @@ namespace llvm_ir { // Sentry allocation used to represent parameters of the entry computation in // alias_scope_metadata_ and noalias_metadata_. static const BufferAllocation* kParameterAllocation = new BufferAllocation( - /*index=*/-1, /*size=*/0, /*is_thread_local=*/false, /*is_reusable=*/false, - LogicalBuffer::Color(0)); + /*index=*/-1, /*size=*/0, LogicalBuffer::Color(0)); void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, llvm_ir::IrArray* array, const ShapeIndex& index) { BufferAllocation::Slice buffer_slice; - if (hlo.opcode() == HloOpcode::kParameter) { - // Parameters may alias with each other but may not alias with our temporary - // buffers. + if (hlo.opcode() == HloOpcode::kParameter && + hlo.parent() == hlo.parent()->parent()->entry_computation()) { + // Entry computation parameters may alias with each other but may not alias + // with our temporary buffers. buffer_slice = BufferAllocation::Slice(kParameterAllocation, 0, 0); } else { const std::set slices = diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fe5ec1cc66d06e85ce70625ef7cf764a37b29166 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { +namespace { +class AliasAnalysisTest : public CpuCodegenTest {}; + +void FakeCustomCallTarget(float* out, float** in) {} + +REGISTER_CUSTOM_CALL_TARGET(FakeCustomCallTarget); + +TEST_F(AliasAnalysisTest, EmbeddedComputationParamsMayAliasTemps) { + const char* hlo_string = R"( +HloModule while + +body { + const.0.125 = f32[] constant(0.125) + body.state = f32[] parameter(0) + ROOT add.2.2 = f32[] add(const.0.125, body.state) +} + +condition { + const.100 = f32[] constant(100) + condition.state = f32[] parameter(0) + addend = f32[] custom-call(condition.state), custom_call_target="FakeCustomCallTarget" + add = f32[] add(addend, condition.state) + ROOT greater-than = pred[] greater-than(const.100, add) +} + +ENTRY while3 { + const.0 = f32[] constant(0) + ROOT while = f32[] while(const.0), condition=condition, body=body +} +)"; + + CompileAndVerifyIr(hlo_string, R"( +; CHECK-LABEL: @body(i8* %retval +; CHECK: %[[add_result:.*]] = fadd fast float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]] +; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:[0-9]+]] +; +; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params +; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %temps, i64 0 +; CHECK: %[[cond_state_buf_untyped:.*]] = load i8*, i8** %[[cond_state_buf_ptr]] +; CHECK: %[[cond_state_buf_typed:.*]] = bitcast i8* %[[cond_state_buf_untyped]] to float* +; CHECK: load float, float* %[[cond_state_buf_typed]], !alias.scope ![[alias_scope_md_for_store]], !noalias ![[noalias_md_for_load:.*]] +; +; CHECK-LABEL: @while3( + +![[alias_scope_md_for_store]] = !{![[buffer_idx_0:.*]]} +![[buffer_idx_0]] = !{!"buffer: {index:0, offset:0, size:4}", ![[aa_md_root:.*]]} +![[aa_md_root]] = !{!"XLA global AA domain"} +![[buffer_idx_1:.*]] = !{!"buffer: {index:1, offset:0, size:4}", !3} +![[buffer_idx_1_offset_16:.*]] = !{!"buffer: {index:1, offset:16, size:1}", !3} +![[noalias_md_for_load]] = !{![[buffer_idx_1_offset_16]], ![[buffer_idx_1]]} +} +)"); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..4eb5d9fb4750927ca189e02f312b2d6be7fdd418 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" + +namespace xla { +namespace llvm_ir { +static const HloInstruction& InstrForConstantBufferAllocation( + const BufferAllocation& allocation) { + CHECK(allocation.is_constant()); + HloInstruction* const_instr = nullptr; + for (const auto& buffer_offset_pair : allocation.assigned_buffers()) { + const LogicalBuffer* buffer = buffer_offset_pair.first; + // BufferAssignment may have assigned non-constant instructions to this + // allocation too so we can't CHECK this condition. E.g. for + // + // while(init = constant, body = identity, cond = ...) + // + // the LogicalBuffer for the kWhile instruction will have the same + // BufferAllocation as the LogicalBuffer for the (init) constant. + if (buffer->instruction()->opcode() == HloOpcode::kConstant) { + CHECK_EQ(const_instr, nullptr) + << const_instr->ToString() << " " << buffer->ToString(); + const_instr = buffer->instruction(); + } + } + CHECK_NE(const_instr, nullptr); + return *const_instr; +} + +string ConstantBufferAllocationToGlobalName( + const BufferAllocation& allocation) { + string instr_name = InstrForConstantBufferAllocation(allocation).name(); + for (char& c : instr_name) { + if (c == '.') { + c = '_'; + } + } + return tensorflow::strings::StrCat("buffer_for_", instr_name); +} + +const Literal& LiteralForConstantAllocation( + const BufferAllocation& allocation) { + return InstrForConstantBufferAllocation(allocation).literal(); +} +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h new file mode 100644 index 0000000000000000000000000000000000000000..bfb6eecb87f6a1b756b3a8da3377f608dd7f0be7 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" + +namespace xla { +namespace llvm_ir { +// In XLA:GPU we map constant buffer allocations to globals in the generated +// LLVM IR. This function gives us the name of the global variable a constant +// buffer is mapped to. Not used on XLA:CPU. +string ConstantBufferAllocationToGlobalName(const BufferAllocation& allocation); + +// Returns the Literal corresponding to `allocation`, which must be a constant +// allocation. +const Literal& LiteralForConstantAllocation(const BufferAllocation& allocation); +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_BUFFER_ASSIGNMENT_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc similarity index 75% rename from tensorflow/compiler/xla/service/llvm_ir/ops.cc rename to tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index 3b298f4746d6177da52ba0227705d07fbeba5c19..27fbb11e2ede66a1268e7e949634b2c7d29cbc1c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" @@ -38,16 +38,16 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, // Emits a sequential loop if launch_dimensions is null. static Status EmitDynamicUpdateSliceInPlaceImpl( const Shape& update_shape, const ElementGenerator& start_indices_generator, - ElementGenerator update_array_generator, const IrArray& output_array, - const gpu::LaunchDimensions* launch_dimensions, - tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) { + bool is_signed, ElementGenerator update_array_generator, + const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions, + tensorflow::StringPiece name, llvm::IRBuilder<>* b) { const Shape& output_shape = output_array.GetShape(); // Read start indices from start_indices_generator. const int64 rank = ShapeUtil::Rank(output_shape); - IrArray::Index start_index(ir_builder->getInt64Ty(), rank); + IrArray::Index start_index(b->getInt64Ty(), rank); for (int64 i = 0; i < rank; ++i) { - IrArray::Index dim_index({ir_builder->getInt64(i)}); + IrArray::Index dim_index({b->getInt64(i)}); TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index)); llvm::Value* output_dim_size = llvm::ConstantInt::get( start_index[i]->getType(), output_shape.dimensions(i)); @@ -56,21 +56,19 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, output_dim_size - update_dim_size) - - // TODO(b/74360564): This is implementation defined behavior, but is - // currently respected by all implementations. Change this if we ever decide - // to oficially document different behavior. - llvm::Value* max_bound = - ir_builder->CreateSub(output_dim_size, update_dim_size); + llvm::Value* max_bound = b->CreateSub(output_dim_size, update_dim_size); llvm::Value* zero = llvm::ConstantInt::get(start_index[i]->getType(), 0); - start_index[i] = ir_builder->CreateSelect( - ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SGE, zero, start_index[i]), - zero, start_index[i]); - - start_index[i] = ir_builder->CreateSelect( - ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SLE, max_bound, - start_index[i]), - max_bound, start_index[i]); + start_index[i] = + b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE + : llvm::ICmpInst::ICMP_UGE, + zero, start_index[i]), + zero, start_index[i]); + + start_index[i] = + b->CreateSelect(b->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE + : llvm::ICmpInst::ICMP_ULE, + max_bound, start_index[i]), + max_bound, start_index[i]); } auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status { @@ -81,31 +79,30 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( // IrArray::Index output_index(start_index.GetType(), rank); for (int64 i = 0; i < rank; ++i) { - llvm::Value* start_index0 = ir_builder->CreateSExtOrBitCast( - start_index[i], update_index[i]->getType()); - output_index[i] = ir_builder->CreateAdd(start_index0, update_index[i]); + llvm::Value* start_index0 = + b->CreateSExtOrBitCast(start_index[i], update_index[i]->getType()); + output_index[i] = b->CreateAdd(start_index0, update_index[i]); } // Do output[output_index] = update[update_index]. TF_ASSIGN_OR_RETURN(llvm::Value * update_data, update_array_generator(update_index)); - output_array.EmitWriteArrayElement(output_index, update_data, ir_builder); + output_array.EmitWriteArrayElement(output_index, update_data, b); return Status::OK(); }; if (launch_dimensions != nullptr) { return gpu::ParallelLoopEmitter(loop_body_emitter, update_shape, - *launch_dimensions, ir_builder) + *launch_dimensions, b) .EmitLoop(name); } - return LoopEmitter(loop_body_emitter, update_shape, ir_builder) - .EmitLoop(name); + return LoopEmitter(loop_body_emitter, update_shape, b).EmitLoop(name); } Status EmitDynamicUpdateSliceInPlace( tensorflow::gtl::ArraySlice operand_arrays, const IrArray& output_array, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder) { + llvm::IRBuilder<>* b) { VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name; // No need to use operand_arrays[0], the input array of the @@ -116,15 +113,16 @@ Status EmitDynamicUpdateSliceInPlace( Shape update_shape = update_array.GetShape(); ElementGenerator start_indices_generator = [&](const IrArray::Index& index) { - return start_indices_array.EmitReadArrayElement(index, ir_builder); + return start_indices_array.EmitReadArrayElement(index, b); }; ElementGenerator update_array_generator = [&](const IrArray::Index& index) { - return update_array.EmitReadArrayElement(index, ir_builder); + return update_array.EmitReadArrayElement(index, b); }; + bool is_signed = ShapeUtil::ElementIsSigned(start_indices_array.GetShape()); return EmitDynamicUpdateSliceInPlaceImpl( - update_shape, start_indices_generator, update_array_generator, - output_array, /*launch_dimensions=*/nullptr, name, ir_builder); + update_shape, start_indices_generator, is_signed, update_array_generator, + output_array, /*launch_dimensions=*/nullptr, name, b); } // Shared implementation for EmitFusedDynamicUpdateSliceInPlace and @@ -135,8 +133,7 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( HloInstruction* fusion, tensorflow::gtl::ArraySlice fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, - const gpu::LaunchDimensions* launch_dimensions, - llvm::IRBuilder<>* ir_builder) { + const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) { CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for " << fusion->ToShortString(); @@ -170,30 +167,30 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( ElementGenerator start_indices_generator = fused_emitter.GetGenerator(start_indices); + bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape()); return EmitDynamicUpdateSliceInPlaceImpl( - update_shape, start_indices_generator, update_array_generator, - fusion_output_array, launch_dimensions, IrName(fusion), ir_builder); + update_shape, start_indices_generator, is_signed, update_array_generator, + fusion_output_array, launch_dimensions, IrName(fusion), b); } Status EmitFusedDynamicUpdateSliceInPlace( HloInstruction* fusion, tensorflow::gtl::ArraySlice fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, - llvm::IRBuilder<>* ir_builder) { + llvm::IRBuilder<>* b) { return EmitFusedDynamicUpdateSliceInPlaceImpl( fusion, fusion_operand_arrays, fusion_output_array, elemental_emitter, - /*launch_dimensions=*/nullptr, ir_builder); + /*launch_dimensions=*/nullptr, b); } Status EmitParallelFusedDynamicUpdateSliceInPlace( HloInstruction* fusion, tensorflow::gtl::ArraySlice fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, - const gpu::LaunchDimensions& launch_dimensions, - llvm::IRBuilder<>* ir_builder) { + const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) { return EmitFusedDynamicUpdateSliceInPlaceImpl( fusion, fusion_operand_arrays, fusion_output_array, elemental_emitter, - &launch_dimensions, ir_builder); + &launch_dimensions, b); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h similarity index 91% rename from tensorflow/compiler/xla/service/llvm_ir/ops.h rename to tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h index 175b081e84d31779b15560cb0998011fe046ca01..3502577d236a099e0b721b98217b758696966821 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_DYNAMIC_UPDATE_SLICE_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_DYNAMIC_UPDATE_SLICE_UTIL_H_ #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" @@ -66,7 +66,7 @@ inline bool CanEmitFusedDynamicUpdateSliceInPlace( Status EmitDynamicUpdateSliceInPlace( tensorflow::gtl::ArraySlice operand_arrays, const IrArray& output_array, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Given a loop-fusion node whose root is a dynamic-update-slice op whose // array-to-be-updated and output share the same buffer slice, emits @@ -76,7 +76,7 @@ Status EmitFusedDynamicUpdateSliceInPlace( HloInstruction* fusion, tensorflow::gtl::ArraySlice fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with // the given launch dimensions. @@ -84,10 +84,9 @@ Status EmitParallelFusedDynamicUpdateSliceInPlace( HloInstruction* fusion, tensorflow::gtl::ArraySlice fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, - const gpu::LaunchDimensions& launch_dimensions, - llvm::IRBuilder<>* ir_builder); + const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b); } // namespace llvm_ir } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_DYNAMIC_UPDATE_SLICE_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index d909845a3a21fc55e44b0037371fca30e577980f..72ede377e1a505d5e4916915e18827e1a0f3fdf9 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -52,7 +52,7 @@ Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) { // that would be regenerated without caching. But this might increase the // JIT compilation time. if (generated_value_bb == nullptr || - generated_value_bb == ir_builder_->GetInsertBlock()) { + generated_value_bb == b_->GetInsertBlock()) { VLOG(3) << "The cached generated value is reused."; return generated_value; } @@ -60,8 +60,7 @@ Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) { "a different BB (" << llvm_ir::AsString(generated_value_bb->getName()) << ") from the current insertion block (" - << llvm_ir::AsString(ir_builder_->GetInsertBlock()->getName()) - << ")."; + << llvm_ir::AsString(b_->GetInsertBlock()->getName()) << ")."; } TF_ASSIGN_OR_RETURN( @@ -77,14 +76,14 @@ Status FusedIrEmitter::HandleConstant(HloInstruction* constant) { llvm::Constant* initializer = llvm_ir::ConvertLiteralToIrConstant(literal, module_); llvm::GlobalVariable* global = new llvm::GlobalVariable( - *ir_builder_->GetInsertBlock()->getModule(), initializer->getType(), + *b_->GetInsertBlock()->getModule(), initializer->getType(), /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, /*Name=*/""); llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( global, llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); generators_[constant] = [=](const IrArray::Index& index) { return IrArray(shape_constant, constant->shape()) - .EmitReadArrayElement(index, ir_builder_); + .EmitReadArrayElement(index, b_); }; return Status::OK(); @@ -104,7 +103,7 @@ Status FusedIrEmitter::HandleGetTupleElement( // Emit code to lookup tuple element pointer, and store it in 'gte_values_'. llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement( get_tuple_element->shape(), get_tuple_element->tuple_index(), - /*alignment=*/1, it->second, ir_builder_, module_); + /*alignment=*/1, it->second, b_, module_); gte_values_.insert(std::make_pair(get_tuple_element, tuple_element_ptr)); // Emit code to read base tuple element array (if non-tuple shaped). if (!ShapeUtil::IsTuple(get_tuple_element->shape())) { @@ -112,16 +111,32 @@ Status FusedIrEmitter::HandleGetTupleElement( [=](const IrArray::Index& index) -> StatusOr { // TODO(b/34080002) Add aliasing information to tuple element IrArray. return IrArray(tuple_element_ptr, get_tuple_element->shape()) - .EmitReadArrayElement(index, ir_builder_); + .EmitReadArrayElement(index, b_); }; } return Status::OK(); } Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) { - generators_[parameter] = [=](const IrArray::Index& index) { + generators_[parameter] = [=](const IrArray::Index& index) -> llvm::Value* { + if (tiled_parameter_info_) { + if (llvm::Value* param_tile_buffer = + tiled_parameter_info_->GetBufferForParameter( + parameter->parameter_number())) { + // TODO(jlebar): Add AA metadata to this load. Tile buffers are global + // variables, so LLVM's points-to analysis doesn't help us much. And we + // want the AA info to be present before address spaces are inferred + // (which is pretty late in the pipeline), so even if we had + // address-space-based AA in LLVM, it wouldn't help us much here. + return b_->CreateLoad( + b_->CreateGEP(param_tile_buffer, {index.GetConstantWithIndexType(0), + tiled_parameter_info_->x(), + tiled_parameter_info_->y()}), + "tiled_buffer"); + } + } return parameter_arrays_[parameter->parameter_number()] - .EmitReadArrayElement(index, ir_builder_); + .EmitReadArrayElement(index, b_); }; // Store ir value for fusion operand associated with fusion parameter to be // accessed by subsequent fused GetTupleElement instructions. @@ -140,11 +155,11 @@ Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) { } generators_[tuple] = [=](const IrArray::Index& index) -> StatusOr { - llvm::Value* ret = llvm::UndefValue::get(llvm::StructType::get( - ir_builder_->getContext(), operand_elemental_ir_types)); + llvm::Value* ret = llvm::UndefValue::get( + llvm::StructType::get(b_->getContext(), operand_elemental_ir_types)); for (size_t i = 0; i < ShapeUtil::TupleElementCount(tuple->shape()); ++i) { TF_ASSIGN_OR_RETURN(llvm::Value * val_i, generators_[operands[i]](index)); - ret = ir_builder_->CreateInsertValue(ret, val_i, i); + ret = b_->CreateInsertValue(ret, val_i, i); } return ret; }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index b3b6026ef17daa184c0a015fdea618597ef068b3..30471480c4fb3ce3bf3226a28e9d2ffa79ae5f29 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -56,8 +57,9 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { FusedIrEmitter(tensorflow::gtl::ArraySlice parameter_arrays, ElementalIrEmitter* elemental_emitter) : parameter_arrays_(parameter_arrays), + tiled_parameter_info_(nullptr), elemental_emitter_(elemental_emitter), - ir_builder_(elemental_emitter->ir_builder()), + b_(elemental_emitter->b()), module_(elemental_emitter->module()) {} Status DefaultAction(HloInstruction* hlo) override; @@ -86,9 +88,14 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { return it->second; } + void SetTiledParameterInfo(const llvm_ir::TiledParameterInfo* info) { + tiled_parameter_info_ = info; + } + private: // Arrays of parameters of fusion instruction tensorflow::gtl::ArraySlice parameter_arrays_; + const llvm_ir::TiledParameterInfo* tiled_parameter_info_; ElementalIrEmitter* elemental_emitter_; @@ -96,7 +103,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { const HloInstruction* fused_root_ = nullptr; // Borrowed - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; llvm::Module* module_; // Map from instruction pointers to functions to generate elements of their diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index ea10cef49a4a9aa048b3e0ea443f052645c4912a..2b6caee6aa72f426cf85c8c56c3ef500ff8c5d3d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -31,7 +31,7 @@ namespace llvm_ir { void IrArray::Index::Delinearize(std::vector* multidim, llvm::Value* linear, const Shape& shape, - llvm::IRBuilder<>* ir_builder) const { + llvm::IRBuilder<>* b) const { int64 divisor = 1; const Layout& layout = shape.layout(); for (int64 i = 0; i < layout.minor_to_major_size(); ++i) { @@ -48,10 +48,9 @@ void IrArray::Index::Delinearize(std::vector* multidim, // useful because cuda-memcheck can't help us much in XLA: Most of our // memory lives in one big allocation, so cuda-memcheck can't detect // out-of-bounds accesses. - auto* quot = - ir_builder->CreateUDiv(linear, GetConstantWithIndexType(divisor)); + auto* quot = b->CreateUDiv(linear, GetConstantWithIndexType(divisor)); if (i < layout.minor_to_major_size() - 1) { - (*multidim)[dimension] = ir_builder->CreateURem( + (*multidim)[dimension] = b->CreateURem( quot, GetConstantWithIndexType(size_of_current_dimension)); } else { (*multidim)[dimension] = quot; @@ -61,7 +60,7 @@ void IrArray::Index::Delinearize(std::vector* multidim, } IrArray::Index::Index(llvm::Value* linear, const Shape& shape, - llvm::IRBuilder<>* ir_builder) + llvm::IRBuilder<>* b) : multidim_(ShapeUtil::Rank(shape)), linear_(linear), layout_(shape.layout()), @@ -71,7 +70,7 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape, CHECK(LayoutUtil::HasLayout(shape)) << "Shape " << ShapeUtil::HumanStringWithLayout(shape) << " should have a layout."; - Delinearize(&multidim_, linear, shape, ir_builder); + Delinearize(&multidim_, linear, shape, b); } IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, @@ -94,7 +93,7 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, } IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, - const Shape& shape, llvm::IRBuilder<>* ir_builder) + const Shape& shape, llvm::IRBuilder<>* b) : multidim_(multidim.begin(), multidim.end()), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { @@ -328,6 +327,7 @@ llvm::Value* IrArray::Index::Linearize( llvm::IRBuilder<>* builder) const { // Each dimension is multiplied by the product of the sizes of all // earlier dimensions and added to the accumulator logical_linear_index. + CHECK_EQ(size(), dimensions.size()); llvm::Value* logical_linear_index = GetConstantWithIndexType(0); int64 multiplier = 1; for (ssize_t i = size() - 1; i >= 0; --i) { @@ -343,7 +343,7 @@ llvm::Value* IrArray::Index::Linearize( } llvm::Value* IrArray::EmitArrayElementAddress( - const IrArray::Index& index, llvm::IRBuilder<>* ir_builder, + const IrArray::Index& index, llvm::IRBuilder<>* b, tensorflow::StringPiece name) const { if (ShapeUtil::IsScalar(*shape_)) { // Special handling of scalars: a scalar pretends to have the same value for @@ -354,12 +354,11 @@ llvm::Value* IrArray::EmitArrayElementAddress( CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_)); if (index.LinearValidOnShape(*shape_)) { - llvm::Module* module = - ir_builder->GetInsertBlock()->getParent()->getParent(); - return ir_builder->CreateInBoundsGEP( - ir_builder->CreateBitCast( - base_ptr_, PrimitiveTypeToIrType(shape_->element_type(), module) - ->getPointerTo()), + llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); + return b->CreateInBoundsGEP( + b->CreateBitCast(base_ptr_, + PrimitiveTypeToIrType(shape_->element_type(), module) + ->getPointerTo()), {index.linear()}, llvm_ir::AsStringRef(name)); } @@ -385,8 +384,8 @@ llvm::Value* IrArray::EmitArrayElementAddress( int64 dimension = LayoutUtil::Major(shape_->layout(), i); gep_indices.push_back(actual_index[dimension]); } - return ir_builder->CreateInBoundsGEP(base_ptr_, gep_indices, - llvm_ir::AsStringRef(name)); + return b->CreateInBoundsGEP(base_ptr_, gep_indices, + llvm_ir::AsStringRef(name)); } void IrArray::AnnotateLoadStoreInstructionWithMetadata( @@ -402,37 +401,37 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata( } llvm::Value* IrArray::EmitReadArrayElement(const Index& index, - llvm::IRBuilder<>* ir_builder, + llvm::IRBuilder<>* b, tensorflow::StringPiece name) const { - llvm::Value* element_address = - EmitArrayElementAddress(index, ir_builder, name); - llvm::LoadInst* load = ir_builder->CreateLoad(element_address); + llvm::Value* element_address = EmitArrayElementAddress(index, b, name); + llvm::LoadInst* load = b->CreateLoad(element_address); AnnotateLoadStoreInstructionWithMetadata(load); return load; } void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value, - llvm::IRBuilder<>* ir_builder) const { - llvm::Value* element_address = EmitArrayElementAddress(index, ir_builder); - llvm::StoreInst* store = ir_builder->CreateStore(value, element_address); + llvm::IRBuilder<>* b) const { + llvm::Value* element_address = EmitArrayElementAddress(index, b); + llvm::StoreInst* store = b->CreateStore(value, element_address); AnnotateLoadStoreInstructionWithMetadata(store); } IrArray IrArray::CastToShape(const Shape& new_shape, - llvm::IRBuilder<>* ir_builder) const { - llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent(); + llvm::IRBuilder<>* b) const { + llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module); - return IrArray( - ir_builder->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()), - new_shape); + IrArray new_irarray( + b->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()), new_shape); + new_irarray.metadata_ = metadata_; + return new_irarray; } /* static */ IrArray::Index IrArray::BumpIndex(const Index& index, int64 which_dimension, int64 addend, - llvm::IRBuilder<>* ir_builder) { + llvm::IRBuilder<>* b) { Index new_index = index; - new_index[which_dimension] = ir_builder->CreateAdd( + new_index[which_dimension] = b->CreateAdd( index[which_dimension], llvm::ConstantInt::get(index[which_dimension]->getType(), addend), "", /*HasNUW=*/true, diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 4648c6d7ac089dbea7e660dd9889d557c8ad7318..28ca793e3eeaed86664bfa6aa859a38f2c4dc6f3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -87,20 +87,19 @@ class IrArray { } // Constructs an index from linear index "linear" and computes the - // multi-dimensional index from "linear" and "shape". "ir_builder" is the IR + // multi-dimensional index from "linear" and "shape". "b" is the IR // builder to emit the index of each dimension in the multi-dimensional // index. // // Precondition: "shape" has a layout. - Index(llvm::Value* linear, const Shape& shape, - llvm::IRBuilder<>* ir_builder); + Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b); // Constructs an index from the given multi-dimensional index and the shape // that it indexes into. // // Precondition: "shape" has a layout. Index(tensorflow::gtl::ArraySlice multidim, - const Shape& shape, llvm::IRBuilder<>* ir_builder); + const Shape& shape, llvm::IRBuilder<>* b); // Constructs an index from both a multi-dimensional index and a linear // index. "shape" has the same meaning as that in the constructor that takes @@ -114,19 +113,19 @@ class IrArray { size_t size() const { return multidim().size(); } llvm::Value* operator[](size_t i) const { return multidim()[i]; } - llvm::Value*& operator[](size_t i) { return multidim()[i]; } + llvm::Value*& operator[](size_t i) { return mutable_multidim()[i]; } - void push_back(llvm::Value* value) { multidim().push_back(value); } + void push_back(llvm::Value* value) { mutable_multidim().push_back(value); } void InsertAt(int64 index, llvm::Value* value) { CHECK_LE(index, size()); - multidim().insert(multidim().begin() + index, value); + mutable_multidim().insert(mutable_multidim().begin() + index, value); } using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; - iterator begin() { return multidim().begin(); } - iterator end() { return multidim().end(); } + iterator begin() { return mutable_multidim().begin(); } + iterator end() { return mutable_multidim().end(); } const_iterator begin() const { return multidim().begin(); } const_iterator end() const { return multidim().end(); } @@ -185,13 +184,13 @@ class IrArray { private: // Changing the multi-dimensional index invalidates the linear index. - std::vector& multidim() { + std::vector& mutable_multidim() { linear_ = nullptr; return multidim_; } void Delinearize(std::vector* multidim, llvm::Value* linear, - const Shape& shape, llvm::IRBuilder<>* ir_builder) const; + const Shape& shape, llvm::IRBuilder<>* b) const; std::vector multidim_; @@ -240,8 +239,7 @@ class IrArray { // // The optional name is useful for debugging when looking at // the emitted LLVM IR. - llvm::Value* EmitArrayElementAddress(const Index& index, - llvm::IRBuilder<>* ir_builder, + llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b, tensorflow::StringPiece name = "") const; // Attach metadata this IrArray instance knows about to "instruction". @@ -255,18 +253,16 @@ class IrArray { // // The optional name is useful for debugging when looking at // the emitted LLVM IR. - llvm::Value* EmitReadArrayElement(const Index& index, - llvm::IRBuilder<>* ir_builder, + llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, tensorflow::StringPiece name = "") const; // Emit IR to write the given value to the array element at the given index. void EmitWriteArrayElement(const Index& index, llvm::Value* value, - llvm::IRBuilder<>* ir_builder) const; + llvm::IRBuilder<>* b) const; // Returns a new IrArray whose shape is "new_shape" and base pointer is a // bitcast of the base pointer of "this" IrArray. - IrArray CastToShape(const Shape& new_shape, - llvm::IRBuilder<>* ir_builder) const; + IrArray CastToShape(const Shape& new_shape, llvm::IRBuilder<>* b) const; void AddAliasScopeMetadata(llvm::MDNode* alias_scope) { CHECK_NE(alias_scope, nullptr); @@ -312,7 +308,7 @@ class IrArray { // Bumps the "which_dimension" value within the provided index by the provided // addend. static Index BumpIndex(const Index& index, int64 which_dimension, - int64 addend, llvm::IRBuilder<>* ir_builder); + int64 addend, llvm::IRBuilder<>* b); private: // Add the specified LLVM IR metadata to loads/stores associated with this diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index 1f6e3c829f890d68aa251b101f0402c120a19d61..b79567369aa532c4963e3941f6cb9844cd1476dd 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -22,9 +22,9 @@ Status KernelSupportLibrary::For( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - return If(ir_builder_->CreateICmpSLT(start, end), [&]() -> Status { + return If(b_->CreateICmpSLT(start, end), [&]() -> Status { TF_RETURN_IF_ERROR(for_body_generator(start, /*is_first_iteration=*/true)); - return For(name, ir_builder_->CreateAdd(start, step), end, step, + return For(name, b_->CreateAdd(start, step), end, step, [&](llvm::Value* iv) { return for_body_generator(iv, false); }); }); } @@ -37,44 +37,44 @@ Status KernelSupportLibrary::For( if (peel_first_iteration) { return For(name, start, end, step, true, [&](llvm::Value* indvar, bool is_first_iteration) -> Status { - return for_body_generator( - indvar, ir_builder_->getInt1(is_first_iteration)); + return for_body_generator(indvar, + b_->getInt1(is_first_iteration)); }); } else { std::unique_ptr loop = llvm_ir::ForLoop::EmitForLoop( - name, start, end, step, ir_builder_, + name, start, end, step, b_, /*unroll_mode=*/unroll_mode_, /*prevent_vectorization=*/prevent_vectorization_); - ir_builder_->SetInsertPoint(&loop->GetBodyBasicBlock()->back()); + b_->SetInsertPoint(&loop->GetBodyBasicBlock()->back()); TF_RETURN_IF_ERROR( for_body_generator(loop->GetIndVarValue(), - /*is_first_iteration=*/ir_builder_->CreateICmpEQ( + /*is_first_iteration=*/b_->CreateICmpEQ( loop->GetIndVarValue(), start))); - llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), ir_builder_); + llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), b_); return Status::OK(); } } Status KernelSupportLibrary::If( - llvm::Value* condition, const std::function& true_block_generator, + tensorflow::StringPiece name, llvm::Value* condition, + const std::function& true_block_generator, const std::function& false_block_generator) { - llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(condition, "", ir_builder_); - ir_builder_->SetInsertPoint(&if_data.true_block->back()); + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, name, b_); + b_->SetInsertPoint(&if_data.true_block->back()); TF_RETURN_IF_ERROR(true_block_generator()); - ir_builder_->SetInsertPoint(&if_data.false_block->back()); + b_->SetInsertPoint(&if_data.false_block->back()); TF_RETURN_IF_ERROR(false_block_generator()); - llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_); + llvm_ir::SetToLastInsertPoint(if_data.after_block, b_); return Status::OK(); } void KernelSupportLibrary::EmitAndCallOutlinedKernel( - bool enable_fast_math, bool optimize_for_size, - llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, + tensorflow::StringPiece kernel_name, KernelSupportLibrary::ArgumentVector arguments, const std::function& kernel_body_generator) { - llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); + llvm::Module* module = b->GetInsertBlock()->getModule(); llvm::Function* function = module->getFunction(llvm_ir::AsStringRef(kernel_name)); @@ -97,22 +97,22 @@ void KernelSupportLibrary::EmitAndCallOutlinedKernel( std::back_inserter(arg_types), [](llvm::Value* arg) { return arg->getType(); }); - auto* function_type = llvm::FunctionType::get( - ir_builder->getVoidTy(), arg_types, /*isVarArg=*/false); + auto* function_type = + llvm::FunctionType::get(b->getVoidTy(), arg_types, /*isVarArg=*/false); function = llvm_ir::CreateFunction( function_type, llvm::GlobalValue::InternalLinkage, /*enable_fast_math=*/enable_fast_math, /*optimize_for_size=*/optimize_for_size, kernel_name, module); - llvm::IRBuilder<>::InsertPointGuard guard(*ir_builder); + llvm::IRBuilder<>::InsertPointGuard guard(*b); auto* entry_bb = - llvm::BasicBlock::Create(ir_builder->getContext(), "entry", function); - auto* return_inst = llvm::ReturnInst::Create(ir_builder->getContext(), + llvm::BasicBlock::Create(b->getContext(), "entry", function); + auto* return_inst = llvm::ReturnInst::Create(b->getContext(), /*retVal=*/nullptr, entry_bb); // Set the insert point to before return_inst. - ir_builder->SetInsertPoint(return_inst); + b->SetInsertPoint(return_inst); std::vector arg_values; /* @@ -132,7 +132,7 @@ void KernelSupportLibrary::EmitAndCallOutlinedKernel( VLOG(3) << "Re-using kernel for " << kernel_name; } - ir_builder->CreateCall(function, llvm_ir::AsArrayRef(sanitized_args)); + b->CreateCall(function, llvm_ir::AsArrayRef(sanitized_args)); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 6f7a9d94e3b9e59b2dfe12b9673335a904ae78b6..b00f903d56a83c5b76188007702470c44c55c213 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -30,14 +30,14 @@ namespace xla { // flow more readable. class KernelSupportLibrary { public: - // `ir_builder` is the llvm::IRBuilder instance used to generate LLVM IR. + // `b` is the llvm::IRBuilder instance used to generate LLVM IR. // `unroll_mode` specifies the desired LLVM unrolling behavior for every loop // generated by this instance of KernelSupportLibrary. explicit KernelSupportLibrary( - llvm::IRBuilder<>* ir_builder, + llvm::IRBuilder<>* b, llvm_ir::UnrollMode unroll_mode = llvm_ir::UnrollMode::kNoUnroll, bool prevent_vectorization = true) - : ir_builder_(ir_builder), + : b_(b), unroll_mode_(unroll_mode), prevent_vectorization_(prevent_vectorization) {} @@ -71,18 +71,18 @@ class KernelSupportLibrary { const std::function& for_body_generator) { - return For(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + return For(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } void ForReturnVoid( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + ForReturnVoid(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } // Generates the following control flow structure if `peel_first_iteration` is @@ -184,17 +184,17 @@ class KernelSupportLibrary { Status For( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - return For(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + return For(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } void ForReturnVoid( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + ForReturnVoid(name, /*start=*/b_->getInt64(start), + /*end=*/b_->getInt64(end), + /*step=*/b_->getInt64(step), for_body_generator); } // Generates the following control flow structure: @@ -203,16 +203,30 @@ class KernelSupportLibrary { // `true_block_generator()`; // else // `false_block_generator()`; - Status If(llvm::Value* condition, + Status If(tensorflow::StringPiece name, llvm::Value* condition, const std::function& true_block_generator, const std::function& false_block_generator = []() -> Status { return Status::OK(); }); + Status If(llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = + []() -> Status { return Status::OK(); }) { + return If("", condition, true_block_generator, false_block_generator); + } + void IfReturnVoid(llvm::Value* condition, const std::function& true_block_generator, const std::function& false_block_generator = []() { }) { - TF_CHECK_OK(If(condition, + IfReturnVoid("", condition, true_block_generator, false_block_generator); + } + + void IfReturnVoid(tensorflow::StringPiece name, llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() { + }) { + TF_CHECK_OK(If(name, condition, [&]() { true_block_generator(); return Status::OK(); @@ -244,41 +258,39 @@ class KernelSupportLibrary { // in a nullptr llvm::Value* in its position to `kernel_body_generator`. // Currently we only support at most one nullptr value in `arguments`. static void EmitAndCallOutlinedKernel( - bool enable_fast_math, bool optimize_for_size, - llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, - ArgumentVector arguments, + bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, + tensorflow::StringPiece kernel_name, ArgumentVector arguments, const std::function& kernel_body_generator); // Thin wrappers around the more general EmitAndCallOutlinedKernel above. static void EmitAndCallOutlinedKernel( - bool enable_fast_math, bool optimize_for_size, - llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, - llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, + bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, + tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1, + llvm::Value* arg2, const std::function& kernel_body_generator) { EmitAndCallOutlinedKernel( - enable_fast_math, optimize_for_size, ir_builder, kernel_name, - {arg0, arg1, arg2}, [&](ArgumentVector args) { + enable_fast_math, optimize_for_size, b, kernel_name, {arg0, arg1, arg2}, + [&](ArgumentVector args) { kernel_body_generator(args[0], args[1], args[2]); }); } static void EmitAndCallOutlinedKernel( - bool enable_fast_math, bool optimize_for_size, - llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, - llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, - llvm::Value* arg3, + bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, + tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1, + llvm::Value* arg2, llvm::Value* arg3, const std::function& kernel_body_generator) { EmitAndCallOutlinedKernel( - enable_fast_math, optimize_for_size, ir_builder, kernel_name, + enable_fast_math, optimize_for_size, b, kernel_name, {arg0, arg1, arg2, arg3}, [&](ArgumentVector args) { kernel_body_generator(args[0], args[1], args[2], args[3]); }); } private: - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; llvm_ir::UnrollMode unroll_mode_; bool prevent_vectorization_; }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc new file mode 100644 index 0000000000000000000000000000000000000000..35b394127288d816952b48c84b193257bab0bcda --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -0,0 +1,118 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace llvm_ir { + +namespace { +// Returns the indices of the first elements of all consecutive subarrays of the +// given array. For example: +// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4} +std::vector ConsecutiveSegments(tensorflow::gtl::ArraySlice xs) { + std::vector is = {0}; + for (size_t i = 1; i < xs.size(); ++i) { + if (1 != xs[i] - xs[i - 1]) { + is.push_back(i); + } + } + return is; +} + +// Merges the sequences of dimensions of the given shape which start at the +// given indices `segs`. +Shape MergeDimensions(tensorflow::gtl::ArraySlice segs, + const Shape& shape) { + std::vector dimensions; + for (size_t i = 1; i <= segs.size(); ++i) { + dimensions.push_back(std::accumulate( + shape.dimensions().begin() + segs[i - 1], + shape.dimensions().begin() + + (segs.size() == i ? shape.dimensions().size() : segs[i]), + 1, std::multiplies())); + } + return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), + dimensions); +} +} // namespace + +tensorflow::gtl::optional > FindTranspose021( + const Shape& a, const Shape& b) { + if (!ShapeUtil::CompatibleIgnoringElementType(a, b)) { + return tensorflow::gtl::nullopt; + } + + std::vector perm(a.dimensions().size()); + { + auto layout_a_orig = LayoutUtil::MinorToMajor(a); + std::vector layout_a(layout_a_orig.rbegin(), layout_a_orig.rend()); + auto layout_b_orig = LayoutUtil::MinorToMajor(b); + std::vector layout_b(layout_b_orig.rbegin(), layout_b_orig.rend()); + for (size_t i = 0; i < perm.size(); ++i) { + perm[i] = PositionInContainer(layout_b, layout_a[i]); + } + } + auto segs = ConsecutiveSegments(perm); + if ((3 == segs.size() && 0 == perm[0]) || 2 == segs.size()) { + Shape norm_a = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a); + Shape reduced_a = MergeDimensions(segs, norm_a); + auto reduced_a_dims = reduced_a.dimensions(); + std::vector dims_021; + if (2 == segs.size()) { + // The logical component-0 is of size one. + dims_021 = {1, reduced_a_dims[1], reduced_a_dims[0]}; + } else { + dims_021 = {reduced_a_dims[0], reduced_a_dims[2], reduced_a_dims[1]}; + } + + return dims_021; + } + + return tensorflow::gtl::nullopt; +} + +IrArray::Index GetUnreducedOutputIndex( + const IrArray::Index& reduced_output_index, + const Shape& reduced_output_shape, const Shape& unreduced_output_shape, + llvm::IRBuilder<>* b) { + auto bounds = reduced_output_shape.dimensions(); + auto minor_to_major = reduced_output_shape.layout().minor_to_major(); + llvm::Value* linear_index = reduced_output_index.GetConstantWithIndexType(0); + int64 multiplier = 1; + for (int i = 0; i < reduced_output_index.size(); ++i) { + int64 dim = minor_to_major[i]; + llvm::Value* addend = + b->CreateMul(reduced_output_index[dim], + reduced_output_index.GetConstantWithIndexType(multiplier), + "linearizing", + /*HasNUW=*/true, /*HasNSW=*/true); + linear_index = b->CreateAdd(linear_index, addend, "", + /*HasNUW=*/true, /*HasNSW=*/true); + multiplier *= bounds[dim]; + } + + return IrArray::Index(linear_index, unreduced_output_shape, b); +} + +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h new file mode 100644 index 0000000000000000000000000000000000000000..ccb9b8ba3e6b0079664f2da92ce67224e176fa1d --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_ + +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" + +namespace xla { +namespace llvm_ir { + +// About 0-2-1 transpose: +// +// If a shape can be viewed as three logical components 0-1-2 in the order of +// major to minor, a 0-2-1-transpose changes the order of such logical +// components to 0-2-1. We call the shape being transposed the input shape and +// the transposed shape the output shape. The logical view of the input and +// output shapes for the transpose are called the 0-1-2 shape or reduced input +// shape and the 0-2-1 shape or the reduced output shape respectively. The +// original input and output shapes are called the unreduced input and output +// shapes. + +// If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the +// reduced shape of `b` or the 0-2-1 shape. +tensorflow::gtl::optional > FindTranspose021(const Shape& a, + const Shape& b); + +// Return the unreduced output index corresponding to the given reduced output +// index. +IrArray::Index GetUnreducedOutputIndex( + const IrArray::Index& reduced_output_index, + const Shape& reduced_output_shape, const Shape& unreduced_output_shape, + llvm::IRBuilder<>* b); + +// A class to represent information for tiled parameters to support IR emission +// for 021 transpose. +class TiledParameterInfo { + public: + TiledParameterInfo(tensorflow::gtl::ArraySlice param_buffers, + llvm::Value* y, llvm::Value* x) + : param_buffers_(param_buffers), y_(y), x_(x) {} + + llvm::Value* x() const { return x_; } + llvm::Value* y() const { return y_; } + + void set_x(llvm::Value* x) { x_ = x; } + void set_y(llvm::Value* y) { y_ = y; } + + llvm::Value* GetBufferForParameter(int64 index) const { + return param_buffers_[index]; + } + + private: + // Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr + // if the parameter is not tiled. + tensorflow::gtl::ArraySlice param_buffers_; + // The y coordinate within a tile. + llvm::Value* y_; + // The x coordinate within a tile. + llvm::Value* x_; +}; + +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index c9ae7d3afd5cdc21157732f6d0dfa824268e86bd..ba7f94834c7fd04d97cec012537244323308b8ce 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -47,27 +47,27 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, /* static */ std::unique_ptr ForLoop::EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, - llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, + llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode, bool prevent_vectorization) { std::unique_ptr loop(new ForLoop(prefix, /*suffix=*/"", start_index, end_index, step, unroll_mode, prevent_vectorization)); - loop->Emit(ir_builder); + loop->Emit(b); return loop; } -void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { +void ForLoop::Emit(llvm::IRBuilder<>* b) { // The preheader block is the block the builder is currently emitting // code into. - preheader_bb_ = ir_builder->GetInsertBlock(); + preheader_bb_ = b->GetInsertBlock(); - llvm::BasicBlock::iterator insert_point = ir_builder->GetInsertPoint(); + llvm::BasicBlock::iterator insert_point = b->GetInsertPoint(); if (insert_point == preheader_bb_->end()) { // We're emitting the loop at the end of a basic block. Verify there is no // terminator (eg, branch) in the basic block. CHECK_EQ(nullptr, preheader_bb_->getTerminator()); - exit_bb_ = CreateLoopBB("loop_exit", ir_builder); + exit_bb_ = CreateLoopBB("loop_exit", b); } else { // We're emitting the loop into the middle of a basic block. splitBasicBlock // requires that this basic block be well-formed (have a terminator). @@ -86,51 +86,50 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { insert_before_bb_ = exit_bb_; // Create remaining basic block which form the inside of the loop. - header_bb_ = CreateLoopBB("loop_header", ir_builder); - body_bb_ = CreateLoopBB("loop_body", ir_builder); + header_bb_ = CreateLoopBB("loop_header", b); + body_bb_ = CreateLoopBB("loop_body", b); // Function entry basic block. // Emit alloca for the induction variable. We do this at the entry to the // basic block to ensure the alloc only executes once per function (we could // be emitting a nested loop). llvm::Function* func = preheader_bb_->getParent(); - ir_builder->SetInsertPoint(&func->getEntryBlock(), - func->getEntryBlock().getFirstInsertionPt()); + b->SetInsertPoint(&func->getEntryBlock(), + func->getEntryBlock().getFirstInsertionPt()); llvm::Value* indvar_address = - ir_builder->CreateAlloca(start_index_->getType(), nullptr, - AsStringRef(GetQualifiedName("invar_address"))); + b->CreateAlloca(start_index_->getType(), nullptr, + AsStringRef(GetQualifiedName("invar_address"))); // Preheader basic block. // Initialize induction variable starting index. Create branch to the header. - ir_builder->SetInsertPoint(preheader_bb_); - ir_builder->CreateStore(start_index_, indvar_address); + b->SetInsertPoint(preheader_bb_); + b->CreateStore(start_index_, indvar_address); // The preheader should not have a branch yet. CHECK_EQ(preheader_bb_->getTerminator(), nullptr); - ir_builder->CreateBr(header_bb_); + b->CreateBr(header_bb_); // Header basic block. // Emit the loop conditional branch. Load and compare indvar with ending // index and jump to loop exit if equal. Jump to body otherwise. - ir_builder->SetInsertPoint(header_bb_); - indvar_ = ir_builder->CreateLoad(indvar_address, - AsStringRef(GetQualifiedName("indvar"))); - llvm::Value* exit_cond = ir_builder->CreateICmpUGE(indvar_, end_index_); - ir_builder->CreateCondBr(/*Cond=*/exit_cond, - /*True=*/exit_bb_, /*False=*/body_bb_); + b->SetInsertPoint(header_bb_); + indvar_ = + b->CreateLoad(indvar_address, AsStringRef(GetQualifiedName("indvar"))); + llvm::Value* exit_cond = b->CreateICmpUGE(indvar_, end_index_); + b->CreateCondBr(/*Cond=*/exit_cond, + /*True=*/exit_bb_, /*False=*/body_bb_); // Body basic block. // Increment indvar, store indvar, and jump to header. - ir_builder->SetInsertPoint(body_bb_); + b->SetInsertPoint(body_bb_); llvm::Value* step = step_; llvm::Value* indvar = indvar_; - llvm::Value* indvar_inc = - ir_builder->CreateAdd(indvar, step, "invar.inc", - /*HasNUW=*/true, /*HasNSW=*/true); - ir_builder->CreateStore(indvar_inc, indvar_address); - llvm::BranchInst* back_branch = ir_builder->CreateBr(header_bb_); + llvm::Value* indvar_inc = b->CreateAdd(indvar, step, "invar.inc", + /*HasNUW=*/true, /*HasNSW=*/true); + b->CreateStore(indvar_inc, indvar_address); + llvm::BranchInst* back_branch = b->CreateBr(header_bb_); - std::vector loop_metadata = GetLoopMetadata(ir_builder); + std::vector loop_metadata = GetLoopMetadata(b); if (!loop_metadata.empty()) { llvm::LLVMContext* ctx = &start_index_->getContext(); auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None); @@ -141,11 +140,10 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { } // Re-point the IR builder to the loop exit block. - ir_builder->SetInsertPoint(exit_bb_); + b->SetInsertPoint(exit_bb_); } -std::vector ForLoop::GetLoopMetadata( - llvm::IRBuilder<>* ir_builder) { +std::vector ForLoop::GetLoopMetadata(llvm::IRBuilder<>* b) { const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable"; const char* const kLlvmLoopUnrollFullMDName = "llvm.loop.unroll.full"; const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable"; @@ -160,7 +158,7 @@ std::vector ForLoop::GetLoopMetadata( if (prevent_vectorization_) { result.push_back(llvm::MDNode::get( *ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName), - llvm::ConstantAsMetadata::get(ir_builder->getFalse())})); + llvm::ConstantAsMetadata::get(b->getFalse())})); } if (unroll_mode_ == xla::llvm_ir::UnrollMode::kFullyUnroll) { @@ -175,9 +173,8 @@ string ForLoop::GetQualifiedName(tensorflow::StringPiece name) { } llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder) { - return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), - ir_builder); + llvm::IRBuilder<>* b) { + return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), b); } std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, @@ -197,12 +194,12 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, bool prevent_vectorization) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. - ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); + b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); } std::unique_ptr loop(new ForLoop( /*prefix=*/name_, suffix, start_index, end_index, stride, unroll_mode, prevent_vectorization)); - loop->Emit(ir_builder_); + loop->Emit(b_); if (outer_loop_preheader_bb_ == nullptr) { outer_loop_preheader_bb_ = loop->GetPreheaderBasicBlock(); @@ -262,5 +259,35 @@ IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( return index; } +IrArray::Index ForLoopNest::EmitOperandArrayLoopNest( + const llvm_ir::IrArray& operand_array, int64 dimension_to_skip, + tensorflow::StringPiece name_suffix) { + // Prepares the dimension list we will use to emit the loop nest. Outermost + // loops are added first. Add loops in major-to-minor order, and skip the + // 'dimension_to_skip' dimension. + std::vector dimensions; + const Shape& shape = operand_array.GetShape(); + for (int64 dimension : LayoutUtil::MinorToMajor(shape)) { + if (dimension != dimension_to_skip) { + dimensions.push_back(dimension); + } + } + + // Create loop nest with one for-loop for each dimension of the + // output. + llvm_ir::IrArray::Index index = + AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix); + // Verify every dimension except the 'dimension_to_skip' dimension was set in + // the index. + for (size_t dimension = 0; dimension < index.size(); ++dimension) { + if (dimension == dimension_to_skip) { + DCHECK_EQ(nullptr, index[dimension]); + } else { + DCHECK_NE(nullptr, index[dimension]); + } + } + return index; +} + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index 0dd5b9d3b2656af68f76c2adfcb1f3a1385eeb91..a4fed5c8dc55d38d25031252e3960404a5bf84e6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -79,7 +79,7 @@ class ForLoop { // loop. static std::unique_ptr EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, - llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, + llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode = llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); @@ -138,10 +138,10 @@ class ForLoop { UnrollMode unroll_mode, bool prevent_vectorization); // Emit the loop at the insert point of the builder. - void Emit(llvm::IRBuilder<>* ir_builder); + void Emit(llvm::IRBuilder<>* b); llvm::BasicBlock* CreateLoopBB(tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Creates a name for an LLVM construct, appending prefix_ and suffix_, if // they are set. @@ -149,7 +149,7 @@ class ForLoop { // Return a list of metadata nodes that should be associated with the // llvm::Loop for this `ForLoop`. - std::vector GetLoopMetadata(llvm::IRBuilder<>* ir_builder); + std::vector GetLoopMetadata(llvm::IRBuilder<>* b); string prefix_; string suffix_; @@ -177,19 +177,18 @@ class ForLoop { // A simple class for constructing nested for-loops. class ForLoopNest { public: - explicit ForLoopNest(llvm::IRBuilder<>* ir_builder, - llvm::Type* index_ty = nullptr) - : ForLoopNest(/*name=*/"", ir_builder) { + explicit ForLoopNest(llvm::IRBuilder<>* b, llvm::Type* index_ty = nullptr) + : ForLoopNest(/*name=*/"", b) { SetIndexType(index_ty); } - ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder, + ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* b, llvm::Type* index_ty = nullptr) : name_(std::string(name)), outer_loop_preheader_bb_(nullptr), outer_loop_exit_bb_(nullptr), inner_loop_body_bb_(nullptr), - ir_builder_(ir_builder) { + b_(b) { SetIndexType(index_ty); } @@ -248,6 +247,17 @@ class ForLoopNest { const Shape& shape, tensorflow::gtl::ArraySlice dimensions, tensorflow::StringPiece suffix); + // Emits a series of nested loops for iterating over an operand array. Loops + // are constructed in major to minor dimension layout order. No loop is + // emitted for the given 'dimension_to_skip'. The function returns an IrArray + // index for the given operand_array containing the indvars of the loops. All + // dimensions of the index are filled except for 'dimension_to_skip'. + // name_suffix is the string to append to the names of LLVM constructs (eg, + // basic blocks) constructed by this method. + IrArray::Index EmitOperandArrayLoopNest(const llvm_ir::IrArray& operand_array, + int64 dimension_to_skip, + tensorflow::StringPiece name_suffix); + // Convenience methods which return particular basic blocks of the outermost // or innermost loops. These methods return nullptr if no loops have been // added yet. @@ -259,7 +269,7 @@ class ForLoopNest { private: void SetIndexType(llvm::Type* index_ty) { - index_type_ = index_ty == nullptr ? ir_builder_->getInt64Ty() : index_ty; + index_type_ = index_ty == nullptr ? b_->getInt64Ty() : index_ty; } llvm::Constant* GetConstantWithIndexType(int64 c) const { @@ -278,7 +288,7 @@ class ForLoopNest { // has been added yet. llvm::BasicBlock* inner_loop_body_bb_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; llvm::Type* index_type_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 97bacc34b59118e60100e4749638d469a1ef1378..e6126881af8b8123e08a4eaa934b52a7fd378ce6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -26,7 +26,7 @@ limitations under the License. #include "llvm/Target/TargetOptions.h" #include "llvm/Transforms/Utils/Cloning.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" @@ -48,8 +48,8 @@ namespace { // Note, this function is only useful in an insertion context; in a global // (e.g. constants) context it will CHECK fail. -llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* ir_builder) { - auto block = CHECK_NOTNULL(ir_builder->GetInsertBlock()); +llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* b) { + auto block = CHECK_NOTNULL(b->GetInsertBlock()); auto fn = CHECK_NOTNULL(block->getParent()); auto module = CHECK_NOTNULL(fn->getParent()); return module; @@ -87,41 +87,41 @@ llvm::Value* EmitCallToIntrinsic( llvm::Intrinsic::ID intrinsic_id, tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice overloaded_types, - llvm::IRBuilder<>* ir_builder) { - llvm::Module* module = ModuleFromIRBuilder(ir_builder); + llvm::IRBuilder<>* b) { + llvm::Module* module = ModuleFromIRBuilder(b); llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( module, intrinsic_id, AsArrayRef(overloaded_types)); - return ir_builder->CreateCall(intrinsic, AsArrayRef(operands)); + return b->CreateCall(intrinsic, AsArrayRef(operands)); } llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* ir_builder) { - if (ir_builder->getFastMathFlags().noNaNs()) { - auto cmp = ir_builder->CreateFCmpUGE(lhs_value, rhs_value); - return ir_builder->CreateSelect(cmp, lhs_value, rhs_value); + llvm::IRBuilder<>* b) { + if (b->getFastMathFlags().noNaNs()) { + auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value); + return b->CreateSelect(cmp, lhs_value, rhs_value); } else { - auto cmp_ge = ir_builder->CreateFCmpOGE(lhs_value, rhs_value); - auto lhs_is_nan = ir_builder->CreateFCmpUNE(lhs_value, lhs_value); - auto sel_lhs = ir_builder->CreateOr(cmp_ge, lhs_is_nan); - return ir_builder->CreateSelect(sel_lhs, lhs_value, rhs_value); + auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value); + auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value); + auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan); + return b->CreateSelect(sel_lhs, lhs_value, rhs_value); } } llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* ir_builder) { - if (ir_builder->getFastMathFlags().noNaNs()) { - auto cmp = ir_builder->CreateFCmpULE(lhs_value, rhs_value); - return ir_builder->CreateSelect(cmp, lhs_value, rhs_value); + llvm::IRBuilder<>* b) { + if (b->getFastMathFlags().noNaNs()) { + auto cmp = b->CreateFCmpULE(lhs_value, rhs_value); + return b->CreateSelect(cmp, lhs_value, rhs_value); } else { - auto cmp_le = ir_builder->CreateFCmpOLE(lhs_value, rhs_value); - auto lhs_is_nan = ir_builder->CreateFCmpUNE(lhs_value, lhs_value); - auto sel_lhs = ir_builder->CreateOr(cmp_le, lhs_is_nan); - return ir_builder->CreateSelect(sel_lhs, lhs_value, rhs_value); + auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value); + auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value); + auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan); + return b->CreateSelect(sel_lhs, lhs_value, rhs_value); } } llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index, - llvm::IRBuilder<>* ir_builder) { + llvm::IRBuilder<>* b) { llvm::Type* array_type = array->getType(); CHECK(array_type->isPointerTy()); llvm::PointerType* array_type_as_pointer = @@ -131,16 +131,16 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index, << " array=" << llvm_ir::DumpToString(*array) << " index=" << llvm_ir::DumpToString(*index); - return ir_builder->CreateInBoundsGEP( + return b->CreateInBoundsGEP( array_type_as_pointer->getElementType(), array, llvm::isa(array) - ? llvm::ArrayRef({ir_builder->getInt64(0), index}) + ? llvm::ArrayRef({b->getInt64(0), index}) : index); } llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index, - llvm::IRBuilder<>* ir_builder) { - return EmitBufferIndexingGEP(array, ir_builder->getInt64(index), ir_builder); + llvm::IRBuilder<>* b) { + return EmitBufferIndexingGEP(array, b->getInt64(index), b); } llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, @@ -232,14 +232,15 @@ llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { return result_type; } -StatusOr EncodeSelfDescribingShapeConstant( - const Shape& shape, int32* shape_size, llvm::IRBuilder<>* ir_builder) { +StatusOr EncodeSelfDescribingShapeConstant(const Shape& shape, + int32* shape_size, + llvm::IRBuilder<>* b) { string encoded_shape = shape.SerializeAsString(); if (encoded_shape.size() > std::numeric_limits::max()) { return InternalError("Encoded shape size exceeded int32 size limit."); } *shape_size = static_cast(encoded_shape.size()); - return ir_builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(encoded_shape)); + return b->CreateGlobalStringPtr(llvm_ir::AsStringRef(encoded_shape)); } StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, @@ -262,59 +263,57 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder, + llvm::IRBuilder<>* b, int alignment) { - return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, ir_builder, - alignment); + return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment); } llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder, int alignment) { - llvm::IRBuilder<>::InsertPoint insert_point = ir_builder->saveIP(); - llvm::Function* function = ir_builder->GetInsertBlock()->getParent(); - ir_builder->SetInsertPoint(&function->getEntryBlock(), - function->getEntryBlock().getFirstInsertionPt()); + llvm::IRBuilder<>* b, int alignment) { + llvm::IRBuilder<>::InsertPoint insert_point = b->saveIP(); + llvm::Function* function = b->GetInsertBlock()->getParent(); + b->SetInsertPoint(&function->getEntryBlock(), + function->getEntryBlock().getFirstInsertionPt()); llvm::AllocaInst* alloca = - ir_builder->CreateAlloca(type, element_count, AsStringRef(name)); + b->CreateAlloca(type, element_count, AsStringRef(name)); if (alignment != 0) { alloca->setAlignment(alignment); } - ir_builder->restoreIP(insert_point); + b->restoreIP(insert_point); return alloca; } llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder) { + llvm::IRBuilder<>* b) { return llvm::BasicBlock::Create( - /*Context=*/ir_builder->getContext(), + /*Context=*/b->getContext(), /*Name=*/AsStringRef(name), - /*Parent=*/ir_builder->GetInsertBlock()->getParent(), + /*Parent=*/b->GetInsertBlock()->getParent(), /*InsertBefore*/ insert_before); } LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder, bool emit_else) { + llvm::IRBuilder<>* b, bool emit_else) { llvm_ir::LlvmIfData if_data; - if_data.if_block = ir_builder->GetInsertBlock(); - if_data.true_block = CreateBasicBlock( - nullptr, tensorflow::strings::StrCat(name, "-true"), ir_builder); + if_data.if_block = b->GetInsertBlock(); + if_data.true_block = + CreateBasicBlock(nullptr, tensorflow::strings::StrCat(name, "-true"), b); if_data.false_block = - emit_else ? CreateBasicBlock(nullptr, - tensorflow::strings::StrCat(name, "-false"), - ir_builder) + emit_else ? CreateBasicBlock( + nullptr, tensorflow::strings::StrCat(name, "-false"), b) : nullptr; // Add a terminator to the if block, if necessary. if (if_data.if_block->getTerminator() == nullptr) { - ir_builder->SetInsertPoint(if_data.if_block); + b->SetInsertPoint(if_data.if_block); if_data.after_block = CreateBasicBlock( - nullptr, tensorflow::strings::StrCat(name, "-after"), ir_builder); - ir_builder->CreateBr(if_data.after_block); + nullptr, tensorflow::strings::StrCat(name, "-after"), b); + b->CreateBr(if_data.after_block); } else { if_data.after_block = if_data.if_block->splitBasicBlock( - ir_builder->GetInsertPoint(), + b->GetInsertPoint(), AsStringRef(tensorflow::strings::StrCat(name, "-after"))); } @@ -322,39 +321,37 @@ LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, // we're going to replace it with a conditional branch. if_data.if_block->getTerminator()->eraseFromParent(); - ir_builder->SetInsertPoint(if_data.if_block); - ir_builder->CreateCondBr( - condition, if_data.true_block, - emit_else ? if_data.false_block : if_data.after_block); + b->SetInsertPoint(if_data.if_block); + b->CreateCondBr(condition, if_data.true_block, + emit_else ? if_data.false_block : if_data.after_block); - ir_builder->SetInsertPoint(if_data.true_block); - ir_builder->CreateBr(if_data.after_block); + b->SetInsertPoint(if_data.true_block); + b->CreateBr(if_data.after_block); if (emit_else) { - ir_builder->SetInsertPoint(if_data.false_block); - ir_builder->CreateBr(if_data.after_block); + b->SetInsertPoint(if_data.false_block); + b->CreateBr(if_data.after_block); } - ir_builder->SetInsertPoint(if_data.after_block, - if_data.after_block->getFirstInsertionPt()); + b->SetInsertPoint(if_data.after_block, + if_data.after_block->getFirstInsertionPt()); return if_data; } llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate, llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* ir_builder) { + llvm::IRBuilder<>* b) { llvm::Value* comparison_result; if (lhs_value->getType()->isIntegerTy()) { - comparison_result = ir_builder->CreateICmp(predicate, lhs_value, rhs_value); + comparison_result = b->CreateICmp(predicate, lhs_value, rhs_value); } else { - comparison_result = ir_builder->CreateFCmp(predicate, lhs_value, rhs_value); + comparison_result = b->CreateFCmp(predicate, lhs_value, rhs_value); } // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1 // arrays. So we extend it to i8 so that it's addressable. - return ir_builder->CreateZExt( - comparison_result, - llvm_ir::PrimitiveTypeToIrType(PRED, ModuleFromIRBuilder(ir_builder))); + return b->CreateZExt(comparison_result, llvm_ir::PrimitiveTypeToIrType( + PRED, ModuleFromIRBuilder(b))); } // Internal helper that is called from emitted code to log an int64 value with a @@ -363,17 +360,14 @@ static void LogS64(const char* tag, int64 value) { LOG(INFO) << tag << " (int64): " << value; } -void EmitLogging(const char* tag, llvm::Value* value, - llvm::IRBuilder<>* ir_builder) { +void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b) { llvm::FunctionType* log_function_type = llvm::FunctionType::get( - ir_builder->getVoidTy(), - {ir_builder->getInt64Ty(), ir_builder->getInt64Ty()}, /*isVarArg=*/false); - ir_builder->CreateCall( + b->getVoidTy(), {b->getInt64Ty(), b->getInt64Ty()}, /*isVarArg=*/false); + b->CreateCall( log_function_type, - ir_builder->CreateIntToPtr( - ir_builder->getInt64(tensorflow::bit_cast(&LogS64)), - log_function_type->getPointerTo()), - {ir_builder->getInt64(tensorflow::bit_cast(tag)), value}); + b->CreateIntToPtr(b->getInt64(tensorflow::bit_cast(&LogS64)), + log_function_type->getPointerTo()), + {b->getInt64(tensorflow::bit_cast(tag)), value}); } void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) { @@ -663,5 +657,56 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { } } +std::pair UMulLowHigh32(llvm::IRBuilder<>* b, + llvm::Value* src0, + llvm::Value* src1) { + CHECK_EQ(src0->getType()->getPrimitiveSizeInBits(), 32); + CHECK_EQ(src1->getType()->getPrimitiveSizeInBits(), 32); + llvm::Type* int64_ty = b->getInt64Ty(); + src0 = b->CreateZExt(src0, int64_ty); + src1 = b->CreateZExt(src1, int64_ty); + return SplitInt64ToInt32s(b, b->CreateMul(src0, src1)); +} + +std::pair SplitInt64ToInt32s( + llvm::IRBuilder<>* b, llvm::Value* value_64bits) { + CHECK_EQ(value_64bits->getType()->getPrimitiveSizeInBits(), 64); + llvm::Type* int32_ty = b->getInt32Ty(); + llvm::Value* low_32bits = b->CreateTrunc(value_64bits, int32_ty); + llvm::Value* high_32bits = + b->CreateTrunc(b->CreateLShr(value_64bits, 32), int32_ty); + return std::make_pair(low_32bits, high_32bits); +} + +llvm::GlobalVariable* GetOrCreateVariableForPhiloxRngState( + llvm::Module* module, llvm::IRBuilder<>* b) { + static const char* kPhiloxRngStateVariableName = "philox_rng_state"; + llvm::GlobalVariable* state_ptr = + module->getNamedGlobal(kPhiloxRngStateVariableName); + if (!state_ptr) { + state_ptr = new llvm::GlobalVariable( + /*M=*/*module, + /*Ty=*/b->getInt64Ty(), + /*isConstant=*/false, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/b->getInt64(0), + /*Name=*/kPhiloxRngStateVariableName); + } + return state_ptr; +} + +void IncrementVariableForPhiloxRngState(int64 value, llvm::Module* module, + llvm::IRBuilder<>* builder) { + llvm::GlobalVariable* state_ptr = + GetOrCreateVariableForPhiloxRngState(module, builder); + llvm::Value* state_value_old = builder->CreateLoad(state_ptr, "load_state"); + // If the 64-bit value overflows, we use the wraparound value. This should + // be fine in practice as we only add one to the value each time when a RNG is + // executed. + llvm::Value* state_value_new = builder->CreateAdd( + state_value_old, builder->getInt64(value), "inc_state"); + builder->CreateStore(state_value_new, state_ptr); +} + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 4a10ec466dae6fdb56546fb8d8b353dcff6a5b8d..09583985342033d486d50910b6f5ca732a9a3756 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -27,7 +27,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "llvm/Support/raw_ostream.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/types.h" @@ -105,26 +105,26 @@ llvm::Value* EmitCallToIntrinsic( llvm::Intrinsic::ID intrinsic_id, tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice overloaded_types, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Emit float max. Emit maxnum intrinsic is fast math is disabled, or // fcmp+select otherwise llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Emit float min. Emit minnum intrinsic is fast math is disabled, or // fcmp+select otherwise llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Convenience methods for emitting a GEP instruction that indexes into a buffer // (1-dimensional array), equivalent to array[index]. The type is automatically // determined from the element type of the array. The int64 index overload // wraps the index in a i64 llvm::Value. llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Returns the LLVM type which represents the given XLA primitive type. llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, @@ -139,8 +139,9 @@ llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module); // Returns a value that represents a pointer to a global string constant that // encodes the shape as a serialized protobuf. -StatusOr EncodeSelfDescribingShapeConstant( - const Shape& shape, int32* shape_size, llvm::IRBuilder<>* ir_builder); +StatusOr EncodeSelfDescribingShapeConstant(const Shape& shape, + int32* shape_size, + llvm::IRBuilder<>* b); // Inverses the encoding of a Shape protobuf into an LLVM global variable. // @@ -164,21 +165,21 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, // through a loop. llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder, + llvm::IRBuilder<>* b, int alignment = 0); // As EmitAllocaAtFunctionEntry, but allocates element_count entries // instead of a single element. llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder, int alignment = 0); + llvm::IRBuilder<>* b, int alignment = 0); // Creates a basic block with the same context and function as for the // builder. Inserts at the end of the function if insert_before is // null. llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Struct with data on a conditional branch in a diamond shape created // via EmitIfThenElse. @@ -210,13 +211,13 @@ struct LlvmIfData { // block with a terminator. If you need to use this for a // non-terminated block, just make the function able to do that too. LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, - llvm::IRBuilder<>* ir_builder, bool emit_else = true); + llvm::IRBuilder<>* b, bool emit_else = true); // Emits a compare operation between "lhs" and "rhs" with the given predicate, // and then converts the result to i8 so that it is addressable. llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate, llvm::Value* lhs, llvm::Value* rhs, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Emits a call that logs the given value with the given tag as a prefix. // The provided tag and value are passed to a runtime logging call that is @@ -228,8 +229,7 @@ llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate, // Precondition: value must be an int64. // Precondition: tag must be a stable pointer for the lifetime of the generated // program (the constant pointer is burned in to the program). -void EmitLogging(const char* tag, llvm::Value* value, - llvm::IRBuilder<>* ir_builder); +void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b); // Adds alignment metadata to a load instruction using the given alignment. // The alignment refers to the result of the load, not the load itself. @@ -292,6 +292,27 @@ llvm::Function* CreateFunction(llvm::FunctionType* function_type, // don't start with xla_ to LLVM. void InitializeLLVMCommandLineOptions(const HloModuleConfig& config); +// Zero-extends two 32-bit values to 64 bits, multiplies them, and returns the +// result as a pair of (low 32 bits, high 32 bits). +std::pair UMulLowHigh32(llvm::IRBuilder<>* b, + llvm::Value* src0, + llvm::Value* src1); +// Splits the 64-bit integer value into its high and low 32 bits. +std::pair SplitInt64ToInt32s( + llvm::IRBuilder<>* b, llvm::Value* value_64bits); + +// Checks whether a global variable is already created to represent a +// state passed between RNG calls implemented with Philox algorithm. If not, +// creates such a variable. Returns the global variable. +llvm::GlobalVariable* GetOrCreateVariableForPhiloxRngState( + llvm::Module* module, llvm::IRBuilder<>* b); + +// Adds a value to the global state variable each time when a RNG hlo is +// executed. The value of this global state variable is added to the seed +// of the Philox RNG algorithm so that calling the same RNG Hlo multiple times +// should rarely produce the same result. +void IncrementVariableForPhiloxRngState(int64 value, llvm::Module* module, + llvm::IRBuilder<>* b); } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index e8b0605b9d75677b34f0973d88d269a5795b7629..36f5fa195224c20e30a14f72b32eb42a681bb5e9 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -33,26 +33,24 @@ namespace xla { namespace llvm_ir { LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, - llvm::IRBuilder<>* ir_builder) - : body_emitter_(body_emitter), shape_(shape), ir_builder_(ir_builder) {} + llvm::IRBuilder<>* b) + : body_emitter_(body_emitter), shape_(shape), b_(b) {} LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, - const IrArray& target_array, - llvm::IRBuilder<>* ir_builder) + const IrArray& target_array, llvm::IRBuilder<>* b) : body_emitter_([=](const llvm_ir::IrArray::Index array_index) -> Status { // Convert target_element_generator to a BodyEmitter. TF_ASSIGN_OR_RETURN(llvm::Value * target_element, target_element_generator(array_index)); - target_array.EmitWriteArrayElement(array_index, target_element, - ir_builder); + target_array.EmitWriteArrayElement(array_index, target_element, b); return Status::OK(); }), shape_(target_array.GetShape()), - ir_builder_(ir_builder) {} + b_(b) {} static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion( const ElementGenerator& target_element_generator, - const std::vector& target_arrays, llvm::IRBuilder<>* ir_builder) { + const std::vector& target_arrays, llvm::IRBuilder<>* b) { return [=](const llvm_ir::IrArray::Index array_index) { TF_ASSIGN_OR_RETURN(llvm::Value * target_element, target_element_generator(array_index)); @@ -64,8 +62,7 @@ static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion( for (int64 i = 0; i < target_arrays.size(); ++i) { target_arrays[i].EmitWriteArrayElement( - array_index, ir_builder->CreateExtractValue(target_element, i), - ir_builder); + array_index, b->CreateExtractValue(target_element, i), b); } return Status::OK(); }; @@ -73,13 +70,12 @@ static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion( LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, - llvm::IRBuilder<>* ir_builder) + llvm::IRBuilder<>* b) : body_emitter_(MakeBodyEmitterForMultiOutputFusion( target_element_generator, - std::vector(target_arrays.begin(), target_arrays.end()), - ir_builder)), + std::vector(target_arrays.begin(), target_arrays.end()), b)), shape_(target_arrays[0].GetShape()), - ir_builder_(ir_builder) { + b_(b) { // Sanity check: In multi-output fusion, all shapes produced must have the // same dimensions. for (const IrArray& array : target_arrays) { @@ -102,7 +98,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( // Loops are added from outermost to innermost order with the ForLoopNest // class so emit loops in order from most-major dimension down to most-minor // dimension (of the target shape). - ForLoopNest loop_nest(loop_name, ir_builder_); + ForLoopNest loop_nest(loop_name, b_); IrArray::Index array_index(index_type, shape_.dimensions_size()); for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { int64 dimension = LayoutUtil::Major(shape_.layout(), i); @@ -116,8 +112,8 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( // Set IR builder insertion point to the loop body basic block of the // innermost loop. llvm::BasicBlock* innermost_body_bb = loop_nest.GetInnerLoopBodyBasicBlock(); - ir_builder_->SetInsertPoint(innermost_body_bb, - innermost_body_bb->getFirstInsertionPt()); + b_->SetInsertPoint(innermost_body_bb, + innermost_body_bb->getFirstInsertionPt()); // Set exit_bb_ to the exit block of the loop nest. exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock(); @@ -129,7 +125,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name, llvm::Type* index_type) { if (index_type == nullptr) { - index_type = ir_builder_->getInt64Ty(); + index_type = b_->getInt64Ty(); } for (const IrArray::Index& array_index : @@ -137,10 +133,10 @@ Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name, TF_RETURN_IF_ERROR(body_emitter_(array_index)); } - // Set the insertion point of ir_builder_ to the loop exit, so that + // Set the insertion point of b_ to the loop exit, so that // code emitted for later instructions will be correctly placed. if (exit_bb_ != nullptr) { - ir_builder_->SetInsertPoint(exit_bb_); + b_->SetInsertPoint(exit_bb_); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 6be1c2fba2cbd78a02865901ef8c5b7e2b2a74e6..c4f5c82086ccfa233e0be118b1de10cce55a51b1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -41,11 +41,11 @@ class LoopEmitter { using BodyEmitter = std::function; LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); // Constructs a LoopEmitter from an element generator that generates each // element of the given target array. LoopEmitter(const ElementGenerator& target_element_generator, - const IrArray& target_array, llvm::IRBuilder<>* ir_builder); + const IrArray& target_array, llvm::IRBuilder<>* b); // Constructs a LoopEmitter that emits one element into each of N separate // arrays on each iteration of the loop. @@ -54,7 +54,7 @@ class LoopEmitter { // produce an LLVM struct with N elements. LoopEmitter(const ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* b); LoopEmitter(const LoopEmitter&) = delete; LoopEmitter& operator=(const LoopEmitter&) = delete; @@ -65,8 +65,7 @@ class LoopEmitter { // specifies the element, will return multiple indices if the loop is // unrolled. std::vector EmitIndexAndSetExitBasicBlock() { - return EmitIndexAndSetExitBasicBlock(/*loop_name=*/"", - ir_builder_->getInt64Ty()); + return EmitIndexAndSetExitBasicBlock(/*loop_name=*/"", b_->getInt64Ty()); } virtual std::vector EmitIndexAndSetExitBasicBlock( @@ -87,7 +86,7 @@ class LoopEmitter { // scalar, no loops are emitted and exit_bb_ is nullptr in that case. llvm::BasicBlock* exit_bb_; - llvm::IRBuilder<>* ir_builder_; + llvm::IRBuilder<>* b_; }; } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/math_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/math_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..0e115cdabf4b290617700276dba8f2e5648a7c07 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/math_ops.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" + +namespace xla { +namespace llvm_ir { + +llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input) { + llvm::Type* type = input->getType(); + + // Clamp the input to [-9, 9]. + llvm::Value* input_clamped = llvm_ir::EmitFloatMin( + llvm_ir::EmitFloatMax(input, llvm::ConstantFP::get(type, -9.0), b), + llvm::ConstantFP::get(type, 9.0), b); + + static constexpr std::array numerator_coeffs{ + -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, + 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, + 4.89352455891786e-03f}; + + static constexpr std::array denominator_coeffs{ + 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, + 4.89352518554385e-03f}; + + llvm::Value* input_squared = b->CreateFMul(input_clamped, input_clamped); + llvm::Value* numerator = llvm::ConstantFP::get(type, numerator_coeffs[0]); + for (int i = 1; i < numerator_coeffs.size(); i++) { + numerator = b->CreateFAdd(b->CreateFMul(input_squared, numerator), + llvm::ConstantFP::get(type, numerator_coeffs[i])); + } + + numerator = b->CreateFMul(input_clamped, numerator); + + llvm::Value* denominator = llvm::ConstantFP::get(type, denominator_coeffs[0]); + for (int i = 1; i < denominator_coeffs.size(); i++) { + denominator = + b->CreateFAdd(b->CreateFMul(input_squared, denominator), + llvm::ConstantFP::get(type, denominator_coeffs[i])); + } + + return b->CreateFDiv(numerator, denominator); +} + +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/math_ops.h b/tensorflow/compiler/xla/service/llvm_ir/math_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..6c8bc3a076367eae2f1829966be2872e5f258178 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/math_ops.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_MATH_OPS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_MATH_OPS_H_ + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" + +namespace xla { +namespace llvm_ir { + +// Emits an approximation of tanh. The implementation uses the same rational +// interpolant as implemented in Eigen3. +llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input); + +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_MATH_OPS_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..e546f5cc4ae305b40c1bdbcae090daadee11241b --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -0,0 +1,161 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" + +// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace llvm_ir { + +namespace { +// Adds the inner comparison loop where we compare elements pointed to by +// 'keys_index' and 'compare_keys_index'. +void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, + const IrArray::Index& compare_keys_index, + const IrArray& keys_array, + const tensorflow::gtl::optional& values_array, + llvm::IRBuilder<>* b) { + // if (is_smaller_index && + // compare_keys[dimension_to_sort] < dimension_to_sort_bound) + llvm::Value* is_smaller_index = b->CreateICmpSLT( + keys_index[dimension_to_sort], compare_keys_index[dimension_to_sort]); + int64 dimension_to_sort_bound = + keys_array.GetShape().dimensions(dimension_to_sort); + auto if_data = EmitIfThenElse( + b->CreateAnd(is_smaller_index, + b->CreateICmpSLT(compare_keys_index[dimension_to_sort], + keys_index.GetConstantWithIndexType( + dimension_to_sort_bound))), + "smaller_comparison_index", b, /*emit_else=*/false); + SetToFirstInsertPoint(if_data.true_block, b); + auto key1 = keys_array.EmitReadArrayElement(keys_index, b); + auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, b); + auto key_type = keys_array.GetShape().element_type(); + auto comparison = + primitive_util::IsFloatingPointType(key_type) + // TODO(b/26783907): Figure out how to handle NaNs. + ? b->CreateFCmp(llvm::FCmpInst::FCMP_ULT, key2, key1) + : b->CreateICmp(primitive_util::IsSignedIntegralType(key_type) + ? llvm::ICmpInst::ICMP_SLT + : llvm::ICmpInst::ICMP_ULT, + key2, key1); + // If key2 < key1 + auto if_smaller_data = + EmitIfThenElse(comparison, "is_smaller_than", b, /*emit_else=*/false); + SetToFirstInsertPoint(if_smaller_data.true_block, b); + // Swap key1 with key2. + keys_array.EmitWriteArrayElement(keys_index, key2, b); + keys_array.EmitWriteArrayElement(compare_keys_index, key1, b); + if (values_array.has_value()) { + // Also swap the values. + auto value1 = values_array.value().EmitReadArrayElement(keys_index, b); + auto value2 = + values_array.value().EmitReadArrayElement(compare_keys_index, b); + values_array.value().EmitWriteArrayElement(keys_index, value2, b); + values_array.value().EmitWriteArrayElement(compare_keys_index, value1, b); + } +} +} // namespace + +Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, + const tensorflow::gtl::optional& values_array, + tensorflow::StringPiece name, llvm::Value* xor_mask, + llvm::IRBuilder<>* b, + const gpu::LaunchDimensions* launch_dimensions) { + const Shape& keys_shape = keys_array.GetShape(); + + // Create loop nests which loop through the operand dimensions. The sort + // dimension is handled in the innermost loop which performs the sorting. + ForLoopNest loop_nest(name, b); + IrArray::Index keys_index = + loop_nest.EmitOperandArrayLoopNest(keys_array, dimension_to_sort, "keys"); + if (loop_nest.GetInnerLoopBodyBasicBlock() != nullptr) { + SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), b); + } + + // 'compare_keys_index' is the index of the element that 'keys_index' should + // be compared to. + IrArray::Index compare_keys_index(keys_index.GetType()); + for (size_t dimension = 0; dimension < keys_index.size(); ++dimension) { + if (dimension != dimension_to_sort) { + compare_keys_index.push_back(keys_index[dimension]); + } else { + compare_keys_index.push_back(nullptr); + } + } + + // Naive C++ code for the inner compare loop: + // + // for (int64 i = 0; i < dimension_to_sort_bound; ++i) { + // int64 j = i ^ xor_mask; + // if (i < j && j < dimension_to_sort_bound) { + // int64 min_key = std::min(keys[i], keys[j]); + // keys[j] = std::max(keys[i], keys[j]); + // keys[i] = min_key; + // } + // } + // + // This follows the algorithm described on Wikipedia: + // https://en.wikipedia.org/wiki/Bitonic_sorter + + int64 dimension_to_sort_bound = + keys_array.GetShape().dimensions(dimension_to_sort); + Shape compare_shape = ShapeUtil::MakeShape(keys_shape.element_type(), + {dimension_to_sort_bound}); + auto compare_loop_body_emitter = + [&](const IrArray::Index& compare_index) -> Status { + keys_index[dimension_to_sort] = compare_index[0]; + compare_keys_index[dimension_to_sort] = + b->CreateXor(compare_index[0], xor_mask); + EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index, + keys_array, values_array, b); + return Status::OK(); + }; + if (launch_dimensions != nullptr) { + TF_RETURN_IF_ERROR(gpu::ParallelLoopEmitter(compare_loop_body_emitter, + compare_shape, + *launch_dimensions, b) + .EmitLoop(name)); + } else { + TF_RETURN_IF_ERROR(LoopEmitter(compare_loop_body_emitter, compare_shape, b) + .EmitLoop(name)); + } + + // Set the IR builder insert point to the exit basic block of the outer most + // loop. This ensures later instructions are inserted after this loop nest. + b->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); + + return Status::OK(); +} + +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h new file mode 100644 index 0000000000000000000000000000000000000000..8458744c6bc0e50a1c1cc8d3e66e29c7d4f74d73 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ + +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace llvm_ir { +// Emits llvm IR to do pairwise comparisons/swaps in the 'dimension_to_sort' +// dimension of 'keys_array'. All other dimensions are kept as-is. This +// implements the inner loop of BitonicSort. If 'launch_dimensions' is nullptr, +// the inner compare loop will not be parallelized. +Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, + const tensorflow::gtl::optional& values_array, + tensorflow::StringPiece name, llvm::Value* xor_mask, + llvm::IRBuilder<>* b, + const gpu::LaunchDimensions* launch_dimensions); +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index 5fc08aab916e377b245b6221108956c06da70767..11ed6ee59f1bf8e7004b8bef7319b37ef41a304c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -31,12 +31,12 @@ namespace llvm_ir { void EmitTupleSelect(const IrArray& select, const IrArray& pred, llvm::Value* on_true, llvm::Value* on_false, - llvm::IRBuilder<>* ir_builder, llvm::Module* module) { + llvm::IRBuilder<>* b, llvm::Module* module) { CHECK(ShapeUtil::IsScalar(pred.GetShape())); llvm::LoadInst* pred_value = - ir_builder->CreateLoad(pred.GetBasePointer(), "load_predicate_value"); - llvm::Value* pred_cond = ir_builder->CreateICmpNE( + b->CreateLoad(pred.GetBasePointer(), "load_predicate_value"); + llvm::Value* pred_cond = b->CreateICmpNE( pred_value, llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, module), 0), "boolean_predicate"); @@ -46,47 +46,42 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred, VLOG(2) << " pred_cond: " << DumpToString(*pred_cond); for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) { - llvm::Value* const element_index[] = {ir_builder->getInt64(0), - ir_builder->getInt64(i)}; + llvm::Value* const element_index[] = {b->getInt64(0), b->getInt64(i)}; llvm::Value* on_true_element_address = - ir_builder->CreateInBoundsGEP(on_true, element_index); - llvm::Value* on_true_element = ir_builder->CreateLoad( + b->CreateInBoundsGEP(on_true, element_index); + llvm::Value* on_true_element = b->CreateLoad( on_true_element_address, "on_true_element_" + llvm::Twine(i)); llvm::Value* on_false_element_address = - ir_builder->CreateInBoundsGEP(on_false, element_index); - llvm::Value* on_false_element = ir_builder->CreateLoad( + b->CreateInBoundsGEP(on_false, element_index); + llvm::Value* on_false_element = b->CreateLoad( on_false_element_address, "on_false_element_" + llvm::Twine(i)); llvm::Value* output_element_address = - ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index); - ir_builder->CreateStore( - ir_builder->CreateSelect(pred_cond, on_true_element, on_false_element, - "select_output_element_" + llvm::Twine(i)), - output_element_address); + b->CreateInBoundsGEP(select.GetBasePointer(), element_index); + b->CreateStore(b->CreateSelect(pred_cond, on_true_element, on_false_element, + "select_output_element_" + llvm::Twine(i)), + output_element_address); } } void EmitTuple(const IrArray& tuple, tensorflow::gtl::ArraySlice operands, - llvm::IRBuilder<>* ir_builder, llvm::Module* module) { + llvm::IRBuilder<>* b, llvm::Module* module) { for (size_t i = 0; i < operands.size(); ++i) { - auto* store = ir_builder->CreateStore( - ir_builder->CreatePointerCast(operands[i], - PrimitiveTypeToIrType(TUPLE, module)), - ir_builder->CreateInBoundsGEP( - tuple.GetBasePointer(), - {ir_builder->getInt64(0), ir_builder->getInt64(i)})); + auto* store = b->CreateStore( + b->CreatePointerCast(operands[i], PrimitiveTypeToIrType(TUPLE, module)), + b->CreateInBoundsGEP(tuple.GetBasePointer(), + {b->getInt64(0), b->getInt64(i)})); tuple.AnnotateLoadStoreInstructionWithMetadata(store); } } llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, int alignment, llvm::Value* operand, - llvm::IRBuilder<>* ir_builder, - llvm::Module* module) { - llvm::Value* element_ptr = ir_builder->CreateInBoundsGEP( - operand, {ir_builder->getInt64(0), ir_builder->getInt64(index)}); - llvm::LoadInst* src_buffer = ir_builder->CreateLoad(element_ptr); + llvm::IRBuilder<>* b, llvm::Module* module) { + llvm::Value* element_ptr = + b->CreateInBoundsGEP(operand, {b->getInt64(0), b->getInt64(index)}); + llvm::LoadInst* src_buffer = b->CreateLoad(element_ptr); // Mark the loaded pointer as dereferenceable if we know its shape. if (!ShapeUtil::IsOpaque(target_shape)) { @@ -98,7 +93,7 @@ llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, llvm::Type* element_type = ShapeToIrType(target_shape, module); llvm::Value* ret_val = - ir_builder->CreateBitCast(src_buffer, element_type->getPointerTo()); + b->CreateBitCast(src_buffer, element_type->getPointerTo()); return ret_val; } diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h index 352d34ebf839c6c2465abade7c3d3eb3b7a34506..cf6bf5d0b14ba71cbed67f9a1dc728c0eef5e393 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -61,13 +61,13 @@ namespace llvm_ir { // output[i] = pred ? tuple_on_true[i] : tuple_on_false[i] void EmitTupleSelect(const IrArray& select, const IrArray& pred, llvm::Value* on_true, llvm::Value* on_false, - llvm::IRBuilder<>* ir_builder, llvm::Module* module); + llvm::IRBuilder<>* b, llvm::Module* module); // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. void EmitTuple(const IrArray& tuple, tensorflow::gtl::ArraySlice operands, - llvm::IRBuilder<>* ir_builder, llvm::Module* module); + llvm::IRBuilder<>* b, llvm::Module* module); // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. A GetTupleElement instruction @@ -75,8 +75,7 @@ void EmitTuple(const IrArray& tuple, // Returns an llvm value representing a pointer to the tuple element buffer. llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, int alignment, llvm::Value* operand, - llvm::IRBuilder<>* ir_builder, - llvm::Module* module); + llvm::IRBuilder<>* b, llvm::Module* module); } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 53efc30c3653879709fceae3dcdd4f679740f622..5e02096ee501b23a7976a50f13bb7e7f3c5e2d34 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 39d6734c3fc06df6832cf67edddbc7c14c815cd1..8f707ea9046a00a15cac469672a7a992f20bf483 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/executable_build_options.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index f410921b4b5337192bdeae5924631d9c06b7d5a5..d631fb5ee42df6525681a5cd1fe1a8241824121d 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -131,18 +131,23 @@ Status LogicalBufferAnalysis::HandleDomain(HloInstruction*) { return Status::OK(); } -Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) { - // RecvDone doesn't create a new buffer but rather aliases its input (Recv) - // tuple element at {0} to its output. +Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction* recv_done) { + // RecvDone produces a two-element tuple containing the data value (which + // aliases part of its operand) and a token. Only the tuple index table and + // the token are defined by the RecvDone. + NewLogicalBuffer(recv_done, /*index=*/{}); + NewLogicalBuffer(recv_done, /*index=*/{1}); return Status::OK(); } Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) { - // Send creates new buffers for the top-level tuple and the context (tuple - // element at {1}). Tuple element at {0} is an alias of the Send operand, so - // we don't need to create a new Logical Buffer for that. + // Send creates new buffers for the top-level tuple, the context (tuple + // element at {1}), and the token (tuple element at {2}). Tuple element at {0} + // is an alias of the Send operand, so we don't need to create a new Logical + // Buffer for that. NewLogicalBuffer(send, /*index=*/{}); NewLogicalBuffer(send, /*index=*/{1}); + NewLogicalBuffer(send, /*index=*/{2}); return Status::OK(); } @@ -152,10 +157,10 @@ Status LogicalBufferAnalysis::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -Status LogicalBufferAnalysis::HandleSelect(HloInstruction* select) { +Status LogicalBufferAnalysis::HandleTupleSelect(HloInstruction* tuple_select) { // Select allocates a new buffer and then shallow copies the on_true or // on_false buffer into this new buffer. - NewLogicalBuffer(select, /*index=*/{}); + NewLogicalBuffer(tuple_select, /*index=*/{}); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index b5ef3967875a58b35631d5f69c210f5cbcd91250..81f524d84a8091e1fff13dc7c55b401143a02753 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -63,7 +63,7 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; - Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; // A map from the buffer ID to the logical buffer std::vector> logical_buffers_; diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 79b5a442aa0ecd0f67ffe4dad50465627d8913fd..4166ef5baf9c891968b584a0c498005e9ae87784 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -115,39 +115,18 @@ HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1, HloInstruction* fused = instr2; // Make sure that if only one of the instructions is a fusion, or if only one // of the instructions is a multi-output fusion, it's what will be fused into. - // - // An invariant is that no bitcast nodes will show up in the middle of a - // fusion node. This invariant must hold in order for us to lower it. Given - // that, we require that during multi-output fusion, a fusion node ending with - // bitcast to preserve its structure as a nested fusion instead being - // merged and flattened. - if (fused->opcode() == HloOpcode::kFusion && - fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + if (fused->opcode() == HloOpcode::kFusion) { std::swap(remaining, fused); } if (fused->IsMultiOutputFusion()) { std::swap(remaining, fused); } - if (fused->opcode() == HloOpcode::kFusion && - fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + if (fused->opcode() == HloOpcode::kFusion) { remaining->MergeFusionInstructionIntoMultiOutput(fused); } else { - if (remaining->opcode() == HloOpcode::kFusion && - remaining->fused_expression_root()->opcode() == HloOpcode::kBitcast) { - auto parent_computation = remaining->parent(); - // Create a nested fusion node. - auto remaining_nested_fused = - parent_computation->AddInstruction(HloInstruction::CreateFusion( - remaining->shape(), HloInstruction::FusionKind::kLoop, - remaining)); - TF_CHECK_OK(parent_computation->ReplaceInstruction( - remaining, remaining_nested_fused)); - remaining = remaining_nested_fused; - } remaining->FuseInstructionIntoMultiOutput(fused); } - return remaining; } diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index d23822e33e11ede0c5cac97e9fe2b0c3dc88cf3d..0019cd725417d81900974b462c3b05075ce3e893 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -78,6 +78,10 @@ class MultiOutputFusion : public HloPassInterface { // Test if it's legal to fuse instr1 and instr2 into one fusion instruction. virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2); + // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction. + // The other instruction is removed from its parent computation. + virtual HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2); + // Recompute reachability for the current computation. void RecomputeReachability(); @@ -101,10 +105,6 @@ class MultiOutputFusion : public HloPassInterface { virtual bool DoProducerConsumerMultiOutputFusion(); private: - // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction. - // The other instruction is removed from its parent computation. - HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2); - // Update the internal data structures after instr1 and instr2 are fused into // one fusion instruction. void Update(HloInstruction* instr1, HloInstruction* instr2); diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 2515222cf2db3d9699c85c13f4fe72b3488fa217..ac6ea4c72f61a47726b3ae7dd000837d3fba1b93 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -86,8 +86,8 @@ namespace xla { // are provided below. // // Example nullary instruction: -// Recv() == Op().WithOpcode(HloOpcode::kRecv) -// Recv(&a) == Op(&a).WithOpcode(HloOpcode::kRecv) +// Param() == Op().WithOpcode(HloOpcode::kParam) +// Param(&a) == Op(&a).WithOpcode(HloOpcode::kParam) // // Example unary instruction: // Abs() == Op().WithOpcode(HloOpcode::kAbs) @@ -726,6 +726,32 @@ class HloInstructionPatternFusionKindImpl { ::xla::HloInstruction::FusionKind kind_; }; +// An HloInstructionPattern implementation that matches only if the instruction +// is a kGetTupleElement with a particular tuple index. +template +class HloInstructionPatternTupleIndexImpl { + public: + explicit constexpr HloInstructionPatternTupleIndexImpl( + const Previous& previous, int64 tuple_index) + : previous_(previous), tuple_index_(tuple_index) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && + inst->opcode() == HloOpcode::kGetTupleElement && + inst->tuple_index() == tuple_index_; + } + + bool Match(::xla::HloInstruction* inst) const { + return previous_.Match(inst) && + inst->opcode() == HloOpcode::kGetTupleElement && + inst->tuple_index() == tuple_index_; + } + + private: + Previous previous_; + int64 tuple_index_; +}; + // A pattern that matches HloInstructions. template class HloInstructionPattern { @@ -841,6 +867,17 @@ class HloInstructionPattern { HloInstructionPatternFusionKindImpl(impl_, kind), matched_inst_); } + // Modifies the pattern to match only if the instruction is a + // get-tuple-element with the given tuple index. + constexpr HloInstructionPattern> + WithTupleIndex(int64 tuple_index) const { + return HloInstructionPattern>( + HloInstructionPatternTupleIndexImpl(impl_, tuple_index), + matched_inst_); + } + private: Impl impl_; HloInstructionType** matched_inst_; @@ -880,9 +917,7 @@ Op(::xla::HloInstruction** matched_inst) { return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \ } XLA_NULLOP_PATTERN(Constant) -XLA_NULLOP_PATTERN(Infeed) XLA_NULLOP_PATTERN(Parameter) -XLA_NULLOP_PATTERN(Recv) #undef XLA_NULLOP_PATTERN // Helpers for unary instructions. @@ -919,18 +954,21 @@ XLA_UNOP_PATTERN(Cos) XLA_UNOP_PATTERN(Exp) XLA_UNOP_PATTERN(Fft) XLA_UNOP_PATTERN(Floor) +XLA_UNOP_PATTERN(GetTupleElement) XLA_UNOP_PATTERN(Imag) +XLA_UNOP_PATTERN(Infeed) XLA_UNOP_PATTERN(IsFinite) XLA_UNOP_PATTERN(Log) XLA_UNOP_PATTERN(Not) XLA_UNOP_PATTERN(Negate) -XLA_UNOP_PATTERN(Outfeed) XLA_UNOP_PATTERN(Real) +XLA_UNOP_PATTERN(Recv) +XLA_UNOP_PATTERN(RecvDone) XLA_UNOP_PATTERN(Reduce) XLA_UNOP_PATTERN(ReducePrecision) XLA_UNOP_PATTERN(Reshape) XLA_UNOP_PATTERN(Reverse) -XLA_UNOP_PATTERN(Send) +XLA_UNOP_PATTERN(SendDone) XLA_UNOP_PATTERN(Sign) XLA_UNOP_PATTERN(Sin) XLA_UNOP_PATTERN(Sort) @@ -981,8 +1019,10 @@ XLA_BINOP_PATTERN(Maximum) XLA_BINOP_PATTERN(Minimum) XLA_BINOP_PATTERN(Multiply) XLA_BINOP_PATTERN(Ne) +XLA_BINOP_PATTERN(Outfeed) XLA_BINOP_PATTERN(Power) XLA_BINOP_PATTERN(Remainder) +XLA_BINOP_PATTERN(Send) XLA_BINOP_PATTERN(Subtract) XLA_BINOP_PATTERN(And) XLA_BINOP_PATTERN(Or) @@ -1040,6 +1080,32 @@ inline auto NonConstant(HloInstructionType** matched_inst) return Op(matched_inst).IsNonConstant(); } +// Add overloads for GetTupleElement which take a int64 specifying which tuple +// element is selected. +template +inline auto GetTupleElement(Arg&& arg, int64 tuple_index) + -> decltype(Op().WithOpcode(HloOpcode::kGetTupleElement) + .WithOperand(0, std::forward(arg)) + .WithTupleIndex(tuple_index)) { + return Op() + .WithOpcode(HloOpcode::kGetTupleElement) + .WithOperand(0, std::forward(arg)) + .WithTupleIndex(tuple_index); +} + +template +inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg, + int64 tuple_index) + -> decltype(Op(matched_inst) + .WithOpcode(HloOpcode::kGetTupleElement) + .WithOperand(0, std::forward(arg)) + .WithTupleIndex(tuple_index)) { + return Op(matched_inst) + .WithOpcode(HloOpcode::kGetTupleElement) + .WithOperand(0, std::forward(arg)) + .WithTupleIndex(tuple_index); +} + } // namespace match } // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index fef3c132b0f3467a01b02f2be88b419459179277..a530581c34bf1d699eae3c53203c197f7943cc53 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -193,5 +193,23 @@ TEST(PatternMatcherTest, FusionKind) { HloInstruction::FusionKind::kLoop))); } +TEST(PatternMatcherTest, GetTupleElement) { + constexpr char kModuleStr[] = R"( + HloModule test_module + + ENTRY while.v11 { + p0 = (f32[], f32[], f32[]) parameter(0) + ROOT gte = f32[] get-tuple-element(p0), index=1 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + + auto* root = hlo_module->entry_computation()->root_instruction(); + EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(0))); + EXPECT_TRUE(Match(root, match::Op().WithTupleIndex(1))); + EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(2))); + EXPECT_FALSE(Match(root, match::GetTupleElement(match::Op(), 0))); + EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 7c63c0acc7764d558b2151190f0fa79fac355cbf..39fe3c7835d1c74c0f1e5bc0ebf5916ec734c24a 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -75,19 +75,6 @@ PlatformUtil::GetSupportedPlatforms() { auto* platform = platform_pair.second; auto compiler_status = Compiler::GetForPlatform(platform); if (compiler_status.ok()) { - if (platform->VisibleDeviceCount() > 0) { - LOG(INFO) << "platform " << platform->Name() << " present with " - << platform->VisibleDeviceCount() << " visible devices"; - } else { - LOG(WARNING) << "platform " << platform->Name() << " present but no " - << "visible devices found"; - } - // Note: currently we call zero device platforms "supported" on the basis - // that, if the platform support was linked in, it was probably intended - // to be used for execution, and this way we can flag an error. - // - // TODO(b/33730287) If we want an alternative version of this behavior we - // could add an --xla_fallback_to_host flag. platforms.push_back(platform); } else { LOG(INFO) << "platform " << platform->Name() << " present but no " diff --git a/tensorflow/compiler/xla/service/pool.h b/tensorflow/compiler/xla/service/pool.h deleted file mode 100644 index 8e710ebb6dc17e0e204ba6ab3c6c159627cd9d3b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/pool.h +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_POOL_H_ -#define TENSORFLOW_COMPILER_XLA_POOL_H_ - -#include -#include - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/core/platform/mutex.h" - -namespace xla { - -// Pool of values, which are created as needed and destroyed when the `Pool` is -// destroyed -template -class Pool { - public: - struct Deleter { - void operator()(T* ptr) { pool->Deallocate(ptr); } - - Pool* pool; - }; - - // A pointer to a taken element of a `Pool` which returns it to the pool on - // destruction - using SmartPtr = std::unique_ptr; - - // Constructs a `Pool` with given factory function, which need not be - // thread-safe. - explicit Pool(std::function()> factory) - : factory_(factory) {} - - explicit Pool() : Pool([]() { return MakeUnique(); }) {} - - // Returns a pointer to a value in the pool, creating a new value if none is - // free. The returned smart pointer returns the element to the pool on - // destruction. - // - // This method is thread-safe. - SmartPtr Allocate() { - tensorflow::mutex_lock lock(mu_); - T* ptr; - if (!xs_.empty()) { - ptr = std::move(xs_.back()).release(); - xs_.pop_back(); - } else { - ptr = factory_().release(); - } - Deleter del = {this}; - return std::unique_ptr(ptr, del); - } - - private: - // Puts a pointer to a value back into the pool, leaving it free for future - // use. - // - // This method is thread-safe. - void Deallocate(T* ptr) { - tensorflow::mutex_lock lock(mu_); - xs_.push_back(std::unique_ptr(ptr)); - } - - const std::function()> factory_ GUARDED_BY(mu_); - std::vector> xs_ GUARDED_BY(mu_); - tensorflow::mutex mu_; -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_POOL_H_ diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 49ec38eb62c7b51c7a2d301d882cef032b288036..ca86c5d13e98a98c62d0c9e8e32e28fe99e0fa1f 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -38,7 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 13e2d3258e3b92f52320201c382594962c0e3b2b..ad3b662c20ac53b0a6d634b16b3b908f730f3d2d 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -175,8 +175,9 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{true, true, false}, {false, false, true}}))); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2( + {{true, true, false}, {false, false, true}}))); auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param1")); @@ -255,12 +256,12 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {3, 2}); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const0)); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1)); @@ -309,7 +310,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param0")); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); builder.AddInstruction(HloInstruction::CreateBinary( @@ -348,7 +349,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1, 3}), "param0")); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({9, 8, 7}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({9, 8, 7}))); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); auto reshape1 = diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index da3b622bfae8ac5132f9f95070ee41674e79b5b8..433560e32258a01c10ade85ec2faf52884271497 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/source_map_util.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -55,7 +56,6 @@ limitations under the License. using ::tensorflow::strings::Printf; using ::tensorflow::strings::StrCat; -using ::xla::source_map_util::InvalidParameterArgument; namespace xla { @@ -169,7 +169,8 @@ Service::Service(const ServiceOptions& options, Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg, CreateChannelHandleResponse* result) { - *result->mutable_channel() = channel_tracker_.NewChannel(); + TF_ASSIGN_OR_RETURN(*result->mutable_channel(), + channel_tracker_.NewChannel(arg->channel_type())); return Status::OK(); } @@ -375,7 +376,7 @@ Service::ExecuteParallelAndRegisterResult( ExecutionProfile* profile) { // Streams where the computation are launched, so we can wait on the streams // to complete. - std::vector::SmartPtr> streams; + std::vector streams; std::vector> timers; // Global data handles for the computation results, one for each computation. @@ -402,7 +403,7 @@ Service::ExecuteParallelAndRegisterResult( CHECK_EQ(replicas.size(), arguments[i].size()); std::vector result_buffers; for (int64 replica = 0; replica < replicas.size(); ++replica) { - TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, + TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream, backend->BorrowStream(replicas[replica])); streams.push_back(std::move(stream)); @@ -514,13 +515,13 @@ StatusOr Service::ExecuteAndRegisterResult( arguments, Backend* backend, const string& result_tag, ExecutionProfile* profile) { // Set up streams. - std::vector::SmartPtr> streams; + std::vector streams; TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, SingleComputationDeviceHandle())); TF_RET_CHECK(!replicas.empty()); for (se::StreamExecutor* executor : replicas) { - TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, + TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream, backend->BorrowStream(executor)); streams.push_back(std::move(stream)); } @@ -532,7 +533,7 @@ StatusOr Service::ExecuteAndRegisterResult( // Set up run options. std::vector run_options; - for (const Pool::SmartPtr& stream : streams) { + for (const StreamPool::Ptr& stream : streams) { ExecutableRunOptions options; options.set_stream(stream.get()); options.set_device_ordinal(stream->parent()->device_ordinal()); @@ -1051,11 +1052,12 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, executor = replicas[arg->replica_id()]; } - Literal literal; + auto literal = Literal::CreateFromShape(arg->shape_with_layout()); + TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( - executor, arg->shape_with_layout(), &literal)); - *result->mutable_literal() = literal.ToProto(); + executor, arg->shape_with_layout(), *literal)); + *result->mutable_literal() = literal->ToProto(); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index 7f3910cdb0366078b97fb5f6a2dc498b37570926..dbfed628bfcabffe66bef41a82e0e2430897d80d 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_EXECUTABLE_RUN_OPTIONS_H_ #include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/stream_executor/stream_executor.h" @@ -27,8 +27,7 @@ namespace xla { // data, now only a stream cache for GPU backend. class ServiceExecutableRunOptions { public: - using StreamBorrower = - std::function::SmartPtr>(int)>; + using StreamBorrower = std::function(int)>; ServiceExecutableRunOptions() : ServiceExecutableRunOptions(ExecutableRunOptions()) {} @@ -51,7 +50,7 @@ class ServiceExecutableRunOptions { // Borrows a stream and returns a smart pointer which returns the stream on // destruction. - StatusOr::SmartPtr> BorrowStream(int device_ordinal) const { + StatusOr BorrowStream(int device_ordinal) const { return borrow_stream_ ? borrow_stream_(device_ordinal) : Status(tensorflow::error::UNIMPLEMENTED, "No stream cache"); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 096bbde92220cdd3feba4d7c51932f2bc135d3ec..a4ea2b28f4dbf41d61702f1af2d65c4d2c86d578 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -58,66 +58,101 @@ Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) { return Status::OK(); } -Status VerifyReducerShape(const ProgramShape& reducer_shape, - const Shape& init_value_shape, - const PrimitiveType& input_element_type) { - if (reducer_shape.parameters_size() != 2) { - return InvalidArgument( - "Reduction function must take 2 parameters, but " +Status VerifyReducerShape( + const ProgramShape& reducer_shape, + tensorflow::gtl::ArraySlice init_value_shapes, + tensorflow::gtl::ArraySlice input_element_types, + int64 inputs) { + if (reducer_shape.parameters_size() != inputs * 2) { + return InvalidArgument( + "Reduction function must take %lld parameters, but " "takes %d parameter(s).", - reducer_shape.parameters_size()); + inputs * 2, reducer_shape.parameters_size()); } const Shape& accumulator_shape = reducer_shape.result(); - if (ShapeUtil::Rank(accumulator_shape) != 0) { - return InvalidArgument( - "Reduction function must have rank 0 (rank %lld reduction function " - "given).", - ShapeUtil::Rank(accumulator_shape)); - } - - // Check that the accumulator can be passed in as the first argument. - // Note: comparing here and below with Compatible since we don't care about - // layout in scalars - see b/26668201 for a longer-term vision. - if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(0))) { + std::vector accumulator_subshapes; + if (ShapeUtil::IsArray(accumulator_shape)) { + if (inputs != 1) { + return InvalidArgument( + "Reduction function must produce a tuple with %lld elements, but " + "produces a scalar", + inputs); + } + accumulator_subshapes.push_back(&accumulator_shape); + } else if (ShapeUtil::IsTuple(accumulator_shape)) { + if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) { + return InvalidArgument( + "Reduction function must produce a tuple with %lld elements, but has " + "%lld elements", + inputs, ShapeUtil::TupleElementCount(accumulator_shape)); + } + for (const Shape& element_shape : accumulator_shape.tuple_shapes()) { + accumulator_subshapes.push_back(&element_shape); + } + } else { return InvalidArgument( - "Reduction function's first parameter shape differs from the " - "result shape: %s vs %s", - ShapeUtil::HumanString(reducer_shape.parameters(0)).c_str(), + "Reduction function must produce a scalar or tuple of scalars, but has " + "shape: %s", ShapeUtil::HumanString(accumulator_shape).c_str()); } - // Check that init_value's shape is suitable for reducer_shape. - if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, - init_value_shape)) { - return InvalidArgument( - "Reduction function's accumulator shape differs from the " - "init_value shape: %s vs %s", - ShapeUtil::HumanString(accumulator_shape).c_str(), - ShapeUtil::HumanString(init_value_shape).c_str()); - } - - // Check that the inputs can be passed in as the second argument. - const Shape& input_element_shape = - ShapeUtil::MakeShape(input_element_type, {}); - if (!ShapeUtil::CompatibleIgnoringFpPrecision(input_element_shape, - reducer_shape.parameters(1))) { - return InvalidArgument( - "Reduction function's second parameter shape differs from the " - "input type element type: %s vs %s", - ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(), - ShapeUtil::HumanString(input_element_shape).c_str()); + for (const Shape* element_shape : accumulator_subshapes) { + if (ShapeUtil::Rank(*element_shape) != 0) { + return InvalidArgument( + "Reduction function must return a scalar or tuple of scalars but " + "returns shape: %s", + ShapeUtil::HumanString(accumulator_shape).c_str()); + } } - // Currently the accumulator and inputs must be the same type, - // though that restriction could be relaxed. - if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, - reducer_shape.parameters(1))) { - return InvalidArgument( - "Reduction function's second parameter shape must " - "match the result shape, but got %s vs %s.", - ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(), - ShapeUtil::HumanString(accumulator_shape).c_str()); + for (int64 i = 0; i < inputs; ++i) { + // Check that the accumulator can be passed in as the first argument. + // Note: comparing here and below with Compatible since we don't care about + // layout in scalars - see b/26668201 for a longer-term vision. + if (!ShapeUtil::Compatible(*accumulator_subshapes[i], + reducer_shape.parameters(i))) { + return InvalidArgument( + "Reduction function's %lld-th parameter shape differs from the " + "result shape: %s vs %s", + i, ShapeUtil::HumanString(reducer_shape.parameters(i)).c_str(), + ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + } + // Check that init_value's shapes are suitable for reducer_shape. + if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i], + *init_value_shapes[i])) { + return InvalidArgument( + "Reduction function's accumulator shape at index %lld differs from " + "the init_value shape: %s vs %s", + i, ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str(), + ShapeUtil::HumanString(*init_value_shapes[i]).c_str()); + } + // Check that the inputs can be passed in as the non-accumulator arguments. + const Shape input_element_shape = + ShapeUtil::MakeShape(input_element_types[i], {}); + if (!ShapeUtil::CompatibleIgnoringFpPrecision( + input_element_shape, reducer_shape.parameters(inputs + i))) { + return InvalidArgument( + "Reduction function's %lld-th parameter shape differs from the " + "input type element type: %s vs %s", + inputs + i, + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), + ShapeUtil::HumanString(input_element_shape).c_str()); + } + // Check that the accumulator and inputs to the reducer function match. + // If the accumulator is scalar, it must have the same type as the inputs + // (up to fp precision). If it is a tuple, then the k-th element of the + // tuple must have the same type as the K-th input (again, up to fp + // precision.) + if (!ShapeUtil::CompatibleIgnoringFpPrecision( + *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) { + return InvalidArgument( + "Reduction function's %lld-th parameter shape must " + "match the result shape, but got %s vs %s.", + inputs + i, + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), + ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + } } return Status::OK(); @@ -222,13 +257,16 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return shape; case HloOpcode::kReal: case HloOpcode::kImag: - if (!ShapeUtil::ElementIsComplex(shape)) { + if (ShapeUtil::ElementIsComplex(shape)) { + return ShapeUtil::ComplexComponentShape(shape); + } else if (ShapeUtil::ElementIsFloating(shape)) { + return shape; + } else { return InvalidArgument( - "Expected element type in shape to be complex for real/imag " - "operation; got %s.", + "Expected element type in shape to be floating or complex for " + "real/imag operation; got %s.", PrimitiveType_Name(shape.element_type()).c_str()); } - return ShapeUtil::ChangeElementType(shape, F32); case HloOpcode::kAbs: if (ShapeUtil::ElementIsComplex(shape)) { return ShapeUtil::ChangeElementType( @@ -239,7 +277,6 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case HloOpcode::kNegate: case HloOpcode::kRoundNearestAfz: case HloOpcode::kSign: - case HloOpcode::kSort: return shape; case HloOpcode::kNot: @@ -930,6 +967,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InferClampShape(lhs, rhs, ehs); case HloOpcode::kSelect: return InferSelectShape(lhs, rhs, ehs); + case HloOpcode::kTupleSelect: + return InferTupleSelectShape(lhs, rhs, ehs); default: return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode).c_str()); @@ -962,6 +1001,23 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } return result; } + case HloOpcode::kSort: { + if (operand_shapes.size() == 1) { + return *operand_shapes[0]; + } else if (operand_shapes.size() == 2) { + if (!ShapeUtil::SameDimensions(*operand_shapes[0], + *operand_shapes[1])) { + return InvalidArgument( + "Sort keys and values dimensions must match. " + "Keys shape is: %s\n, Values shape is: %s", + ShapeUtil::HumanString(*operand_shapes[0]).c_str(), + ShapeUtil::HumanString(*operand_shapes[1]).c_str()); + } + return ShapeUtil::MakeTupleShape( + {*operand_shapes[0], *operand_shapes[1]}); + } + return InvalidArgument("Unexpected number of operands for sort"); + } default: return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode).c_str()); @@ -1723,11 +1779,83 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeTupleShape(operand_shape_values); } +/* static */ StatusOr ShapeInference::InferAllToAllShape( + const Shape& shape, int64 split_dimension, int64 concat_dimension, + int64 split_count) { + TF_RET_CHECK(split_count > 0); + if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) { + return InvalidArgument( + "AllToAll split_dimension %lld is out-of-bounds in shape %s.", + split_dimension, ShapeUtil::HumanString(shape).c_str()); + } + if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) { + return InvalidArgument( + "AllToAll concat_dimension %lld is out-of-bounds in shape %s.", + concat_dimension, ShapeUtil::HumanString(shape).c_str()); + } + if (shape.dimensions(split_dimension) % split_count != 0) { + return InvalidArgument( + "AllToAll split dimension size %lld must be dividable by split_count " + "%lld.", + shape.dimensions(split_dimension), split_count); + } + std::vector new_dimensions(shape.dimensions().begin(), + shape.dimensions().end()); + new_dimensions[split_dimension] /= split_count; + new_dimensions[concat_dimension] *= split_count; + return ShapeUtil::MakeShape(shape.element_type(), new_dimensions); +} + +/* static */ StatusOr ShapeInference::InferAllToAllTupleShape( + tensorflow::gtl::ArraySlice operand_shapes) { + // An Alltoall HLO instruction receives N operands (with the same shape) and + // returns a tuple that contains N array shapes. + TF_RET_CHECK(!operand_shapes.empty()); + for (int i = 0; i < operand_shapes.size(); i++) { + if (!ShapeUtil::Equal(*operand_shapes[0], *operand_shapes[i])) { + return InvalidArgument( + "HLO all-to-all has operands with different shapes: the 0th " + "operand shape %s, but the %dth operand has shape %s.", + ShapeUtil::HumanString(*operand_shapes[0]).c_str(), i, + ShapeUtil::HumanString(*operand_shapes[i]).c_str()); + } + } + + return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes); +} + /* static */ StatusOr ShapeInference::InferReduceShape( - const Shape& arg, const Shape& init_value, + tensorflow::gtl::ArraySlice arg_shapes, tensorflow::gtl::ArraySlice dimensions_to_reduce, const ProgramShape& to_apply) { - // Check that the dimension to reduce are in-bounds for the given shape. + if (arg_shapes.empty()) { + return InvalidArgument("Reduce must have at least 2 arguments, has 0"); + } + if (arg_shapes.size() % 2) { + return InvalidArgument( + "Reduce must have an even number of arguments, has %lu", + arg_shapes.size()); + } + int64 num_reduced_args = arg_shapes.size() / 2; + + tensorflow::gtl::ArraySlice reduced_args(arg_shapes, 0, + num_reduced_args); + // Check that all of the reduced tensors have the same dimensions. The element + // types may be different. + for (int64 i = 1; i < num_reduced_args; ++i) { + if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) { + return InvalidArgument( + "All reduced tensors must have the sime dimension. Tensor 0 has " + "shape %s, Tensor %lld has shape %s", + ShapeUtil::HumanString(*reduced_args[0]).c_str(), i, + ShapeUtil::HumanString(*reduced_args[i]).c_str()); + } + } + + // Check that the dimensions to reduce are in-bounds for the given shape. + // We've already verified all reduced tensors have the same dimensions, so it + // doesn't matter which one we choose. + const Shape& arg = *reduced_args[0]; for (int64 dimension : dimensions_to_reduce) { if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) { return InvalidArgument( @@ -1735,8 +1863,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(arg).c_str()); } } - TF_RETURN_IF_ERROR( - VerifyReducerShape(to_apply, init_value, arg.element_type())); + + tensorflow::gtl::ArraySlice init_values( + arg_shapes, num_reduced_args, arg_shapes.size()); + std::vector element_types; + for (const Shape* arg : reduced_args) { + element_types.push_back(arg->element_type()); + } + TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply, init_values, element_types, + num_reduced_args)); std::set dimensions_to_reduce_set(dimensions_to_reduce.begin(), dimensions_to_reduce.end()); @@ -1747,15 +1882,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } - return ShapeUtil::MakeShape(to_apply.result().element_type(), new_dimensions); + if (ShapeUtil::IsScalar(to_apply.result())) { + return ShapeUtil::MakeShape(to_apply.result().element_type(), + new_dimensions); + } else { + std::vector result_subshapes; + for (const Shape& subshape : to_apply.result().tuple_shapes()) { + result_subshapes.push_back( + ShapeUtil::MakeShape(subshape.element_type(), new_dimensions)); + } + return ShapeUtil::MakeTupleShape(result_subshapes); + } } /* static */ StatusOr ShapeInference::InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value_shape, const Window& window, const ProgramShape& to_apply_shape) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window")); - TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape, - operand_shape.element_type())); + TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape}, + {operand_shape.element_type()}, + /*inputs=*/1)); return InferWindowOutputShape(operand_shape, window, init_value_shape.element_type(), /*allow_negative_padding=*/false); @@ -1800,8 +1946,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } // Check if the scatter function has a proper shape as a reduction. - TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, init_value_shape, - source_shape.element_type())); + TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, {&init_value_shape}, + {source_shape.element_type()}, + /*inputs=*/1)); // Check if the result shape of window operation matches the source shape. TF_ASSIGN_OR_RETURN(const Shape& window_result_shape, @@ -2259,15 +2406,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, // broadcast from all operands, not just the predicate. /* static */ StatusOr ShapeInference::InferSelectShape( const Shape& pred, const Shape& on_true, const Shape& on_false) { - bool compatible; - if (ShapeUtil::IsTuple(on_true)) { - // Select only defines the top-level buffer, so if it's a tuple, the two - // input must match exactly. - compatible = ShapeUtil::Compatible(on_true, on_false); - } else { - compatible = ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false); - } - if (!compatible) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) { return InvalidArgument( "Operands to select must be the same shape; got %s and %s.", ShapeUtil::HumanString(on_true).c_str(), @@ -2279,7 +2418,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(pred).c_str()); } if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) || - ShapeUtil::Rank(pred) == 0) { + ShapeUtil::IsScalar(pred)) { // By this stage we know that pred's element type is PRED. Therefore, this // check restricts pred to be a PRED scalar, or a PRED array with the same // dimensions as on_true and on_false. @@ -2293,6 +2432,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } +/* static */ StatusOr ShapeInference::InferTupleSelectShape( + const Shape& pred, const Shape& on_true, const Shape& on_false) { + // Select only defines the top-level buffer, so if it's a tuple, the two + // input must match exactly. + if (!ShapeUtil::Compatible(on_true, on_false)) { + return InvalidArgument( + "Operands to tuple-select must be the same shape; got %s and %s.", + ShapeUtil::HumanString(on_true).c_str(), + ShapeUtil::HumanString(on_false).c_str()); + } + if (pred.element_type() != PRED) { + return InvalidArgument( + "TupleSelect's pred operand must have PRED element type; got %s.", + ShapeUtil::HumanString(pred).c_str()); + } + if (!ShapeUtil::IsScalar(pred)) { + return InvalidArgument( + "TupleSelect operation with non-scalar predicate: %s.", + ShapeUtil::HumanString(pred).c_str()); + } + return on_true; +} + /* static */ StatusOr ShapeInference::InferCallShape( tensorflow::gtl::ArraySlice arg_shapes, const ProgramShape& to_apply) { @@ -2532,4 +2694,194 @@ static Status ValidateGatherDimensionNumbers( return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds); } +namespace { + +Status ValidateScatterDimensionNumbers( + const Shape& operand_shape, + tensorflow::gtl::ArraySlice scatter_indices_shape, + const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { + // Validate update_window_dims in ScatterDimensionNumbers. + if (!c_is_sorted(dim_numbers.update_window_dims())) { + return InvalidArgument( + "update_window_dims in scatter op must be sorted; got: %s.", + Join(dim_numbers.update_window_dims(), ", ").c_str()); + } + if (c_adjacent_find(dim_numbers.update_window_dims()) != + dim_numbers.update_window_dims().end()) { + return InvalidArgument( + "update_window_dims in scatter op must not repeat; got: %s.", + Join(dim_numbers.update_window_dims(), ", ").c_str()); + } + const int64 updates_rank = ShapeUtil::Rank(updates_shape); + for (int64 window_dim : dim_numbers.update_window_dims()) { + if (window_dim < 0 || window_dim >= updates_rank) { + return InvalidArgument( + "Invalid update_window_dims set in scatter op; valid range is [0, " + "%lld). got: %lld.", + updates_rank, window_dim); + } + } + + // Validate inserted_window_dims in ScatterDimensionNumbers. + if (!c_is_sorted(dim_numbers.inserted_window_dims())) { + return InvalidArgument( + "inserted_window_dims in scatter op must be sorted; got: %s.", + Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + } + if (c_adjacent_find(dim_numbers.inserted_window_dims()) != + dim_numbers.inserted_window_dims().end()) { + return InvalidArgument( + "inserted_window_dims in scatter op must not repeat; got: %s.", + Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + } + for (int64 inserted_dim : dim_numbers.inserted_window_dims()) { + if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) { + return InvalidArgument( + "Invalid inserted_window_dims set in scatter op; valid range is [0, " + "%d), got: %lld.", + operand_shape.dimensions_size(), inserted_dim); + } + } + + // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers. + if (dim_numbers.scatter_dims_to_operand_dims_size() != + scatter_indices_shape[dim_numbers.index_vector_dim()]) { + return InvalidArgument( + "Scatter op has %d elements in scatter_dims_to_operand_dims and the " + "bound of dimension index_vector_dim=%lld of scatter_indices is %lld. " + "These two numbers must be equal.", + dim_numbers.scatter_dims_to_operand_dims_size(), + dim_numbers.index_vector_dim(), + scatter_indices_shape[dim_numbers.index_vector_dim()]); + } + for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); ++i) { + int64 scatter_dim_to_operand_dim = + dim_numbers.scatter_dims_to_operand_dims(i); + if (scatter_dim_to_operand_dim < 0 || + scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) { + return InvalidArgument( + "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), " + "got: %d->%lld.", + operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim); + } + } + std::vector sorted_scatter_dims_to_operand_dims( + dim_numbers.scatter_dims_to_operand_dims().begin(), + dim_numbers.scatter_dims_to_operand_dims().end()); + c_sort(sorted_scatter_dims_to_operand_dims); + if (c_adjacent_find(sorted_scatter_dims_to_operand_dims) != + sorted_scatter_dims_to_operand_dims.end()) { + return InvalidArgument( + "Repeated dimensions not allowed in scatter_dims_to_operand_dims; " + "got: %s.", + Join(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str()); + } + + return Status::OK(); +} + +} // namespace + +/*static*/ StatusOr ShapeInference::InferScatterShape( + const Shape& operand_shape, const Shape& scatter_indices_shape, + const Shape& updates_shape, const ProgramShape& to_apply_shape, + const ScatterDimensionNumbers& scatter_dim_numbers) { + TF_RETURN_IF_ERROR( + ExpectArray(operand_shape, "operand tensor of scatter op")); + TF_RETURN_IF_ERROR( + ExpectArray(scatter_indices_shape, "scatter indices of scatter op")); + TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op")); + + if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) { + return InvalidArgument( + "Scatter indices parameter must be an integral tensor; got %s.", + ShapeUtil::HumanString(scatter_indices_shape).c_str()); + } + + if (scatter_indices_shape.dimensions_size() < + scatter_dim_numbers.index_vector_dim() || + scatter_dim_numbers.index_vector_dim() < 0) { + return InvalidArgument( + "Scatter index leaf dimension must be within [0, rank(scatter_indices)" + " + 1). rank(scatter_indices) is %d and scatter index leaf dimension " + "is %lld.", + scatter_indices_shape.dimensions_size(), + scatter_dim_numbers.index_vector_dim()); + } + + // Check if the update computation has a proper shape as a reduction. + const Shape init_value_shape = + ShapeUtil::MakeShape(operand_shape.element_type(), {}); + TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape}, + {updates_shape.element_type()}, + /*inputs=*/1)); + + std::vector expanded_scatter_indices_shape = + ArraySliceToVector(AsInt64Slice(scatter_indices_shape.dimensions())); + if (expanded_scatter_indices_shape.size() == + scatter_dim_numbers.index_vector_dim()) { + expanded_scatter_indices_shape.push_back(1); + } + + int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 + + scatter_dim_numbers.update_window_dims_size(); + if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) { + return InvalidArgument("Updates tensor must be of rank %lld; got %lld.", + expected_updates_rank, + ShapeUtil::Rank(updates_shape)); + } + + TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers( + operand_shape, expanded_scatter_indices_shape, updates_shape, + scatter_dim_numbers)); + + int64 inserted_dims_seen = 0; + std::vector max_update_window_bounds; + for (int i = 0; i < operand_shape.dimensions_size(); ++i) { + if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() && + scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) { + ++inserted_dims_seen; + } else { + max_update_window_bounds.push_back(operand_shape.dimensions(i)); + } + } + for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) { + auto update_window_dim = scatter_dim_numbers.update_window_dims(i); + if (updates_shape.dimensions(update_window_dim) > + max_update_window_bounds[i]) { + return InvalidArgument( + "Bounds of the window dimensions of updates must not exceed the " + "bounds of the corresponding dimensions of operand. For dimension " + "%lld, updates bound is %lld, operand bound is %lld.", + update_window_dim, updates_shape.dimensions(update_window_dim), + max_update_window_bounds[i]); + } + } + + int64 scatter_dims_seen = 0; + for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) { + bool is_update_window_dim = + c_binary_search(scatter_dim_numbers.update_window_dims(), i); + if (is_update_window_dim) { + continue; + } + if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) { + ++scatter_dims_seen; + } + if (updates_shape.dimensions(i) != + expanded_scatter_indices_shape[scatter_dims_seen]) { + return InvalidArgument( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices. For " + "scatter dimension %lld, updates bound is %lld, scatter_indices " + "bound is %lld.", + i, updates_shape.dimensions(i), + expanded_scatter_indices_shape[scatter_dims_seen]); + } + ++scatter_dims_seen; + } + + return operand_shape; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index ad34a2aa184e786a9825193d23f106f8a950758a..c185b0a1bd79e23e0d76daad50fb4a9708a743dd 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -119,11 +119,22 @@ class ShapeInference { const Shape& in, FftType fft_type, tensorflow::gtl::ArraySlice fft_length); - // Infers the shape produced a cross replica sum with the given operand + // Infers the shape produced by a cross replica sum with the given operand // shapes. static StatusOr InferCrossReplicaSumShape( tensorflow::gtl::ArraySlice operand_shapes); + // Infers final shape of an Alltoall operation that is created by the xla + // builder. + static StatusOr InferAllToAllShape(const Shape& shape, + int64 split_dimension, + int64 concat_dimension, + int64 split_count); + + // Infers the shape of an HLO all-to-all instruction. + static StatusOr InferAllToAllTupleShape( + tensorflow::gtl::ArraySlice operand_shapes); + // Infers the shape produced by applying the given reduction computation // shape to the given input operand shape. // @@ -131,7 +142,7 @@ class ShapeInference { // index as the leading parameter, and the program shape should match // accordingly (or an error will result). static StatusOr InferReduceShape( - const Shape& arg, const Shape& init_value, + tensorflow::gtl::ArraySlice arg_shapes, tensorflow::gtl::ArraySlice dimensions_to_reduce, const ProgramShape& to_apply); @@ -268,6 +279,14 @@ class ShapeInference { const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice window_bounds); + // Helper that validates the given input shape, scatter indices shape, updates + // shape, and scatter dimension numbers that constitute a scatter operation, + // and returns the result shape of the scatter operation. + static StatusOr InferScatterShape( + const Shape& operand_shape, const Shape& scatter_indices_shape, + const Shape& updates_shape, const ProgramShape& to_apply_shape, + const ScatterDimensionNumbers& scatter_dim_numbers); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. @@ -286,6 +305,10 @@ class ShapeInference { static StatusOr InferSelectShape(const Shape& pred, const Shape& on_true, const Shape& on_false); + // Helper for inferring the shape of TupleSelect ops. + static StatusOr InferTupleSelectShape(const Shape& pred, + const Shape& on_true, + const Shape& on_false); // Helper for inferring shapes of binary operations which use degenerate // dimension broadcasting (a dimension of size 1 in one operand is broadcast diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index bafe14d6f45f851924c37908d4c93bbff2dac459..a73fa181cdd13dc7fabcdc367ae117e19bdc3e5f 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -62,7 +63,7 @@ class ReduceShapeInferenceTest : public ShapeInferenceTest { tensorflow::gtl::ArraySlice dimensions_to_reduce) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); auto inferred_status = ShapeInference::InferReduceShape( - arg, f32_, dimensions_to_reduce, to_apply); + {&arg, &f32_}, dimensions_to_reduce, to_apply); EXPECT_IS_OK(inferred_status.status()); EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape, inferred_status.ValueOrDie())); @@ -702,11 +703,99 @@ TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) { /*dimensions_to_reduce=*/{0, 1, 2}); } +TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_IS_OK(inferred_status.status()); + EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}), + inferred_status.ValueOrDie())); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = + ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_}, + ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("must take 4 parameters, but takes 6 parameter(s)")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT( + inferred_status.status().error_message(), + HasSubstr( + "parameter shape differs from the result shape: s32[] vs f32[]")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) { + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("must have at least 2 arguments, has 0")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = + ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT( + inferred_status.status().error_message(), + HasSubstr("must produce a tuple with 2 elements, but produces a scalar")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT( + inferred_status.status().error_message(), + HasSubstr("must produce a tuple with 2 elements, but has 3 elements")); +} + +TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) { + Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); + ProgramShape to_apply = ShapeUtil::MakeProgramShape( + {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_})); + auto inferred_status = ShapeInference::InferReduceShape( + {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("accumulator shape at index 0 differs from the " + "init_value shape: s32[] vs f32[]")); +} + TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = ShapeInference::InferReduceShape( - ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{3, 4}, - to_apply); + {&arg_shape, &f32_}, + /*dimensions_to_reduce=*/{3, 4}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), HasSubstr("out-of-bounds dimension")); @@ -714,8 +803,9 @@ TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_); + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = - ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, + ShapeInference::InferReduceShape({&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), @@ -724,12 +814,13 @@ TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_); + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = - ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, + ShapeInference::InferReduceShape({&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), - HasSubstr("first parameter shape differs")); + HasSubstr("0-th parameter shape differs")); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2) { @@ -1523,7 +1614,19 @@ TEST_F(ShapeInferenceTest, BadSlice) { << statusor.status(); } -class GatherShapeInferenceTest : public ShapeInferenceTest { +TEST_F(ShapeInferenceTest, BadSort) { + auto keys = ShapeUtil::MakeShape(F32, {4}); + auto values = ShapeUtil::MakeShape(F32, {5}); + StatusOr statusor = + ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values}); + ASSERT_FALSE(statusor.ok()); + + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("dimensions must match")) + << statusor.status(); +} + +class ScatterGatherShapeInferenceTest : public ShapeInferenceTest { protected: const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5}); @@ -1540,59 +1643,63 @@ class GatherShapeInferenceTest : public ShapeInferenceTest { ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_}); + const ProgramShape to_apply_ = + ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); }; -TEST_F(GatherShapeInferenceTest, TensorFlowGather) { - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, - HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/1), - /*window_bounds=*/{64, 1})); +// Shape inference tests for Gather. + +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) { + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + matrix_64_48_, s64_vector_32_, + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), + /*window_bounds=*/{64, 1})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, - HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{1}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, - /*index_vector_dim=*/1), - /*window_bounds=*/{1, 48})); +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) { + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + matrix_64_48_, s64_vector_32_, + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{1}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1), + /*window_bounds=*/{1, 48})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape(matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, - HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, - /*index_vector_dim=*/4), - /*window_bounds=*/{1, 48})); +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) { + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4), + /*window_bounds=*/{1, 48})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1604,12 +1711,12 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1622,12 +1729,12 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1640,13 +1747,13 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { +TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) { // This is equivalent to a dynamic slice. TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{0, 1, 2, 3, 4}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1658,13 +1765,13 @@ TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { +TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) { // The gather indices "tensor" is a scalar S here that's used to slice out // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result. TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_scalar_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{0, 1, 2, 3}, /*elided_window_dims=*/{0}, /*gather_dims_to_operand_dims=*/{0}, @@ -1676,13 +1783,14 @@ TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { +TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) { StatusOr statusor = ShapeInference::InferGatherShape( tuple_shape_, s64_vector_32_, - HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/1), + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1690,13 +1798,14 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { << statusor.status(); } -TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { +TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) { StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, tuple_shape_, - HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/0), + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1704,13 +1813,14 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { << statusor.status(); } -TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { +TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) { StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, vector_32_, - HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/0), + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1718,11 +1828,11 @@ TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingWindowIndices) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 8, 7}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1735,11 +1845,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowIndices) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 7}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1752,11 +1862,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexOutOfBounds) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 99, 100, 101}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1768,11 +1878,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 9}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1784,11 +1894,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingElidedWindowDims) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{4}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1802,11 +1912,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{0, 1, 2, 3, 19}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1819,11 +1929,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{0, 1, 2, 3, 3}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1837,11 +1947,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingGatherToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}, @@ -1856,11 +1966,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}, @@ -1874,11 +1984,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedGatherToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}, @@ -1892,11 +2002,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingElidedWindowDims) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{2, 1}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1908,10 +2018,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { +TEST_F(ScatterGatherShapeInferenceTest, + InvalidGatherDimNumbers_WindowBoundsTooLarge) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7}, /*elided_window_dims=*/{2}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1924,11 +2035,11 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1942,11 +2053,11 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7}, /*elided_window_dims=*/{1}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1959,10 +2070,10 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { +TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1976,5 +2087,575 @@ TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { << statusor.status(); } +// Shape inference tests for Scatter. + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdates) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {64, 32}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {32, 48}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdates) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {10, 32}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {32, 8}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {65, 32}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Bounds of the window dimensions of updates must not exceed " + "the bounds of the corresponding dimensions of operand.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {32, 49}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Bounds of the window dimensions of updates must not exceed " + "the bounds of the corresponding dimensions of operand.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + TfScatterWithUpdatesNotMatchingIndices) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {64, 31}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + TfScatterWithUpdatesNotMatchingIndicesV2) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {31, 48}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdates) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 64}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdates) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 10}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 12}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 65}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Bounds of the window dimensions of updates must not exceed " + "the bounds of the corresponding dimensions of operand.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + TfScatterNdWithUpdatesNotMatchingIndices) { + StatusOr statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {9, 9, 8, 7, 64}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfBatchDynamicUpdateSlice) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDim) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, NoUpdateScatterDims) { + // This is equivalent to a dynamic update slice. + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0, 1, 2, 3, 4}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, ScalarScatterIndices) { + // The scalar indices "tensor" is a scalar S here that's used to update a + // [30,29,28,27] shaped tensor within the operand at position S. + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_scalar_, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0, 1, 2, 3}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/0))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedTensorInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + tuple_shape_, s64_vector_32_, s64_vector_32_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected array argument for operand")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + ScatterWithTupleShapedScatterIndicesInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + s64_vector_32_, tuple_shape_, s64_vector_32_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected array argument for scatter indices")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + s64_vector_32_, s64_vector_32_, tuple_shape_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected array argument for updates")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, FloatingPointScatterIndicesInput) { + StatusOr statusor = ShapeInference::InferScatterShape( + s64_vector_32_, vector_32_, s64_vector_32_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Scatter indices parameter must be an integral tensor")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/10)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Scatter index leaf dimension must be within [0, " + "rank(scatter_indices) + 1)")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdates) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 50}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Updates tensor must be of rank 7; got 8.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdateComputation) { + const ProgramShape invalid_update_computation = + ShapeUtil::MakeProgramShape({f32_}, f32_); + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), + invalid_update_computation, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Reduction function must take 2 parameters, but takes 1")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 8, 7}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("update_window_dims in scatter op must be sorted")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_RepeatedUpdateWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 7}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("update_window_dims in scatter op must not repeat")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 9}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid update_window_dims set in scatter op; valid " + "range is [0, 9)")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{2, 1}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("inserted_window_dims in scatter op must be sorted")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_RepeatedInsertedWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 1}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("inserted_window_dims in scatter op must not repeat")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 5}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid inserted_window_dims set in scatter op; valid " + "range is [0, 5)")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Scatter op has 4 elements in scatter_dims_to_operand_dims and " + "the bound of dimension index_vector_dim=4 of scatter_indices " + "is 5. These two numbers must be equal")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 10}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid scatter_dims_to_operand_dims mapping; domain " + "is [0, 5), got: 4->10")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 2, 3}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Repeated dimensions not allowed in scatter_dims_to_operand_dims")) + << statusor.status(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc new file mode 100644 index 0000000000000000000000000000000000000000..c0582c6a2d3a05e2ed5aead5faac54e536d350cd --- /dev/null +++ b/tensorflow/compiler/xla/service/stream_pool.cc @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/stream_pool.h" + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) { + std::unique_ptr stream; + { + tensorflow::mutex_lock lock(mu_); + if (!streams_.empty()) { + // Re-use an existing stream from the pool. + stream = std::move(streams_.back()); + streams_.pop_back(); + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool reusing existing stream"; + } + } + + if (!stream) { + // Create a new stream. + stream = MakeUnique(executor); + stream->Init(); + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool created new stream"; + } + + // Return the stream wrapped in Ptr, which has our special deleter semantics. + PtrDeleter deleter = {this}; + return Ptr(stream.release(), deleter); +} + +void StreamPool::ReturnStream(se::Stream* stream) { + if (stream->ok()) { + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool returning ok stream"; + tensorflow::mutex_lock lock(mu_); + streams_.emplace_back(stream); + } else { + // If the stream has encountered any errors, all subsequent operations on it + // will fail. So just delete the stream, and rely on new streams to be + // created in the future. + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool deleting !ok stream"; + delete stream; + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/stream_pool.h b/tensorflow/compiler/xla/service/stream_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..7221d323a61593ac4b203a81b6046d81a5beaaf0 --- /dev/null +++ b/tensorflow/compiler/xla/service/stream_pool.h @@ -0,0 +1,64 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_STREAM_POOL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_STREAM_POOL_H_ + +#include +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// Pool of stream_executor::Streams, which are created as needed and +// destroyed when the pool is destroyed. +class StreamPool { + public: + struct PtrDeleter { + void operator()(se::Stream* stream) { pool->ReturnStream(stream); } + StreamPool* pool; + }; + + // Stream pointer type returned by BorrowStream, which returns the + // stream to the pool on destruction. + using Ptr = std::unique_ptr; + + StreamPool() {} + + // Returns a pointer to a stream in the pool, creating a new stream + // if none are available in the pool. The returned smart pointer + // returns the stream to the pool on destruction. + // + // This method is thread-safe. + Ptr BorrowStream(se::StreamExecutor* executor); + + private: + // Puts a pointer to a stream back into the pool, leaving it free + // for future use. Streams that have previously encountered errors + // are deleted, and not returned to the pool. + // + // This method is thread-safe. + void ReturnStream(se::Stream* stream); + + tensorflow::mutex mu_; + std::vector> streams_ GUARDED_BY(mu_); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_STREAM_POOL_H_ diff --git a/tensorflow/compiler/xla/service/stream_pool_test.cc b/tensorflow/compiler/xla/service/stream_pool_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..aaf5c37b0d250f78cb57639255ac9b59e1b462f7 --- /dev/null +++ b/tensorflow/compiler/xla/service/stream_pool_test.cc @@ -0,0 +1,136 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/stream_pool.h" + +#include + +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace { + +class StreamPoolTest : public ::testing::Test { + protected: + std::unique_ptr NewStreamExecutor() { + se::Platform* platform = + se::MultiPlatformManager::PlatformWithName("Host").ConsumeValueOrDie(); + se::StreamExecutorConfig config(/*ordinal=*/0); + return platform->GetUncachedExecutor(config).ConsumeValueOrDie(); + } +}; + +TEST_F(StreamPoolTest, EmptyPool) { StreamPool pool; } + +TEST_F(StreamPoolTest, OneStreamPool) { + std::unique_ptr executor = NewStreamExecutor(); + StreamPool pool; + + // Borrow and return a stream. + StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + se::Stream* stream1_ptr = stream1.get(); + EXPECT_TRUE(stream1->ok()); + stream1 = nullptr; + + // Borrow and return another stream. + StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + se::Stream* stream2_ptr = stream2.get(); + EXPECT_TRUE(stream2->ok()); + stream2 = nullptr; + + // The underlying streams should be the same, since stream1 was the + // only stream available in the pool when stream2 was borrowed. + EXPECT_EQ(stream1_ptr, stream2_ptr); +} + +TEST_F(StreamPoolTest, TwoStreamPool) { + std::unique_ptr executor = NewStreamExecutor(); + StreamPool pool; + + // Borrow two streams. + StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + se::Stream* stream1_ptr = stream1.get(); + EXPECT_TRUE(stream1->ok()); + StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + se::Stream* stream2_ptr = stream2.get(); + EXPECT_TRUE(stream2->ok()); + + // The underlying streams should be different, since we haven't + // returned either of them yet. + EXPECT_NE(stream1_ptr, stream2_ptr); + + // Return stream1 and borrow stream3. + stream1 = nullptr; + StreamPool::Ptr stream3 = pool.BorrowStream(executor.get()); + se::Stream* stream3_ptr = stream3.get(); + EXPECT_TRUE(stream3->ok()); + + // stream1 and stream3 should be the same. + EXPECT_EQ(stream1_ptr, stream3_ptr); + EXPECT_NE(stream2_ptr, stream3_ptr); + + // Return stream2, and borrow stream4. + stream2 = nullptr; + StreamPool::Ptr stream4 = pool.BorrowStream(executor.get()); + se::Stream* stream4_ptr = stream4.get(); + EXPECT_TRUE(stream4->ok()); + + // Stream2 and stream4 should be the same. + EXPECT_EQ(stream2_ptr, stream4_ptr); + EXPECT_NE(stream3_ptr, stream4_ptr); +} + +TEST_F(StreamPoolTest, BadStreamDiscarded) { + std::unique_ptr executor = NewStreamExecutor(); + StreamPool pool; + + // Borrow a stream. + StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + EXPECT_TRUE(stream1->ok()); + + // Force an error on the stream; here we call a method that requires + // DNN support, which we know the Host platform doesn't support. + stream1->ThenDepthConcatenate({}, {}, nullptr); + EXPECT_FALSE(stream1->ok()); + + // Return stream1 and borrow stream2. + stream1 = nullptr; + StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + se::Stream* stream2_ptr = stream2.get(); + EXPECT_TRUE(stream2->ok()); + + // The underlying streams should be different. They would have been + // the same, but since we forced an error on stream1, it cannot be + // put back into the pool. Sadly we can't just check: + // EXPECT_NE(stream1_ptr, stream2_ptr); + // + // The above should hold logically, but it may fail if the new + // stream instance allocated for stream2 happens to reside in the + // same memory address as stream1, which has been deleted. + // + // The check that stream2->ok() serves as a good-enough check. + + // Return stream2 and borrow stream3. The previous error on stream1 + // has no effect on these streams, and they are the same. + stream2 = nullptr; + StreamPool::Ptr stream3 = pool.BorrowStream(executor.get()); + se::Stream* stream3_ptr = stream3.get(); + EXPECT_TRUE(stream3->ok()); + EXPECT_EQ(stream2_ptr, stream3_ptr); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 4c5038a009ba5da4172129980014913f3f4418f4..32d368a90429ec026120bdf033957617eeaba23e 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -43,14 +43,39 @@ TransferManager::GetPlatformTransferManagers() { StatusOr> TransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer) { StatusOr> ret; + + se::Stream* substream = stream->GetOrCreateSubStream(); + substream->ThenWaitFor(stream); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + + tensorflow::Notification n; + Status s; + Literal literal(device_buffer.on_host_shape()); + TransferLiteralFromDevice(substream, device_buffer, literal, + [&](Status status) { + s = status; + n.Notify(); + }); + n.WaitForNotification(); + if (!s.ok()) { + return s; + } + return MakeUnique(std::move(literal)); +} + +Status TransferManager::TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + const MutableBorrowingLiteral& literal) { se::Stream* substream = stream->GetOrCreateSubStream(); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); + Status ret; tensorflow::Notification n; - TransferLiteralFromDevice(substream, device_buffer, - [&](StatusOr> arg) { - ret = std::move(arg); + TransferLiteralFromDevice(substream, device_buffer, literal, + [&](Status status) { + ret = status; n.Notify(); }); n.WaitForNotification(); @@ -64,6 +89,7 @@ Status TransferManager::TransferLiteralToDevice( // Use a substream so that if we are called from a HostCallback we don't // deadlock. se::Stream* substream = stream->GetOrCreateSubStream(); + substream->ThenWaitFor(stream); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); TF_RETURN_IF_ERROR( @@ -74,22 +100,27 @@ Status TransferManager::TransferLiteralToDevice( StatusOr> TransferManager::TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source) { + StatusOr> ret; // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. - StatusOr> ret; se::Stream* substream = stream->GetOrCreateSubStream(); auto cleanup = tensorflow::gtl::MakeCleanup( [&]() { stream->ReturnSubStream(substream); }); tensorflow::Notification n; - TransferArrayFromDevice(substream, shape, source, - [&](StatusOr> arg) { - ret = std::move(arg); + Literal literal(shape); + Status s; + TransferArrayFromDevice(substream, shape, source, literal, + [&](Status status) { + s = status; n.Notify(); }); n.WaitForNotification(); - return ret; + if (!s.ok()) { + return s; + } + return MakeUnique(std::move(literal)); } Status TransferManager::TransferArrayToDevice( @@ -128,7 +159,7 @@ Status TransferManager::TransferArrayToDeviceAsync( void TransferManager::TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source, - std::function>)> done) { + const MutableBorrowingLiteral& literal, std::function done) { if (!ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) { auto error = StrCat("Shape ", ShapeUtil::HumanString(shape), " has a differently shaped representation on-device: ", @@ -145,7 +176,8 @@ void TransferManager::TransferArrayFromDevice( stream->parent()->platform(), stream->parent()->device_ordinal()); shaped_buffer.set_buffer(source, /*index=*/{}); - return TransferLiteralFromDevice(stream, shaped_buffer, std::move(done)); + return TransferLiteralFromDevice(stream, shaped_buffer, literal, + std::move(done)); } /* static */ void TransferManager::RegisterTransferManager( diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index e384359642a8fe09e0b8516e342a56259912922a..475a2e5c141d66fa689fb402da1ee81fb4ab80f7 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -59,6 +59,9 @@ class TransferManager { // This function should be avoided in favor of the asynchronous version below. virtual StatusOr> TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer); + virtual Status TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + const MutableBorrowingLiteral& literal); // Begins transferring a literal containing the data held in the given // ShapedBuffer using the provided executor. @@ -69,9 +72,10 @@ class TransferManager { // // device_buffer is copied by reference and must live at least until done() is // invoked. - virtual void TransferLiteralFromDevice( - se::Stream* stream, const ShapedBuffer& device_buffer, - std::function>)> done) = 0; + virtual void TransferLiteralFromDevice(se::Stream* stream, + const ShapedBuffer& device_buffer, + MutableBorrowingLiteral literal, + std::function done) = 0; // Transfers the given literal into the previously allocated device memory // represented by the given ShapedBuffer using the given executor. The shape @@ -101,10 +105,10 @@ class TransferManager { // transfer an array at a known address. Status TransferArrayToDevice(se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest); - void TransferArrayFromDevice( - se::Stream* stream, const Shape& shape, - const se::DeviceMemoryBase& source, - std::function>)> done); + void TransferArrayFromDevice(se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& source, + const MutableBorrowingLiteral& literal, + std::function done); Status TransferArrayToDeviceAsync(se::Stream* stream, const LiteralSlice& literal, @@ -120,9 +124,9 @@ class TransferManager { // Transfers the given literal from the Outfeed interface of the device, // using the given executor. - virtual Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, - const Shape& literal_shape, - Literal* literal) = 0; + virtual Status TransferLiteralFromOutfeed( + se::StreamExecutor* executor, const Shape& literal_shape, + MutableBorrowingLiteral literal) = 0; // Resets the devices associated with this transfer manager. virtual Status ResetDevices( @@ -167,16 +171,6 @@ class TransferManager { const se::Platform* platform); protected: - // Transfer a memory block of the given size from 'source' buffer to the - // Infeed interface of the device using the given executor. - // - // size is the size to transfer from source in bytes. - // - // source is the source data that must be in the target-dependent layout that - // the Infeed HLO used in the computation expects. - virtual Status TransferBufferToInfeed(se::StreamExecutor* executor, - int64 size, const void* source) = 0; - // Transfer a memory block of the given size from the device source into the // 'destination' buffer. // diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index cccb8f2fbb0266bbf1f40b09170938a1e5d3e78d..58f767e913fbc0023e0c45a4f0e82ecefeeef2d6 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -160,11 +160,11 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { auto builder = HloComputation::Builder("entry"); // (1.0 + 2.0) * (2.0 - 3.0) HloInstruction* const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); HloInstruction* const2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); HloInstruction* const3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( const1->shape(), HloOpcode::kAdd, const1, const2)); HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index d1e174464759dbc2c0d84c4ddac27cb21635e131..0447807a41b8b32ee297e1ca94393da8c687c5e6 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -232,8 +232,7 @@ Status TuplePointsToAnalysis::HandleGetTupleElement( // Copy the points-to set (and tuple sources) at index {element_index} of the // operand to the points-to set for this GetTupleElement instruction. points_to_set.ForEachMutableElement( - [&, this](const ShapeIndex& target_index, - PointsToSet::BufferList* points_to) { + [&](const ShapeIndex& target_index, PointsToSet::BufferList* points_to) { // Construct an index into the operand by prepending element_index to // the index for the GetTupleElement instruction's points-to set. ShapeIndex src_index; @@ -292,22 +291,29 @@ Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) { } Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { - // RecvDone aliases its input (Recv) tuple element {0} to its output. + // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its + // output. The other indices ({} and {1}) define their own buffers. PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done); + points_to_set.AddPointedToBuffer( + logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{}), + /*index=*/{}); + points_to_set.AddPointedToBuffer( + logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{1}), + /*index=*/{1}); + const PointsToSet& operand_points_to_set = GetPointsToSet(recv_done->operand(0)); - // Recursively copy the points to set of the operand tuple {0}. + // Recursively copy the points to set of the operand tuple {0} to the output + // element {0}. points_to_set.ForEachMutableElement( - [this, &points_to_set, &operand_points_to_set]( + [&points_to_set, &operand_points_to_set]( const ShapeIndex& index, PointsToSet::BufferList* buffers) { - ShapeIndex src_index({0}); - for (auto element : index) { - src_index.push_back(element); + if (index.empty() || index[0] != 0) { + return; } - *buffers = operand_points_to_set.element(src_index); - for (auto& tuple_source : - operand_points_to_set.tuple_sources(src_index)) { + *buffers = operand_points_to_set.element(index); + for (auto& tuple_source : operand_points_to_set.tuple_sources(index)) { points_to_set.add_tuple_source(index, tuple_source); } }); @@ -315,7 +321,7 @@ Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { } Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) { - // Send creates a tuple of {aliased operand, U32 context}. + // Send creates a tuple of {aliased operand, U32 context, token}. PointsToSet& points_to_set = CreateEmptyPointsToSet(send); // Creates the points to set for the tuple and its element at {1}. @@ -328,6 +334,10 @@ Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) { context_buffer->push_back( &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1}))); + auto token_buffer = points_to_set.mutable_element(ShapeIndex({2})); + token_buffer->push_back( + &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({2}))); + // Recursively copy the points to set of the operand to output tuple {0}. const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0)); operand_points_to_set.ForEachElement( @@ -388,7 +398,7 @@ Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select) { +Status TuplePointsToAnalysis::HandleTupleSelect(HloInstruction* tuple_select) { // Select allocates a new buffer and then shallow copies the on_true or // on_false buffer into this new buffer. Which side is chosen cannot be // determined statically so conservatively set the points-to set to the union @@ -396,9 +406,9 @@ Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select) { // // First create a copy of the on_true points-to set (and tuple sources), then // add in elements of the on_false points-to set (tuple sources). - auto on_true = select->operand(1); - auto on_false = select->operand(2); - PointsToSet& points_to_set = CreateCopiedPointsToSet(select, on_true); + auto on_true = tuple_select->operand(1); + auto on_false = tuple_select->operand(2); + PointsToSet& points_to_set = CreateCopiedPointsToSet(tuple_select, on_true); const PointsToSet& false_points_to_set = *PerInst(on_false)->points_to_set; points_to_set.ForEachMutableElement( [&](const ShapeIndex& index, PointsToSet::BufferList* buffers) { @@ -416,7 +426,7 @@ Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select) { // respective element in the points-to set should contain only itself. points_to_set.mutable_element({})->clear(); points_to_set.AddPointedToBuffer( - logical_buffer_analysis_->GetBuffer(select, /*index=*/{}), + logical_buffer_analysis_->GetBuffer(tuple_select, /*index=*/{}), /*index=*/{}); return Status::OK(); } @@ -506,7 +516,7 @@ Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction( const HloInstruction* instruction, TuplePointsToAnalysis::BufferDefinitionVector* buffers) { GetPointsToSet(instruction) - .ForEachElement([this, buffers, instruction]( + .ForEachElement([buffers, instruction]( const ShapeIndex& index, const PointsToSet::BufferList& source_buffers) { // Add buffers which 'instruction' is the source of. @@ -536,7 +546,7 @@ PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet( PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction); const PointsToSet& src_points_to_set = GetPointsToSet(src); dst_points_to_set.ForEachMutableElement( - [this, &dst_points_to_set, &src_points_to_set]( + [&dst_points_to_set, &src_points_to_set]( const ShapeIndex& index, PointsToSet::BufferList* buffers) { *buffers = src_points_to_set.element(index); for (auto& tuple_source : src_points_to_set.tuple_sources(index)) { @@ -707,6 +717,7 @@ bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt( // root at operand 0 or 1. Or... // (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index // 0. +// (5) The 'user' of 'operand' is Sort, and it is the only user. // // (2) and (3) can only be determined if points-to analysis is available. bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( @@ -772,6 +783,21 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( std::vector operand_indices = user->OperandIndices(operand); return operand_indices.size() == 1 && operand_indices[0] == 0; } + if (user->opcode() == HloOpcode::kSort) { + // Only valid if there are no other users. + if (operand->users().size() != 1) { + return false; + } + // If we only sort keys, the output of sort is not a tuple, so we can always + // share the buffer. + if (user->operand_count() == 1) { + return true; + } + CHECK(!user_index.empty()); + // Only share with the right tuple element buffer. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && user_index[0] == operand_indices[0]; + } if (user->opcode() == HloOpcode::kCall) { // TODO(b/62548313): Remove when buffer assignment is module scoped and // does not assign buffers to calls. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index c0d82414806d9a6ff57aec59d077f444137fec9a..686bb053288fbd6a46ca50a2c65c739354fd2678 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -253,7 +253,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; - Status HandleSelect(HloInstruction* select) override; + Status HandleTupleSelect(HloInstruction* tuple_select) override; string ToString() const; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 5734f284071944bc22011405898cf86f33dc48d7..10d382e8abc92145c1804cbf18bbed714fa34571 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -124,9 +124,9 @@ class TuplePointsToAnalysisTest : public HloTestBase { TEST_F(TuplePointsToAnalysisTest, SimpleTuple) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -177,14 +177,14 @@ TEST_F(TuplePointsToAnalysisTest, NestedTuple) { // tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto inner_tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({inner_tuple, constant3})); @@ -238,14 +238,14 @@ TEST_F(TuplePointsToAnalysisTest, GetTupleElement) { // tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto inner_tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({inner_tuple, constant3})); @@ -270,7 +270,7 @@ TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) { // Create a tuple which contains duplicate elements. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant, constant, constant})); @@ -291,9 +291,9 @@ TEST_F(TuplePointsToAnalysisTest, TupleCopy) { // the same. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto copy = builder.AddInstruction( @@ -317,9 +317,10 @@ TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) { // Send forwards its operand to the output tuple at {0}. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); auto send = builder.AddInstruction( - HloInstruction::CreateSend(constant, /*channel_id=*/0)); + HloInstruction::CreateSend(constant, token, /*channel_id=*/0)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); BuildModuleAndRunAnalysis(builder.Build()); @@ -342,8 +343,9 @@ TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) { TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) { // RecvDone forwards its operand tuple element at {0} to the output. auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); auto recv = builder.AddInstruction(HloInstruction::CreateRecv( - ShapeUtil::MakeShape(F32, {1, 2, 3}), /*channel_id=*/0)); + ShapeUtil::MakeShape(F32, {1, 2, 3}), token, /*channel_id=*/0)); auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); BuildModuleAndRunAnalysis(builder.Build()); @@ -355,7 +357,7 @@ TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) { ExpectHasTopLevelBuffers( points_to_analysis_->GetPointsToSet(recv).element({}), {recv}); - ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {}}}); + ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {0}}}); } TEST_F(TuplePointsToAnalysisTest, TupleSelect) { @@ -363,18 +365,18 @@ TEST_F(TuplePointsToAnalysisTest, TupleSelect) { // set containing the union of both sides. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto tuple2 = builder.AddInstruction( HloInstruction::CreateTuple({constant2, constant2})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); BuildModuleAndRunAnalysis(builder.Build()); @@ -401,9 +403,9 @@ TEST_F(TuplePointsToAnalysisTest, SelectTupleParameters) { auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, tuple_shape, "param1")); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, param0, param1)); + tuple_shape, HloOpcode::kTupleSelect, pred, param0, param1)); auto copy = builder.AddInstruction( HloInstruction::CreateUnary(tuple_shape, HloOpcode::kCopy, select)); @@ -441,18 +443,18 @@ TEST_F(TuplePointsToAnalysisTest, UnambiguousTupleSelect) { // Select from two identical tuples. The result should not be ambiguous. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto tuple2 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); BuildModuleAndRunAnalysis(builder.Build()); @@ -472,9 +474,9 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) { // the right values. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto inner_tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto inner_tuple2 = builder.AddInstruction( @@ -486,9 +488,9 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) { builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple2})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); + tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2)); BuildModuleAndRunAnalysis(builder.Build()); @@ -519,9 +521,9 @@ TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) { // have the operand of the bitcast in its points-to set. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( constant2->shape(), HloOpcode::kBitcast, constant2)); auto tuple = @@ -555,9 +557,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) { // Construct a tuple constant and kCopy it. Verify the points-to set of the // copy correctly correctly points into the nested elements of the constant. auto builder = HloComputation::Builder(TestName()); - auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::MakeTuple({Literal::CreateR2({{1.0}, {2.0}}).get(), - Literal::CreateR1({2.0, 42}).get()}))); + auto tuple_constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), + LiteralUtil::CreateR1({2.0, 42}).get()}))); auto copy = builder.AddInstruction(HloInstruction::CreateUnary( tuple_constant->shape(), HloOpcode::kCopy, tuple_constant)); @@ -577,9 +580,9 @@ TEST_F(TuplePointsToAnalysisTest, BufferAliases) { // times. Verify buffer alias sets. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto inner_tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto tuple = builder.AddInstruction( @@ -618,7 +621,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { auto tuple_element1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(update_shape, tuple_param0, 1)); auto ones = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); + LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f}))); // Create 'update' = Add(GetTupleElement(tuple_param0, 1), ones) auto update = builder.AddInstruction(HloInstruction::CreateBinary( update_shape, HloOpcode::kAdd, tuple_element1, ones)); @@ -866,9 +869,9 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -960,9 +963,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -1009,14 +1012,56 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); } +TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto sort = + builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { + auto builder = HloComputation::Builder(TestName()); + + Shape keys_shape = ShapeUtil::MakeShape(F32, {8}); + Shape values_shape = ShapeUtil::MakeShape(F32, {8}); + auto keys = builder.AddInstruction( + HloInstruction::CreateParameter(0, keys_shape, "keys")); + auto values = builder.AddInstruction( + HloInstruction::CreateParameter(1, values_shape, "values")); + auto sort = builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The buffer for the keys can be shared with the first tuple entry. + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0})); + // The buffer for the values can be shared with the second tuple entry. + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(values, {}, + sort, {1})); + // Verify that the buffers are not shared with the "wrong" tuple entry. + EXPECT_FALSE( + points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(values, {}, + sort, {0})); +} + TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { auto builder = HloComputation::Builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); @@ -1025,7 +1070,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add_operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -1047,7 +1092,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -1055,7 +1100,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { HloInstruction::CreateReverse(data_shape, operand, {0, 1})); auto two = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); @@ -1073,7 +1118,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - auto make_cond = [this, &data_shape]() { + auto make_cond = [&data_shape]() { auto builder = HloComputation::Builder(TestName() + ".Cond"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); @@ -1082,7 +1127,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { return builder.Build(); }; - auto make_body = [this, &data_shape]() { + auto make_body = [&data_shape]() { auto builder = HloComputation::Builder(TestName() + ".Body"); auto data = builder.AddInstruction( HloInstruction::CreateParameter(0, data_shape, "data")); @@ -1120,7 +1165,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { auto sub_param = sub_builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "sub_param")); auto one = sub_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto ones = sub_builder.AddInstruction( HloInstruction::CreateBroadcast(shape, one, {1})); auto add = sub_builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index d3635eae81ec7017f9bf6a69250d10716309c9ec..39b693872da6bd985d95c2abc9519662c838a3f5 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc new file mode 100644 index 0000000000000000000000000000000000000000..af2cb6dc2a3f4a004351acc62796e0daf46719c2 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -0,0 +1,238 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { + +using tensorflow::gtl::nullopt; +using tensorflow::gtl::optional; + +// Finds and returns the non-constant operand in instr. +// +// CHECK-fails if instr doesn't have exactly one unique non-constant operand. +static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { + const HloInstruction* result = nullptr; + for (const HloInstruction* operand : instr->operands()) { + if (!operand->IsConstant()) { + if (result != nullptr) { + CHECK_EQ(result, operand); + } + result = operand; + } + } + CHECK_NE(result, nullptr); + return result; +} + +// If all of instr's operands are either constants or have the form +// get-tuple-element(gte_operand, N) +// for the same value N, returns N. Otherwise, returns nullopt. +static optional GetGTEOperandIndex(const HloInstruction* instr, + const HloInstruction* gte_operand) { + VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", " + << gte_operand->ToString() << ")"; + optional tuple_idx; + for (const HloInstruction* operand : instr->operands()) { + if (operand->IsConstant()) { + continue; + } + // Look through copies. + // TODO(b/68830972): We wouldn't need this if for loop matching on the GPU + // would run before copy insertion. + if (operand->opcode() == HloOpcode::kCopy) { + operand = operand->operand(0); + } + if (operand->opcode() != HloOpcode::kGetTupleElement) { + VLOG(2) << "instr uses something other than gte(gte_operand): " + << operand->ToString(); + return nullopt; + } + if (operand->operand(0) != gte_operand) { + VLOG(2) << "instr has gte whose operand is not gte_operand: " + << operand->ToString(); + return nullopt; + } + if (tuple_idx && tuple_idx != operand->tuple_index()) { + VLOG(2) << "instr has operands with conflicting gte indices, " + << *tuple_idx << " vs " << operand->tuple_index(); + return nullopt; + } + + tuple_idx = operand->tuple_index(); + } + return tuple_idx; +} + +// Tries to get the tuple index of the induction variable of a while loop. +// +// Checks that the loop condition and root both plumb the induction variable +// through the same tuple index, and that they both apply exactly one op to the +// induction variable before deciding whether to do another loop iteration (in +// the loop condition's case) or packing the induction variable into the result +// tuple (in the loop body's case). +// +// Specifically, checks that the loop condition has structure +// +// root = op(constants, get-tuple-elem(param0, N), constants) +// +// and the loop body has the structure +// +// inc = op(constants, get-tuple-elem(param0, N), constants) +// root = tuple(..., inc, ...) // inc is N'th operand of tuple(). +// +// If so, returns N. Otherwise, returns nullopt. +static optional GetLoopInductionVarTupleIdx( + const HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + VLOG(2) << "Finding induction variable for loop " + << while_op->ToShortString(); + + // The while_cond computation should have the form + // + // while_cond_root = + // op(constants, get-tuple-elem(while_cond_param, N), constants). + // + // If it does, set indvar_tuple_idx to N. + auto* while_cond = while_op->while_condition(); + auto* while_cond_root = while_cond->root_instruction(); + auto* while_cond_param = while_cond->parameter_instruction(0); + optional indvar_tuple_idx = + GetGTEOperandIndex(while_cond_root, while_cond_param); + if (!indvar_tuple_idx) { + VLOG(2) << "Induction variable not found in loop condition: " + << while_cond->root_instruction()->ToString(); + return nullopt; + } + + // The while_body computation should have the form + // + // while_body_inc = + // op(constants, get-tuple-elem(while_body_param, N), constants) + // while_body_root = tuple(..., while_body_inc, ...) + // + // where while_body_inc is operand N of while_body_root. + auto* while_body = while_op->while_body(); + auto* while_body_root = while_body->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple instruction: " + << while_body_root->ToString(); + return nullopt; + } + + auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx); + auto* while_body_param = while_body->parameter_instruction(0); + optional while_body_indvar_tuple_idx = + GetGTEOperandIndex(while_body_inc, while_body_param); + if (!while_body_indvar_tuple_idx) { + VLOG(2) + << "Induction variable not found in while body increment instruction: " + << while_body_inc->ToString(); + return nullopt; + } + if (while_body_indvar_tuple_idx != indvar_tuple_idx) { + VLOG(2) << "Tuple index of induction variable does not match between loop " + "condition (" + << *indvar_tuple_idx << ") and while body (" + << *while_body_indvar_tuple_idx << ")"; + return nullopt; + } + + // Finally, check that the while loop's initial value is a tuple with enough + // elements. + auto* while_init = while_op->operand(0); + if (while_init->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While init expected to be a tuple: " << while_init->ToString(); + return nullopt; + } + + VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx; + return indvar_tuple_idx; +} + +optional ComputeWhileLoopTripCount(HloInstruction* while_op, + int64 max_value_returned) { + VLOG(2) << "Getting trip count for loop " << while_op->ToString(); + + // The loop's induction variable is found at + // + // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx), + // + // where comp is while_op->while_body() or while_op->while_condition(). + optional indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op); + if (!indvar_tuple_idx) { + return nullopt; + } + + // Now that we know the index of the induction variable, we can we can try to + // compute how many times the loop executes. Start by computing the induction + // variable's initial value. + HloEvaluator evaluator(/*max_loop_iterations=*/0); + auto* while_init = while_op->mutable_operand(0); + auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); + StatusOr> indvar_init_result = + evaluator.Evaluate(indvar_init); + if (!indvar_init_result.ok()) { + VLOG(2) << "Couldn't evaluate induction variable init: " + << indvar_init_result.status(); + return nullopt; + } + + auto* while_body = while_op->while_body(); + auto* while_body_indvar_update = + while_body->root_instruction()->operand(*indvar_tuple_idx); + auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); + + // The initial value of the induction variable. + std::unique_ptr indvar_iter_val = + std::move(indvar_init_result).ValueOrDie(); + for (int64 trip_count = 0; trip_count != max_value_returned + 1; + ++trip_count) { + auto* while_cond = while_op->while_condition(); + auto* while_cond_root = while_cond->root_instruction(); + auto* while_cond_indvar = NonConstantOperand(while_cond_root); + StatusOr> result = + evaluator.EvaluateWithSubstitutions( + while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}}); + if (!result.ok()) { + VLOG(2) << "Couldn't evaluate while cond: " << result.status(); + return nullopt; + } + if (result.ValueOrDie()->data() == + tensorflow::gtl::ArraySlice{false}) { + VLOG(2) << "Loop has static trip count of " << trip_count; + return trip_count; + } + + // Calculate the value of the induction variable after one iteration of the + // loop, and check whether the while condition is true with this new value. + StatusOr> indvar_next_result = + evaluator.EvaluateWithSubstitutions( + while_body_indvar_update, + {{while_body_indvar, indvar_iter_val.get()}}); + if (!indvar_next_result.ok()) { + VLOG(2) << "Couldn't evaluate induction variable update: " + << indvar_next_result.status(); + return nullopt; + } + indvar_iter_val = std::move(indvar_next_result).ValueOrDie(); + } + + VLOG(2) << "Loop has unknown trip count."; + return nullopt; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..bf59813e8c405a8709446bf8457729348ceae4ec --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_analysis.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/lib/gtl/optional.h" + +namespace xla { + +// Returns the precise trip count of the loop if it's statically known, +// nullopt otherwise. max_value_returned limits the number of steps that are +// evaluated while trying to brute force a loop trip count, trip counts larger +// than max_value_returned result in nullopt. +tensorflow::gtl::optional ComputeWhileLoopTripCount( + HloInstruction *while_op, int64 max_value_returned = 128); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc index 10fc4958fae06414dbe7a3a0a798cb5c6e0f35c2..62af45128ad2fb7bf886bef78ec3ab42529a181e 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -61,6 +61,12 @@ StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileBody( WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) { int64 index = invariant_gte->tuple_index(); const HloInstruction& invariant_value = *init_value.operand(index); + + // Should have at least one user that's not while_body_root. + if (invariant_gte->user_count() <= 1) { + continue; + } + if (invariant_value.opcode() == HloOpcode::kConstant) { auto* constant_instr = while_body->AddInstruction(invariant_value.Clone(/*suffix=*/".sunk")); diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 393e75803888d8a642881c4d525b170d1e1180ba..266039d2ff8ef4befba0d1023ac1914737207d4f 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -196,5 +196,50 @@ ENTRY entry { op::GetTupleElement(op::Parameter(0)), op::GetTupleElement(op::Parameter(0)))); } + +TEST_F(WhileLoopConstantSinkingTest, DontCreateDeadConstant) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2],f32[2]) parameter(0) + p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0 + p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1 + + outfeed = token[] outfeed(p_body.0) + ROOT root = (f32[2],f32[2],f32[2]) tuple(p_body.0, p_body.1, p_body.1) +} + +condition { + p_cond = (f32[2],f32[2]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = f32[2] constant({2, 1}) + while_init = (f32[2],f32[2]) tuple(const_0, const_1) + ROOT while = (f32[2],f32[2],f32[2]) while(while_init), condition=condition, + body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::GetTupleElement(), op::GetTupleElement(), + op::GetTupleElement())); + for (const HloInstruction* inst : while_body->instructions()) { + if (inst->opcode() == HloOpcode::kConstant) { + EXPECT_GT(inst->user_count(), 0); + } + } +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 23519e445ea8a5f578a54708f38059feef3280c0..32e69c335b713c438bd7fcb2053709b0624f58ed 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -53,7 +53,7 @@ HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation( builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "param")); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); return module->AddEmbeddedComputation(builder.Build()); } @@ -125,7 +125,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { builder.AddInstruction(HloInstruction::CreateUnary( scalar_s32, HloOpcode::kNegate, mul_result)); HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); HloInstruction* sub_result = builder.AddInstruction(HloInstruction::CreateBinary( scalar_s32, HloOpcode::kSubtract, negate_result, constant)); @@ -273,7 +273,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { HloComputation::Builder builder(TestName()); auto* scalar_param = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_s32, "param")); - auto* token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto* token = builder.AddInstruction(HloInstruction::CreateToken()); auto* init_value = builder.AddInstruction( HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( @@ -323,7 +323,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { HloComputation::Builder builder(TestName()); auto* scalar_param = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_s32, "param")); - auto* token = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto* token = builder.AddInstruction(HloInstruction::CreateToken()); auto* init_value = builder.AddInstruction( HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index ec05a74e286c89dd8db5ae07580e461938d7c087..dd8697e680c56165f87c365a721eda2de1ebc085 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/call_inliner.h" -#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -26,23 +26,6 @@ namespace xla { using tensorflow::gtl::nullopt; using tensorflow::gtl::optional; -// Finds and returns the non-constant operand in instr. -// -// CHECK-fails if instr doesn't have exactly one unique non-constant operand. -static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { - const HloInstruction* result = nullptr; - for (const HloInstruction* operand : instr->operands()) { - if (!operand->IsConstant()) { - if (result != nullptr) { - CHECK_EQ(result, operand); - } - result = operand; - } - } - CHECK_NE(result, nullptr); - return result; -} - // Determines whether the given instruction is a send/recv node, or has a // subcomputation which contains a send/recv node. static bool IsOrContainsSendOrRecv(const HloInstruction* instr); @@ -72,211 +55,6 @@ static bool IsOrContainsSendOrRecv(const HloInstruction* instr) { return false; } -// If all of instr's operands are either constants or have the form -// get-tuple-element(gte_operand, N) -// for the same value N, returns N. Otherwise, returns nullopt. -static optional GetGTEOperandIndex(const HloInstruction* instr, - const HloInstruction* gte_operand) { - VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", " - << gte_operand->ToString() << ")"; - optional tuple_idx; - for (const HloInstruction* operand : instr->operands()) { - if (operand->IsConstant()) { - continue; - } - if (operand->opcode() != HloOpcode::kGetTupleElement) { - VLOG(2) << "instr uses something other than gte(gte_operand): " - << operand->ToString(); - return nullopt; - } - if (operand->operand(0) != gte_operand) { - VLOG(2) << "instr has gte whose operand is not gte_operand: " - << operand->ToString(); - return nullopt; - } - if (tuple_idx && tuple_idx != operand->tuple_index()) { - VLOG(2) << "instr has operands with conflicting gte indices, " - << *tuple_idx << " vs " << operand->tuple_index(); - return nullopt; - } - - tuple_idx = operand->tuple_index(); - } - return tuple_idx; -} - -// Tries to get the tuple index of the induction variable of a while loop. -// -// Checks that the loop condition and root both plumb the induction variable -// through the same tuple index, and that they both apply exactly one op to the -// induction variable before deciding whether to do another loop iteration (in -// the loop condition's case) or packing the induction variable into the result -// tuple (in the loop body's case). -// -// Specifically, checks that the loop condition has structure -// -// root = op(constants, get-tuple-elem(param0, N), constants) -// -// and the loop body has the structure -// -// inc = op(constants, get-tuple-elem(param0, N), constants) -// root = tuple(..., inc, ...) // inc is N'th operand of tuple(). -// -// If so, returns N. Otherwise, returns nullopt. -static optional GetLoopInductionVarTupleIdx( - const HloInstruction* while_op) { - CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - VLOG(2) << "Finding induction variable for loop " - << while_op->ToShortString(); - - // The while_cond computation should have the form - // - // while_cond_root = - // op(constants, get-tuple-elem(while_cond_param, N), constants). - // - // If it does, set indvar_tuple_idx to N. - auto* while_cond = while_op->while_condition(); - auto* while_cond_root = while_cond->root_instruction(); - auto* while_cond_param = while_cond->parameter_instruction(0); - optional indvar_tuple_idx = - GetGTEOperandIndex(while_cond_root, while_cond_param); - if (!indvar_tuple_idx) { - VLOG(2) << "Induction variable not found in loop condition: " - << while_cond->root_instruction()->ToString(); - return nullopt; - } - - // The while_body computation should have the form - // - // while_body_inc = - // op(constants, get-tuple-elem(while_body_param, N), constants) - // while_body_root = tuple(..., while_body_inc, ...) - // - // where while_body_inc is operand N of while_body_root. - auto* while_body = while_op->while_body(); - auto* while_body_root = while_body->root_instruction(); - if (while_body_root->opcode() != HloOpcode::kTuple) { - VLOG(2) << "While body's root is not a tuple instruction: " - << while_body_root->ToString(); - return nullopt; - } - - auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx); - auto* while_body_param = while_body->parameter_instruction(0); - optional while_body_indvar_tuple_idx = - GetGTEOperandIndex(while_body_inc, while_body_param); - if (!while_body_indvar_tuple_idx) { - VLOG(2) - << "Induction variable not found in while body increment instruction: " - << while_body_inc->ToString(); - return nullopt; - } - if (while_body_indvar_tuple_idx != indvar_tuple_idx) { - VLOG(2) << "Tuple index of induction variable does not match between loop " - "condition (" - << *indvar_tuple_idx << ") and while body (" - << *while_body_indvar_tuple_idx << ")"; - return nullopt; - } - - // Finally, check that the while loop's initial value is a tuple with enough - // elements. - auto* while_init = while_op->operand(0); - if (while_init->opcode() != HloOpcode::kTuple) { - VLOG(2) << "While init expected to be a tuple: " << while_init->ToString(); - return nullopt; - } - - VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx; - return indvar_tuple_idx; -} - -// Tries to determine the number of times the given loop executes. Currently -// simply returns 0, 1, or "can't tell" (nullopt). -static optional GetLoopTripCount(HloInstruction* while_op) { - CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - VLOG(2) << "Getting trip count for loop " << while_op->ToString(); - - // The loop's induction variable is found at - // - // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx), - // - // where comp is while_op->while_body() or while_op->while_condition(). - optional indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op); - if (!indvar_tuple_idx) { - return nullopt; - } - - VLOG(2) << "Induction variable is at index " << *indvar_tuple_idx - << " in input tuple."; - - // Now that we know the index of the induction variable, we can we can try to - // compute how many times the loop executes. Start by computing the induction - // variable's initial value. - HloEvaluator evaluator(/*max_loop_iterations=*/0); - auto* while_init = while_op->mutable_operand(0); - auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); - StatusOr> indvar_init_result = - evaluator.Evaluate(indvar_init); - if (!indvar_init_result.ok()) { - VLOG(2) << "Couldn't evaluate induction variable init: " - << indvar_init_result.status(); - return nullopt; - } - - // Evaluates the while loop's condition, returning either "true" (continue - // looping), "false" (stop looping), or nullopt (can't evaluate). - auto evaluate_while_cond = [&](const Literal& indvar) -> optional { - auto* while_cond = while_op->while_condition(); - auto* while_cond_root = while_cond->root_instruction(); - auto* while_cond_indvar = NonConstantOperand(while_cond_root); - StatusOr> result = - evaluator.EvaluateWithSubstitutions(while_cond_root, - {{while_cond_indvar, &indvar}}); - if (!result.ok()) { - VLOG(2) << "Couldn't evaluate while cond: " << result.status(); - return nullopt; - } - return result.ValueOrDie()->data() == - tensorflow::gtl::ArraySlice{true}; - }; - - // The initial value of the induction variable. - const Literal& indvar_iter0_val = *indvar_init_result.ValueOrDie(); - - // Evaluate whether the while condition is true when seeded with - // indvar_iter0_val. - optional while_cond_iter0_val = evaluate_while_cond(indvar_iter0_val); - if (while_cond_iter0_val == false) { - VLOG(2) << "Loop has static trip count of 0."; - return 0; - } - - // Calculate the value of the induction variable after one iteration of the - // loop, and check whether the while condition is true with this new value. - auto* while_body = while_op->while_body(); - auto* while_body_indvar_update = - while_body->root_instruction()->operand(*indvar_tuple_idx); - auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); - StatusOr> indvar_iter1_result = - evaluator.EvaluateWithSubstitutions( - while_body_indvar_update, {{while_body_indvar, &indvar_iter0_val}}); - if (!indvar_iter1_result.ok()) { - VLOG(2) << "Couldn't evaluate induction variable update: " - << indvar_iter1_result.status(); - return nullopt; - } - const Literal& indvar_iter1_val = *indvar_iter1_result.ValueOrDie(); - optional while_cond_iter1_val = evaluate_while_cond(indvar_iter1_val); - if (while_cond_iter1_val == false) { - VLOG(2) << "Determined that loop has static trip count of 1."; - return 1; - } - - VLOG(2) << "Loop has unknown trip count >= 1."; - return nullopt; -} - // Tries to remove elements in a while loop's tuple that aren't used within the // loop. // @@ -577,7 +355,9 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { } // Remove while loops with static trip count of 0. - optional trip_count = GetLoopTripCount(while_op); + optional trip_count = + ComputeWhileLoopTripCount(while_op, + /*max_value_returned=*/1); if (trip_count && *trip_count == 0) { // The loop never executes, so the value of the loop is the value of its // "init" operand. diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 0536c99b671ff37d67bb0fc7f9ab0b806d15f016..2e1571943e537f772ee7dcd95c80ba540445b76e 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -157,7 +157,7 @@ TEST_F(WhileLoopSimplifierTest, auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* true_op = while_op->while_body()->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); TF_ASSERT_OK(true_op->AddControlDependencyTo( while_op->while_body()->root_instruction())); ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); @@ -175,9 +175,11 @@ TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) { auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); + auto* token = while_body->AddInstruction(HloInstruction::CreateToken()); auto* send = while_body->AddInstruction(HloInstruction::CreateSend( while_body->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))), + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))), + token, /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateSendDone(send)); EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); @@ -190,8 +192,9 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); + auto* token = while_body->AddInstruction(HloInstruction::CreateToken()); auto* recv = while_body->AddInstruction( - HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), + HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateRecvDone(recv)); EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); @@ -208,7 +211,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) { auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); - auto token = while_body->AddInstruction(HloInstruction::CreateAfterAll({})); + auto token = while_body->AddInstruction(HloInstruction::CreateToken()); while_body->AddInstruction(HloInstruction::CreateInfeed( ShapeUtil::MakeShape(F32, {1}), token, "config")); EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index 473eab2ea84eb8faf745cbe299bc80bcc1b62a35..1ef17b9d7d2e769aadf39f8a70f78200b88e9d2c 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/tuple_util.h" @@ -38,7 +39,7 @@ static StatusOr WidenWhileCondition( // the root instruction later. We later change the root instruction to // something more appropriate. builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); return narrow_condition->parent()->AddEmbeddedComputation(builder.Build()); }(); @@ -154,7 +155,7 @@ MakeCountedLoopConditionComputation(const Shape& loop_state_shape, {&loop_state_shape}, scalar_pred, "while_cond")); HloInstruction* trip_count_constant = cond_computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(trip_count))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(trip_count))); HloInstruction* param = cond_computation->parameter_instruction(0); TF_ASSIGN_OR_RETURN(HloInstruction * indvar, @@ -175,7 +176,7 @@ static StatusOr> MakeCountedLoopBodyComputation( CreateComputationWithSignature( {&loop_state_shape}, loop_state_shape, "while_body")); HloInstruction* one = body_computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); HloInstruction* param = body_computation->parameter_instruction(0); TF_ASSIGN_OR_RETURN(HloInstruction * indvar, MakeGetTupleElementHlo(param, 0)); @@ -203,7 +204,7 @@ static StatusOr MakeInitTupleFromInitValues( std::vector init_values_with_indvar; init_values_with_indvar.reserve(init_values.size() + 1); HloInstruction* zero = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); init_values_with_indvar.push_back(zero); c_copy(init_values, std::back_inserter(init_values_with_indvar)); return computation->AddInstruction( diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc index 44b0ec5cd4c1d406467007fcc530e919d602c438..83d696fe0915086c3c98b6d7cbdaeaeb4d9d0bdb 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -32,7 +32,8 @@ StatusOr ZeroSizedHloElimination::Run(HloModule* module) { for (HloComputation* comp : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { if (instruction->HasSideEffect() || - !ShapeUtil::IsArray(instruction->shape())) { + !ShapeUtil::IsArray(instruction->shape()) || + instruction->opcode() == HloOpcode::kConstant) { continue; } if (comp->IsRemovable(instruction) && diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index f5331280ee9f252aa5717baab88f2c203be5c372..b9ef18892d7aa859f6b0b505db4c004e4f5c5066 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -67,7 +67,16 @@ TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateParameter) { } TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateSideEffects) { - builder_.AddInstruction(HloInstruction::CreateSend(zero_sized_param_, 0)); + auto token = builder_.AddInstruction(HloInstruction::CreateToken()); + builder_.AddInstruction( + HloInstruction::CreateSend(zero_sized_param_, token, 0)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination()); + EXPECT_FALSE(changed); +} + +TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateConstant) { + builder_.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination()); EXPECT_FALSE(changed); } diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index 7ee366b27a82bdbcb7a63a57ea80194db8ca7df4..caad31d6ce7ce35fa362ec364b0d7f1d95973715 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -67,6 +67,14 @@ void ShapeLayout::ResetLayout(const Layout& layout) { TF_CHECK_OK(ShapeUtil::ValidateShape(shape_)); } +void ShapeLayout::ResetLayout(const Layout& layout, + ShapeIndexView shape_index) { + CHECK(ShapeUtil::IsTuple(shape_)); + *ShapeUtil::GetMutableSubshape(&shape_, shape_index)->mutable_layout() = + layout; + TF_CHECK_OK(ShapeUtil::ValidateShape(shape_)); +} + bool ShapeLayout::operator==(const ShapeLayout& other) const { return ShapeUtil::Equal(shape_, other.shape_); } diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h index 36806da599cc9b27286e67c128bb7f496f29c105..214cf98854938414c23c5031f4114016140ae9a7 100644 --- a/tensorflow/compiler/xla/shape_layout.h +++ b/tensorflow/compiler/xla/shape_layout.h @@ -72,6 +72,10 @@ class ShapeLayout { // tuple. void ResetLayout(const Layout& layout); + // Resets the layout on the shape at the provided ShapeIndex to the provided + // layout. Shape must be a tuple. + void ResetLayout(const Layout& layout, ShapeIndexView shape_index); + // Returns a string representation of this object. string ToString() const { return ShapeUtil::HumanStringWithLayout(shape_); } diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 4aacc87b78e2c271829cdf397cd69bfb490125b8..c74dd648addd70633edc2ec10a60879a00942716 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -44,10 +44,6 @@ struct ShapeTreeNode { // Data corresponding to this node. std::pair data; - // Children of this node, as indices into the container's nodes_ array. - std::vector children; - - // Tells whether this is a leaf node. bool is_leaf = true; explicit ShapeTreeNode(ShapeIndex index) @@ -56,6 +52,20 @@ struct ShapeTreeNode { : data(std::move(index), std::move(data)) {} }; +// Internal representation of an index table entry. +struct IndexTableEntry { + // Index of the node in the ShapeTreeNode vector. + uint32 index; + // Index of the first child in a IndexTableEntry vector. In the index + // table all children entries for a given node will be placed next to each + // other. This allows us to use a single field to index them. + uint32 children_start; +#ifndef NDEBUG + // Number of children, used for bounds checking. + uint32 children_count; +#endif +}; + } // namespace internal template @@ -84,6 +94,7 @@ template class ShapeTree { public: using Node = internal::ShapeTreeNode; + using Index = internal::IndexTableEntry; // Default constructor creates a tree with a nil shape (i.e. an empty tuple). ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} @@ -267,11 +278,12 @@ class ShapeTree { private: // Initialize node->children based on 'shape'. All children are assigned the // the given 'init_value'. - void InitChildren(const Shape& shape, const T& init_value, Node* node); + void InitChildren(const Shape& shape, const T& init_value, Node* node, + Index* index); // Initialize node->children based on 'shape'. All children have // default-constructed data values. - void InitChildren(const Shape& shape, Node* node); + void InitChildren(const Shape& shape, Node* node, Index* index); // Returns the number of subshapes, including interior nodes, in shape. int64 CountSubshapes(const Shape& shape); @@ -291,6 +303,9 @@ class ShapeTree { // The nodes in this shape tree. std::vector nodes_; + // Index table for node lookups. + std::vector index_table_; + // If we own our Shape, this field contains it, and shape_ is a pointer into // here. Otherwise if we don't own our shape, this is nullptr. std::shared_ptr shape_storage_; @@ -373,36 +388,74 @@ int64 ShapeTree::CountSubshapes(const Shape& shape) { template void ShapeTree::InitChildren(const Shape& shape, const T& init_value, - Node* node) { + Node* node, Index* index) { if (ShapeUtil::IsTuple(shape)) { const int64 size = ShapeUtil::TupleElementCount(shape); - node->children.reserve(size); +#ifndef NDEBUG + index->children_count = size; +#endif node->is_leaf = false; ShapeIndex shape_index = node->data.first; shape_index.push_back(0); + + // At the end of the index_table, reserve a continuous space to hold the + // children of current node. In order to enforce the invariant that all + // children of a given node are placed together, we need to do the + // reservation before we recurse into any of its children. + int64 children_start_position = index_table_.size(); + index_table_.resize(index_table_.size() + size); + for (int i = 0; i < size; ++i) { shape_index[shape_index.size() - 1] = i; - node->children.push_back(nodes_.size()); + index_table_[children_start_position + i].index = nodes_.size(); + // The first child of the node in the index table is placed at the end of + // the table. + index_table_[children_start_position + i].children_start = + index_table_.size(); nodes_.emplace_back(shape_index, init_value); - InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back()); + InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back(), + &index_table_[children_start_position + i]); } + } else { +#ifndef NDEBUG + index->children_count = 0; +#endif } } template -void ShapeTree::InitChildren(const Shape& shape, Node* node) { +void ShapeTree::InitChildren(const Shape& shape, Node* node, Index* index) { if (ShapeUtil::IsTuple(shape)) { const int64 size = ShapeUtil::TupleElementCount(shape); - node->children.reserve(size); +#ifndef NDEBUG + index->children_count = size; +#endif node->is_leaf = false; ShapeIndex shape_index = node->data.first; shape_index.push_back(0); + + // At the end of the index_table, reserve a continuous space to hold the + // children of current node. In order to enforce the invariant that all + // children of a given node are placed together, we need to do the + // reservation before we recurse into any of its children. + int64 children_start_position = index_table_.size(); + index_table_.resize(index_table_.size() + size); + for (int i = 0; i < size; ++i) { shape_index[shape_index.size() - 1] = i; - node->children.push_back(nodes_.size()); + index_table_[children_start_position + i].index = nodes_.size(); + // The first child of the node in the index table is placed at the end of + // the table. + index_table_[children_start_position + i].children_start = + index_table_.size(); nodes_.emplace_back(shape_index); - InitChildren(shape.tuple_shapes(i), &nodes_.back()); + InitChildren(shape.tuple_shapes(i), &nodes_.back(), + &index_table_[children_start_position + i]); } + } else { +#ifndef NDEBUG + index->children_count = 0; +#endif } } @@ -413,24 +466,36 @@ ShapeTree::ShapeTree(Shape shape) // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(shape_storage_.get()); - nodes_.reserve(CountSubshapes(*shape_)); + const int64 count = CountSubshapes(*shape_); + nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}); - InitChildren(*shape_, &nodes_[0]); + + index_table_.reserve(count); + index_table_.emplace_back(Index{0, 1}); + InitChildren(*shape_, &nodes_[0], &index_table_[0]); } template ShapeTree::ShapeTree(const Shape* shape) : shape_(shape) { - nodes_.reserve(CountSubshapes(*shape_)); + const int64 count = CountSubshapes(*shape_); + nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}); - InitChildren(*shape_, &nodes_[0]); + + index_table_.reserve(count); + index_table_.emplace_back(Index{0, 1}); + InitChildren(*shape_, &nodes_[0], &index_table_[0]); } template ShapeTree::ShapeTree(const std::shared_ptr& shape) : shape_storage_(shape), shape_(shape_storage_.get()) { - nodes_.reserve(CountSubshapes(*shape_)); + const int64 count = CountSubshapes(*shape_); + nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}); - InitChildren(*shape_, &nodes_[0]); + + index_table_.reserve(count); + index_table_.emplace_back(Index{0, 1}); + InitChildren(*shape_, &nodes_[0], &index_table_[0]); } template @@ -440,26 +505,38 @@ ShapeTree::ShapeTree(Shape shape, const T& init_value) // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(shape_storage_.get()); - nodes_.reserve(CountSubshapes(*shape_)); + const int64 count = CountSubshapes(*shape_); + nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}, init_value); - InitChildren(*shape_, init_value, &nodes_[0]); + + index_table_.reserve(count); + index_table_.emplace_back(Index{0, 1}); + InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]); } template ShapeTree::ShapeTree(const Shape* shape, const T& init_value) : shape_(shape) { - nodes_.reserve(CountSubshapes(*shape_)); + const int64 count = CountSubshapes(*shape_); + nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}, init_value); - InitChildren(*shape_, init_value, &nodes_[0]); + + index_table_.reserve(count); + index_table_.emplace_back(Index{0, 1}); + InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]); } template ShapeTree::ShapeTree(const std::shared_ptr& shape, const T& init_value) : shape_storage_(shape), shape_(shape_storage_.get()) { - nodes_.reserve(CountSubshapes(*shape_)); + const int64 count = CountSubshapes(*shape_); + nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}, init_value); - InitChildren(*shape_, init_value, &nodes_[0]); + + index_table_.reserve(count); + index_table_.emplace_back(Index{0, 1}); + InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]); } template @@ -474,13 +551,16 @@ T* ShapeTree::mutable_element(ShapeIndexView index) { template internal::ShapeTreeNode* ShapeTree::Lookup(ShapeIndexView index) { - Node* node = &nodes_[0]; + Index* iter = &index_table_[0]; for (const int64 i : index) { CHECK_GE(i, 0); - CHECK_LT(i, node->children.size()); - node = &nodes_[node->children[i]]; +#ifndef NDEBUG + CHECK_LT(i, iter->children_count); +#endif + iter = &index_table_[iter->children_start + i]; } - return node; + + return &nodes_[iter->index]; } template diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index 51de82e95746281ed6e587b545dc933b48ce1ad4..c4c958be4a18f23b8e34f9e619e447c6bf4334b5 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -172,7 +172,7 @@ TEST_F(ShapeTreeTest, TupleShape) { // Write zero to all data elements. shape_tree.ForEachMutableElement( - [&sum](const ShapeIndex& /*index*/, int* data) { *data = 0; }); + [](const ShapeIndex& /*index*/, int* data) { *data = 0; }); EXPECT_EQ(0, shape_tree.element({})); EXPECT_EQ(0, shape_tree.element({0})); EXPECT_EQ(0, shape_tree.element({1})); @@ -227,14 +227,16 @@ TEST_F(ShapeTreeTest, NestedTupleShape) { TEST_F(ShapeTreeTest, InvalidIndexingTuple) { ShapeTree shape_tree{tuple_shape_}; - +#ifndef NDEBUG EXPECT_DEATH(shape_tree.element({4}), ""); +#endif } TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) { ShapeTree shape_tree{nested_tuple_shape_}; - +#ifndef NDEBUG EXPECT_DEATH(shape_tree.element({0, 0}), ""); +#endif } TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) { @@ -602,12 +604,15 @@ void BM_Iterate(int iters, int depth, int fan_out) { } } -BENCHMARK(BM_Construct)->ArgPair(2, 8); -BENCHMARK(BM_ConstructUnowned)->ArgPair(2, 8); -BENCHMARK(BM_Copy)->ArgPair(2, 8); -BENCHMARK(BM_Move)->ArgPair(2, 8); -BENCHMARK(BM_ForEach)->ArgPair(2, 8); -BENCHMARK(BM_Iterate)->ArgPair(2, 8); +#define BENCHMARK_WITH_ARGS(name) \ + BENCHMARK(name)->ArgPair(2, 8)->ArgPair(1, 1000) + +BENCHMARK_WITH_ARGS(BM_Construct); +BENCHMARK_WITH_ARGS(BM_ConstructUnowned); +BENCHMARK_WITH_ARGS(BM_Copy); +BENCHMARK_WITH_ARGS(BM_Move); +BENCHMARK_WITH_ARGS(BM_ForEach); +BENCHMARK_WITH_ARGS(BM_Iterate); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 2166c34358fa62815c3fb32f28392f9036e25158..34869cc5078699603c006387161fddd4fee4a9f8 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -46,28 +46,14 @@ namespace xla { using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; -string ShapeIndex::ToString() const { - return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}"); -} +string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); } string ShapeIndexView::ToString() const { - return StrCat("{", - tensorflow::str_util::Join( - tensorflow::gtl::make_range(begin_, end_), ","), - "}"); + return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}"); } bool ShapeIndexView::operator==(const ShapeIndexView& other) const { - if (size() != other.size()) { - return false; - } - for (auto it = begin(), other_it = other.begin(); it != end(); - ++it, ++other_it) { - if (*it != *other_it) { - return false; - } - } - return true; + return indices_ == other.indices_; } bool ShapeIndexView::operator!=(const ShapeIndexView& other) const { @@ -610,8 +596,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { }; auto comma_list_to_int64s = - [&s, - string_to_int64](const string& input) -> StatusOr> { + [string_to_int64](const string& input) -> StatusOr> { std::vector results; for (const string& piece : tensorflow::str_util::Split(input, ',')) { TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece)); @@ -696,7 +681,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { CompatibleIgnoringElementType); } else { // Opaque, token, etc types are vacuously compatible. - return true; + return lhs.element_type() == rhs.element_type(); } } @@ -711,7 +696,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { CompatibleIgnoringFpPrecision); } else { // Opaque, token, etc types are vacuously compatible. - return true; + return lhs.element_type() == rhs.element_type(); } } @@ -806,7 +791,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (LayoutUtil::IsSparseArray(shape)) { allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout()); } else { - CHECK(LayoutUtil::IsDenseArray(shape)); + CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString(); tensorflow::gtl::ArraySlice padded_dimensions = LayoutUtil::PaddedDimensions(shape); if (!padded_dimensions.empty()) { @@ -891,44 +876,62 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) { VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape); - auto invalid_argument = - InvalidArgument("Shape %s size may overflow int64.", - ShapeUtil::HumanString(shape).c_str()); + if (!IsArray(shape)) { return Status::OK(); } - int64 shape_size; - if (LayoutUtil::IsSparseArray(shape)) { - shape_size = LayoutUtil::MaxSparseElements(shape.layout()); - if (shape_size < 0) { - return invalid_argument; - } - shape_size = MultiplyWithoutOverflow(shape_size, ShapeUtil::Rank(shape)); - if (shape_size < 0) { - return invalid_argument; + + int64 shape_size = [&shape]() { + if (LayoutUtil::IsSparseArray(shape)) { + int64 max_sparse_elements = LayoutUtil::MaxSparseElements(shape.layout()); + if (max_sparse_elements < 0) { + return max_sparse_elements; + } + int64 sparse_elements_size = MultiplyWithoutOverflow( + max_sparse_elements, ByteSizeOfPrimitiveType(shape.element_type())); + if (sparse_elements_size < 0) { + return sparse_elements_size; + } + int64 sparse_indices_size = + MultiplyWithoutOverflow(max_sparse_elements, ShapeUtil::Rank(shape)); + if (sparse_indices_size < 0) { + return sparse_indices_size; + } + sparse_indices_size = + MultiplyWithoutOverflow(sparse_indices_size, sizeof(int64)); + if (sparse_indices_size < 0) { + return sparse_indices_size; + } + // At this point, both sparse_indices_size and sparse_elements_size are + // non-negative, so we can easily check if adding them wraps. + if (static_cast(sparse_elements_size) + + static_cast(sparse_indices_size) > + INT64_MAX) { + return static_cast(-1); + } } - shape_size = MultiplyWithoutOverflow(shape_size, sizeof(int64)); - if (shape_size < 0) { - return invalid_argument; + + // This is intentionally unconditional: even if the shape is sparse, we want + // to verify the densified version has a reasonable size. + int64 dense_shape_size = 1; + if (shape.dimensions().empty()) { + return dense_shape_size; } - } - // This is intentionally unconditional: even if the shape is sparse, we want - // to verify the densified version has a reasonable size. - if (shape.dimensions().empty()) { - return Status::OK(); - } - shape_size = 1; - for (int64 dim : shape.dimensions()) { - shape_size = MultiplyWithoutOverflow(shape_size, dim); - if (shape_size < 0) { - return invalid_argument; + for (int64 dim : shape.dimensions()) { + dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, dim); + if (dense_shape_size < 0) { + return dense_shape_size; + } } - } - shape_size = MultiplyWithoutOverflow( - shape_size, ByteSizeOfPrimitiveType(shape.element_type())); + dense_shape_size = MultiplyWithoutOverflow( + dense_shape_size, ByteSizeOfPrimitiveType(shape.element_type())); + return dense_shape_size; + }(); + if (shape_size < 0) { - return invalid_argument; + return InvalidArgument("Shape %s size may overflow int64.", + ShapeUtil::HumanString(shape).c_str()); } VLOG(3) << "Shape size is valid: " << shape_size; @@ -1119,12 +1122,41 @@ Status ForEachMutableSubshapeHelper( for (auto dim : Permute(permutation, shape.dimensions())) { new_shape.add_dimensions(dim); } + + // If `shape` has a layout, by contract we choose a new layout such that the + // transpose defined by this permutation is a bitcast. + // + // Some formalism helps to understand the correct way to do this. We're going + // to do algebra in the group of permutations of the dimensions of `shape`. + // + // Since the order of `shape`'s dimensions is not permuted relative to itself, + // `shape`'s list of dimensions is isomorphic to the identity I. + // + // Let `shape`'s layout be L. A layout is a permutation which maps a + // minor-to-major physical layout to the order of a shape's logical dims. + // Therefore inverse of a layout maps from logical to physical dims, and so + // the physical layout of I is simply L'.I = L', where L' is the inverse of L. + // + // Let the argument `permutation` be P. This is a permutation over `shape`'s + // dimensions, so our return value will be a shape with dims P.I = P. Our + // goal is to construct a layout permutation L* that we can apply to P such + // that that the physical dimension ordering of the returned shape is the same + // as that of the original shape, namely L'. + // + // Our returned shape has dims P and layout L*, so its in-memory layout is + // L*'.P. Setting this equal to L' and solving for L*, we get: + // + // L*'.P = L' => + // L*' = L'P' => + // L* = P.L + // if (shape.has_layout()) { CHECK(LayoutUtil::IsDenseArray(shape)); Layout* new_layout = new_shape.mutable_layout(); new_layout->set_format(DENSE); new_layout->clear_minor_to_major(); - for (auto index : Permute(permutation, shape.layout().minor_to_major())) { + for (auto index : ComposePermutations( + permutation, AsInt64Slice(shape.layout().minor_to_major()))) { new_layout->add_minor_to_major(index); } if (shape.layout().padded_dimensions_size() > 0) { @@ -1134,6 +1166,13 @@ Status ForEachMutableSubshapeHelper( new_layout->add_padded_dimensions(dim); } } + // The permutation accepted by TransposeIsBitcast is the inverse of the + // permutation here. + CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation))) + << "shape=" << HumanStringWithLayout(shape) + << ", new_shape=" << HumanStringWithLayout(new_shape) + << ", permutation={" << tensorflow::str_util::Join(permutation, ",") + << "}"; } return new_shape; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 5ae04451d32bd733dce55c4a56f5ebc1882d9fbd..d6f17fc965d24bbbbd083b8dd0ec11a59e49ed4e 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" @@ -73,10 +74,12 @@ class ShapeIndex { // push_front is O(n^2), but shapes don't usually have a ton of dimensions. void push_front(int64 value) { indices_.insert(indices_.begin(), value); } - std::vector::const_iterator begin() const { return indices_.begin(); } - std::vector::const_iterator end() const { return indices_.end(); } - std::vector::iterator begin() { return indices_.begin(); } - std::vector::iterator end() { return indices_.end(); } + using container_type = tensorflow::gtl::InlinedVector; + + container_type::const_iterator begin() const { return indices_.begin(); } + container_type::const_iterator end() const { return indices_.end(); } + container_type::iterator begin() { return indices_.begin(); } + container_type::iterator end() { return indices_.end(); } const int64* data() const { return indices_.data(); } @@ -97,7 +100,7 @@ class ShapeIndex { string ToString() const; private: - std::vector indices_; + container_type indices_; }; // A view into a ShapeIndex as above, with the cheap/easy ability to consume the @@ -110,31 +113,33 @@ class ShapeIndex { class ShapeIndexView { public: ShapeIndexView(const ShapeIndex& shape_index, int64 offset = 0) - : ShapeIndexView(shape_index.data() + offset, - shape_index.data() + shape_index.size()) { + : indices_(shape_index.data() + offset, shape_index.size() - offset) { CHECK_LE(offset, shape_index.size()); } - ShapeIndexView(std::initializer_list indices) - : ShapeIndexView(indices.begin(), indices.end()) {} + ShapeIndexView(std::initializer_list indices) : indices_(indices) {} ShapeIndexView(const ShapeIndexView& other) = default; using iterator = const int64*; - iterator begin() const { return begin_; } - iterator end() const { return end_; } - int64 size() const { return std::distance(begin_, end_); } - bool empty() const { return begin_ == end_; } + iterator begin() const { return indices_.begin(); } + iterator end() const { return indices_.end(); } + int64 size() const { return indices_.size(); } + bool empty() const { return indices_.empty(); } int64 front() const { CHECK(!empty()); - return *begin_; + return indices_.front(); } ShapeIndexView ConsumeFront() const { - CHECK(!empty()); - auto new_begin = begin_; - ++new_begin; - return ShapeIndexView(new_begin, end_); + ShapeIndexView result = *this; + result.indices_.pop_front(); + return result; + } + ShapeIndexView ConsumeBack() const { + ShapeIndexView result = *this; + result.indices_.pop_back(); + return result; } - ShapeIndex ToShapeIndex() const { return ShapeIndex(begin_, end_); } + ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); } bool operator==(const ShapeIndexView& other) const; bool operator!=(const ShapeIndexView& other) const; @@ -142,10 +147,7 @@ class ShapeIndexView { string ToString() const; private: - ShapeIndexView(iterator begin, iterator end) : begin_(begin), end_(end) {} - - iterator begin_; - iterator end_; + tensorflow::gtl::ArraySlice indices_; }; std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index); @@ -530,7 +532,13 @@ class ShapeUtil { static bool HasDegenerateDimensions(const Shape& shape); // Permutes the dimensions by the given permutation, so - // return_value.dimensions[permutation[i]] = argument.dimensions[i] + // return_value.dimensions[permutation[i]] = argument.dimensions[i]. + // + // Postcondition: For any valid permutation, + // + // !HasLayout(shape) || + // TransposeIsBitcast(shape, PermuteDimensions(permutation, shape), + // InversePermutation(permutation)). static Shape PermuteDimensions(tensorflow::gtl::ArraySlice permutation, const Shape& shape); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index b6f30af381dd8d24ff28fdf7f729d6cb3df46ec9..e5dd62ae9a3dd9b961a7ae03a99c19220dbd43e7 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" +#include #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" @@ -22,12 +23,23 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { using ::testing::ElementsAre; +TEST(ShapeUtilTest, ShapeIndexViewTest) { + ShapeIndex index = {1, 2, 3, 4}; + ShapeIndexView index_view(index, 1); + EXPECT_EQ(3, index_view.size()); + EXPECT_EQ(ShapeIndexView({2, 3, 4}), index_view); + EXPECT_EQ(ShapeIndexView({3, 4}), index_view.ConsumeFront()); + EXPECT_EQ(ShapeIndexView({2, 3}), index_view.ConsumeBack()); +} + TEST(ShapeUtilTest, GetDimensionHelperCanNegativeIndex) { Shape matrix = ShapeUtil::MakeShape(F32, {2, 3}); EXPECT_EQ(3, ShapeUtil::GetDimension(matrix, -1)); @@ -322,6 +334,17 @@ TEST(ShapeUtilTest, IncompatibleScalarVsTuple) { EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape2, shape1)); } +TEST(ShapeUtilTest, OpaqueVsArray) { + Shape shape1 = ShapeUtil::MakeShape(F32, {5, 7}); + Shape shape2 = ShapeUtil::MakeOpaqueShape(); + EXPECT_FALSE(ShapeUtil::Compatible(shape1, shape2)); + EXPECT_FALSE(ShapeUtil::Compatible(shape2, shape1)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape2, shape1)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(shape1, shape2)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(shape2, shape1)); +} + TEST(ShapeUtilTest, CompareShapesWithPaddedDimensionsMismatch) { Shape shape1 = ShapeUtil::MakeShape(F32, {20, 30}); shape1.mutable_layout()->add_padded_dimensions(10); @@ -821,6 +844,28 @@ TEST(ShapeUtilTest, HasDegenerateDimensions) { ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 0, 5}))); } +TEST(ShapeUtilTest, PermuteDimensionsLayout) { + std::vector layout(3); + std::iota(layout.begin(), layout.end(), 0); + do { + Shape s = ShapeUtil::MakeShapeWithLayout(F32, {10, 100, 1000}, layout); + SCOPED_TRACE(tensorflow::strings::StrCat("s=", ShapeUtil::HumanString(s))); + + std::vector permutation(3); + std::iota(permutation.begin(), permutation.end(), 0); + do { + SCOPED_TRACE(tensorflow::strings::StrCat( + "permutation=", tensorflow::str_util::Join(permutation, ","))); + + // TransposeIsBitcast takes the inverse of the permutation that + // PermuteDimensions takes. + EXPECT_TRUE(ShapeUtil::TransposeIsBitcast( + s, ShapeUtil::PermuteDimensions(permutation, s), + InversePermutation(permutation))); + } while (std::next_permutation(permutation.begin(), permutation.end())); + } while (std::next_permutation(layout.begin(), layout.end())); +} + TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast( ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}), diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 5a45e2e610218d8401ae9600db7041311b607bdd..42d52aee780e2aade0f2ed3597e653567b8da49b 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -65,6 +65,7 @@ cc_library( srcs = ["test_utils.cc"], hdrs = ["test_utils.h"], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", @@ -88,6 +89,7 @@ cc_library( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:error_spec", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_comparison", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test", @@ -152,8 +154,8 @@ tf_cc_binary( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", @@ -179,6 +181,7 @@ cc_library( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -189,8 +192,8 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -209,6 +212,7 @@ cc_library( deps = [ ":codegen_test_base", ":filecheck", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:test", @@ -258,7 +262,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", @@ -286,8 +290,8 @@ xla_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -302,7 +306,7 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -310,8 +314,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -330,8 +334,8 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -345,16 +349,16 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -372,9 +376,10 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -391,8 +396,8 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -406,7 +411,7 @@ xla_test( tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -415,9 +420,9 @@ xla_test( "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -435,14 +440,14 @@ xla_test( tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -460,9 +465,9 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -479,8 +484,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -497,8 +502,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -515,9 +520,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -531,6 +536,7 @@ xla_test( srcs = ["scalar_computations_test.cc"], shard_count = 32, deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -538,8 +544,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -557,8 +563,8 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -573,7 +579,7 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -581,8 +587,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -599,7 +605,7 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -607,8 +613,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -633,7 +639,7 @@ xla_test( deps = [ ":client_library_test_base", ":literal_test_util", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -645,7 +651,7 @@ xla_test( tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -653,7 +659,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -676,8 +682,8 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -697,7 +703,7 @@ xla_test( "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -720,8 +726,8 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -744,8 +750,8 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -764,11 +770,12 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -780,7 +787,7 @@ xla_test( CONVOLUTION_TEST_DEPS = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -789,7 +796,7 @@ CONVOLUTION_TEST_DEPS = [ "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -827,13 +834,13 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -856,8 +863,8 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -874,7 +881,7 @@ xla_test( ":test_utils", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -885,9 +892,10 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -906,7 +914,7 @@ xla_test( ":test_utils", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -917,9 +925,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -939,12 +947,12 @@ xla_test( ], deps = [ ":test_utils", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -965,7 +973,7 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -984,8 +992,8 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1006,7 +1014,7 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", @@ -1030,14 +1038,15 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1057,9 +1066,9 @@ xla_test( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1078,6 +1087,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", @@ -1087,9 +1097,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1114,9 +1124,9 @@ xla_test_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1148,16 +1158,16 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1175,10 +1185,10 @@ xla_test( deps = [ ":client_library_test_base", "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1227,12 +1237,13 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1245,11 +1256,12 @@ xla_test( name = "custom_call_test", srcs = ["custom_call_test.cc"], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -1272,8 +1284,8 @@ xla_test( "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1290,12 +1302,13 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1315,8 +1328,8 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1333,8 +1346,8 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1350,8 +1363,8 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1367,7 +1380,7 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -1375,8 +1388,8 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1390,14 +1403,14 @@ xla_test( name = "prng_test", srcs = ["prng_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -1415,6 +1428,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", @@ -1425,8 +1439,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1445,7 +1459,7 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1468,9 +1482,9 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1494,8 +1508,8 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1514,8 +1528,8 @@ xla_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1529,15 +1543,15 @@ xla_test( name = "cross_replica_sum_test", srcs = ["cross_replica_sum_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1559,7 +1573,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1573,15 +1587,15 @@ xla_test( name = "compilation_cache_test", srcs = ["compilation_cache_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1599,8 +1613,8 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1613,7 +1627,7 @@ xla_test( name = "compute_constant_test", srcs = ["compute_constant_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1622,8 +1636,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1643,8 +1657,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1660,8 +1674,8 @@ xla_test( deps = [ ":client_library_test_base", "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -1674,8 +1688,8 @@ xla_test( deps = [ ":client_library_test_base", "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -1688,15 +1702,15 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1713,7 +1727,7 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -1730,6 +1744,7 @@ tf_cc_test( srcs = ["llvm_compiler_test.cc"], tags = ["requires-gpu-sm35"], deps = [ + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:cpu_plugin", @@ -1750,7 +1765,7 @@ xla_test( name = "round_trip_packed_literal_test", srcs = ["round_trip_packed_literal_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:packed_literal_reader", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -1773,15 +1788,16 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -1800,14 +1816,14 @@ xla_test( srcs = ["multioutput_fusion_test.cc"], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", @@ -1840,11 +1856,11 @@ xla_test( name = "local_client_allocation_test", srcs = ["local_client_allocation_test.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1863,7 +1879,7 @@ xla_test( shard_count = 30, tags = ["optonly"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -1871,8 +1887,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", @@ -1888,6 +1904,16 @@ xla_test( ], ) +xla_test( + name = "outfeed_in_nested_computation_test", + srcs = ["outfeed_in_nested_computation_test.cc"], + deps = [ + "//tensorflow/compiler/xla/tests:local_client_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "hlo_metadata_test", srcs = [ @@ -1897,7 +1923,7 @@ tf_cc_test( ":local_client_test_base", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/core:test_main", @@ -1909,7 +1935,7 @@ xla_test( srcs = ["round_trip_transfer_test.cc"], deps = [ "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", @@ -1930,7 +1956,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1939,8 +1965,8 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1953,7 +1979,7 @@ xla_test( name = "deep_graph_test", srcs = ["deep_graph_test.cc"], deps = [ - "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -1978,7 +2004,7 @@ xla_test( ":literal_test_util", ":local_client_test_base", ":xla_internal_test_main", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -1986,6 +2012,7 @@ xla_test( "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:generic_transfer_manager", "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -2038,9 +2065,30 @@ xla_test( ":local_client_test_base", ":test_utils", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], ) + +xla_test( + name = "iota_test", + srcs = ["iota_test.cc"], + blacklisted_backends = [ + "cpu", + "gpu", + ], + tags = [ + "enable_for_xla_interpreter", + ], + deps = [ + ":client_library_test_base", + ":literal_test_util", + ":xla_internal_test_main", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 3bdf98544affca11fd825e28d20f4903188fe920..74f2e36f826cd82ce4015df857f3de67950beaeb 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -24,9 +24,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -225,7 +225,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 0x8000000000000000LL, 0x8000000000000000LL, 1}; - std::unique_ptr lhs_literal = Literal::CreateR1({lhs}); + std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); std::unique_ptr lhs_data = client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); @@ -239,7 +239,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 0, 1, 0x8000000000000000LL}; - std::unique_ptr rhs_literal = Literal::CreateR1({rhs}); + std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); std::unique_ptr rhs_data = client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); @@ -265,7 +265,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 1, 0, -1}; - std::unique_ptr lhs_literal = Literal::CreateR1({lhs}); + std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); std::unique_ptr lhs_data = client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); @@ -278,7 +278,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 0x7FFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL}; - std::unique_ptr rhs_literal = Literal::CreateR1({rhs}); + std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); std::unique_ptr rhs_data = client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); @@ -303,13 +303,13 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { b_values.push_back(2 * i / static_cast(count + 2)); } - std::unique_ptr a_literal = Literal::CreateR1({a_values}); + std::unique_ptr a_literal = LiteralUtil::CreateR1({a_values}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a_constant = ConstantR1(&builder, a_values); auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param"); - std::unique_ptr b_literal = Literal::CreateR1({b_values}); + std::unique_ptr b_literal = LiteralUtil::CreateR1({b_values}); std::unique_ptr b_data = client_->TransferToServer(*b_literal).ConsumeValueOrDie(); auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param"); @@ -1426,7 +1426,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { std::vector values = {1.0f, 2.0f, 3.2f, -4.0f}; std::vector exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr param_literal = Literal::CreateR1(values); + std::unique_ptr param_literal = LiteralUtil::CreateR1(values); std::unique_ptr param_data = client_->TransferToServer(*param_literal).ConsumeValueOrDie(); @@ -1454,10 +1454,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); @@ -1479,10 +1479,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); @@ -1504,10 +1504,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); @@ -1529,10 +1529,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); @@ -1555,15 +1555,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = client_->TransferToServer(*literal2).ConsumeValueOrDie(); auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); @@ -1587,15 +1587,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = client_->TransferToServer(*literal2).ConsumeValueOrDie(); @@ -1620,15 +1620,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = client_->TransferToServer(*literal2).ConsumeValueOrDie(); @@ -1654,19 +1654,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; std::vector values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f}; - std::unique_ptr literal0 = Literal::CreateR1(values0); + std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1(values1); + std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(*literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = Literal::CreateR1(values2); + std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = client_->TransferToServer(*literal2).ConsumeValueOrDie(); - std::unique_ptr literal3 = Literal::CreateR1(values3); + std::unique_ptr literal3 = LiteralUtil::CreateR1(values3); std::unique_ptr data3 = client_->TransferToServer(*literal3).ConsumeValueOrDie(); @@ -2101,12 +2101,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - Literal::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); + LiteralUtil::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -2123,12 +2123,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR3FromArray3D(Array3D(0, 7, 0)); + LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - Literal::CreateR3FromArray3D(Array3D(0, 7, 0)); + LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -2145,7 +2145,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -2201,7 +2201,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { // the input tensor is large enough to exercise the vectorized tanh // implementation on XLA CPU. XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateR1( + auto input_literal = LiteralUtil::CreateR1( {1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80, -0.67, 0.16, -0.07, 0.39, -0.41, 0.04, 1.36, 1.25, 0.41, 0.65, -1.08, 0.32, -1.45, -0.77, -1.09, 0.91, -1.03, -0.30, -1.11, -1.17, 1.50, -0.85, @@ -2243,7 +2243,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { // Just to help make sense of the scales here -- exp(89) saturates float32 and // exp(-10) is smaller than our error spec. - std::unique_ptr input_literal = Literal::CreateR1( + std::unique_ptr input_literal = LiteralUtil::CreateR1( {1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31, -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5, -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4, @@ -2277,7 +2277,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { // implementation on XLA CPU. XlaBuilder builder(TestName()); - std::unique_ptr input_literal = Literal::CreateR1( + std::unique_ptr input_literal = LiteralUtil::CreateR1( {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198, -167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9, 198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04, @@ -2469,9 +2469,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0}); Tuple(&builder, {cmp_dim_0, cmp_dim_1}); - auto expected = Literal::MakeTuple( - {Literal::CreateR2({{true, true}, {true, false}}).get(), - Literal::CreateR2({{true, false}, {false, false}}).get()}); + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{true, true}, {true, false}}).get(), + LiteralUtil::CreateR2({{true, false}, {false, false}}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -2825,8 +2825,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { std::iota(r1.begin(), r1.end(), 1.0); XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR4FromArray4DWithLayout( - r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); + std::unique_ptr a_literal = + LiteralUtil::CreateR4FromArray4DWithLayout( + r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); auto a = ConstantLiteral(&builder, *a_literal); auto b = ConstantR1(&builder, r1); Add(a, b, {1}); @@ -2887,8 +2888,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) { // broadcast. XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { XlaBuilder builder(TestName()); - auto x_literal = Literal::CreateR1({1, 2, 3}); - auto y_literal = Literal::CreateR1({4, 5}); + auto x_literal = LiteralUtil::CreateR1({1, 2, 3}); + auto y_literal = LiteralUtil::CreateR1({4, 5}); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index 8d15b7841bc7298cd6865d8689cc496c0459e4b9..caeb0bf49a0dde9eeac02037b2ea04fd024d100c 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc index 8c227df7f04e79ccc332062d0889d282c0f5e40f..af0b8522394a0c591e6c42ad12db8853ef66243c 100644 --- a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc +++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index d9d7ba1362a6975465971f4bc29da4d541e2f821..24b17b71007a1872462bed1f6b86ae1a5bb9922c 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -20,10 +20,11 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -62,7 +63,7 @@ class BatchNormalizationTest {5.0f, 4.4f}, // p2 }); input_array_.FillWithPZ(pz); - input_literal_ = std::move(*Literal::CreateR4FromArray4D(input_array_)); + input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_)); CHECK_EQ(kSamples, input_array_.planes()); CHECK_EQ(kZ, input_array_.depth()); CHECK_EQ(kY, input_array_.height()); @@ -118,7 +119,7 @@ XLA_TEST_P(BatchNormalizationTest, SubtractInZ) { XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) { XlaBuilder builder("square_tesseract_elementwise"); auto x = ConstantLiteral(&builder, input_literal_); - SquareF32(x); + Square(x); using tensorflow::MathUtil; @@ -150,7 +151,7 @@ XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) { auto activation_deviations = Sub(input_activations, set_means, /*broadcast_dimensions=*/{1}); XlaComputation add = CreateScalarAddComputation(F32, &builder); - auto dev_squares = SquareF32(activation_deviations); + auto dev_squares = Square(activation_deviations); Reduce(dev_squares, ConstantR0(&builder, 0.0f), add, {0, 2, 3}); std::vector expected = {18, 0.06}; @@ -160,7 +161,7 @@ XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) { XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) { XlaBuilder builder("variance_to_stddev"); auto variance = ConstantR1(&builder, {6.f, .02f}); - SqrtF32(variance); + Sqrt(variance); std::vector expected = {2.44948974f, 0.14142136f}; ComputeAndCompareR1(&builder, expected, {}, error_spec_); @@ -195,20 +196,20 @@ XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { auto epsilon2 = ConstantR1(&builder, {kEpsilon, kEpsilon}); auto activation_deviations = Sub(input_activations, set_means, /*broadcast_dimensions=*/{1}); - auto dev_squares = SquareF32(activation_deviations); + auto dev_squares = Square(activation_deviations); auto sum_of_squares = CheckShape(&builder, Reduce(dev_squares, ConstantR0(&builder, 0.0f), add, /*dimensions_to_reduce=*/{0, 2, 3}), TwoElementVectorF32); auto variance = Div(sum_of_squares, count); - auto standard_deviation = SqrtF32(variance); + auto standard_deviation = Sqrt(variance); auto standard_deviation_above_epsilon = CheckShape(&builder, Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2})); auto gt_eps = Select(standard_deviation_above_epsilon, standard_deviation, epsilon2); - auto normalization_factors = ReciprocalF32(gt_eps); + auto normalization_factors = Reciprocal(gt_eps); auto normalized_input_activations = Mul(activation_deviations, normalization_factors, /*broadcast_dimensions=*/{1}); @@ -241,12 +242,12 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, - {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, + {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) .get(), - Literal::CreateR1({4, 5}).get(), - Literal::CreateR1({5, 5}).get()}); + LiteralUtil::CreateR1({4, 5}).get(), + LiteralUtil::CreateR1({5, 5}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); } @@ -266,12 +267,12 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, - {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, + {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) .get(), - Literal::CreateR1({4, 5}).get(), - Literal::CreateR1({5, 5}).get()}); + LiteralUtil::CreateR1({4, 5}).get(), + LiteralUtil::CreateR1({5, 5}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); } @@ -297,11 +298,11 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { BatchNormTraining(h0, h1, h2, /*epsilon=*/1, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) .get(), - Literal::CreateR1(std::vector(260, 1.0f)).get(), - Literal::CreateR1(std::vector(260, 0.0f)).get()}); + LiteralUtil::CreateR1(std::vector(260, 1.0f)).get(), + LiteralUtil::CreateR1(std::vector(260, 0.0f)).get()}); ComputeAndCompareTuple(&builder, *expected, {operand.get(), scale.get(), offset.get()}, @@ -330,11 +331,12 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { BatchNormTraining(h0, h1, h2, /*epsilon=*/-100, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR3FromArray3D({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR3FromArray3D( + {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) .get(), - Literal::CreateR1(std::vector(1, 15.0f)).get(), - Literal::CreateR1(std::vector(1, 125.0f)).get()}); + LiteralUtil::CreateR1(std::vector(1, 15.0f)).get(), + LiteralUtil::CreateR1(std::vector(1, 125.0f)).get()}); ComputeAndCompareTuple(&builder, *expected, {operand.get(), scale.get(), offset.get()}, @@ -361,12 +363,12 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) { BatchNormGrad(operand, scale, mean, var, grad_output, /*epsilon=*/0.0, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR4({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, - {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR4({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, + {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) .get(), - Literal::CreateR1({0, 0}).get(), - Literal::CreateR1({16, 20}).get()}); + LiteralUtil::CreateR1({0, 0}).get(), + LiteralUtil::CreateR1({16, 20}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); } @@ -512,11 +514,12 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D, scale4D, offset4D, epsilon); - auto expected_normalized = Literal::CreateR4FromArray4D(normalized); + auto expected_normalized = + LiteralUtil::CreateR4FromArray4D(normalized); - auto offset_literal = Literal::CreateR1(offset); - auto scale_literal = Literal::CreateR1(scale); - auto input_literal = Literal::CreateR4FromArray4D(input_array); + auto offset_literal = LiteralUtil::CreateR1(offset); + auto scale_literal = LiteralUtil::CreateR1(scale); + auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); auto input_activations = Parameter(&builder, 0, input_literal->shape(), "input"); @@ -525,9 +528,9 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { auto offset_activations = Parameter(&builder, 2, offset_literal->shape(), "scale"); - auto expected = Literal::MakeTuple({expected_normalized.get(), - Literal::CreateR1(mean).get(), - Literal::CreateR1(var).get()}); + auto expected = LiteralUtil::MakeTuple( + {expected_normalized.get(), LiteralUtil::CreateR1(mean).get(), + LiteralUtil::CreateR1(var).get()}); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -612,11 +615,11 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D, scale4D, offset4D, epsilon); - auto offset_literal = Literal::CreateR1(offset); - auto scale_literal = Literal::CreateR1(scale); - auto mean_literal = Literal::CreateR1(mean); - auto var_literal = Literal::CreateR1(var); - auto input_literal = Literal::CreateR4FromArray4D(input_array); + auto offset_literal = LiteralUtil::CreateR1(offset); + auto scale_literal = LiteralUtil::CreateR1(scale); + auto mean_literal = LiteralUtil::CreateR1(mean); + auto var_literal = LiteralUtil::CreateR1(var); + auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); auto input_activations = Parameter(&builder, 0, input_literal->shape(), "input"); @@ -730,7 +733,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { var4D, [epsilon](float a) { return a + epsilon; }); auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D( - var_add_epsilon, [epsilon](float a) { return 1 / std::sqrt(a); }); + var_add_epsilon, [](float a) { return 1 / std::sqrt(a); }); auto grad_output_times_var = *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon, @@ -799,14 +802,14 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { }); auto expected_grad_activation = - Literal::CreateR4FromArray4D(grad_activation); + LiteralUtil::CreateR4FromArray4D(grad_activation); - auto input_literal = Literal::CreateR4FromArray4D(input_array); - auto scale_literal = Literal::CreateR1(scale); - auto mean_literal = Literal::CreateR1(mean); - auto var_literal = Literal::CreateR1(var); + auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); + auto scale_literal = LiteralUtil::CreateR1(scale); + auto mean_literal = LiteralUtil::CreateR1(mean); + auto var_literal = LiteralUtil::CreateR1(var); auto grad_output_literal = - Literal::CreateR4FromArray4D(grad_output_array); + LiteralUtil::CreateR4FromArray4D(grad_output_array); auto input_parameter = Parameter(&builder, 0, input_literal->shape(), "input"); @@ -832,9 +835,9 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { grad_output_parameter, epsilon, feature_index); auto expected = - Literal::MakeTuple({expected_grad_activation.get(), - Literal::CreateR1(grad_scale).get(), - Literal::CreateR1(grad_offset).get()}); + LiteralUtil::MakeTuple({expected_grad_activation.get(), + LiteralUtil::CreateR1(grad_scale).get(), + LiteralUtil::CreateR1(grad_offset).get()}); // Run all HLO passes during this test. In particular, ClientLibraryTestBase // disables constant folding, but we want it enabled for our zero-sized tensor diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index f40d03bea79de2a78814a0ad9f6cae6098d1449b..6c20f654fe3df6a28e9633cd832c11b487894bad 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -95,18 +95,18 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR4( + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR4( {{{{static_cast(-1.6875f)}, {static_cast(-2.04f)}}, {{static_cast(0.105f)}, {static_cast(0.66f)}}}, {{{static_cast(1.89f)}, {static_cast(3.35f)}}, {{static_cast(3.7f)}, {static_cast(6.04f)}}}}) .get(), - Literal::CreateR1( + LiteralUtil::CreateR1( {static_cast(4), static_cast(5)}) .get(), - Literal::CreateR1( + LiteralUtil::CreateR1( {static_cast(5), static_cast(5)}) .get()}); @@ -139,17 +139,17 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) { BatchNormGrad(operand, scale, mean, var, grad_output, /*epsilon=*/0.0, kFeatureIndex); - auto expected = Literal::MakeTuple( - {Literal::CreateR4( + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR4( {{{{static_cast(-3.f)}, {static_cast(-3.f)}}, {{static_cast(-1.f)}, {static_cast(-1.f)}}}, {{{static_cast(1.f)}, {static_cast(1.f)}}, {{static_cast(3.f)}, {static_cast(3.f)}}}}) .get(), - Literal::CreateR1( + LiteralUtil::CreateR1( {static_cast(0), static_cast(0)}) .get(), - Literal::CreateR1( + LiteralUtil::CreateR1( {static_cast(16), static_cast(20)}) .get()}); diff --git a/tensorflow/compiler/xla/tests/binop_scaling_test.cc b/tensorflow/compiler/xla/tests/binop_scaling_test.cc index 20cb989751ad69e2f3cf97c87c43293951f599ab..0d7a3aa46a9c12c19d954c11ae3a2cccbed886ef 100644 --- a/tensorflow/compiler/xla/tests/binop_scaling_test.cc +++ b/tensorflow/compiler/xla/tests/binop_scaling_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" diff --git a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc index d531e8fa82e47f7bcd278f10da2c205e44db0ac1..c6b5108fe9e5bcf843982676d822f1942359da71 100644 --- a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc +++ b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 5fdd1018a41413aa6f4e08a0c02a40bf17f1f882..1d28e85b16596b0ec2717138fb2081878203e8b2 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -20,7 +20,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -58,7 +59,7 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { Array3D* r3_array, float start, float end, int seed) { *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r3_array->FillRandom(start, end, seed); - auto r3_data = Literal::CreateR3FromArray3D(*r3_array)->Relayout( + auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout( LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r3_global_data = client_->TransferToServer(*r3_data).ConsumeValueOrDie(); @@ -71,7 +72,7 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { Array2D* r2_array, float start, float end, int seed) { *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r2_array->FillRandom(start, end, seed); - auto r2_data = Literal::CreateR2FromArray2D(*r2_array)->Relayout( + auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout( LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r2_global_data = client_->TransferToServer(*r2_data).ConsumeValueOrDie(); @@ -156,6 +157,86 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } +XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR1(&b, {1, 2}), + ShapeUtil::MakeShape(F32, {2, 2}), {1}); + + Array2D expected(2, 2); + expected(0, 0) = 1; + expected(0, 1) = 2; + expected(1, 0) = 1; + expected(1, 1) = 2; + + ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR1(&b, {1, 2}), + ShapeUtil::MakeShape(F32, {2, 2}), {0}); + + Array2D expected(2, 2); + expected(0, 0) = 1; + expected(0, 1) = 1; + expected(1, 0) = 2; + expected(1, 1) = 2; + + ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), + ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 1}); + + Array3D expected(2, 2, 2); + expected(0, 0, 0) = 1.0; + expected(1, 0, 0) = 2.0; + expected(0, 0, 1) = 1.0; + expected(1, 0, 1) = 2.0; + expected(0, 1, 0) = 5.0; + expected(1, 1, 0) = 6.0; + expected(1, 1, 1) = 6.0; + expected(0, 1, 1) = 5.0; + + ComputeAndCompareR3(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), + ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 2}); + + Array3D expected(2, 2, 2); + expected(0, 0, 0) = 1.0; + expected(1, 0, 0) = 2.0; + expected(0, 0, 1) = 5.0; + expected(1, 0, 1) = 6.0; + expected(0, 1, 0) = 1.0; + expected(1, 1, 0) = 2.0; + expected(1, 1, 1) = 6.0; + expected(0, 1, 1) = 5.0; + + ComputeAndCompareR3(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsNotPossibleWithBroadCast) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR1(&b, {1, 2}), + ShapeUtil::MakeShape(F32, {3, 2}), {1}); + + Array2D expected(3, 2); + expected(0, 0) = 1; + expected(0, 1) = 2; + expected(1, 0) = 1; + expected(1, 1) = 2; + expected(2, 0) = 1; + expected(2, 1) = 2; + + ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); +} + // Tests implicit broadcasting of PREDs. XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { XlaBuilder b(TestName()); @@ -210,13 +291,13 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { XlaBuilder b(TestName()); Add(ConstantR2(&b, {{1.0, 5.0}}), - ConstantLiteral(&b, *Literal::CreateR3( + ConstantLiteral(&b, *LiteralUtil::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); auto expected = - Literal::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, - {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); + LiteralUtil::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, + {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -285,7 +366,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { } } } - auto expected = Literal::CreateR3FromArray3D(expected_array); + auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); ComputeAndCompareLiteral( &builder, *expected, {r3_implicit_global_data.get(), r3_global_data.get()}, @@ -310,7 +391,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { Add(r3h, r1h); auto expected = - Literal::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); + LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()}, ErrorSpec(0.0001)); @@ -318,39 +399,40 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *Literal::CreateR3({{{1, 2}}})); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}}})); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = - Literal::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); + LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *Literal::CreateR3({{{1}, {2}}})); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1}, {2}}})); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = - Literal::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); + LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *Literal::CreateR3({{{1, 2}, {3, 4}}})); + auto r1 = + ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = - Literal::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); + LiteralUtil::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -358,40 +440,40 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { XlaBuilder b(TestName()); auto r1 = - ConstantLiteral(&b, *Literal::CreateR3({{{1, 2}}, {{3, 4}}})); + ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = - Literal::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); + LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { XlaBuilder b(TestName()); - auto r1 = - ConstantLiteral(&b, *Literal::CreateR3({{{1}, {2}}, {{3}, {4}}})); + auto r1 = ConstantLiteral( + &b, *LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = - Literal::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); + LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *Literal::CreateR3({{{1}}})); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1}}})); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = - Literal::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); + LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -532,7 +614,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { *v = ApplyOpToFloats(spec.op2, tmp, v3); }); - auto expected = Literal::CreateR2FromArray2D(expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral( &builder, *expected, {r2_implicit_global_data1.get(), r2_global_data.get(), @@ -546,22 +628,24 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *Literal::CreateR2({{1, 2}})); - auto r2 = ConstantLiteral(&b, *Literal::CreateR2({{1, 2}, {3, 4}})); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}})); + auto r2 = + ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}, {3, 4}})); Add(r2, r1); - auto expected = Literal::CreateR2({{2, 4}, {4, 6}}); + auto expected = LiteralUtil::CreateR2({{2, 4}, {4, 6}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *Literal::CreateR2({{1}, {2}})); - auto r2 = ConstantLiteral(&b, *Literal::CreateR2({{1, 2}, {3, 4}})); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2({{1}, {2}})); + auto r2 = + ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}, {3, 4}})); Add(r2, r1); - auto expected = Literal::CreateR2({{2, 3}, {5, 6}}); + auto expected = LiteralUtil::CreateR2({{2, 3}, {5, 6}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -570,11 +654,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1, {0}); - auto expected = - Literal::CreateR3({{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); + auto expected = LiteralUtil::CreateR3( + {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -583,11 +667,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r1, r3, {1}); - auto expected = - Literal::CreateR3({{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); + auto expected = LiteralUtil::CreateR3( + {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -596,11 +680,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r1, r3, {2}); - auto expected = - Literal::CreateR3({{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); + auto expected = LiteralUtil::CreateR3( + {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -611,7 +695,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { auto r1_1 = ConstantR1(&b, {100, 200}); auto r1_2 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); for (int i = 0; i < 3; ++i) { r3 = Add(r1_0, r3, {0}); r3 = Add(r3, r1_1, {1}); @@ -619,7 +703,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { } r3 = Mul(r3, ConstantR0(&b, -2)); - auto expected = Literal::CreateR3( + auto expected = LiteralUtil::CreateR3( {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}}, {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}}); @@ -640,7 +724,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { } r3 = Mul(r3, ConstantR0(&b, -1)); - auto expected = Literal::CreateR3( + auto expected = LiteralUtil::CreateR3( {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}}, {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}}); @@ -653,7 +737,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { XlaBuilder b(TestName()); Add(ConstantR2(&b, {{1.0, 5.0}, {1.0, 5.0}}), - ConstantLiteral(&b, *Literal::CreateR3( + ConstantLiteral(&b, *LiteralUtil::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 51b9f0d3e330e73f5d110f0a62f824179d5c7cf7..c7b94b5bbaaa512ad36056f9e68a87cc706c24b1 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -37,7 +37,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { // Test degenerate case of broadcasting a scalar into a scalar. auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {}), input, {})); @@ -46,14 +46,14 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR0(42.0), *result, - error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR0(42.0), + *result, error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 2}), input, {})); @@ -63,14 +63,14 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, + *LiteralUtil::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); // Broadcast vector in both dimension 0 and dimension 1. Join them in a tuple // to enable testing of the results. @@ -86,18 +86,18 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), + *LiteralUtil::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), LiteralSlice(*result, {0}), error_spec_)); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), + *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), LiteralSlice(*result, {1}), error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1})); @@ -106,9 +106,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE( - LiteralTestUtil::Near(*Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, + error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { @@ -116,7 +116,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { // the dimensions, ie transpose. auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0})); @@ -125,15 +125,15 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE( - LiteralTestUtil::Near(*Literal::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + *LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, + error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2})); @@ -143,15 +143,15 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), + *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1.0, 2.0}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1.0, 2.0}))); // Broadcast vector in dimension 1. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -166,8 +166,9 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { Array2D pz({{1, 2}, {1, 2}}); expected.FillWithPZ(pz); - EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { @@ -176,7 +177,7 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { int64 r1_size = input_data.size(); std::iota(input_data.begin(), input_data.end(), 0.0f); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1(input_data))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1(input_data))); // Broadcast vector in dimension 3. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -196,8 +197,9 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { } expected.FillWithYX(yx); - EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), + *result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { @@ -207,7 +209,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { std::vector r1_array(64, 42.0); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1(r1_array))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1(r1_array))); // Broadcast vector in dimension 1. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -218,14 +220,14 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR4FromArray4D(r4_array), + EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(r4_array), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {})); @@ -238,15 +240,16 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { Array4D expected(64, 64, 3, 3); expected.Fill(1.0f); - EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { auto builder = HloComputation::Builder(TestName()); Array2D to_broadcast({{1.0f, 2.0f}, {3.0f, 4.0f}}); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2FromArray2D(to_broadcast))); + LiteralUtil::CreateR2FromArray2D(to_broadcast))); // Broadcast vector in dimensions 2 and 3. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -260,8 +263,9 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { Array4D expected(3, 3, 2, 2); expected.FillWithYX(to_broadcast); - EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { @@ -280,7 +284,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { } } auto input = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR3FromArray3D(input_vals))); + LiteralUtil::CreateR3FromArray3D(input_vals))); // Broadcast vector in dimensions 2 and 3. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -291,8 +295,9 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); + EXPECT_TRUE( + LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), + *result, error_spec_)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index bc64a19ce22072152216a7c150fbd16480d261fb..b1d18210eaafdfec0920c0cccaa0dfdbd6de5609 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -16,8 +16,9 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -76,7 +77,8 @@ class CallOpTest : public ClientLibraryTestBase { XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32IdentityComputation(); - auto constant = ConstantLiteral(&builder, *Literal::CreateR0(42.0)); + auto constant = + ConstantLiteral(&builder, *LiteralUtil::CreateR0(42.0)); Call(&builder, callee, {constant}); ComputeAndCompareR0(&builder, 42.0, {}, ErrorSpec(0.01f)); @@ -85,8 +87,8 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S0F32AdditionComputation(); - auto x = ConstantLiteral(&builder, *Literal::CreateR1({})); - auto y = ConstantLiteral(&builder, *Literal::CreateR1({})); + auto x = ConstantLiteral(&builder, *LiteralUtil::CreateR1({})); + auto y = ConstantLiteral(&builder, *LiteralUtil::CreateR1({})); Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.01f)); @@ -95,8 +97,10 @@ XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S2F32AdditionComputation(); - auto x = ConstantLiteral(&builder, *Literal::CreateR1({1.0f, 2.0f})); - auto y = ConstantLiteral(&builder, *Literal::CreateR1({2.0f, 3.0f})); + auto x = + ConstantLiteral(&builder, *LiteralUtil::CreateR1({1.0f, 2.0f})); + auto y = + ConstantLiteral(&builder, *LiteralUtil::CreateR1({2.0f, 3.0f})); Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f)); @@ -129,15 +133,15 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr start, - client_->TransferToServer(*Literal::CreateR0(1.0f))); + client_->TransferToServer(*LiteralUtil::CreateR0(1.0f))); ComputeAndCompareR0(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f)); } XLA_TEST_F(CallOpTest, CallR0F32Tuple) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32TupleComputation(); - auto elem = Literal::CreateR0(42.0); - auto tuple = Literal::MakeTuple({elem.get()}); + auto elem = LiteralUtil::CreateR0(42.0); + auto tuple = LiteralUtil::MakeTuple({elem.get()}); Call(&builder, callee, {ConstantLiteral(&builder, *elem)}); ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f)); diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index 1ad57c075b22c7730ffd8d1beeab60c9d5dc7458..a4eb57fc7b9abd460a7d158d0dc629eba88018cd 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -36,7 +36,7 @@ class CheckExecutionArityTest : public ClientLibraryTestBase {}; TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { XlaBuilder builder("add_two_params"); - auto param_literal = Literal::CreateR1({1.1f, 2.2f}); + auto param_literal = LiteralUtil::CreateR1({1.1f, 2.2f}); auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0"); auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1"); @@ -85,12 +85,12 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { ASSERT_IS_OK(computation_status.status()); auto computation = computation_status.ConsumeValueOrDie(); - auto f32_literal = Literal::CreateR0(1.1f); + auto f32_literal = LiteralUtil::CreateR0(1.1f); auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie(); - auto f32_4_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); + auto f32_4_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); auto f32_4_data = client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie(); - auto u8_4_literal = Literal::CreateR1U8("hola"); + auto u8_4_literal = LiteralUtil::CreateR1U8("hola"); auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie(); // Match diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index dafd6ebabbe6edafc1c926677b3ea00e775be010..59d917054be2ebe3a25f902f51972a682a5231b6 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -157,7 +157,7 @@ string ClientLibraryTestBase::ExecuteToString( void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, const tensorflow::core::Bitmap& expected, tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = Literal::CreateR1(expected); + std::unique_ptr expected_literal = LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -273,10 +273,16 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), arguments_passed_in.end()); + + // Transfer and use elements of arguments_, if the AddParam() API was used. + std::vector> owning_arguments; if (!arguments_.empty()) { CHECK(arguments.empty()); for (const auto& argument : arguments_) { - arguments.push_back(argument.get()); + owning_arguments.push_back( + client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)) + .ValueOrDie()); + arguments.push_back(owning_arguments.back().get()); } } @@ -295,7 +301,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( std::unique_ptr converted_expected; Shape layout_shape; if (use_bfloat16_) { - converted_expected = Literal::ConvertF32ToBF16(expected); + converted_expected = LiteralUtil::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; @@ -331,10 +337,16 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( ErrorSpec error, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), arguments_passed_in.end()); + + // Transfer and use elements of arguments_, if the AddParam() API was used. + std::vector> owning_arguments; if (!arguments_.empty()) { CHECK(arguments.empty()); for (const auto& argument : arguments_) { - arguments.push_back(argument.get()); + owning_arguments.push_back( + client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)) + .ValueOrDie()); + arguments.push_back(owning_arguments.back().get()); } } @@ -347,7 +359,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( std::unique_ptr converted_expected; Shape layout_shape; if (use_bfloat16_) { - converted_expected = Literal::ConvertF32ToBF16(expected); + converted_expected = LiteralUtil::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; @@ -389,7 +401,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( auto actual = actual_status.ConsumeValueOrDie(); // Turn the expected value into a literal. - std::unique_ptr expected_literal = Literal::CreateR1U8(expected); + std::unique_ptr expected_literal = LiteralUtil::CreateR1U8(expected); VLOG(1) << "expected: " << expected_literal->ToString(); VLOG(1) << "actual: " << actual->ToString(); @@ -454,6 +466,14 @@ ClientLibraryTestBase::ComputeValueAndReference( // function. std::vector> argument_data; std::vector> ref_argument_data; + + // Use `arguments_` if the AddParam() API was used. Otherwise, use + // plain `arguments`. + if (!arguments_.empty()) { + CHECK_EQ(arguments.size(), 0); + arguments = arguments_; + } + for (const auto& arg : arguments) { TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg.Clone())); TF_ASSIGN_OR_RETURN(auto ref_data, ref_client_->TransferToServer(arg)); @@ -552,16 +572,16 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaBuilder* builder) { - XlaOp data_handle; - arguments_.push_back(CreateParameterAndTransferLiteral( - arguments_.size(), argument, "", builder, &data_handle)); - return data_handle; + arguments_.push_back(argument.Clone()); + return Parameter(builder, /*parameter_number=*/arguments_.size() - 1, + MaybeConvertShapeToBfloat16(argument.shape()), ""); } XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder) { - return ConstantLiteral( - builder, use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal); + return ConstantLiteral(builder, use_bfloat16_ + ? *LiteralUtil::ConvertF32ToBF16(literal) + : literal); } std::unique_ptr @@ -574,22 +594,39 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, nullptr, builder, data_handle); } +Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) { + if (!use_bfloat16_) { + return shape; + } + Shape new_shape = shape; + ShapeUtil::ForEachMutableSubshape(&new_shape, + [](Shape* subshape, const ShapeIndex&) { + if (subshape->element_type() == F32) { + subshape->set_element_type(BF16); + } + }); + return new_shape; +} + +Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16( + const Literal& literal) { + if (use_bfloat16_) { + return std::move(*LiteralUtil::ConvertF32ToBF16(literal)); + } + return literal.Clone(); +} + std::unique_ptr ClientLibraryTestBase::CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, const DeviceHandle* device_handle, XlaBuilder* builder, XlaOp* data_handle) { - const Literal* param_literal = &literal; - std::unique_ptr converted_literal; - if (use_bfloat16_) { - converted_literal = Literal::ConvertF32ToBF16(literal); - param_literal = converted_literal.get(); - } + Literal param_literal = MaybeConvertLiteralToBfloat16(literal); std::unique_ptr data = - client_->TransferToServer(*param_literal, device_handle) + client_->TransferToServer(param_literal, device_handle) .ConsumeValueOrDie(); *data_handle = - Parameter(builder, parameter_number, param_literal->shape(), name); + Parameter(builder, parameter_number, param_literal.shape(), name); return data; } diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 37862fa9cb941ab9161064f2a58db38d08f83169..4a6e8a31241d39db21935576d57f0acb17caef11 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -26,8 +26,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -284,7 +285,7 @@ class ClientLibraryTestBase : public ::testing::Test { template XlaOp AddParam(const Array& argument, XlaBuilder* builder) { - return AddParam(*Literal::CreateFromArray(argument), builder); + return AddParam(*LiteralUtil::CreateFromArray(argument), builder); } // Creates a constant instruction with the given literal. When the @@ -299,13 +300,14 @@ class ClientLibraryTestBase : public ::testing::Test { template XlaOp CreateConstantFromArray(const Array& array, XlaBuilder* builder) { - return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); + return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array), + builder); } // Same as CreateConstantFromArray, but for scalars. template XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) { - return CreateConstantFromLiteral(*Literal::CreateR0(value), + return CreateConstantFromLiteral(*LiteralUtil::CreateR0(value), builder); } @@ -373,6 +375,13 @@ class ClientLibraryTestBase : public ::testing::Test { // The float type used in this test, BF16 or F32 according to use_bfloat16. PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; } + // Executes the computation and calculates the expected reference value using + // the reference client. Returns two literals in the order of (expected, + // actual). + StatusOr, std::unique_ptr>> + ComputeValueAndReference(XlaBuilder* builder, + tensorflow::gtl::ArraySlice arguments); + Client* client_; Client* ref_client_; // To compute reference result. ExecutionOptions execution_options_; @@ -390,19 +399,16 @@ class ClientLibraryTestBase : public ::testing::Test { const string& error_message)>& verify_output, const Shape* output_with_layout = nullptr); - // Executes the computation and calculates the expected reference value using - // the reference client. Returns two literals in the order of (expected, - // actual). - StatusOr, std::unique_ptr>> - ComputeValueAndReference(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments); + // Converts an f32 shape/literal to bf16 if use_bfloat16_ is true. + Literal MaybeConvertLiteralToBfloat16(const Literal& literal); + Shape MaybeConvertShapeToBfloat16(const Shape& shape); // Whether to run tests with all float-type input/output converted to // bfloat16. bool use_bfloat16_ = false; // Arguments to be passed to the computation when it runs. - std::vector> arguments_; + std::vector arguments_; }; template @@ -410,7 +416,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - Literal::CreateR0(expected); + LiteralUtil::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -426,7 +432,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - Literal::CreateR0(expected); + LiteralUtil::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -436,7 +442,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - Literal::CreateR1(expected); + LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -452,7 +458,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - Literal::CreateR1(expected); + LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -462,7 +468,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( XlaBuilder* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - Literal::CreateR2FromArray2D(expected); + LiteralUtil::CreateR2FromArray2D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -478,7 +484,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - Literal::CreateR2FromArray2D(expected); + LiteralUtil::CreateR2FromArray2D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -488,7 +494,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( XlaBuilder* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - Literal::CreateR3FromArray3D(expected); + LiteralUtil::CreateR3FromArray3D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -504,7 +510,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - Literal::CreateR3FromArray3D(expected); + LiteralUtil::CreateR3FromArray3D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -514,7 +520,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( XlaBuilder* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - Literal::CreateR4FromArray4D(expected); + LiteralUtil::CreateR4FromArray4D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -530,7 +536,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - Literal::CreateR4FromArray4D(expected); + LiteralUtil::CreateR4FromArray4D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -539,9 +545,9 @@ template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = Literal::CreateR0(value); + std::unique_ptr literal = LiteralUtil::CreateR0(value); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = Literal::ConvertF32ToBF16(*literal); + literal = LiteralUtil::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -553,9 +559,9 @@ template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = Literal::CreateR1(values); + std::unique_ptr literal = LiteralUtil::CreateR1(values); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = Literal::ConvertF32ToBF16(*literal); + literal = LiteralUtil::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -567,9 +573,9 @@ template std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = Literal::CreateR2FromArray2D(array_2d); + std::unique_ptr literal = LiteralUtil::CreateR2FromArray2D(array_2d); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = Literal::ConvertF32ToBF16(*literal); + literal = LiteralUtil::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -581,9 +587,9 @@ template std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = Literal::CreateR3FromArray3D(array_3d); + std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(array_3d); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = Literal::ConvertF32ToBF16(*literal); + literal = LiteralUtil::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 831b863998f1cab31d37aa4474be45d8531075ac..c898dacf489db97223e2918414daf5de88bece64 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -56,7 +56,7 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { client_->Execute(computation, {}, &execution_options)); std::unique_ptr expected_literal = - Literal::CreateR2WithLayout( + LiteralUtil::CreateR2WithLayout( {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); TF_ASSERT_OK_AND_ASSIGN( @@ -112,9 +112,9 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg; Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr const_arg, - client_->TransferToServer(*Literal::CreateR2({{5, 6}, {7, 8}}))); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr const_arg, + client_->TransferToServer( + *LiteralUtil::CreateR2({{5, 6}, {7, 8}}))); XlaBuilder b(TestName() + ".add"); Add(Parameter(&b, 0, shape, "param_0"), @@ -136,7 +136,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { TF_ASSERT_OK_AND_ASSIGN(auto results, client_->ExecuteParallel(computation_instances)); - auto expected_result = Literal::CreateR2({{6, 8}, {10, 12}}); + auto expected_result = LiteralUtil::CreateR2({{6, 8}, {10, 12}}); TF_ASSERT_OK_AND_ASSIGN( auto result_literal, diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index eb211dd8ff376fb0da03b3e68be1d849970d96fd..7c52c9fbbb57f9291ea9f0966e2efa715819fb67 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -50,7 +50,7 @@ class CompilationCacheTest : public ClientLibraryTestBase { &execution_profile) .ConsumeValueOrDie(); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR0(expected_result), *result, error_spec_)); + *LiteralUtil::CreateR0(expected_result), *result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -67,7 +67,7 @@ class CompilationCacheTest : public ClientLibraryTestBase { std::unique_ptr result = client_->Transfer(*data_handle).ConsumeValueOrDie(); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2(expected_result), *result, error_spec_)); + *LiteralUtil::CreateR2(expected_result), *result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -89,13 +89,13 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) { XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledWithDifferentParameters) { std::unique_ptr data_42 = - client_->TransferToServer(*Literal::CreateR0(42.0f)) + client_->TransferToServer(*LiteralUtil::CreateR0(42.0f)) .ConsumeValueOrDie(); std::unique_ptr data_123 = - client_->TransferToServer(*Literal::CreateR0(123.0f)) + client_->TransferToServer(*LiteralUtil::CreateR0(123.0f)) .ConsumeValueOrDie(); std::unique_ptr data_456 = - client_->TransferToServer(*Literal::CreateR0(456.0f)) + client_->TransferToServer(*LiteralUtil::CreateR0(456.0f)) .ConsumeValueOrDie(); XlaBuilder builder(TestName()); @@ -143,12 +143,12 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) { // layouts. Use these arrays as parameters to a simple computation. If the // layout of the array changes then computation should be recompiled (cache // miss). - auto rowmaj_array = Literal::CreateR2WithLayout( + auto rowmaj_array = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0})); auto rowmaj_handle = client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie(); - auto colmaj_array = Literal::CreateR2WithLayout( + auto colmaj_array = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})); auto colmaj_handle = client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index ba22530f1cfee56337f862c25122d399dbf0f1e4..5a06d061f0d83fff547502495ff8ab13fb421b70 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -99,7 +99,7 @@ TEST_F(ComputeConstantTest, ScalarInt32Literal) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = b.ConstantR0(42); + auto computation = ConstantR0(&b, 42); EXPECT_TRUE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -113,7 +113,7 @@ TEST_F(ComputeConstantTest, ScalarFloatAdd) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto computation = - b.Add(b.ConstantR0(42.5f), b.ConstantR0(1.5f)); + Add(ConstantR0(&b, 42.5f), ConstantR0(&b, 1.5f)); EXPECT_TRUE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -127,8 +127,8 @@ TEST_F(ComputeConstantTest, ScalarRng) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto computation = - b.RngUniform(b.ConstantR0(1.1f), b.ConstantR0(2.1f), - ShapeUtil::MakeShape(F32, {})); + RngUniform(ConstantR0(&b, 1.1f), ConstantR0(&b, 2.1f), + ShapeUtil::MakeShape(F32, {})); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -141,7 +141,7 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"); + auto computation = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param"); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -156,8 +156,8 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto computation = - b.Add(b.ConstantR0(1.0f), - b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); + Add(ConstantR0(&b, 1.0f), + Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param")); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); @@ -174,18 +174,18 @@ TEST_F(ComputeConstantTest, UnrelatedParam) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0"); + auto param_a = Parameter(&b, 10, ShapeUtil::MakeShape(F32, {}), "param0"); auto constant_4 = - b.Add(b.ConstantR0(2.5f), b.ConstantR0(1.5f)); - auto not_constant_a = b.Add(constant_4, param_a); + Add(ConstantR0(&b, 2.5f), ConstantR0(&b, 1.5f)); + auto not_constant_a = Add(constant_4, param_a); - auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1"); + auto param_b = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "param1"); auto constant_9 = - b.Mul(b.ConstantR0(2.0f), b.ConstantR0(4.5f)); - auto not_constant_b = b.Add(param_b, constant_9); + Mul(ConstantR0(&b, 2.0f), ConstantR0(&b, 4.5f)); + auto not_constant_b = Add(param_b, constant_9); - auto constant_13 = b.Add(constant_4, constant_9); - b.Add(not_constant_b, b.Add(constant_13, not_constant_a)); + auto constant_13 = Add(constant_4, constant_9); + Add(not_constant_b, Add(constant_13, not_constant_a)); EXPECT_TRUE(IsConstant(constant_13, &b)); @@ -201,13 +201,13 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { XlaBuilder b(TestName()); auto computation = - b.Add(b.ConstantR1({1, 2}), b.ConstantR1({3, 4})); + Add(ConstantR1(&b, {1, 2}), ConstantR1(&b, {3, 4})); EXPECT_TRUE(IsConstant(computation, &b)); TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); std::unique_ptr expected_literal = - Literal::CreateR1({4, 6}); + LiteralUtil::CreateR1({4, 6}); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } @@ -216,12 +216,12 @@ TEST_F(ComputeConstantTest, IntegerDivide) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = b.Div(b.ConstantR0(15), b.ConstantR0(3)); + auto computation = Div(ConstantR0(&b, 15), ConstantR0(&b, 3)); EXPECT_TRUE(IsConstant(computation, &b)); TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); - std::unique_ptr expected_literal = Literal::CreateR0(5); + std::unique_ptr expected_literal = LiteralUtil::CreateR0(5); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } @@ -237,13 +237,13 @@ XLA_TEST_F(ComputeConstantTest, Layout) { TF_ASSERT_OK_AND_ASSIGN( auto computed, ComputeConstantLiteral( client, - b.Add(b.ConstantR2({{1, 2}, {3, 4}}), - b.ConstantR2({{10, 20}, {30, 40}})), + Add(ConstantR2(&b, {{1, 2}, {3, 4}}), + ConstantR2(&b, {{10, 20}, {30, 40}})), &b, &layout_proto)); std::unique_ptr expected_literal = - Literal::CreateR2WithLayout({{11, 22}, {33, 44}}, - LayoutUtil::MakeLayout(layout)); + LiteralUtil::CreateR2WithLayout( + {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( expected_literal->shape(), computed->shape())); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index 1161b560b7b0756556911812666c6f4fe9179f72..be017477d84eb9faf5aa79dcdf54d6b6aaf6fd8e 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -534,8 +534,8 @@ TEST_P(ConcatR2BinaryTest, DoIt) { // concat XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); - auto x_literal = Literal::CreateR0(2.f); - auto y_literal = Literal::CreateR0(3.f); + auto x_literal = LiteralUtil::CreateR0(2.f); + auto y_literal = LiteralUtil::CreateR0(3.f); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); @@ -556,9 +556,9 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { // produces the correct result in rank 1. XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); - auto x_literal = Literal::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); - auto y_literal = Literal::CreateR0(1.5f); - auto z_literal = Literal::CreateR0(5.5f); + auto x_literal = LiteralUtil::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); + auto y_literal = LiteralUtil::CreateR0(1.5f); + auto z_literal = LiteralUtil::CreateR0(5.5f); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); @@ -584,9 +584,9 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); Array3D x3d(3, 5, 7, 3.14f); - auto x_literal = Literal::CreateR3FromArray3D(x3d); - auto y_literal = Literal::CreateR0(1.5f); - auto z_literal = Literal::CreateR0(5.5f); + auto x_literal = LiteralUtil::CreateR3FromArray3D(x3d); + auto y_literal = LiteralUtil::CreateR0(1.5f); + auto z_literal = LiteralUtil::CreateR0(5.5f); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index ee3c83039bfc13f6ad78111d92ba0f8387a3ade3..b27c1044baf2c0002f166c53a81e4361c60d012a 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -172,88 +172,95 @@ class ConditionalOpTest : public ClientLibraryTestBase { // Test true and false computations that do not take any parameters. XLA_TEST_F(ConditionalOpTest, Parameters0) { XlaBuilder builder(TestName()); - auto pred = ConstantR0(&builder, true); + XlaOp pred; + auto pred_arg = CreateR0Parameter(true, 0, "pred", &builder, &pred); auto operands = Tuple(&builder, {}); auto true_computation = CreateR0ConstantComputation(56.0f); auto false_computation = CreateR0ConstantComputation(12.0f); Conditional(pred, operands, true_computation, operands, false_computation); - ComputeAndCompareR0(&builder, 56.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 56.0f, {pred_arg.get()}, error_spec_); } // Test true and false computations that take in 1 parameter. XLA_TEST_F(ConditionalOpTest, Parameters1) { XlaBuilder builder(TestName()); - auto pred = ConstantR0(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); auto operand1 = ConstantR0(&builder, 56.0f); auto operand2 = ConstantR0(&builder, 12.0f); auto identity = CreateR0IdentityComputation(); Conditional(pred, operand1, identity, operand2, identity); - ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test conditional with two different computations in the true and false cases // that take in different arguments. XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { XlaBuilder builder(TestName()); - auto pred = ConstantR0(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); auto operand1 = ConstantR0(&builder, 56.4f); auto operand2 = ConstantR0(&builder, 12.6f); Conditional(pred, operand1, CreateR0CeilComputation(), operand2, CreateR0FloorComputation()); - ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test conditional with two different computations in the true and false cases // that take in the same arguments. XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { XlaBuilder builder(TestName()); - auto pred = ConstantR0(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); auto operand = ConstantR0(&builder, 12.6f); Conditional(pred, operand, CreateR0CeilComputation(), operand, CreateR0FloorComputation()); - ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test conditional with the same computation in the true and false cases but // take in different arguments. XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { XlaBuilder builder(TestName()); - auto pred = ConstantR0(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); auto operand1 = ConstantR0(&builder, 56.4f); auto operand2 = ConstantR0(&builder, 12.6f); auto floor = CreateR0FloorComputation(); Conditional(pred, operand1, floor, operand2, floor); - ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test conditional with the same computation in the true and false cases that // take in the same arguments. XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { XlaBuilder builder(TestName()); - auto pred = ConstantR0(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); auto operand = ConstantR0(&builder, 12.6f); auto floor = CreateR0FloorComputation(); Conditional(pred, operand, floor, operand, floor); - ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test conditional with different instances of the same computation in the true // and false cases. XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { XlaBuilder builder(TestName()); - auto pred = ConstantR0(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); auto operand1 = ConstantR0(&builder, 56.4f); auto operand2 = ConstantR0(&builder, 12.6f); Conditional(pred, operand1, CreateR0FloorComputation(), operand2, CreateR0FloorComputation()); - ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test the case when a call invokes a computation that contains a conditional. @@ -268,75 +275,83 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { auto inner_builder_result = inner_builder.Build(); XlaBuilder builder(TestName()); - auto pred = ConstantR0(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); auto operand1 = ConstantR0(&builder, 56.4f); auto operand2 = ConstantR0(&builder, 12.6f); Call(&builder, inner_builder_result.ConsumeValueOrDie(), {pred, operand1, operand2}); - ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test true and false computations that take in 2 parameters and predicate is // true. XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { XlaBuilder builder(TestName()); - auto pred = ConstantR0(&builder, true); + XlaOp pred; + auto pred_arg = CreateR0Parameter(true, 0, "pred", &builder, &pred); auto operand1 = ConstantR0(&builder, 56.0f); auto operand2 = ConstantR0(&builder, 12.0f); auto operands = Tuple(&builder, {operand1, operand2}); Conditional(pred, operands, CreateR0TupleAddComputation(), operands, CreateR0TupleSubComputation()); - ComputeAndCompareR0(&builder, 68.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 68.0f, {pred_arg.get()}, error_spec_); } // Test true and false computations that take in 2 parameters and predicate is // false. XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { XlaBuilder builder(TestName()); - auto pred = ConstantR0(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); auto operand1 = ConstantR0(&builder, 56.0f); auto operand2 = ConstantR0(&builder, 12.0f); auto operands = Tuple(&builder, {operand1, operand2}); Conditional(pred, operands, CreateR0TupleAddComputation(), operands, CreateR0TupleSubComputation()); - ComputeAndCompareR0(&builder, 44.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 44.0f, {pred_arg.get()}, error_spec_); } // Test true and false computations that take in 2 array parameters and // predicate is true. XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { XlaBuilder builder(TestName()); - auto pred = ConstantR0(&builder, true); + XlaOp pred; + auto pred_arg = CreateR0Parameter(true, 0, "pred", &builder, &pred); auto operand1 = ConstantR1(&builder, {24.0f, 56.0f}); auto operand2 = ConstantR1(&builder, {10.0f, 11.0f}); auto operands = Tuple(&builder, {operand1, operand2}); Conditional(pred, operands, CreateR1TupleAddComputation(), operands, CreateR1TupleSubComputation()); - ComputeAndCompareR1(&builder, {34.0f, 67.0f}, {}, error_spec_); + ComputeAndCompareR1(&builder, {34.0f, 67.0f}, {pred_arg.get()}, + error_spec_); } // Test true and false computations that take in 2 array parameters and // predicate is false. XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { XlaBuilder builder(TestName()); - auto pred = ConstantR0(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); auto operand1 = ConstantR1(&builder, {24.0f, 56.0f}); auto operand2 = ConstantR1(&builder, {10.0f, 11.0f}); auto operands = Tuple(&builder, {operand1, operand2}); Conditional(pred, operands, CreateR1TupleAddComputation(), operands, CreateR1TupleSubComputation()); - ComputeAndCompareR1(&builder, {14.0f, 45.0f}, {}, error_spec_); + ComputeAndCompareR1(&builder, {14.0f, 45.0f}, {pred_arg.get()}, + error_spec_); } // Test true and false computations that return a tuple of scalars. XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { XlaBuilder builder(TestName()); - auto pred = ConstantR0(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); auto operands = Tuple(&builder, {ConstantR0(&builder, 12.2f), ConstantR0(&builder, 25.6f)}); Conditional(pred, operands, CreateR0TupleCeilComputation(), operands, @@ -344,15 +359,16 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { ComputeAndCompareTuple( &builder, - *Literal::MakeTuple({Literal::CreateR0(12.0f).get(), - Literal::CreateR0(25.0f).get()}), - {}, error_spec_); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(12.0f).get(), + LiteralUtil::CreateR0(25.0f).get()}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a tuple of arrays. XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { XlaBuilder builder(TestName()); - auto pred = ConstantR0(&builder, true); + XlaOp pred; + auto pred_arg = CreateR0Parameter(true, 0, "pred", &builder, &pred); auto operands = Tuple(&builder, {ConstantR1(&builder, {12.2f, 15.8f}), ConstantR1(&builder, {25.6f, 29.2f})}); @@ -361,9 +377,10 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { ComputeAndCompareTuple( &builder, - *Literal::MakeTuple({Literal::CreateR1({13.0f, 16.0f}).get(), - Literal::CreateR1({26.0f, 30.0f}).get()}), - {}, error_spec_); + *LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1({13.0f, 16.0f}).get(), + LiteralUtil::CreateR1({26.0f, 30.0f}).get()}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a tuple of a predicate, a @@ -392,17 +409,19 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { EXPECT_IS_OK(false_builder_result.status()); XlaBuilder builder(TestName()); - auto pred = ConstantR0(&builder, true); + XlaOp pred; + auto pred_arg = CreateR0Parameter(true, 0, "pred", &builder, &pred); auto operands = Tuple(&builder, {}); Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands, false_builder_result.ConsumeValueOrDie()); ComputeAndCompareTuple( &builder, - *Literal::MakeTuple({Literal::CreateR0(true).get(), - Literal::CreateR0(12.2f).get(), - Literal::CreateR1({12.8f, 14.6f}).get()}), - {}, error_spec_); + *LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(true).get(), + LiteralUtil::CreateR0(12.2f).get(), + LiteralUtil::CreateR1({12.8f, 14.6f}).get()}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a nested tuple. @@ -436,21 +455,24 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { EXPECT_IS_OK(false_builder_result.status()); XlaBuilder builder(TestName()); - auto pred = ConstantR0(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); auto operands = Tuple(&builder, {}); Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands, false_builder_result.ConsumeValueOrDie()); ComputeAndCompareTuple( &builder, - *Literal::MakeTuple( - {Literal::MakeTuple({Literal::CreateR0(46.6f).get(), - Literal::CreateR1({54.4f, 58.4f}).get()}) + *LiteralUtil::MakeTuple( + {LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(46.6f).get(), + LiteralUtil::CreateR1({54.4f, 58.4f}).get()}) .get(), - Literal::MakeTuple({Literal::CreateR1({62.1f, 67.4f}).get(), - Literal::CreateR0(9.3f).get()}) + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1({62.1f, 67.4f}).get(), + LiteralUtil::CreateR0(9.3f).get()}) .get()}), - {}, error_spec_); + {pred_arg.get()}, error_spec_); } // Test conditional that takes in scalar operands in the form of external @@ -511,8 +533,9 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) { EXPECT_IS_OK(inner_builder_result.status()); XlaBuilder builder(TestName()); - auto pred1 = ConstantR0(&builder, true); - auto pred2 = ConstantR0(&builder, false); + XlaOp pred1, pred2; + auto pred1_arg = CreateR0Parameter(true, 0, "pred1", &builder, &pred1); + auto pred2_arg = CreateR0Parameter(false, 1, "pred2", &builder, &pred2); auto operand1 = ConstantR0(&builder, 1.1f); auto operand2 = ConstantR0(&builder, 12.2f); auto operand3 = ConstantR0(&builder, 43.3f); @@ -520,7 +543,8 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) { Conditional(pred1, tuple_operand, inner_builder_result.ConsumeValueOrDie(), operand3, CreateR0IdentityComputation()); - ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, + {pred1_arg.get(), pred2_arg.get()}, error_spec_); } XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { @@ -539,13 +563,14 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { EXPECT_IS_OK(inner_builder_result.status()); XlaBuilder builder(TestName()); - auto pred2 = ConstantR0(&builder, false); + XlaOp pred; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); auto operand1 = ConstantR0(&builder, 1.1f); auto operand2 = ConstantR0(&builder, 12.2f); - auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2}); + auto tuple_operand = Tuple(&builder, {pred, operand1, operand2}); Call(&builder, inner_builder_result.ConsumeValueOrDie(), {tuple_operand}); - ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); + ComputeAndCompareR0(&builder, 12.0f, {pred_arg.get()}, error_spec_); } // Test a mismatch in the shape of the true operand and true computation. @@ -600,16 +625,17 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { auto test_swap = [&](float a, float b) { XlaBuilder builder(TestName()); - auto x = ConstantR0(&builder, a); - auto y = ConstantR0(&builder, b); + XlaOp x, y; + auto x_arg = CreateR0Parameter(a, 0, "x", &builder, &x); + auto y_arg = CreateR0Parameter(b, 1, "y", &builder, &y); auto tuple_operand = Tuple(&builder, {x, y}); Call(&builder, main, {tuple_operand}); ComputeAndCompareTuple( &builder, - *Literal::MakeTuple({Literal::CreateR0(a).get(), - Literal::CreateR0(b).get()}), - {}, error_spec_); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(a).get(), + LiteralUtil::CreateR0(b).get()}), + {x_arg.get(), y_arg.get()}, error_spec_); }; test_swap(3.11f, 9.4f); diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index cc5d3b11767457444d4c199943e689f082d5b199..49375748319ad5fe40db507a034ec4b07adb7e84 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -110,8 +110,8 @@ TEST_F(ConstantsTest, Small_2x2) { TEST_F(ConstantsTest, Empty_3x0x2) { XlaBuilder builder(TestName()); - ConstantLiteral( - &builder, *Literal::CreateR3FromArray3D(Array3D(3, 0, 2))); + ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D( + Array3D(3, 0, 2))); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {}); } @@ -126,7 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) { {{5.f, 6.f}, // y0 {7.f, 8.f}}, // y1 }); - ConstantLiteral(&builder, *Literal::CreateR3FromArray3D(array3d)); + ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D(array3d)); ComputeAndCompareR3(&builder, array3d, {}); } @@ -141,7 +141,7 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { }); input_array.FillWithPZ(pz); std::unique_ptr input_literal = - Literal::CreateR4FromArray4D(input_array); + LiteralUtil::CreateR4FromArray4D(input_array); { XlaBuilder builder(TestName()); @@ -159,22 +159,23 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { // TODO(b/29263943): Support tuple constants. TEST_F(ConstantsTest, DISABLED_TupleConstant) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *Literal::MakeTuple( - {Literal::CreateR2({{1.0}, {2.0}}).get(), - Literal::CreateR1({2.0, 42}).get()})); + ConstantLiteral(&builder, + *LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), + LiteralUtil::CreateR1({2.0, 42}).get()})); std::unique_ptr result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); - LiteralTestUtil::ExpectR2Near( - {{1.0}, {2.0}}, LiteralSlice(*result, {0}), error_spec_); - LiteralTestUtil::ExpectR1Near( - {2.0, 42.0}, LiteralSlice(*result, {1}), error_spec_); + LiteralTestUtil::ExpectR2Near({{1.0}, {2.0}}, + LiteralSlice(*result, {0}), error_spec_); + LiteralTestUtil::ExpectR1Near({2.0, 42.0}, LiteralSlice(*result, {1}), + error_spec_); } TEST_F(ConstantsTest, Token) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *Literal::CreateToken()); + ConstantLiteral(&builder, *LiteralUtil::CreateToken()); // TODO(b/80000000): tokens cannot be returned from computations. Tuple(&builder, {}); TF_ASSERT_OK(Execute(&builder, {}).status()); diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 292942a49e2f0c4b077dc71c9d0e730909689e3a..1adc68cc4839dcd7d89741ec016f27bc9047c9a5 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -52,13 +53,67 @@ TEST_F(ConvertTest, ConvertR1S32ToR1S32) { ComputeAndCompareR1(&builder, expected, {}); } +TEST_F(ConvertTest, ConvertR1S32ToR1U32) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {42, 64}); + ConvertElementType(a, U32); + + std::vector expected = {42, 64}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1S32ToR1PRED) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {42, 0, -64}); + ConvertElementType(a, PRED); + + std::array expected = {true, false, true}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1U32ToR1U32) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {42, 64}); + ConvertElementType(a, U32); + + std::vector expected = {42, 64}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1U32ToR1S32) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {42, 64}); + ConvertElementType(a, S32); + + std::vector expected = {42, 64}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1U32ToR1PRED) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {42, 0, 64}); + ConvertElementType(a, PRED); + + std::array expected = {true, false, true}; + ComputeAndCompareR1(&builder, expected, {}); +} + TEST_F(ConvertTest, ConvertR1F32ToR1F32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {42.0f, 64.0f}); ConvertElementType(a, F32); std::vector expected = {42.0f, 64.0f}; - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1F32ToR1PRED) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {42.0f, 0.0f, 64.0f}); + ConvertElementType(a, PRED); + + std::array expected = {true, false, true}; + ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1S32ToR1F32) { @@ -67,7 +122,7 @@ TEST_F(ConvertTest, ConvertR1S32ToR1F32) { ConvertElementType(a, F32); std::vector expected = {42.0f, 64.0f}; - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1PREDToR1S32) { @@ -79,6 +134,15 @@ TEST_F(ConvertTest, ConvertR1PREDToR1S32) { ComputeAndCompareR1(&builder, expected, {}); } +TEST_F(ConvertTest, ConvertR1PREDToR1U32) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {true, false, true}); + ConvertElementType(a, U32); + + std::vector expected = {1, 0, 1}; + ComputeAndCompareR1(&builder, expected, {}); +} + TEST_F(ConvertTest, ConvertR1PREDToR1F32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {true, false, true}); @@ -94,7 +158,7 @@ XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { ConvertElementType(a, F32); std::vector expected = {}; - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1F32ToR1S32) { @@ -145,7 +209,7 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { static_cast(0x8000008000000000LL), static_cast(0x8000010000000000LL), }; - std::unique_ptr arg_literal = Literal::CreateR1({arg}); + std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); @@ -164,7 +228,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000000, 0x80000001, 0x80000002, 0x80000003, 0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF}; - std::unique_ptr arg_literal = Literal::CreateR1({arg}); + std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); @@ -182,7 +246,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XlaBuilder builder(TestName()); std::vector arg{0.0f, 1.0f, 16777216.0f, 16777218.0f, 2147483647.0f, 4294967040.0f}; - std::unique_ptr arg_literal = Literal::CreateR1({arg}); + std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); @@ -199,7 +263,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; - std::unique_ptr arg_literal = Literal::CreateR1({arg}); + std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); @@ -216,7 +280,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, -1, -0x1000}; - std::unique_ptr arg_literal = Literal::CreateR1({arg}); + std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); @@ -253,7 +317,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { 9223370937343148032.f, -9223371487098961920.f, -9223370937343148032.f}; - std::unique_ptr arg_literal = Literal::CreateR1({arg}); + std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); @@ -391,7 +455,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, - client_->TransferToServer(*Literal::CreateR1(input))); + client_->TransferToServer(*LiteralUtil::CreateR1(input))); XlaBuilder builder(TestName()); ConvertElementType( @@ -411,7 +475,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, - client_->TransferToServer(*Literal::CreateR1(input))); + client_->TransferToServer(*LiteralUtil::CreateR1(input))); XlaBuilder builder(TestName()); ConvertElementType( diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 7605ebf4c0eacd7f44e867e23dbc27c6c1bc3e93..7b6bbc4f571af2e11306f95c24e243e78e0f4f4e 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -93,7 +93,8 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, auto weight_array = MakeUnique>(4, 3, 1, 1); weight_array->FillWithMultiples(0.2); auto weight_data = - client_->TransferToServer(*Literal::CreateR4FromArray4D(*weight_array)) + client_ + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array)) .ConsumeValueOrDie(); XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 0f6d54d042dd6af6d82e1eea93a66c2e9be53639..5ed8122e0073bde77bb2507a0ddd89c4365627c9 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -23,9 +23,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -123,8 +123,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { })); ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, + {std::move(*LiteralUtil::CreateFromArray(input_data)), + std::move(*LiteralUtil::CreateFromArray(filter_data))}, error_spec_); } }; @@ -157,8 +157,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { {7.0f, 8.0f}, })); ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, + {std::move(*LiteralUtil::CreateFromArray(input_data)), + std::move(*LiteralUtil::CreateFromArray(filter_data))}, error_spec_); } }; @@ -192,8 +192,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { })); ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, + {std::move(*LiteralUtil::CreateFromArray(input_data)), + std::move(*LiteralUtil::CreateFromArray(filter_data))}, error_spec_); } }; @@ -224,8 +224,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}})); // clang-format on ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, + {std::move(*LiteralUtil::CreateFromArray(input_data)), + std::move(*LiteralUtil::CreateFromArray(filter_data))}, error_spec_); } }; @@ -249,10 +249,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { Array3D expected({{{510, 610, 710, 810}}}); auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -284,10 +284,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { Array3D expected({{{570.0f, 670.0f, 770.0f}}}); auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -319,10 +319,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { Array3D expected({{{190, 320, 230, 380, 270, 440, 310, 500}}}); auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -350,10 +350,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { Array3D expected({{{510, 0, 610, 0, 710, 0, 810}}}); auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -386,10 +386,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}}); auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -434,15 +434,15 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota(input_elems.begin(), input_elems.end(), 1.0f); - auto input_r1 = Literal::CreateR1(input_elems); + auto input_r1 = LiteralUtil::CreateR1(input_elems); auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota(filter_elems.begin(), filter_elems.end(), 1.0f); - auto filter_r1 = Literal::CreateR1(filter_elems); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); - auto expected_r1 = Literal::CreateR1( + auto expected_r1 = LiteralUtil::CreateR1( {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446, 38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470}); auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie(); @@ -497,15 +497,15 @@ class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota_int_init_value(input_elems, 1); - auto input_r1 = Literal::CreateR1(input_elems); + auto input_r1 = LiteralUtil::CreateR1(input_elems); auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota_int_init_value(filter_elems, 1); - auto filter_r1 = Literal::CreateR1(filter_elems); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); - auto expected_r1 = Literal::CreateR1( + auto expected_r1 = LiteralUtil::CreateR1( {static_cast(92115), static_cast(93150), static_cast(94185)}); auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie(); @@ -561,8 +561,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, expected_result.Fill(0); ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(param0)), - std::move(*Literal::CreateFromArray(param1))}, + {std::move(*LiteralUtil::CreateFromArray(param0)), + std::move(*LiteralUtil::CreateFromArray(param1))}, error_spec_); } @@ -617,18 +617,18 @@ class Convolve1D1WindowTestBase std::vector input_elems(ShapeUtil::ElementsIn(input_shape), static_cast(1.0f)); - auto input_r1 = Literal::CreateR1(input_elems); + auto input_r1 = LiteralUtil::CreateR1(input_elems); auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), static_cast(1.0f)); - auto filter_r1 = Literal::CreateR1(filter_elems); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); std::vector expect_elems(batch * output_feature * num_windows, static_cast(window_size * input_feature)); - auto expected_r1 = Literal::CreateR1(expect_elems); + auto expected_r1 = LiteralUtil::CreateR1(expect_elems); auto expected_r3 = expected_r1->Reshape({batch, num_windows, output_feature}) .ConsumeValueOrDie(); @@ -737,8 +737,8 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { })); ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, + {std::move(*LiteralUtil::CreateFromArray(input_data)), + std::move(*LiteralUtil::CreateFromArray(filter_data))}, error_spec_); } @@ -761,8 +761,8 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { filter_data.FillIota(10); ComputeAndCompare(&builder, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}); + {std::move(*LiteralUtil::CreateFromArray(input_data)), + std::move(*LiteralUtil::CreateFromArray(filter_data))}); } } // namespace diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index c31d033bb0f0e52d40251c4d7b64d52f42d29dc6..6784c16715da72d337edf70fa51db42c59404136 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -27,8 +27,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -1333,17 +1333,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) { XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { XlaBuilder builder(TestName()); - auto gradients_flat = Literal::CreateR1({1}); + auto gradients_flat = LiteralUtil::CreateR1({1}); auto gradients_literal = gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); auto gradients = ConstantLiteral(&builder, *gradients_literal); - auto weights_flat = Literal::CreateR1({1, 10, 100}); + auto weights_flat = LiteralUtil::CreateR1({1, 10, 100}); auto weights_literal = weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); auto weights = ConstantLiteral(&builder, *weights_literal); - auto expected_flat = Literal::CreateR1({10}); + auto expected_flat = LiteralUtil::CreateR1({10}); auto expected_literal = expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); @@ -1357,17 +1357,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { XlaBuilder builder(TestName()); - auto activations_flat = Literal::CreateR1({1, 2, 3, 4}); + auto activations_flat = LiteralUtil::CreateR1({1, 2, 3, 4}); auto activations_literal = activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie(); auto activations = ConstantLiteral(&builder, *activations_literal); - auto gradients_flat = Literal::CreateR1({100, 10, 1}); + auto gradients_flat = LiteralUtil::CreateR1({100, 10, 1}); auto gradients_literal = gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); auto gradients = ConstantLiteral(&builder, *gradients_literal); - auto expected_flat = Literal::CreateR1({13, 24, 130}); + auto expected_flat = LiteralUtil::CreateR1({13, 24, 130}); auto expected_literal = expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index fef42885e516fa8c8f87756d7a953fe5f37a630f..5ef273e5a26ea8a16db864974c9bfa2c296cbce8 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -58,37 +58,38 @@ class CopyOpTest : public HloTestBase { }; XLA_TEST_F(CopyOpTest, CopyR0Bool) { - TestCopyOp(*Literal::CreateR0(true)); + TestCopyOp(*LiteralUtil::CreateR0(true)); } XLA_TEST_F(CopyOpTest, CopyR1S0U32) { - TestCopyOp(*Literal::CreateR1({})); + TestCopyOp(*LiteralUtil::CreateR1({})); } XLA_TEST_F(CopyOpTest, CopyR1S3U32) { - TestCopyOp(*Literal::CreateR1({1, 2, 3})); + TestCopyOp(*LiteralUtil::CreateR1({1, 2, 3})); } XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) { - TestCopyOp(*Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + TestCopyOp( + *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) { - TestCopyOp(*Literal::CreateR4( + TestCopyOp(*LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) { - TestCopyOp(*Literal::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); + TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); } XLA_TEST_F(CopyOpTest, CopyParameterScalar) { auto builder = HloComputation::Builder(TestName()); // Copy literal to device to use as parameter. - auto literal = Literal::CreateR0(42.0); + auto literal = LiteralUtil::CreateR0(42.0); Shape shape = literal->shape(); auto param0 = builder.AddInstruction( @@ -109,7 +110,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { auto builder = HloComputation::Builder(TestName()); - auto literal = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -131,7 +132,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { HloComputation::Builder builder(TestName()); std::unique_ptr literal = - Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); // Reverse the minor-to-major order of the literal. Layout* literal_layout = literal->mutable_shape_do_not_use()->mutable_layout(); @@ -168,7 +169,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = Literal::CreateR3FromArray3D(a); + std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -202,7 +203,7 @@ void CopyOpTest::TestCopyConstantLayoutR4( HloComputation::Builder builder(TestName()); - std::unique_ptr literal = Literal::CreateR4FromArray4D(a); + std::unique_ptr literal = LiteralUtil::CreateR4FromArray4D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc index b151187c4b8f01c5b46ccadf27d2e22a7c902e98..d12a4e7fcd7813775a81677bcaa07af60ff9b477 100644 --- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc +++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -45,7 +45,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); - auto literal = Literal::CreateR1({1, 2, 3}); + auto literal = LiteralUtil::CreateR1({1, 2, 3}); EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()})); } @@ -66,10 +66,10 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); - auto literal0 = Literal::CreateR1({1, 2, 3}); - auto literal1 = Literal::CreateR1({10, 20}); + auto literal0 = LiteralUtil::CreateR1({1, 2, 3}); + auto literal1 = LiteralUtil::CreateR1({10, 20}); EXPECT_EQ( - *Literal::MakeTuple({literal0.get(), literal1.get()}), + *LiteralUtil::MakeTuple({literal0.get(), literal1.get()}), *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()})); } @@ -93,9 +93,9 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); - auto literal0 = Literal::CreateR1({1, 2, 3}); - auto literal1 = Literal::CreateR1({10, 20}); - EXPECT_EQ(*Literal::MakeTuple({literal0.get(), literal1.get()}), + auto literal0 = LiteralUtil::CreateR1({1, 2, 3}); + auto literal1 = LiteralUtil::CreateR1({10, 20}); + EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}), *ExecuteAndTransfer(std::move(module), {literal0.get()})); } diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index d1516a28b0bb3857d9aee0922a252e25a8f9d2d5..13c777835eb2d2519d39205cdc96f0aac4850c7d 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" @@ -74,7 +74,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateCustomCall(r0f32_, {constant}, "R0F32Add2")); @@ -95,7 +95,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { array(1, 1) = 4.0f; auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2FromArray2D(array))); + HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(array))); builder.AddInstruction( HloInstruction::CreateCustomCall(r0f32_, {constant}, "R2F32ReduceSum")); @@ -111,7 +111,7 @@ XLA_TEST_F(CustomCallTest, auto b = HloComputation::Builder(TestName()); auto input = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2FromArray2D( + HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D( Array2D{{1.0f, 2.0f}, {3.0f, 4.0f}}))); auto incremented = b.AddInstruction(HloInstruction::CreateCustomCall( ShapeUtil::MakeShape(F32, {1, 2, 2}), {input}, "Add1ToValues")); diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index d4b3aac85bff283515088f6e61c9d2bad11f60d3..5f234f36a8543ad408fb3430b27844beb16a54b5 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index acba67491d25007ab774530fd7ca236a4363b6f0..2db6503afab748d7b778e26b2f9350ac64c7778b 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -171,7 +171,7 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) { XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({3.14f, -100.25f}); + LiteralUtil::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); diff --git a/tensorflow/compiler/xla/tests/deep_graph_test.cc b/tensorflow/compiler/xla/tests/deep_graph_test.cc index 810947ab01b69b10b6ae60c551bd7aba10a6313d..3f3e8ab712fea14be9e4a7015effdf8ce518309b 100644 --- a/tensorflow/compiler/xla/tests/deep_graph_test.cc +++ b/tensorflow/compiler/xla/tests/deep_graph_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" namespace xla { diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 33d79aebb18b353ae26ca4b7006d4ce8356c2c1f..0e9e92ed996fbb34826d19b670c7c4920a1aad13 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -67,15 +67,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { XlaOp param; auto param_data = CreateParameterAndTransferLiteral( 0, - *Literal::MakeTuple({Literal::CreateR2({{1, 2}, {3, 4}}).get(), - Literal::CreateR2({{5, 6}, {7, 8}}).get()}), + *LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1, 2}, {3, 4}}).get(), + LiteralUtil::CreateR2({{5, 6}, {7, 8}}).get()}), "arg0", &builder, ¶m); auto lhs = GetTupleElement(param, 0); auto rhs = GetTupleElement(param, 1); Dot(lhs, rhs); ComputeAndCompareLiteral(&builder, - *Literal::CreateR2({{19, 22}, {43, 50}}), + *LiteralUtil::CreateR2({{19, 22}, {43, 50}}), {param_data.get()}); } @@ -110,7 +111,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) { this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, OneElementVectorDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR1(&builder, {static_cast(2.0f)}); @@ -136,7 +137,7 @@ std::vector MinorToMajorForIsRowMajor(bool row_major) { return {row_major ? 1 : 0, row_major ? 0 : 1}; } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x0) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 2)); @@ -147,7 +148,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x3) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR2FromArray2D(&builder, Array2D(0, 2)); @@ -159,7 +160,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_3x2_2x0) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR2FromArray2D( @@ -171,7 +172,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_2x0_0x2) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto lhs = ConstantR2FromArray2D(&builder, Array2D(2, 0)); @@ -182,7 +183,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { &builder, Array2D(2, 2, static_cast(0.0f)), {}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) { using T = TypeParam; XlaBuilder builder(this->TestName()); auto param0 = @@ -194,11 +195,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) { auto lhs_handle = this->client_ - ->TransferToServer(*Literal::CreateR2FromArray2D( + ->TransferToServer(*LiteralUtil::CreateR2FromArray2D( {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}})) .ConsumeValueOrDie(); auto rhs_handle = this->client_ - ->TransferToServer(*Literal::CreateR2FromArray2D( + ->TransferToServer(*LiteralUtil::CreateR2FromArray2D( {{1.0f}, {2.0f}, {3.0f}, {4.0f}})) .ConsumeValueOrDie(); @@ -217,14 +218,14 @@ class SquareMatrixDot : public DotOperationTest { void TestImpl(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*Literal::CreateFromArrayWithLayout( + ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 2.0f}, {3.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*Literal::CreateFromArrayWithLayout( + ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 6.0f}, {7.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) @@ -286,9 +287,10 @@ void ParametricDotTest::TestImpl() { std::unique_ptr> dot_lhs_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); - std::unique_ptr dot_lhs_lit = Literal::CreateR2FromArray2DWithLayout( - *dot_lhs_data, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(param.dot_lhs_row_major))); + std::unique_ptr dot_lhs_lit = + LiteralUtil::CreateR2FromArray2DWithLayout( + *dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor( + param.dot_lhs_row_major))); std::unique_ptr dot_lhs_handle = client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie(); @@ -297,7 +299,7 @@ void ParametricDotTest::TestImpl() { Layout rhs_layout = LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.dot_rhs_row_major)); std::unique_ptr dot_rhs_lit = - Literal::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout); + LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout); std::unique_ptr dot_rhs_handle = client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie(); @@ -307,7 +309,7 @@ void ParametricDotTest::TestImpl() { if (param.has_addend) { addend_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.n); - addend_lit = Literal::CreateR2FromArray2DWithLayout( + addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout( *addend_data, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.addend_row_major))); addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie(); @@ -476,14 +478,14 @@ class NonsquareMatrixDot : public DotOperationTest { void TestImpl(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*Literal::CreateFromArrayWithLayout( + ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*Literal::CreateFromArrayWithLayout( + ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) @@ -510,12 +512,12 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); } XLA_TEST_F(DotOperationTest, MatrixVectorC64) { auto lhs_handle = client_ - ->TransferToServer(*Literal::CreateR2WithLayout( + ->TransferToServer(*LiteralUtil::CreateR2WithLayout( {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*Literal::CreateR2WithLayout( + ->TransferToServer(*LiteralUtil::CreateR2WithLayout( {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); @@ -531,7 +533,7 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) { &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, ConcurrentMatMult) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ConcurrentMatMult) { using T = TypeParam; XlaBuilder builder(this->TestName()); @@ -583,7 +585,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); auto x_data = this->client_ - ->TransferToServer(*Literal::CreateR4FromArray4D( + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( {{{{1000.0f, 100.0f}, {10.0f, 1.0f}}, {{2000.0f, 200.0f}, {20.0f, 2.0f}}}, {{{3000.0f, 300.0f}, {30.0f, 3.0f}}, @@ -591,7 +593,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { .ConsumeValueOrDie(); auto y_data = this->client_ - ->TransferToServer(*Literal::CreateR4FromArray4D( + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, {{{11.0f, 22.0f}, {33.0f, 44.0f}}, {{55.0f, 66.0f}, {77.0f, 88.0f}}}})) @@ -610,7 +612,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { {x_data.get(), y_data.get()}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) { using T = TypeParam; XlaBuilder builder(this->TestName()); @@ -629,13 +631,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { auto x_data = this->client_ - ->TransferToServer(*Literal::CreateR3FromArray3D( + ->TransferToServer(*LiteralUtil::CreateR3FromArray3D( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) .ConsumeValueOrDie(); auto y_data = this->client_ - ->TransferToServer(*Literal::CreateR3FromArray3D( + ->TransferToServer(*LiteralUtil::CreateR3FromArray3D( {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}})) .ConsumeValueOrDie(); @@ -646,7 +648,49 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { {x_data.get(), y_data.get()}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { + using T = TypeParam; + + XlaBuilder builder(this->TestName()); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), + "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), + "y"); + + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(3); + dnums.add_rhs_contracting_dimensions(2); + dnums.add_lhs_batch_dimensions(0); + dnums.add_lhs_batch_dimensions(1); + dnums.add_rhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(1); + + DotGeneral(x, y, dnums); + + auto x_data = + this->client_ + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + {{{9.0f, 10.0f}, {11.0f, 12.0f}}, + {{13.0f, 14.0f}, {15.0f, 16.0f}}}})) + .ConsumeValueOrDie(); + + auto y_data = + this->client_ + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}, + {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}})) + .ConsumeValueOrDie(); + + this->template ComputeAndCompareR4( + &builder, + /*expected=*/ + {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + {{{10.0f, 9.0f}, {12.0f, 11.0f}}, {{14.0f, 13.0f}, {16.0f, 15.0f}}}}, + {x_data.get(), y_data.get()}, this->error_spec_); +} + +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) { using T = TypeParam; for (bool transpose_lhs : {false, true}) { for (bool transpose_rhs : {false, true}) { @@ -664,15 +708,17 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { } auto lhs_handle = this->client_ - ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( - *lhs, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(row_major)))) + ->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout( + *lhs, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); auto rhs_handle = this->client_ - ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( - *rhs, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(row_major)))) + ->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout( + *rhs, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); XlaBuilder builder(this->TestName()); @@ -704,7 +750,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { } } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, DotOfConcatOptimizationWithConstLHS) { using T = TypeParam; auto prim_type = primitive_util::NativeToPrimitiveType(); @@ -733,15 +779,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_0_value_array))); + *LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_1_value_array))); + *LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_2_value_array))); + *LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); Array2D expected({{53.0f, 74.0f}, {45.0f, 66.0f}}); this->template ComputeAndCompareR2( @@ -750,7 +796,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, DotOfConcatOptimizationWithConstRHS) { using T = TypeParam; std::unique_ptr> constant_rhs_array( @@ -782,15 +828,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_0_value_array))); + *LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_1_value_array))); + *LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, this->client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_2_value_array))); + *LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); Array2D expected({{38.0f, 36.0f}, {93.0f, 91.0f}}); this->template ComputeAndCompareR2( @@ -853,10 +899,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) { ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstRHSReverseMM)))) { + + DotOfGatherOptimizationWithConstRHSReverseMM) { std::unique_ptr> constant_lhs_array( new Array2D({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, @@ -883,10 +928,7 @@ XLA_TEST_F(DotOperationTest, ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstLHSReverseMM)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) { std::unique_ptr> constant_lhs_array( new Array2D({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, @@ -913,10 +955,7 @@ XLA_TEST_F(DotOperationTest, ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstRHSRows)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) { std::unique_ptr> constant_lhs_array( new Array2D({{1.0, 2.0}, {3.0, 4.0}, @@ -948,10 +987,7 @@ XLA_TEST_F(DotOperationTest, ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstLHSRows)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) { std::unique_ptr> constant_lhs_array( new Array2D({{1.0, 2.0}, {3.0, 4.0}, @@ -983,10 +1019,7 @@ XLA_TEST_F(DotOperationTest, ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstRHSCols)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) { std::unique_ptr> constant_lhs_array(new Array2D( {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); std::unique_ptr> constant_rhs_array( @@ -1010,10 +1043,7 @@ XLA_TEST_F(DotOperationTest, ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -// TODO (b/69062148) Enable when Dot implements general contracting dimensions. -XLA_TEST_F(DotOperationTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( - DotOfGatherOptimizationWithConstLHSCols)))) { +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) { std::unique_ptr> constant_lhs_array(new Array2D( {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); std::unique_ptr> constant_rhs_array( @@ -1036,5 +1066,28 @@ XLA_TEST_F(DotOperationTest, Array2D expected({{168.0}, {168.0}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } + +XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) { + XlaBuilder builder(TestName()); + + Array2D lhs_array({{1.0f, 2.0f}, {3.0f, 4.0f}}); + auto lhs_constant = ConstantR2FromArray2D(&builder, lhs_array); + + Array2D rhs_array({{5.0f, 6.0f}, {7.0f, 8.0f}}); + auto rhs_constant = ConstantR2FromArray2D(&builder, rhs_array); + + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + DotGeneral(lhs_constant, rhs_constant, dot_dnums); + + Array2D expected({ + {26.f, 30.f}, + {38.f, 44.f}, + }); + + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index f3c258a4d4c446c465320ac16ef7c72e299a51a8..7f6f203a1ba48e0053f799c58bbbeae87aef1f7f 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" @@ -124,11 +124,11 @@ class DynamicSliceTest : public ClientLibraryTestBase { // vector is special so that it cannot be an ArraySlice, which // is what the code below wants. So instead we do this. Literal input_values = - std::move(*Literal::CreateR1(input_values_int) + std::move(*LiteralUtil::CreateR1(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_values = - std::move(*Literal::CreateR1(expected_values_int) + std::move(*LiteralUtil::CreateR1(expected_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -150,11 +150,11 @@ class DynamicSliceTest : public ClientLibraryTestBase { const std::vector& slice_sizes, const Array2D& expected_values_int) { Literal input_values = - std::move(*Literal::CreateR2FromArray2D(input_values_int) + std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_values = - std::move(*Literal::CreateR2FromArray2D(expected_values_int) + std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -176,11 +176,11 @@ class DynamicSliceTest : public ClientLibraryTestBase { const std::vector& slice_sizes, const Array3D& expected_values_int) { Literal input_values = - std::move(*Literal::CreateR3FromArray3D(input_values_int) + std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_values = - std::move(*Literal::CreateR3FromArray3D(expected_values_int) + std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -202,18 +202,28 @@ XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R1OOB) { TestR1OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } +XLA_TEST_F(DynamicSliceTest, UInt32R1OOB) { + RunR1({0, 1, 2, 3, 4}, {2147483648u}, {2}, {3, 4}); +} XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R2OOB) { TestR2OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2(); } +XLA_TEST_F(DynamicSliceTest, UInt32R2OOB) { + RunR2({{0, 1}, {2, 3}}, {2147483648u, 0}, {1, 1}, {{2}}); +} XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R3OOB) { TestR3OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3(); } XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } +XLA_TEST_F(DynamicSliceTest, UInt32R3OOB) { + RunR3({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}, + {2147483648u, 0, 2147483648u}, {1, 1, 1}, {{{5}}}); +} XLA_TEST_F(DynamicSliceTest, Int32R1Pred) { // Slice at dimension start. @@ -349,15 +359,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { void RunR0(int input_value_int, int update_value_int, const std::vector slice_starts, int expected_value_int) { Literal input_value = - std::move(*Literal::CreateR0(input_value_int) + std::move(*LiteralUtil::CreateR0(input_value_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal update_value = - std::move(*Literal::CreateR0(update_value_int) + std::move(*LiteralUtil::CreateR0(update_value_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_value = - std::move(*Literal::CreateR0(expected_value_int) + std::move(*LiteralUtil::CreateR0(expected_value_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -380,15 +390,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, tensorflow::gtl::ArraySlice expected_values_int) { Literal input_values = - std::move(*Literal::CreateR1(input_values_int) + std::move(*LiteralUtil::CreateR1(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal update_values = - std::move(*Literal::CreateR1(update_values_int) + std::move(*LiteralUtil::CreateR1(update_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_values = - std::move(*Literal::CreateR1(expected_values_int) + std::move(*LiteralUtil::CreateR1(expected_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -411,15 +421,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, const Array2D& expected_values_int) { Literal input_values = - std::move(*Literal::CreateR2FromArray2D(input_values_int) + std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal update_values = - std::move(*Literal::CreateR2FromArray2D(update_values_int) + std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_values = - std::move(*Literal::CreateR2FromArray2D(expected_values_int) + std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -442,15 +452,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, const Array3D& expected_values_int) { Literal input_values = - std::move(*Literal::CreateR3FromArray3D(input_values_int) + std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal update_values = - std::move(*Literal::CreateR3FromArray3D(update_values_int) + std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); Literal expected_values = - std::move(*Literal::CreateR3FromArray3D(expected_values_int) + std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); @@ -520,7 +530,7 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { template void DumpArray(const string& name, const Array3D values) { std::unique_ptr literal = - Literal::CreateR3FromArray3D(values); + LiteralUtil::CreateR3FromArray3D(values); LOG(INFO) << name << ":" << literal->ToString(); } }; @@ -530,21 +540,32 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R0) { TestR0(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R0) { TestR0(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R0) { TestR0(); } -// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. XLA_TEST_F(DynamicUpdateSliceTest, Int32R1BF16) { TestR1(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt32R1OOB) { + RunR1({0, 1, 2, 3, 4}, {5, 6}, {2147483648u}, {0, 1, 2, 5, 6}); +} XLA_TEST_F(DynamicUpdateSliceTest, Int32R2BF16) { TestR2(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt32R2OOB) { + RunR2({{0, 1}, {2, 3}}, {{4}}, {2147483648u, 0}, + {{0, 1}, {4, 3}}); +} XLA_TEST_F(DynamicUpdateSliceTest, Int32R3BF16) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt32R3OOB) { + RunR3({{{0, 1}, {2, 3}}, {{4, 5}, {6, 7}}}, {{{8}}}, + {2147483648u, 0, 2147483648u}, + {{{0, 1}, {2, 3}}, {{4, 8}, {6, 7}}}); +} XLA_TEST_F(DynamicUpdateSliceTest, Int32OOBBF16) { TestOOB(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32OOB) { TestOOB(); } @@ -695,7 +716,7 @@ void BM_DynamicSlice(int num_iters) { XlaBuilder builder("DynamicSlice"); // Create input as a constant: shape [1, 2, 3, 4] - auto input_literal = Literal::CreateR4( + auto input_literal = LiteralUtil::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); auto input = ConstantLiteral(&builder, *input_literal); @@ -715,7 +736,7 @@ void BM_DynamicSlice(int num_iters) { start_indices_shape, &allocator, /*device_ordinal=*/0) .ConsumeValueOrDie(); - auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); + auto start_indices_literal = LiteralUtil::CreateR1({0, 1, 2, 3}); auto stream = client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc index ddc6a7db18760bf951023f0a684d78739f3e869d..5116e60ca63ef5f94b25b15e6616086fb9e44bbb 100644 --- a/tensorflow/compiler/xla/tests/execution_profile_test.cc +++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/platform/test.h" @@ -31,7 +31,7 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr input, client_->TransferToServer( - *Literal::CreateR2F32Linspace(1e0, 1e5, 256, 256))); + *LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256))); XlaBuilder b(TestName() + ".add"); Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1")); diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index 74cf8b213e0a03394c84008e7a2919e1a5bf1af2..bf1de02ba9dbd97db9ee31484402fe9b92385219 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -39,7 +39,7 @@ class ExhaustiveF32ElementwiseOpTest XlaBuilder builder(TestName()); std::unique_ptr input_literal = - Literal::CreateFromDimensions(F32, {input_size}); + LiteralUtil::CreateFromDimensions(F32, {input_size}); for (int64 i = begin; i < end; i++) { if (i >= known_incorrect_range.first && i < known_incorrect_range.second) { diff --git a/tensorflow/compiler/xla/tests/filecheck.cc b/tensorflow/compiler/xla/tests/filecheck.cc index 93d1c921c4a138cda55ed7338b8e3aa82518d114..dcb469087e0064d17ce3b04fdeaf0b6136069a55 100644 --- a/tensorflow/compiler/xla/tests/filecheck.cc +++ b/tensorflow/compiler/xla/tests/filecheck.cc @@ -76,6 +76,11 @@ StatusOr RunFileCheck(const string& input, const string& pattern) { XLA_LOG_LINES(tensorflow::WARNING, input); LOG(WARNING) << "FileCheck pattern was:"; XLA_LOG_LINES(tensorflow::WARNING, pattern); + } else if (!standard_error.empty()) { + LOG(INFO) << "FileCheck stderr:"; + XLA_LOG_LINES(tensorflow::INFO, standard_error); + LOG(INFO) << "FileCheck input was:"; + XLA_LOG_LINES(tensorflow::INFO, input); } return succeeded; } diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index 30dc639f117b9871238f0bf1628502cf8bef2e0c..39cc6c5927f1d416e31f689487efc10c20371abe 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc index 0254ae1baaa864b38c3b217a5c2026d34b7f7d12..c5bbbe778df15d63a2586bd6291a7a33fc82aa52 100644 --- a/tensorflow/compiler/xla/tests/fmax_test.cc +++ b/tensorflow/compiler/xla/tests/fmax_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index ab470f16a32c2363e88a11a9f7d564dcf2981f42..792be0d3fcd55621b9f8cdf0fdc28f7bb49294d1 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -25,14 +25,15 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -89,7 +90,7 @@ class FusionTest : public HloTestBase { HloInstruction* hlos[4]; for (int i = 0; i < Arity; ++i) { hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2FromArray2D(operand_data[i]))); + LiteralUtil::CreateR2FromArray2D(operand_data[i]))); } auto answer_shape = ShapeUtil::MakeShape(prim_type, {test_width, test_height}); @@ -115,7 +116,7 @@ class FusionTest : public HloTestBase { ArraySlice(hlos, 0, Arity + 1), HloInstruction::FusionKind::kLoop); - auto expected = Literal::CreateR2FromArray2D(answer_data); + auto expected = LiteralUtil::CreateR2FromArray2D(answer_data); auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); if (primitive_util::IsFloatingPointType(prim_type)) { EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4))); @@ -186,27 +187,28 @@ XLA_TEST_F(FusionTest, Test) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0}, {2.0}, {3.0}}))); + LiteralUtil::CreateR2({{1.0}, {2.0}, {3.0}}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{-1.0}, {-1.0}, {-1.0}}))); + LiteralUtil::CreateR2({{-1.0}, {-1.0}, {-1.0}}))); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1)); auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0})); auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.62, 2.72, 3.14}}))); + LiteralUtil::CreateR2({{1.62, 2.72, 3.14}}))); auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate( ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0)); auto const6 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}))); + LiteralUtil::CreateR2({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}))); auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6)); auto add8 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7)); auto const9 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}}))); - auto const10 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{true, false, true}, {false, true, false}}))); + LiteralUtil::CreateR2({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}}))); + auto const10 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2( + {{true, false, true}, {false, true, false}}))); auto select11 = builder.AddInstruction( HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kSelect, const10, add8, const9)); @@ -222,7 +224,7 @@ XLA_TEST_F(FusionTest, Test) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2({{0.5}, {2.72}}), + *LiteralUtil::CreateR2({{0.5}, {2.72}}), *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } @@ -233,11 +235,11 @@ XLA_TEST_F(FusionTest, Parameter) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 2.0, 3.0}}))); + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}}))); auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0)); auto const2 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{-2.0, -2.0, -2.0}}))); + LiteralUtil::CreateR2({{-2.0, -2.0, -2.0}}))); // add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1} auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2)); @@ -248,7 +250,7 @@ XLA_TEST_F(FusionTest, Parameter) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2({{-1.0, 0.0, 1.0}}), + *LiteralUtil::CreateR2({{-1.0, 0.0, 1.0}}), *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } @@ -269,7 +271,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { auto hlo_module = CreateNewModule(); auto two = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); auto x = builder.AddInstruction(HloInstruction::CreateBroadcast(shape, two, {})); auto y = builder.AddInstruction( @@ -292,9 +294,9 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.0, 2.0, 3.0}))); + LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto const_array = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}))); + LiteralUtil::CreateR2({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}))); auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1})); // add2 = broadcast(const_vector) + const_array @@ -308,7 +310,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *Literal::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), + *LiteralUtil::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } @@ -316,14 +318,14 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto single_element_array = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{5}}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR2({{5}}))); auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {}), single_element_array)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR0(5), + LiteralTestUtil::Equal(*LiteralUtil::CreateR0(5), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -331,14 +333,14 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); + LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 2, 3}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), + *LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -346,14 +348,14 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); + LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}), + *LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -361,14 +363,14 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR3({{{7}}}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR3({{{7}}}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR0(7), + LiteralTestUtil::Equal(*LiteralUtil::CreateR0(7), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -376,14 +378,14 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(7))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR3({{{7}}}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR3({{{7}}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -391,14 +393,14 @@ XLA_TEST_F(FusionTest, Reshape__) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(7))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR0(7), + LiteralTestUtil::Equal(*LiteralUtil::CreateR0(7), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -406,14 +408,14 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), + *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -421,14 +423,14 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0})); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR2({{1, 4}, {2, 5}, {3, 6}}), + *LiteralUtil::CreateR2({{1, 4}, {2, 5}, {3, 6}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -436,14 +438,14 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0})); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), + *LiteralUtil::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -451,7 +453,7 @@ XLA_TEST_F(FusionTest, Reverse) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( ShapeUtil::MakeShape(S32, {3}), const0, {0})); hlo_module->AddEntryComputation(builder.Build()) @@ -459,7 +461,7 @@ XLA_TEST_F(FusionTest, Reverse) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR1({3, 2, 1}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR1({3, 2, 1}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -467,7 +469,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( ShapeUtil::MakeShape(S32, {3}), const0, {0})); auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -477,7 +479,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR1({-3, -2, -1}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-3, -2, -1}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -485,7 +487,7 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(S32, {2}), const0, {})); auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -495,15 +497,15 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR1({-1, -1}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-1, -1}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, SliceNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(S32, {2}), const0, {0}, {4}, {2})); auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -513,17 +515,17 @@ XLA_TEST_F(FusionTest, SliceNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR1({-1, -3}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-1, -3}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DynamicSliceNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1}))); auto dynamic_slice2 = builder.AddInstruction(HloInstruction::CreateDynamicSlice( ShapeUtil::MakeShape(S32, {2}), const0, const1, {2})); @@ -535,15 +537,15 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR1({-2, -3}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-2, -3}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReshapeNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 3, 4}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {2, 2}), const0)); auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -552,16 +554,16 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1}, HloInstruction::FusionKind::kLoop); - EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{-1, -2}, {-3, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{-1, -2}, {-3, -4}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, TransposeNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2}, {3, 4}}))); + LiteralUtil::CreateR2({{1, 2}, {3, 4}}))); auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {2, 2}), const0, {1, 0})); auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -570,9 +572,9 @@ XLA_TEST_F(FusionTest, TransposeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1}, HloInstruction::FusionKind::kLoop); - EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR2({{-1, -3}, {-2, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR2({{-1, -3}, {-2, -4}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } std::unique_ptr MakeReduceTestComputation() { @@ -590,10 +592,10 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { auto hlo_module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 4, 8}))); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 4, 8}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); @@ -602,7 +604,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR0(15), + LiteralTestUtil::Equal(*LiteralUtil::CreateR0(15), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -610,10 +612,10 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { auto hlo_module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 4, 8}))); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2, 4, 8}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); @@ -624,7 +626,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR0(-15), + LiteralTestUtil::Equal(*LiteralUtil::CreateR0(-15), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -632,9 +634,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); + LiteralUtil::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); Window window; ASSERT_TRUE( tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n" @@ -674,7 +676,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::CreateR2({{462, 2145}, {24871, 62491}}), + *LiteralUtil::CreateR2({{462, 2145}, {24871, 62491}}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -686,9 +688,9 @@ XLA_TEST_F(FusionTest, SharedConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0)); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -710,7 +712,7 @@ XLA_TEST_F(FusionTest, SharedConstant) { EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6); EXPECT_TRUE( - LiteralTestUtil::Equal(*Literal::CreateR1({8}), + LiteralTestUtil::Equal(*LiteralUtil::CreateR1({8}), *ExecuteAndTransfer(std::move(hlo_module), {}))); } @@ -764,6 +766,79 @@ XLA_TEST_F(FusionTest, Clamp2D) { TestElementwise2D(HloOpcode::kClamp); } +// TODO(b/73903144): Enable on interpreter once interpreter supports bitcast. +XLA_TEST_F(FusionTest, DISABLED_ON_INTERPRETER(FusionWithLayout)) { + const string hlo_text = R"( +HloModule Cluster + +fusion_c { + fusion.arg = f32[2,2]{1,0} parameter(0) + bitcast.0 = f32[2,2,1]{2,1,0} bitcast(fusion.arg) + tanh.0 = f32[2,2,1]{0,2,1} tanh(bitcast.0) + ROOT bitcast.2 = f32[2,2,1]{1,2,0} bitcast(tanh.0) +} + +ENTRY main { + arg = f32[2,2]{1,0} parameter(0) + ROOT fusion = f32[2,2,1]{1,2,0} fusion(arg), kind=kLoop, calls=fusion_c +} +)"; + + std::unique_ptr operand = + LiteralUtil::CreateR2({{0., 0.}, {1., 0.}}); + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text, config)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + test_runner_.Execute(std::move(module), {operand.get()}, + /*run_hlo_passes=*/false)); + EXPECT_TRUE(LiteralTestUtil::Equal( + *LiteralUtil::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), + *result)); +} + +class FusionClientLibraryTest : public ClientLibraryTestBase {}; + +XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) { + // On the GPU backend, it's possible to have too many transposes within one + // fusion, causing the kernel to run out shared memory and thus not compile. + // We want to check that doesn't happen. + // + // To do this, we create a computation that computes + // + // P0 + P0*P1*P1 + P0*P2*P2 ... + // + // where even parameters have layout 1 and odd parameters have layout 2. + // + // Our goal is to tempt the backend into creating one giant multi-output + // fusion for the whole computation, including the transposes. Currently + // multi-output fusion only fuses fusions, so each of the terms in the sum + // needs to be a fusion itself, thus the contortions above. + constexpr int kNumParams = 25; + XlaBuilder b("ManyLayoutTransformations"); + + // This test produces values that overflow int32, which is UB, so use uint32, + // where overflow is OK. + Array2D arr(32, 32); + arr.FillUnique(); + std::unique_ptr l1 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout( + LayoutUtil::MakeLayout({0, 1})); + + std::unique_ptr l2 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout( + LayoutUtil::MakeLayout({1, 0})); + + XlaOp p0 = AddParam(*l1, &b); + XlaOp sum = p0; + for (int i = 1; i < kNumParams; ++i) { + auto pN = AddParam((i % 2 == 0 ? *l1 : *l2), &b); + sum = sum + p0 * pN * pN; + } + + ComputeAndCompare(&b, {}); +} + void BM_ParallelFusion(int num_iters) { // Simple element-wise computation to benchmark parallel task partitioning. tensorflow::testing::StopTiming(); @@ -804,19 +879,19 @@ void BM_ParallelFusion(int num_iters) { // Transfer literals to device. auto param0_literal = - Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); + LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); ScopedShapedBuffer buffer0 = client->LiteralToShapedBuffer(*param0_literal, device_ordinal) .ConsumeValueOrDie(); auto param1_literal = - Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); + LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); ScopedShapedBuffer buffer1 = client->LiteralToShapedBuffer(*param1_literal, device_ordinal) .ConsumeValueOrDie(); auto param2_literal = - Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); + LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); ScopedShapedBuffer buffer2 = client->LiteralToShapedBuffer(*param2_literal, device_ordinal) .ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index b8404826b161b9edbbd260d73c175cce935ace91..b77bece85ad1b2192b04330af9e60d3a424b59f4 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -22,9 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -// NB! TODO(b/74360564): These tests do not test out of bounds behavior since -// that hasn't been specced yet. - namespace xla { namespace { @@ -63,8 +60,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -84,8 +82,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -105,9 +104,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 2}, {2, 1}}); + LiteralUtil::CreateR2({{0, 2}, {2, 1}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -127,9 +126,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -149,9 +148,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -171,11 +170,11 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -195,11 +194,11 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -219,8 +218,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({1, 1}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -240,9 +240,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR2({{2, 1}, {1, 1}}); + LiteralUtil::CreateR2({{2, 1}, {1, 1}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -261,18 +261,15 @@ ENTRY main { window_bounds={1, 0} } )"; - std::unique_ptr operand = Literal::CreateR2({{}, {}, {}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); RunTest(hlo_text, operand.get(), gather_indices.get()); } XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) { // Out of bounds indices must not crash, and the indices in range should // produce the same values across all backends. - // - // TODO(b/74360564): Once we have a well defined semantics for OOB accesses, - // we should get rid of the mask and check that backends produce the same - // value for OOB indices too. const string hlo_text = R"( HloModule BatchDynamicSlice @@ -286,29 +283,45 @@ ENTRY main { gather_dims_to_operand_dims={0,1}, index_vector_dim=1, window_bounds={1,1} - gather_reshaped = s32[6]{0} reshape(gather) - in_bounds_mask = s32[6]{0} parameter(2) - ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask) + ROOT result = s32[6]{0} reshape(gather) } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR2( + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); - std::unique_ptr in_bounds_mask = - Literal::CreateR1({0, 1, 1, 0, 0, 1}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) { + // Out of bounds indices must not crash, and the indices in range should + // produce the same values across all backends. - RunTest(hlo_text, - {operand.get(), gather_indices.get(), in_bounds_mask.get()}); + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = u32[6,2]{1,0} parameter(1) + gather = s32[6,1,1]{2,1,0} gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} + ROOT result = s32[6]{0} reshape(gather) +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = LiteralUtil::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); } XLA_TEST_F(GatherOperationTest, NegativeIndex) { // Negative indices must not crash, and the indices in range should produce // the same values across all backends. - // - // TODO(b/74360564): Once we have a well defined semantics for negative - // accesses, we should get rid of the mask and check that backends produce the - // same value for negative indices too. const string hlo_text = R"( HloModule BatchDynamicSlice @@ -322,20 +335,40 @@ ENTRY main { gather_dims_to_operand_dims={0,1}, index_vector_dim=1, window_bounds={1,1} - gather_reshaped = s32[6]{0} reshape(gather) - in_bounds_mask = s32[6]{0} parameter(2) - ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask) + ROOT result = s32[6]{0} reshape(gather) } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR2( + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = LiteralUtil::CreateR2( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - std::unique_ptr in_bounds_mask = - Literal::CreateR1({0, 1, 1, 0, 0, 1}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) { + // Negative indices must not crash, and the indices in range should produce + // the same values across all backends. - RunTest(hlo_text, - {operand.get(), gather_indices.get(), in_bounds_mask.get()}); + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = u32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + gather = u32[6,1,1]{2,1,0} gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} + ROOT result = u32[6]{0} reshape(gather) +} +)"; + std::unique_ptr operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = LiteralUtil::CreateR2( + {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); } XLA_TEST_F(GatherOperationTest, OneScalarIndex) { @@ -353,9 +386,9 @@ ENTRY main { window_bounds={1,3,2} } )"; - std::unique_ptr operand = Literal::CreateR3( + std::unique_ptr operand = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - std::unique_ptr gather_indices = Literal::CreateR0(1); + std::unique_ptr gather_indices = LiteralUtil::CreateR0(1); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -374,8 +407,8 @@ ENTRY main { window_bounds={1} } )"; - std::unique_ptr operand = Literal::CreateR1({1, 2, 3, 4}); - std::unique_ptr gather_indices = Literal::CreateR0(1); + std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3, 4}); + std::unique_ptr gather_indices = LiteralUtil::CreateR0(1); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -395,8 +428,8 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = LiteralUtil::CreateR1({}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -419,8 +452,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({0, 2}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -443,9 +477,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 2}, {2, 1}}); + LiteralUtil::CreateR2({{0, 2}, {2, 1}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -468,9 +502,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -493,11 +527,11 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -521,11 +555,11 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // - {{-4, 4}, {-5, 5}, {-6, 6}}, // - {{-7, 7}, {-8, 8}, {-9, 9}}}); + LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = - Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralUtil::CreateR2({{0, 0}, {1, 0}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -548,8 +582,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + LiteralUtil::CreateR1({1, 1}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -572,9 +607,9 @@ ENTRY main { } )"; std::unique_ptr operand = - Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = - Literal::CreateR2({{2, 1}, {1, 1}}); + LiteralUtil::CreateR2({{2, 1}, {1, 1}}); RunTest(hlo_text, operand.get(), gather_indices.get()); } @@ -609,12 +644,13 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { Gather(operand, indices, dim_numbers, {1, 3}); std::vector expected = {}; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr operand_arg, - client_->TransferToServer(*Literal::CreateR2( - {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr operand_arg, + client_->TransferToServer( + *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr indices_arg, - client_->TransferToServer(*Literal::CreateR1({0, 2}))); + client_->TransferToServer(*LiteralUtil::CreateR1({0, 2}))); TF_ASSERT_OK_AND_ASSIGN(std::vector devices, client_->GetDeviceHandles(1)); xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions(); diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc index 76bf47845ca045b4eede9a3b47ae5c2ce93ce577..51450314b611b49c643fb6fd5b0c0d2e7205a2d2 100644 --- a/tensorflow/compiler/xla/tests/half_test.cc +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -37,8 +37,7 @@ class HalfTestBase : public ClientLibraryTestBase { static const int kNumElements = 4; }; -using UnaryBuildFuncTy = - std::function; +using UnaryBuildFuncTy = std::function; struct UnaryOpTestParam { std::function compute_func; @@ -49,7 +48,8 @@ class UnaryOpTest : public HalfTestBase, public ::testing::WithParamInterface {}; XLA_TEST_P(UnaryOpTest, Ops) { - std::vector x({half(1.4), half(-2.3), half(3.2), half(-4.1)}); + std::vector x({half(1.4), half(-2.3), half(3.2), half(-4.1), half(9.0), + half(42.0), half(-9.0), half(-100.0)}); XlaBuilder builder(TestName()); XlaOp x_opnd; auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", @@ -62,7 +62,7 @@ XLA_TEST_P(UnaryOpTest, Ops) { } UnaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd); + build_func(x_opnd); ComputeAndCompareR1(&builder, expected, {x_data.get()}, error_spec_); } @@ -79,18 +79,17 @@ half round_imp(half value) { INSTANTIATE_TEST_CASE_P( half, UnaryOpTest, ::testing::Values( - UnaryOpTestParam{[](half x) { return abs(x); }, &XlaBuilder::Abs}, - UnaryOpTestParam{[](half x) { return round_imp(x); }, - &XlaBuilder::Round}, - UnaryOpTestParam{[](half x) { return ceil(x); }, &XlaBuilder::Ceil}, - UnaryOpTestParam{[](half x) { return cos(x); }, &XlaBuilder::Cos}, - UnaryOpTestParam{[](half x) { return exp(x); }, &XlaBuilder::Exp}, - UnaryOpTestParam{[](half x) { return floor(x); }, &XlaBuilder::Floor}, - UnaryOpTestParam{[](half x) { return log(x); }, &XlaBuilder::Log}, - UnaryOpTestParam{[](half x) { return -x; }, &XlaBuilder::Neg}, - UnaryOpTestParam{[](half x) { return sign_imp(x); }, &XlaBuilder::Sign}, - UnaryOpTestParam{[](half x) { return sin(x); }, &XlaBuilder::Sin}, - UnaryOpTestParam{[](half x) { return tanh(x); }, &XlaBuilder::Tanh} + UnaryOpTestParam{[](half x) { return abs(x); }, &Abs}, + UnaryOpTestParam{[](half x) { return round_imp(x); }, &Round}, + UnaryOpTestParam{[](half x) { return ceil(x); }, &Ceil}, + UnaryOpTestParam{[](half x) { return cos(x); }, &Cos}, + UnaryOpTestParam{[](half x) { return exp(x); }, &Exp}, + UnaryOpTestParam{[](half x) { return floor(x); }, &Floor}, + UnaryOpTestParam{[](half x) { return log(x); }, &Log}, + UnaryOpTestParam{[](half x) { return -x; }, &Neg}, + UnaryOpTestParam{[](half x) { return sign_imp(x); }, &Sign}, + UnaryOpTestParam{[](half x) { return sin(x); }, &Sin}, + UnaryOpTestParam{[](half x) { return tanh(x); }, &Tanh} )); @@ -118,19 +117,18 @@ XLA_TEST_P(UnaryPredTest, Ops) { } UnaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd); + build_func(x_opnd); ComputeAndCompareR1(&builder, expected, {x_data.get()}); } INSTANTIATE_TEST_CASE_P(half, UnaryPredTest, ::testing::Values(UnaryPredTestParam{ - [](half x) { return isfinite(x); }, - &XlaBuilder::IsFinite})); + [](half x) { return isfinite(x); }, &IsFinite})); -using BinaryBuildFuncTy = std::function)>; +using BinaryBuildFuncTy = + std::function)>; struct BinaryOpTestParam { std::function compute_func; @@ -159,7 +157,7 @@ XLA_TEST_P(BinaryOpTest, Ops) { } BinaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd, y_opnd, {}); + build_func(x_opnd, y_opnd, {}); ComputeAndCompareR1(&builder, expected, {x_data.get(), y_data.get()}, error_spec_); @@ -173,22 +171,15 @@ half atan2_imp(half x, half y) { INSTANTIATE_TEST_CASE_P( half, BinaryOpTest, ::testing::Values( - BinaryOpTestParam{[](half x, half y) { return x + y; }, - &XlaBuilder::Add}, + BinaryOpTestParam{[](half x, half y) { return x + y; }, &Add}, BinaryOpTestParam{[](half x, half y) { return atan2_imp(x, y); }, - &XlaBuilder::Atan2}, - BinaryOpTestParam{[](half x, half y) { return x / y; }, - &XlaBuilder::Div}, - BinaryOpTestParam{[](half x, half y) { return max(x, y); }, - &XlaBuilder::Max}, - BinaryOpTestParam{[](half x, half y) { return min(x, y); }, - &XlaBuilder::Min}, - BinaryOpTestParam{[](half x, half y) { return x * y; }, - &XlaBuilder::Mul}, - BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, - &XlaBuilder::Pow}, - BinaryOpTestParam{[](half x, half y) { return x - y; }, - &XlaBuilder::Sub} + &Atan2}, + BinaryOpTestParam{[](half x, half y) { return x / y; }, &Div}, + BinaryOpTestParam{[](half x, half y) { return max(x, y); }, &Max}, + BinaryOpTestParam{[](half x, half y) { return min(x, y); }, &Min}, + BinaryOpTestParam{[](half x, half y) { return x * y; }, &Mul}, + BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, &Pow}, + BinaryOpTestParam{[](half x, half y) { return x - y; }, &Sub} )); @@ -221,27 +212,22 @@ XLA_TEST_P(BinaryPredTest, Ops) { } BinaryBuildFuncTy build_func = GetParam().build_func; - build_func(&builder, x_opnd, y_opnd, {}); + build_func(x_opnd, y_opnd, {}); ComputeAndCompareR1(&builder, expected, {x_data.get(), y_data.get()}); } INSTANTIATE_TEST_CASE_P( half, BinaryPredTest, - ::testing::Values(BinaryPredTestParam{[](half x, half y) { return x == y; }, - &XlaBuilder::Eq}, - BinaryPredTestParam{[](half x, half y) { return x != y; }, - &XlaBuilder::Ne}, - BinaryPredTestParam{[](half x, half y) { return x >= y; }, - &XlaBuilder::Ge}, - BinaryPredTestParam{[](half x, half y) { return x > y; }, - &XlaBuilder::Gt}, - BinaryPredTestParam{[](half x, half y) { return x <= y; }, - &XlaBuilder::Le}, - BinaryPredTestParam{[](half x, half y) { return x < y; }, - &XlaBuilder::Lt} - - )); + ::testing::Values( + BinaryPredTestParam{[](half x, half y) { return x == y; }, &Eq}, + BinaryPredTestParam{[](half x, half y) { return x != y; }, &Ne}, + BinaryPredTestParam{[](half x, half y) { return x >= y; }, &Ge}, + BinaryPredTestParam{[](half x, half y) { return x > y; }, &Gt}, + BinaryPredTestParam{[](half x, half y) { return x <= y; }, &Le}, + BinaryPredTestParam{[](half x, half y) { return x < y; }, &Lt} + + )); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc index 4d82442f7e3630c115eff1f17544e2b892c5e7eb..5511190caf95544e2ac48d91c0a138db06a2544c 100644 --- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc +++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 242cc5db11ff2bdf69209df7537216573d8afbf3..0dce1b22a331e20054ba026f23e3284d7dd0e88a 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -233,6 +233,29 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( reference_preprocessor); } +::testing::AssertionResult HloTestBase::Run(const StringPiece hlo_string) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_or_status.status().ToString(); + } + const auto& fake_arguments = + MakeFakeArguments(module_or_status.ValueOrDie().get()) + .ConsumeValueOrDie(); + std::vector fake_argument_ptrs; + c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const std::unique_ptr& literal) { return literal.get(); }); + return test_runner_ + .Execute(std::move(module_or_status.ValueOrDie()), + fake_argument_ptrs, /*run_hlo_passes=*/true) + .ok() + ? ::testing::AssertionSuccess() + : ::testing::AssertionFailure(); +} + ::testing::AssertionResult HloTestBase::RunAndCompareFromFile( const string& filename, const tensorflow::gtl::optional& error, const std::function& reference_preprocessor) { @@ -276,9 +299,10 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( HloComputation* HloTestBase::FindComputation(HloModule* module, tensorflow::StringPiece name) { - auto it = c_find_if(module->computations(), + auto computations = module->computations(); + auto it = c_find_if(computations, [&](HloComputation* c) { return c->name() == name; }); - if (it == module->computations().end()) { + if (it == computations.end()) { return nullptr; } return *it; @@ -287,9 +311,10 @@ HloComputation* HloTestBase::FindComputation(HloModule* module, HloInstruction* HloTestBase::FindInstruction(HloModule* module, tensorflow::StringPiece name) { for (const HloComputation* c : module->computations()) { - auto it = c_find_if(c->instructions(), + auto instructions = c->instructions(); + auto it = c_find_if(instructions, [&](HloInstruction* i) { return i->name() == name; }); - if (it != c->instructions().end()) { + if (it != instructions.end()) { return *it; } } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 9009d67cea6840235d63724ef76d777c8f693d33..bb55e562ad5e8fc32771d1ac23579db17a7e819d 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -166,6 +166,8 @@ class HloTestBase : public ::testing::Test { const tensorflow::gtl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; + ::testing::AssertionResult Run(const tensorflow::StringPiece hlo_string) + TF_MUST_USE_RESULT; ::testing::AssertionResult RunAndCompareFromFile( const string& filename, const tensorflow::gtl::optional& error, const std::function& reference_preprocessor = nullptr) @@ -200,6 +202,13 @@ class HloTestBase : public ::testing::Test { ->ResetLayout(layout); } + void ForceResultLayout(HloModule* module, const Layout& layout, + ShapeIndexView shape_index) { + module->mutable_entry_computation_layout() + ->mutable_result_layout() + ->ResetLayout(layout, shape_index); + } + // Convenience method to clear the layout of the computation result in // 'module'. void ForceClearResultLayout(HloModule* module) { diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..17ac95ae0198d98490b25f7f2edd32d1e0495803 --- /dev/null +++ b/tensorflow/compiler/xla/tests/iota_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace { + +class IotaTest : public ClientLibraryTestBase { + public: + explicit IotaTest(se::Platform* platform = nullptr) + : ClientLibraryTestBase(platform) {} + template + std::vector GetExpected(const int64 num_elements) { + std::vector result(num_elements); + std::iota(result.begin(), result.end(), 0); + return result; + } +}; + +XLA_TEST_F(IotaTest, SimpleR1) { + for (int num_elements = 1; num_elements < 10000001; num_elements *= 10) { + { + XlaBuilder builder(TestName() + "_f32"); + IotaGen(&builder, F32, num_elements); + ComputeAndCompareR1(&builder, GetExpected(num_elements), {}, + ErrorSpec{0.0001}); + } + { + XlaBuilder builder(TestName() + "_u32"); + IotaGen(&builder, U32, num_elements); + ComputeAndCompareR1(&builder, GetExpected(num_elements), + {}); + } + { + XlaBuilder builder(TestName() + "_s32"); + IotaGen(&builder, S32, num_elements); + ComputeAndCompareR1(&builder, GetExpected(num_elements), + {}); + } + } +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index d1b8a6cf0b2552f1b7d95a2560d502da14ddc39a..31a099c15f1f20457c90de97054f68a31eb49011 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/error_spec.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -154,20 +155,20 @@ class LiteralTestUtil { template /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR0(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR0(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR1Equal( tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR1(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR1(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2Equal( std::initializer_list> expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR2(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR2(expected), actual)); } template @@ -175,46 +176,46 @@ template std::initializer_list>> expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR3(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR3(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( const Array2D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR2FromArray2D(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR2FromArray2D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( const Array3D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR3FromArray3D(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR3FromArray3D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( const Array4D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*Literal::CreateR4FromArray4D(expected), actual)); + EXPECT_TRUE(Equal(*LiteralUtil::CreateR4FromArray4D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR0(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR0(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR1Near( tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR1(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR1(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2Near( std::initializer_list> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR2(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR2(expected), actual, error)); } template @@ -222,7 +223,7 @@ template std::initializer_list>> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR3(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR3(expected), actual, error)); } template @@ -231,28 +232,28 @@ template std::initializer_list>>> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR4(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR4(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2NearArray2D( const Array2D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR2FromArray2D(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR2FromArray2D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR3NearArray3D( const Array3D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR3FromArray3D(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR3FromArray3D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR4NearArray4D( const Array4D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*Literal::CreateR4FromArray4D(expected), actual, error)); + EXPECT_TRUE(Near(*LiteralUtil::CreateR4FromArray4D(expected), actual, error)); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index bbac7285aefbb1f028fad152e4b7fe6af01e9f6d..f297b2b847f570d26e71ddcd8e34bc626f982e1f 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -31,8 +31,9 @@ namespace xla { namespace { TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { - std::unique_ptr literal = Literal::MakeTuple({ - Literal::CreateR0(42).get(), Literal::CreateR0(64).get(), + std::unique_ptr literal = LiteralUtil::MakeTuple({ + LiteralUtil::CreateR0(42).get(), + LiteralUtil::CreateR0(64).get(), }); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal)); } @@ -42,11 +43,13 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { // un-fail an assertion failure. The CHECK-failure is death, so we can make a // death assertion. auto unequal_things_are_equal = [] { - std::unique_ptr lhs = Literal::MakeTuple({ - Literal::CreateR0(42).get(), Literal::CreateR0(64).get(), + std::unique_ptr lhs = LiteralUtil::MakeTuple({ + LiteralUtil::CreateR0(42).get(), + LiteralUtil::CreateR0(64).get(), }); - std::unique_ptr rhs = Literal::MakeTuple({ - Literal::CreateR0(64).get(), Literal::CreateR0(42).get(), + std::unique_ptr rhs = LiteralUtil::MakeTuple({ + LiteralUtil::CreateR0(64).get(), + LiteralUtil::CreateR0(42).get(), }); CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal"; }; @@ -55,8 +58,8 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { auto dummy_lambda = [] { - auto two = Literal::CreateR0(2); - auto four = Literal::CreateR0(4); + auto two = LiteralUtil::CreateR0(2); + auto four = LiteralUtil::CreateR0(4); ErrorSpec error(0.001); CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four"; }; @@ -98,8 +101,8 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { } TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { - auto expected = Literal::CreateR1({1, 2, 3}); - auto actual = Literal::CreateR1({4, 5, 6}); + auto expected = LiteralUtil::CreateR1({1, 2, 3}); + auto actual = LiteralUtil::CreateR1({4, 5, 6}); ::testing::AssertionResult result = LiteralTestUtil::Equal(*expected, *actual); EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}")); @@ -107,25 +110,26 @@ TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { } TEST(LiteralTestUtilTest, NearComparatorR1) { - auto a = - Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); - auto b = - Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + auto a = LiteralUtil::CreateR1( + {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + auto b = LiteralUtil::CreateR1( + {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); } TEST(LiteralTestUtilTest, NearComparatorR1Nan) { - auto a = - Literal::CreateR1({0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); - auto b = - Literal::CreateR1({0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); + auto a = LiteralUtil::CreateR1( + {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); + auto b = LiteralUtil::CreateR1( + {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); } TEST(LiteralTestUtil, NearComparatorDifferentLengths) { - auto a = - Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); - auto b = Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); + auto a = LiteralUtil::CreateR1( + {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + auto b = + LiteralUtil::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001})); } diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 082bc34136e004795ce300c66591758f47c665fe..e719da54d45d3e6eb3f3e14d3fa3076db2081e04 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -14,9 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_compiler.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" +#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -64,7 +65,7 @@ class LLVMCompilerTest : public ::testing::Test { // Create HLO module, and run the compiler. auto builder = HloComputation::Builder(TestName()); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); @@ -86,7 +87,7 @@ class LLVMCompilerTest : public ::testing::Test { void TestMultiModuleCompilation(LLVMCompiler *compiler) { HloComputation::Builder builder(TestName()); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); std::unique_ptr hlo_module = CreateNewModule(); hlo_module->AddEntryComputation(builder.Build()); @@ -144,7 +145,7 @@ TEST_F(CpuCompilerTest, HooksTest) { } TEST_F(GpuCompilerTest, HooksTest) { - gpu::GpuCompiler compiler; + gpu::NVPTXCompiler compiler; TestCompilerHooks(&compiler); } @@ -154,7 +155,7 @@ TEST_F(CpuCompilerTest, MultiModuleCompilation) { } TEST_F(GpuCompilerTest, MultModuleCompilation) { - gpu::GpuCompiler compiler; + gpu::NVPTXCompiler compiler; TestMultiModuleCompilation(&compiler); } } // namespace diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc index 2c45f19c090d2690878430363bf0d20252b2f3df..6fc11150978931f980349799372872f9fb68f292 100644 --- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc +++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/tests/filecheck.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,28 +26,28 @@ limitations under the License. namespace xla { -void LLVMIRGenTestBase::SetIrHook(bool match_optimized_ir) { +void LlvmIrGenTestBase::SetIrHook(bool match_optimized_ir) { auto llvm_compiler = GetLLVMCompiler(); using std::placeholders::_1; // Add the IR inspection hook to the LLVM compiler. if (match_optimized_ir) { llvm_compiler->SetPostOptimizationHook( - std::bind(&LLVMIRGenTestBase::IrHook, this, _1)); + std::bind(&LlvmIrGenTestBase::IrHook, this, _1)); } else { llvm_compiler->SetPreOptimizationHook( - std::bind(&LLVMIRGenTestBase::IrHook, this, _1)); + std::bind(&LlvmIrGenTestBase::IrHook, this, _1)); } } -void LLVMIRGenTestBase::ResetIrHook() { +void LlvmIrGenTestBase::ResetIrHook() { auto llvm_compiler = GetLLVMCompiler(); llvm_compiler->RemovePreOptimizationHook(); llvm_compiler->RemovePostOptimizationHook(); } -void LLVMIRGenTestBase::CompileAndVerifyIr( +void LlvmIrGenTestBase::CompileAndVerifyIr( std::unique_ptr hlo_module, const string& pattern, bool match_optimized_ir) { SetIrHook(match_optimized_ir); @@ -58,7 +59,17 @@ void LLVMIRGenTestBase::CompileAndVerifyIr( EXPECT_TRUE(filecheck_result.ValueOrDie()); } -void LLVMIRGenTestBase::CompileAheadOfTimeAndVerifyIr( +void LlvmIrGenTestBase::CompileAndVerifyIr(const string& hlo_text, + const string& expected_llvm_ir, + bool match_optimized_ir) { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text, config)); + CompileAndVerifyIr(std::move(module), expected_llvm_ir, match_optimized_ir); +} + +void LlvmIrGenTestBase::CompileAheadOfTimeAndVerifyIr( std::unique_ptr hlo_module, const AotCompilationOptions& options, const string& pattern, bool match_optimized_ir) { SetIrHook(match_optimized_ir); @@ -71,11 +82,11 @@ void LLVMIRGenTestBase::CompileAheadOfTimeAndVerifyIr( EXPECT_TRUE(filecheck_result.ValueOrDie()); } -LLVMCompiler* LLVMIRGenTestBase::GetLLVMCompiler() { +LLVMCompiler* LlvmIrGenTestBase::GetLLVMCompiler() { return static_cast(backend().compiler()); } -Status LLVMIRGenTestBase::IrHook(const llvm::Module& module) { +Status LlvmIrGenTestBase::IrHook(const llvm::Module& module) { ir_ = llvm_ir::DumpModuleToString(module); return Status::OK(); } diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h index 74cbb5f5df662992046a5b0f9a31e52879f375ad..018f9546afc3e408686a9ac75a74320a05b27182 100644 --- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h +++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h @@ -24,7 +24,7 @@ limitations under the License. namespace xla { // Tests that verify IR emitted by the CPU/GPU backend is as expected. -class LLVMIRGenTestBase : public CodegenTestBase { +class LlvmIrGenTestBase : public CodegenTestBase { protected: // Compiles the given HLO module to LLVM IR and verifies the IR matches the // given pattern. `pattern` is in the FileCheck pattern matching syntax @@ -38,6 +38,12 @@ class LLVMIRGenTestBase : public CodegenTestBase { void CompileAndVerifyIr(std::unique_ptr hlo_module, const string& pattern, bool match_optimized_ir); + // A thin wrapper around CompileAndVerifyIr that parses `hlo_text` to create + // an HLO module. + void CompileAndVerifyIr(const string& hlo_text, + const string& expected_llvm_ir, + bool match_optimized_ir = false); + // Compiles the given HLO module to LLVM IR and verifies the IR matches the // given pattern. `pattern` is in the FileCheck pattern matching syntax // (http://llvm.org/docs/CommandGuide/FileCheck.html). diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc index 9191be9fd905ab2e0c661042b042c8233d39e4a1..e2cd5bcc5a95f692dcf4a43d717252bfe876aa81 100644 --- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" @@ -45,7 +45,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) { TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform()); auto x_array = - LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); int64 allocation_count_before = allocator_->allocation_count(); diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc index 47cab796041e9669affaebd7866d0d80100730f1..115448c908ac9e7f0b01772ce348d23bf4d838ed 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc @@ -42,13 +42,12 @@ extern "C" void SumStructElements(float* out, void** parameters) { TEST_F(LocalClientAotTest, Constant) { xla::ExecutableRunOptions run_options; OpaqueData opaque_data{100, 20, 3}; - void* parameters[] = {&opaque_data}; float out = 0; - void* temporary_buffers[] = {nullptr, &out}; - SumAndDouble(&out, &run_options, parameters, temporary_buffers); + void* temporary_buffers[] = {&opaque_data, &out}; + SumAndDouble(&out, &run_options, nullptr, temporary_buffers); EXPECT_EQ(out, 246.0f); opaque_data = {1, 2, 3}; - SumAndDouble(&out, &run_options, parameters, temporary_buffers); + SumAndDouble(&out, &run_options, nullptr, temporary_buffers); EXPECT_EQ(out, 12.0f); } diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 70612e7c49d2815096cc54fd6ae796148249b4db..e310966d8b062f2baac00a17dd42cd449595d0d2 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -21,8 +21,8 @@ limitations under the License. #include "llvm/ADT/Triple.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/types.h" @@ -92,9 +92,10 @@ int main(int argc, char** argv) { // It's lame to hard-code the buffer assignments, but we need // local_client_aot_test.cc to be able to easily invoke the function. CHECK_EQ(result->result_buffer_index(), 1); - CHECK_EQ(result->buffer_sizes().size(), 2); - CHECK_EQ(result->buffer_sizes()[0], -1); // param buffer + CHECK_EQ(result->buffer_sizes().size(), 3); + CHECK_EQ(result->buffer_sizes()[0], -2); // param buffer CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // result buffer + CHECK_EQ(result->buffer_sizes()[2], -1); // const buffer if (triple.isOSBinFormatELF()) { // Check the ELF magic. CHECK_EQ(result->object_file_data()[0], 0x7F); diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 2c6393794ef1b1558f5e651b5cb7bfa2afa961de..1a823cf189b310c62c735419936544ea99fcfbaf 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -68,7 +68,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddScalars) { auto y = ConstantR0(&builder, 123.0f); Add(x, y); - auto x_value = LiteralToShapedBuffer(*Literal::CreateR0(42.0f)); + auto x_value = LiteralToShapedBuffer(*LiteralUtil::CreateR0(42.0f)); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value}); LiteralTestUtil::ExpectR0Near(165.f, *ShapedBufferToLiteral(result), @@ -81,7 +81,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) { auto y = ConstantR1(&builder, {}); Add(x, y); - auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({})); + auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR1({})); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); LiteralTestUtil::ExpectR1Near({}, *ShapedBufferToLiteral(result), @@ -95,7 +95,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectors) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); LiteralTestUtil::ExpectR1Near( @@ -109,7 +109,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ExecutionProfile profile; ScopedShapedBuffer result = ExecuteLocallyOrDie( builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(), @@ -128,13 +128,13 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { auto computation = builder.Build().ConsumeValueOrDie(); // Create x as a col-major array. - auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( + auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. - auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( + auto y_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout( {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); @@ -161,9 +161,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *Literal::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *Literal::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); // Run with col-major result layout. ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie( @@ -198,9 +198,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *Literal::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *Literal::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); @@ -226,9 +226,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *Literal::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *Literal::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); @@ -255,7 +255,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { Tuple(&builder, {x, y}); auto array = LiteralToShapedBuffer( - *Literal::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); ExecutableBuildOptions options = DefaultExecutableBuildOptions(); Shape shape_with_layout = ShapeUtil::MakeTupleShape( @@ -298,12 +298,12 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { Tuple(&builder, {array_sum, vector_diff}); auto computation = builder.Build().ConsumeValueOrDie(); - auto x_literal = Literal::MakeTuple( - {Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - Literal::CreateR1({42.0, 75.0, 123.0}).get()}); - auto y_literal = Literal::MakeTuple( - {Literal::CreateR1({2.0, 4.0, 6.0}).get(), - Literal::CreateR2({{55.0, 44.0}, {33.0, 22.0}}).get()}); + auto x_literal = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), + LiteralUtil::CreateR1({42.0, 75.0, 123.0}).get()}); + auto y_literal = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1({2.0, 4.0, 6.0}).get(), + LiteralUtil::CreateR2({{55.0, 44.0}, {33.0, 22.0}}).get()}); auto x_buffer = LiteralToShapedBuffer(*x_literal); auto y_buffer = LiteralToShapedBuffer(*y_literal); @@ -344,12 +344,12 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { Tuple(&builder, {negate_array, vector_sum}); auto computation = builder.Build().ConsumeValueOrDie(); - auto arg_literal = Literal::MakeTuple( - {Literal::MakeTuple( - {Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - Literal::CreateR1({42.0, 75.0, 123.0}).get()}) + auto arg_literal = LiteralUtil::MakeTuple( + {LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), + LiteralUtil::CreateR1({42.0, 75.0, 123.0}).get()}) .get(), - Literal::CreateR1({222.0, -2.0, 10.0}).get()}); + LiteralUtil::CreateR1({222.0, -2.0, 10.0}).get()}); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); @@ -377,9 +377,9 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { Tuple(&builder, {Neg(element_0), Add(element_1, element_1)}); auto computation = builder.Build().ConsumeValueOrDie(); - auto arg_literal = Literal::MakeTuple( - {Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - Literal::CreateR2({{11.0, 3.0}, {4.0, 5.0}}).get()}); + auto arg_literal = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), + LiteralUtil::CreateR2({{11.0, 3.0}, {4.0, 5.0}}).get()}); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer}); @@ -429,10 +429,10 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { // -tuple_index}. std::vector> arg_elements; for (int i = 0; i < kElementCount; ++i) { - arg_elements.push_back(Literal::CreateR1({1.0f * i, -1.0f * i})); + arg_elements.push_back(LiteralUtil::CreateR1({1.0f * i, -1.0f * i})); } std::unique_ptr arg_literal = - Literal::MakeTupleOwned(std::move(arg_elements)); + LiteralUtil::MakeTupleOwned(std::move(arg_elements)); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); @@ -480,12 +480,13 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { for (int i = 0; i < kFanout; ++i) { std::vector> inner_tuple_elements; for (int j = 0; j < kFanout; ++j) { - inner_tuple_elements.push_back(Literal::CreateR0(i + j)); + inner_tuple_elements.push_back(LiteralUtil::CreateR0(i + j)); } outer_tuple_elements.push_back( - Literal::MakeTupleOwned(std::move(inner_tuple_elements))); + LiteralUtil::MakeTupleOwned(std::move(inner_tuple_elements))); } - auto arg_literal = Literal::MakeTupleOwned(std::move(outer_tuple_elements)); + auto arg_literal = + LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements)); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); @@ -524,11 +525,11 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { auto computation = builder.Build().ConsumeValueOrDie(); // Construct the argument to pass to the computation. - std::unique_ptr arg_literal = Literal::CreateR0(123.0); + std::unique_ptr arg_literal = LiteralUtil::CreateR0(123.0); for (int i = 0; i < kTupleDepth; ++i) { std::vector> arg_vector; arg_vector.push_back(std::move(arg_literal)); - arg_literal = Literal::MakeTupleOwned(std::move(arg_vector)); + arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector)); } auto arg_buffer = LiteralToShapedBuffer(*arg_literal); @@ -551,7 +552,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*Literal::CreateR1({1.0f, 2.0f, 3.0f})); + LiteralToShapedBuffer(*LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f})); auto execute_status = ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); @@ -567,7 +568,7 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { Neg(x); auto x_array = LiteralToShapedBuffer( - *Literal::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); + *LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); @@ -584,7 +585,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) { Neg(x); auto x_array = LiteralToShapedBuffer( - *Literal::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); + *LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally( builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions().set_result_layout( @@ -767,7 +768,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { executable_status.ConsumeValueOrDie(); auto x_array = - LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ScopedShapedBuffer result = executable->Run({&x_array}, DefaultExecutableRunOptions()) .ConsumeValueOrDie(); @@ -795,29 +796,29 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { }; // Array shapes. - test_to_device_and_back(*Literal::CreateR0(42.0)); - test_to_device_and_back(*Literal::CreateR0(true)); - test_to_device_and_back(*Literal::CreateR1({1.0, 42.0, 744.4})); + test_to_device_and_back(*LiteralUtil::CreateR0(42.0)); + test_to_device_and_back(*LiteralUtil::CreateR0(true)); + test_to_device_and_back(*LiteralUtil::CreateR1({1.0, 42.0, 744.4})); test_to_device_and_back( - *Literal::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); - test_to_device_and_back(*Literal::CreateR2({{2, 1}, {4444, 56}})); + *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(*LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); // Null shape (empty tuple). - test_to_device_and_back(*Literal::MakeTuple({})); + test_to_device_and_back(*LiteralUtil::MakeTuple({})); // Non-nested tuples. test_to_device_and_back( - *Literal::MakeTuple({Literal::CreateR0(12223.0).get()})); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(12223.0).get()})); test_to_device_and_back( - *Literal::MakeTuple({Literal::CreateR1({1.0, -42.0}).get(), - Literal::CreateR0(123456.0).get()})); + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1.0, -42.0}).get(), + LiteralUtil::CreateR0(123456.0).get()})); // Nested tuple. - test_to_device_and_back(*Literal::MakeTuple( - {Literal::MakeTuple({Literal::CreateR1({1.0, -42.0}).get(), - Literal::CreateR0(123456.0).get()}) + test_to_device_and_back(*LiteralUtil::MakeTuple( + {LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1.0, -42.0}).get(), + LiteralUtil::CreateR0(123456.0).get()}) .get(), - Literal::CreateR0(false).get()})); + LiteralUtil::CreateR0(false).get()})); } XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { @@ -835,13 +836,13 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { }; test_to_device_and_back( - *Literal::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); - test_to_device_and_back(*Literal::CreateR2({{2, 1}, {4444, 56}})); + *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(*LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); test_to_device_and_back( - *Literal::CreateR2({{20000000000ULL, 1}, {4444, 56}})); - test_to_device_and_back( - *Literal::MakeTuple({Literal::CreateR1({1.0, -42.0}).get(), - Literal::CreateR0(123456789000LL).get()})); + *LiteralUtil::CreateR2({{20000000000ULL, 1}, {4444, 56}})); + test_to_device_and_back(*LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1({1.0, -42.0}).get(), + LiteralUtil::CreateR0(123456789000LL).get()})); } XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { @@ -860,7 +861,7 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { })); ASSERT_IS_OK(local_client_->TransferToInfeedLocal( - *Literal::CreateR1({-5.0, 123.0, 42.0}), + *LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), local_client_->default_device_ordinal())); // Join the thread. @@ -869,9 +870,7 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); } -// TODO(b/34359662): Support infeed/outfeed on GPU and CPU parallel. -// 2017-10-18. -XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(InfeedOutfeedTest)) { +XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) { XlaBuilder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {3}); auto in = Infeed(&builder, shape); @@ -885,7 +884,7 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(InfeedOutfeedTest)) { [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); })); ASSERT_IS_OK(local_client_->TransferToInfeedLocal( - *Literal::CreateR1({-5.0, 123.0, 42.0}), + *LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), local_client_->default_device_ordinal())); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, @@ -920,7 +919,7 @@ void BM_LocalClientOverhead(int num_iters) { transfer_manager ->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0) .ConsumeValueOrDie(); - auto literal = Literal::CreateR2({{0, 0, 0}, {0, 0, 0}}); + auto literal = LiteralUtil::CreateR2({{0, 0, 0}, {0, 0, 0}}); auto stream = client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal, diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index c31ba0e713a45d18b60bfdb9a47545cf34220333..eaddf756dbc913dd9668cd22228fbd18c2c33309 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -20,6 +20,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 258226523d830b40ecaa761df95988dc90f5ca47..b4477e9a6b23363ee3a1380f9f98f4b8226f6920 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/platform_util.h" diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc index cdf70ee4185be2ecd9dcb2d21fbd98c2ab6cc0ad..2d622242e657ce032a17f7b26c94227d343e2a38 100644 --- a/tensorflow/compiler/xla/tests/log_test.cc +++ b/tensorflow/compiler/xla/tests/log_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 1b3bc9d5040e1382f534e00ea2679ebbd48ceb59..0732e195d44d738b264361e43d38259c26a4116e 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -169,7 +169,7 @@ class MapTest : public ClientLibraryTestBase { TEST_F(MapTest, MapEachElemPlusOneR0) { // Applies lambda (x) (+ x 1)) to an input scalar. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR0(42.0); + std::unique_ptr param0_literal = LiteralUtil::CreateR0(42.0); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -183,7 +183,7 @@ TEST_F(MapTest, MapEachElemPlusOneR0) { XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR1({}); + std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -198,7 +198,7 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4. XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -212,7 +212,7 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { TEST_F(MapTest, MapEachF32ElementToS32Constant) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -225,7 +225,7 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) { TEST_F(MapTest, MapEachF32ElementToU32Constant) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -239,7 +239,7 @@ TEST_F(MapTest, MapEachElemLongerChainR1) { // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector. XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); + LiteralUtil::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -255,7 +255,7 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then // maps (lambda (x) (* x 2)) on the result. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR1({}); + std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -272,7 +272,7 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { // maps (lambda (x) (* x 2)) on the result. XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -287,7 +287,7 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { TEST_F(MapTest, MapEachElemPlusOneR2) { // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR2( + std::unique_ptr param0_literal = LiteralUtil::CreateR2( {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -343,11 +343,11 @@ TEST_F(MapTest, MapBinaryAdder) { // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors. XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - Literal::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -365,12 +365,12 @@ TEST_F(MapTest, MapBinaryAdder) { // for Map that used to fail in shape inference (b/28989438). XLA_TEST_F(MapTest, AddWithMixedLayouts) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR2WithLayout( + std::unique_ptr param0_literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0})); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = Literal::CreateR2WithLayout( + std::unique_ptr param1_literal = LiteralUtil::CreateR2WithLayout( {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1})); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -392,12 +392,12 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { XLA_TEST_F(MapTest, AddR3_3x0x2) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR3FromArray3D(Array3D(3, 0, 2)); + LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - Literal::CreateR3FromArray3D(Array3D(3, 0, 2)); + LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -414,15 +414,15 @@ TEST_F(MapTest, MapTernaryAdder) { // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors. XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - Literal::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); std::unique_ptr param2_literal = - Literal::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); + LiteralUtil::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); std::unique_ptr param2_data = client_->TransferToServer(*param2_literal).ConsumeValueOrDie(); @@ -476,11 +476,11 @@ TEST_F(MapTest, MapOperantionWithBuildError) { auto error_add = sub_builder->BuildAndNoteError(); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - Literal::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -513,8 +513,8 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { Pow(x, y); auto power = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = Literal::CreateR0(2.0f); - std::unique_ptr param1_literal = Literal::CreateR0(5.0f); + std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); + std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -540,8 +540,8 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { Sub(y, x); // note that this is y - x, not x - y auto sub_opposite = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = Literal::CreateR0(2.0f); - std::unique_ptr param1_literal = Literal::CreateR0(5.0f); + std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); + std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -565,7 +565,7 @@ TEST_F(MapTestWithFullOpt, MapSquare) { Mul(x, x); auto square = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = Literal::CreateR0(10.0f); + std::unique_ptr param0_literal = LiteralUtil::CreateR0(10.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 17b1807f44a457786906afc15d8d410f6cf2d4cd..da8c42d465340f2af3d6acd2c3676b69512f193f 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -63,8 +63,8 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { Exp(data); std::unique_ptr expected = - Literal::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 - {0.36788f, 1.64872f}}); // row 1 + LiteralUtil::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 + {0.36788f, 1.64872f}}); // row 1 this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } @@ -92,8 +92,8 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { Map(&builder, {data}, add_half, {0, 1}); std::unique_ptr expected = - Literal::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 - {-0.5f, 1.0f}}); // row 1 + LiteralUtil::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 + {-0.5f, 1.0f}}); // row 1 this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } @@ -111,8 +111,8 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { Max(lhs, rhs); std::unique_ptr expected = - Literal::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 - {3.0f, -4.0f}}); // row 1 + LiteralUtil::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 + {3.0f, -4.0f}}); // row 1 this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); } @@ -200,12 +200,14 @@ class MatOpsDotAddTest TF_ASSERT_OK_AND_ASSIGN( auto lhs_handle, - client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( - lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout( + lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); TF_ASSERT_OK_AND_ASSIGN( auto rhs_handle, - client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( - rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer( + *LiteralUtil::CreateR2FromArray2DWithLayout( + rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); XlaBuilder builder(TestName()); auto lhs_arg = Parameter(&builder, 0, lhs_shape, "lhs"); diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc index e576f000ef23e761d6fa818457eec2144d4bcb00..955dbef6dcd28421fb351c6ee064ac53eda1fd08 100644 --- a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc +++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 6597748c8d1f45391799dbe384a5afc0284de2dd..eb06b115daa96bccd73de30bb7fa30733a6fd947 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -60,7 +60,7 @@ class MultiOutputFusionTest : public HloTestBase { const Shape elem_shape2 = ShapeUtil::MakeShape(F32, {size, size}); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(8.0f))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(8.0f))); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, elem_shape0, "0")); @@ -105,8 +105,9 @@ class MultiOutputFusionTest : public HloTestBase { Literal expect(ShapeUtil::MakeShape(F32, {size, size})); expect.PopulateWithValue(size * 1.5f * 3.5f); - auto actual = ExecuteAndTransfer( - std::move(hlo_module), {Literal::CreateR0(-9.0f).get(), &arg1}); + auto actual = + ExecuteAndTransfer(std::move(hlo_module), + {LiteralUtil::CreateR0(-9.0f).get(), &arg1}); EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); } @@ -165,7 +166,8 @@ class MultiOutputFusionTest : public HloTestBase { Literal input1(ShapeUtil::MakeShape(F64, {size})); input1.PopulateWithValue(1.); - Literal expect = std::move(*Literal::CreateR1({size * 1.5f * 3.5f})); + Literal expect = + std::move(*LiteralUtil::CreateR1({size * 1.5f * 3.5f})); auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); } @@ -198,16 +200,16 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::MakeTupleOwned( - Literal::MakeTupleOwned( - Literal::MakeTupleOwned(Literal::CreateR0(42)), - Literal::CreateR0(1.0)), - Literal::MakeTupleOwned(Literal::CreateR0(3.0), - Literal::CreateR0(4))); + auto param = LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), + LiteralUtil::CreateR0(1.0)), + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(3.0), + LiteralUtil::CreateR0(4))); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned(Literal::CreateR0(42)), *result)); + *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), *result)); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { @@ -232,7 +234,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR1({1.0, 2.0, 3.0, -1.0}); + auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0, -1.0}); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, *result); @@ -265,7 +267,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR1({1.0, 2.0, 3.0}); + auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, *result); @@ -308,12 +310,14 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto param = + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned(Literal::CreateR2({{3, 7}, {11, 15}}), - Literal::CreateR2({{5, 16}, {36, 64}})), + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2({{3, 7}, {11, 15}}), + LiteralUtil::CreateR2({{5, 16}, {36, 64}})), *result)); } @@ -338,12 +342,14 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto param = + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned(Literal::CreateR2({{6, 8}, {10, 12}}), - Literal::CreateR2({{25, 36}, {49, 64}})), + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2({{6, 8}, {10, 12}}), + LiteralUtil::CreateR2({{25, 36}, {49, 64}})), *result)); } @@ -369,13 +375,14 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto param = + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned(Literal::CreateR1({14, 22}), - Literal::CreateR1({36, 64}), - Literal::CreateR1({66, 138})), + *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({14, 22}), + LiteralUtil::CreateR1({36, 64}), + LiteralUtil::CreateR1({66, 138})), *result)); } @@ -401,14 +408,15 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto param = + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned( - Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), - Literal::CreateR2({{3, 7}, {11, 15}}), - Literal::CreateR2({{5, 16}, {36, 64}})), + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), + LiteralUtil::CreateR2({{3, 7}, {11, 15}}), + LiteralUtil::CreateR2({{5, 16}, {36, 64}})), *result)); } @@ -434,14 +442,16 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto param = + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned( - Literal::CreateR2({{6, 8}, {10, 12}}), - Literal::CreateR3({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), - Literal::CreateR2({{25, 36}, {49, 64}})), + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2({{6, 8}, {10, 12}}), + LiteralUtil::CreateR3( + {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), + LiteralUtil::CreateR2({{25, 36}, {49, 64}})), *result)); } @@ -468,14 +478,16 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto param = + LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned( - Literal::CreateR1({14, 22}), - Literal::CreateR3({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), - Literal::CreateR3( + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR1({14, 22}), + LiteralUtil::CreateR3( + {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), + LiteralUtil::CreateR3( {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})), *result)); } @@ -502,15 +514,16 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - auto init1 = Literal::CreateR0(5); - auto init2 = Literal::CreateR0(6); + auto param = + LiteralUtil::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto init1 = LiteralUtil::CreateR0(5); + auto init2 = LiteralUtil::CreateR0(6); std::unique_ptr result = ExecuteNoHloPasses( std::move(module), {param.get(), init1.get(), init2.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned( - Literal::CreateR2({{167, 172}, {176, 180}}), - Literal::CreateR2({{6, 6}, {6, 8}})), + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2({{167, 172}, {176, 180}}), + LiteralUtil::CreateR2({{6, 6}, {6, 8}})), *result)); } @@ -537,19 +550,20 @@ XLA_TEST_F(MultiOutputFusionTest, auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::CreateR3( + auto param = LiteralUtil::CreateR3( {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}}); std::unique_ptr result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned( - Literal::CreateR2({{3, 7}, {11, 15}}), - Literal::CreateR2({{5, 16}, {36, 64}}), - Literal::CreateR3({{{Eigen::half(1), Eigen::half(2)}, - {Eigen::half(3), Eigen::half(4)}}, - {{Eigen::half(5), Eigen::half(6)}, - {Eigen::half(7), Eigen::half(8)}}})), + *LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR2({{3, 7}, {11, 15}}), + LiteralUtil::CreateR2({{5, 16}, {36, 64}}), + LiteralUtil::CreateR3( + {{{Eigen::half(1), Eigen::half(2)}, + {Eigen::half(3), Eigen::half(4)}}, + {{Eigen::half(5), Eigen::half(6)}, + {Eigen::half(7), Eigen::half(8)}}})), *result)); } diff --git a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0a0426adcbc1b5b89be0841fa2c4204e2b65abf4 --- /dev/null +++ b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc @@ -0,0 +1,169 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/local_client_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +// Tests that ensure outfeed instructions that are contained in nested +// computations in non-root positions are executed. + +class OutfeedInNestedComputationTest : public LocalClientTestBase {}; + +XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) { + XlaBuilder b(TestName()); + + Shape state_tuple_array_shape = ShapeUtil::MakeShape(xla::S32, {10, 5}); + Shape int_shape = ShapeUtil::MakeShape(xla::S32, {}); + Shape state_tuple_shape = + ShapeUtil::MakeTupleShape({int_shape, state_tuple_array_shape}); + Shape xfeed_shape = ShapeUtil::MakeShape(xla::S32, {2}); + + XlaOp some_buffer = Broadcast(ConstantR0(&b, 0), {10, 5}); + XlaOp num_iter = Infeed(&b, int_shape); + XlaOp init_tuple = Tuple(&b, {num_iter, some_buffer}); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation loop_cond, [&] { + // Condition: iteration variable > 0 + XlaBuilder cond_builder("loop_condition"); + XlaOp state_tuple = Parameter(&cond_builder, 0, state_tuple_shape, "state"); + XlaOp loop_counter = GetTupleElement(state_tuple, 0); + Outfeed(loop_counter, int_shape, ""); + Gt(loop_counter, ConstantR0(&cond_builder, 0)); + return cond_builder.Build(); + }()); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation loop_body, [&] { + XlaBuilder body_builder("loop_body"); + XlaOp state_tuple = Parameter(&body_builder, 0, state_tuple_shape, "state"); + XlaOp loop_counter = GetTupleElement(state_tuple, 0); + XlaOp buffer_inside = GetTupleElement(state_tuple, 1); + + // Read some stuff from Infeed. + XlaOp some_input = Infeed(&body_builder, xfeed_shape); + XlaOp sum = Add(some_input, Broadcast(loop_counter, {2})); + Outfeed(sum, xfeed_shape, ""); + + XlaOp iter_left = Sub(loop_counter, ConstantR0(&body_builder, 1)); + + Tuple(&body_builder, {iter_left, buffer_inside}); + return body_builder.Build(); + }()); + + // Build loop. + XlaOp result_tuple = While(loop_cond, loop_body, init_tuple); + GetTupleElement(result_tuple, 0); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build()); + + std::unique_ptr comp_result; + std::unique_ptr thread( + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions(), "execute_thread", [&] { + comp_result = local_client_->ExecuteAndTransfer(computation, {}) + .ConsumeValueOrDie(); + })); + + VLOG(1) << "Transferring trip count to computation"; + // Transfer number of iterations to Infeed. + TF_ASSERT_OK( + local_client_->TransferToInfeed(*LiteralUtil::CreateR0(1))); + + // Pick up value from outfeed + { + VLOG(1) << "Reading from condition outfeed"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + local_client_->TransferFromOutfeed(&int_shape)); + EXPECT_EQ(r->Get({}), 1); + } + + VLOG(1) << "Writing data to infeed"; + // Transfer some stuff to Infeed for use inside of loop. + TF_ASSERT_OK(local_client_->TransferToInfeed( + *LiteralUtil::CreateR1({10, 20}))); + + // Pick up value from outfeed + { + VLOG(1) << "Reading from body outfeed"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + local_client_->TransferFromOutfeed(&xfeed_shape)); + EXPECT_EQ(r->Get({0}), 11); + EXPECT_EQ(r->Get({1}), 21); + } + + { + VLOG(1) << "Reading from condition outfeed"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + local_client_->TransferFromOutfeed(&int_shape)); + EXPECT_EQ(r->Get({}), 0); + } + + // Joins the thread + thread.reset(); + + EXPECT_EQ(comp_result->Get({}), 0); +} + +XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { + XlaBuilder b(TestName()); + + Shape condition_shape = ShapeUtil::MakeShape(xla::PRED, {}); + Shape result_shape = ShapeUtil::MakeShape(xla::PRED, {}); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation true_computation, [&] { + XlaBuilder inner_builder("true_computation"); + XlaOp param = Parameter(&inner_builder, 0, result_shape, "param"); + Outfeed(param, result_shape, ""); + Or(param, param); + return inner_builder.Build(); + }()); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation false_computation, [&] { + XlaBuilder inner_builder("false_computation"); + Parameter(&inner_builder, 0, result_shape, "param"); + return inner_builder.Build(); + }()); + + XlaOp pred = Infeed(&b, condition_shape); + Conditional(/*predicate=*/pred, /*true_operand=*/pred, + /*true_computation=*/true_computation, /*false_operand=*/pred, + /*false_computation=*/false_computation); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build()); + + std::unique_ptr comp_result; + std::unique_ptr thread( + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions(), "execute_thread", [&] { + comp_result = local_client_->ExecuteAndTransfer(computation, {}) + .ConsumeValueOrDie(); + })); + + TF_ASSERT_OK( + local_client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + local_client_->TransferFromOutfeed(&result_shape)); + + EXPECT_EQ(r->Get({}), true); + + // Join the thread + thread.reset(); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index 2e5081bbcb64ea9416c5a9731dba43891ecceedf..ca21b0b2ba590a6daadf2c8d3d9ad213514b0f0f 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -93,8 +93,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(0); - Pad(AddParam(*Literal::CreateR1({}), &b), - AddParam(*Literal::CreateR0(0.1), &b), padding_config); + Pad(AddParam(*LiteralUtil::CreateR1({}), &b), + AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, {}, {}, DefaultErrorSpec()); } @@ -108,8 +108,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) { dimension->set_edge_padding_high(4); dimension->set_interior_padding(7); - Pad(AddParam(*Literal::CreateR1({}), &b), - AddParam(*Literal::CreateR0(0.1), &b), padding_config); + Pad(AddParam(*LiteralUtil::CreateR1({}), &b), + AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, std::vector(5, 0.1), {}, DefaultErrorSpec()); } @@ -123,8 +123,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(1); - Pad(AddParam(*Literal::CreateR1({1, 2, 3}), &b), - AddParam(*Literal::CreateR0(0.1), &b), padding_config); + Pad(AddParam(*LiteralUtil::CreateR1({1, 2, 3}), &b), + AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); std::vector expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3}); ComputeAndCompareR1(&b, expected, {}, DefaultErrorSpec()); } @@ -132,7 +132,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { XlaBuilder b(TestName()); Pad(AddParam(Array4D(2, 0, 3, 2), &b), - AddParam(*Literal::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); + AddParam(*LiteralUtil::CreateR0(1.5), &b), + r4_padding_on_dim0_dim1_); ComputeAndCompareR4(&b, Array4D(5, 2, 3, 2, 1.5f), {}, DefaultErrorSpec()); } @@ -147,7 +148,7 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { }); input->FillWithYX(input_xy); - Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0(1.5), &b), + Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(2, 3, 3, 2); @@ -166,7 +167,8 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { const float pad_value = 1.5f; Array4D input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); - Pad(AddParam(input, &b), AddParam(*Literal::CreateR0(pad_value), &b), + Pad(AddParam(input, &b), + AddParam(*LiteralUtil::CreateR0(pad_value), &b), r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(8, 5, 1, 1); @@ -205,11 +207,11 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) { const float pad_value = -5.123f; Array4D input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6}); - auto input = Literal::CreateR4FromArray4D(input_array); + auto input = LiteralUtil::CreateR4FromArray4D(input_array); input = input->Relayout(layout); - Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0(pad_value), &b), - padding_config); + Pad(AddParam(*input, &b), + AddParam(*LiteralUtil::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 1, 5, 8); expected_array.Fill(pad_value); @@ -251,11 +253,11 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { input_array(0, 0, 0, 0) = 1.0f; input_array(0, 24, 6, 6) = 2.0f; input_array(0, 17, 2, 5) = 3.0f; - auto input = Literal::CreateR4FromArray4D(input_array); + auto input = LiteralUtil::CreateR4FromArray4D(input_array); input = input->Relayout(layout); - Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0(pad_value), &b), - padding_config); + Pad(AddParam(*input, &b), + AddParam(*LiteralUtil::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 25, 17, 11); expected_array.Fill(pad_value); @@ -329,7 +331,7 @@ XLA_TEST_P(PadTestFloat, Large2DPad) { padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 + 100 * dim); } - Pad(input, AddParam(*Literal::CreateR0(0.0f), &b), padding_config); + Pad(input, AddParam(*LiteralUtil::CreateR0(0.0f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -351,7 +353,8 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) { padding_config.mutable_dimensions(1)->set_edge_padding_low(6); padding_config.mutable_dimensions(1)->set_edge_padding_high(4); padding_config.mutable_dimensions(1)->set_interior_padding(2); - Pad(input, AddParam(*Literal::CreateR0(3.14f), &b), padding_config); + Pad(input, AddParam(*LiteralUtil::CreateR0(3.14f), &b), + padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -376,7 +379,8 @@ XLA_TEST_P(PadTestFloat, High2DPad) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); + Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -403,7 +407,8 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); + Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -430,7 +435,8 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding[dim]); } - Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); + Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -446,12 +452,13 @@ XLA_TEST_P(PadTestFloat, ReducePad) { XlaComputation add = CreateScalarAddComputation(FloatType(), &b); auto reduce = - Reduce(input, AddParam(*Literal::CreateR0(0.0), &b), add, {0}); + Reduce(input, AddParam(*LiteralUtil::CreateR0(0.0), &b), add, {0}); PaddingConfig padding_config = MakeNoPaddingConfig(3); padding_config.mutable_dimensions(0)->set_edge_padding_low(1); padding_config.mutable_dimensions(0)->set_edge_padding_high(1); - Pad(reduce, AddParam(*Literal::CreateR0(0.0f), &b), padding_config); + Pad(reduce, AddParam(*LiteralUtil::CreateR0(0.0f), &b), + padding_config); Array3D expected({{{0.0, 0.0}, {0.0, 0.0}}, {{2.0, 2.0}, {2.0, 2.0}}, diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index 2620063aa492902a705690d28d8124d16184d635..f6c762e7a4bee91a26c4c2e033c3717fef6d91d0 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -21,10 +21,10 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -42,7 +42,8 @@ class ParamsTest : public ClientLibraryTestBase {}; XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR0(3.14159f); + std::unique_ptr param0_literal = + LiteralUtil::CreateR0(3.14159f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -54,7 +55,7 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR1({}); + std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -67,7 +68,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({3.14f, -100.25f}); + LiteralUtil::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -80,7 +81,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XlaBuilder builder(TestName()); string str("hello world"); - std::unique_ptr param0_literal = Literal::CreateR1U8(str); + std::unique_ptr param0_literal = LiteralUtil::CreateR1U8(str); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -94,7 +95,7 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR2FromArray2D(Array2D(3, 0)); + LiteralUtil::CreateR2FromArray2D(Array2D(3, 0)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -106,7 +107,7 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR2( + std::unique_ptr param0_literal = LiteralUtil::CreateR2( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -122,12 +123,12 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XLA_TEST_F(ParamsTest, TwoParameters) { XlaBuilder builder(TestName()); - std::unique_ptr literal0 = Literal::CreateR1({1, 2}); + std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); - std::unique_ptr literal1 = Literal::CreateR1({10, 20}); + std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); @@ -153,7 +154,7 @@ XLA_TEST_F(ParamsTest, TwoParameters) { XLA_TEST_F(ParamsTest, MissingParameter) { // Test that an error is returned when a computation with an incomplete set of // parameters (parameter numbers not contiguous from 0) is executed. - std::unique_ptr literal = Literal::CreateR0(3.14159f); + std::unique_ptr literal = LiteralUtil::CreateR0(3.14159f); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -167,12 +168,12 @@ XLA_TEST_F(ParamsTest, MissingParameter) { XLA_TEST_F(ParamsTest, UnusedParameter) { XlaBuilder builder(TestName()); - std::unique_ptr literal0 = Literal::CreateR1({1, 2}); + std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); Parameter(&builder, 0, literal0->shape(), "param0"); - std::unique_ptr literal1 = Literal::CreateR1({10, 20}); + std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); Parameter(&builder, 1, literal1->shape(), "param1"); @@ -187,11 +188,12 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { // unused expression. XlaBuilder builder(TestName()); - std::unique_ptr literal0 = Literal::CreateR1({1, 2}); + std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = Literal::CreateR1({10, 20, 30}); + std::unique_ptr literal1 = + LiteralUtil::CreateR1({10, 20, 30}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); @@ -231,7 +233,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::vector sum_value = {{entry0, entry1}}; sum_value.resize(size); - std::unique_ptr literal = Literal::CreateR1(sum_value); + std::unique_ptr literal = LiteralUtil::CreateR1(sum_value); param_data_owner.push_back( client_->TransferToServer(*literal).ConsumeValueOrDie()); XlaOp param = Parameter(&builder, i, literal->shape(), "param"); @@ -266,7 +268,7 @@ XLA_TEST_F(ParamsTest, constexpr int kParamCount = 3000; for (int i = 0; i < kParamCount; ++i) { target += i; - std::unique_ptr literal = Literal::CreateR0(i); + std::unique_ptr literal = LiteralUtil::CreateR0(i); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); XlaOp param = Parameter(&builder, i, literal->shape(), "param"); @@ -298,7 +300,7 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( std::vector params; for (int i = 0; i < kParamCount; ++i) { target += i; - std::unique_ptr literal = Literal::CreateR1({i, i}); + std::unique_ptr literal = LiteralUtil::CreateR1({i, i}); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); XlaOp param = Parameter(&builder, i, literal->shape(), "param"); @@ -322,10 +324,10 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( std::vector> elements; std::vector ptrs; for (int i = 0; i < kParamCount; ++i) { - elements.push_back(Literal::CreateR1({target + i, target + i})); + elements.push_back(LiteralUtil::CreateR1({target + i, target + i})); ptrs.push_back(elements.back().get()); } - ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data); + ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data); } // Test large number of parameters flowing into a while-loop. @@ -354,7 +356,7 @@ XLA_TEST_F(ParamsTest, std::vector params; std::vector parameter_shapes; for (int i = 0; i < kParamCount; ++i) { - std::unique_ptr literal = Literal::CreateR1({i, i}); + std::unique_ptr literal = LiteralUtil::CreateR1({i, i}); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); XlaOp param = Parameter(&builder, i, literal->shape(), "param"); @@ -364,7 +366,7 @@ XLA_TEST_F(ParamsTest, // Add bool parameter for the loop condition. Use a parameter HLO instead of a // constant because DCE may eliminate the while-body otherwise. - std::unique_ptr bool_literal = Literal::CreateR0(false); + std::unique_ptr bool_literal = LiteralUtil::CreateR0(false); param_data_owner.push_back( std::move(client_->TransferToServer(*bool_literal)).ValueOrDie()); XlaOp bool_param = @@ -421,10 +423,10 @@ XLA_TEST_F(ParamsTest, std::vector> elements; std::vector ptrs; for (int i = 0; i < kParamCount; ++i) { - elements.push_back(Literal::CreateR1({i, i})); + elements.push_back(LiteralUtil::CreateR1({i, i})); ptrs.push_back(elements.back().get()); } - ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data); + ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data); } #endif @@ -441,9 +443,9 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { std::unique_ptr data = client_ - ->TransferToServer(*Literal::MakeTuple({ - Literal::CreateR1({1, 2, 3}).get(), - Literal::CreateR1({4, 5, 6}).get(), + ->TransferToServer(*LiteralUtil::MakeTuple({ + LiteralUtil::CreateR1({1, 2, 3}).get(), + LiteralUtil::CreateR1({4, 5, 6}).get(), })) .ConsumeValueOrDie(); @@ -455,7 +457,7 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { // Verifies that passing a 2x2 with {0, 1} layout returns the same value back // when (transferred to the server and) passed through a parameter. XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { - std::unique_ptr literal = Literal::CreateR2WithLayout( + std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); XlaBuilder builder(TestName()); Parameter(&builder, 0, literal->shape(), "input"); @@ -467,7 +469,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { // As above, but for {1, 0} layout. XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { - std::unique_ptr literal = Literal::CreateR2WithLayout( + std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0})); XlaBuilder builder(TestName()); Parameter(&builder, 0, literal->shape(), "input"); @@ -478,7 +480,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { } XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { - std::unique_ptr literal = Literal::CreateR2({ + std::unique_ptr literal = LiteralUtil::CreateR2({ {1, 3}, {2, 4}, }); diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 6154ce671c37b450c7b1171190775c6b1ed8fbf6..2fc7f816b56db6f57ca835d1847476b6d622ce5e 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -29,14 +29,14 @@ namespace { class PredTest : public ClientLibraryTestBase { protected: - void TestCompare( - bool lhs, bool rhs, bool expected, - XlaOp (XlaBuilder::*op)(const xla::XlaOp&, const xla::XlaOp&, - tensorflow::gtl::ArraySlice)) { + void TestCompare(bool lhs, bool rhs, bool expected, + std::function)> + op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0(&builder, lhs); XlaOp rhs_op = ConstantR0(&builder, rhs); - (builder.*op)(lhs_op, rhs_op, {}); + op(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } }; @@ -54,27 +54,27 @@ TEST_F(PredTest, ConstantR0PredFalse) { } TEST_F(PredTest, ConstantR0PredCompareEq) { - TestCompare(true, false, false, &XlaBuilder::Eq); + TestCompare(true, false, false, &Eq); } TEST_F(PredTest, ConstantR0PredCompareNe) { - TestCompare(true, false, true, &XlaBuilder::Ne); + TestCompare(true, false, true, &Ne); } TEST_F(PredTest, ConstantR0PredCompareLe) { - TestCompare(true, false, false, &XlaBuilder::Le); + TestCompare(true, false, false, &Le); } TEST_F(PredTest, ConstantR0PredCompareLt) { - TestCompare(true, false, false, &XlaBuilder::Lt); + TestCompare(true, false, false, &Lt); } TEST_F(PredTest, ConstantR0PredCompareGe) { - TestCompare(true, false, true, &XlaBuilder::Ge); + TestCompare(true, false, true, &Ge); } TEST_F(PredTest, ConstantR0PredCompareGt) { - TestCompare(true, false, true, &XlaBuilder::Gt); + TestCompare(true, false, true, &Gt); } TEST_F(PredTest, ConstantR1Pred) { diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 8e163e885d0d6315341c213577a3beb0180b679a..326e13b3867f2f804e882e00e35850d0189ad8d7 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -177,12 +177,12 @@ XLA_TEST_F(PrngTest, Uniformity108) { EXPECT_LT(UniformChiSquared(108, 256), 132.144); } XLA_TEST_F(PrngTest, Uniformity256) { - EXPECT_LT(UniformChiSquared(256, 256), 293.248); + EXPECT_LT(UniformChiSquared(256, 512), 293.248); } XLA_TEST_F(PrngTest, MapUsingRng) { // Build a x -> (x + U[0,1)) computation. - auto build_sum_rng = [this](XlaBuilder& builder) { + auto build_sum_rng = [](XlaBuilder& builder) { auto b = builder.CreateSubBuilder("sum_with_rng"); auto x = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "input"); Add(x, @@ -193,7 +193,7 @@ XLA_TEST_F(PrngTest, MapUsingRng) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); + LiteralUtil::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr param0_data, client_->TransferToServer(*param0_literal)); diff --git a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc index 526a38e8d1dbed9cdd4a31bfbec49bc5c6bb174b..fab2a65de109c670a6854c0fc1118162acf3d312 100644 --- a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc +++ b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc index 9052b188ed09a715b6ad7c3a40dc853d02cdd70c..a080dd1732bde21712cf47b4b57538cf4040f30e 100644 --- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -95,21 +95,21 @@ XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) { *reduce_input_shape->mutable_layout() = LayoutUtil::MakeLayout(reduce_layout.input_minor_to_major); - std::unique_ptr reduce_input = - Literal::CreateR4({{ /*i0=0*/ - {/*i1=0*/ - {-0.246092796, -0.179497838, -0.161181688}, - {-0.151643038, -0.240213156, -0.198156}}, - {/*i1=1*/ - {-0.14222312, -0.162200093, -0.193907976}, - {-0.239411, -0.198166847, -0.172471642}}}, - { /*i0=1*/ - {/*i1=0*/ - {-0.22965157, -0.218723893, -0.129257083}, - {-0.188762426, -0.16123569, -0.181166649}}, - {/*i1=1*/ - {-0.241772294, -0.245131493, -0.160247207}, - {-0.179881215, -0.23383224, -0.121976733}}}}); + std::unique_ptr reduce_input = LiteralUtil::CreateR4( + {{ /*i0=0*/ + {/*i1=0*/ + {-0.246092796, -0.179497838, -0.161181688}, + {-0.151643038, -0.240213156, -0.198156}}, + {/*i1=1*/ + {-0.14222312, -0.162200093, -0.193907976}, + {-0.239411, -0.198166847, -0.172471642}}}, + { /*i0=1*/ + {/*i1=0*/ + {-0.22965157, -0.218723893, -0.129257083}, + {-0.188762426, -0.16123569, -0.181166649}}, + {/*i1=1*/ + {-0.241772294, -0.245131493, -0.160247207}, + {-0.179881215, -0.23383224, -0.121976733}}}}); EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); } diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index 4c1aa121067eed465c6128ea7a34e0284f7af43e..531648fe3eb8e3941c5e3c012847ee68c616590f 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -22,9 +22,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -230,7 +230,8 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR1({input_values}); + std::unique_ptr a_literal = + LiteralUtil::CreateR1({input_values}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal->shape(), "a"); @@ -253,7 +254,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR1({1.00001}); + std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal->shape(), "a"); @@ -282,7 +283,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR1({1.00001}); + std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal->shape(), "a"); @@ -308,7 +309,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR1({1.00001}); + std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal->shape(), "a"); @@ -332,7 +333,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR1({1.00001}); + std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal->shape(), "a"); @@ -357,7 +358,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR1({1.00001}); + std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal->shape(), "a"); diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index c9f57cbb16729627a5e9ad3d49438295a286989e..2065271a7f686c52c88df80b0efe8f2e1542d198 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -37,8 +37,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -67,12 +67,12 @@ class ReduceTest : public ClientLibraryTestBase { ReduceTest() { // Implementation note: laid out z >> y >> x by default. // clang-format off - literal_2d_ = Literal::CreateR2({ + literal_2d_ = LiteralUtil::CreateR2({ // x0 x1 x2 { 1.f, 2.f, 3.f}, // y0 { 4.f, 5.f, 6.f}, // y1 }); - literal_3d_ = Literal::CreateR3Projected({ + literal_3d_ = LiteralUtil::CreateR3Projected({ // x0 x1 x2 { 1.f, 2.f, 3.f}, // y0 { 4.f, 5.f, 6.f}, // y1 @@ -101,7 +101,7 @@ class ReduceTest : public ClientLibraryTestBase { } } std::unique_ptr input_literal = - Literal::CreateR1(AsSlice(input_data)); + LiteralUtil::CreateR1(AsSlice(input_data)); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -125,15 +125,15 @@ class ReduceTest : public ClientLibraryTestBase { XlaComputation reduce; if (and_reduce) { init_value = ConstantR0(&builder, true); - reduce = CreateScalarAndComputation(&builder); + reduce = CreateScalarAndComputation(PRED, &builder); } else { init_value = ConstantR0(&builder, false); - reduce = CreateScalarOrComputation(&builder); + reduce = CreateScalarOrComputation(PRED, &builder); } Reduce(pred_values, init_value, reduce, /*dimensions_to_reduce=*/{0}); - std::unique_ptr input_literal = Literal::CreateR1(input_data); + std::unique_ptr input_literal = LiteralUtil::CreateR1(input_data); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -163,10 +163,10 @@ class ReduceTest : public ClientLibraryTestBase { XlaComputation reduce_op; if (and_reduce) { init_value = ConstantR0(&builder, true); - reduce_op = CreateScalarAndComputation(&builder); + reduce_op = CreateScalarAndComputation(PRED, &builder); } else { init_value = ConstantR0(&builder, false); - reduce_op = CreateScalarOrComputation(&builder); + reduce_op = CreateScalarOrComputation(PRED, &builder); } Reduce(input_pred, init_value, reduce_op, @@ -175,7 +175,7 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(0, 1); std::unique_ptr input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = @@ -209,7 +209,7 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = @@ -237,7 +237,7 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = @@ -295,7 +295,7 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillUnique(initial_value); std::unique_ptr input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = @@ -450,7 +450,7 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -482,7 +482,7 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - Literal::CreateR2FromArray2D(input_data); + LiteralUtil::CreateR2FromArray2D(input_data); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -531,7 +531,7 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { Array3D input_data(rows, 2, cols / 2); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - Literal::CreateR3FromArray3D(input_data); + LiteralUtil::CreateR3FromArray3D(input_data); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -594,7 +594,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { auto max = CreateScalarMaxComputation(F32, &builder); Array2D input(300, 250); input.FillRandom(214.0f); - auto input_literal = Literal::CreateR2FromArray2D(input); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input); Reduce(ConstantLiteral(&builder, *input_literal), ConstantR0(&builder, FLT_MIN), max, {0, 1}); auto input_max = FLT_MIN; @@ -609,7 +609,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) { auto min = CreateScalarMinComputation(F32, &builder); Array2D input(150, 130); input.FillRandom(214.0f); - auto input_literal = Literal::CreateR2FromArray2D(input); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input); Reduce(ConstantLiteral(&builder, *input_literal), ConstantR0(&builder, FLT_MAX), min, {0, 1}); @@ -623,7 +623,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) { XlaBuilder builder(TestName()); Array2D input({{1}, {2}}); auto min = CreateScalarMinComputation(U32, &builder); - auto input_literal = Literal::CreateR2FromArray2D(input); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input); auto initial_value = ConstantR0(&builder, std::numeric_limits::max()); @@ -635,7 +635,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) { XlaBuilder builder(TestName()); Array2D input({{1}, {2}}); auto max = CreateScalarMaxComputation(U32, &builder); - auto input_literal = Literal::CreateR2FromArray2D(input); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input); auto initial_value = ConstantR0(&builder, std::numeric_limits::min()); @@ -798,13 +798,17 @@ XLA_TEST_F(ReduceTest, VectorizedReduce_Min) { XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanAnd) { RunVectorizedReduceTestForType( - static_cast(CreateScalarAndComputation), + static_cast([](XlaBuilder* builder) { + return CreateScalarAndComputation(PRED, builder); + }), [](bool a, bool b) { return a && b; }, true); } XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanOr) { RunVectorizedReduceTestForType( - static_cast(CreateScalarOrComputation), + static_cast([](XlaBuilder* builder) { + return CreateScalarOrComputation(PRED, builder); + }), [](bool a, bool b) { return a || b; }, false); } @@ -818,7 +822,7 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { // input_array.FillRandom(3.14f, 0.05); input_array.Fill(1.0f); - auto input_literal = Literal::CreateR3FromArray3D(input_array); + auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array); input_literal = input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout)); std::unique_ptr input_data = @@ -872,7 +876,8 @@ XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) { auto a = ConstantR0(&builder, 2.0f); auto a2 = Abs(a); - std::unique_ptr b_literal = Literal::CreateR1({1.0f, 4.0f}); + std::unique_ptr b_literal = + LiteralUtil::CreateR1({1.0f, 4.0f}); std::unique_ptr b_data = client_->TransferToServer(*b_literal).ConsumeValueOrDie(); auto b = Parameter(&builder, 0, b_literal->shape(), "b"); @@ -900,7 +905,7 @@ class ReduceInitializerTest : public ReduceTest { auto init = ConstantR0(&builder, initializer); std::vector input_arr(num_elems, std::numeric_limits::lowest()); - auto input_literal = Literal::CreateR1(input_arr); + auto input_literal = LiteralUtil::CreateR1(input_arr); auto input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init, @@ -950,10 +955,11 @@ XLA_TEST_F(ReduceTest, ReduceIdentity) { float operand[] = {42.0f}; float init = 58.5f; float expected = 42.0f; - std::unique_ptr input_literal = Literal::CreateR1(operand); + std::unique_ptr input_literal = + LiteralUtil::CreateR1(operand); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - std::unique_ptr input_literal2 = Literal::CreateR0(init); + std::unique_ptr input_literal2 = LiteralUtil::CreateR0(init); std::unique_ptr input_global_data2 = client_->TransferToServer(*input_literal2).ConsumeValueOrDie(); ComputeAndCompareR0( @@ -961,5 +967,32 @@ XLA_TEST_F(ReduceTest, ReduceIdentity) { ErrorSpec(0.0001)); } +XLA_TEST_F(ReduceTest, AndReduceU64) { + XlaBuilder builder(TestName()); + Array2D initializer = {{0x123456789ABCDEF0LL, 0x3BCDEF12A4567890LL}, + {0XFFFFFFFFFFFFFFD6LL, 101}, + {1, 0XFFFFFFFFFFFFFFFFLL}}; + auto reducer = CreateScalarAndComputation(U64, &builder); + auto m = ConstantR2FromArray2D(&builder, initializer); + Reduce(m, ConstantR0(&builder, 0xFFFFFFFFFFFFFFFFLL), reducer, {1}); + + std::vector expected = {0x1204461080145890LL, 68, 1}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(ReduceTest, OrReduceU64) { + XlaBuilder builder(TestName()); + Array2D initializer = {{0x123456789ABCDEF0LL, 0x3BCDEF12A4567890LL}, + {0xFFFFFFFFFFFFFFD6LL, 101}, + {1, 0xCAFEBEEFABABABABLL}}; + auto reducer = CreateScalarOrComputation(U64, &builder); + auto m = ConstantR2FromArray2D(&builder, initializer); + Reduce(m, ConstantR0(&builder, 0), reducer, {1}); + + std::vector expected = {0X3BFDFF7ABEFEFEF0LL, 0XFFFFFFFFFFFFFFF7LL, + 0xCAFEBEEFABABABABLL}; + ComputeAndCompareR1(&builder, expected, {}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 741974480c6a862a7794aa6257f131a5893e963d..1bd6fdab31d6c3516339bdb98459ffe3bbdef1d1 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -24,8 +24,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -70,8 +70,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - auto init = - CreateConstantFromLiteral(*Literal::CreateR0(0.0f), &builder_); + auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0(0.0f), + &builder_); ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); @@ -81,7 +81,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - auto init = CreateConstantFromLiteral(Literal::MinValue(F32), &builder_); + auto init = + CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_); ReduceWindow(input, init, CreateScalarMaxComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); @@ -91,7 +92,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - auto init = CreateConstantFromLiteral(Literal::MaxValue(F32), &builder_); + auto init = + CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_); ReduceWindow(input, init, CreateScalarMinComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); @@ -102,9 +104,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface, TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { const auto input = CreateConstantFromLiteral( - *Literal::CreateR1({1, 1, 1, 1}), &builder_); + *LiteralUtil::CreateR1({1, 1, 1, 1}), &builder_); const auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(0), &builder_); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(0), &builder_); TF_ASSERT_OK(builder_.first_error()); ReduceWindow(input, init_value, CreateScalarAddComputation(FloatType(), &builder_), @@ -119,32 +121,32 @@ TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { // Regression test for b/68964348. TEST_P(ReduceWindowTest, R0ReduceWindow) { const auto input = - CreateConstantFromLiteral(*Literal::CreateR0(42.0), &builder_); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(42.0), &builder_); const auto init = - CreateConstantFromLiteral(*Literal::CreateR0(1.0), &builder_); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(1.0), &builder_); ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), /*window_dimensions=*/{}, /*window_strides=*/{}, Padding::kSame); - ComputeAndCompareLiteral(&builder_, *Literal::CreateR0(43.0), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR0(43.0), {}, ErrorSpec(0.00001)); } TEST_P(ReduceWindowTest, Min3In5Stride2) { const auto input = CreateConstantFromLiteral( - *Literal::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + *LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, {3}, {2}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({100, 1}), {}, - ErrorSpec(0.00001)); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({100, 1}), + {}, ErrorSpec(0.00001)); } TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) { const auto input = CreateConstantFromLiteral( - *Literal::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + *LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1}, Padding::kSame); ComputeAndCompareLiteral(&builder_, - *Literal::CreateR1({1000, 100, 10, 1, 1}), {}, - ErrorSpec(0.00001)); + *LiteralUtil::CreateR1({1000, 100, 10, 1, 1}), + {}, ErrorSpec(0.00001)); } XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { @@ -156,7 +158,7 @@ XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -171,7 +173,7 @@ TEST_P(ReduceWindowTest, NonSquareSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -185,7 +187,7 @@ TEST_P(ReduceWindowTest, MiddleDimsSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -202,7 +204,7 @@ TEST_P(ReduceWindowTest, Along2ndMinorDim) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -224,8 +226,8 @@ TEST_P(ReduceWindowTest, AmongMajor2Dims) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { @@ -247,8 +249,8 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } // Tests the super windowing logic w.r.t handling prime number of windows in a @@ -272,8 +274,8 @@ TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { @@ -289,8 +291,8 @@ TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } // Tests a reduction function that is not a simple add/min/max/etc. @@ -308,12 +310,12 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { auto lhs = Parameter(b.get(), 0, scalar, "lhs"); auto rhs = Parameter(b.get(), 1, scalar, "rhs"); Min(Add(lhs, rhs), - CreateConstantFromLiteral(*Literal::CreateR0(8.0f), b.get())); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(8.0f), b.get())); XlaComputation reduce_fn = b->BuildAndNoteError(); ReduceWindow( input, - CreateConstantFromLiteral(*Literal::CreateR0(0.0f), &builder_), + CreateConstantFromLiteral(*LiteralUtil::CreateR0(0.0f), &builder_), reduce_fn, /*window_dimensions=*/{1, 1, 2, 1}, /*window_strides=*/{1, 1, 1, 1}, padding); @@ -327,15 +329,15 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { /*window=*/{1, 1, 2, 1}, /*stride=*/{1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*expected), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*expected), + {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R4UnitWindow) { Array4D input_array(13, 12, 8, 15); input_array.FillRandom(2.f, 2.f); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( @@ -347,7 +349,7 @@ TEST_P(ReduceWindowTest, R4UnitWindow) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -376,7 +378,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) { auto shape = ShapeUtil::MakeShape(F32, input_dims); std::unique_ptr arg_literal = - Literal::CreateFullWithDescendingLayout(input_dims, 1.0f); + LiteralUtil::CreateFullWithDescendingLayout(input_dims, 1.0f); const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); @@ -385,7 +387,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) { std::vector output_dims = {8, 8, 6, 6, 8, 8}; std::unique_ptr expected = - Literal::CreateFullWithDescendingLayout(output_dims, 9.0f); + LiteralUtil::CreateFullWithDescendingLayout(output_dims, 9.0f); ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } @@ -394,7 +396,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { Array4D input_array(2, 1, 27, 119); input_array.FillRandom(2.0f); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( @@ -408,7 +410,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -416,7 +418,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { Array4D input_array(3, 2, 4, 64); input_array.FillRandom(2.0f); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( @@ -430,7 +432,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -438,7 +440,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { Array4D input_array(1, 3, 12, 200); input_array.FillRandom(2.0f); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( @@ -452,7 +454,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -473,18 +475,18 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), + {}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) { std::vector input_vector(128 * 9, 1); const auto input = CreateConstantFromLiteral( - *Literal::CreateR1(input_vector), &builder_); + *LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {32}, {128}, Padding::kValid); ComputeAndCompareLiteral( &builder_, - *Literal::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, + *LiteralUtil::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, DefaultErrorSpec()); } @@ -499,9 +501,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; const auto input = CreateConstantFromLiteral( - *Literal::CreateR1(input_vector), &builder_); + *LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {128}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({1088}), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({1088}), {}, DefaultErrorSpec()); } @@ -516,9 +518,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; const auto input = CreateConstantFromLiteral( - *Literal::CreateR1(input_vector), &builder_); + *LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {1}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({1088}), {}, + ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({1088}), {}, DefaultErrorSpec()); } @@ -535,14 +537,15 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd( input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, + *LiteralUtil::CreateFromArray(*res), {}, + DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { Array2D input_array(6, 4, 1.0f); XlaOp input = Broadcast( - CreateConstantFromLiteral(Literal::One(F32), &builder_), {6, 4}); + CreateConstantFromLiteral(LiteralUtil::One(F32), &builder_), {6, 4}); Padding padding = Padding::kSame; ReduceWindowAdd(input, {4, 2}, {3, 3}, padding); @@ -550,8 +553,9 @@ TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3}, padding); - ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, + *LiteralUtil::CreateFromArray(*res), {}, + DefaultErrorSpec()); } INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest, @@ -609,7 +613,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, param.base_bounds[2], param.base_bounds[3]); input.FillIota(1); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", @@ -621,7 +625,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, } auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); auto computation = param.reducer == kAdd ? CreateScalarAddComputation(FloatType(), &b) @@ -647,7 +651,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*stride=*/param.strides, /*padding=*/padding); std::unique_ptr expected_literal = - Literal::CreateFromArray(*expected); + LiteralUtil::CreateFromArray(*expected); const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout( input_literal->shape().element_type(), AsInt64Slice(expected_literal->shape().dimensions()), param.layout); @@ -959,14 +963,14 @@ TEST_P(R3ReduceWindowTest, Add) { Array3D input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2], 1.0f); std::unique_ptr input_literal = - Literal::CreateR3FromArray3DWithLayout( + LiteralUtil::CreateR3FromArray3DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); ReduceWindow(/*operand=*/parameter, /*init_value=*/init_value, /*computation=*/CreateScalarAddComputation(FloatType(), &b), @@ -977,7 +981,7 @@ TEST_P(R3ReduceWindowTest, Add) { /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/param.padding); - ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected), {input_arg.get()}, DefaultErrorSpec()); } @@ -1093,7 +1097,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, const float kInitValue = 0.0f; Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); std::unique_ptr input_literal = - Literal::CreateR2FromArray2DWithLayout( + LiteralUtil::CreateR2FromArray2DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; @@ -1107,7 +1111,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, @@ -1123,7 +1127,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected), {input_arg.get()}, DefaultErrorSpec()); } }; @@ -1292,7 +1296,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { std::vector input_vector(param.base_bounds[0]); std::iota(std::begin(input_vector), std::end(input_vector), 0); std::unique_ptr input_literal = - Literal::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); + LiteralUtil::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); @@ -1304,7 +1308,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); auto init_value = - CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, @@ -1323,7 +1327,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *Literal::CreateR1(*expected), + ComputeAndCompareLiteral(&b, *LiteralUtil::CreateR1(*expected), {input_arg.get()}, DefaultErrorSpec()); } diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index bebd814fa8b863428750dc12a93d1ef5ad7e6685..d8914513819415368a628eab1f482f9644dd46b1 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -91,10 +91,10 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { // Run it. std::unique_ptr x_data = - client_->TransferToServer(*Literal::CreateR0(2)) + client_->TransferToServer(*LiteralUtil::CreateR0(2)) .ConsumeValueOrDie(); std::unique_ptr y_data = - client_->TransferToServer(*Literal::CreateR0(3)) + client_->TransferToServer(*LiteralUtil::CreateR0(3)) .ConsumeValueOrDie(); std::unique_ptr literal = client_ diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index 5812fe442b25da1b7e34494d00fe8025d29b2802..368f5583c9ce3773e57b858ff7606f679346529a 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -22,9 +22,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index d3d6c3c7d703161e433740acbbd58d51ba1434af..382d1b1ae741285dcd1f7761edb82a5c333887af 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -55,39 +55,39 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { XlaBuilder builder(TestName()); Array2D input_array(1, 1); input_array.Fill(1.0f); - auto input_literal = Literal::CreateR2FromArray2D(input_array); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); - auto expected_literal = Literal::CreateR1({1.0f}); + auto expected_literal = LiteralUtil::CreateR1({1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateR1({1.0f}); + auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{}); - auto expected_literal = Literal::CreateR1({1.0f}); + auto expected_literal = LiteralUtil::CreateR1({1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateR1({1.0f}); + auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0}); - auto expected_literal = Literal::CreateR1({1.0f}); + auto expected_literal = LiteralUtil::CreateR1({1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -97,7 +97,7 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { XlaBuilder builder(TestName()); Array2D input_array(1, 1); input_array.Fill(1.0f); - auto input_literal = Literal::CreateR2FromArray2D(input_array); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); @@ -105,7 +105,7 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { /*new_sizes=*/{}); auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie(); - auto expected_literal = Literal::CreateR0(1.0f); + auto expected_literal = LiteralUtil::CreateR0(1.0f); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -113,14 +113,14 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = Literal::CreateR0(1.0f); + std::unique_ptr param0_literal = LiteralUtil::CreateR0(1.0f); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", &builder, ¶meter); auto a = Neg(parameter); Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); - auto expected_literal = Literal::CreateR1({-1.0f}); + auto expected_literal = LiteralUtil::CreateR1({-1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -128,12 +128,12 @@ XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { XLA_TEST_P(ReshapeTest, Trivial0x3) { XlaBuilder builder(TestName()); Array2D input_array(0, 3); - auto input_literal = Literal::CreateR2FromArray2D(input_array); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); - auto expected_literal = Literal::CreateR1({}); + auto expected_literal = LiteralUtil::CreateR1({}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -142,12 +142,12 @@ XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) { XlaBuilder builder(TestName()); std::unique_ptr param0_literal = - Literal::CreateR2FromArray2D(Array2D(0, 3)); + LiteralUtil::CreateR2FromArray2D(Array2D(0, 3)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); - auto expected_literal = Literal::CreateR1({}); + auto expected_literal = LiteralUtil::CreateR1({}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -155,12 +155,12 @@ XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) { XLA_TEST_P(ReshapeTest, Trivial3x0) { XlaBuilder builder(TestName()); Array2D input_array(3, 0); - auto input_literal = Literal::CreateR2FromArray2D(input_array); + auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); - auto expected_literal = Literal::CreateR1({}); + auto expected_literal = LiteralUtil::CreateR1({}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -168,12 +168,12 @@ XLA_TEST_P(ReshapeTest, Trivial3x0) { // Collapses a 2-dimensional row vector to 1 dimension. XLA_TEST_P(ReshapeTest, Trivial1x3) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateR2({{1.0f, 2.0f, 3.0f}}); + auto input_literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); - auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); + auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -181,12 +181,12 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) { // Collapses a 2-dimensional column vector to 1 dimension. XLA_TEST_P(ReshapeTest, Trivial3x1) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateR2({{1.0f}, {2.0f}, {3.0f}}); + auto input_literal = LiteralUtil::CreateR2({{1.0f}, {2.0f}, {3.0f}}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); - auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); + auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -194,13 +194,13 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) { // Splits an empty vector into an empty matrix. XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateR1({}); + auto input_literal = LiteralUtil::CreateR1({}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0}, /*new_sizes=*/{2, 0}); - auto expected_literal = Literal::CreateR2({{}, {}}); + auto expected_literal = LiteralUtil::CreateR2({{}, {}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -209,14 +209,14 @@ XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) { XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { XlaBuilder builder(TestName()); auto input_literal = - Literal::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0}, /*new_sizes=*/{2, 3}); auto expected_literal = - Literal::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); + LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -224,13 +224,13 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { // Transposes a 2x0 array to a 0x2 array. XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(Array2D(0, 2)); + auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 2)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 0}); - auto expected_literal = Literal::CreateR2({{}, {}}); + auto expected_literal = LiteralUtil::CreateR2({{}, {}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -239,7 +239,7 @@ XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) { XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { XlaBuilder builder(TestName()); auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); - auto input_literal = Literal::CreateFromArray(*simple); + auto input_literal = LiteralUtil::CreateFromArray(*simple); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); @@ -247,7 +247,7 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { /*new_sizes=*/{3, 1}); auto expected = ReferenceUtil::TransposeArray2D(*simple); - auto expected_literal = Literal::CreateFromArray(*expected); + auto expected_literal = LiteralUtil::CreateFromArray(*expected); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -256,7 +256,7 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { XLA_TEST_P(ReshapeTest, TransposeAsReshape) { XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto input_literal = Literal::CreateFromArray(*a4x3); + auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); @@ -264,7 +264,7 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { /*new_sizes=*/{3, 4}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); - auto expected_literal = Literal::CreateFromArray(*expected); + auto expected_literal = LiteralUtil::CreateFromArray(*expected); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -272,12 +272,12 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { // Transposes a 0x4 array with XlaBuilder::Transpose. XLA_TEST_P(ReshapeTest, Transpose0x4) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(Array2D(0, 4)); + auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 4)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Transpose(parameter, {1, 0}); - auto expected_literal = Literal::CreateR2({{}, {}, {}, {}}); + auto expected_literal = LiteralUtil::CreateR2({{}, {}, {}, {}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -286,14 +286,14 @@ XLA_TEST_P(ReshapeTest, Transpose0x4) { XLA_TEST_P(ReshapeTest, Transpose4x3) { XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto input_literal = Literal::CreateFromArray(*a4x3); + auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Transpose(parameter, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); - auto expected_literal = Literal::CreateFromArray(*expected); + auto expected_literal = LiteralUtil::CreateFromArray(*expected); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -302,26 +302,27 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { // rearrangement of the originals (split), but no reordering (no shuffle). XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(Array2D(6, 0)); + auto input_literal = LiteralUtil::CreateFromArray(Array2D(6, 0)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 3, 0, 0}); - auto expected_literal = Literal::CreateFromArray(Array4D(2, 3, 0, 0)); + auto expected_literal = + LiteralUtil::CreateFromArray(Array4D(2, 3, 0, 0)); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(Array4D(2, 3, 4, 0)); + auto input_literal = LiteralUtil::CreateFromArray(Array4D(2, 3, 4, 0)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{24, 0}); - auto expected_literal = Literal::CreateFromArray(Array2D(24, 0)); + auto expected_literal = LiteralUtil::CreateFromArray(Array2D(24, 0)); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -331,7 +332,7 @@ XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) { XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto input_literal = Literal::CreateFromArray(*a4x3); + auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); @@ -339,20 +340,20 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { /*new_sizes=*/{2, 6}); auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); - auto expected_literal = Literal::CreateFromArray(*expected); + auto expected_literal = LiteralUtil::CreateFromArray(*expected); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(Array2D(0, 6)); + auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 6)); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 0}); - auto expected_literal = Literal::CreateFromArray(Array2D(3, 0)); + auto expected_literal = LiteralUtil::CreateFromArray(Array2D(3, 0)); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -362,7 +363,7 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto input_literal = Literal::CreateFromArray(*a4x3); + auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); @@ -370,7 +371,7 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { /*new_sizes=*/{2, 6}); Array2D expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); - auto expected_literal = Literal::CreateFromArray(expected); + auto expected_literal = LiteralUtil::CreateFromArray(expected); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -388,13 +389,13 @@ static Array3D ArrayForDocR3Tests() { XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, /*new_sizes=*/{24}); - auto expected_literal = Literal::CreateR1( + auto expected_literal = LiteralUtil::CreateR1( {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, @@ -403,33 +404,33 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, /*new_sizes=*/{8, 3}); - auto expected_literal = Literal::CreateR2({{10, 11, 12}, - {15, 16, 17}, - {20, 21, 22}, - {25, 26, 27}, - {30, 31, 32}, - {35, 36, 37}, - {40, 41, 42}, - {45, 46, 47}}); + auto expected_literal = LiteralUtil::CreateR2({{10, 11, 12}, + {15, 16, 17}, + {20, 21, 22}, + {25, 26, 27}, + {30, 31, 32}, + {35, 36, 37}, + {40, 41, 42}, + {45, 46, 47}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{24}); - auto expected_literal = Literal::CreateR1( + auto expected_literal = LiteralUtil::CreateR1( {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, @@ -438,33 +439,33 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{8, 3}); - auto expected_literal = Literal::CreateR2({{10, 20, 30}, - {40, 11, 21}, - {31, 41, 12}, - {22, 32, 42}, - {15, 25, 35}, - {45, 16, 26}, - {36, 46, 17}, - {27, 37, 47}}); + auto expected_literal = LiteralUtil::CreateR2({{10, 20, 30}, + {40, 11, 21}, + {31, 41, 12}, + {22, 32, 42}, + {15, 25, 35}, + {45, 16, 26}, + {36, 46, 17}, + {27, 37, 47}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { XlaBuilder builder(TestName()); - auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{2, 6, 2}); - auto expected_literal = Literal::CreateR3( + auto expected_literal = LiteralUtil::CreateR3( {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}}, {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, @@ -491,12 +492,12 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { Array4D t2x2x2x3(2, 2, 2, 3); auto filler2x3 = MakeLinspaceArray2D(1.0f, 6.0f, 2, 3); t2x2x2x3.FillWithYX(*filler2x3); - auto input_literal = Literal::CreateFromArray(t2x2x2x3); + auto input_literal = LiteralUtil::CreateFromArray(t2x2x2x3); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); - auto expected_literal = Literal::CreateR2( + auto expected_literal = LiteralUtil::CreateR2( {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); @@ -516,7 +517,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { t(1, 0, 0, 1) = 5; t(1, 0, 1, 0) = 6; t(1, 0, 1, 1) = 7; - auto input_literal = Literal::CreateFromArray(t); + auto input_literal = LiteralUtil::CreateFromArray(t); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); @@ -524,7 +525,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { /*new_sizes=*/{2, 4}); auto expected_literal = - Literal::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); + LiteralUtil::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, zero_error_spec_); } @@ -545,7 +546,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) { &b, ¶meter); Reshape(parameter, dimensions, {}); - auto expected_literal = Literal::CreateR0(83.0f); + auto expected_literal = LiteralUtil::CreateR0(83.0f); ComputeAndCompareLiteral(&b, *expected_literal, {input.get()}, zero_error_spec_); } @@ -553,7 +554,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) { XLA_TEST_P(ReshapeTest, BadDimensions) { XlaBuilder b(TestName()); - auto input_literal = Literal::CreateR1({1.0f}); + auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, ¶meter); @@ -565,7 +566,7 @@ XLA_TEST_P(ReshapeTest, BadDimensions) { XLA_TEST_P(ReshapeTest, BadNewSizes) { XlaBuilder b(TestName()); - auto input_literal = Literal::CreateR1({1.0f, 2.0f}); + auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f}); XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, ¶meter); @@ -577,7 +578,8 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) { XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { XlaBuilder builder(TestName()); // clang-format off - auto input_literal = Literal::CreateR4FromArray4DWithLayout(Array4D{ + auto input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + Array4D{ { { {0, 1}, @@ -622,16 +624,16 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) .ConsumeValueOrDie(); std::unique_ptr expected = - Literal::CreateR2FromArray2D(expected_array); + LiteralUtil::CreateR2FromArray2D(expected_array); if (use_bfloat16()) { - expected = Literal::ConvertF32ToBF16(*expected); + expected = LiteralUtil::ConvertF32ToBF16(*expected); } EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); } XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { XlaBuilder builder(TestName()); - std::unique_ptr input_literal = Literal::CreateR2({ + std::unique_ptr input_literal = LiteralUtil::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, @@ -642,7 +644,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off - auto expected_literal = Literal::CreateR4({ + auto expected_literal = LiteralUtil::CreateR4({ {{{0, 1, 2, 3}}, {{4, 5, 6, 7}}}, {{{100, 101, 102, 103}}, @@ -658,7 +660,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { // Tests R2->R4 reshape with the reshape dimensions {1, 0}. XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { XlaBuilder builder(TestName()); - std::unique_ptr input_literal = Literal::CreateR2({ + std::unique_ptr input_literal = LiteralUtil::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, @@ -669,7 +671,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off - auto expected_literal = Literal::CreateR4({ + auto expected_literal = LiteralUtil::CreateR4({ {{{0, 100, 200, 1}}, {{101, 201, 2, 102}}}, {{{202, 3, 103, 203}}, @@ -691,7 +693,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( @@ -699,7 +701,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); std::unique_ptr expected = - Literal::ReshapeSlice({2, 1}, {1, 0}, *input_literal); + LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, *input_literal); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } @@ -713,7 +715,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( @@ -721,7 +723,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); std::unique_ptr expected = - Literal::ReshapeSlice({4, 2}, {1, 0}, *input_literal); + LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, *input_literal); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } @@ -736,7 +738,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( @@ -749,7 +751,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { expected_array(indices[0], indices[2] * 30 + indices[1] * 3 + indices[3]) = *cell; }); - auto expected = Literal::CreateR2FromArray2D(expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } @@ -763,7 +765,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( @@ -785,7 +787,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. if (use_bfloat16()) { - auto expected = Literal::ConvertF32ToBF16(*input_literal); + auto expected = LiteralUtil::ConvertF32ToBF16(*input_literal); EXPECT_EQ(expected->data(), output_literal->data()); } else { EXPECT_EQ(input_literal->data(), output_literal->data()); @@ -794,7 +796,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { XlaBuilder builder(TestName()); - auto literal_1x2x3x4 = Literal::CreateR4( + auto literal_1x2x3x4 = LiteralUtil::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); @@ -808,7 +810,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { } XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { - auto literal_1x2x3x4 = Literal::CreateR4( + auto literal_1x2x3x4 = LiteralUtil::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); @@ -820,7 +822,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { /*new_sizes=*/{2, 4, 3, 1}); // clang-format off - auto expected_2x4x3x1 = Literal::CreateR4( + auto expected_2x4x3x1 = LiteralUtil::CreateR4( {{{{1}, {5}, {9}}, {{2}, {6}, {10}}, {{3}, {7}, {11}}, @@ -844,7 +846,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; @@ -854,7 +856,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -873,7 +875,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; @@ -883,7 +885,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -902,7 +904,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; @@ -912,7 +914,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -932,7 +934,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; @@ -942,7 +944,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -961,7 +963,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - Literal::CreateR4FromArray4DWithLayout( + LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({0, 1, 2, 3})); XlaBuilder builder(TestName()); XlaOp parameter; @@ -971,7 +973,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - Literal::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal) + LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal) ->Relayout(input_literal->shape().layout()); // Specify the requested output shape explicitly to ensure that this reshape diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 662bc42224851ac19c690129f525953e6d410a55..41e49b4003236d55d85592315652a0ddefd5c485 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -82,7 +82,7 @@ TEST_P(FloatReverseTest, Reverses) { std::vector input_vector( ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims))); std::iota(input_vector.begin(), input_vector.end(), 0.0); - auto r1_literal = Literal::CreateR1(input_vector); + auto r1_literal = LiteralUtil::CreateR1(input_vector); auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie(); XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index 7cfca781acda15879075f4386c2096e537877aac..a620fe19085d98c8b6642b25b159d6c2308bdae2 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/packed_literal_reader.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc index f334a8c1318a59bbfdd27dd1a63ed162600089ce..a8193c2eac05ba4f0df339909f3e82a28ac35253 100644 --- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -46,61 +46,62 @@ class RoundTripTransferTest : public ClientLibraryTestBase { }; TEST_F(RoundTripTransferTest, R0S32) { - RoundTripTest(*Literal::CreateR0(42)); + RoundTripTest(*LiteralUtil::CreateR0(42)); } TEST_F(RoundTripTransferTest, R0F32) { - RoundTripTest(*Literal::CreateR0(42.0)); + RoundTripTest(*LiteralUtil::CreateR0(42.0)); } TEST_F(RoundTripTransferTest, R1F32_Len0) { - RoundTripTest(*Literal::CreateR1({})); + RoundTripTest(*LiteralUtil::CreateR1({})); } TEST_F(RoundTripTransferTest, R1F32_Len2) { - RoundTripTest(*Literal::CreateR1({42.0, 64.0})); + RoundTripTest(*LiteralUtil::CreateR1({42.0, 64.0})); } TEST_F(RoundTripTransferTest, R1F32_Len256) { std::vector values(256); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*Literal::CreateR1(values)); + RoundTripTest(*LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1024) { std::vector values(1024); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*Literal::CreateR1(values)); + RoundTripTest(*LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1025) { std::vector values(1025); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*Literal::CreateR1(values)); + RoundTripTest(*LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len4096) { std::vector values(4096); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*Literal::CreateR1(values)); + RoundTripTest(*LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R2F32_Len10x0) { - RoundTripTest(*Literal::CreateR2FromArray2D(Array2D(10, 0))); + RoundTripTest( + *LiteralUtil::CreateR2FromArray2D(Array2D(10, 0))); } TEST_F(RoundTripTransferTest, R2F32_Len2x2) { - RoundTripTest(*Literal::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); + RoundTripTest(*LiteralUtil::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); } TEST_F(RoundTripTransferTest, R3F32) { RoundTripTest( - *Literal::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); + *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); } TEST_F(RoundTripTransferTest, R4F32) { - RoundTripTest(*Literal::CreateR4({{ + RoundTripTest(*LiteralUtil::CreateR4({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, @@ -108,33 +109,36 @@ TEST_F(RoundTripTransferTest, R4F32) { } TEST_F(RoundTripTransferTest, EmptyTuple) { - RoundTripTest(*Literal::MakeTuple({})); + RoundTripTest(*LiteralUtil::MakeTuple({})); } TEST_F(RoundTripTransferTest, TupleOfR1F32) { - RoundTripTest(*Literal::MakeTuple({Literal::CreateR1({1, 2}).get(), - Literal::CreateR1({3, 4}).get()})); + RoundTripTest( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), + LiteralUtil::CreateR1({3, 4}).get()})); } TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) { - RoundTripTest(*Literal::MakeTuple({Literal::CreateR1({}).get(), - Literal::CreateR1({3, 4}).get()})); + RoundTripTest( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({}).get(), + LiteralUtil::CreateR1({3, 4}).get()})); } TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) { - RoundTripTest(*Literal::MakeTuple({Literal::CreateR0(1.0).get(), - Literal::CreateR1({2, 3}).get()})); + RoundTripTest( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(1.0).get(), + LiteralUtil::CreateR1({2, 3}).get()})); } // Below two tests are added to identify the cost of large data transfers. TEST_F(RoundTripTransferTest, R2F32_Large) { - RoundTripTest(*Literal::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); + RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); } TEST_F(RoundTripTransferTest, R4F32_Large) { Array4D array4d(2, 2, 256, 256); array4d.FillWithMultiples(1.0f); - RoundTripTest(*Literal::CreateR4FromArray4D(array4d)); + RoundTripTest(*LiteralUtil::CreateR4FromArray4D(array4d)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index d0ebb108aec99660c757493fa23b185baf35874a..e42c71eb284deb2e50d6ea4b47fa707e4bc14ffc 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -19,8 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -44,25 +45,26 @@ class ScalarComputationsTest : public ClientLibraryTestBase { protected: // A template for building and running a binary comparison test. template - void TestCompare( - NativeT lhs, NativeT rhs, bool expected, - XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&, - tensorflow::gtl::ArraySlice)) { + void TestCompare(NativeT lhs, NativeT rhs, bool expected, + std::function)> + op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0(&builder, lhs); XlaOp rhs_op = ConstantR0(&builder, rhs); - (builder.*op)(lhs_op, rhs_op, {}); + op(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } template void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected, - XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&, - tensorflow::gtl::ArraySlice)) { + std::function)> + op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0(&builder, lhs); XlaOp rhs_op = ConstantR0(&builder, rhs); - (builder.*op)(lhs_op, rhs_op, {}); + op(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } }; @@ -161,7 +163,7 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { ConvertElementType(a, F32); int64 value = 3LL << 35; - std::unique_ptr a_literal = Literal::CreateR0(value); + std::unique_ptr a_literal = LiteralUtil::CreateR0(value); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); ComputeAndCompareR0(&builder, static_cast(value), @@ -225,9 +227,9 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = Literal::CreateR0(2.1f); - std::unique_ptr b_literal = Literal::CreateR0(5.5f); - std::unique_ptr c_literal = Literal::CreateR0(0.5f); + std::unique_ptr a_literal = LiteralUtil::CreateR0(2.1f); + std::unique_ptr b_literal = LiteralUtil::CreateR0(5.5f); + std::unique_ptr c_literal = LiteralUtil::CreateR0(0.5f); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); @@ -374,8 +376,8 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { for (uint32 divisor : vals) { if (divisor != 0) { for (uint32 dividend : vals) { - auto dividend_literal = Literal::CreateR0(dividend); - auto divisor_literal = Literal::CreateR0(divisor); + auto dividend_literal = LiteralUtil::CreateR0(dividend); + auto divisor_literal = LiteralUtil::CreateR0(divisor); TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, client_->TransferToServer(*dividend_literal)); TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, @@ -386,7 +388,8 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { {dividend_data.get(), divisor_data.get()}, &execution_options_) .ConsumeValueOrDie(); - auto expected_literal = Literal::CreateR0(dividend / divisor); + auto expected_literal = + LiteralUtil::CreateR0(dividend / divisor); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } @@ -415,8 +418,8 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { for (uint32 divisor : vals) { if (divisor != 0) { for (uint32 dividend : vals) { - auto dividend_literal = Literal::CreateR0(dividend); - auto divisor_literal = Literal::CreateR0(divisor); + auto dividend_literal = LiteralUtil::CreateR0(dividend); + auto divisor_literal = LiteralUtil::CreateR0(divisor); TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, client_->TransferToServer(*dividend_literal)); TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, @@ -427,7 +430,8 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { {dividend_data.get(), divisor_data.get()}, &execution_options_) .ConsumeValueOrDie(); - auto expected_literal = Literal::CreateR0(dividend % divisor); + auto expected_literal = + LiteralUtil::CreateR0(dividend % divisor); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } @@ -439,7 +443,7 @@ XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x"); Rem(x, ConstantR0(&builder, 80000)); - std::unique_ptr literal = Literal::CreateR0(87919); + std::unique_ptr literal = LiteralUtil::CreateR0(87919); TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal)); ComputeAndCompareR0(&builder, 7919, {input_data.get()}); } @@ -583,117 +587,116 @@ XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) { // S32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqS32Greater) { - TestCompare(2, 1, false, &XlaBuilder::Eq); + TestCompare(2, 1, false, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareEqS32Equal) { - TestCompare(3, 3, true, &XlaBuilder::Eq); + TestCompare(3, 3, true, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeS32) { - TestCompare(2, 1, true, &XlaBuilder::Ne); + TestCompare(2, 1, true, &Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeS32) { - TestCompare(2, 1, true, &XlaBuilder::Ge); + TestCompare(2, 1, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtS32) { - TestCompare(1, 5, false, &XlaBuilder::Gt); + TestCompare(1, 5, false, &Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeS32) { - TestCompare(2, 1, false, &XlaBuilder::Le); + TestCompare(2, 1, false, &Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtS32) { - TestCompare(9, 7, false, &XlaBuilder::Lt); + TestCompare(9, 7, false, &Lt); TestCompare(std::numeric_limits::min(), - std::numeric_limits::max(), true, &XlaBuilder::Lt); + std::numeric_limits::max(), true, &Lt); } // U32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqU32False) { - TestCompare(2, 1, false, &XlaBuilder::Eq); + TestCompare(2, 1, false, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeU32) { - TestCompare(2, 1, true, &XlaBuilder::Ne); + TestCompare(2, 1, true, &Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeU32Greater) { - TestCompare(2, 1, true, &XlaBuilder::Ge); + TestCompare(2, 1, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeU32Equal) { - TestCompare(3, 3, true, &XlaBuilder::Ge); + TestCompare(3, 3, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtU32) { - TestCompare(1, 5, false, &XlaBuilder::Gt); - TestCompare(5, 5, false, &XlaBuilder::Gt); - TestCompare(5, 1, true, &XlaBuilder::Gt); + TestCompare(1, 5, false, &Gt); + TestCompare(5, 5, false, &Gt); + TestCompare(5, 1, true, &Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeU32) { - TestCompare(2, 1, false, &XlaBuilder::Le); + TestCompare(2, 1, false, &Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtU32) { - TestCompare(9, 7, false, &XlaBuilder::Lt); - TestCompare(0, std::numeric_limits::max(), true, - &XlaBuilder::Lt); + TestCompare(9, 7, false, &Lt); + TestCompare(0, std::numeric_limits::max(), true, &Lt); } // F32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqF32False) { - TestCompare(2.0, 1.3, false, &XlaBuilder::Eq); + TestCompare(2.0, 1.3, false, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeF32) { - TestCompare(2.0, 1.3, true, &XlaBuilder::Ne); + TestCompare(2.0, 1.3, true, &Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32Greater) { - TestCompare(2.0, 1.9, true, &XlaBuilder::Ge); + TestCompare(2.0, 1.9, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32Equal) { - TestCompare(3.5, 3.5, true, &XlaBuilder::Ge); + TestCompare(3.5, 3.5, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtF32) { - TestCompare(1.0, 5.2, false, &XlaBuilder::Gt); + TestCompare(1.0, 5.2, false, &Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeF32) { - TestCompare(2.0, 1.2, false, &XlaBuilder::Le); + TestCompare(2.0, 1.2, false, &Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32) { - TestCompare(9.0, 7.2, false, &XlaBuilder::Lt); + TestCompare(9.0, 7.2, false, &Lt); } // F32 comparisons with exceptional values. The test names encode the // left/right operands at the end, and use Minf and Mzero for -inf and -0.0. XLA_TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) { - TestCompare(-INFINITY, -0.0, true, &XlaBuilder::Lt); + TestCompare(-INFINITY, -0.0, true, &Lt); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) { // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. - TestCompare(-0.0, 0.0, false, &XlaBuilder::Lt); + TestCompare(-0.0, 0.0, false, &Lt); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) { - TestCompare(0.0, INFINITY, true, &XlaBuilder::Lt); + TestCompare(0.0, INFINITY, true, &Lt); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) { - TestCompare(-INFINITY, -0.0, false, &XlaBuilder::Ge); + TestCompare(-INFINITY, -0.0, false, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) { // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. - TestCompare(-0.0, 0.0, true, &XlaBuilder::Ge); + TestCompare(-0.0, 0.0, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) { - TestCompare(0.0, INFINITY, false, &XlaBuilder::Ge); + TestCompare(0.0, INFINITY, false, &Ge); } XLA_TEST_F(ScalarComputationsTest, ExpScalar) { @@ -813,65 +816,65 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) { } XLA_TEST_F(ScalarComputationsTest, MinS32Above) { - TestMinMax(10, 3, 3, &XlaBuilder::Min); + TestMinMax(10, 3, 3, &Min); } XLA_TEST_F(ScalarComputationsTest, MinS32Below) { - TestMinMax(-100, 3, -100, &XlaBuilder::Min); + TestMinMax(-100, 3, -100, &Min); } XLA_TEST_F(ScalarComputationsTest, MaxS32Above) { - TestMinMax(10, 3, 10, &XlaBuilder::Max); + TestMinMax(10, 3, 10, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxS32Below) { - TestMinMax(-100, 3, 3, &XlaBuilder::Max); + TestMinMax(-100, 3, 3, &Max); } XLA_TEST_F(ScalarComputationsTest, MinU32Above) { const uint32 large = std::numeric_limits::max(); - TestMinMax(large, 3, 3, &XlaBuilder::Min); + TestMinMax(large, 3, 3, &Min); } XLA_TEST_F(ScalarComputationsTest, MinU32Below) { - TestMinMax(0, 5, 0, &XlaBuilder::Min); + TestMinMax(0, 5, 0, &Min); } XLA_TEST_F(ScalarComputationsTest, MaxU32Above) { const uint32 large = std::numeric_limits::max(); - TestMinMax(large, 3, large, &XlaBuilder::Max); + TestMinMax(large, 3, large, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxU32Below) { - TestMinMax(0, 5, 5, &XlaBuilder::Max); + TestMinMax(0, 5, 5, &Max); } XLA_TEST_F(ScalarComputationsTest, MinF32Above) { - TestMinMax(10.1f, 3.1f, 3.1f, &XlaBuilder::Min); + TestMinMax(10.1f, 3.1f, 3.1f, &Min); } XLA_TEST_F(ScalarComputationsTest, MinF32Below) { - TestMinMax(-100.1f, 3.1f, -100.1f, &XlaBuilder::Min); + TestMinMax(-100.1f, 3.1f, -100.1f, &Min); } XLA_TEST_F(ScalarComputationsTest, MinPropagatesNan) { SetFastMathDisabled(true); - TestMinMax(NAN, 3.1f, NAN, &XlaBuilder::Min); - TestMinMax(-3.1f, NAN, NAN, &XlaBuilder::Min); + TestMinMax(NAN, 3.1f, NAN, &Min); + TestMinMax(-3.1f, NAN, NAN, &Min); } XLA_TEST_F(ScalarComputationsTest, MaxF32Above) { - TestMinMax(10.1f, 3.1f, 10.1f, &XlaBuilder::Max); + TestMinMax(10.1f, 3.1f, 10.1f, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxF32Below) { - TestMinMax(-100.1f, 3.1f, 3.1f, &XlaBuilder::Max); + TestMinMax(-100.1f, 3.1f, 3.1f, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxPropagatesNan) { SetFastMathDisabled(true); - TestMinMax(NAN, 3.1f, NAN, &XlaBuilder::Max); - TestMinMax(-3.1f, NAN, NAN, &XlaBuilder::Max); + TestMinMax(NAN, 3.1f, NAN, &Max); + TestMinMax(-3.1f, NAN, NAN, &Max); } XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { @@ -897,18 +900,6 @@ XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) { ComputeAndCompareR0(&b, 10, {}); } -XLA_TEST_F(ScalarComputationsTest, SqrtF320) { - XlaBuilder builder(TestName()); - Literal zero_literal = Literal::Zero(PrimitiveType::F32); - - std::unique_ptr zero_data = - client_->TransferToServer(zero_literal).ConsumeValueOrDie(); - - XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero"); - SqrtF32(zero); - - ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); -} XLA_TEST_F(ScalarComputationsTest, RoundScalar) { XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index 0a173fbbbd5cb5e5005728331561008b8b29af26..e3d4f98dd7432d1dce7e697586e8b17105dc82e7 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -22,10 +22,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc index 59409ab26e1c19a8271318c18e19caa7b8ddc3b7..1c01402798658877889527a5dd02d5c74787ff99 100644 --- a/tensorflow/compiler/xla/tests/select_test.cc +++ b/tensorflow/compiler/xla/tests/select_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 3e5c01d6d47cc3f3b7d46ce300fe26c5ec9e63fa..b8ad6668f80a3002eff3cc458997966ee67c8d4b 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -170,7 +170,7 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { values.FillRandom(3.14f); auto expected = ReferenceUtil::Slice4D(values, {{0, 0, 0, 0}}, {{2, 4, 6, 8}}, /*strides=*/{{1, 1, 2, 1}}); - auto expected_literal = Literal::CreateR4FromArray4DWithLayout( + auto expected_literal = LiteralUtil::CreateR4FromArray4DWithLayout( *expected, LayoutUtil::MakeLayout({0, 1, 2, 3})); XlaBuilder builder(TestName()); auto original = ConstantR4FromArray4D(&builder, values); @@ -197,7 +197,7 @@ class SliceR1Test : public ClientLibraryTestBase, // vector. tensorflow::gtl::InlinedVector input(spec.input_dim0); std::iota(input.begin(), input.end(), NativeT()); - auto literal = Literal::CreateR1(input); + auto literal = LiteralUtil::CreateR1(input); XlaBuilder builder(TestName()); auto original = Parameter(&builder, 0, literal->shape(), "p0"); @@ -344,7 +344,11 @@ INSTANTIATE_TEST_CASE_P( R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 2}, R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 8}, R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 7}, - R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 125} + R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 125}, + R1Spec{16 * 1024 * 1024, 0, 16 * 1024 * 1024, 4097}, + R1Spec{16 * 1024 * 1024, 0, 16 * 1024 * 1024, 4093}, + R1Spec{16 * 1024 * 1024, 12 * 1024 + 17, 16 * 1024 * 1024 - 231, 4097}, + R1Spec{16 * 1024 * 1024, 12 * 1024 + 17, 16 * 1024 * 1024 - 231, 4093} ), SliceR1TestDataToString ); @@ -368,7 +372,7 @@ XLA_TEST_P(SliceR2Test, DoIt) { const R2Spec& spec = GetParam(); Array2D input(spec.input_dim0, spec.input_dim1); input.FillUnique(); - auto literal = Literal::CreateR2FromArray2DWithLayout( + auto literal = LiteralUtil::CreateR2FromArray2DWithLayout( input, LayoutUtil::MakeLayout(spec.layout)); XlaBuilder builder(TestName()); @@ -463,7 +467,7 @@ class SliceR4Test : public ClientLibraryTestBase, auto expected = ReferenceUtil::Slice4D( values, spec.slice_starts, spec.slice_limits, spec.slice_strides); XlaBuilder builder(TestName()); - auto literal = Literal::CreateR4FromArray4DWithLayout( + auto literal = LiteralUtil::CreateR4FromArray4DWithLayout( values, LayoutUtil::MakeLayout(spec.input_layout)); auto parameter = Parameter(&builder, 0, literal->shape(), "p0"); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 000535a982fb08af69e7b317501f82ba7f402fb9..2647937013222ccfdae98b0c1d141f461020b5c9 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" @@ -110,7 +111,7 @@ StatusOr> MakeFakeLiteralInternal( MakeFakeLiteralInternal(element_shape, engine)); elements.push_back(std::move(element)); } - return Literal::MakeTupleOwned(std::move(elements)); + return LiteralUtil::MakeTupleOwned(std::move(elements)); } if (engine == nullptr) { return Literal::CreateFromShape(shape); @@ -161,6 +162,9 @@ StatusOr> MakeFakeLiteralInternal( })); break; } + // Token requires no data. + case TOKEN: + break; default: return Unimplemented("Unsupported type for fake literal generation: %s", ShapeUtil::HumanString(shape).c_str()); @@ -217,7 +221,7 @@ std::unique_ptr MakeRandomNonwrappingSliceIndex( start_indices[i] = generator(*engine); } } - return Literal::CreateR1(start_indices); + return LiteralUtil::CreateR1(start_indices); } // Use dataflow analysis on each parameter to see if there are uses that would @@ -315,9 +319,9 @@ StatusOr> CreateLiteralForConstrainedUses( } else if (needs_constant != nullptr) { switch (constant_type) { case ConstantType::kZero: - return Literal::Zero(param.shape().element_type()).CloneToUnique(); + return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique(); case ConstantType::kOne: - return Literal::One(param.shape().element_type()).CloneToUnique(); + return LiteralUtil::One(param.shape().element_type()).CloneToUnique(); case ConstantType::kUnknown: // We want the identity element for the computation, but we don't really // know what it is - so any value we generate will be just as wrong. diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index a8689f64981569ceb7c8a712f8ece00c99e8cf2d..e59f215a9a3ace80d7a23e1bbc40970c7a63ea0d 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index e8f2fb44d8976a1ab4150a0876397b8f3f8fc440..a2f0338e25977d7c76dbc48b3afc649b77ba4ee2 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -15,7 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -53,5 +54,23 @@ XLA_TEST_F(TestUtilsTest, UnusedParam) { TF_ASSERT_OK(MakeFakeArguments(&module).status()); } +XLA_TEST_F(TestUtilsTest, Token) { + auto module = ParseHloString( + R"(HloModule outfeed_module + + ENTRY InfeedToOutfeed { + token = token[] parameter(0) + infeed = ((u32[3]{0}, pred[]), token[]) infeed(token) + infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0 + outfeed = token[] outfeed(infeed.data, token) + ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token) + infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0 + infeed.1.token = token[] get-tuple-element(infeed.1), index=1 + outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token) + })") + .ValueOrDie(); + TF_ASSERT_OK(MakeFakeArguments(module.get()).status()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index e9008fa48aa7d0158bd2221791be23c128859098..2bdbd08309a81b201fc224110805549f7fb5bb55 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -31,21 +31,21 @@ class TokenHloTest : public HloTestBase {}; XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { std::unique_ptr module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction(HloInstruction::CreateAfterAll({})); + builder.AddInstruction(HloInstruction::CreateToken()); module->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, Execute(std::move(module), {})); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *Literal::CreateToken())); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken())); } XLA_TEST_F(TokenHloTest, TokenTree) { std::unique_ptr module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto token0 = builder.AddInstruction(HloInstruction::CreateAfterAll({})); - auto token1 = builder.AddInstruction(HloInstruction::CreateAfterAll({})); - auto token2 = builder.AddInstruction(HloInstruction::CreateAfterAll({})); + auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); + auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); + auto token2 = builder.AddInstruction(HloInstruction::CreateToken()); builder.AddInstruction( HloInstruction::CreateAfterAll({token0, token0, token1, token2})); @@ -53,7 +53,7 @@ XLA_TEST_F(TokenHloTest, TokenTree) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, Execute(std::move(module), {})); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *Literal::CreateToken())); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken())); } XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { @@ -64,7 +64,7 @@ XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { builder.AddInstruction( HloInstruction::CreateParameter(1, ShapeUtil::MakeTokenShape(), "p1")); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); module->AddEntryComputation(builder.Build()); Status status = HloVerifier().Run(module.get()).status(); @@ -98,7 +98,7 @@ XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); builder.AddInstruction(HloInstruction::CreateAfterAll({param})); builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(123))); + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); module->AddEntryComputation(builder.Build()); Status status = HloVerifier().Run(module.get()).status(); @@ -184,7 +184,7 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr module, HloRunner::CreateModuleFromString(module_string, debug_options)); - auto arg = Literal::CreateR0(true); + auto arg = LiteralUtil::CreateR0(true); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, Execute(std::move(module), {arg.get()})); EXPECT_EQ(42, result->Get({})); @@ -195,7 +195,7 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr module, HloRunner::CreateModuleFromString(module_string, debug_options)); - auto arg = Literal::CreateR0(false); + auto arg = LiteralUtil::CreateR0(false); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, Execute(std::move(module), {arg.get()})); EXPECT_EQ(7, result->Get({})); diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index 86babb58c9d4515935a5904e04e8fea1074a2812..125513ddfd16cb4e742e7d589e22b721307621ee 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -18,10 +18,11 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -60,7 +61,7 @@ class TransferManagerTest : public LocalClientTestBase { } protected: - Backend::StreamPtr stream_ptr_; + StreamPool::Ptr stream_ptr_; se::Stream* stream_; private: @@ -68,7 +69,7 @@ class TransferManagerTest : public LocalClientTestBase { }; XLA_TEST_F(TransferManagerTest, TransferR0U32) { - std::unique_ptr literal = Literal::CreateR0(42); + std::unique_ptr literal = LiteralUtil::CreateR0(42); const Shape& shape = literal->shape(); auto device_buffer = AllocateDeviceBuffer(shape); @@ -84,7 +85,7 @@ XLA_TEST_F(TransferManagerTest, TransferR0U32) { XLA_TEST_F(TransferManagerTest, TransferR1F32) { std::unique_ptr literal = - Literal::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); + LiteralUtil::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); const Shape& shape = literal->shape(); auto device_buffer = AllocateDeviceBuffer(shape); @@ -102,7 +103,7 @@ XLA_TEST_F(TransferManagerTest, TransferR1F32) { XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { std::vector test_vector(1024 * 1024); std::iota(test_vector.begin(), test_vector.end(), 0); - std::unique_ptr literal = Literal::CreateR1(test_vector); + std::unique_ptr literal = LiteralUtil::CreateR1(test_vector); const Shape& shape = literal->shape(); auto device_buffer = AllocateDeviceBuffer(shape); @@ -118,7 +119,7 @@ XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { XLA_TEST_F(TransferManagerTest, TransferR1U8) { const char* test_string = "0123456789abcdef"; - std::unique_ptr literal = Literal::CreateR1U8(test_string); + std::unique_ptr literal = LiteralUtil::CreateR1U8(test_string); const Shape& shape = literal->shape(); auto device_buffer = AllocateDeviceBuffer(shape); @@ -134,7 +135,7 @@ XLA_TEST_F(TransferManagerTest, TransferR1U8) { XLA_TEST_F(TransferManagerTest, TransferR2F32) { std::unique_ptr literal = - Literal::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); + LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); const Shape& shape = literal->shape(); auto device_buffer = AllocateDeviceBuffer(shape); @@ -151,7 +152,7 @@ XLA_TEST_F(TransferManagerTest, TransferR2F32) { XLA_TEST_F(TransferManagerTest, TransferR2F32AndChangeLayoutTransferringToDevice) { - std::unique_ptr literal = Literal::CreateR2WithLayout( + std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1})); const Shape ondevice_shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}); @@ -172,10 +173,10 @@ XLA_TEST_F(TransferManagerTest, } XLA_TEST_F(TransferManagerTest, TransferTuple) { - std::unique_ptr literal = Literal::MakeTuple( - {Literal::CreateR0(123.0f).get(), - Literal::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - Literal::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}); + std::unique_ptr literal = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(123.0f).get(), + LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}); auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. @@ -189,7 +190,7 @@ XLA_TEST_F(TransferManagerTest, TransferTuple) { } XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { - std::unique_ptr literal = Literal::MakeTuple({}); + std::unique_ptr literal = LiteralUtil::MakeTuple({}); auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. @@ -203,13 +204,13 @@ XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { } XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { - std::unique_ptr literal = Literal::MakeTuple( - {Literal::CreateR0(123.0f).get(), - Literal::MakeTuple( - {Literal::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - Literal::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) + std::unique_ptr literal = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(123.0f).get(), + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) .get(), - Literal::CreateR1({-10.0f, 123.0f}).get()}); + LiteralUtil::CreateR1({-10.0f, 123.0f}).get()}); auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. @@ -223,7 +224,7 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { } XLA_TEST_F(TransferManagerTest, TransferComplexValue) { - std::unique_ptr literal = Literal::CreateR1( + std::unique_ptr literal = LiteralUtil::CreateR1( {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}); auto device_buffer = AllocateDeviceBuffer(literal->shape()); @@ -238,12 +239,12 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValue) { } XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { - std::unique_ptr literal = Literal::MakeTuple( - {Literal::CreateR1( + std::unique_ptr literal = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1( {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}) .get(), - Literal::CreateR1({1, 2, 3, 4, 5, 6}).get(), - Literal::CreateR0(complex64(0.3f, -0.4f)).get()}); + LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6}).get(), + LiteralUtil::CreateR0(complex64(0.3f, -0.4f)).get()}); auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. @@ -265,25 +266,25 @@ XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*Literal::CreateToken(), *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(*LiteralUtil::CreateToken(), *result)); } XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) { const int64 kIterationCount = 5000; - std::unique_ptr literal1 = Literal::MakeTuple( - {Literal::CreateR0(123.0f).get(), - Literal::MakeTuple( - {Literal::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - Literal::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) + std::unique_ptr literal1 = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(123.0f).get(), + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) .get(), - Literal::CreateR1({-10.0f, 123.0f}).get()}); - std::unique_ptr literal2 = Literal::MakeTuple( - {Literal::CreateR0(456.0f).get(), - Literal::MakeTuple( - {Literal::CreateR2({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(), - Literal::CreateR1({44.0f, -11.0f, 3333333.3f}).get()}) + LiteralUtil::CreateR1({-10.0f, 123.0f}).get()}); + std::unique_ptr literal2 = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(456.0f).get(), + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(), + LiteralUtil::CreateR1({44.0f, -11.0f, 3333333.3f}).get()}) .get(), - Literal::CreateR1({-98.0f, 153.0f}).get()}); + LiteralUtil::CreateR1({-98.0f, 153.0f}).get()}); auto device_buffer1 = AllocateDeviceBuffer(literal1->shape()); auto device_buffer2 = AllocateDeviceBuffer(literal2->shape()); @@ -325,10 +326,10 @@ class TransferDeviceToHostBenchmark : public TransferManagerTest { std::vector> tuple_elements; for (int i = 0; i < num_tuple_elements; ++i) { tuple_elements.push_back( - Literal::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); + LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); } std::unique_ptr literal = - Literal::MakeTupleOwned(std::move(tuple_elements)); + LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); auto device_buffer = AllocateDeviceBuffer(literal->shape()); TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, device_buffer)); @@ -357,10 +358,10 @@ class TransferHostToDeviceBenchmark : public TransferManagerTest { std::vector> tuple_elements; for (int i = 0; i < num_tuple_elements; ++i) { tuple_elements.push_back( - Literal::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); + LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); } std::unique_ptr literal = - Literal::MakeTupleOwned(std::move(tuple_elements)); + LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); auto device_buffer = AllocateDeviceBuffer(literal->shape()); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc index 6ebb4324f8d20ed9f8886d92b0513441685ed19b..fbe9d1b64aa0c06d65b547c45cfa981800d40ff3 100644 --- a/tensorflow/compiler/xla/tests/transpose_test.cc +++ b/tensorflow/compiler/xla/tests/transpose_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index ec11508891d13f8032a1ebec388c756cf6d752c7..97bbf80aff80e995ea5cdd3e5d8807ee4d380067 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -49,10 +50,10 @@ XLA_TEST_F(TupleTest, TupleConstant) { {1.1f, 2.2f, 3.5f}, // row 0 {4.8f, 5.0f, 6.7f}, // row 1 }; - auto value = - Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), - Literal::CreateR1(constant_vector).get(), - Literal::CreateR2(constant_matrix).get()}); + auto value = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(constant_scalar).get(), + LiteralUtil::CreateR1(constant_vector).get(), + LiteralUtil::CreateR2(constant_matrix).get()}); ConstantLiteral(&builder, *value); ComputeAndCompareTuple(&builder, *value, {}, error_spec_); @@ -64,9 +65,9 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) { const float constant_scalar1 = 7.3f; const float constant_scalar2 = 1.2f; - auto value = - Literal::MakeTuple({Literal::CreateR0(constant_scalar1).get(), - Literal::CreateR0(constant_scalar2).get()}); + auto value = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(constant_scalar1).get(), + LiteralUtil::CreateR0(constant_scalar2).get()}); ConstantLiteral(&builder, *value); ComputeAndCompareTuple(&builder, *value, {}, error_spec_); @@ -86,10 +87,10 @@ XLA_TEST_F(TupleTest, TupleCreate) { ConstantR1(&builder, constant_vector), ConstantR2(&builder, constant_matrix)}); - auto expected = - Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), - Literal::CreateR1(constant_vector).get(), - Literal::CreateR2(constant_matrix).get()}); + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0(constant_scalar).get(), + LiteralUtil::CreateR1(constant_vector).get(), + LiteralUtil::CreateR2(constant_matrix).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -100,8 +101,9 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { Tuple(&builder, {ConstantR0(&builder, 7.0), ConstantR1(&builder, {})}); - auto expected = Literal::MakeTuple({Literal::CreateR0(7.0).get(), - Literal::CreateR1({}).get()}); + auto expected = + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(7.0).get(), + LiteralUtil::CreateR1({}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -109,7 +111,7 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { XLA_TEST_F(TupleTest, EmptyTupleCreate) { XlaBuilder builder(TestName()); Tuple(&builder, {}); - auto expected = Literal::MakeTuple({}); + auto expected = LiteralUtil::MakeTuple({}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -193,9 +195,9 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { ConstantR2(&builder, constant_matrix)}); Tuple(&builder, {GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)}); - auto expected = - Literal::MakeTuple({Literal::CreateR2(constant_matrix).get(), - Literal::CreateR1(constant_vector).get()}); + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2(constant_matrix).get(), + LiteralUtil::CreateR1(constant_vector).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -216,8 +218,8 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false} Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); auto expected = - Literal::MakeTuple({Literal::CreateR0(direction).get(), - Literal::CreateR0(!direction).get()}); + LiteralUtil::MakeTuple({LiteralUtil::CreateR0(direction).get(), + LiteralUtil::CreateR0(!direction).get()}); ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()}, error_spec_); @@ -284,8 +286,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) { ConstantR1(&builder, vec1)}); Select(ConstantR0(&builder, false), tuple12, tuple21); - auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), - Literal::CreateR1(vec1).get()}); + auto expected = + LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), + LiteralUtil::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -328,8 +331,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) { ConstantR1(&builder, vec1)}); Select(ConstantR0(&builder, true), tuple12, tuple21); - auto expected = Literal::MakeTuple({Literal::CreateR1(vec1).get(), - Literal::CreateR1(vec2).get()}); + auto expected = + LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec1).get(), + LiteralUtil::CreateR1(vec2).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -403,8 +407,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) { Select(ConstantR0(&builder, false), tuple12, tuple21); - auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), - Literal::CreateR1(vec1).get()}); + auto expected = + LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), + LiteralUtil::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -414,13 +419,13 @@ XLA_TEST_F(TupleTest, NestedTuples) { ConstantR0(&builder, 42.0)}); Tuple(&builder, {inner_tuple, ConstantR1(&builder, {22.0, 44.0})}); - auto expected_v1 = Literal::CreateR1({1.0, 2.0}); - auto expected_s = Literal::CreateR0(42.0); + auto expected_v1 = LiteralUtil::CreateR1({1.0, 2.0}); + auto expected_s = LiteralUtil::CreateR0(42.0); auto expected_inner_tuple = - Literal::MakeTuple({expected_v1.get(), expected_s.get()}); - auto expected_v2 = Literal::CreateR1({22.0, 44.0}); + LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()}); + auto expected_v2 = LiteralUtil::CreateR1({22.0, 44.0}); auto expected = - Literal::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); + LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -440,14 +445,14 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { std::unique_ptr data = client_ - ->TransferToServer(*Literal::MakeTuple({ - Literal::MakeTuple( + ->TransferToServer(*LiteralUtil::MakeTuple({ + LiteralUtil::MakeTuple( { - Literal::CreateR1({1.0, 2.0, 3.0}).get(), - Literal::CreateR1({4.0, 5.0, 6.0}).get(), + LiteralUtil::CreateR1({1.0, 2.0, 3.0}).get(), + LiteralUtil::CreateR1({4.0, 5.0, 6.0}).get(), }) .get(), - Literal::CreateR1({7.0, 8.0, 9.0}).get(), + LiteralUtil::CreateR1({7.0, 8.0, 9.0}).get(), })) .ConsumeValueOrDie(); @@ -478,11 +483,12 @@ XLA_TEST_F(TupleTest, ComplexTuples) { std::unique_ptr arg0 = client_ - ->TransferToServer(*Literal::MakeTuple( - {Literal::CreateR0({1, 2}).get(), - Literal::MakeTuple( - {Literal::CreateR1({{10, 20}, {30, 40}}).get(), - Literal::CreateR2( + ->TransferToServer(*LiteralUtil::MakeTuple( + {LiteralUtil::CreateR0({1, 2}).get(), + LiteralUtil::MakeTuple( + {LiteralUtil::CreateR1({{10, 20}, {30, 40}}) + .get(), + LiteralUtil::CreateR2( {{{100, 200}, {300, 400}}, {{1000, 2000}, {3000, 4000}}, {{10000, 20000}, {30000, 40000}}}) @@ -491,11 +497,13 @@ XLA_TEST_F(TupleTest, ComplexTuples) { .ConsumeValueOrDie(); std::unique_ptr arg1 = client_ - ->TransferToServer(*Literal::CreateR1({{1, 2}, {1, -2}})) + ->TransferToServer( + *LiteralUtil::CreateR1({{1, 2}, {1, -2}})) .ConsumeValueOrDie(); - auto sum = Literal::CreateR2({{{111, 222}, {331, 442}}, - {{1011, 2022}, {3031, 4042}}, - {{10011, 20022}, {30031, 40042}}}); + auto sum = + LiteralUtil::CreateR2({{{111, 222}, {331, 442}}, + {{1011, 2022}, {3031, 4042}}, + {{10011, 20022}, {30031, 40042}}}); auto prod = MakeUnique(sum->shape()); ASSERT_TRUE(prod->Populate( [&sum](tensorflow::gtl::ArraySlice indexes) { @@ -505,9 +513,9 @@ XLA_TEST_F(TupleTest, ComplexTuples) { : complex64(1, -2)); }) .ok()); - auto expected = - Literal::MakeTuple({Literal::MakeTuple({prod.get(), sum.get()}).get(), - Literal::CreateR0({123, 456}).get()}); + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::MakeTuple({prod.get(), sum.get()}).get(), + LiteralUtil::CreateR0({123, 456}).get()}); ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()}, error_spec_); } @@ -530,12 +538,59 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { auto module = HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); - auto param = Literal::MakeTupleOwned(Literal::CreateR1({1, 2, 3})); + auto param = + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({1, 2, 3})); auto result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *Literal::MakeTupleOwned(Literal::CreateR2({{1, 2, 3}})), + *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2({{1, 2, 3}})), *result)); } +// Disabled on interpreter due to lack of outfeed. +XLA_TEST_F(TupleHloTest, + DISABLED_ON_INTERPRETER(NonAmbiguousTopLevelAllocation)) { + const char* testcase = R"( + HloModule tuple + + ENTRY main { + a = f32[2] parameter(0) + b = f32[2] parameter(1) + c = f32[2] parameter(2) + d = f32[2] parameter(3) + cond = pred[] parameter(4) + + tup0 = (f32[2],f32[2]) tuple(a, b) + tup1 = (f32[2],f32[2]) tuple(c, d) + + s = (f32[2],f32[2]) tuple-select(cond, tup0, tup1) + gte = f32[2] get-tuple-element(s), index=0 + tuple = (f32[2]) tuple(gte) + token = token[] after-all() + ROOT outfeed = token[] outfeed(tuple, token) + } + )"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param0 = LiteralUtil::CreateR1({1, 2}); + auto param1 = LiteralUtil::CreateR1({2, 3}); + auto param4 = LiteralUtil::CreateR0(false); + // Put execution on a separate thread so we can block on outfeed. + std::unique_ptr thread( + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions(), "execute_thread", [&] { + TF_EXPECT_OK(Execute(std::move(module), + {param0.get(), param1.get(), param1.get(), + param0.get(), param4.get()}) + .status()); + })); + auto expected = + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({2, 3})); + auto literal = Literal::CreateFromShape(expected->shape()); + TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( + backend().default_stream_executor(), expected->shape(), *literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *literal)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 929b1ca7fb93c545265bf85fec1ed7dc845405b2..20ae68ab74026936c43e5f525eb796eb402a19cb 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -101,7 +101,7 @@ void UnaryOpTest::AbsTestHelper() { Abs(arg); std::unique_ptr expected = - Literal::CreateR1({2, 25, 0, 0.5, inf(), inf()}); + LiteralUtil::CreateR1({2, 25, 0, 0.5, inf(), inf()}); ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); } @@ -113,7 +113,7 @@ void UnaryOpTest::SignTestHelper() { {{-2, 0}, {0, 25}, {0, 0}, {static_cast(-0.0), 0}, {-1, 1}}); Sign(arg); - std::unique_ptr expected = Literal::CreateR1( + std::unique_ptr expected = LiteralUtil::CreateR1( {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}}); ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); } @@ -128,7 +128,7 @@ void UnaryOpTest::SignAbsTestHelper() { Sub(Mul(sign, ConvertElementType(abs, C64)), arg); std::unique_ptr expected = - Literal::CreateR1({0, 0, 0, 0}); + LiteralUtil::CreateR1({0, 0, 0, 0}); ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); } @@ -173,7 +173,7 @@ XLA_TEST_F(UnaryOpTest, SignTestR0) { Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64)); std::unique_ptr expected = - Literal::CreateR0({-2.6f, 0.8f}); + LiteralUtil::CreateR0({-2.6f, 0.8f}); ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); } diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc index ea3aba6df1d3fbd492a23b280309322b8524c0bf..ef1b1445bbe555da00db4446d59439b752735a80 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index c11df7cdf5a22568e80ce6e00fdbd862e6dcae9b..3848ec1684cdc9186e14ac0b60315b7520d127f3 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -135,46 +135,6 @@ XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(VecOpsSimpleTest, SquareTenValues) { - XlaBuilder builder(TestName()); - auto x = ConstantR1( - &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - SquareF32(x); - - std::vector expected = {4.41, 6.76, 6.76, 16., 4.41, - 5.29, 25., 0.81, 5.76, 2.56}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); -} - -XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) { - XlaBuilder builder(TestName()); - auto x = ConstantR1( - &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - ReciprocalF32(x); - - std::vector expected = { - 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048, - 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); -} - -XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) { - XlaBuilder builder(TestName()); - auto x = ConstantR1(&builder, {0.0, -0.0}); - SqrtF32(x); - - ComputeAndCompareR1(&builder, {0, 0}, {}, error_spec_); -} - -XLA_TEST_F(VecOpsSimpleTest, SqrtSixValues) { - XlaBuilder builder(TestName()); - auto x = ConstantR1(&builder, {16.0, 1.0, 1024.0, 0.16, 0.2, 12345}); - SqrtF32(x); - - std::vector expected = {4, 1, 32, 0.4, 0.4472, 111.1080}; - ComputeAndCompareR1(&builder, expected, {}, error_spec_); -} - XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index bbd67cd8d7c433550deefc38ce28b2b732d354aa..1bdf1867b9330b715b0ba4aca71d56307883c775 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -347,8 +347,8 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { // the sum will increase by 1.0. It will first be >15.5 when the elements // have all reached 2.0. auto expected_data = - Literal::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); - auto expected = Literal::MakeTuple({expected_data.get()}); + LiteralUtil::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); + auto expected = LiteralUtil::MakeTuple({expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -397,12 +397,13 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0(N); - auto expected_w1 = Literal::CreateR1({1.0f, 1.0f, 1.0f}); - auto expected_w2 = Literal::CreateR1({2.0f, 2.0f, 2.0f}); - auto expected_w3 = Literal::CreateR1({3.0f, 3.0f, 3.0f}); - auto expected = Literal::MakeTuple({expected_counter.get(), expected_w2.get(), - expected_w3.get(), expected_w1.get()}); + auto expected_counter = LiteralUtil::CreateR0(N); + auto expected_w1 = LiteralUtil::CreateR1({1.0f, 1.0f, 1.0f}); + auto expected_w2 = LiteralUtil::CreateR1({2.0f, 2.0f, 2.0f}); + auto expected_w3 = LiteralUtil::CreateR1({3.0f, 3.0f, 3.0f}); + auto expected = + LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(), + expected_w3.get(), expected_w1.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -506,11 +507,11 @@ TEST_F(WhileTest, WhileWithTupleResult) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0(5); - auto expected_data = Literal::CreateR1( + auto expected_counter = LiteralUtil::CreateR0(5); + auto expected_data = LiteralUtil::CreateR1( {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}); auto expected = - Literal::MakeTuple({expected_counter.get(), expected_data.get()}); + LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -554,10 +555,10 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0(5); - auto expected_predicate = Literal::CreateR0(true); - auto expected = - Literal::MakeTuple({expected_counter.get(), expected_predicate.get()}); + auto expected_counter = LiteralUtil::CreateR0(5); + auto expected_predicate = LiteralUtil::CreateR0(true); + auto expected = LiteralUtil::MakeTuple( + {expected_counter.get(), expected_predicate.get()}); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); } @@ -599,10 +600,10 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0(5); - auto expected_data = Literal::CreateR0(7); + auto expected_counter = LiteralUtil::CreateR0(5); + auto expected_data = LiteralUtil::CreateR0(7); auto expected = - Literal::MakeTuple({expected_counter.get(), expected_data.get()}); + LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -882,11 +883,11 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { << ShapeUtil::HumanString( builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = Literal::CreateR0(5); - auto expected_data = Literal::CreateR1( + auto expected_counter = LiteralUtil::CreateR0(5); + auto expected_data = LiteralUtil::CreateR1( {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f}); auto expected = - Literal::MakeTuple({expected_counter.get(), expected_data.get()}); + LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -974,12 +975,12 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); While(cond_computation, body_computation, t); - auto expected_element = Literal::CreateR1({1, 1}); + auto expected_element = LiteralUtil::CreateR1({1, 1}); auto expected = - Literal::MakeTuple({expected_element.get(), expected_element.get()}); + LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*Literal::CreateR1({42, 42}))); + client_->TransferToServer(*LiteralUtil::CreateR1({42, 42}))); ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1004,7 +1005,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*Literal::CreateR1({42, 42}))); + client_->TransferToServer(*LiteralUtil::CreateR1({42, 42}))); ComputeAndCompareR1(&outer, {1.0f, 1.0f}, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1030,7 +1031,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*Literal::CreateR0(42))); + client_->TransferToServer(*LiteralUtil::CreateR0(42))); ComputeAndCompareR0(&outer, 43.0f, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1069,11 +1070,11 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*Literal::CreateR0(1))); + client_->TransferToServer(*LiteralUtil::CreateR0(1))); - auto add1 = Literal::CreateR0(15); - auto add2 = Literal::CreateR0(16); - auto expected = Literal::MakeTuple({add1.get(), add2.get()}); + auto add1 = LiteralUtil::CreateR0(15); + auto add2 = LiteralUtil::CreateR0(16); + auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()}); ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1226,15 +1227,44 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { auto while_instruction = While(condition, body, init); GetTupleElement(while_instruction, 3); - TF_ASSERT_OK_AND_ASSIGN(auto param_value, - client_->TransferToServer(*Literal::CreateR2( - {{1.0, 2.0}, {-1.0, -2.0}}))); + TF_ASSERT_OK_AND_ASSIGN( + auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2( + {{1.0, 2.0}, {-1.0, -2.0}}))); ComputeAndCompareR2( &builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}}, {param_value.get()}, ErrorSpec(4e-5)); } +TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) { + auto while_shape = ShapeUtil::MakeShape(S32, {}); + + XlaComputation condition; + { + XlaBuilder builder("condition"); + Parameter(&builder, 0, while_shape, "state"); + Infeed(&builder, ShapeUtil::MakeShape(PRED, {})); + TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); + } + + XlaComputation body; + { + XlaBuilder builder("body"); + auto indvar = Parameter(&builder, 0, while_shape, "state"); + Add(indvar, ConstantR0(&builder, 1)); + TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); + } + + XlaBuilder builder(TestName()); + While(condition, body, ConstantR0(&builder, 0)); + + TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); + TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); + TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(false))); + + ComputeAndCompareR0(&builder, 2, {}); +} + void BM_WhileLoop(int num_iters) { // Benchmark a simple kernel to measure while loop overheads. tensorflow::testing::StopTiming(); diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index c0616809f9f060e3447e62d387535a5acffe1075..11f3efb1f34ad23ebdcbb65c90aa5fb7a6adeae5 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -18,10 +18,11 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -79,10 +80,12 @@ struct ParsedProfileOutputLine { Status ParseOneProfileOutputLine( const string& line, bool expect_hlo, - gtl::FlatMap* parsed_results) { + gtl::FlatMap* parsed_results, + tensorflow::gtl::ArraySlice opcodes_to_ignore = + {}) { string separator = "[^:]*:: +"; - string match_percentage = "\\d+\\.\\d\\d%"; - string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)"; + string match_percentage = R"(\d+\.\d*% +\d+Σ)"; + string match_cycles = R"((\d+) cycles +\( *()" + match_percentage + R"()\))"; string match_usecs = "([0-9.]+) usec"; string match_flops = "([^ ]*)"; string match_trops = "([^ ]*)"; @@ -113,7 +116,9 @@ Status ParseOneProfileOutputLine( ", Regexp: ", regexp_pattern); } - InsertOrDie(parsed_results, parsed_line.opcode, parsed_line); + if (!c_linear_search(opcodes_to_ignore, parsed_line.opcode)) { + InsertOrDie(parsed_results, parsed_line.opcode, parsed_line); + } return Status::OK(); } @@ -129,7 +134,7 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, DeviceMemoryAllocator* allocator = backend->memory_allocator(); auto* transfer_manager = backend->transfer_manager(); TF_ASSERT_OK_AND_ASSIGN( - Backend::StreamPtr stream_ptr, + StreamPool::Ptr stream_ptr, backend->BorrowStream(backend->default_device_ordinal())); TF_ASSERT_OK_AND_ASSIGN( @@ -220,7 +225,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { MaybeFind(parsed_profile_lines, "tanh")); EXPECT_GT(total_profile.cycles, 0); - EXPECT_EQ(total_profile.cycles_percentage, "100.00%"); + EXPECT_EQ(total_profile.cycles_percentage, "100.% 100Σ"); EXPECT_TRUE(HasFlops(total_profile)); EXPECT_TRUE(HasTrops(total_profile)); @@ -240,9 +245,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { EXPECT_TRUE(HasTrops(tanh_profile)); } -// TODO(b/71544591): The GPU backend does not record cycles spent in on Hlo -// instructions "interior" to while nodes. -XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) { +XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { const int64 size = 256; Shape matrix_shape = ShapeUtil::MakeShape(F32, {size, size}); Shape while_result_shape = @@ -269,7 +272,7 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) { auto matrix = GetTupleElement(state, 1); auto next_iteration = Add(GetTupleElement(state, 0), ConstantR0(&builder, 1)); - Tuple(&builder, {next_iteration, Add(matrix, matrix)}); + Tuple(&builder, {next_iteration, Mul(matrix, matrix)}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } @@ -291,36 +294,50 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) { tensorflow::str_util::Split(profile_output, '\n'); auto while_body_profile_start = - std::find_if(profile_output_lines.begin(), profile_output_lines.end(), + c_find_if(profile_output_lines, [](tensorflow::StringPiece s) { + return tensorflow::str_util::StartsWith(s, + "Execution profile for body"); + }); + + ASSERT_NE(while_body_profile_start, profile_output_lines.cend()); + + auto while_body_profile_end = + std::find_if(while_body_profile_start, profile_output_lines.end(), [](tensorflow::StringPiece s) { return tensorflow::str_util::StartsWith( - s, "Execution profile for body"); + s, "********** microseconds report **********"); }); - ASSERT_NE(while_body_profile_start, profile_output_lines.end()); + // We emit a blank line before the "********** microseconds report **********" + // line. + while_body_profile_end--; - gtl::FlatMap parsed_profile_lines; + ASSERT_NE(while_body_profile_end, profile_output_lines.end()); - TF_ASSERT_OK( - ParseOneProfileOutputLine(*std::next(while_body_profile_start, 1), - /*expect_hlo=*/false, &parsed_profile_lines)); + gtl::FlatMap parsed_profile_lines; - TF_ASSERT_OK( - ParseOneProfileOutputLine(*std::next(while_body_profile_start, 2), - /*expect_hlo=*/true, &parsed_profile_lines)); + for (auto while_body_profile_i = while_body_profile_start + 1; + while_body_profile_i != while_body_profile_end; while_body_profile_i++) { + // There are multiple "get-tuple-element" instructions in the while body so + // we ignore them -- we don't want parsed_profile_lines to be a multi-map. + TF_ASSERT_OK(ParseOneProfileOutputLine( + *while_body_profile_i, + /*expect_hlo=*/while_body_profile_i != (while_body_profile_start + 1), + &parsed_profile_lines, {"get-tuple-element"})); + } TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_while_body_profile, MaybeFind(parsed_profile_lines, "[total]")); - TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile, - MaybeFind(parsed_profile_lines, "add")); + TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine multiply_profile, + MaybeFind(parsed_profile_lines, "multiply")); EXPECT_GT(total_while_body_profile.cycles, 0); EXPECT_EQ(total_while_body_profile.opcode, "[total]"); - EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.00%"); + EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.% 100Σ"); - EXPECT_GT(total_while_body_profile.cycles, dot_profile.cycles); - EXPECT_NE(dot_profile.cycles_percentage, "0.00%"); - EXPECT_NE(dot_profile.cycles_percentage, "100.00%"); + EXPECT_GT(total_while_body_profile.cycles, multiply_profile.cycles); + EXPECT_NE(multiply_profile.cycles_percentage, "0.00%"); + EXPECT_NE(multiply_profile.cycles_percentage, "100.00%"); } } // namespace } // namespace xla @@ -337,8 +354,11 @@ static std::pair AddXlaHloProfileFlag(int argc, char** argv) { new_argv[argc] = strdup("--xla_hlo_profile"); // Fusion can change the Hlo instructions that show up in the final Hlo - // executable, so block it here. - new_argv[argc + 1] = strdup("--xla_disable_hlo_passes=fusion"); + // executable, so block it here. Also block the WhileLoopInvariantCodeMotion + // pass, otherwise a while loop is transformed and we could not match the + // original name in the ProfileWhileComputation test. + new_argv[argc + 1] = strdup( + "--xla_disable_hlo_passes=fusion,while-loop-invariant-code-motion"); return {argc + 2, new_argv}; } diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 56702feab9a4e8d00df3a165ab994aef2d42d830..897123d7606db60abc1105b03beb3f23ab249579 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h index e45e5291c9b10803f5e5008b72c7dd0116a0dea0..708e8c80d8b5c09454eb64d4e12df51a5b7ea628 100644 --- a/tensorflow/compiler/xla/text_literal_reader.h +++ b/tensorflow/compiler/xla/text_literal_reader.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc index 23070b663870a2b78b38663e09a32fcb28d9c2dc..92f9b4f9f0efa2dc08287bdcbefc88f879164308 100644 --- a/tensorflow/compiler/xla/text_literal_reader_test.cc +++ b/tensorflow/compiler/xla/text_literal_reader_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 373c0d2d8d8ab05dec11e51f265d41b91e7920bf..24e0784741a4c9779b0adb7a7740c3d6e2fb033a 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h index 0a1235b5e04675da0f412bafab6c4ecf04367787..159ac1b7e1b6f9c07dac795fb640cd0b2d284bcb 100644 --- a/tensorflow/compiler/xla/text_literal_writer.h +++ b/tensorflow/compiler/xla/text_literal_writer.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ #define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc index 70cf2fb1b8a1b4f2ecfdaeaef3a00ddc974e2652..4ea02faffcd52065b05c0444202bd1a3d9d87ee6 100644 --- a/tensorflow/compiler/xla/text_literal_writer_test.cc +++ b/tensorflow/compiler/xla/text_literal_writer_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -30,8 +31,9 @@ namespace xla { namespace { TEST(TextLiteralWriterTest, WritesFloatLiteral) { - auto literal = Literal::CreateR2({ - {3.14, 2.17}, {1.23, 4.56}, + auto literal = LiteralUtil::CreateR2({ + {3.14, 2.17}, + {1.23, 4.56}, }); string path = tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/whatever"); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index e4a052c8f1c0009619c3a94606f6384d04006e4e..40d28a57bfddd3403cad8252df985b746362631f 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -37,6 +37,7 @@ cc_library( "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_proto", @@ -74,7 +75,7 @@ cc_library( srcs = ["replay_computation.cc"], deps = [ "//tensorflow/compiler/xla:execution_options_util", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -84,7 +85,9 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:testing", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service/gpu:infeed_manager", @@ -123,7 +126,7 @@ tf_cc_binary( name = "show_literal", srcs = ["show_literal.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", @@ -145,7 +148,7 @@ tf_cc_binary( name = "show_text_literal", srcs = ["show_text_literal.cc"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:text_literal_reader", "//tensorflow/compiler/xla:types", @@ -164,6 +167,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", @@ -181,6 +185,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_proto", @@ -198,6 +203,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_graph_dumper", diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index befb55453777dce30af89bcaad2ffe1647097576..f20dcef382b86d27d7c176ae7e4132ad1db7b901 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index cfb8f37487d6499b803438a135be54524fcf17d2..f0af0580c1fbca455c6ed5f87f82971faee50a06 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 5dd5150be339846d0775880931f615b92c5b08d8..f03e1b1f965af761c101555fd0275bc0425b9cf0 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index a5dce20456c6a2402f425ebb3d575d1bb625f839..dc5c106d02cb679f3e6f5b2bea40bbb42f8bd1cc 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 3a7917cf3043de8a77f189f011bdeb3e8d2ddf3c..be4cf4318b33f41fc611ea90a1a02198e23b84e4 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -30,6 +30,9 @@ limitations under the License. // The output format is: // // file_path: computation_name :: type:literal_str +// +// Note: If you pass multiple modules, they will be compiled in parallel but run +// in series. #include #include @@ -42,8 +45,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" @@ -74,6 +79,18 @@ struct Options { int num_runs = 1; }; +std::unique_ptr CompileExecutable(const HloSnapshot& module, + LocalClient* client) { + XlaComputation computation(module.hlo().hlo_module()); + std::vector argument_layouts; + for (const auto& param : computation.proto().program_shape().parameters()) { + argument_layouts.push_back(¶m); + } + return client + ->Compile(computation, argument_layouts, ExecutableBuildOptions()) + .ValueOrDie(); +} + // Invokes the given computation passing arbitrary data for every (unbound) // parameter if use_fake_data, Otherwise use recorded data if available. // @@ -84,6 +101,7 @@ struct Options { // If neither generate_fake_infeed is true nor a fake_infeed_shape is provided, // no infeed is performed. StatusOr ReplayComputation(const HloSnapshot& module, + LocalExecutable* executable, LocalClient* client, const Options& opts) { XlaComputation computation(module.hlo().hlo_module()); @@ -166,34 +184,34 @@ StatusOr ReplayComputation(const HloSnapshot& module, }); } - std::vector argument_layouts; - for (const auto& param : computation.proto().program_shape().parameters()) { - argument_layouts.push_back(¶m); - } - std::unique_ptr executable = - client->Compile(computation, argument_layouts, ExecutableBuildOptions()) - .ValueOrDie(); - - // Do not attmept to run the executable, if num_runs is less than 1. + // Do not attempt to run the executable if num_runs is less than 1. if (opts.num_runs < 1) { return Cancelled("Cancelled after compilation since --num_runs < 1."); } // Run the computation num_runs times, and return the result from the last // execution. + const bool xla_hlo_profile = + legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile(); StreamExecutorMemoryAllocator allocator( client->platform(), {client->platform()->ExecutorForDevice(0).ValueOrDie()}); tensorflow::gtl::optional result; for (int i = 0; i < opts.num_runs; ++i) { + // If xla_hlo_profile is enabled, print a noisy message before the last run, + // making it easier to separate this profile from the others in the logspam. + if (xla_hlo_profile && i == opts.num_runs - 1) { + LOG(INFO) << "\n\n***** Final run below ******"; + } ExecutionProfile profile; ExecutableRunOptions run_options; run_options.set_execution_profile(&profile); run_options.set_allocator(&allocator); TF_ASSIGN_OR_RETURN(result, executable->Run(argument_ptrs, run_options)); - LOG(INFO) << "Execution took " - << static_cast(profile.compute_time_ns()) / 1e9 << "s"; + LOG(INFO) << "Done executing in " + << static_cast(profile.compute_time_ns()) / 1e9 + << "s: " << module.hlo().hlo_module().name(); } TF_ASSIGN_OR_RETURN(std::unique_ptr result_literal, @@ -234,15 +252,39 @@ StatusOr ParseInputFile(const string& filename, int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { LocalClient* client = ClientLibrary::LocalClientOrDie(); int exit_status = EXIT_SUCCESS; + + std::vector snapshots; for (char* arg : args) { StatusOr maybe_snapshot = ParseInputFile(arg, opts); - if (!maybe_snapshot.ok()) { - continue; + if (maybe_snapshot.ok()) { + snapshots.push_back(std::move(maybe_snapshot).ValueOrDie()); } - HloSnapshot snapshot = std::move(maybe_snapshot).ValueOrDie(); - StatusOr result_status = ReplayComputation(snapshot, client, opts); + } + + // Compile all the modules in parallel. + LOG(INFO) << "Compiling " << snapshots.size() << " modules in parallel."; + std::vector> executables; + { + // ThreadPool CHECK-fails if we give it 0 threads. + tensorflow::thread::ThreadPool thread_pool( + tensorflow::Env::Default(), tensorflow::ThreadOptions(), + "compile_modules", std::max(size_t{1}, snapshots.size()), + /*low_latency_hint=*/false); + executables.resize(snapshots.size()); + for (int64 i = 0; i < snapshots.size(); ++i) { + thread_pool.Schedule([&snapshots, &executables, client, i] { + executables[i] = CompileExecutable(snapshots[i], client); + }); + } + } + LOG(INFO) << "Done compiling; now running the modules."; + + for (int64 i = 0; i < executables.size(); ++i) { + LocalExecutable* executable = executables[i].get(); + StatusOr result_status = + ReplayComputation(snapshots[i], executable, client, opts); if (!result_status.ok()) { - fprintf(stderr, "%s: error: %s\n", arg, + fprintf(stderr, "%s: error: %s\n", args[i], result_status.status().ToString().c_str()); exit_status = EXIT_FAILURE; continue; @@ -250,10 +292,11 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { if (opts.print_result) { Literal result = std::move(result_status).ValueOrDie(); - fprintf(stdout, "%s: %s :: %s:%s\n", arg, - snapshot.hlo().hlo_module().name().c_str(), + fprintf(stdout, "%s: %s :: %s:%s\n", args[i], + executable->executable()->module().name().c_str(), ShapeUtil::HumanString(result.shape()).c_str(), result.ToString().c_str()); + auto& snapshot = snapshots[i]; if (snapshot.has_result()) { std::unique_ptr literal = Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc index fe8e72ba32bb4493b2751cfdfeb977f271092f9c..51909190a3ef20c3df78d08796e88bdbb650609d 100644 --- a/tensorflow/compiler/xla/tools/show_literal.cc +++ b/tensorflow/compiler/xla/tools/show_literal.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc index 8525873e913185554d18df8c8c3584bfcdcdcabe..48c837481181f6ad8f864569fd62e0e23fa02ecd 100644 --- a/tensorflow/compiler/xla/tools/show_text_literal.cc +++ b/tensorflow/compiler/xla/tools/show_text_literal.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/text_literal_reader.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 6041fae1595dacb309008857f1c758ee96a646bb..5ae099a4622bb7116c7a17f93060b699ead6e3a6 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -500,17 +500,17 @@ bool c_is_sorted(const C& c, Compare&& comp) { } template -auto c_adjacent_find(const C& c) -> decltype(std::begin(c)) { +auto c_adjacent_find(C& c) -> decltype(std::begin(c)) { return std::adjacent_find(std::begin(c), std::end(c)); } template -auto c_find_if(const C& c, Pred&& pred) -> decltype(std::begin(c)) { +auto c_find_if(C& c, Pred&& pred) -> decltype(std::begin(c)) { return std::find_if(std::begin(c), std::end(c), std::forward(pred)); } template -auto c_find(const C& c, Value&& value) -> decltype(std::begin(c)) { +auto c_find(C& c, Value&& value) -> decltype(std::begin(c)) { return std::find(std::begin(c), std::end(c), std::forward(value)); } @@ -534,6 +534,13 @@ c_count_if(const C& c, Pred&& pred) { return std::count_if(std::begin(c), std::end(c), std::forward(pred)); } +// Determines whether `value` is present in `c`. +template +bool c_linear_search(const C& c, T&& value) { + auto last = std::end(c); + return std::find(std::begin(c), last, std::forward(value)) != last; +} + template int64 FindIndex(const C& c, Value&& value) { auto it = c_find(c, std::forward(value)); @@ -555,6 +562,11 @@ void EraseAt(C* c, int64 index) { c->erase(c->begin() + index); } +template +std::vector ArraySliceToVector(tensorflow::gtl::ArraySlice slice) { + return std::vector(slice.begin(), slice.end()); +} + template std::vector InlinedVectorToVector( const tensorflow::gtl::InlinedVector& inlined_vector) { diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 6f07e4606bef015214f2c564515c8258a906205b..10c0adc6707f01fcee87303a6e2ec5c570601309 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -293,6 +293,7 @@ message ComputationStatsResponse { } message CreateChannelHandleRequest { + ChannelHandle.ChannelType channel_type = 1; } message CreateChannelHandleResponse { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index c7472173a705b7a6e1bee2f5221f23db0a77991d..4c35e93d38450b8263290da8e327d1f2126c1532 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -308,6 +308,22 @@ message DeviceHandle { // Send instructions will be blocked until the data is transferred. message ChannelHandle { int64 handle = 1; + enum ChannelType { + // Invalid primitive type to serve as default. + CHANNEL_TYPE_INVALID = 0; + + // A channel for sending data between devices. + DEVICE_TO_DEVICE = 1; + + // A channel for sending data from the device to the host. Can only be used + // with a Send operation. + DEVICE_TO_HOST = 2; + + // A channel for sending data from the host to the device. Can only be used + // with a Recv operation. + HOST_TO_DEVICE = 3; + } + ChannelType type = 2; } // DeviceAssignmentProto is a serialized form of DeviceAssignment class, which @@ -431,6 +447,20 @@ message GatherDimensionNumbers { int64 index_vector_dim = 4; } +// Describes the dimension numbers for a scatter operation. +// +// All the fields are similar to the corresponding fields in +// GatherDimensionNumbers. Differences are noted below. +message ScatterDimensionNumbers { + // The set of dimensions in the updates shape that are window dimensions. + repeated int64 update_window_dims = 1; + // The set of window dimensions that must be inserted into the updates shape. + repeated int64 inserted_window_dims = 2; + + repeated int64 scatter_dims_to_operand_dims = 3; + int64 index_vector_dim = 4; +} + message ConvolutionDimensionNumbers { // The number of the dimension that represents batch in the input. int64 input_batch_dimension = 7; @@ -531,3 +561,11 @@ message OpSharding { // to. repeated OpSharding tuple_shardings = 5; } + +// Describes the replica groups in a cross replica op (e.g., all-reduce and +// all-to-all). +message ReplicaGroup { + // The ids of the replicas that belongs to the same group. The ordering of the + // ids matters in some op (e.g., all-to-all). + repeated int64 replica_ids = 1; +} diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 2d7916c8b134ac034dcd3192cb70fa27fa742d6f..cc34db995e2ad653c6acce10d451de62ae8b264b 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -7,7 +7,6 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load("//third_party/mpi:mpi.bzl", "if_mpi") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") load("//tensorflow:tensorflow.bzl", "if_not_windows") load("//tensorflow:tensorflow.bzl", "if_not_windows_cuda") @@ -27,8 +26,6 @@ py_library( "//tensorflow/contrib/bayesflow:bayesflow_py", "//tensorflow/contrib/boosted_trees:init_py", "//tensorflow/contrib/checkpoint/python:checkpoint", - "//tensorflow/contrib/cloud:cloud_py", - "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/coder:coder_py", "//tensorflow/contrib/compiler:compiler_py", @@ -110,23 +107,33 @@ py_library( "//tensorflow/contrib/tfprof", "//tensorflow/contrib/timeseries", "//tensorflow/contrib/tpu", - "//tensorflow/contrib/tpu:tpu_py", "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:util", "//tensorflow/python/estimator:estimator_py", - ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_tensorrt([ - "//tensorflow/contrib/tensorrt:init_py", - ]) + select({ + ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + select({ "//tensorflow:with_kafka_support_windows_override": [], "//tensorflow:with_kafka_support": [ "//tensorflow/contrib/kafka", ], "//conditions:default": [], + }) + select({ + "//tensorflow:with_aws_support_windows_override": [], + "//tensorflow:with_aws_support": [ + "//tensorflow/contrib/kinesis", + ], + "//conditions:default": [], }) + if_not_windows_cuda([ "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols ]) + if_not_windows([ + "//tensorflow/contrib/bigtable", # depends on bigtable + "//tensorflow/contrib/cloud:cloud_py", # doesn't compile on Windows "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", + # TODO(aaroey): tensorrt dependency has to appear before tflite so the + # build can resolve its flatbuffers symbols within the tensorrt library. + # This is an issue with the tensorrt static library and will be fixed by + # the next tensorrt release, so fix the order here after that. + "//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows "//tensorflow/contrib/lite/python:lite", # unix dependency, need to fix code ]), ) @@ -156,6 +163,12 @@ cc_library( "//tensorflow/contrib/kafka:dataset_kernels", ], "//conditions:default": [], + }) + select({ + "//tensorflow:with_aws_support_windows_override": [], + "//tensorflow:with_aws_support": [ + "//tensorflow/contrib/kinesis:dataset_kernels", + ], + "//conditions:default": [], }), ) @@ -185,5 +198,11 @@ cc_library( "//tensorflow/contrib/kafka:dataset_ops_op_lib", ], "//conditions:default": [], + }) + select({ + "//tensorflow:with_aws_support_windows_override": [], + "//tensorflow:with_aws_support": [ + "//tensorflow/contrib/kinesis:dataset_ops_op_lib", + ], + "//conditions:default": [], }), ) diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 9aad772f0acd941d50d6ba238d345616195a6939..e18ea8df4df719a7317333cf9038ce7facf8d6ac 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -22,10 +22,12 @@ from __future__ import print_function import os # Add projects here, they will show up under tf.contrib. +from tensorflow.contrib import autograph from tensorflow.contrib import batching from tensorflow.contrib import bayesflow from tensorflow.contrib import checkpoint -from tensorflow.contrib import cloud +if os.name != "nt": + from tensorflow.contrib import cloud from tensorflow.contrib import cluster_resolver from tensorflow.contrib import coder from tensorflow.contrib import compiler diff --git a/tensorflow/contrib/android/cmake/src/main/AndroidManifest.xml b/tensorflow/contrib/android/cmake/src/main/AndroidManifest.xml index bced47e046db889366bf88e563d086a8c367431a..c17110a78be49f70ef108be79a624d87ad9ed28d 100644 --- a/tensorflow/contrib/android/cmake/src/main/AndroidManifest.xml +++ b/tensorflow/contrib/android/cmake/src/main/AndroidManifest.xml @@ -1,6 +1,10 @@ + + diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md index 7e26f4711851138c1834f881621ebfa227a85821..cc54da4daa9a5bb4e64145963ffec63021d08876 100644 --- a/tensorflow/contrib/autograph/README.md +++ b/tensorflow/contrib/autograph/README.md @@ -1,10 +1,10 @@ # AutoGraph -IMPORTANT: AutoGraph is alpha software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)). +IMPORTANT: AutoGraph is beta software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)). AutoGraph is a Python to TensorFlow compiler. -With AutoGraph, you can write [Eager style](https://www.tensorflow.org/guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops. +With AutoGraph, you can write [Eager style](https://www.tensorflow.org/guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops. [Please see this file for which parts of the Python language we currently support](LIMITATIONS.md). For example, this Python function: @@ -68,12 +68,21 @@ Then import the `autograph` module from `tf.contrib`: from tensorflow.contrib import autograph as ag ``` -### Interactive demo notebooks +### Related links -For more extensive examples, check out these interactive notebooks: +Articles: - * [RNN trained using Keras and Estimators](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb) + * [TensorFlow blog post](https://medium.com/tensorflow/autograph-converts-python-into-tensorflow-graphs-b2a871f87ec7) + +Interactive notebooks: + + * [Quick guide](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/guide/autograph.ipynb) + * [RNN trained using Keras and Estimators](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb) * [Demo from the TF Dev Summit 2018](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb) + * [Basic control flow speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_collatz_speed_test.ipynb) + * [MNIST training speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_mnist_speed_test.ipynb) + * [Basic algorithm samples](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb) + * [Introductory workshop support notebook](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb) ## Using with annotations diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py index 361cf2d77c7e46912d5bff5881df2ffa897c5179..26e7a4a4d38e264486c981e6fc4c547bcc53b302 100644 --- a/tensorflow/contrib/autograph/__init__.py +++ b/tensorflow/contrib/autograph/__init__.py @@ -22,17 +22,21 @@ from __future__ import division from __future__ import print_function # TODO(mdan): Bring only the relevant symbols to the top level. -from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph import operators +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.core.errors import GraphConstructionError +from tensorflow.contrib.autograph.core.errors import TfRuntimeError +from tensorflow.contrib.autograph.core.errors import improved_errors +from tensorflow.contrib.autograph.impl.api import RunMode from tensorflow.contrib.autograph.impl.api import convert from tensorflow.contrib.autograph.impl.api import converted_call from tensorflow.contrib.autograph.impl.api import do_not_convert -from tensorflow.contrib.autograph.impl.api import RunMode from tensorflow.contrib.autograph.impl.api import to_code from tensorflow.contrib.autograph.impl.api import to_graph from tensorflow.contrib.autograph.lang.directives import set_element_type from tensorflow.contrib.autograph.lang.directives import set_loop_options from tensorflow.contrib.autograph.lang.special_functions import stack +from tensorflow.contrib.autograph.lang.special_functions import tensor_list from tensorflow.contrib.autograph.pyct.transformer import AutographParseError from tensorflow.python.util.all_util import remove_undocumented @@ -46,10 +50,15 @@ _allowed_symbols = [ 'to_graph', # Overloaded operators 'operators', + # Errors + 'improved_errors', + 'GraphConstructionError', + 'TfRuntimeError', # Python language "extensions" 'set_element_type', 'set_loop_options', 'stack', + 'tensor_list', # Exceptions 'AutographParseError', # Utilities: to be removed diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD index b2e2e27673dafe290cef40a9fe0a834bfe1ea61f..2d2ab7040a8bb76f9538f201f75a2e4dcba0f511 100644 --- a/tensorflow/contrib/autograph/converters/BUILD +++ b/tensorflow/contrib/autograph/converters/BUILD @@ -21,16 +21,18 @@ py_library( "break_statements.py", "builtin_functions.py", "call_trees.py", + "conditional_expressions.py", "continue_statements.py", "control_flow.py", "decorators.py", - "ifexp.py", - "list_comprehension.py", + "directives.py", + "error_handlers.py", + "list_comprehensions.py", "lists.py", "logical_expressions.py", "name_scopes.py", + "return_statements.py", "side_effect_guards.py", - "single_return.py", "slices.py", ], srcs_version = "PY2AND3", @@ -94,6 +96,17 @@ py_test( ], ) +py_test( + name = "conditional_expressions_test", + srcs = ["conditional_expressions_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "continue_statements_test", srcs = ["continue_statements_test.py"], @@ -131,6 +144,18 @@ py_test( ], ) +py_test( + name = "directives_test", + srcs = ["directives_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", + "//tensorflow/contrib/autograph/lang", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "name_scopes_test", srcs = ["name_scopes_test.py"], @@ -143,8 +168,8 @@ py_test( ) py_test( - name = "list_comprehension_test", - srcs = ["list_comprehension_test.py"], + name = "list_comprehensions_test", + srcs = ["list_comprehensions_test.py"], srcs_version = "PY2AND3", deps = [ ":converters", @@ -179,11 +204,7 @@ py_test( name = "side_effect_guards_test", srcs = ["side_effect_guards_test.py"], srcs_version = "PY2AND3", - tags = [ - # TODO(mdan): Fix. - "flaky", - "notap", - ], + tags = ["notsan"], deps = [ ":converters", "//tensorflow/contrib/autograph/core:test_lib", @@ -192,8 +213,8 @@ py_test( ) py_test( - name = "single_return_test", - srcs = ["single_return_test.py"], + name = "return_statements_test", + srcs = ["return_statements_test.py"], srcs_version = "PY2AND3", deps = [ ":converters", @@ -204,8 +225,8 @@ py_test( ) py_test( - name = "ifexp_test", - srcs = ["ifexp_test.py"], + name = "error_handlers_test", + srcs = ["error_handlers_test.py"], srcs_version = "PY2AND3", deps = [ ":converters", diff --git a/tensorflow/contrib/autograph/converters/__init__.py b/tensorflow/contrib/autograph/converters/__init__.py index e4e8eda42f655e204310eaa9defdd5c90bf06e15..6325ac78dc3a08d14c1abf5e0f1ae60258639162 100644 --- a/tensorflow/contrib/autograph/converters/__init__.py +++ b/tensorflow/contrib/autograph/converters/__init__.py @@ -18,5 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# TODO(mdan): Define a base transformer class that can recognize skip_processing -# TODO(mdan): All converters are incomplete, especially those that change blocks +# Naming conventions: +# * each converter should specialize on a single idiom; be consistent with +# the Python reference for naming +# * all converters inherit core.converter.Base +# * module names describe the idiom that the converter covers, plural +# * the converter class is named consistent with the module, singular and +# includes the word Transformer +# +# Example: +# +# lists.py +# class ListTransformer(converter.Base) diff --git a/tensorflow/contrib/autograph/converters/asserts.py b/tensorflow/contrib/autograph/converters/asserts.py index e664a403a5fb800e7d0dddfa5695330927aaf4e0..af2f20f267d5cc64a6e9507a08c44f7e52245c28 100644 --- a/tensorflow/contrib/autograph/converters/asserts.py +++ b/tensorflow/contrib/autograph/converters/asserts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Converts Assert statements to their corresponding TF calls.""" +"""Converts assert statements to their corresponding TF calls.""" from __future__ import absolute_import from __future__ import division @@ -24,8 +24,8 @@ from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import templates -class AssertsTransformer(converter.Base): - """Transforms Print nodes to Call so they can be handled as functions.""" +class AssertTransformer(converter.Base): + """Transforms Assert nodes to Call so they can be handled as functions.""" def visit_Assert(self, node): self.generic_visit(node) @@ -46,4 +46,4 @@ class AssertsTransformer(converter.Base): def transform(node, ctx): - return AssertsTransformer(ctx).visit(node) + return AssertTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/asserts_test.py b/tensorflow/contrib/autograph/converters/asserts_test.py index 2cd0e626bc4552bd40bc94b890fdcc7efcafb3f3..38faba45df6746d56933a1647594af133b671628 100644 --- a/tensorflow/contrib/autograph/converters/asserts_test.py +++ b/tensorflow/contrib/autograph/converters/asserts_test.py @@ -32,10 +32,10 @@ class AssertsTest(converter_testing.TestCase): def test_fn(a): assert a > 0 - node = self.parse_and_analyze(test_fn, {}) - node = asserts.transform(node, self.ctx) + node, ctx = self.prepare(test_fn, {}) + node = asserts.transform(node, ctx) - self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call)) + self.assertTrue(isinstance(node.body[0].value, gast.Call)) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py index a990e359a2a25a57ee2a4f8a866350633f3b9ea8..180779670d91abd7d395bda0b63f592967c5015b 100644 --- a/tensorflow/contrib/autograph/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Canonicalizes break statements by de-sugaring into a control boolean.""" +"""Lowers break statements to conditionals.""" from __future__ import absolute_import from __future__ import division @@ -24,20 +24,25 @@ from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno -# Tags for local state. -BREAK_USED = 'break_used' -CONTROL_VAR_NAME = 'control_var_name' +class _Break(object): + def __init__(self): + self.used = False + self.control_var_name = None -class BreakStatementTransformer(converter.Base): + def __repr__(self): + return 'used: %s, var: %s' % (self.used, self.control_var_name) + + +class BreakTransformer(converter.Base): """Canonicalizes break statements into additional conditionals.""" def visit_Break(self, node): - self.set_local(BREAK_USED, True) - var_name = self.get_local(CONTROL_VAR_NAME) + self.state[_Break].used = True + var_name = self.state[_Break].control_var_name # TODO(mdan): This will fail when expanded inside a top-level else block. template = """ - var_name = True + var_name = tf.constant(True) continue """ return templates.replace(template, var_name=var_name) @@ -57,12 +62,12 @@ class BreakStatementTransformer(converter.Base): block=block) return node - def _track_body(self, nodes, break_var): - self.enter_local_scope() - self.set_local(CONTROL_VAR_NAME, break_var) + def _process_body(self, nodes, break_var): + self.state[_Break].enter() + self.state[_Break].control_var_name = break_var nodes = self.visit_block(nodes) - break_used = self.get_local(BREAK_USED, False) - self.exit_local_scope() + break_used = self.state[_Break].used + self.state[_Break].exit() return nodes, break_used def visit_While(self, node): @@ -70,7 +75,7 @@ class BreakStatementTransformer(converter.Base): break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.test = self.visit(node.test) - node.body, break_used = self._track_body(node.body, break_var) + node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) @@ -80,7 +85,7 @@ class BreakStatementTransformer(converter.Base): guarded_orelse = self._guard_if_present(node.orelse, break_var) template = """ - var_name = False + var_name = tf.constant(False) while test and not var_name: body else: @@ -101,7 +106,7 @@ class BreakStatementTransformer(converter.Base): node.target = self.visit(node.target) node.iter = self.visit(node.iter) - node.body, break_used = self._track_body(node.body, break_var) + node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) @@ -117,7 +122,7 @@ class BreakStatementTransformer(converter.Base): # the control variable is marked as used. # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ - var_name = False + var_name = tf.constant(False) for target in iter_: (var_name,) body @@ -138,4 +143,4 @@ class BreakStatementTransformer(converter.Base): def transform(node, ctx): - return BreakStatementTransformer(ctx).visit(node) + return BreakTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/break_statements_test.py b/tensorflow/contrib/autograph/converters/break_statements_test.py index dcff1c54c2f9300d58d217517e108d634ae85fb4..fcae7d68c0f90817e001b45fa86ca6be08456027 100644 --- a/tensorflow/contrib/autograph/converters/break_statements_test.py +++ b/tensorflow/contrib/autograph/converters/break_statements_test.py @@ -20,12 +20,19 @@ from __future__ import print_function from tensorflow.contrib.autograph.converters import break_statements from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.python.eager import context as tfe_ctx +from tensorflow.python.framework import constant_op from tensorflow.python.platform import test class BreakCanonicalizationTest(converter_testing.TestCase): - def test_basic_while(self): + def assertTransformedEquivalent(self, test_fn, *inputs): + with self.converted(test_fn, break_statements, {}, + constant_op.constant) as result: + self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) + + def test_while_loop(self): def test_fn(x): v = [] @@ -36,15 +43,12 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - node = self.parse_and_analyze(test_fn, {}) - node = break_statements.transform(node, self.ctx) - - with self.compiled(node) as result: - self.assertEqual([], result.test_fn(0)) - self.assertEqual([], result.test_fn(1)) - self.assertEqual([3], result.test_fn(4)) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 1) + self.assertTransformedEquivalent(test_fn, 4) - def test_basic_for(self): + def test_for_loop(self): def test_fn(a): v = [] @@ -55,18 +59,13 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - node = self.parse_and_analyze(test_fn, {}) - node = break_statements.transform(node, self.ctx) - - with self.compiled(node) as result: + with self.converted(test_fn, break_statements, {}, + constant_op.constant) as result: # The break is incompletely canonicalized. The loop will not interrupt, # but the section following the break will be skipped. - self.assertEqual([], result.test_fn([])) - self.assertEqual([3, 3], result.test_fn([4, 4])) - self.assertEqual([3], result.test_fn([4, 5])) self.assertEqual([3], result.test_fn([5, 4])) - def test_deeply_nested(self): + def test_nested(self): def test_fn(x): v = [] @@ -83,13 +82,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u, w - node = self.parse_and_analyze(test_fn, {}) - node = break_statements.transform(node, self.ctx) - - with self.compiled(node) as result: - self.assertEqual(([], [], []), result.test_fn(0)) - self.assertEqual(([2, 1], [2], [0]), result.test_fn(3)) - self.assertEqual(([10, 9, 8, 7], [10, 8], [6]), result.test_fn(11)) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 11) def test_nested_loops(self): @@ -109,16 +105,13 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u - node = self.parse_and_analyze(test_fn, {}) - node = break_statements.transform(node, self.ctx) - - with self.compiled(node) as result: - self.assertEqual(([], []), result.test_fn(0)) - self.assertEqual(([1], []), result.test_fn(2)) - self.assertEqual(([2, 1], [1]), result.test_fn(3)) - self.assertEqual(([4, 3, 2, 1], [3, 1]), result.test_fn(5)) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 5) - def test_loop_else(self): + def test_loop_orelse(self): def test_fn(x): v = [] @@ -134,13 +127,10 @@ class BreakCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u - node = self.parse_and_analyze(test_fn, {}) - node = break_statements.transform(node, self.ctx) - - with self.compiled(node) as result: - self.assertEqual(([], []), result.test_fn(0)) - self.assertEqual(([], [1]), result.test_fn(2)) - self.assertEqual(([2], [1]), result.test_fn(3)) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, 3) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py index e9000e518ce14f9e0ea486d5b3e374439b8c78ca..d5c3e2c250cc1ee0205fd1941040bf70de4a149a 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys - import six from tensorflow.contrib.autograph.converters import builtin_functions @@ -36,55 +34,39 @@ class BuiltinFunctionsTest(converter_testing.TestCase): def test_fn(a): return len(a) - node = self.parse_and_analyze(test_fn, {'len': len}) - node = builtin_functions.transform(node, self.ctx) - - with self.compiled(node, array_ops.shape) as result: + with self.converted(test_fn, builtin_functions, {'len': len}, + array_ops.shape) as result: with self.test_session() as sess: - self.assertEqual(3, - sess.run( - result.test_fn(constant_op.constant([0, 0, 0])))) - - self.assertEqual(3, result.test_fn([0, 0, 0])) + ops = result.test_fn(constant_op.constant([0, 0, 0])) + self.assertEqual(sess.run(ops), 3) def test_print(self): - def test_fn(a): - print(a) + if six.PY2: + return - node = self.parse_and_analyze(test_fn, {'print': print}) - node = builtin_functions.transform(node, self.ctx) + def test_fn(a): + return print(a) - with self.compiled(node) as result: + with self.converted(test_fn, builtin_functions, {'print': print}) as result: with self.test_session() as sess: - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - result.test_fn(constant_op.constant('a')) - sess.run(sess.graph.get_operations()) - self.assertEqual(out_capturer.getvalue(), 'a\n') - finally: - sys.stdout = sys.__stdout__ + with self.assertPrints('a\n'): + sess.run(result.test_fn('a')) - def test_print_with_op_multiple_values(self): + def test_print_multiple_values(self): - def test_fn(a, b, c): - print(a, b, c) + if six.PY2: + return - node = self.parse_and_analyze(test_fn, {'print': print}) - node = builtin_functions.transform(node, self.ctx) + def test_fn(a, b, c): + return print(a, b, c) - with self.compiled(node) as result: + with self.converted(test_fn, builtin_functions, {'print': print}) as result: with self.test_session() as sess: - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - result.test_fn( - constant_op.constant('a'), constant_op.constant(1), [2, 3]) - sess.run(sess.graph.get_operations()) - self.assertEqual(out_capturer.getvalue(), 'a 1 [2, 3]\n') - finally: - sys.stdout = sys.__stdout__ + with self.assertPrints('a 1 [2, 3]\n'): + sess.run( + result.test_fn( + constant_op.constant('a'), constant_op.constant(1), [2, 3])) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py index a36b3d77a9233daed864c616306b2ad27f582a38..2d1bed3367fa0b283200b775c5953da80c855367 100644 --- a/tensorflow/contrib/autograph/converters/call_trees.py +++ b/tensorflow/contrib/autograph/converters/call_trees.py @@ -238,7 +238,7 @@ class CallTreeTransformer(converter.Base): # Before we could convert all the time though, we'd need a reasonable # caching mechanism. template = """ - ag__.converted_call(func, True, False, {}, args) + ag__.converted_call(func, True, False, False, {}, args) """ call_expr = templates.replace(template, func=node.func, args=node.args) new_call = call_expr[0].value diff --git a/tensorflow/contrib/autograph/converters/call_trees_test.py b/tensorflow/contrib/autograph/converters/call_trees_test.py index 27d8281b856f505062ceacc8ad50c8cbc2ce6c81..8cdba659eee264717204cc6048bbe0b8bbfe245f 100644 --- a/tensorflow/contrib/autograph/converters/call_trees_test.py +++ b/tensorflow/contrib/autograph/converters/call_trees_test.py @@ -36,37 +36,34 @@ class CallTreesTest(converter_testing.TestCase): def test_fn_1(_): raise ValueError('This should not be called in the compiled version.') - def renamed_test_fn_1(a): + def other_test_fn_1(a): return a + 1 def test_fn_2(a): return test_fn_1(a) + 1 - node = self.parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1}) - node = call_trees.transform(node, self.ctx) + ns = {'test_fn_1': test_fn_1} + node, ctx = self.prepare(test_fn_2, ns) + node = call_trees.transform(node, ctx) - with self.compiled(node) as result: - # Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 - # manually. - result.renamed_test_fn_1 = renamed_test_fn_1 - self.assertEquals(3, result.test_fn_2(1)) + with self.compiled(node, ns) as result: + new_name, _ = ctx.namer.compiled_function_name(('test_fn_1',)) + setattr(result, new_name, other_test_fn_1) + self.assertEquals(result.test_fn_2(1), 3) def test_dynamic_function(self): def test_fn_1(): - raise ValueError('This should be masked by the mock.') + raise ValueError('This should be masked by the mock in self.compiled.') def test_fn_2(f): return f() + 3 - node = self.parse_and_analyze(test_fn_2, {}) - node = call_trees.transform(node, self.ctx) - - with self.compiled(node) as result: + with self.converted(test_fn_2, call_trees, {}) as result: # 10 = 7 (from the mock) + 3 (from test_fn_2) self.assertEquals(10, result.test_fn_2(test_fn_1)) - def test_simple_methods(self): + def test_basic_method(self): class TestClass(object): @@ -76,49 +73,43 @@ class CallTreesTest(converter_testing.TestCase): def test_fn_2(self, a): return self.test_fn_1(a) + 1 - node = self.parse_and_analyze( - TestClass.test_fn_2, {'TestClass': TestClass}, + ns = {'TestClass': TestClass} + node, ctx = self.prepare( + TestClass.test_fn_2, + ns, namer=converter_testing.FakeNoRenameNamer(), arg_types={'self': (TestClass.__name__, TestClass)}) - node = call_trees.transform(node, self.ctx) + node = call_trees.transform(node, ctx) - with self.compiled(node) as result: + with self.compiled(node, ns) as result: tc = TestClass() self.assertEquals(3, result.test_fn_2(tc, 1)) - def test_py_func_wrap_no_retval(self): + def test_py_func_no_retval(self): def test_fn(a): setattr(a, 'foo', 'bar') - node = self.parse_and_analyze(test_fn, {'setattr': setattr}) - node = call_trees.transform(node, self.ctx) - - with self.compiled(node) as result: + with self.converted(test_fn, call_trees, {'setattr': setattr}) as result: with self.test_session() as sess: - # The function has no return value, so we do some tricks to grab the - # generated py_func node and ensure its effect only happens at graph - # execution. class Dummy(object): pass a = Dummy() result.test_fn(a) + py_func_op, = sess.graph.get_operations() self.assertFalse(hasattr(a, 'foo')) - sess.run(sess.graph.get_operations()[0]) + sess.run(py_func_op) self.assertEquals('bar', a.foo) - def test_py_func_wrap_known_function(self): + def test_py_func_known_function(self): def test_fn(): return np.random.binomial(2, 0.5) - node = self.parse_and_analyze(test_fn, {'np': np}) - node = call_trees.transform(node, self.ctx) - - with self.compiled(node, dtypes.int64) as result: - result.np = np + with self.converted(test_fn, call_trees, {'np': np}, + dtypes.int64) as result: with self.test_session() as sess: self.assertTrue(isinstance(result.test_fn(), ops.Tensor)) self.assertIn(sess.run(result.test_fn()), (0, 1, 2)) @@ -130,22 +121,17 @@ class CallTreesTest(converter_testing.TestCase): a = math_ops.add(a, constant_op.constant(1)) return a - node = self.parse_and_analyze( - test_fn, { - 'math_ops': math_ops, - 'constant_op': constant_op - }, + ns = {'math_ops': math_ops, 'constant_op': constant_op} + node, ctx = self.prepare( + test_fn, + ns, arg_types=set(((math_ops.__name__,), (constant_op.__name__,)))) - node = call_trees.transform(node, self.ctx) + node = call_trees.transform(node, ctx) - with self.compiled(node) as result: - result.math_ops = math_ops - result.constant_op = constant_op + with self.compiled(node, ns) as result: with self.test_session() as sess: - # Not renamed, because the converter doesn't rename the definition - # itself (the caller is responsible for that). result_tensor = result.test_fn(constant_op.constant(1)) - self.assertEquals(3, sess.run(result_tensor)) + self.assertEquals(sess.run(result_tensor), 3) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/conditional_expressions.py b/tensorflow/contrib/autograph/converters/conditional_expressions.py new file mode 100644 index 0000000000000000000000000000000000000000..63f649dfdf5f740ba66260a51175a0ec2b716ea3 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/conditional_expressions.py @@ -0,0 +1,129 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converts the ternary conditional operator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno + + +class _FunctionDefs(object): + + def __init__(self): + self.nodes = [] + + +class _Statement(object): + + def __init__(self): + self.scope = None + + +class ConditionalExpressionTransformer(converter.Base): + """Converts contitional expressions to functional form.""" + + def _postprocess_statement(self, node): + """Inserts any separate functions that node may use.""" + replacements = [] + for def_node in self.state[_FunctionDefs].nodes: + replacements.extend(def_node) + replacements.append(node) + node = replacements + # The corresponding enter is called by self.visit_block (see _process_block) + self.state[_FunctionDefs].exit() + return node, None + + def _create_branch(self, expr, name_stem): + scope = self.state[_Statement].scope + name = self.ctx.namer.new_symbol(name_stem, scope.referenced) + template = """ + def name(): + return expr, + """ + node = templates.replace(template, name=name, expr=expr) + self.state[_FunctionDefs].nodes.append(node) + return name + + def visit_IfExp(self, node): + if anno.hasanno(node.test, anno.Basic.QN): + name_root = anno.getanno(node.test, anno.Basic.QN).ssf() + else: + name_root = 'ifexp' + + true_fn_name = self._create_branch(node.body, '%s_true' % name_root) + false_fn_name = self._create_branch(node.orelse, '%s_false' % name_root) + + return templates.replace_as_expression( + 'ag__.utils.run_cond(test, true_fn_name, false_fn_name)', + test=node.test, + true_fn_name=true_fn_name, + false_fn_name=false_fn_name) + + def _process_block(self, scope, block): + self.state[_Statement].enter() + self.state[_Statement].scope = scope + block = self.visit_block( + block, + before_visit=self.state[_FunctionDefs].enter, + after_visit=self._postprocess_statement) + self.state[_Statement].exit() + return block + + def visit_FunctionDef(self, node): + node.args = self.generic_visit(node.args) + node.decorator_list = self.visit_block(node.decorator_list) + node.body = self._process_block( + anno.getanno(node, anno.Static.SCOPE), node.body) + return node + + def visit_For(self, node): + node.target = self.visit(node.target) + node.body = self._process_block( + anno.getanno(node, NodeAnno.BODY_SCOPE), node.body) + node.orelse = self._process_block( + anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse) + return node + + def visit_While(self, node): + node.test = self.visit(node.test) + node.body = self._process_block( + anno.getanno(node, NodeAnno.BODY_SCOPE), node.body) + node.orelse = self._process_block( + anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse) + return node + + def visit_If(self, node): + node.test = self.visit(node.test) + node.body = self._process_block( + anno.getanno(node, NodeAnno.BODY_SCOPE), node.body) + node.orelse = self._process_block( + anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse) + return node + + def visit_With(self, node): + node.items = self.visit_block(node.items) + node.body = self._process_block( + anno.getanno(node, NodeAnno.BODY_SCOPE), node.body) + return node + + +def transform(node, ctx): + node = ConditionalExpressionTransformer(ctx).visit(node) + return node diff --git a/tensorflow/contrib/autograph/converters/conditional_expressions_test.py b/tensorflow/contrib/autograph/converters/conditional_expressions_test.py new file mode 100644 index 0000000000000000000000000000000000000000..95a3108741800c5fe504690f92876fa63edd8651 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/conditional_expressions_test.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for conditional_expressions module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.converters import conditional_expressions +from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.python.platform import test + + +class ConditionalExpressionsTest(converter_testing.TestCase): + + def assertTransformedEquivalent(self, test_fn, *inputs): + ns = {} + with self.converted(test_fn, conditional_expressions, ns) as result: + self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) + + def test_basic(self): + + def test_fn(x): + return 1 if x else 0 + + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 3) + + def test_nested_orelse(self): + + def test_fn(x): + y = x * x if x > 0 else x if x else 1 + return y + + self.assertTransformedEquivalent(test_fn, -2) + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 2) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py index 958bde0a58764e705c35ab73ce879b2c11ce7cdc..0476e97c15e33dcfc09b3555cf8dc7ff3fd7ce19 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements.py +++ b/tensorflow/contrib/autograph/converters/continue_statements.py @@ -37,7 +37,7 @@ class ContinueCanonicalizationTransformer(converter.Base): def visit_Continue(self, node): self.set_local(CONTINUE_USED, True) template = """ - var_name = True + var_name = tf.constant(True) """ return templates.replace( template, var_name=self.get_local(CONTROL_VAR_NAME)) @@ -92,7 +92,7 @@ class ContinueCanonicalizationTransformer(converter.Base): if self.get_local(CONTINUE_USED, False): template = """ - var_name = False + var_name = tf.constant(False) """ control_var_init = templates.replace(template, var_name=continue_var) nodes = control_var_init + nodes diff --git a/tensorflow/contrib/autograph/converters/continue_statements_test.py b/tensorflow/contrib/autograph/converters/continue_statements_test.py index 2ce1837972c50bbc4921487a290f5cb2f782b5f3..37c15211b4fe266e57879249fe7e060ded44dc1f 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements_test.py +++ b/tensorflow/contrib/autograph/converters/continue_statements_test.py @@ -20,12 +20,19 @@ from __future__ import print_function from tensorflow.contrib.autograph.converters import continue_statements from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.python.eager import context as tfe_ctx +from tensorflow.python.framework import constant_op from tensorflow.python.platform import test class ContinueCanonicalizationTest(converter_testing.TestCase): - def test_basic_continue(self): + def assertTransformedEquivalent(self, test_fn, *inputs): + with self.converted(test_fn, continue_statements, {}, + constant_op.constant) as result: + self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) + + def test_basic(self): def test_fn(x): v = [] @@ -36,17 +43,13 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - node = self.parse_and_analyze(test_fn, {}) - node = continue_statements.transform(node, self.ctx) - - with self.compiled(node) as result: - self.assertEqual(test_fn(0), result.test_fn(0)) - self.assertEqual(test_fn(1), result.test_fn(1)) - self.assertEqual(test_fn(2), result.test_fn(2)) - self.assertEqual(test_fn(3), result.test_fn(3)) - self.assertEqual(test_fn(4), result.test_fn(4)) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 1) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 4) - def test_basic_continue_for_loop(self): + def test_for_loop(self): def test_fn(a): v = [] @@ -57,16 +60,13 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v - node = self.parse_and_analyze(test_fn, {}) - node = continue_statements.transform(node, self.ctx) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, []) + self.assertTransformedEquivalent(test_fn, [1]) + self.assertTransformedEquivalent(test_fn, [2]) + self.assertTransformedEquivalent(test_fn, [1, 2, 3]) - with self.compiled(node) as result: - self.assertEqual(test_fn([]), result.test_fn([])) - self.assertEqual(test_fn([1]), result.test_fn([1])) - self.assertEqual(test_fn([2]), result.test_fn([2])) - self.assertEqual(test_fn([1, 2, 3]), result.test_fn([1, 2, 3])) - - def test_continue_deeply_nested(self): + def test_nested(self): def test_fn(x): v = [] @@ -83,15 +83,11 @@ class ContinueCanonicalizationTest(converter_testing.TestCase): v.append(x) return v, u, w - node = self.parse_and_analyze(test_fn, {}) - node = continue_statements.transform(node, self.ctx) - - with self.compiled(node) as result: - self.assertEqual(test_fn(0), result.test_fn(0)) - self.assertEqual(test_fn(1), result.test_fn(1)) - self.assertEqual(test_fn(2), result.test_fn(2)) - self.assertEqual(test_fn(3), result.test_fn(3)) - self.assertEqual(test_fn(4), result.test_fn(4)) + with tfe_ctx.eager_mode(): + self.assertTransformedEquivalent(test_fn, 0) + self.assertTransformedEquivalent(test_fn, 1) + self.assertTransformedEquivalent(test_fn, 3) + self.assertTransformedEquivalent(test_fn, 4) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py index f4a87106279d5658ecaa90a577cbe741711ba22e..5a5a6ad63a777f463e80e061d4870f2ee7491c39 100644 --- a/tensorflow/contrib/autograph/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -25,8 +25,7 @@ from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct.static_analysis import cfg -from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct.static_analysis import annos class SymbolNamer(object): @@ -47,6 +46,7 @@ class SymbolNamer(object): class ControlFlowTransformer(converter.Base): """Transforms control flow structures like loops an conditionals.""" + def _create_cond_branch(self, body_name, aliased_orig_names, aliased_new_names, body, returns): if aliased_orig_names: @@ -90,55 +90,51 @@ class ControlFlowTransformer(converter.Base): return templates.replace( template, test=test, body_name=body_name, orelse_name=orelse_name) - def visit_If(self, node): - self.generic_visit(node) + def _fmt_symbol_list(self, symbol_set): + if not symbol_set: + return 'no variables' + return ', '.join(map(str, symbol_set)) - body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) - body_defs = body_scope.created | body_scope.modified - orelse_defs = orelse_scope.created | orelse_scope.modified - live = anno.getanno(node, 'live_out') - - # We'll need to check if we're closing over variables that are defined - # elsewhere in the function - # NOTE: we can only detect syntactic closure in the scope - # of the code passed in. If the AutoGraph'd function itself closes - # over other variables, this analysis won't take that into account. - defined = anno.getanno(node, 'defined_in') - - # We only need to return variables that are - # - modified by one or both branches - # - live (or has a live parent) at the end of the conditional - modified = [] - for def_ in body_defs | orelse_defs: - def_with_parents = set((def_,)) | def_.support_set - if live & def_with_parents: - modified.append(def_) - - # We need to check if live created variables are balanced - # in both branches - created = live & (body_scope.created | orelse_scope.created) - - # The if statement is illegal if there are variables that are created, - # that are also live, but both branches don't create them. - if created: - if created != (body_scope.created & live): - raise ValueError( - 'The main branch does not create all live symbols that the else ' - 'branch does.') - if created != (orelse_scope.created & live): - raise ValueError( - 'The else branch does not create all live symbols that the main ' - 'branch does.') - - # Alias the closure variables inside the conditional functions - # to avoid errors caused by the local variables created in the branch - # functions. + def visit_If(self, node): + node = self.generic_visit(node) + + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) + defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) + live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) + + modified_in_cond = body_scope.modified | orelse_scope.modified + returned_from_cond = set() + for s in modified_in_cond: + if s in live_out: + returned_from_cond.add(s) + elif s.is_composite(): + # Special treatment for compound objects: if any of their owner entities + # are live, then they are outputs as well. + if any(owner in live_out for owner in s.owner_set): + returned_from_cond.add(s) + + need_alias_in_body = body_scope.modified & defined_in + need_alias_in_orelse = orelse_scope.modified & defined_in + + created_in_body = body_scope.modified & returned_from_cond - defined_in + created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in + + if created_in_body != created_in_orelse: + raise ValueError( + 'if statement may not initialize all variables: the true branch' + ' creates %s, while the false branch creates %s. Make sure all' + ' these variables are initialized either in both' + ' branches or before the if statement.' % + (self._fmt_symbol_list(created_in_body), + self._fmt_symbol_list(created_in_orelse))) + + # Alias the closure variables inside the conditional functions, to allow + # the functions access to the respective variables. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. - aliased_body_orig_names = tuple(body_scope.modified - body_scope.created) - aliased_orelse_orig_names = tuple(orelse_scope.modified - - orelse_scope.created) + aliased_body_orig_names = tuple(need_alias_in_body) + aliased_orelse_orig_names = tuple(need_alias_in_orelse) aliased_body_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) @@ -153,58 +149,47 @@ class ControlFlowTransformer(converter.Base): node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) - if not modified: + returned_from_cond = tuple(returned_from_cond) + if returned_from_cond: + if len(returned_from_cond) == 1: + # TODO(mdan): Move this quirk into the operator implementation. + cond_results = returned_from_cond[0] + else: + cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None) + + returned_from_body = tuple( + alias_body_map[s] if s in need_alias_in_body else s + for s in returned_from_cond) + returned_from_orelse = tuple( + alias_orelse_map[s] if s in need_alias_in_orelse else s + for s in returned_from_cond) + + else: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. - results = None - elif len(modified) == 1: - results = modified[0] - else: - results = gast.Tuple([s.ast() for s in modified], None) + cond_results = None + # TODO(mdan): This doesn't belong here; it's specific to the operator. + returned_from_body = templates.replace_as_expression('tf.constant(1)') + returned_from_orelse = templates.replace_as_expression('tf.constant(1)') body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) - if modified: - - def build_returns(aliased_names, alias_map, scope): - """Builds list of return variables for a branch of a conditional.""" - returns = [] - for s in modified: - if s in aliased_names: - returns.append(alias_map[s]) - else: - if s not in scope.created | defined: - raise ValueError( - 'Attempting to return variable "%s" from the true branch of ' - 'a conditional, but it was not closed over, or created in ' - 'this branch.' % str(s)) - else: - returns.append(s) - return tuple(returns) - - body_returns = build_returns(aliased_body_orig_names, alias_body_map, - body_scope) - orelse_returns = build_returns(aliased_orelse_orig_names, - alias_orelse_map, orelse_scope) - - else: - body_returns = orelse_returns = templates.replace('tf.ones(())')[0].value body_def = self._create_cond_branch( body_name, - aliased_orig_names=tuple(aliased_body_orig_names), - aliased_new_names=tuple(aliased_body_new_names), + aliased_orig_names=aliased_body_orig_names, + aliased_new_names=aliased_body_new_names, body=node_body, - returns=body_returns) + returns=returned_from_body) orelse_def = self._create_cond_branch( orelse_name, - aliased_orig_names=tuple(aliased_orelse_orig_names), - aliased_new_names=tuple(aliased_orelse_new_names), + aliased_orig_names=aliased_orelse_orig_names, + aliased_new_names=aliased_orelse_new_names, body=node_orelse, - returns=orelse_returns) - cond_expr = self._create_cond_expr(results, node.test, body_name, + returns=returned_from_orelse) + cond_expr = self._create_cond_expr(cond_results, node.test, body_name, orelse_name) return body_def + orelse_def + cond_expr @@ -212,11 +197,11 @@ class ControlFlowTransformer(converter.Base): def visit_While(self, node): self.generic_visit(node) - body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced - cond_scope = anno.getanno(node, NodeAnno.COND_SCOPE) + cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE) cond_closure = set() for s in cond_scope.referenced: for root in s.support_set: @@ -277,7 +262,7 @@ class ControlFlowTransformer(converter.Base): def visit_For(self, node): self.generic_visit(node) - body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced @@ -331,7 +316,5 @@ class ControlFlowTransformer(converter.Base): def transform(node, ctx): - cfg.run_analyses(node, cfg.Liveness(ctx.info)) - cfg.run_analyses(node, cfg.Defined(ctx.info)) node = ControlFlowTransformer(ctx).visit(node) return node diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py index 735eb92a0dd06ee7fd621b92b1a8f894e09cee4a..ade35014263c3ae4ec14b40ee0f2507b70627d41 100644 --- a/tensorflow/contrib/autograph/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -20,16 +20,23 @@ from __future__ import print_function from tensorflow.contrib.autograph.converters import control_flow from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test class ControlFlowTest(converter_testing.TestCase): - def test_simple_while(self): + def assertTransformedResult(self, test_fn, inputs, expected): + if not isinstance(inputs, tuple): + inputs = (inputs,) + with self.converted(test_fn, control_flow, {}, + constant_op.constant) as result: + with self.test_session() as sess: + self.assertEqual(sess.run(result.test_fn(*inputs)), expected) + + def test_while_basic(self): def test_fn(n): i = 0 @@ -39,29 +46,18 @@ class ControlFlowTest(converter_testing.TestCase): i += 1 return s, i, n - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, self.ctx) - - with self.compiled(node) as result: - with self.test_session() as sess: - self.assertEqual((10, 5, 5), - sess.run(result.test_fn(constant_op.constant(5)))) + self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 5, 5)) - def test_while_single_var(self): + def test_while_single_output(self): def test_fn(n): while n > 0: n -= 1 return n - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, self.ctx) + self.assertTransformedResult(test_fn, constant_op.constant(5), 0) - with self.compiled(node) as result: - with self.test_session() as sess: - self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5)))) - - def test_simple_if(self): + def test_if_basic(self): def test_fn(n): a = 0 @@ -72,114 +68,85 @@ class ControlFlowTest(converter_testing.TestCase): b = 2 * n return a, b - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, self.ctx) + self.assertTransformedResult(test_fn, constant_op.constant(1), (-1, 0)) + self.assertTransformedResult(test_fn, constant_op.constant(-1), (0, -2)) + + def test_if_complex_outputs(self): + + class TestClass(object): - with self.compiled(node) as result: + def __init__(self, a, b): + self.a = a + self.b = b + + def test_fn(n, obj): + obj.a = 0 + obj.b = 0 + if n > 0: + obj.a = -n + else: + obj.b = 2 * n + return obj + + with self.converted(test_fn, control_flow, {}) as result: with self.test_session() as sess: - self.assertEqual((-1, 0), - sess.run(result.test_fn(constant_op.constant(1)))) - self.assertEqual((0, -2), - sess.run(result.test_fn(constant_op.constant(-1)))) + res_obj = result.test_fn(constant_op.constant(1), TestClass(0, 0)) + self.assertEqual(sess.run((res_obj.a, res_obj.b)), (-1, 0)) + res_obj = result.test_fn(constant_op.constant(-1), TestClass(0, 0)) + self.assertEqual(sess.run((res_obj.a, res_obj.b)), (0, -2)) - def test_if_single_var(self): + def test_if_single_output(self): def test_fn(n): if n > 0: n = -n return n - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, self.ctx) + self.assertTransformedResult(test_fn, constant_op.constant(1), -1) - with self.compiled(node) as result: - with self.test_session() as sess: - self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1)))) - - def test_imbalanced_aliasing(self): + def test_if_semi(self): def test_fn(n): if n > 0: n = 3 return n - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, self.ctx) - - with self.compiled(node, control_flow_ops.cond) as result: - with self.test_session() as sess: - self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(2)))) - self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3)))) + self.assertTransformedResult(test_fn, constant_op.constant(2), 3) + self.assertTransformedResult(test_fn, constant_op.constant(-3), -3) - def test_ignore_unread_variable(self): + def test_if_local_var(self): def test_fn(n): - b = 3 # pylint: disable=unused-variable if n > 0: b = 4 + n = b + 1 return n - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, self.ctx) + self.assertTransformedResult(test_fn, constant_op.constant(1), 5) + self.assertTransformedResult(test_fn, constant_op.constant(-1), -1) - with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: - with self.test_session() as sess: - self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(3)))) - self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3)))) + def test_if_no_outputs(self): - def test_handle_temp_variable(self): + def test_fn(n): + if n > 0: + b = 4 # pylint:disable=unused-variable + return n - def test_fn_using_temp(x, y, w): - if x < y: - z = x + y - else: - w = 2 - tmp = w - z = x - tmp - return z, w + # Without side effect guards, the if statement will stage a cond, + # but that will be pruned at execution. + self.assertTransformedResult(test_fn, constant_op.constant(1), 1) + self.assertTransformedResult(test_fn, constant_op.constant(-1), -1) - node = self.parse_and_analyze(test_fn_using_temp, {}) - node = control_flow.transform(node, self.ctx) + def test_if_imbalanced_outputs(self): - with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: - with self.test_session() as sess: - z, w = sess.run( - result.test_fn_using_temp( - constant_op.constant(-3), constant_op.constant(3), - constant_op.constant(3))) - self.assertEqual(0, z) - self.assertEqual(3, w) - z, w = sess.run( - result.test_fn_using_temp( - constant_op.constant(3), constant_op.constant(-3), - constant_op.constant(3))) - self.assertEqual(1, z) - self.assertEqual(2, w) - - def test_fn_ignoring_temp(x, y, w): - if x < y: - z = x + y - else: - w = 2 - tmp = w - z = x - tmp - return z + def test_fn(n): + if n > 0: + b = 4 + return b - node = self.parse_and_analyze(test_fn_ignoring_temp, {}) - node = control_flow.transform(node, self.ctx) - - with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: - with self.test_session() as sess: - z = sess.run( - result.test_fn_ignoring_temp( - constant_op.constant(-3), constant_op.constant(3), - constant_op.constant(3))) - self.assertEqual(0, z) - z = sess.run( - result.test_fn_ignoring_temp( - constant_op.constant(3), constant_op.constant(-3), - constant_op.constant(3))) - self.assertEqual(1, z) + node, ctx = self.prepare(test_fn, {}) + with self.assertRaises(transformer.AutographParseError): + control_flow.transform(node, ctx) def test_simple_for(self): @@ -191,22 +158,11 @@ class ControlFlowTest(converter_testing.TestCase): s2 += e * e return s1, s2 - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, self.ctx) + self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), (4, 10)) + empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32) + self.assertTransformedResult(test_fn, empty_vector, (0, 0)) - with self.compiled(node) as result: - with self.test_session() as sess: - l = [1, 2, 3] - self.assertEqual( - test_fn(l), sess.run(result.test_fn(constant_op.constant(l)))) - l = [] - self.assertEqual( - test_fn(l), - sess.run( - result.test_fn( - constant_op.constant(l, shape=(0,), dtype=dtypes.int32)))) - - def test_for_single_var(self): + def test_for_single_output(self): def test_fn(l): s = 0 @@ -214,22 +170,11 @@ class ControlFlowTest(converter_testing.TestCase): s += e return s - node = self.parse_and_analyze(test_fn, {}) - node = control_flow.transform(node, self.ctx) + self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), 4) + empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32) + self.assertTransformedResult(test_fn, empty_vector, 0) - with self.compiled(node) as result: - with self.test_session() as sess: - l = [1, 2, 3] - self.assertEqual( - test_fn(l), sess.run(result.test_fn(constant_op.constant(l)))) - l = [] - self.assertEqual( - test_fn(l), - sess.run( - result.test_fn( - constant_op.constant(l, shape=(0,), dtype=dtypes.int32)))) - - def test_for_with_iterated_expression(self): + def test_for_iterated_expression(self): eval_count = [0] @@ -243,14 +188,13 @@ class ControlFlowTest(converter_testing.TestCase): s += e return s - node = self.parse_and_analyze(test_fn, {'count_evals': count_evals}) - node = control_flow.transform(node, self.ctx) + ns = {'count_evals': count_evals} + node, ctx = self.prepare(test_fn, ns) + node = control_flow.transform(node, ctx) - with self.compiled(node) as result: - result.count_evals = count_evals - self.assertEqual(test_fn(5), result.test_fn(5)) - # count_evals ran twice, once for test_fn and another for result.test_fn - self.assertEqual(eval_count[0], 2) + with self.compiled(node, ns) as result: + self.assertEqual(result.test_fn(5), 10) + self.assertEqual(eval_count[0], 1) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/decorators_test.py b/tensorflow/contrib/autograph/converters/decorators_test.py index d41c7fde2474803a438100e7e00ce8e9f675de45..095abc5edc02de55cd0b28d9aa9f9c4e7cec13c3 100644 --- a/tensorflow/contrib/autograph/converters/decorators_test.py +++ b/tensorflow/contrib/autograph/converters/decorators_test.py @@ -61,13 +61,13 @@ class DecoratorsTest(converter_testing.TestCase): 'simple_decorator': simple_decorator, 'converter_testing': converter_testing, } - node = self.parse_and_analyze( + node, ctx = self.prepare( f, namespace, recursive=False, autograph_decorators=autograph_decorators) - node = decorators.transform(node, self.ctx) - import_line = '\n'.join(self.ctx.program.additional_imports) + node = decorators.transform(node, ctx) + import_line = '\n'.join(ctx.program.additional_imports) result, _ = compiler.ast_to_object(node, source_prefix=import_line) return getattr(result, f.__name__) @@ -76,11 +76,8 @@ class DecoratorsTest(converter_testing.TestCase): def test_fn(a): return a - node = self.parse_and_analyze(test_fn, {}) - node = decorators.transform(node, self.ctx) - result, _ = compiler.ast_to_object(node) - - self.assertEqual(1, result.test_fn(1)) + with self.converted(test_fn, decorators, {}) as result: + self.assertEqual(1, result.test_fn(1)) def test_function(self): @@ -124,7 +121,7 @@ class DecoratorsTest(converter_testing.TestCase): return b + 11 return inner_fn(a) - # Expected to fail because simple_decorator cannot be imported. + # Expected to fail because simple_decorator could not be imported. with self.assertRaises(transformer.AutographParseError): test_fn(1) diff --git a/tensorflow/contrib/autograph/converters/directives.py b/tensorflow/contrib/autograph/converters/directives.py new file mode 100644 index 0000000000000000000000000000000000000000..77f625bac792621c45799d1a220f99eb4b99f7af --- /dev/null +++ b/tensorflow/contrib/autograph/converters/directives.py @@ -0,0 +1,128 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Handles directives. + +This converter removes the directive functions from the code and moves the +information they specify into AST annotations. It is a specialized form of +static analysis, one that is specific to AutoGraph. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.lang import directives +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.python.util import tf_inspect + +ENCLOSING_LOOP = 'enclosing_loop' + + +def _map_args(call_node, function): + """Maps AST call nodes to the actual function's arguments. + + Args: + call_node: ast.Call + function: Callable[..., Any], the actual function matching call_node + Returns: + Dict[Text, ast.AST], mapping each of the function's argument names to + the respective AST node. + Raises: + ValueError: if the default arguments are not correctly set + """ + args = call_node.args + kwds = {kwd.arg: kwd.value for kwd in call_node.keywords} + call_args = tf_inspect.getcallargs(function, *args, **kwds) + + # Keyword arguments not specified in kwds will be mapped to their defaults, + # which are Python values. Since we don't currently have a way to transform + # those into AST references, we simply remove them. By convention, directives + # use UNSPECIFIED as default value for for optional arguments. No other + # defaults should be present. + unexpected_defaults = [] + for k in call_args: + if (k not in kwds + and call_args[k] not in args + and call_args[k] is not directives.UNSPECIFIED): + unexpected_defaults.append(k) + if unexpected_defaults: + raise ValueError('Unexpected keyword argument values, %s, for function %s' + % (zip(unexpected_defaults, + [call_args[k] for k in unexpected_defaults]), + function)) + return {k: v for k, v in call_args.items() if v is not directives.UNSPECIFIED} + + +class DirectivesTransformer(converter.Base): + """Parses compiler directives and converts them into AST annotations.""" + + def _process_symbol_directive(self, call_node, directive): + if len(call_node.args) < 1: + raise ValueError('"%s" requires a positional first argument' + ' as the target' % directive.__name__) + target = call_node.args[0] + defs = anno.getanno(target, anno.Static.ORIG_DEFINITIONS) + for def_ in defs: + def_.directives[directive] = _map_args(call_node, directive) + return call_node + + def _process_statement_directive(self, call_node, directive): + if self.local_scope_level < 1: + raise ValueError( + '"%s" must be used inside a statement' % directive.__name__) + target = self.get_local(ENCLOSING_LOOP) + node_anno = anno.getanno(target, converter.AgAnno.DIRECTIVES, {}) + node_anno[directive] = _map_args(call_node, directive) + anno.setanno(target, converter.AgAnno.DIRECTIVES, node_anno) + return call_node + + def visit_Expr(self, node): + if isinstance(node.value, gast.Call): + call_node = node.value + if anno.hasanno(call_node.func, 'live_val'): + live_val = anno.getanno(call_node.func, 'live_val') + + if live_val is directives.set_element_type: + call_node = self._process_symbol_directive(call_node, live_val) + elif live_val is directives.set_loop_options: + call_node = self._process_statement_directive(call_node, live_val) + else: + return self.generic_visit(node) + + return None # Directive calls are not output in the generated code. + return self.generic_visit(node) + + # TODO(mdan): This will be insufficient for other control flow. + # That means that if we ever have a directive that affects things other than + # loops, we'll need support for parallel scopes, or have multiple converters. + def _track_and_visit_loop(self, node): + self.enter_local_scope() + self.set_local(ENCLOSING_LOOP, node) + node = self.generic_visit(node) + self.exit_local_scope() + return node + + def visit_While(self, node): + return self._track_and_visit_loop(node) + + def visit_For(self, node): + return self._track_and_visit_loop(node) + + +def transform(node, ctx): + return DirectivesTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/directives_test.py b/tensorflow/contrib/autograph/converters/directives_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a2d083b891314d2f8f3fa61b46edc347ca8e24eb --- /dev/null +++ b/tensorflow/contrib/autograph/converters/directives_test.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for directives module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.converters import directives as directives_converter +from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.contrib.autograph.core.converter import AgAnno +from tensorflow.contrib.autograph.lang import directives +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.python.platform import test + + +class DirectivesTest(converter_testing.TestCase): + + def test_local_target(self): + + def test_fn(): + l = [] + string_var = 0 + directives.set_element_type(l, 'a', string_var) + + node, ctx = self.prepare(test_fn, {'directives': directives}) + node = directives_converter.transform(node, ctx) + + def_, = anno.getanno(node.body[0].targets[0], + anno.Static.DEFINITIONS) + d = def_.directives[directives.set_element_type] + self.assertEqual(d['dtype'].s, 'a') + self.assertEqual(d['shape'].id, 'string_var') + + def test_argument_target(self): + + def test_fn(a): + directives.set_element_type(a, 1, shape=2) + + node, ctx = self.prepare(test_fn, {'directives': directives}) + node = directives_converter.transform(node, ctx) + + def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) + d = def_.directives[directives.set_element_type] + self.assertEqual(d['dtype'].n, 1) + self.assertEqual(d['shape'].n, 2) + + def test_loop_target(self): + + def test_fn(): + a = True + while True: + directives.set_loop_options(parallel_iterations=10, back_prop=a) + + node, ctx = self.prepare(test_fn, {'directives': directives}) + node = directives_converter.transform(node, ctx) + + d = anno.getanno(node.body[1], AgAnno.DIRECTIVES) + d = d[directives.set_loop_options] + self.assertEqual(d['parallel_iterations'].n, 10) + self.assertEqual(d['back_prop'].id, 'a') + self.assertNotIn('swap_memory', d) + + def test_invalid_default(self): + + def invalid_directive(valid_arg, invalid_default=object()): + del valid_arg + del invalid_default + return + + def call_invalid_directive(): + invalid_directive(1) + + node, _ = parser.parse_entity(call_invalid_directive) + # Find the call to the invalid directive + node = node.body[0].body[0].value + with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'): + directives_converter._map_args(node, invalid_directive) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/converters/error_handlers.py b/tensorflow/contrib/autograph/converters/error_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..193682139438c1d0133b17165d7f7fb84e2eaaac --- /dev/null +++ b/tensorflow/contrib/autograph/converters/error_handlers.py @@ -0,0 +1,53 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wraps function bodies with a try/except to rewrite error tracebacks. + +Only adds try/except wrappers to functions that have the anno.Basic.ORIGIN +annotation because these are the functions originally written by the user. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates + + +class ErrorRewritingTransformer(converter.Base): + """Possibly wraps the body of a function in a try/except. + + Only wraps functions that were originally defined by the user, detected by + checking for the anno.Basic.ORIGIN annotation. + """ + + def visit_FunctionDef(self, node): + node = self.generic_visit(node) + + if (anno.hasanno(node, anno.Basic.ORIGIN) and + len(self.enclosing_entities) <= 1): + template = """ + try: + body + except: + ag__.rewrite_graph_construction_error(ag_source_map__) + """ + node.body = templates.replace(template, body=node.body) + return node + + +def transform(node, ctx): + return ErrorRewritingTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/error_handlers_test.py b/tensorflow/contrib/autograph/converters/error_handlers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5d61b220afa0fcf9a9e619bbd78f83a5076c473a --- /dev/null +++ b/tensorflow/contrib/autograph/converters/error_handlers_test.py @@ -0,0 +1,59 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for error_handlers module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.converters import error_handlers +from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.contrib.autograph.core import errors +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import origin_info +from tensorflow.python.platform import test + + +class ErrorHandlersTest(converter_testing.TestCase): + + def test_basic(self): + + def test_fn(): + raise ValueError() + + node, ctx = self.prepare(test_fn, {}) + anno.setanno( + node, anno.Basic.ORIGIN, + origin_info.OriginInfo(None, 'test_function_name', 'test_code', + 'test_comment')) + node = error_handlers.transform(node, ctx) + with self.compiled(node, {}) as result: + with self.assertRaises(errors.GraphConstructionError): + # Here we just assert that the handler works. Its correctness is + # verified by errors_test.py. + result.test_fn() + + def test_no_origin_annotation(self): + + def test_fn(): + raise ValueError() + + with self.converted(test_fn, error_handlers, {}) as result: + with self.assertRaises(ValueError): + result.test_fn() + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/converters/ifexp.py b/tensorflow/contrib/autograph/converters/ifexp.py deleted file mode 100644 index e996138498ab2b7efa76671d8cc67fd4c6a9d9b8..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/converters/ifexp.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Canonicalizes the ternary conditional operator.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.autograph.core import converter -from tensorflow.contrib.autograph.pyct import templates - - -class IfExp(converter.Base): - """Canonicalizes all IfExp nodes into plain conditionals.""" - - def visit_IfExp(self, node): - template = """ - ag__.utils.run_cond(test, lambda: (body,), lambda: (orelse,)) - """ - desugared_ifexp = templates.replace_as_expression( - template, test=node.test, body=node.body, orelse=node.orelse) - return desugared_ifexp - - -def transform(node, ctx): - """Desugar IfExp nodes into plain conditionals. - - Args: - node: ast.AST, the node to transform - ctx: converter.EntityContext - - Returns: - new_node: an AST with no IfExp nodes, only conditionals. - """ - - node = IfExp(ctx).visit(node) - return node diff --git a/tensorflow/contrib/autograph/converters/ifexp_test.py b/tensorflow/contrib/autograph/converters/ifexp_test.py deleted file mode 100644 index cdd5a2f591edc1138df1c165577ed375131ddf09..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/converters/ifexp_test.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for ifexp module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.autograph import utils -from tensorflow.contrib.autograph.converters import ifexp -from tensorflow.contrib.autograph.core import converter_testing -from tensorflow.python.platform import test - - -class IfExpTest(converter_testing.TestCase): - - def compiled_fn(self, test_fn, *args): - node = self.parse_and_analyze(test_fn, {}) - node = ifexp.transform(node, self.ctx) - module = self.compiled(node, *args) - return module - - def test_simple(self): - - def test_fn(x): - return 1 if x else 0 - - with self.compiled_fn(test_fn) as result: - result.autograph_util = utils - for x in [0, 1]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_fn(self): - - def f(x): - return 3 * x - - def test_fn(x): - y = f(x * x if x > 0 else x) - return y - - with self.compiled_fn(test_fn) as result: - result.autograph_util = utils - result.f = f - for x in [-2, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_exp(self): - - def test_fn(x): - return x * x if x > 0 else x - - with self.compiled_fn(test_fn) as result: - result.autograph_util = utils - for x in [-2, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_nested(self): - - def test_fn(x): - return x * x if x > 0 else x if x else 1 - - with self.compiled_fn(test_fn) as result: - result.autograph_util = utils - for x in [-2, 0, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_in_cond(self): - - def test_fn(x): - if x > 0: - return x * x if x < 5 else x * x * x - return -x - - with self.compiled_fn(test_fn) as result: - result.autograph_util = utils - for x in [-2, 2, 5]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_assign_in_cond(self): - - def test_fn(x): - if x > 0: - x = -x if x < 5 else x - return x - - with self.compiled_fn(test_fn) as result: - result.autograph_util = utils - for x in [-2, 2, 5]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/autograph/converters/list_comprehension.py b/tensorflow/contrib/autograph/converters/list_comprehension.py deleted file mode 100644 index c4a13ee822ab84706df83256d9e9684c3f7dacba..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/converters/list_comprehension.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Canonicalizing list comprehensions into for and if statements. - -e.g. -result = [x * x for x in xs] - -becomes - -result = [] -for x in xs: - elt = x * x - result.append(elt) -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import gast - -from tensorflow.contrib.autograph.core import converter -from tensorflow.contrib.autograph.pyct import parser -from tensorflow.contrib.autograph.pyct import templates - - -class ListCompCanonicalizationTransformer(converter.Base): - """NodeTransformer to canonicalize list comprehensions.""" - - def make_update_list_node(self, list_, elt): - return templates.replace('list_.append(elt)', list_=list_, elt=elt)[0] - - def instantiate_list_node(self): - return parser.parse_str('[]').body[0].value - - def visit_Assign(self, node): - if not isinstance(node.value, gast.ListComp): - return node - if len(node.targets) > 1: - raise ValueError('Only support single assignment.') - return self.canonicalize_listcomp(node.targets[0], node.value) - - def canonicalize_listcomp(self, result_node, list_comp_node): - - make_list = templates.replace( - 'list_ = create_list', - list_=result_node, - create_list=self.instantiate_list_node()) - loop_body = self.make_update_list_node(result_node, list_comp_node.elt) - - for gen in reversed(list_comp_node.generators): - for gen_if in reversed(gen.ifs): - loop_body = templates.replace( - 'if test: loop_body', test=gen_if, loop_body=loop_body) - loop_body = templates.replace( - 'for target in iter_: loop_body', - iter_=gen.iter, - target=gen.target, - loop_body=loop_body) - - return make_list + loop_body - - -def transform(node, ctx): - return ListCompCanonicalizationTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/list_comprehensions.py b/tensorflow/contrib/autograph/converters/list_comprehensions.py new file mode 100644 index 0000000000000000000000000000000000000000..ecf4628816201a0a6ef4ca14b0f351d818d905b3 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/list_comprehensions.py @@ -0,0 +1,82 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Lowers list comprehensions into for and if statements. + +Example: + + result = [x * x for x in xs] + +becomes + + result = [] + for x in xs: + elt = x * x + result.append(elt) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.pyct import templates + + +# TODO(mdan): This should covert directly to operator calls. + + +class ListCompTransformer(converter.Base): + """Lowers list comprehensions into standard control flow.""" + + def visit_Assign(self, node): + if not isinstance(node.value, gast.ListComp): + return self.generic_visit(node) + if len(node.targets) > 1: + raise NotImplementedError('multiple assignments') + + target, = node.targets + list_comp_node = node.value + + template = """ + target = [] + """ + initialization = templates.replace(template, target=target) + + template = """ + target.append(elt) + """ + body = templates.replace(template, target=target, elt=list_comp_node.elt) + + for gen in reversed(list_comp_node.generators): + for gen_if in reversed(gen.ifs): + template = """ + if test: + body + """ + body = templates.replace(template, test=gen_if, body=body) + template = """ + for target in iter_: + body + """ + body = templates.replace( + template, iter_=gen.iter, target=gen.target, body=body) + + return initialization + body + + +def transform(node, ctx): + return ListCompTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/list_comprehension_test.py b/tensorflow/contrib/autograph/converters/list_comprehensions_test.py similarity index 59% rename from tensorflow/contrib/autograph/converters/list_comprehension_test.py rename to tensorflow/contrib/autograph/converters/list_comprehensions_test.py index 2bbee93412ce3174a14f3d60af9435dcf3b82cc6..59b5ce9ca052bd1f2201285bef90f398b35e536c 100644 --- a/tensorflow/contrib/autograph/converters/list_comprehension_test.py +++ b/tensorflow/contrib/autograph/converters/list_comprehensions_test.py @@ -12,33 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for list_comprehension module.""" +"""Tests for list_comprehensions module.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import list_comprehension +from tensorflow.contrib.autograph.converters import list_comprehensions from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.platform import test class ListCompTest(converter_testing.TestCase): + def assertTransformedEquivalent(self, test_fn, *inputs): + with self.converted(test_fn, list_comprehensions, {}) as result: + self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) + def test_basic(self): def test_fn(l): s = [e * e for e in l] return s - node = self.parse_and_analyze(test_fn, {}) - node = list_comprehension.transform(node, self.ctx) - - with self.compiled(node) as result: - l = [1, 2, 3] - self.assertEqual(test_fn(l), result.test_fn(l)) - l = [] - self.assertEqual(test_fn(l), result.test_fn(l)) + self.assertTransformedEquivalent(test_fn, []) + self.assertTransformedEquivalent(test_fn, [1, 2, 3]) def test_multiple_generators(self): @@ -46,29 +44,17 @@ class ListCompTest(converter_testing.TestCase): s = [e * e for sublist in l for e in sublist] return s - node = self.parse_and_analyze(test_fn, {}) - node = list_comprehension.transform(node, self.ctx) + self.assertTransformedEquivalent(test_fn, []) + self.assertTransformedEquivalent(test_fn, [[1], [2], [3]]) - with self.compiled(node) as result: - l = [[1], [2], [3]] - self.assertEqual(test_fn(l), result.test_fn(l)) - l = [] - self.assertEqual(test_fn(l), result.test_fn(l)) - - def test_conds(self): + def test_cond(self): def test_fn(l): s = [e * e for e in l if e > 1] return s - node = self.parse_and_analyze(test_fn, {}) - node = list_comprehension.transform(node, self.ctx) - - with self.compiled(node) as result: - l = [1, 2, 3] - self.assertEqual(test_fn(l), result.test_fn(l)) - l = [] - self.assertEqual(test_fn(l), result.test_fn(l)) + self.assertTransformedEquivalent(test_fn, []) + self.assertTransformedEquivalent(test_fn, [1, 2, 3]) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/contrib/autograph/converters/lists.py index d77a04479826779b8aa859d70f2f7ff51138f841..a02fc827b8bd92b36549599b5433118fcd9a28cf 100644 --- a/tensorflow/contrib/autograph/converters/lists.py +++ b/tensorflow/contrib/autograph/converters/lists.py @@ -33,6 +33,7 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.lang import directives from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates @@ -88,12 +89,12 @@ class ListTransformer(converter.Base): scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) target_node = node.func.value - # Attempt to use a related name if can get one. Otherwise use something + # Attempt to use a related name if one exists. Otherwise use something # generic. if anno.hasanno(target_node, anno.Basic.QN): target_name = anno.getanno(target_node, anno.Basic.QN).ssf() else: - target_name = 'list' + target_name = 'list_' pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced) pop_uses = self.get_local(POP_USES, []) @@ -104,9 +105,10 @@ class ListTransformer(converter.Base): def _replace_stack_call(self, node): assert len(node.args) == 1 - dtype = anno.getanno( + dtype = self.get_definition_directive( node.args[0], - 'element_type', + directives.set_element_type, + 'dtype', default=templates.replace_as_expression('None')) template = """ ag__.list_stack( @@ -134,7 +136,10 @@ class ListTransformer(converter.Base): node = self._replace_append_call(node) elif func_name == 'pop' and (len(node.args) <= 1): node = self._replace_pop_call(node) - elif func_name == 'stack' and (len(node.args) == 1): + elif (func_name == 'stack' and (len(node.args) == 1) and + (not node.keywords or node.keywords[0].arg == 'strict')): + # This avoids false positives with keyword args. + # TODO(mdan): handle kwargs properly. node = self._replace_stack_call(node) return node @@ -146,15 +151,22 @@ class ListTransformer(converter.Base): pop_element = original_call_node.args[0] else: pop_element = parser.parse_expression('None') + # The call will be something like "target.pop()", and the dtype is hooked to # target, hence the func.value. - dtype = anno.getanno( + # TODO(mdan): For lists of lists, this won't work. + # The reason why it won't work is because it's unclear how to annotate + # the list as a "list of lists with a certain element type" when using + # operations like `l.pop().pop()`. + dtype = self.get_definition_directive( original_call_node.func.value, - 'element_type', + directives.set_element_type, + 'dtype', default=templates.replace_as_expression('None')) - shape = anno.getanno( + shape = self.get_definition_directive( original_call_node.func.value, - 'element_shape', + directives.set_element_type, + 'shape', default=templates.replace_as_expression('None')) template = """ diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py index ea04097b28deedd705164bd95ab62dba3e3c7834..996e99ee61b3713a03ff167b892101fca35eaeac 100644 --- a/tensorflow/contrib/autograph/converters/lists_test.py +++ b/tensorflow/contrib/autograph/converters/lists_test.py @@ -18,9 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.converters import lists from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.contrib.autograph.lang import directives +from tensorflow.contrib.autograph.lang import special_functions +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -28,6 +31,9 @@ from tensorflow.python.ops import list_ops from tensorflow.python.platform import test +tf = None # Will be replaced by a mock. + + class ListTest(converter_testing.TestCase): def test_empty_list(self): @@ -35,10 +41,7 @@ class ListTest(converter_testing.TestCase): def test_fn(): return [] - node = self.parse_and_analyze(test_fn, {}) - node = lists.transform(node, self.ctx) - - with self.compiled(node) as result: + with self.converted(test_fn, lists, {}) as result: tl = result.test_fn() # Empty tensor lists cannot be evaluated or stacked. self.assertTrue(isinstance(tl, ops.Tensor)) @@ -49,27 +52,19 @@ class ListTest(converter_testing.TestCase): def test_fn(): return [1, 2, 3] - node = self.parse_and_analyze(test_fn, {}) - node = lists.transform(node, self.ctx) - - with self.compiled(node) as result: - with self.test_session() as sess: - tl = result.test_fn() - r = list_ops.tensor_list_stack(tl, dtypes.int32) - self.assertAllEqual(sess.run(r), [1, 2, 3]) + with self.converted(test_fn, lists, {}) as result: + self.assertAllEqual(result.test_fn(), [1, 2, 3]) def test_list_append(self): def test_fn(): - l = [1] + l = special_functions.tensor_list([1]) l.append(2) l.append(3) return l - node = self.parse_and_analyze(test_fn, {}) - node = lists.transform(node, self.ctx) - - with self.compiled(node) as result: + ns = {'special_functions': special_functions} + with self.converted(test_fn, lists, ns) as result: with self.test_session() as sess: tl = result.test_fn() r = list_ops.tensor_list_stack(tl, dtypes.int32) @@ -78,24 +73,21 @@ class ListTest(converter_testing.TestCase): def test_list_pop(self): def test_fn(): - l = [1, 2, 3] - utils.set_element_type(l, dtypes.int32, ()) + l = special_functions.tensor_list([1, 2, 3]) s = l.pop() return s, l - node = self.parse_and_analyze( - test_fn, - { - 'utils': utils, - 'dtypes': dtypes - }, - include_type_analysis=True, - ) - node = lists.transform(node, self.ctx) - - with self.compiled(node) as result: - result.utils = utils - result.dtypes = dtypes + ns = {'special_functions': special_functions} + node, ctx = self.prepare(test_fn, ns) + def_, = anno.getanno(node.body[0].targets[0], + anno.Static.ORIG_DEFINITIONS) + def_.directives[directives.set_element_type] = { + 'dtype': parser.parse_expression('tf.int32'), + 'shape': parser.parse_expression('()'), + } + node = lists.transform(node, ctx) + + with self.compiled(node, ns, dtypes.int32) as result: with self.test_session() as sess: ts, tl = result.test_fn() r = list_ops.tensor_list_stack(tl, dtypes.int32) @@ -108,10 +100,7 @@ class ListTest(converter_testing.TestCase): s = l.pop().pop() return s - node = self.parse_and_analyze(test_fn, {}) - node = lists.transform(node, self.ctx) - - with self.compiled(node) as result: + with self.converted(test_fn, lists, {}) as result: test_input = [1, 2, [1, 2, 3]] # TODO(mdan): Pass a list of lists of tensor when we fully support that. # For now, we just pass a regular Python list of lists just to verify that @@ -120,29 +109,24 @@ class ListTest(converter_testing.TestCase): def test_list_stack(self): - tf = None # Will be replaced with a mock. - def test_fn(): l = [1, 2, 3] - utils.set_element_type(l, dtypes.int32) return tf.stack(l) - node = self.parse_and_analyze( - test_fn, - { - 'utils': utils, - 'dtypes': dtypes - }, - include_type_analysis=True, - ) - node = lists.transform(node, self.ctx) - - with self.compiled(node, array_ops.stack, dtypes.int32) as result: - result.utils = utils - result.dtypes = dtypes + node, ctx = self.prepare(test_fn, {}) + def_, = anno.getanno(node.body[0].targets[0], + anno.Static.ORIG_DEFINITIONS) + def_.directives[directives.set_element_type] = { + 'dtype': parser.parse_expression('tf.int32') + } + node = lists.transform(node, ctx) + + with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result: with self.test_session() as sess: self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3]) + # TODO(mdan): Add a test with tf.stack with axis kwarg. + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/converters/logical_expressions_test.py b/tensorflow/contrib/autograph/converters/logical_expressions_test.py index 48186024a9da7b41fa7ff9a8ab18f3477ba09c8f..ca07de5e8a1f870391ecbe41bf1341dc52c25347 100644 --- a/tensorflow/contrib/autograph/converters/logical_expressions_test.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions_test.py @@ -31,10 +31,8 @@ class GradientsFunctionTest(converter_testing.TestCase): def test_fn(a, b): return a == b - node = self.parse_and_analyze(test_fn, {}) - node = logical_expressions.transform(node, self.ctx) - - with self.compiled(node, math_ops.equal) as result: + with self.converted(test_fn, logical_expressions, {}, + math_ops.equal) as result: with self.test_session() as sess: self.assertTrue(sess.run(result.test_fn(1, 1))) self.assertFalse(sess.run(result.test_fn(1, 2))) @@ -44,11 +42,8 @@ class GradientsFunctionTest(converter_testing.TestCase): def test_fn(a, b, c): return (a or b) and (a or b or c) - node = self.parse_and_analyze(test_fn, {}) - node = logical_expressions.transform(node, self.ctx) - - with self.compiled(node, math_ops.logical_or, - math_ops.logical_and) as result: + with self.converted(test_fn, logical_expressions, {}, math_ops.logical_or, + math_ops.logical_and) as result: with self.test_session() as sess: self.assertTrue(sess.run(result.test_fn(True, False, True))) diff --git a/tensorflow/contrib/autograph/converters/name_scopes_test.py b/tensorflow/contrib/autograph/converters/name_scopes_test.py index 444d0bcd469f35689d078debe3622f930dbac723..a329b0db70e2c6559fa5cf36694cf808fa28a6cb 100644 --- a/tensorflow/contrib/autograph/converters/name_scopes_test.py +++ b/tensorflow/contrib/autograph/converters/name_scopes_test.py @@ -31,17 +31,13 @@ class FunctionNameScopeTransformer(converter_testing.TestCase): def test_fn(l): """This should stay here.""" - a = 5 + a = 1 l += a return l - node = self.parse_and_analyze(test_fn, {}) - node = name_scopes.transform(node, self.ctx) - - with self.compiled(node, ops.name_scope) as result: + with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result: result_op = result.test_fn(constant_op.constant(1)) self.assertIn('test_fn/', result_op.op.name) - self.assertEqual('This should stay here.', result.test_fn.__doc__) def test_long_docstring(self): @@ -54,13 +50,12 @@ class FunctionNameScopeTransformer(converter_testing.TestCase): Returns: l """ - return l - - node = self.parse_and_analyze(test_fn, {}) - node = name_scopes.transform(node, self.ctx) + return l + 1 - with self.compiled(node, ops.name_scope) as result: - self.assertIn('Multi-line', result.test_fn.__doc__) + with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result: + result_op = result.test_fn(constant_op.constant(1)) + self.assertIn('test_fn/', result_op.op.name) + self.assertIn('Multi-line docstring.', result.test_fn.__doc__) self.assertIn('Returns:', result.test_fn.__doc__) def test_nested_functions(self): @@ -68,21 +63,16 @@ class FunctionNameScopeTransformer(converter_testing.TestCase): def test_fn(l): def inner_fn(i): - return i ** 2 - - l += 4 - return inner_fn(l) + return i + 1 - node = self.parse_and_analyze(test_fn, {}) - node = name_scopes.transform(node, self.ctx) + l += 1 + return l, inner_fn(l) - with self.compiled(node, ops.name_scope) as result: - result_op = result.test_fn(constant_op.constant(1)) - first_result_input_name = result_op.op.inputs[0].name - second_result_input_name = result_op.op.inputs[1].name - self.assertIn('test_fn/', first_result_input_name) - self.assertNotIn('inner_fn', first_result_input_name) - self.assertIn('test_fn/inner_fn/', second_result_input_name) + with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result: + first, second = result.test_fn(constant_op.constant(1)) + self.assertIn('test_fn/', first.op.name) + self.assertNotIn('inner_fn', first.op.name) + self.assertIn('test_fn/inner_fn/', second.op.name) def test_method(self): @@ -91,48 +81,20 @@ class FunctionNameScopeTransformer(converter_testing.TestCase): def test_fn(self, l): def inner_fn(i): - return i ** 2 - - l += 4 - return inner_fn(l) + return i + 1 - # Note that 'TestClass' was needed in the namespace here. - node = self.parse_and_analyze( - TestClass, {'TestClass': TestClass}, owner_type=TestClass) - node = name_scopes.transform(node, self.ctx) + l += 1 + return l, inner_fn(l) - with self.compiled(node, ops.name_scope) as result: - result_op = result.TestClass().test_fn(constant_op.constant(1)) - first_result_input_name = result_op.op.inputs[0].name - second_result_input_name = result_op.op.inputs[1].name - self.assertIn('TestClass/test_fn/', first_result_input_name) - self.assertNotIn('inner_fn', first_result_input_name) - self.assertIn('TestClass/test_fn/inner_fn/', second_result_input_name) + ns = {'TestClass': TestClass} + node, ctx = self.prepare(TestClass, ns, owner_type=TestClass) + node = name_scopes.transform(node, ctx) - def test_operator(self): - - class TestClass(object): - - def __call__(self, l): - - def inner_fn(i): - return i ** 2 - - l += 4 - return inner_fn(l) - - # Note that 'TestClass' was needed in the namespace here. - node = self.parse_and_analyze( - TestClass.__call__, {'TestClass': TestClass}, owner_type=TestClass) - node = name_scopes.transform(node, self.ctx) - - with self.compiled(node, ops.name_scope) as result: - result_op = result.__call__(TestClass(), constant_op.constant(1)) - first_result_input_name = result_op.op.inputs[0].name - second_result_input_name = result_op.op.inputs[1].name - self.assertIn('call__/', first_result_input_name) - self.assertNotIn('inner_fn', first_result_input_name) - self.assertIn('call__/inner_fn/', second_result_input_name) + with self.compiled(node, {}, ops.name_scope) as result: + first, second = result.TestClass().test_fn(constant_op.constant(1)) + self.assertIn('TestClass/test_fn/', first.op.name) + self.assertNotIn('inner_fn', first.op.name) + self.assertIn('TestClass/test_fn/inner_fn/', second.op.name) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/single_return.py b/tensorflow/contrib/autograph/converters/return_statements.py similarity index 100% rename from tensorflow/contrib/autograph/converters/single_return.py rename to tensorflow/contrib/autograph/converters/return_statements.py diff --git a/tensorflow/contrib/autograph/converters/return_statements_test.py b/tensorflow/contrib/autograph/converters/return_statements_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3c7c8c8a2586c6716e78960ee964ff3b0735fa47 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/return_statements_test.py @@ -0,0 +1,167 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for return_statements module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.converters import return_statements +from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.python.framework import ops +from tensorflow.python.platform import test + + +class SingleReturnTest(converter_testing.TestCase): + + def assertTransformedEquivalent(self, test_fn, *inputs): + ns = {'ops': ops} + with self.converted(test_fn, return_statements, ns) as result: + self.assertEqual(test_fn(*inputs), result.test_fn(*inputs)) + + def test_straightline(self): + + def test_fn(x): + return x * x + + self.assertTransformedEquivalent(test_fn, 2) + + def test_conditional(self): + + def test_fn(x): + if x > 0: + return x + else: + return x * x + + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, -2) + + def test_missing_orelse(self): + + def test_fn(x): + if x > 0: + return x + + node, ctx = self.prepare(test_fn, {}) + with self.assertRaises(ValueError): + return_statements.transform(node, ctx) + + def test_missing_orelse_recovrable(self): + + def test_fn(x): + if x > 0: + return x + return x * x + + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, -2) + + def test_missing_branch_return_recoverable(self): + + def test_fn(x): + if x < 0: + x *= x + else: + return x + return x + + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, -2) + + def test_conditional_nested(self): + + def test_fn(x): + if x > 0: + if x < 5: + return x + else: + return x * x + else: + return x * x * x + + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, -2) + self.assertTransformedEquivalent(test_fn, 5) + + def test_context_manager(self): + + def test_fn(x): + with ops.name_scope(''): + return x * x + + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, -2) + + def test_context_manager_in_conditional(self): + + def test_fn(x): + if x > 0: + with ops.name_scope(''): + return x * x + else: + return x + + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, -2) + + def text_conditional_in_context_manager(self): + + def test_fn(x): + with ops.name_scope(''): + if x > 0: + return x * x + else: + return x + + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, -2) + + def test_no_return(self): + + def test_fn(x): + x *= x + + self.assertTransformedEquivalent(test_fn, 2) + + def test_nested_functions(self): + + def test_fn(x): + + def inner_fn(y): + if y > 0: + return y * y + else: + return y + + return inner_fn(x) + + self.assertTransformedEquivalent(test_fn, 2) + self.assertTransformedEquivalent(test_fn, -2) + + def test_loop(self): + + def test_fn(x): + for _ in range(10): + return x + return x + + node, ctx = self.prepare(test_fn, {}) + with self.assertRaises(ValueError): + return_statements.transform(node, ctx) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py index a7ad8efed4c88e15ce9dc14cb02e5e035602013d..bee512abbc2e115d69bc9a5d53b6c54d428cc73a 100644 --- a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py @@ -25,140 +25,138 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variables +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test +tf = None # Will be replaced by a mock. + + class SideEffectGuardsTest(converter_testing.TestCase): def test_side_effect_on_return_only_variable(self): - tf = None - def test_fn(a): tf.assign(a, a + 1) return a - node = self.parse_and_analyze(test_fn, {}) - node = side_effect_guards.transform(node, self.ctx) + node, ctx = self.prepare(test_fn, {}) + node = side_effect_guards.transform(node, ctx) - with self.compiled(node, state_ops.assign) as result: - self.assertEqual(len(node.body[0].body), 1) + self.assertEqual(len(node.body), 1) + + with self.compiled(node, {}, state_ops.assign) as result: with self.test_session() as sess: - v = variables.Variable(2) + v = variable_scope.get_variable('test', initializer=2) sess.run(v.initializer) - # NOTE: We don't expect the assignment to execute in this case, because - # variables cannot be reliably guarded. - self.assertEqual(2, sess.run(result.test_fn(v))) + sess.run(result.test_fn(v)) + # TODO(mdan): Add support for this use case. + # Right now the variable `a` is not conditioned on the `assign` because + # there's no way to add control dependencies to a variable object. + self.assertEqual(2, sess.run(v)) def test_side_effect_on_used_variable(self): - tf = None - def test_fn(a): tf.assign(a, a + 1) return a + 1 - node = self.parse_and_analyze(test_fn, {}) - node = side_effect_guards.transform(node, self.ctx) + node, ctx = self.prepare(test_fn, {}) + node = side_effect_guards.transform(node, ctx) - with self.compiled(node, state_ops.assign) as result: - self.assertEqual(len(node.body[0].body), 1) + self.assertEqual(len(node.body), 1) + + with self.compiled(node, {}, state_ops.assign) as result: with self.test_session() as sess: - v = variables.Variable(2) + v = variable_scope.get_variable('test', initializer=2) sess.run(v.initializer) - # NOTE: Unlike test_side_effect_on_return_only_variable, the variable - # was used in the local scope and so we could catch the assign's side - # effect. - self.assertEqual(4, sess.run(result.test_fn(v))) + sess.run(result.test_fn(v)) + # TODO(mdan): Ensure the result of test_fn(v) is also deterministic. + # Right now it's 3 or 4 based on whether the read is synchronized. + self.assertEqual(3, sess.run(v)) def test_side_effect_on_tensor(self): - tf = None - def test_fn(a): tf.Assert(a > 0, ['expected in throw']) return a - node = self.parse_and_analyze(test_fn, {}) - node = side_effect_guards.transform(node, self.ctx) + node, ctx = self.prepare(test_fn, {}) + node = side_effect_guards.transform(node, ctx) - with self.compiled(node, control_flow_ops.Assert) as result: - self.assertEqual(len(node.body[0].body), 1) + self.assertEqual(len(node.body), 1) + + with self.compiled(node, {}, control_flow_ops.Assert) as result: with self.test_session() as sess: - # NOTE: In this case we can also capture the side effect because the - # argument is a tensor ans we can wrap it inside an identity. with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'expected in throw'): sess.run(result.test_fn(constant_op.constant(-1))) def test_multiline_block(self): - tf = None - def test_fn(a): - tf.assign(a, a + 1) + tf.assign_add(a, 1) b = a + 1 - tf.assign(a, b + 1) - c = b + 1 - d = c + 1 - return d + tf.assign_add(a, 1) + b += 1 + return b - node = self.parse_and_analyze(test_fn, {}) - node = side_effect_guards.transform(node, self.ctx) + node, ctx = self.prepare(test_fn, {}) + node = side_effect_guards.transform(node, ctx) - with self.compiled(node, state_ops.assign) as result: - self.assertEqual(len(node.body[0].body), 1) + self.assertEqual(len(node.body), 1) + + with self.compiled(node, {}, state_ops.assign_add) as result: with self.test_session() as sess: - v = variables.Variable(2) + v = variable_scope.get_variable('test', initializer=2) sess.run(v.initializer) - self.assertEqual(6, sess.run(result.test_fn(v))) + sess.run(result.test_fn(v)) + # TODO(mdan): Ensure the result of test_fn(v) is also deterministic. + self.assertEqual(4, sess.run(v)) def test_multiline_nested_block(self): - tf = None - def test_fn(a): with tf.name_scope('foo'): tf.assign(a, a + 1) b = a + 1 - c = b + 1 - d = c + 1 - return d + return b - node = self.parse_and_analyze(test_fn, {}) - node = side_effect_guards.transform(node, self.ctx) + node, ctx = self.prepare(test_fn, {}) + node = side_effect_guards.transform(node, ctx) - with self.compiled(node, state_ops.assign, ops.name_scope) as result: - self.assertEqual(len(node.body[0].body[0].body), 1) + self.assertEqual(len(node.body[0].body), 1) + + with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result: with self.test_session() as sess: - v = variables.Variable(2) + v = variable_scope.get_variable('test', initializer=2) sess.run(v.initializer) - self.assertEqual(6, sess.run(result.test_fn(v))) + sess.run(result.test_fn(v)) + # TODO(mdan): Ensure the result of test_fn(v) is also deterministic. + self.assertEqual(3, sess.run(v)) def test_multiline_block_unsafe(self): - tf = None - def test_fn(a): tf.assign(a, a + 1) b = a + 1 - tf.assign(a, a + 1) + tf.assign_add(a, 1) c = b + 1 - d = c + 1 - return d + return c + + node, ctx = self.prepare(test_fn, {}) + node = side_effect_guards.transform(node, ctx) - node = self.parse_and_analyze(test_fn, {}) - node = side_effect_guards.transform(node, self.ctx) + self.assertEqual(len(node.body), 1) - with self.compiled(node, state_ops.assign) as result: - self.assertEqual(len(node.body[0].body), 1) + with self.compiled(node, {}, state_ops.assign, + state_ops.assign_add) as result: with self.test_session() as sess: - v = variables.Variable(2) + v = variable_scope.get_variable('test', initializer=2) sess.run(v.initializer) - # NOTE: This intentionally highlights the flakiness. The test should be - # tightened down once that is solved. - self.assertTrue(sess.run(result.test_fn(v)) in (6, 7)) + sess.run(result.test_fn(v)) + # TODO(mdan): Ensure the result of test_fn(v) is also deterministic. + self.assertEqual(4, sess.run(v)) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/single_return_test.py b/tensorflow/contrib/autograph/converters/single_return_test.py deleted file mode 100644 index 1f0de4310e370235a4a7bfeaa61bd519a81aff47..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/converters/single_return_test.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for single_return module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.autograph.converters import single_return -from tensorflow.contrib.autograph.core import converter_testing -from tensorflow.python.framework.ops import name_scope -from tensorflow.python.platform import test - - -class SingleReturnTest(converter_testing.TestCase): - - def compiled_fn(self, test_fn, *args): - node = self.parse_and_analyze(test_fn, {}) - node = single_return.transform(node, self.ctx) - module = self.compiled(node, *args) - return module - - def test_noop(self): - # Noop - def test_fn(x): - return x - - with self.compiled_fn(test_fn) as result: - self.assertEqual(test_fn(2.0), result.test_fn(2.0)) - - def test_return_expression(self): - # ANF - def test_fn(x): - return x * x - - with self.compiled_fn(test_fn) as result: - x = 2 - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_merge(self): - # Simple merge - def test_fn(x): - if x > 0: - return x - else: - return x * x - - with self.compiled_fn(test_fn) as result: - for x in [-2, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_orphan_branch(self): - - def test_fn(x): - if x > 0: - return x - - with self.assertRaises(ValueError): - self.compiled_fn(test_fn) - - def test_lift_body_into_false_branch(self): - - def test_fn(x): - if x > 0: - return x - return x * x - - with self.compiled_fn(test_fn) as result: - for x in [-2, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_lift_body_into_true_branch(self): - - def test_fn(x): - if x < 0: - x *= x - else: - # TODO(alexbw): linter bug here that requires us suppress this warning. - return x # pylint: disable=undefined-loop-variable - return x - - with self.compiled_fn(test_fn) as result: - for x in [-2, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_nested_if(self): - - def test_fn(x): - if x > 0: - if x < 5: - return x - else: - return x * x - else: - return x * x * x - - with self.compiled_fn(test_fn) as result: - for x in [-2, 2, 5]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_context_manager(self): - - def test_fn(x): - - with name_scope(''): - return x * x - - with self.compiled_fn(test_fn) as result: - result.name_scope = name_scope - for x in [-2, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_context_manager_in_conditional(self): - - def test_fn(x): - if x > 0: - with name_scope(''): - return x * x - else: - return x - - with self.compiled_fn(test_fn, name_scope) as result: - result.name_scope = name_scope - for x in [-2, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def text_conditional_in_context_manager(self): - - def test_fn(x): - with name_scope(''): - if x > 0: - return x * x - else: - return x - - with self.compiled_fn(test_fn) as result: - result.name_scope = name_scope - for x in [-2, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_no_return(self): - - def test_fn(x): - x *= x - - with self.compiled_fn(test_fn) as result: - self.assertEqual(test_fn(2), result.test_fn(2)) - - def test_nested_functiondefs(self): - - def test_fn(x): - - def inner_fn(y): - if y > 0: - return y * y - else: - return y - - return inner_fn(x) - - with self.compiled_fn(test_fn) as result: - for x in [-2, 2]: - self.assertEqual(test_fn(x), result.test_fn(x)) - - def test_loop(self): - - def test_fn(x): - for _ in range(10): - return x - return x - - with self.assertRaises(ValueError): - self.compiled_fn(test_fn) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/contrib/autograph/converters/slices.py index 3f5fc57125a8b65faf1e3a377d7984ff05b3245c..c527f98613a2ffebf35141d4dac85e972a89c93b 100644 --- a/tensorflow/contrib/autograph/converters/slices.py +++ b/tensorflow/contrib/autograph/converters/slices.py @@ -21,7 +21,7 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.core import converter -from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.lang import directives from tensorflow.contrib.autograph.pyct import templates @@ -36,12 +36,14 @@ class SliceTransformer(converter.Base): def _process_single_assignment(self, target, value): if not isinstance(target, gast.Subscript): return None + if not isinstance(target.slice, gast.Index): + return None template = """ target = ag__.set_item(target, key, item) """ return templates.replace( - template, target=target.value, key=target.slice, item=value) + template, target=target.value, key=target.slice.value, item=value) def visit_Assign(self, node): node = self.generic_visit(node) @@ -56,17 +58,17 @@ class SliceTransformer(converter.Base): def visit_Subscript(self, node): node = self.generic_visit(node) if not isinstance(node.slice, gast.Index): - # TODO(mdan): It might make more sense to wave them through. - raise NotImplementedError('non-index slice') + return node if not isinstance(node.ctx, gast.Load): # Index writes are handled at a higher level, one at which the rvalue is # also available. return node - dtype = anno.getanno( + dtype = self.get_definition_directive( node.value, - 'element_type', + directives.set_element_type, + 'dtype', default=templates.replace_as_expression('None')) template = """ @@ -76,7 +78,7 @@ class SliceTransformer(converter.Base): opts=ag__.GetItemOpts(element_dtype=dtype)) """ return templates.replace_as_expression( - template, target=node.value, key=node.slice, dtype=dtype) + template, target=node.value, key=node.slice.value, dtype=dtype) def transform(node, ctx): diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/contrib/autograph/converters/slices_test.py index df9a4c8bab66f24374605b45bc90bc2730431323..c822d53a4a2810755fd6841af85544dd8fc76a5e 100644 --- a/tensorflow/contrib/autograph/converters/slices_test.py +++ b/tensorflow/contrib/autograph/converters/slices_test.py @@ -18,9 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.converters import slices from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.contrib.autograph.lang import directives +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import list_ops @@ -32,28 +35,42 @@ class SliceTest(converter_testing.TestCase): def test_index_access(self): def test_fn(l): - utils.set_element_type(l, dtypes.int32) return l[1] - node = self.parse_and_analyze( - test_fn, - { - 'utils': utils, - 'dtypes': dtypes - }, - include_type_analysis=True, - ) - node = slices.transform(node, self.ctx) - - with self.compiled(node, dtypes.int32) as result: - result.utils = utils - result.dtypes = dtypes + node, ctx = self.prepare(test_fn, {}) + def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) + def_.directives[directives.set_element_type] = { + 'dtype': parser.parse_expression('tf.int32') + } + node = slices.transform(node, ctx) + + with self.compiled(node, {}, dtypes.int32) as result: with self.test_session() as sess: tl = list_ops.tensor_list_from_tensor( [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32)) y = result.test_fn(tl) self.assertEqual(2, sess.run(y)) + def test_index_access_multiple_definitions(self): + + def test_fn(l): + if l: + l = [] + return l[1] + + node, ctx = self.prepare(test_fn, {}) + def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) + def_.directives[directives.set_element_type] = { + 'dtype': parser.parse_expression('tf.int32') + } + def_, = anno.getanno(node.body[0].body[0].targets[0], + anno.Static.DEFINITIONS) + def_.directives[directives.set_element_type] = { + 'dtype': parser.parse_expression('tf.float32') + } + with self.assertRaises(transformer.AutographParseError): + slices.transform(node, ctx) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/core/BUILD b/tensorflow/contrib/autograph/core/BUILD index 833f9dced81bd651244d281322c830bb1c88b259..1873045a921f8af6068d8fccca6a5625b2aedcf8 100644 --- a/tensorflow/contrib/autograph/core/BUILD +++ b/tensorflow/contrib/autograph/core/BUILD @@ -19,6 +19,7 @@ py_library( srcs = [ "config.py", "converter.py", + "errors.py", "naming.py", ], srcs_version = "PY2AND3", @@ -30,6 +31,31 @@ py_library( ], ) +py_test( + name = "errors_test", + srcs = ["errors_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":core", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + ], +) + +py_test( + name = "naming_test", + srcs = ["naming_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":core", + "//tensorflow/python:client_testlib", + ], +) + py_library( name = "test_lib", srcs = [ @@ -47,13 +73,3 @@ py_library( "@six_archive//:six", ], ) - -py_test( - name = "naming_test", - srcs = ["naming_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":core", - "//tensorflow/python:client_testlib", - ], -) diff --git a/tensorflow/contrib/autograph/core/converter.py b/tensorflow/contrib/autograph/core/converter.py index 54e6aa0f3bbb9059e044861362407cb5050240b4..83a80c1f52123c325782a67c651e892163af83b3 100644 --- a/tensorflow/contrib/autograph/core/converter.py +++ b/tensorflow/contrib/autograph/core/converter.py @@ -64,15 +64,29 @@ from __future__ import division from __future__ import print_function import collections +from enum import Enum + from tensorflow.contrib.autograph.core import config from tensorflow.contrib.autograph.core import naming +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import cfg +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import liveness +from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions +from tensorflow.contrib.autograph.pyct.static_analysis import type_info # TODO(mdan): These contexts can be refactored into first class objects. # For example, we could define Program and Entity abstractions that hold on # to the actual entity and have conversion methods. +# TODO(mdan): Add a test specific to this converter. + class ProgramContext(object): """ProgramContext keeps track of converting function hierarchies. @@ -197,6 +211,46 @@ class Base(transformer.Base): self._used = False self._ast_depth = 0 + def get_definition_directive(self, node, directive, arg, default): + """Returns the unique directive for a symbol, or a default if none exist. + + See lang/directives.py for details on directives. + + Args: + node: ast.AST + directive: Callable[..., Any] + arg: str + default: Any + + Raises: + ValueError: if conflicting annotations have been found + """ + defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ()) + if not defs: + return default + + # TODO(mdan): Simplify this. + arg_values = [] + for def_ in defs: + if (directive not in def_.directives or + arg not in def_.directives[directive]): + continue + arg_value = def_.directives[directive][arg] + for prev_value in arg_values: + if not ast_util.matches(arg_value, prev_value): + qn = anno.getanno(node, anno.Basic.QN) + raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' % + (qn, directive.__name__, arg, + compiler.ast_to_source(arg_value).strip(), + compiler.ast_to_source(prev_value).strip())) + arg_values.append(arg_value) + + if not arg_values: + return default + + arg_value, = arg_values + return arg_value + def visit(self, node): if not self._ast_depth: if self._used: @@ -208,3 +262,69 @@ class Base(transformer.Base): return super(Base, self).visit(node) finally: self._ast_depth -= 1 + + +class AnnotatedDef(reaching_definitions.Definition): + + def __init__(self): + super(AnnotatedDef, self).__init__() + self.directives = {} + + +class AgAnno(Enum): + """Annotation labels specific to AutoGraph. See anno.py.""" + + DIRECTIVES = 'User directives associated with the annotated statement.' + + def __repr__(self): + return self.name + + +def standard_analysis(node, context, is_initial=False): + """Performs a complete static analysis of the given code. + + Args: + node: ast.AST + context: converter.EntityContext + is_initial: bool, whether this is the initial analysis done on the input + source code + + Returns: + ast.AST, same as node, with the static analysis annotations added + """ + # TODO(mdan): Clear static analysis here. + # TODO(mdan): Consider not running all analyses every time. + # TODO(mdan): Don't return a node because it's modified by reference. + graphs = cfg.build(node) + node = qual_names.resolve(node) + node = activity.resolve(node, context.info, None) + node = reaching_definitions.resolve(node, context.info, graphs, AnnotatedDef) + node = liveness.resolve(node, context.info, graphs) + node = live_values.resolve(node, context.info, config.PYTHON_LITERALS) + node = type_info.resolve(node, context.info) + # This second call allows resolving first-order class attributes. + node = live_values.resolve(node, context.info, config.PYTHON_LITERALS) + if is_initial: + anno.dup( + node, + { + anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, + }, + ) + return node + + +def apply_(node, context, converter_module): + """Applies a converter to an AST. + + Args: + node: ast.AST + context: converter.EntityContext + converter_module: converter.Base + + Returns: + ast.AST, the result of applying converter to node + """ + node = standard_analysis(node, context) + node = converter_module.transform(node, context) + return node diff --git a/tensorflow/contrib/autograph/core/converter_testing.py b/tensorflow/contrib/autograph/core/converter_testing.py index 0e46aacc1216d2dbd9d34ad0e72ca8251094bddc..5ee2c3fffd7474cb8ca28349385a9d543e92a72d 100644 --- a/tensorflow/contrib/autograph/core/converter_testing.py +++ b/tensorflow/contrib/autograph/core/converter_testing.py @@ -20,19 +20,19 @@ from __future__ import print_function import contextlib import imp +import sys + +import six from tensorflow.contrib.autograph import operators from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.core import config from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.core import errors from tensorflow.contrib.autograph.pyct import compiler from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import pretty_printer -from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.contrib.autograph.pyct import transformer -from tensorflow.contrib.autograph.pyct.static_analysis import activity -from tensorflow.contrib.autograph.pyct.static_analysis import live_values -from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.platform import test @@ -74,7 +74,17 @@ class TestCase(test.TestCase): """Base class for unit tests in this module. Contains relevant utilities.""" @contextlib.contextmanager - def compiled(self, node, *symbols): + def assertPrints(self, expected_result): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + yield + self.assertEqual(out_capturer.getvalue(), expected_result) + finally: + sys.stdout = sys.__stdout__ + + @contextlib.contextmanager + def compiled(self, node, namespace, *symbols): source = None self.dynamic_calls = [] @@ -84,12 +94,17 @@ class TestCase(test.TestCase): return 7 try: - result, source = compiler.ast_to_object(node) + result, source = compiler.ast_to_object(node, include_source_map=True) + result.tf = self.make_fake_mod('fake_tf', *symbols) fake_ag = self.make_fake_mod('fake_ag', converted_call) fake_ag.__dict__.update(operators.__dict__) fake_ag.__dict__['utils'] = utils + fake_ag.__dict__['rewrite_graph_construction_error'] = ( + errors.rewrite_graph_construction_error) result.__dict__['ag__'] = fake_ag + for k, v in namespace.items(): + result.__dict__[k] = v yield result except Exception: # pylint:disable=broad-except if source is None: @@ -98,6 +113,13 @@ class TestCase(test.TestCase): print('Offending compiled code:\n%s' % source) raise + @contextlib.contextmanager + def converted(self, entity, converter_module, namespace, *tf_symbols): + node, ctx = self.prepare(entity, namespace) + node = converter_module.transform(node, ctx) + with self.compiled(node, namespace, *tf_symbols) as result: + yield result + def make_fake_mod(self, name, *symbols): fake_mod = imp.new_module(name) for s in symbols: @@ -114,17 +136,16 @@ class TestCase(test.TestCase): for k, v in ns.items(): setattr(module, k, v) - def parse_and_analyze(self, - test_fn, - namespace, - namer=None, - arg_types=None, - include_type_analysis=True, - owner_type=None, - recursive=True, - autograph_decorators=()): + def prepare(self, + test_fn, + namespace, + namer=None, + arg_types=None, + owner_type=None, + recursive=True, + autograph_decorators=()): node, source = parser.parse_entity(test_fn) - + node = node.body[0] if namer is None: namer = FakeNamer() program_ctx = converter.ProgramContext( @@ -141,12 +162,5 @@ class TestCase(test.TestCase): arg_types=arg_types, owner_type=owner_type) ctx = converter.EntityContext(namer, entity_info, program_ctx) - - node = qual_names.resolve(node) - node = activity.resolve(node, entity_info) - node = live_values.resolve(node, entity_info, {}) - if include_type_analysis: - node = type_info.resolve(node, entity_info) - node = live_values.resolve(node, entity_info, {}) - self.ctx = ctx - return node + node = converter.standard_analysis(node, ctx, is_initial=True) + return node, ctx diff --git a/tensorflow/contrib/autograph/core/errors.py b/tensorflow/contrib/autograph/core/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..5a57d57e7d4c6461f05030b72cc9bfe1b33210db --- /dev/null +++ b/tensorflow/contrib/autograph/core/errors.py @@ -0,0 +1,258 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Error rewriting logic. + +Contains the functions responsible for rewriting tracebacks of errors raised +in AutoGraph (AG) code to refer to user written code, so that errors only refer +to the original user code. + +When 'user code' is used in comments it refers to the original source code that +the user wrote and is converting using AutoGraph. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import logging +import sys +import traceback + +from tensorflow.contrib.autograph.pyct import origin_info +from tensorflow.python.framework import errors_impl + +# TODO(mdan): Add a superclass common to all errors. + + +class GraphConstructionError(Exception): + """Error for graph construction errors from AutoGraph generated code.""" + + def __init__(self, original_error, custom_traceback): + self.original_error = original_error + self.custom_traceback = custom_traceback + super(GraphConstructionError, self).__init__() + + def __str__(self): + traceback_str = ''.join(traceback.format_list(self.custom_traceback)) + return ('Traceback (most recent call last):\n' + traceback_str + '\n' + str( + self.original_error) + '\n') + + +class TfRuntimeError(Exception): + """Error wrapper for runtime errors raised by AutoGraph generated code.""" + + def __init__(self, op_name, op_message, custom_traceback): + self.op_name = op_name + self.op_message = op_message + self.custom_traceback = custom_traceback + super(TfRuntimeError, self).__init__() + + def __str__(self): + message = '%s\n\nCaused by op %r, defined at:\n' % (self.op_message, + self.op_name) + return message + ''.join(traceback.format_list(self.custom_traceback)) + + +def _rewrite_tb(source_map, tb): + """Rewrites code references in a traceback. + + Args: + source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping + locations to their origin + tb: List[Tuple[Text, Text, Text, Text]], consistent with + traceback.extract_tb. + Returns: + List[Tuple[Text, Text, Text, Text]], the rewritten traceback + """ + new_tb = [] + for frame in tb: + filename, lineno, _, _ = frame + loc = origin_info.LineLocation(filename, lineno) + origin = source_map.get(loc) + if origin is not None: + new_tb.append(origin.as_frame()) + else: + new_tb.append(frame) + return new_tb + + +# TODO(mdan): rename to raise_* +def rewrite_graph_construction_error(source_map): + """Rewrites errors raised by non-AG APIs inside AG generated code. + + This is called from the except handler inside an AutoGraph generated function + (that is, during exception handling). Only rewrites the frames corresponding + to the function that this is called from, so each function is responsible + to call this to have its own frames rewritten. + + This function always raises an error. + + Args: + source_map: Dict[origin_info.Location, origin_info.OriginInfo], the source + map belonging to the calling function + + Raises: + GraphConstructionError: The rewritten underlying error. + Exception: The underlying error, if it could not be rewritten. + """ + error_info = sys.exc_info() + _, original_error, e_traceback = error_info + assert original_error is not None + try: + current_traceback = _cut_traceback_loops(source_map, + traceback.extract_tb(e_traceback)) + if isinstance(original_error, GraphConstructionError): + # TODO(mdan): This is incomplete. + # The error might have bubbled through a non-converted function. + previous_traceback = original_error.custom_traceback + cleaned_traceback = [current_traceback[0]] + previous_traceback + else: + cleaned_traceback = current_traceback + + cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback) + + if isinstance(original_error, GraphConstructionError): + original_error.custom_traceback = cleaned_traceback + new_error = original_error + else: + new_error = GraphConstructionError(original_error, cleaned_traceback) + except Exception: + logging.exception('Error while rewriting AutoGraph error:') + # TODO(mdan): Should reraise here, removing the top frame as well. + raise original_error + else: + raise new_error + finally: + # Addresses warning https://docs.python.org/2/library/sys.html#sys.exc_info. + del e_traceback + + +def _cut_traceback_loops(source_map, original_traceback): + """Check for cases where we leave a user method and re-enter it. + + This is done by looking at the function names when the filenames are from any + files the user code is in. If we find a case where we return to a user method + after leaving it then we cut out the frames in between because we assume this + means these in between frames are from internal AutoGraph code that shouldn't + be included. + + An example of this is: + + File "file1.py", line 57, in my_func + ... + File "control_flow_ops.py", line 231, in cond + ... + File "control_flow_ops.py", line 1039, in inner_cond + ... + File "file1.py", line 68, in my_func + ... + + Where we would remove the control_flow_ops.py frames because we re-enter + my_func in file1.py. + + The source map keys are (file_path, line_number) so get the set of all user + file_paths. + + Args: + source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping + locations to their origin + original_traceback: List[Tuple[Text, Text, Text, Text]], consistent with + traceback.extract_tb. + + Returns: + List[Tuple[Text, Text, Text, Text]], the traceback with any loops removed. + """ + all_user_files = set(loc.filename for loc in source_map) + cleaned_traceback = [] + last_user_frame_index = None + last_user_user_file_path = None + # TODO(mdan): Simplify this logic. + for fi, frame in enumerate(original_traceback): + frame_file_path, lineno, _, _ = frame + src_map_key = origin_info.LineLocation(frame_file_path, lineno) + if frame_file_path in all_user_files: + if src_map_key in source_map: + if (last_user_frame_index is not None and + last_user_user_file_path == frame_file_path): + cleaned_traceback = cleaned_traceback[:last_user_frame_index] + last_user_frame_index = fi + last_user_user_file_path = frame_file_path + cleaned_traceback.append(frame) + return cleaned_traceback + + +# TODO(mdan): This should be consistent with rewrite_graph_construction_error +# Both should either raise or return. +def rewrite_tf_runtime_error(error, source_map): + """Rewrites TensorFlow runtime errors raised by ops created in AG code. + + Args: + error: tf.OpError + source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo] + + Returns: + TfRuntimeError, the rewritten underlying error. + """ + try: + cleaned_traceback = _cut_traceback_loops(source_map, error.op.traceback) + # cleaned_traceback = error.op.traceback + cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback) + + op_name = error.op.name + op_message = error.message + rewritten_error = TfRuntimeError(op_name, op_message, cleaned_traceback) + return rewritten_error + except Exception: # pylint: disable=broad-except + logging.exception('Error while rewriting AutoGraph error:') + return error + + +# TODO(znado): Add arg to enable different levels of error rewriting. +@contextlib.contextmanager +def improved_errors(converted_function): + """Context manager that rewrites runtime errors. + + This context manager will rewrite runtime errors so that their traceback + is relative to the original code before conversion. + + Use with the output of to_graph, and wrap the execution of respective ops. + Example: + + converted_my_func = ag.to_graph(my_func) + ops = converted_my_func(...) + + with ag.improved_errors(converted_my_func): + sess.run(ops) + + Args: + converted_function: Callable[..., Any], the output of a to_graph call + + Yields: + None + + Raises: + TfRuntimeError: if any OpError originates in the converted code, it will + be wrapped into a TfRuntimeError + ValueError: If converted_function is not generated by AutoGraph + """ + if (getattr(converted_function, 'ag_source_map', None) is None or + not isinstance(converted_function.ag_source_map, dict)): + raise ValueError( + 'converted_function must be the result of an autograph.to_graph call') + try: + yield + except errors_impl.OpError as e: + raise rewrite_tf_runtime_error(e, converted_function.ag_source_map) diff --git a/tensorflow/contrib/autograph/core/errors_test.py b/tensorflow/contrib/autograph/core/errors_test.py new file mode 100644 index 0000000000000000000000000000000000000000..404c1f5456f9654724d068e3007fe9ced15cbf07 --- /dev/null +++ b/tensorflow/contrib/autograph/core/errors_test.py @@ -0,0 +1,105 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for errors module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.core import errors +from tensorflow.contrib.autograph.pyct import origin_info +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors as tf_errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test +from tensorflow.python.util import tf_inspect + + +def zero_div(): + x = array_ops.constant(10, dtype=dtypes.int32) + return x // 0 + + +def zero_div_caller(): + return zero_div() + + +class RuntimeErrorsTest(test.TestCase): + + def fake_origin(self, function, line_offset): + _, lineno = tf_inspect.getsourcelines(function) + filename = tf_inspect.getsourcefile(function) + lineno += line_offset + loc = origin_info.LineLocation(filename, lineno) + origin = origin_info.OriginInfo(loc, 'test_function_name', 'test_code', + 'test_comment') + return loc, origin + + def test_improved_errors_basic(self): + loc, origin = self.fake_origin(zero_div, 2) + zero_div_caller.ag_source_map = {loc: origin} + + ops = zero_div_caller() + with self.assertRaises(errors.TfRuntimeError) as cm: + with errors.improved_errors(zero_div_caller): + with self.test_session() as sess: + sess.run(ops) + + for frame in cm.exception.custom_traceback: + _, _, function_name, _ = frame + self.assertNotEqual('zero_div', function_name) + self.assertIn(origin.as_frame(), set(cm.exception.custom_traceback)) + + def test_improved_errors_no_matching_lineno(self): + loc, origin = self.fake_origin(zero_div, -1) + zero_div_caller.ag_source_map = {loc: origin} + + ops = zero_div_caller() + with self.assertRaises(errors.TfRuntimeError) as cm: + with errors.improved_errors(zero_div_caller): + with self.test_session() as sess: + sess.run(ops) + + all_function_names = set() + for frame in cm.exception.custom_traceback: + _, _, function_name, _ = frame + all_function_names.add(function_name) + self.assertNotEqual('test_function_name', function_name) + self.assertIn('zero_div', all_function_names) + + def test_improved_errors_failures(self): + loc, _ = self.fake_origin(zero_div, 2) + zero_div_caller.ag_source_map = {loc: 'bogus object'} + + ops = zero_div_caller() + with self.assertRaises(tf_errors.InvalidArgumentError): + with errors.improved_errors(zero_div_caller): + with self.test_session() as sess: + sess.run(ops) + + def test_improved_errors_validation(self): + with self.assertRaisesRegexp( + ValueError, + 'converted_function must be the result of an autograph.to_graph call'): + errors.improved_errors(zero_div).__enter__() + with self.assertRaisesRegexp( + ValueError, + 'converted_function must be the result of an autograph.to_graph call'): + zero_div_caller.ag_source_map = 'not a dict' + errors.improved_errors(zero_div_caller).__enter__() + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/examples/integration_tests/BUILD b/tensorflow/contrib/autograph/examples/integration_tests/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..6c281485b4a3c4d09292a4d7af16330cdc44edd4 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/integration_tests/BUILD @@ -0,0 +1,54 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_test( + name = "errors_test", + srcs = [ + "errors_test.py", + ], + srcs_version = "PY2AND3", + tags = ["no_windows"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "keras_test", + srcs = [ + "keras_test.py", + ], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "list_literals_test", + srcs = [ + "list_literals_test.py", + ], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py b/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b9159942bcf8837b97dfac000d8fb34d15a314 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py @@ -0,0 +1,162 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Error traceback rewriting integration tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib import autograph as ag +from tensorflow.python.util import tf_inspect + + +class ErrorsTest(tf.test.TestCase): + + def test_graph_construction_error_rewriting_call_tree(self): + + def innermost(x): + if x > 0: + return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32) + return tf.zeros((2, 3)) + + def inner_caller(): + return innermost(1.0) + + def caller(): + return inner_caller() + + with self.assertRaises(ag.GraphConstructionError) as error: + graph = ag.to_graph(caller) + graph() + expected = error.exception + custom_traceback = expected.custom_traceback + found_correct_filename = False + num_innermost_names = 0 + num_inner_caller_names = 0 + num_caller_names = 0 + ag_output_filename = tf_inspect.getsourcefile(graph) + for frame in custom_traceback: + filename, _, fn_name, _ = frame + self.assertFalse('control_flow_ops.py' in filename) + self.assertFalse(ag_output_filename in filename) + found_correct_filename |= __file__ in filename + self.assertNotEqual('tf__test_fn', fn_name) + num_innermost_names += int('innermost' == fn_name) + self.assertNotEqual('tf__inner_caller', fn_name) + num_inner_caller_names += int('inner_caller' == fn_name) + self.assertNotEqual('tf__caller', fn_name) + num_caller_names += int('caller' == fn_name) + self.assertTrue(found_correct_filename) + self.assertEqual(num_innermost_names, 1) + self.assertEqual(num_inner_caller_names, 1) + self.assertEqual(num_caller_names, 1) + + def test_graph_construction_error_rewriting_class(self): + + class TestClass(object): + + def test_fn(self): + return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32) + + def inner_caller(self): + return self.test_fn() + + def caller(self): + return self.inner_caller() + + # Note we expect a TypeError here because the traceback will not be + # rewritten for classes. + with self.assertRaises(TypeError): + graph = ag.to_graph(TestClass) + graph().caller() + + def test_runtime_error_rewriting(self): + + def g(x, s): + while tf.reduce_sum(x) > s: + x //= 0 + return x + + def test_fn(x): + return g(x, 10) + + compiled_fn = ag.to_graph(test_fn) + + with self.assertRaises(ag.TfRuntimeError) as error: + with self.test_session() as sess: + x = compiled_fn(tf.constant([4, 8])) + with ag.improved_errors(compiled_fn): + sess.run(x) + expected = error.exception + custom_traceback = expected.custom_traceback + found_correct_filename = False + num_test_fn_frames = 0 + num_g_frames = 0 + ag_output_filename = tf_inspect.getsourcefile(compiled_fn) + for frame in custom_traceback: + filename, _, fn_name, source_code = frame + self.assertFalse(ag_output_filename in filename) + self.assertFalse('control_flow_ops.py' in filename) + self.assertFalse('ag__.' in fn_name) + self.assertFalse('tf__g' in fn_name) + self.assertFalse('tf__test_fn' in fn_name) + found_correct_filename |= __file__ in filename + num_test_fn_frames += int('test_fn' == fn_name and + 'return g(x, 10)' in source_code) + # This makes sure that the code is correctly rewritten from "x_1 //= 0" to + # "x //= 0". + num_g_frames += int('g' == fn_name and 'x //= 0' in source_code) + self.assertTrue(found_correct_filename) + self.assertEqual(num_test_fn_frames, 1) + self.assertEqual(num_g_frames, 1) + + def test_runtime_error_rewriting_nested(self): + + def test_fn(x): + + def g(y): + return y**2 // 0 + + s = 0 + for xi in x: + s += g(xi) + return s + + compiled_fn = ag.to_graph(test_fn) + + # TODO(b/111408261): Nested functions currently do not rewrite correctly, + # when they do we should change this test to check for the same traceback + # properties as the other tests. This should throw a runtime error with a + # frame with "g" as the function name but because we don't yet add + # try/except blocks to inner functions the name is "tf__g". + with self.assertRaises(ag.TfRuntimeError) as error: + with self.test_session() as sess: + x = compiled_fn(tf.constant([4, 8])) + with ag.improved_errors(compiled_fn): + sess.run(x) + expected = error.exception + custom_traceback = expected.custom_traceback + num_tf_g_frames = 0 + for frame in custom_traceback: + _, _, fn_name, _ = frame + self.assertNotEqual('g', fn_name) + num_tf_g_frames += int('tf__g' == fn_name) + self.assertEqual(num_tf_g_frames, 1) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py b/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7e7ef5a3e2bbf6a15936eb181c9c4112f8b820e6 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py @@ -0,0 +1,103 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras integration tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib import autograph + + +class MinimalKeras(tf.keras.Model): + + def call(self, x): + return x * 3 + + +class ModelWithStaticConditional(object): + + def __init__(self, initial): + self.initial = initial + if self.initial: + self.h = 15 + + @autograph.convert() + def call(self): + x = 10 + if self.initial: + x += self.h + return x + + +class BasicBlock(tf.keras.Model): + + def __init__(self): + super(BasicBlock, self).__init__() + self.conv1 = tf.keras.layers.Conv2D(8, 3) + self.pool = tf.keras.layers.GlobalAveragePooling2D() + self.dense = tf.keras.layers.Dense(3) + + def call(self, x): + x = self.conv1(x) + x = self.pool(x) + x = self.dense(x) + return x + + +class CompoundModel(tf.keras.Model): + + def __init__(self): + super(CompoundModel, self).__init__() + self.block = BasicBlock() + + @autograph.convert(recursive=True) + def call(self, x): + x = self.block(x) # pylint: disable=not-callable + return x + + +class KerasTest(tf.test.TestCase): + + def test_basic(self): + MinimalKeras() + + def test_conditional_attributes_False(self): + model = ModelWithStaticConditional(False) + self.assertEqual(model.call(), 10) + + def test_conditional_attributes_True(self): + model = ModelWithStaticConditional(True) + self.assertEqual(model.call(), 25) + + def test_recursive_true(self): + with self.assertRaisesRegexp(NotImplementedError, + 'Object conversion is not yet supported.'): + with tf.Graph().as_default(): + model = CompoundModel() + model.build(tf.TensorShape((None, 10, 10, 1))) + init = tf.global_variables_initializer() + + with tf.Session() as sess: + sess.run(init) + sample_input = tf.random_uniform((1, 10, 10, 1)) + output = model(sample_input) # pylint: disable=not-callable + self.assertEqual(sess.run(output).shape, (1, 3)) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/autograph/utils/type_hints.py b/tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py similarity index 54% rename from tensorflow/contrib/autograph/utils/type_hints.py rename to tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py index aeb9e545610460afbe364dfcfc7a54b9aede29fe..680b6dbaf07fc10e11dfa1e9d3a075624024c103 100644 --- a/tensorflow/contrib/autograph/utils/type_hints.py +++ b/tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py @@ -12,30 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""No-op utilities that provide static type hints. - -These are used when the data type is not known at creation, for instance in the -case of empty lists. -""" +"""Tests of functions that use list literals.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +import tensorflow as tf + +from tensorflow.contrib import autograph as ag + + +def list_used_as_tuple(): + return tf.constant([1, 2, 3]) + + +class ListLiteralsTest(tf.test.TestCase): + + def test_basic(self): + converted = ag.to_graph(list_used_as_tuple) + result = converted() -def set_element_type(entity, dtype, shape=None): - """Indicates that the entity is expected hold items of specified type. + with self.test_session() as sess: + self.assertAllEqual(sess.run(result), [1, 2, 3]) - This function is a no-op. Its presence merely marks the data type of its - argument. The staged TensorFlow ops will reflect and assert this data type. - Args: - entity: A Tensor or TensorArray. - dtype: TensorFlow dtype value to assert for entity. - shape: Optional shape to assert for entity. - Returns: - The value of entity, unchanged. - """ - del dtype - del shape - return entity +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_collatz_speed_test.ipynb b/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_collatz_speed_test.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..c10a5741f640be5ab7d2604dd32f2f4d6ddf1a22 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_collatz_speed_test.ipynb @@ -0,0 +1,299 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "aQkTGc-d8I1k" + }, + "source": [ + "This notebook runs a basic speed test for a simple algorithm that implements the process described in Collatz Conjecture.\n", + "\n", + "https://en.wikipedia.org/wiki/Collatz_conjecture" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "x5ChBlH09jk_" + }, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "X-QAUpWdPxUh" + }, + "outputs": [], + "source": [ + "!pip install -U -q tf-nightly" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "wiKQu3w05eCa" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from matplotlib import pyplot as plt\n", + "import tensorflow as tf\n", + "from tensorflow.contrib import autograph as ag\n", + "from tensorflow.python.eager import context" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "_cRFTcwT9mnn" + }, + "source": [ + "### Plotting helpers" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "ww7rc0GQ9pMu" + }, + "outputs": [], + "source": [ + "def plot_results(counts, times, title):\n", + " plt.plot(counts, np.array(times) * 1000., 'o')\n", + " plt.ylabel('Time (milliseconds)')\n", + " plt.xlabel('Collatz counter')\n", + " plt.title(title)\n", + " plt.ylim(0, 30)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ESZGw9s9-Y5_" + }, + "source": [ + "### Collatz function definition" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "qeunWm9m-dT7" + }, + "outputs": [], + "source": [ + "def collatz(a):\n", + " count = 0\n", + " while a \u003e 1.1:\n", + " if a % 2 \u003c 0.1:\n", + " a //= 2\n", + " else:\n", + " a = 3 * a + 1\n", + " count += 1\n", + " return count\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "nnFmPDvScsDo" + }, + "source": [ + "# AutoGraph" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 301 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 9153, + "status": "ok", + "timestamp": 1531757473651, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "6fU4vlxYcsDe", + "outputId": "11b50f28-aced-4506-a743-4b749e9645c3" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAEcCAYAAAAydkhNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XtcVGXCB/DfGRBUQA0ZURQvyIspm1reQkxNSPICgoqW\npWZu1vbmjZJV3Jc+axappVLu7guV25rU5g3wlq3iBd1wXHSN3hXy9ZaCgoOIIKAzMOf9g5dZkTkz\nB5i7v+9fzJlzzjzPHD2/Oc/znOcIoiiKICIiMkBh6wIQEZH9YkgQEZEkhgQREUliSBARkSSGBBER\nSWJIEBGRJIYEkQNYsWIFkpOTbV0MegQxJMipzJ49G8OHD4dWq5W9zeOPP45r164163O2bt2KqKgo\nDB48GKNGjcKcOXOwf//+5haXyO4xJMhpFBUV4fTp0xAEAYcPH5a9nSAIzfqc9957D1999RVWrFiB\nU6dO4fjx41iyZAmOHz8uuQ3vWSVHxZAgp5GRkYHBgwdj6tSpSE9P1y+fPXs2duzYoX+dnp6OWbNm\nAQBefvlliKKIqKgoPPXUU/juu+8AANu2bcP48eMxYsQIvPnmm7h58yYA4PLly/jmm2+wYcMGhISE\nwM3NDYIg4KmnnkJSUlKjz9ywYQNefPFFDB48GIWFhdi1axcmTpyIp556Cs899xy+/fZb/fqnTp3C\nmDFjkJKSgqeffhphYWHYs2dPo/rduXMHr7/+Op566inMnDmz2Vc/RC3BkCCnkZmZiaioKEyePBkn\nTpxAWVmZ5LoNVw9bt24FAOzevRtnzpzBhAkTkJOTg/Xr1+OTTz7BiRMn4Ofnh7i4OACASqVCt27d\nMGDAAJPl2bNnD1avXo0zZ86gW7du6Ny5M1JTU3HmzBkkJSUhKSkJ+fn5+vVLS0tRXl6O48eP48MP\nP0RiYiKuXLmif3/fvn1YuHAhcnNz4e/vj40bN7bkayJqFoYEOYXc3Fxcv34dEyZMQHBwMHr27Nnk\nl7hce/fuxfTp0/H444+jTZs2iIuLw9mzZ3H9+nXcvn0bSqWy0fpjxozBsGHDMHDgQNy4cUO/PCYm\nBn379oVCoYCrqyvGjBmDHj16AACGDh2K0NBQ5Obm6tcXBAFLlixBmzZtMGzYMIwZM0Z/ZQMA48eP\nx69+9SsoFApERkY2ChgiS2FIkFPIzMzEqFGj0LFjRwDApEmTkJGR0aJ93bx5E35+fvrX7du3R6dO\nnVBSUoJOnTrpm54aHDt2DCdPnoRWq23U99C1a9cm682cORMjRozAsGHDkJ2djdu3b+vf79ChA9zd\n3fWv/fz8Gn2Wj4+P/u927dqhurq6RfUjag5XWxeAqLXu37+P7777DjqdDqNGjQIAaDQaVFZWoqCg\nAO3bt8e9e/f066vVaqP769KlC65fv65/XV1djfLycvj6+qJTp05YvXo1/vWvfyE4OLjRdg93Tj/Y\nIa7RaLB48WKsW7cOYWFhUCgU+M///M9G21RUVODevXto27YtAODGjRsICgpq5rdBZF68kiCHd/Dg\nQbi4uOC7775DZmYmMjMz8d1332Ho0KHIzMxE//798be//Q337t3DL7/8gp07dzba3sfHp1En8OTJ\nk7Fr1y4UFBRAo9Fg/fr1GDRoEPz8/NCnTx/MnDkTcXFx+OGHH3D//n3odDqcOXPG6CgprVYLrVaL\nxx57DAqFAseOHcPf//73RuuIoohPPvkEWq0Wubm5OHr0KCZMmGDeL4uomXglQQ4vIyMD06ZNg6+v\nb6Pls2bNwvvvv4+9e/fip59+QmhoKPr164fIyEjk5OTo11u4cCHi4+Oh0WiwatUqPP/881i8eDEW\nLlyIiooKPPnkk1i/fr1+/cTERGzduhVJSUm4du0avLy80Lt3b2zcuFHfTPVwYHh4eGDlypVYvHgx\ntFotnn32WYSFhTVaR6lUomPHjnjmmWfQvn17rFq1Cr179zbzt0XUPIIlHzqk0Wjw0ksvQavVoq6u\nDhEREXjrrbdQWFiIuLg43LlzB8HBwVi7di1cXZlX9Og6deoU4uPjcfToUVsXhagRizY3ubm5YcuW\nLcjIyEBGRgays7Px448/4qOPPsK8efPw/fffw8vLq9EYdiIish8W75No164dgPqritraWgiCAJVK\nhYiICAD1wwQPHjxo6WIQEVELWDwkdDodoqOjERoaitDQUPj7+6NDhw5QKOo/umvXrk2GFBI9aoYP\nH86mJrJLFg8JhUKhb2rKy8vDxYsXm6zT3LlziIjIOqw2BNbT0xPDhg3Djz/+iIqKCuh0OgBAcXEx\nunTpYnJ7TpBGRGR9Fh1SVFZWhjZt2sDLywv37t1DTk4OFixYgBEjRuDAgQOYOHEi0tPTmwwFNEQQ\nBKjVlZYsrk0plV6sn4Ny5roBrJ+jUyq9WrW9RUNCrVZj+fLl0Ol00Ol0mDhxIsaMGYOAgADExcUh\nOTkZ/fv3x/Tp0y1ZDCIiaiGL3idhbs6e9qyfY3LmugGsn6Nr7ZUEp+UgIiJJDAkiIpLEkCAiIkkM\nCSIiksSQICIiSQwJIiKSxJAgIiJJDAkiIpLEkCAiIkkMCSIiksSQICIiSQwJIiKSxJAgIiJJDAki\nIpLEkCAiIkkMCSIiksSQICIiSQwJIiKSxJAgIiJJDAkiIpLEkCAiIkkMCSIiksSQICIiSQwJIiKS\nxJAgIiJJDAkiIpLEkCAiIkkMCSIikuRqyZ0XFxcjPj4epaWlcHFxwYwZMzB79mxs2rQJ27ZtQ+fO\nnQEAS5cuxejRoy1ZFCIiagGLhoSLiwtWrFiB/v37o6qqClOnTsXIkSMBAPPmzcO8efMs+fFERNRK\nFg0JpVIJpVIJAPDw8EDfvn1x8+ZNAIAoipb8aCIiMgOr9UkUFhaioKAAAwcOBACkpaVhypQpWLly\nJSorK61VDCIiagarhERVVRUWLVqEhIQEeHh4YNasWTh06BAyMzPh4+ODpKQkaxSDiIiaSRAt3O5T\nW1uL119/HaNHj8bcuXObvF9UVIQ33ngDe/bssWQxiIioBSzaJwEACQkJCAwMbBQQarVa31dx8OBB\nBAUFydqXWu28zVJKpRfr56CcuW4A6+folEqvVm1v0ZA4ffo09uzZg6CgIERHR0MQBCxduhR79+5F\nfn4+FAoFunfvjlWrVlmyGERE1EIWDYkhQ4YgPz+/yXLeE0FE5Bh4xzUREUliSBARkSSGBBERSWJI\nEBGRJIYEERFJYkgQEZEkhgQREUliSBARkSSGBBERSWJIEBGRJIYEERFJYkgQEZEkhgQREUliSBAR\nkSSGBBERSWJIEBGRJIYEERFJYkgQEZEk2Y8vvXfvHtRqNdzd3dGlSxdLlomIiOyE0ZDQ6XTIyMjA\n9u3bUVBQAE9PT2g0Gri6uiI8PByvvPIK+vTpY62yEhGRlRkNiRdffBGDBw/GihUrEBwcDBcXFwDA\nrVu3cPz4cSQmJuKFF17ApEmTrFJYIiKyLkEURVHqzbKyMnh7exvdgZx1zEWtrrTK59iCUunF+jko\nZ64bwPo5OqXSq1XbG+24NnTyv3XrFs6ePWt0HSIicg6yRjfNmjULlZWVqKioQHR0NFauXIk1a9ZY\numxERGRjskKiuroaXl5eOHLkCCIjI7Fnzx6cOHHC0mUjIiIbkxUSGo0GAKBSqTBy5EgoFAp9JzYR\nETkvWSExfPhwREREIDc3F8OHD0dFRQUUCt6HR0Tk7GTdTPfuu++ioKAA/v7+cHNzw927d7F69WpL\nl42IiGzMaEhcuHBB/3ebNm1QXFysf+3m5ma5UhERkV0wGhILFiyAIAgQRRE3btyAp6cnAODu3bvo\n1q0bDh8+bHTnxcXFiI+PR2lpKVxcXBAbG4s5c+bgzp07WLp0KYqKitCjRw9s3LgRXl6tG8tLRETm\nZzQkGkJg9erVGDJkCCZMmAAAOHDgAM6dO2dy5y4uLlixYgX69++PqqoqTJ06FaGhodi1axdCQkLw\n2muvITU1FSkpKXjnnXfMUB0iIjInWb3PeXl5+oAAgOeffx4nT540uZ1SqUT//v0BAB4eHujbty9K\nSkqQlZWFmJgYAEBMTAwOHTrUkrITEZGFyQqJmpoa5Obm6l/n5uaipqamWR9UWFiIgoICDBo0CLdu\n3YKPjw+A+iC5fft2s/ZFRETWIXt0U1xcHNq1awcAuH//Pj7++GPZH1JVVYVFixYhISEBHh4eEASh\nRYVt7Rwk9o71c1zOXDeA9XuUyQqJoUOH4tChQ7h8+TJEUURAQIDs0U21tbVYtGgRpkyZgvDwcABA\n586dUVpaCh8fH6jVatnzPzn7JFysn2Ny5roBrJ+js+gEfw+qq6uDm5sbXF1dcfXq1UbDY41JSEhA\nYGAg5s6dq182btw47Nq1CwCQnp6OsLCwZhabiIisQdaVRFpaGj766CN06tRJ31QkCAKysrKMbnf6\n9Gns2bMHQUFBiI6OhiAIWLp0KV577TUsWbIEO3fuhJ+fH5KTk1tfEyIiMjtZIbF582bs3bsX3bt3\nb9bOhwwZgvz8fIPvffnll83aFxERWZ+s5ialUtnsgCAiIscn60pi5MiRWLt2LSZNmgR3d3f98sDA\nQIsVjIiIbE9WSGRkZACov9O6gZw+CSIicmyyQsLUHE1EROScZIUEUD8jrEqlAgA8/fTT6Nu3r8UK\nRURE9kFWx3VGRgZeeeUV5OfnIz8/H/PmzcPu3bstXTYiIrIx2UNg09PToVQqAQBqtRrz589HVFSU\nRQtHRES2JfuO64aAePhvIiJyXrJComfPnvjkk09QUlKCmzdvYtOmTfD397d02YiIyMZkhcTvf/97\nXL58GVFRUYiKisKlS5ewatUqS5eNiIhsTFafROfOnbFhwwZLl4WIiOyMrCuJ1NRUlJeX61/fvn0b\nn3/+ucUKRURE9kFWSOzbtw+dOnXSv37sscewd+9eixWKiIjsg6yQEEWxybK6ujqzF4aIiOyLrJDo\n3bs3/vznP0MUReh0OmzevBk9e/a0dNmIiMjGZIXEypUrceTIEQwcOBCDBw/GsWPHkJiYaOmyERGR\njcka3eTr64stW7aguroaANC+fXuLFoqIiOyD7D6J7du3449//CPat2+PwsJCnDlzxtJlIyIiG5MV\nEklJSTh58iQOHToEAPDw8MAHH3xg0YIREZHtyQoJlUqFjz76CG3btgVQPwT2/v37Fi0YERHZnqyQ\ncHd3hyAI+tc6nc5iBSIiIvshq+M6KCgIu3fvhiiKKCwsRGpqKoYMGWLpshERkY3JupJYvnw5Tp06\nBbVajdjYWNTV1WHZsmWWLhsREdmYrCsJT09PrF692tJlISIiOyPrSmL//v24e/cuACA5ORnz58/H\n//zP/1i0YEREZHuyQuJPf/oTPD09kZeXhxMnTiA6OppXFkREjwBZIeHqWt8q9fe//x2xsbGIjIzk\nEFgiokeArJAQBAG7d+/Gvn37EBISAgDQarUWLRgREdmerJD43e9+hwMHDiA2Nhb+/v64cuUKRowY\nYXK7hIQEjBw5EpGRkfplmzZtwujRoxETE4OYmBhkZ2e3vPRERGRRgmjoYRFmkpubCw8PD8THx2PP\nnj0A6kPCw8MD8+bNa/b+1OpKcxfRbiiVXqyfg3LmugGsn6NTKr1atb3RIbB/+ctfMHfuXKxdu9bg\n+/Hx8UZ3PnToUBQVFTVZbsFcIiIiMzIaEu7u7gDMPzV4WloaMjMz8atf/QrLly+Hl1frko6IiCzD\nos1NAFBUVIQ33nhD39xUVlaGxx57DIIgYMOGDVCr1ZxRlojIThm9kkhLSzO68UsvvdTsD/T29tb/\nPWPGDLzxxhuyt3X2dkPWzzE5c90A1s/RWbRPwhx3VT98oaJWq6FUKgEABw8eRFBQUKs/g4iILMNo\nSCQlJbVq52+//TZUKhXKy8sxduxYLFy4ECqVCvn5+VAoFOjevTtWrVrVqs8gIiLLMRoSx44dM7rx\nmDFjjL7/8ccfN1k2bdo0GcUiIiJ7YDQkPv/8c8n3BEEwGRJEROTYjIbEV199Za1yEBGRHTIaEteu\nXYO/vz8uXLhg8P3AwECLFIqIiOyD0ZBYvXo1UlJSsGDBgibvCYKArKwsixWMiIhsz2hIpKSkAAAO\nHz5slcIQEZF9kfX4UgCoqalBcXEx6urq9MvY3ERE5NxkhcSWLVuwYcMGdOzYEQpF/ezibG4iInJ+\nskLiL3/5Cw4cOABfX19Ll4eIiOyIrIcOde3alQFBRPQIknUlsXDhQqxcuRJjxozRTx8OmL7jmoiI\nHJuskDhy5AiOHDmCK1euNOqTYEgQETk3WSFx8OBBHD58GG3btrV0eYiIyI7I6pPw9/eHq6vs0bJE\nROQkZJ35e/Xqhblz5yI8PBxubm765S156BARETkOWSGh1WrRs2dPnD9/3tLlISIiOyIrJFr78CEi\nInJMRvskTD2+VKPR4OLFi2YtEBER2Q+TE/zV1NRg8uTJGDRoEHx8fHD//n1cvnwZx48fx7Fjx7B8\n+XL07dvXWuUlIiIrMhoSn376KfLy8vDtt9/iD3/4A4qLi9GuXTsEBQUhPDwcaWlp8PT0tFZZiYjI\nykz2SQwcOBADBw60RlmIiMjOyLpPgoiIHk0MCSIiksSQICIiSQwJIiKSJCskbt26hXfeeUc/DUdB\nQQG++eYbixaMiIhsT1ZI/O53v8OQIUNQUVEBAAgICMDXX39t0YIREZHtyQqJkpISvPjii3BxcQEA\nuLm56Z8rQUREzkvWmf7hacIrKiogiqJFCkRERPZDVkiMHz8eiYmJqKqqwq5du/Dqq69i2rRpJrdL\nSEjAyJEjERkZqV92584dvPrqq4iIiMD8+fNRWVnZ8tITEZFFyQqJX//61xg6dCiCg4Nx7NgxzJ49\nG3PnzjW53dSpU/HFF180WpaamoqQkBB8//33GDFiBFJSUlpWciIisjjZj5uLiopCVFRUs3Y+dOhQ\nFBUVNVqWlZWFrVu3AgBiYmIwe/ZsvPPOO83aLxERWYeskLh16xa2bt2Kq1evora2Vr88OTm52R9Y\nVlYGHx8fAIBSqcTt27ebvQ8iIrIOWSHx5ptvYsCAAQgJCdGPcLIFpdLLZp9tDayf43LmugGs36NM\nVkjU1NTg3XffNcsHdu7cGaWlpfDx8YFarYa3t7fsbdVq5+3kViq9WD8H5cx1A1g/R9faAJTVcT1o\n0CD8/PPPLfqAh4fKjhs3Drt27QIApKenIywsrEX7JSIiy5N1JfHCCy/g5ZdfRteuXeHu7q5fvmPH\nDqPbvf3221CpVCgvL8fYsWOxcOFCLFiwAIsXL8bOnTvh5+fXon4NIiKyDlkhsWzZMrzxxhsYMGBA\ns/okPv74Y4PLv/zyS9n7ICIi25EVEu7u7pg/f76ly0JERHZGVp/EM888g+zsbEuXhYiI7IysK4lt\n27YhNTUVHh4ecHNzgyiKEAQBOTk5li4fERHZkKyQ2Llzp6XLQUREdkhWSHTv3t3S5SAiIjtkNCSW\nLVuGdevWYdq0aRAEocn7pobAEhGRYzMaEg0zvf72t7+1SmGIiMi+GA2Jr7/+Gh988AGGDx9urfIQ\nEZEdMToENj8/31rlICIiO8QHVRMRkSSjzU3nz59HSEhIk+W8T4KI6NFgNCR69+6N1NRUa5WFiIjs\njNGQcHNz4z0SRESPMKN9Em3atLFWOYiIyA4ZDYlt27ZZqxxERGSHOLqJiIgkMSSIiEgSQ4KIiCQx\nJIiISBJDgoiIJDEkiIhIEkOCiIgkMSSIiEgSQ4KIiCTJesY1EdGjTnWuBPtyruB6aTX8fNpjUkhv\njBjga+tiWRxDgojIBNW5EqTs/pf+daG6Sv/a2YOCIUFEVuHIv8T35VyRWP6Lw9ShpRgSRGRxjv5L\n/HpptcHlN25VWbkk1mezkBg3bhw8PT2hUCjg6uqKHTt22KooRGRhjv5L3M+nPQrVTQOhW2cPG5TG\numwWEoIg4KuvvkLHjh1tVQQishJH/yU+KaR3oyuhfy/v1ei1IzepSbFZSIiiCJ1OZ6uPJ3I69nyC\nsuYvcUt8Dw3b78v5BTduVaFbZw9MCunVaL+O3qQmxaZXEvPnz4cgCJg5cyZmzJhhq6IQOTx7P0HJ\n/SXeWpb8HkYM8DW6D0dvUpNis5D461//CqVSibKyMsybNw8BAQEYOnSorYpDZHXm/MVr7ycoOb/E\nzcGW34OjN6lJsVlIKJVKAIC3tzeee+45/PTTTyZDQqn0skbRbIb1c1zNrVv2PwsN/uLt0KEtRj/Z\no9F627P+F1dLKtHT1wuxYf/R6P0G129Jn6DM8b2bYx+Tx3hh8pjAZm0jt/4NWvo9mKN+Pbt64cqN\niibL/X29HPrfvk1CoqamBjqdDh4eHqiursaJEyfw1ltvmdxOra60QulsQ6n0Yv0cVEvq9s33BRLL\nf0b/HvWDOR5uOrlyowLrtp5GRcW9Jr+K/TpLt/m39nu31bFrTv0btOR7MFf9Iob5G2xSixjmb9N/\n+60NKJuERGlpKd566y0IgoC6ujpERkZi1KhRtigKkU3IaZpoTtOJsTZ/e+7QNqYlTUfW6vswxFpN\natZmk5Dw9/dHZmamLT6ayCKaeyI2NdpHda7E4PuA4TZuqRMUALvu0DbGVJAa+85tdaI21bntiHjH\nNVErSfUvANInYlO//A2916Cjp5vB5YZOUIlfqAyuay8d2sYYC1JTo5jsvW6OhFOF0yNLda4EiV+o\n8Os1R5D4hQqqcyUt2s/2rP81uHxfzi+S24wY4IvXo4LRQ+kJF4WAHkpPvB4VjBEDfCWbWRqUVdyX\nXVZHHnEzKaS3xPJeRpuiyLx4JUGPJHOOp79aYrhT0tSJWOoXr9SJ/UFyrwQceToJY01Hn+05Z3Ab\nRwg/R8OQoEeSOcfT9/Q1PPSxpSdiqRP7gwrVd/HrNUdM9n/YsiPXHKSC1JHDz9GwuYkeSVK/1otK\n7za7CSo27D8MLm/piViqmeVhOlHUXwFJldNYs5YjM9YURebFKwl6JEn9EhVF6JfLbYIa/WQPVFTc\nM9uImobtth+5gLLK+7K2MXYF5IwdubYexfQoYUhQqznSOPyGshaVym+7bugMNVZHS5yI5QYE0LQt\n3pGOSUs5Y/jZI4YEtZjqXEmTX7v2NA7/4RNlv56PIet0YZP1BAF4zNNd8qRcVHrX6vcaSPWZtHFR\nQFvXdPbkB9vi7X2yP3IsDAlqEVNj+c05Dr8lv4oNnSilOoMf83JHWYX0r3ZXheETsyXvNZDqM6mV\nmF7/wbZ4e5/sjxwLQ4JaxNRYfnMNRZTzq9hQiJgq34OMBQQgfWK25HBLqT6T7j6e/3+fgHRbvCPf\nG0H2hyHhpCzdJm1qLL+5hiJuP3LB8PKjFzBigK9kiAiCWT4er0cFY1/OFbMPtzR1fIwNXTXVFs/h\noWRODAknZI02aVNj+Zs7FFHqpCnVT9Dw61/qikGqicgQby/D/RE9lJ7678uc9xrIOT6tGb3j6PdG\nkH1hSDgha7RJS52IvDu4I3ZsYLM+x9hJ09R2zW27DxvSAz9fLTc6CV6DhvfMPdxS7vFp6egdDg8l\nc2JIOCFrtEmb80Rk7KQp9SsfqD+xu0g0K8lpuzf0eVLrmnO4pbWOD0OBzIEh4YRMzZ7Zmr4KS/R1\nSJ00C9V3YaproU40vFxO2/2DrHlSZZ8BORKGhBOSagrq17NTq/oqpJqFUnf/C92VHvrASDt4Htln\ni6CtE+EiAO3atkH1/Vr4dTYcKsb6NyQyoIk2LgroRNEhmlbYZ0COhCHhhKSaglrTV6E6V4LN+wzP\nvCni34Fx4qcb+NflMv17dSJwt0YLQDqUpE6azaETRXwW/2yr9mEt7DMgR8KQcFKGmk9aOr2yqRvn\nHvRgQEgx1EELGO48lsvRmmrYZ0COgiHhhKT6DVraFt6cG9PkkHr8ptT9CA28vdwBwfDNb2yqIbIM\nThXuZBp+9Reqq5pMJd3S6ZXlPASnOaRCydQU2bHPBuKjN0OdcuprInvFKwknY6zfYdX84fq/m9MW\nLuchOA2C+3ibbHKSCqUH2+qLSu/CVaFAnU4Hv/8fzvrgjWYMBSLrYEg4GVNj8FtygpXqWDZ0Y9q/\nRzddh7ZOBxdBQLu2rqi5XysrlBgARPaFIWFhDz6/QCEIqNPVD+r09nJH7LPNuzNZDkuMwW/uaJyX\nngvCS88FNVqmVHpBrTb8LGgisl8MiYcYegbBz1dvm7x5zFBnMdB4xE6d+O9R/2WV9y0yx7+lxuDz\nFz7Ro4kh8QBTzyCQGucvdZOZt5e7yc809xz/HINPRObEkHiA3KGeD5/YpbaT8/hJS8zxz1/9RGQu\nHAL7ALlDPR8+sbdmiKij3QRGRI8WhsQD/Hzay1rv4RO71HZympt4ExgR2TObhUR2djaef/55RERE\nIDU1tUX7UJ0rQeIXKvx6zREkfqGC6lxJq8pk6mauf6/X66HXhreLfTZQf+OXQgBcFP+e09Tby503\ngRGR3bNJn4ROp8N7772HL7/8El26dMH06dMRFhaGvn37Sm4zZdnuRrOIWuLpa4Y6ffv17GTwXgBT\n2z188xcRkSOySUjk5eWhV69e6N69OwBg0qRJyMrKMhoSOp3YKAgs9fS11jwNjGFARM7GJs1NJSUl\n6Natm/61r68vbt68KXv7fTm/WOXpXkREjzqbhIQoyn2UjGE3blVJdhZztBARkfnYpLmpa9euuH79\nuv51SUkJunTpInt7f18vxIb9B9ZtPd3kvRcj+kGp9DJLOa3NUcstlzPXz5nrBrB+jzJBbO3P+hao\nq6vD888/jy+//BJKpRKxsbFYv3690T4JIiKyPptcSbi4uOC//uu/8Oqrr0IURUyfPp0BQURkh2xy\nJUFERI6Bd1wTEZEkhgQREUliSBARkSS7DwlzzPFkb8aNG4eoqChER0dj+vTpAIA7d+7g1VdfRURE\nBObPn4/KSsd5iltCQgJGjhyJyMhI/TJj9Vm9ejXGjx+PKVOmID8/3xZFbhZD9du0aRNGjx6NmJgY\nxMTEIDs7W/9eSkoKxo8fjwkTJuDEiRO2KLJsxcXFmDNnDiZOnIjIyEhs2bIFgPMcv4fr99VXXwFw\nnuOn0WiMkXxQAAAKiUlEQVQQGxuL6OhoREZGYtOmTQCAwsJCzJgxAxEREYiLi0Ntba1+/aVLl2L8\n+PGYOXNmo1sRJIl2rK6uTgwPDxcLCwtFjUYjRkVFiRcuXLB1sVpt3LhxYnl5eaNla9euFVNTU0VR\nFMWUlBRx3bp1tihai/zjH/8Qz507J06ePFm/TKo+R48eFV977TVRFEXx7NmzYmxsrPUL3EyG6vfp\np5+KmzdvbrLuhQsXxClTpoharVa8du2aGB4eLup0OmsWt1lu3rwpnjt3ThRFUbx79644fvx48cKF\nC05z/KTq5yzHTxRFsbq6WhRFUaytrRVjY2PFs2fPiosXLxb3798viqIoJiYmit98840oiqKYlpYm\nvvvuu6IoiuK+ffvEJUuWmNy/XV9JPDjHU5s2bfRzPDk6URSh0+kaLcvKykJMTAwAICYmBocOHbJF\n0Vpk6NCh6NChQ6NlD9en4bhlZWUhOjoaADBo0CBUVlaitLTUugVuJkP1AwzPHJCVlYWJEyfC1dUV\nPXr0QK9evZCXl2eNYraIUqlE//79AQAeHh7o27cvSkpKnOb4GapfwxRAznD8AKBdu3YA6q8Samtr\nIQgCVCoVIiIiADQ+nzx4XCMiIpCTk2Ny/3YdEq2d48leCYKA+fPnY9q0adi+fTsA4NatW/Dx8QFQ\n/w/79u3btixiq5WVlTWqT1lZGQDg5s2b6Nq1q349X19flJS0bop3W0lLS8OUKVOwcuVKfXOMoX+z\njlK/wsJCFBQUYNCgQU3+PTrD8Wuo38CBAwE4z/HT6XSIjo5GaGgoQkND4e/vjw4dOkChqD+9d+3a\nVV+HB4+fi4sLOnTogPLycqP7t+uQMJT0zuCvf/0rdu3ahc8++wxpaWnIzc2FIAimN3QCho6pI9Z9\n1qxZOHToEDIzM+Hj44MPP/wQgOPWr6qqCosWLUJCQgI8PDwky+ws9XOm46dQKJCRkYHs7Gzk5eXh\n4sWLTdZpqMPD9RNF0WT97DokWjvHk71SKpUAAG9vb4SHhyMvLw+dO3fWX7ar1Wp4e3vbsoitJlUf\nX19fFBcX69crLi52yGPq7e2t/881Y8YMfZNE165dcePGDf16jlC/2tpaLFq0CFOmTEF4eDgA5zp+\nhurnTMevgaenJ4YNG4Yff/wRFRUV+ibtB+vw4PGrq6vD3bt30bFjR6P7teuQeOKJJ3D16lUUFRVB\no9Fg3759CAsLs3WxWqWmpgZVVfXTmVdXV+PEiRMICgrCuHHjsGvXLgBAenq6w9Xz4V8oUvUJCwtD\nRkYGAODs2bPo0KGDvlnDnj1cP7Varf/74MGDCAoKAlBf7/3790Oj0eDatWu4evWqvnnDXiUkJCAw\nMBBz587VL3Om42eofs5y/MrKyvRNZffu3UNOTg4CAwMxYsQIHDhwAEDj4zdu3Dikp6cDAA4cOICn\nn37a5GfY/bQc2dnZeP/99/VzPC1YsMDWRWqVa9eu4a233oIgCKirq0NkZCQWLFiA8vJyLFmyBDdu\n3ICfnx+Sk5MNdpbao7fffhsqlQrl5eXw8fHBwoULER4ejsWLFxusz6pVq3D8+HG0a9cOSUlJCA4O\ntnENjDNUP5VKhfz8fCgUCnTv3h2rVq3SnyxTUlKwY8cOuLq6YuXKlRg1apSNayDt9OnTePnllxEU\nFARBECAIApYuXYqBAwdK/nt0pOMnVb+9e/c6xfH7+eefsXz5cuh0Ouh0OkycOBG/+c1vcO3aNcTF\nxaGiogL9+/fHunXr0KZNG2g0Gixbtgz5+fno1KkT1q9fjx49ehj9DLsPCSIish27bm4iIiLbYkgQ\nEZEkhgQREUliSBARkSSGBBERSWJIEBGRJIYE2b3a2lokJycjIiICkZGRmDRpEtasWYO6ujqj261Y\nsQJpaWkA6qeGXrt2rcnPOnToEH766SezlNsSioqKsG3bNlsXgx4hDAmye8uXL8fFixeRkZGBPXv2\nYPfu3QgICIBGozH7Z2VlZdn1rJ+FhYX49ttvW7StqVAlMsTV1gUgMuaXX35BVlaW/g5foH72ytjY\nWAD1M2CuW7dO/3CYUaNGIT4+3uikZefPn8fvf/971NTUQKPRYMaMGZgzZw5OnDiBw4cPIycnBzt2\n7MArr7yCwsJCHDx4EIIgQKPR4NKlS/jHP/4BT0/PRvv85z//iXXr1qGqqgqCICA+Ph4jR45EXl4e\nPvjgA9TU1KBdu3ZYuXIlnnjiCZw6dQpr1qzBzp07AaDR61OnTuGDDz7AwIEDcfbsWSgUCqxfvx4B\nAQF47733UFRUhJiYGPTs2RPJycm4dOkSkpKSUF5eDq1Wizlz5mDq1KkAgMcffxzLli3D0aNHMWzY\nMCxatMjsx4icnFmeekFkIfv37xejo6Ml3//666/FefPmibW1taJWqxXnzp2rf8DK8uXLxa1bt4qi\nWP+QoDVr1oiiKIpVVVWiRqPR/z1x4kTx4sWLTbZ52LJly8QPP/ywyfLy8nIxNDRUPHv2rCiKoqjT\n6cSKigpRo9GIY8eOFXNyckRRFMUffvhBHDt2rKjVakWVSiVOmzZNv48HX6tUKjE4OFjMz88XRVEU\n//SnP4nvvPNOk/VEsf5BMzExMeKlS5dEUax/sE5ERIT+db9+/cTPP/9c8vsjMoVXEmTXRBOzxuTk\n5CAmJgYuLi4AgKlTp+LQoUN44YUXJLepqanBu+++i4KCAigUCqjVahQUFCAgIEBym40bN6Kmpga/\n/e1vm7x39uxZBAYGYtCgQQDqp2X28vLC+fPn4ebmpp9ELSQkBG5ubrh8+bLJevfp0wePP/44gPqH\n+xw9etTgeleuXMGlS5cQFxen/660Wi0uXryIPn36AID+IUFELcGQILsWHByMK1euoLKyEl5eXk3e\nFw3Mh29qfvz169dDqVRi7dq1+gdAGevf2LlzJ06ePKl//rOhMshd3lBeFxeXRk8nvH//fqP13N3d\n9X+7uLjon1FsaH/e3t76mT0fJggC2rdvb/A9IjnYcU12rVevXhg3bhwSExP1U6zX1dVhy5YtqKmp\nwciRI5Geno7a2lpotVpkZGQgNDTU6D4rKyvRrVs3CIKA8+fPIzc3V/+eh4cH7t69q3/9ww8/4LPP\nPsMf//hHuLm5Gdzfk08+iQsXLuDHH38EUN9PUlFRgYCAAGi1Wpw6dQoAcPLkSdTW1qJ3797o0aMH\nCgsLUVlZCVEUsW/fPlnfh6enp35qaKD+iqNt27bIzMzUL7t06ZL+uzJ1JUZkCq8kyO6tWbMGn376\nKaZOnQo3NzeIoojRo0fDzc0NM2fOxNWrV/XP7X3mmWf0ndpSfvOb3yA+Ph67d+9Gz549MWzYMP17\nU6ZMwYoVK3DgwAG88sor2LlzJ2pqajB//nz9VUBaWlqjX+cdO3bEpk2bkJSUhOrqari4uCA+Ph4h\nISH45JNPsHr1an3H9aeffgpXV1f4+vpi3rx5iImJgb+/P5544glcuHDB5HfRr18/9OnTB5GRkQgI\nCEBycjL++7//G++//z42b96Muro6+Pj4YOPGjQDs/6lqZP84VTgREUlicxMREUliSBARkSSGBBER\nSWJIEBGRJIYEERFJYkgQEZEkhgQREUliSBARkaT/AzLfG+oMx+5pAAAAAElFTkSuQmCC\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7fc3b259add0\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "counts = []\n", + "times = []\n", + "for n in np.logspace(0, 7, 50):\n", + "\n", + " with tf.Graph().as_default():\n", + " tf_collatz = ag.to_graph(collatz)\n", + " count = tf_collatz(tf.constant(n, dtype=tf.float32))\n", + " with tf.Session() as sess:\n", + " count_value = sess.run(count)\n", + "\n", + " res = %timeit -n10 -r1 -o -q sess.run(count)\n", + " counts.append(count_value)\n", + " times.append(res.best)\n", + " \n", + "plot_results(counts, times, 'AutoGraph')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "RRENYzLRF_f3" + }, + "source": [ + "# Eager" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 301 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 5003, + "status": "ok", + "timestamp": 1531757478713, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "dhDf8LLdF_f-", + "outputId": "3de0a5a5-7a11-4b41-8ab0-e4e21ce8d59b" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAEcCAYAAAAydkhNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XtYVWW+B/Dv2hshBdSQHaighhwas7Qeb6GFDjIyI3LZ\nGphdJLLMzqSlKaPQsTPm5KhZkZ7moKOnManGK17wsUfIS87QNj2jnEnIg5cQEtyAyDWBvdf5g4d9\nBPbaLGCvfeP7+QvW2mvt38tGvq71vut9BVEURRAREZmhsncBRETkuBgSREQkiSFBRESSGBJERCSJ\nIUFERJIYEkREJIkhQUREktzsXQCRvYWHh6OiogJqtRqiKEIQBMyePRtvv/22vUsjsjuGBBGA9PR0\nPPHEE3Z5b4PBALVabZf3JuoMbzcRATA38cCNGzeQmJiISZMmITQ0FMuXL0dtba1p//fffw+tVotx\n48bhjTfewNKlS5GWlmbaf+LECcTFxWHChAmYN28efvjhB9O+8PBwbNu2DTExMXj88cdhNBqVbSBR\nNzEkiCSIoohFixbhb3/7G44ePYqysjJs3rwZANDU1ITFixdjzpw5OHv2LGbNmoXjx4+bjv3++++R\nmpqKd999F2fPnsXcuXPx2muvoampyfSao0ePYtu2bTh37hxUKv5TJMfE30wiAL/97W8xceJETJgw\nARMnTsSePXswbNgwhIaGws3NDffffz8SExPx3XffAQAuXLgAg8GA559/Hmq1Gr/61a8wZswY0/n2\n7NmDZ555Bo8++igEQUBcXBzc3d1x8eJF02vmz58PPz8/uLu727y9RHKxT4IIwCeffNKhT6KyshJr\n167FuXPnUF9fD4PBgIEDBwIA9Ho9/Pz82rx+8ODBpq9/+uknHDx4ELt27QLQclXS3NyMW7dumV7j\n7++vVHOIrIYhQQTzfRKbNm2CIAg4cuQI+vfvj+zsbKxduxYAoNFoUFZW1ub1N2/exLBhwwC0BMCi\nRYvw6quvKl88kYJ4u4lIQl1dHTw9PeHl5YWysjJs377dtO+xxx6DWq1GRkYGDAYDsrOzkZeXZ9qf\nkJCAL7/80rStvr4ep06dQn19vc3bQdQTvJIgAvDaa69BpVKZnpOYMmUK3nzzTaxYsQLjx4/H8OHD\nERsbi08//RQA0KdPH2zevBmpqanYtGkTwsLCEB4ebupfeOSRR/Duu+9izZo1KCoqgoeHB8aNG4cJ\nEyYAAARBsFdTibpEUHLRocbGRjz33HNoamqCwWBAZGQkXn/9dRQXF2PZsmW4c+cORo8ejQ0bNsDN\njXlFzi0hIQHz5s2DVqu1dylEVqPo7SZ3d3fs3LkTmZmZyMzMxOnTp3Hx4kW8//77SEpKwldffQVv\nb2/s3btXyTKIFPHdd9+hvLwcBoMBBw4cwOXLl/HUU0/Zuywiq1K8T6Jv374AWq4qmpubIQgCdDod\nIiMjAQBarbbN+HIiZ3Ht2jXExsZi/Pjx+PTTT/Hxxx/D19fX3mURWZXi93iMRiNmz56NoqIiPPfc\ncwgMDET//v1NDw/5+/u3GRZI5CwSEhKQkJBg7zKIFKX4lYRKpTLdasrLy8OVK1c6vIadeEREjslm\nQ2C9vLwwYcIEXLx4EdXV1aa5akpLS/HAAw90eryC/etERCRB0dtNlZWV6NOnD7y9vfHzzz8jNzcX\nCxcuxKRJk3Ds2DHMnDkTBw4cwPTp0zs9lyAI0OtrlCzXrjQab7bPSbly2wC2z9lpNN49Ol7RkNDr\n9Vi5ciWMRiOMRiNmzpyJqVOnIigoCMuWLUNaWhpGjRqFp59+WskyiIiomxR9TsLaXD3t2T7n5Mpt\nA9g+Z9fTKwlOy0FERJIYEkREJIkhQUREkhgSREQkiSFBRESSGBJERCSJIUFERJIYEkREJIkhQURE\nkhgSREQkiSFBRESSGBJERCSJIUFERJIYEkREJIkhQUREkhgSREQkiSFBRESSGBJERCSJIUFERJIY\nEkREJIkhQUREkhgSREQkiSFBRESSGBJERCSJIUFERJIYEkREJIkhQUREkhgSREQkyU3Jk5eWliI5\nORnl5eVQq9VISEjACy+8gC1btmD37t0YNGgQAGDp0qUICwtTshQiIuoGRUNCrVZj1apVGDVqFOrq\n6jB79mxMnjwZAJCUlISkpCQl356IiHpI0ZDQaDTQaDQAAE9PT4wcORK3bt0CAIiiqORbExGRFdis\nT6K4uBgFBQUYM2YMACAjIwOxsbFITU1FTU2NrcogIqIusElI1NXVYcmSJUhJSYGnpyeeffZZZGdn\n4+DBg/D19cW6detsUQYREXWRICp836e5uRmvvvoqwsLCkJiY2GF/SUkJFi1ahMOHDytZBhERdYOi\nfRIAkJKSguDg4DYBodfrTX0Vx48fR0hIiKxz6fWue1tKo/Fm+5yUK7cNYPucnUbj3aPjFQ2J8+fP\n4/DhwwgJCUFcXBwEQcDSpUtx5MgR5OfnQ6VSYejQoVizZo2SZRARUTcpGhLjxo1Dfn5+h+18JoKI\nyDnwiWsiIpLEkCAiIkkMCSIiksSQICIiSQwJIiKSxJAgIiJJDAkiIpLEkCAiIkkMCSIiksSQICIi\nSQwJIiKSxJAgIiJJDAkiIpLEkCAiIkkMCSIiksSQICIiSQwJIiKSxJAgIiJJspcv/fnnn6HX6+Hh\n4YEHHnhAyZqIiMhBWAwJo9GIzMxM7NmzBwUFBfDy8kJjYyPc3NwQERGBF198EQ8++KCtaiUiIhuz\nGBLz5s3DY489hlWrVmH06NFQq9UAgIqKCnzzzTdYvXo1nnnmGURFRdmkWCIisi1BFEVRamdlZSV8\nfHwsnkDOa6xFr6+xyfvYg0bjzfY5KVduG8D2OTuNxrtHx1vsuDb3x7+iogIXLlyw+BoiInINskY3\nPfvss6ipqUF1dTXi4uKQmpqK9evXK10bERHZmayQqK+vh7e3N06cOIHo6GgcPnwYZ86cUbo2IiKy\nM1kh0djYCADQ6XSYPHkyVCqVqRObiIhcl6yQmDhxIiIjI3Hu3DlMnDgR1dXVUKn4HB4RkauT9TDd\nO++8g4KCAgQGBsLd3R21tbVYu3at0rUREZGdWQyJwsJC09d9+vRBaWmp6Xt3d3flqiIiIodgMSQW\nLlwIQRAgiiJu3rwJLy8vAEBtbS0GDx6Mr7/+2uLJS0tLkZycjPLycqjVasTHx2P+/Pm4c+cOli5d\nipKSEgQEBOCjjz6Ct3fPxvISEZH1WQyJ1hBYu3Ytxo0bh9/85jcAgGPHjuHSpUudnlytVmPVqlUY\nNWoU6urqMHv2bEyZMgX79+9HaGgoXnnlFWzduhXp6elYvny5FZpDRETWJKv3OS8vzxQQAPDrX/8a\n3377bafHaTQajBo1CgDg6emJkSNHoqysDDk5OdBqtQAArVaL7Ozs7tROREQKkxUSDQ0NOHfunOn7\nc+fOoaGhoUtvVFxcjIKCAowdOxYVFRXw9fUF0BIkt2/f7tK5iIjINmSPblq2bBn69u0LALh79y42\nbdok+03q6uqwZMkSpKSkwNPTE4IgdKvYns5B4ujYPuflym0D2D65Tv+jGHty/hdFZTUY5ueN+On/\ngrDHA6xybnuRFRLjx49HdnY2rl27BlEUERQUJHt0U3NzM5YsWYLY2FhEREQAAAYNGoTy8nL4+vpC\nr9fLnv/J1SfhYvuckyu3DWD75NJdKkP6oe9N31+/WY2Nu86juvpnTHrYr8fn7y5FJ/i7l8FggLu7\nO9zc3FBUVNRmeKwlKSkpCA4ORmJiomlbeHg49u/fDwA4cOAApk+f3sWyiYgcS1budYntP9q0DmuT\ndSWRkZGB999/HwMHDjTdKhIEATk5ORaPO3/+PA4fPoyQkBDExcVBEAQsXboUr7zyCt58803s27cP\nQ4YMQVpaWs9bQkRkRz+V15vdfrOizsaVWJeskNixYweOHDmCoUOHdunk48aNQ35+vtl9n376aZfO\nRUTkyIb49kOxvmMgDB7kaYdqrEfW7SaNRtPlgCAi6k2iQkdIbB9u20KsTNaVxOTJk7FhwwZERUXB\nw8PDtD04OFixwoiIHInuUhmycq/jp/J6DPHth6jQEW06pFu/zsr9ETcr6jB4kCeiQofbtdPaGmSF\nRGZmJoCWJ61byemTICJyBe1HLhXr60zftw8KZw+F9mSFRGdzNBERuTJLI5dcLRTakxUSQMuMsDqd\nDgDwxBNPYOTIkYoVRUTUXZ3dFuoOVx25JIesjuvMzEy8+OKLyM/PR35+PpKSknDo0CGlayMi6pLW\n20LF+joYRdF0W0h3qaxH5x3i28/sdmcfuSSH7CGwBw4cgEajAQDo9XosWLAAMTExihZHRNQVSt0W\nigod0aZP4v+3O/fIJTlk325qDYj2XxMROQqlbgu56sglOWSFxLBhw/Dxxx9j7ty5EAQBu3fvRmBg\noNK1ERF1iZIPtLniyCU5ZPVJ/P73v8e1a9cQExODmJgYXL16FWvWrFG6NiKiLnHVB9rsSdaVxKBB\ng/Dhhx8qXQsRUY/05ttCSpEVElu3bkVCQgIGDhwIALh9+zb27duHl19+WdHiiIi6qrfeFlKKrNtN\nWVlZpoAAgPvvvx9HjhxRrCgiInIMskJCFMUO2wwGg9WLISIixyIrJEaMGIH/+q//giiKMBqN2LFj\nB4YNG6Z0bUREZGeyQiI1NRUnTpzAmDFj8Nhjj+HUqVNYvXq10rUREZGdyeq49vPzw86dO1Ff3/Kg\nSr9+5h9RJyIi1yK7T2LPnj345JNP0K9fPxQXF+O///u/la6NiIjsTFZIrFu3Dt9++y2ys7MBAJ6e\nnnjvvfcULYyIiOxP1u0mnU6HzMxMaLVaAC1DYO/evatoYUREligxJTh1JCskPDw8IAiC6Xuj0ahY\nQUREnZG7Uhz1nKzbTSEhITh06BBEUURxcTH+/d//HePGjVO6NiIisyxNCU7WJSskVq5cibNnz0Kv\n1yM+Ph4GgwErVqxQujYiIrN680pxtibrdpOXlxfWrl2rdC1ERLIoOSU4tSXrSuLo0aOora0FAKSl\npWHBggX45z//qWhhRES6S2VYvV2Hl9efwOrtOtMypJwS3HZkhcSf/vQneHl5IS8vD2fOnEFcXByv\nLIhIUZbWq570sB9ejRmNAI0X1CoBARovvBozmp3WCpB1u8nNreVlf/vb3xAfH4/o6Gjs2LFD0cKI\nqHfrbL1qTgluG7KuJARBwKFDh5CVlYXQ0FAAQFNTk6KFEVHvxs5pxyArJN5++20cO3YM8fHxCAwM\nxPXr1zFp0qROj0tJScHkyZMRHR1t2rZlyxaEhYVBq9VCq9Xi9OnT3a+eiFzWEF/zc8Sxc9q2BNHc\nYhFWcu7cOXh6eiI5ORmHDx8G0BISnp6eSEpK6vL59Poaa5foMDQab7bPSbly2wD7ta/9A3OtrN33\n0Bs+v56w2Cfxl7/8BYmJidiwYYPZ/cnJyRZPPn78eJSUlHTYrmAuEZGD6uo0Glyv2jFYDAkPDw8A\n1p8aPCMjAwcPHsQjjzyClStXwtu7Z0lHRI6tu9NosHPa/hS93QQAJSUlWLRokel2U2VlJe6//34I\ngoAPP/wQer2eM8oSubjF75/A9ZvVHbaPGNwfm5f/0g4VkVwWryQyMjIsHvzcc891+Q19fHxMXyck\nJGDRokWyj3X1+4Zsn3Ny5bYB1mlfUan542+U1dj9Z9cbPr+esBgS1niquv2Fil6vh0ajAQAcP34c\nISEhPX4PInJsnEbDeVkMiXXr1vXo5G+99RZ0Oh2qqqowbdo0LF68GDqdDvn5+VCpVBg6dCjWrFnT\no/cgIscXFTrC7EglTqPh+CyGxKlTpywePHXqVIv7N23a1GHbnDlzZJRFRK6EI5Wcl8WQ+POf/yy5\nTxCETkOCiKgVRyo5J4sh8dlnn9mqDiIickAWQ+LGjRsIDAxEYWGh2f3BwcGKFEVERI7BYkisXbsW\n6enpWLhwYYd9giAgJydHscKIiMj+LIZEeno6AODrr7+2STFERORYZK0nAQANDQ0oLS2FwWAwbePt\nJiLH1NV5koikyAqJnTt34sMPP8SAAQOgUrXMLs7bTUSOqbvzJBGZIysk/vKXv+DYsWPw8+MvGJGj\n62xFN6KukLXokL+/PwOCyElIrehWrK/F6u066C6V2bgicmayriQWL16M1NRUTJ061TR9OND5E9dE\nJM3a/Qat5zNamNiZt56oq2SFxIkTJ3DixAlcv369TZ8EQ4Koe6zdbyC1ipsU3noiuWSFxPHjx/H1\n11/jvvvuU7oeol6hJ/0G5q5ApM4n5WZFxxlZicyRFRKBgYFwc5M9WpaIOiHVb9DZH2+pKxBB6Nr7\nc4pukkvWX/7hw4cjMTERERERcHd3N23vzqJDRNT5+gpS/RVSVwxuKhWaDMYO2328PVBZc7fDdk7R\nTXLJCommpiYMGzYMly9fVroeol7B0voKlvorpK5Amo0dAwIA4n/Z8sArp+im7pIVEj1dfIiI2rK0\nvsLq7Tqzx6Qf+h591AKMho77hvp6ISp0uGQYMBSouzpdvvSRRx6R3N/Y2IgbN25g5MiRVi+MyNVJ\nra8gdbUAAE0G88NbWwOBYUDW1ukEfw0NDZg1axbGjh0LX19f3L17F9euXcM333yDU6dOYeXKlQwJ\nIiuS6q+4Vx+1CkZR5O0jUpzFkNi8eTPy8vLw17/+Ff/xH/+B0tJS9O3bFyEhIYiIiEBGRga8vLxs\nVStRryDVX3EvoyhiW/IvbVQR9Wad9kmMGTMGY8aMsUUtRISW21B7ThSaHZXUikNYyVZkzd1ERLbV\nOipJCoewkq3wCTkiO+hs3qZJD/uhsOQOcs4Xdzh2+rgA9kGQzTAkiGxM7rxNz/0qBMFDB/AZB7Ir\nhgSRjXVl3iYOayV7k9UnUVFRgeXLl5um4SgoKMAXX3yhaGFErqq78zYR2YOskHj77bcxbtw4VFdX\nAwCCgoLw+eefK1oYkasa4tvP7HaOWCJHJCskysrKMG/ePKjVagCAu7u7aV0JIuqaqNAREts5Yokc\nj6w+ifbThFdXV0O0sPoVUW/TlVXmLM3bRORoZIXEjBkzsHr1atTV1WH//v34/PPPMWfOnE6PS0lJ\nwcmTJzFo0CAcPnwYAHDnzh0sXboUJSUlCAgIwEcffQRvb++etYLIjk7/o7jLq8yxQ5qchax7Ri+/\n/DLGjx+P0aNH49SpU3jhhReQmJjY6XGzZ8/G9u3b22zbunUrQkND8dVXX2HSpElIT0/vXuVEDmJP\nzv+a3Z6V+6ONKyGyPtlDYGNiYhATE9Olk48fPx4lJSVttuXk5GDXrl0AAK1WixdeeAHLly/v0nmJ\nHElRWY3Z7RytRK5AVkhUVFRg165dKCoqQnNzs2l7Wlpal9+wsrISvr6+AACNRoPbt293+RxE1tCV\nfgRLrx3m543rN6s7HMPRSuQKZIXEv/7rv+Lhhx9GaGioaYSTPWg0rt13wfbZjlQ/wtbD32O4f3/E\nT/8XhD0eYPG1/fvfh7DHAxA//V+wcdf5Du8xL/Ihh2pzT7hKO6S4evt6QlZINDQ04J133rHKGw4a\nNAjl5eXw9fWFXq+Hj4+P7GP1evOX9a5Ao/Fm+2zoi68KzG4XReD6zWps3HUe2w/+E/G/DJZ8QvqL\nr37AqIABCHs8ANXVP3cYrTQqYIBDtbm7HO2zs7be0L6ekBUSY8eOxQ8//ICHHnqoy2/QfqhseHg4\n9u/fj4ULF+LAgQOYPn16l89J1FOWVn9rVVlzF+mHvocgmN9/b58DRyuRq5IVEs888wyef/55+Pv7\nw8PDw7R97969Fo976623oNPpUFVVhWnTpmHx4sVYuHAh3njjDezbtw9DhgzpVr8GUU/JWf2tlZtK\nhSaDscN29jlQbyArJFasWIFFixbh4Ycf7lKfxKZNm8xu//TTT2Wfg8iaWjugS8rljzxqNnYMCIBP\nSFPvICskPDw8sGDBAqVrIeoWuaOU2k/RLddQXy9EhQ7nE9LUK8kKiaeeegqnT59GWFiY0vUQdYnc\ntRkA6Sm6AzQtIbDnZCEqqzsuGdoaCAwF6o1khcTu3buxdetWeHp6wt3dHaIoQhAE5ObmKl0fkUWW\n1mZo3d96hSF1i+lmRZ0pBFquSnjFQNRKVkjs27dP6TqIukVqlFJJeW2HKwwp93ZA84qBqC1ZITF0\n6FCl6yDqFqlRSlIjksxhBzSRNIshsWLFCmzcuBFz5syBYGaweGdDYImUcG9H9UAvd7OvkRqRJAgt\nHdG8nUQkj8WQaJ3p9Xe/+51NiiEyp30oVNb8f+dy69c+3h64U9do+sOflXvd7BXGUF8vrFkw0UaV\nEzk/iyHx+eef47333sPEifxHRfbRfvTSvQFxr3739cH7v53SZpu54a68tUTUNRZDIj8/31Z1EJm1\n50ShrNe1n5abq78RWYfs9SSIlNT+gbiHht2PH4puS145tGduigyOVCLqOYshcfnyZYSGhnbYzuck\nyJrMPRAnd16lVryNRKQMiyExYsQIbN261Va1UC8l9UCcHH3UKrwUNYpXDEQKsRgS7u7ufEaCFCdn\n2m4pDAgiZaks7ezTp4+t6qBebIhvP9mv7aNWQSW0zLf0asxoBgSRwixeSezevdtWdVAvFhU6Qtbs\nrAwFItvj6Cayu9Y//FKzsPr090D8tGAGBJEdMCTI6syt7zBrquV1djkLK5FjEsT2i1A7MFdfrNzZ\n2mcuDADzTzq3zJnkKbkgkDNzxs+uK9g+56bRWP4PWmd4JUHdIrXYj4+3h9nXi6LlBYGIyDFZHN1E\nJEXq2QY5T0i3LghERI6PIUHd0pNnG9rPs0REjou3m0iW9v0P/e5zQ21DU7fOZW6eJSJyTAwJ6pS5\n/oee4DxLRM6DIUGd6sncSq1UAjDE14tDWomcDEOCOtWd/of2K8XNmhrs0sMMiVwVQ4I6NcS3n+xb\nTHw6msi1MCSoU3LmVgrQ8FYSkStiSFCnTHMrnSg0+xwEJ94jcl12C4nw8HB4eXlBpVLBzc0Ne/fu\ntVcpJAPnViLqnewWEoIg4LPPPsOAAQPsVUKvZG6+pa78kee60US9i91CQhRFGI1Ge719ryQ13xLA\nuZSIyDy7XkksWLAAgiBg7ty5SEhIsFcpLkfqakHqeYes3B8ZEkRklt1C4ssvv4RGo0FlZSWSkpIQ\nFBSE8ePH26sclyF1tVBYckfyeQfOpUREUhxiPYktW7bA09MTSUlJ9i7F6S1+/wSu36w2u893YF+U\nVzV02D5icH9sXv5LpUsjIidklyuJhoYGGI1GeHp6or6+HmfOnMHrr7/e6XGu/MSutRY+KSqVPodU\nH1DkhEDFf7auvLCLK7cNYPucnVMuOlReXo7XX38dgiDAYDAgOjoaTz75pD1KcTmWno6+U9uIV2NG\ncwgrEclml5AIDAzEwYMH7fHWLs/S09GDB3lyCCsRdQkXHXIxkx72w/RxAWb3cYpuIuoqTsvhgp77\nVQiChw7gbSUi6jGGhIvibSUisgbebiIiIkkMCSIiksSQICIiSeyTsKGezsBKRGRrDAkbyTh+GTnn\ni03fcwZWInIGDAmF6S6VSa7oBnAGViJybAwJBbWfkdUczsBKRI6MIWEl5vobpNZvuNfgQZ5Kl0ZE\n1G0MCSuQWsNBEDo/llNlEJEj4xBYK5C6YnBTWf7xTh8XwP4IInJovJKwAqkV35ol1m/w8fZA/C+D\nGRBE5PAYElYgtYbDUF8vRIUO50R7ROS0GBKdkPMAnNQaDq2BwFAgImfFkLBAqkMaaPsAXOvXvGIg\nIlfDkGjn3isHtUS/s7kH4HjFQESuiCFxj/ZXDkaD+dfxATgi6i04BPYech5+A/gAHBH1HgyJe0gN\nZW2PD8ARUW/B2033kBrK2ketglEU2SFNRL0OQ+IeUkNZX4oaxWAgol6JIXEPDmUlImqLIdEOh7IS\nEf0/dlwTEZEkp76S4JrRRETKctqQkDtlBhERdZ/dbjedPn0av/71rxEZGYmtW7d2+XipB9+ycn/s\nWWFERGRil5AwGo149913sX37dhw5cgRZWVm4cuVKl84h9eAbp8wgIrIeu4REXl4ehg8fjqFDh6JP\nnz6IiopCTk6OxWNiVxzC6u066C6VAWh58M0cTplBRGQ9dgmJsrIyDB482PS9n58fbt26ZfEYo1E0\n9TvoLpUhKnSE2ddxygwiIuuxS0iIotij41un6n41ZjQCNF5QqwQEaLzwasxodloTEVmRXUY3+fv7\n46effjJ9X1ZWhgceeED28Tcr6qDReGPWVG/MmhqsRIl2odF427sERbly+1y5bQDb15vZJSQeffRR\nFBUVoaSkBBqNBllZWfjggw8sHnN4U6yNqiMiolZ2CQm1Wo1/+7d/w0svvQRRFPH0009j5MiR9iiF\niIgsEMSedhAQEZHL4txNREQkiSFBRESSGBJERCTJ4UOip3M8OaLw8HDExMQgLi4OTz/9NADgzp07\neOmllxAZGYkFCxagpqbGzlXKl5KSgsmTJyM6Otq0zVJ71q5dixkzZiA2Nhb5+fn2KLlLzLVvy5Yt\nCAsLg1arhVarxenTp0370tPTMWPGDPzmN7/BmTNn7FGybKWlpZg/fz5mzpyJ6Oho7Ny5E4DrfH7t\n2/fZZ58BcJ3Pr7GxEfHx8YiLi0N0dDS2bNkCACguLkZCQgIiIyOxbNkyNDc3m16/dOlSzJgxA3Pn\nzm3zKIIk0YEZDAYxIiJCLC4uFhsbG8WYmBixsLDQ3mX1WHh4uFhVVdVm24YNG8StW7eKoiiK6enp\n4saNG+1RWrd899134qVLl8RZs2aZtkm15+TJk+Irr7wiiqIoXrhwQYyPj7d9wV1krn2bN28Wd+zY\n0eG1hYWFYmxsrNjU1CTeuHFDjIiIEI1Goy3L7ZJbt26Jly5dEkVRFGtra8UZM2aIhYWFLvP5SbXP\nVT4/URTF+vp6URRFsbm5WYyPjxcvXLggvvHGG+LRo0dFURTF1atXi1988YUoiqKYkZEhvvPOO6Io\nimJWVpb45ptvdnp+h76S6M4cT85AFEUYjcY223JycqDVagEAWq0W2dnZ9iitW8aPH4/+/fu32da+\nPa2fW04mDIcbAAAJLUlEQVRODuLi4gAAY8eORU1NDcrLy21bcBeZax9gfuaAnJwczJw5E25ubggI\nCMDw4cORl5dnizK7RaPRYNSoUQAAT09PjBw5EmVlZS7z+ZlrX+sUQK7w+QFA3759AbRcJTQ3N0MQ\nBOh0OkRGRgJo+/fk3s81MjISubm5nZ7foUOiO3M8OQNBELBgwQLMmTMHe/bsAQBUVFTA19cXQMsv\n9u3bt+1ZYo9VVla2aU9lZSUA4NatW/D39ze9zs/PD2VlZXapsacyMjIQGxuL1NRU0+0Yc7+zztK+\n4uJiFBQUYOzYsR1+H13h82tt35gxYwC4zudnNBoRFxeHKVOmYMqUKQgMDET//v2hUrX8eff39ze1\n4d7PT61Wo3///qiqqrJ4focOCXNJ7wq+/PJL7N+/H9u2bUNGRgbOnTsHQRDsXZZNmPtMnbHtzz77\nLLKzs3Hw4EH4+vrij3/8IwDnbV9dXR2WLFmClJQUeHp6StbsKu1zpc9PpVIhMzMTp0+fRl5entll\nF1rb0L59oih22j6HDomezvHkqDQaDQDAx8cHERERyMvLw6BBg0yX7Xq9Hj4+PvYsscek2uPn54fS\n0lLT60pLS53yM/Xx8TH940pISDDdkvD398fNmzdNr3OG9jU3N2PJkiWIjY1FREQEANf6/My1z5U+\nv1ZeXl6YMGECLl68iOrqatMt7XvbcO/nZzAYUFtbiwEDBlg8r0OHxL1zPDU2NiIrKwvTp0+3d1k9\n0tDQgLq6loWR6uvrcebMGYSEhCA8PBz79+8HABw4cMDp2tn+fyhS7Zk+fToyMzMBABcuXED//v1N\ntzUcWfv26fV609fHjx9HSEgIgJZ2Hz16FI2Njbhx4waKiopMtzccVUpKCoKDg5GYmGja5kqfn7n2\nucrnV1lZabpV9vPPPyM3NxfBwcGYNGkSjh07BqDt5xceHo4DBw4AAI4dO4Ynnnii0/dw+Gk5Tp8+\njT/84Q+mOZ4WLlxo75J65MaNG3j99dchCAIMBgOio6OxcOFCVFVV4c0338TNmzcxZMgQpKWlme0s\ndURvvfUWdDodqqqq4Ovri8WLFyMiIgJvvPGG2fasWbMG33zzDfr27Yt169Zh9OjRdm6BZebap9Pp\nkJ+fD5VKhaFDh2LNmjWmP5bp6enYu3cv3NzckJqaiieffNLOLZB2/vx5PP/88wgJCYEgCBAEAUuX\nLsWYMWMkfx+d6fOTat+RI0dc4vP74YcfsHLlShiNRhiNRsycOROvvfYabty4gWXLlqG6uhqjRo3C\nxo0b0adPHzQ2NmLFihXIz8/HwIED8cEHHyAgIMDiezh8SBARkf049O0mIiKyL4YEERFJYkgQEZEk\nhgQREUliSBARkSSGBBERSWJIkMNrbm5GWloaIiMjER0djaioKKxfvx4Gg8HicatWrUJGRgaAlqmh\nN2zY0Ol7ZWdn43/+53+sUrcSSkpKsHv3bnuXQb0IQ4Ic3sqVK3HlyhVkZmbi8OHDOHToEIKCgtDY\n2Gj198rJyXHoWT+Li4vx17/+tVvHdhaqROa42bsAIkt+/PFH5OTkmJ7wBVpmr4yPjwfQMgPmxo0b\nTYvDPPnkk0hOTrY4adnly5fx+9//Hg0NDWhsbERCQgLmz5+PM2fO4Ouvv0Zubi727t2LF198EcXF\nxTh+/DgEQUBjYyOuXr2K7777Dl5eXm3O+Y9//AMbN25EXV0dBEFAcnIyJk+ejLy8PLz33ntoaGhA\n3759kZqaikcffRRnz57F+vXrsW/fPgBo8/3Zs2fx3nvvYcyYMbhw4QJUKhU++OADBAUF4d1330VJ\nSQm0Wi2GDRuGtLQ0XL16FevWrUNVVRWampowf/58zJ49GwDwi1/8AitWrMDJkycxYcIELFmyxOqf\nEbk4q6x6QaSQo0ePinFxcZL7P//8czEpKUlsbm4Wm5qaxMTERNMCKytXrhR37dolimLLIkHr168X\nRVEU6+rqxMbGRtPXM2fOFK9cudLhmPZWrFgh/vGPf+ywvaqqSpwyZYp44cIFURRF0Wg0itXV1WJj\nY6M4bdo0MTc3VxRFUfz73/8uTps2TWxqahJ1Op04Z84c0znu/V6n04mjR48W8/PzRVEUxT/96U/i\n8uXLO7xOFFsWmtFqteLVq1dFUWxZWCcyMtL0/UMPPST++c9/lvz5EXWGVxLk0MROZo3Jzc2FVquF\nWq0GAMyePRvZ2dl45plnJI9paGjAO++8g4KCAqhUKuj1ehQUFCAoKEjymI8++ggNDQ343e9+12Hf\nhQsXEBwcjLFjxwJomZbZ29sbly9fhru7u2kStdDQULi7u+PatWudtvvBBx/EL37xCwAti/ucPHnS\n7OuuX7+Oq1evYtmyZaafVVNTE65cuYIHH3wQAEyLBBF1B0OCHNro0aNx/fp11NTUwNvbu8N+0cx8\n+J3Nj//BBx9Ao9Fgw4YNpgWgLPVv7Nu3D99++61p/WdzNcjd3lqvWq1uszrh3bt327zOw8PD9LVa\nrTatUWzufD4+PqaZPdsTBAH9+vUzu49IDnZck0MbPnw4wsPDsXr1atMU6waDATt37kRDQwMmT56M\nAwcOoLm5GU1NTcjMzMSUKVMsnrOmpgaDBw+GIAi4fPkyzp07Z9rn6emJ2tpa0/d///vfsW3bNnzy\nySdwd3c3e77HH38chYWFuHjxIoCWfpLq6moEBQWhqakJZ8+eBQB8++23aG5uxogRIxAQEIDi4mLU\n1NRAFEVkZWXJ+nl4eXmZpoYGWq447rvvPhw8eNC07erVq6afVWdXYkSd4ZUEObz169dj8+bNmD17\nNtzd3SGKIsLCwuDu7o65c+eiqKjItG7vU089ZerUlvLaa68hOTkZhw4dwrBhwzBhwgTTvtjYWKxa\ntQrHjh3Diy++iH379qGhoQELFiwwXQVkZGS0+d/5gAEDsGXLFqxbtw719fVQq9VITk5GaGgoPv74\nY6xdu9bUcb1582a4ubnBz88PSUlJ0Gq1CAwMxKOPPorCwsJOfxYPPfQQHnzwQURHRyMoKAhpaWn4\nz//8T/zhD3/Ajh07YDAY4Ovri48++giA46+qRo6PU4UTEZEk3m4iIiJJDAkiIpLEkCAiIkkMCSIi\nksSQICIiSQwJIiKSxJAgIiJJDAkiIpL0f3zF2/hGE4QYAAAAAElFTkSuQmCC\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7fc3af690a50\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "with context.eager_mode():\n", + "\n", + " counts = []\n", + " times = [] \n", + " for n in np.logspace(0, 7, 50):\n", + "\n", + " n_tensor = tf.constant(n, dtype=tf.float32)\n", + " count = collatz(n_tensor)\n", + "\n", + " res = %timeit -n10 -r1 -o -q collatz(n_tensor)\n", + " times.append(res.best)\n", + " counts.append(count)\n", + " \n", + "plot_results(counts, times, 'Eager')\n" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "x5ChBlH09jk_", + "_cRFTcwT9mnn" + ], + "default_view": {}, + "last_runtime": { + "build_target": "", + "kind": "local" + }, + "name": "Autograph vs. Eager Collatz speed test", + "provenance": [ + { + "file_id": "0B8bm7KvwJklpMUQtbnVpYkdJUjRtOTRyWVVfSEhpRl9HYm5n", + "timestamp": 1531512047714 + } + ], + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_mnist_speed_test.ipynb b/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_mnist_speed_test.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..952ec091fb1883e4f17314efa8c458bfe7f01eda --- /dev/null +++ b/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_mnist_speed_test.ipynb @@ -0,0 +1,652 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "etTmZVFN8fYO" + }, + "source": [ + "This notebook runs a basic speed test for a short training loop of a neural network training on the MNIST dataset." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eqOvRhOz8SWs" + }, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "nHY0tntRizGb" + }, + "outputs": [], + "source": [ + "!pip install -U -q tf-nightly" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "Pa2qpEmoVOGe" + }, + "outputs": [], + "source": [ + "import gzip\n", + "import os\n", + "import shutil\n", + "import time\n", + "\n", + "import numpy as np\n", + "import six\n", + "from six.moves import urllib\n", + "import tensorflow as tf\n", + "\n", + "from tensorflow.contrib import autograph as ag\n", + "from tensorflow.contrib.eager.python import tfe\n", + "from tensorflow.python.eager import context\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PZWxEJFM9A7b" + }, + "source": [ + "### Testing boilerplate" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "kfZk9EFZ5TeQ" + }, + "outputs": [], + "source": [ + "# Test-only parameters. Test checks successful completion not correctness. \n", + "burn_ins = 1\n", + "trials = 1\n", + "max_steps = 2\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "k0GKbZBJ9Gt9" + }, + "source": [ + "### Speed test configuration" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "gWXV8WHn43iZ" + }, + "outputs": [], + "source": [ + "#@test {\"skip\": true} \n", + "burn_ins = 3\n", + "trials = 10\n", + "max_steps = 500\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "kZV_3pGy8033" + }, + "source": [ + "### Data source setup" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "YfnHJbBOBKae" + }, + "outputs": [], + "source": [ + "def download(directory, filename):\n", + " filepath = os.path.join(directory, filename)\n", + " if tf.gfile.Exists(filepath):\n", + " return filepath\n", + " if not tf.gfile.Exists(directory):\n", + " tf.gfile.MakeDirs(directory)\n", + " url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n", + " zipped_filepath = filepath + '.gz'\n", + " print('Downloading %s to %s' % (url, zipped_filepath))\n", + " urllib.request.urlretrieve(url, zipped_filepath)\n", + " with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:\n", + " shutil.copyfileobj(f_in, f_out)\n", + " os.remove(zipped_filepath)\n", + " return filepath\n", + "\n", + "\n", + "def dataset(directory, images_file, labels_file):\n", + " images_file = download(directory, images_file)\n", + " labels_file = download(directory, labels_file)\n", + "\n", + " def decode_image(image):\n", + " # Normalize from [0, 255] to [0.0, 1.0]\n", + " image = tf.decode_raw(image, tf.uint8)\n", + " image = tf.cast(image, tf.float32)\n", + " image = tf.reshape(image, [784])\n", + " return image / 255.0\n", + "\n", + " def decode_label(label):\n", + " label = tf.decode_raw(label, tf.uint8)\n", + " label = tf.reshape(label, [])\n", + " return tf.to_int32(label)\n", + "\n", + " images = tf.data.FixedLengthRecordDataset(\n", + " images_file, 28 * 28, header_bytes=16).map(decode_image)\n", + " labels = tf.data.FixedLengthRecordDataset(\n", + " labels_file, 1, header_bytes=8).map(decode_label)\n", + " return tf.data.Dataset.zip((images, labels))\n", + "\n", + "\n", + "def mnist_train(directory):\n", + " return dataset(directory, 'train-images-idx3-ubyte',\n", + " 'train-labels-idx1-ubyte')\n", + "\n", + "def mnist_test(directory):\n", + " return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')\n", + "\n", + "def setup_mnist_data(is_training, hp, batch_size):\n", + " if is_training:\n", + " ds = mnist_train('/tmp/autograph_mnist_data')\n", + " ds = ds.cache()\n", + " ds = ds.shuffle(batch_size * 10)\n", + " else:\n", + " ds = mnist_test('/tmp/autograph_mnist_data')\n", + " ds = ds.cache()\n", + " ds = ds.repeat()\n", + " ds = ds.batch(batch_size)\n", + " return ds\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "qzkZyZcS9THu" + }, + "source": [ + "### Keras model definition" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "x_MU13boiok2" + }, + "outputs": [], + "source": [ + "def mlp_model(input_shape):\n", + " model = tf.keras.Sequential((\n", + " tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),\n", + " tf.keras.layers.Dense(100, activation='relu'),\n", + " tf.keras.layers.Dense(10, activation='softmax')))\n", + " model.build()\n", + " return model\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "DXt4GoTxtvn2" + }, + "source": [ + "# AutoGraph" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "W51sfbONiz_5" + }, + "outputs": [], + "source": [ + "def predict(m, x, y):\n", + " y_p = m(x)\n", + " losses = tf.keras.losses.categorical_crossentropy(y, y_p)\n", + " l = tf.reduce_mean(losses)\n", + " accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n", + " accuracy = tf.reduce_mean(accuracies)\n", + " return l, accuracy\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "CsAD0ajbi9iZ" + }, + "outputs": [], + "source": [ + "def fit(m, x, y, opt):\n", + " l, accuracy = predict(m, x, y)\n", + " opt.minimize(l)\n", + " return l, accuracy\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "RVw57HdTjPzi" + }, + "outputs": [], + "source": [ + "def get_next_batch(ds):\n", + " itr = ds.make_one_shot_iterator()\n", + " image, label = itr.get_next()\n", + " x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))\n", + " y = tf.one_hot(tf.squeeze(label), 10)\n", + " return x, y\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "UUI0566FjZPx" + }, + "outputs": [], + "source": [ + "def train(train_ds, test_ds, hp):\n", + " m = mlp_model((28 * 28,))\n", + " opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n", + "\n", + " train_losses = []\n", + " test_losses = []\n", + " train_accuracies = []\n", + " test_accuracies = []\n", + " ag.set_element_type(train_losses, tf.float32)\n", + " ag.set_element_type(test_losses, tf.float32)\n", + " ag.set_element_type(train_accuracies, tf.float32)\n", + " ag.set_element_type(test_accuracies, tf.float32)\n", + "\n", + " i = tf.constant(0)\n", + " while i \u003c hp.max_steps:\n", + " train_x, train_y = get_next_batch(train_ds)\n", + " test_x, test_y = get_next_batch(test_ds)\n", + " step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)\n", + " step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n", + "\n", + " train_losses.append(step_train_loss)\n", + " test_losses.append(step_test_loss)\n", + " train_accuracies.append(step_train_accuracy)\n", + " test_accuracies.append(step_test_accuracy)\n", + "\n", + " i += 1\n", + " return (ag.stack(train_losses), ag.stack(test_losses),\n", + " ag.stack(train_accuracies), ag.stack(test_accuracies))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 215 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 12156, + "status": "ok", + "timestamp": 1531752050611, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "K1m8TwOKjdNd", + "outputId": "bd5746f2-bf91-44aa-9eff-38eb11ced33f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "('Duration:', 0.6226680278778076)\n", + "('Duration:', 0.6082069873809814)\n", + "('Duration:', 0.6223258972167969)\n", + "('Duration:', 0.6176440715789795)\n", + "('Duration:', 0.6309840679168701)\n", + "('Duration:', 0.6180410385131836)\n", + "('Duration:', 0.6219630241394043)\n", + "('Duration:', 0.6183009147644043)\n", + "('Duration:', 0.6176400184631348)\n", + "('Duration:', 0.6476900577545166)\n", + "('Mean duration:', 0.62254641056060789, '+/-', 0.0099792188690656976)\n" + ] + } + ], + "source": [ + "#@test {\"timeout\": 90}\n", + "with tf.Graph().as_default():\n", + " hp = tf.contrib.training.HParams(\n", + " learning_rate=0.05,\n", + " max_steps=max_steps,\n", + " )\n", + " train_ds = setup_mnist_data(True, hp, 500)\n", + " test_ds = setup_mnist_data(False, hp, 100)\n", + " tf_train = ag.to_graph(train)\n", + " losses = tf_train(train_ds, test_ds, hp)\n", + "\n", + " with tf.Session() as sess:\n", + " durations = []\n", + " for t in range(burn_ins + trials):\n", + " sess.run(tf.global_variables_initializer())\n", + "\n", + " start = time.time()\n", + " (train_losses, test_losses, train_accuracies,\n", + " test_accuracies) = sess.run(losses)\n", + "\n", + " if t \u003c burn_ins:\n", + " continue\n", + "\n", + " duration = time.time() - start\n", + " durations.append(duration)\n", + " print('Duration:', duration)\n", + "\n", + " print('Mean duration:', np.mean(durations), '+/-', np.std(durations))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "A06kdgtZtlce" + }, + "source": [ + "# Eager" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "hBKOKGrWty4e" + }, + "outputs": [], + "source": [ + "def predict(m, x, y):\n", + " y_p = m(x)\n", + " losses = tf.keras.losses.categorical_crossentropy(tf.cast(y, tf.float32), y_p)\n", + " l = tf.reduce_mean(losses)\n", + " accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n", + " accuracy = tf.reduce_mean(accuracies)\n", + " return l, accuracy\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "HCgTZ0MTt6vt" + }, + "outputs": [], + "source": [ + "def train(ds, hp):\n", + " m = mlp_model((28 * 28,))\n", + " opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n", + "\n", + " train_losses = []\n", + " test_losses = []\n", + " train_accuracies = []\n", + " test_accuracies = []\n", + "\n", + " i = 0\n", + " train_test_itr = tfe.Iterator(ds)\n", + " for (train_x, train_y), (test_x, test_y) in train_test_itr:\n", + " train_x = tf.to_float(tf.reshape(train_x, (-1, 28 * 28)))\n", + " train_y = tf.one_hot(tf.squeeze(train_y), 10)\n", + " test_x = tf.to_float(tf.reshape(test_x, (-1, 28 * 28)))\n", + " test_y = tf.one_hot(tf.squeeze(test_y), 10)\n", + "\n", + " if i \u003e hp.max_steps:\n", + " break\n", + "\n", + " with tf.GradientTape() as tape:\n", + " step_train_loss, step_train_accuracy = predict(m, train_x, train_y)\n", + " grad = tape.gradient(step_train_loss, m.variables)\n", + " opt.apply_gradients(zip(grad, m.variables))\n", + " step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n", + "\n", + " train_losses.append(step_train_loss)\n", + " test_losses.append(step_test_loss)\n", + " train_accuracies.append(step_train_accuracy)\n", + " test_accuracies.append(step_test_accuracy)\n", + "\n", + " i += 1\n", + " return train_losses, test_losses, train_accuracies, test_accuracies\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 215 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 52499, + "status": "ok", + "timestamp": 1531752103279, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "plv_yrn_t8Dy", + "outputId": "55d5ab3d-252d-48ba-8fb4-20ec3c3e6d00" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "('Duration:', 3.9973549842834473)\n", + "('Duration:', 4.018772125244141)\n", + "('Duration:', 3.9740989208221436)\n", + "('Duration:', 3.9922947883605957)\n", + "('Duration:', 3.9795801639556885)\n", + "('Duration:', 3.966722011566162)\n", + "('Duration:', 3.986541986465454)\n", + "('Duration:', 3.992305040359497)\n", + "('Duration:', 4.012261867523193)\n", + "('Duration:', 4.004716157913208)\n", + "('Mean duration:', 3.9924648046493529, '+/-', 0.015681688635624851)\n" + ] + } + ], + "source": [ + "#@test {\"timeout\": 90}\n", + "with context.eager_mode():\n", + " durations = []\n", + " for t in range(burn_ins + trials):\n", + " hp = tf.contrib.training.HParams(\n", + " learning_rate=0.05,\n", + " max_steps=max_steps,\n", + " )\n", + " train_ds = setup_mnist_data(True, hp, 500)\n", + " test_ds = setup_mnist_data(False, hp, 100)\n", + " ds = tf.data.Dataset.zip((train_ds, test_ds))\n", + " start = time.time()\n", + " (train_losses, test_losses, train_accuracies,\n", + " test_accuracies) = train(ds, hp)\n", + " \n", + " train_losses[-1].numpy()\n", + " test_losses[-1].numpy()\n", + " train_accuracies[-1].numpy()\n", + " test_accuracies[-1].numpy()\n", + "\n", + " if t \u003c burn_ins:\n", + " continue\n", + "\n", + " duration = time.time() - start\n", + " durations.append(duration)\n", + " print('Duration:', duration)\n", + "\n", + " print('Mean duration:', np.mean(durations), '+/-', np.std(durations))\n" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "eqOvRhOz8SWs", + "PZWxEJFM9A7b", + "kZV_3pGy8033" + ], + "default_view": {}, + "name": "Autograph vs. Eager MNIST speed test", + "provenance": [ + { + "file_id": "1tAQW5tHUgAc8M4-iwwJm6Xs6dV9nEqtD", + "timestamp": 1530297010607 + }, + { + "file_id": "18dCjshrmHiPTIe1CNsL8tnpdGkuXgpM9", + "timestamp": 1530289467317 + }, + { + "file_id": "1DcfimonWU11tmyivKBGVrbpAl3BIOaRG", + "timestamp": 1522272821237 + }, + { + "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K", + "timestamp": 1522238054357 + }, + { + "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ", + "timestamp": 1521743157199 + }, + { + "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-", + "timestamp": 1520522344607 + } + ], + "version": "0.3.2", + "views": {} + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb b/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..bf824e2760e694ae3c00c9f08d9aa5d5522a9b84 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb @@ -0,0 +1,1512 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "b9R-4ezU3NH0" + }, + "source": [ + "## AutoGraph: examples of simple algorithms\n", + "\n", + "This notebook shows how you can use AutoGraph to compile simple algorithms and run them in TensorFlow.\n", + "\n", + "It requires the nightly build of TensorFlow, which is installed below." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "TuWj26KWz1fZ" + }, + "outputs": [], + "source": [ + "!pip install -U -q tf-nightly" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "3kudk1elq0Gh" + }, + "source": [ + "### Fibonacci numbers\n", + "\n", + "https://en.wikipedia.org/wiki/Fibonacci_number" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 197 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 7512, + "status": "ok", + "timestamp": 1532101577266, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "H7olFlMXqrHe", + "outputId": "472dbfe0-9449-4f93-e908-1a0785188a92" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 : 1\n", + "1 : 2\n", + "2 : 3\n", + "3 : 5\n", + "4 : 8\n", + "5 : 13\n", + "6 : 21\n", + "7 : 34\n", + "8 : 55\n", + "9 : 89\n" + ] + } + ], + "source": [ + "import tensorflow as tf\n", + "from tensorflow.contrib import autograph as ag\n", + "\n", + "\n", + "def fib(n):\n", + " f1 = 0\n", + " f2 = 1\n", + " for i in range(n):\n", + " tmp = f2\n", + " f2 = f2 + f1\n", + " f1 = tmp\n", + " print(i, ': ', f2)\n", + " return f2\n", + "\n", + "\n", + "with tf.Graph().as_default():\n", + " final_fib = ag.to_graph(fib)(tf.constant(10))\n", + " with tf.Session() as sess:\n", + " sess.run(final_fib)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "p8zZyj-tq4K3" + }, + "source": [ + "#### Generated code" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 541 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 103, + "status": "ok", + "timestamp": 1532101577412, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "UeWjK8rHq6Cj", + "outputId": "73ece895-12fb-489a-e52c-032945d7ed7a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "from __future__ import print_function\n", + "import tensorflow as tf\n", + "\n", + "def tf__fib(n):\n", + " try:\n", + " with tf.name_scope('fib'):\n", + " f1 = 0\n", + " f2 = 1\n", + "\n", + " def extra_test(f1_1, f2_1):\n", + " with tf.name_scope('extra_test'):\n", + " return True\n", + "\n", + " def loop_body(i, f1_1, f2_1):\n", + " with tf.name_scope('loop_body'):\n", + " tmp = f2_1\n", + " f2_1 = f2_1 + f1_1\n", + " f1_1 = tmp\n", + " with ag__.utils.control_dependency_on_returns(ag__.utils.\n", + " dynamic_print(i, ': ', f2_1)):\n", + " f2, i_1 = ag__.utils.alias_tensors(f2_1, i)\n", + " return f1_1, f2\n", + " f1, f2 = ag__.for_stmt(ag__.utils.dynamic_builtin(range, n),\n", + " extra_test, loop_body, (f1, f2))\n", + " return f2\n", + " except:\n", + " ag__.rewrite_graph_construction_error(ag_source_map__)\n", + "\n" + ] + } + ], + "source": [ + "print(ag.to_code(fib))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eIfVy6ZTrFEH" + }, + "source": [ + "### Fizz Buzz\n", + "\n", + "https://en.wikipedia.org/wiki/Fizz_buzz" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 125 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 233, + "status": "ok", + "timestamp": 1532101577681, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "33CAheYsrEQ7", + "outputId": "82a493ee-15b5-419d-8c9c-5f4159090a05" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Buzz\n", + "11\n", + "Fizz\n", + "13\n", + "14\n", + "FizzBuzz\n" + ] + } + ], + "source": [ + "import tensorflow as tf\n", + "from tensorflow.contrib import autograph as ag\n", + "\n", + "def fizzbuzz(i, n):\n", + " while i \u003c n:\n", + " msg = ''\n", + " if i % 3 == 0:\n", + " msg += 'Fizz'\n", + " if i % 5 == 0:\n", + " msg += 'Buzz'\n", + " if msg == '':\n", + " msg = tf.as_string(i)\n", + " print(msg)\n", + " i += 1\n", + " return i\n", + "\n", + "with tf.Graph().as_default():\n", + " final_i = ag.to_graph(fizzbuzz)(tf.constant(10), tf.constant(16))\n", + " with tf.Session() as sess:\n", + " sess.run(final_i)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Lkq3DBGOv3fA" + }, + "source": [ + "#### Generated code" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 1081 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 289, + "status": "ok", + "timestamp": 1532101578003, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "bBhFIIaZrxvx", + "outputId": "d076a7ea-e643-4689-f90a-57f5d086dedc" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "from __future__ import print_function\n", + "import tensorflow as tf\n", + "\n", + "def tf__fizzbuzz(i, n):\n", + " try:\n", + " with tf.name_scope('fizzbuzz'):\n", + "\n", + " def loop_test(i_1):\n", + " with tf.name_scope('loop_test'):\n", + " return tf.less(i_1, n)\n", + "\n", + " def loop_body(i_1):\n", + " with tf.name_scope('loop_body'):\n", + " msg = ''\n", + "\n", + " def if_true():\n", + " with tf.name_scope('if_true'):\n", + " msg_1, = msg,\n", + " msg_1 += 'Fizz'\n", + " return msg_1,\n", + "\n", + " def if_false():\n", + " with tf.name_scope('if_false'):\n", + " return msg,\n", + " msg = ag__.utils.run_cond(tf.equal(i_1 % 3, 0), if_true, if_false)\n", + "\n", + " def if_true_1():\n", + " with tf.name_scope('if_true_1'):\n", + " msg_2, = msg,\n", + " msg_2 += 'Buzz'\n", + " return msg_2,\n", + "\n", + " def if_false_1():\n", + " with tf.name_scope('if_false_1'):\n", + " return msg,\n", + " msg = ag__.utils.run_cond(tf.equal(i_1 % 5, 0), if_true_1, if_false_1\n", + " )\n", + "\n", + " def if_true_2():\n", + " with tf.name_scope('if_true_2'):\n", + " msg_3, = msg,\n", + " msg_3 = tf.as_string(i_1)\n", + " return msg_3,\n", + "\n", + " def if_false_2():\n", + " with tf.name_scope('if_false_2'):\n", + " return msg,\n", + " msg = ag__.utils.run_cond(tf.equal(msg, ''), if_true_2, if_false_2)\n", + " with ag__.utils.control_dependency_on_returns(ag__.utils.\n", + " dynamic_print(msg)):\n", + " msg_4 = ag__.utils.alias_tensors(msg)\n", + " i_1 += 1\n", + " return i_1,\n", + " i = ag__.while_stmt(loop_test, loop_body, (i,), (tf, n, ag__, i))\n", + " return i\n", + " except:\n", + " ag__.rewrite_graph_construction_error(ag_source_map__)\n", + "\n" + ] + } + ], + "source": [ + "print(ag.to_code(fizzbuzz))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BNRtprSvwJgk" + }, + "source": [ + "### Conway's Game of Life\n", + "\n", + "https://en.wikipedia.org/wiki/Conway%27s_Game_of_Life" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "r8_0ioEuAI-a" + }, + "source": [ + "#### Testing boilerplate" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "7moIlf8VABkl" + }, + "outputs": [], + "source": [ + "NUM_STEPS = 1" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "QlEvfIQPAYF5" + }, + "source": [ + "#### Game of Life for AutoGraph" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "5pCK2qQSAAK4" + }, + "outputs": [], + "source": [ + "#@test {\"skip\": true} \n", + "NUM_STEPS = 100" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 308 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 14892, + "status": "ok", + "timestamp": 1532101593030, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "hC3qMqryPDHS", + "outputId": "8405c0e9-e518-41d6-f5bc-e78df6474169" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\u003cvideo width=\"432.0\" height=\"288.0\" controls autoplay loop\u003e\n", + " \u003csource type=\"video/mp4\" src=\"data:video/mp4;base64,AAAAHGZ0eXBNNFYgAAACAGlzb21pc28yYXZjMQAAAAhmcmVlAACZUm1kYXQAAAKuBgX//6rcRem9\n", + "5tlIt5Ys2CDZI+7veDI2NCAtIGNvcmUgMTQ4IHIyNzk1IGFhYTlhYTggLSBILjI2NC9NUEVHLTQg\n", + "QVZDIGNvZGVjIC0gQ29weWxlZnQgMjAwMy0yMDE3IC0gaHR0cDovL3d3dy52aWRlb2xhbi5vcmcv\n", + "eDI2NC5odG1sIC0gb3B0aW9uczogY2FiYWM9MSByZWY9MyBkZWJsb2NrPTE6MDowIGFuYWx5c2U9\n", + "MHgzOjB4MTEzIG1lPWhleCBzdWJtZT03IHBzeT0xIHBzeV9yZD0xLjAwOjAuMDAgbWl4ZWRfcmVm\n", + "PTEgbWVfcmFuZ2U9MTYgY2hyb21hX21lPTEgdHJlbGxpcz0xIDh4OGRjdD0xIGNxbT0wIGRlYWR6\n", + "b25lPTIxLDExIGZhc3RfcHNraXA9MSBjaHJvbWFfcXBfb2Zmc2V0PS0yIHRocmVhZHM9OSBsb29r\n", + "YWhlYWRfdGhyZWFkcz0xIHNsaWNlZF90aHJlYWRzPTAgbnI9MCBkZWNpbWF0ZT0xIGludGVybGFj\n", + "ZWQ9MCBibHVyYXlfY29tcGF0PTAgY29uc3RyYWluZWRfaW50cmE9MCBiZnJhbWVzPTMgYl9weXJh\n", + "bWlkPTIgYl9hZGFwdD0xIGJfYmlhcz0wIGRpcmVjdD0xIHdlaWdodGI9MSBvcGVuX2dvcD0wIHdl\n", + "aWdodHA9MiBrZXlpbnQ9MjUwIGtleWludF9taW49MTAgc2NlbmVjdXQ9NDAgaW50cmFfcmVmcmVz\n", + "aD0wIHJjX2xvb2thaGVhZD00MCByYz1jcmYgbWJ0cmVlPTEgY3JmPTIzLjAgcWNvbXA9MC42MCBx\n", + "cG1pbj0wIHFwbWF4PTY5IHFwc3RlcD00IGlwX3JhdGlvPTEuNDAgYXE9MToxLjAwAIAAAAPQZYiE\n", + "ABH//veIHzLLafk613IR560urR9Q7kZxXqS9/iAAAAMAFpyZZ6/h5MpYA5/oqv4s2qPbYpW3jfK6\n", + "zQ6q7WMrNj7Hy8jZzmBpfHCwAAO1W4riBNsrapcCk+5V1W0XkkFULR4Qe+H3uGA2HgNW0zFAAUgt\n", + "W4tdpXv2OEg0Vuy5W5l/xGRmEGKDyeXyrM0S6q/1EKbad0x2mcHseUqNmeOGLy1N3b376XZKZcPY\n", + "IXC5F2332tNMj8CwOQiXM9PiCLyCVfZ3rQSkKBTZErkpS5kXUyoJG3FdIqLjRFKEapbUjcW64HIo\n", + "BeIbtRyWV9FyZfcTakx2KW3eB4ZI//MDykSe8CRgN76uBEqZFXwO63wmUREhHOb5AdaLV3xyGl/I\n", + "RV70rU/3t9t1aq5mFD3hy1aLTAV2U7nG072dyX87F7NgCxZHT2kFxu44fxf6gqVzE3PEbGr5fx9x\n", + "7TKXtmY53VP8UaeCd2HJiZ/sd165SutTnfiWvaLuCnmmXGF0AGqbj9S19kgOhTubZIJBydTTqQOV\n", + "YRlxbgKn2nzvunv9+NDG0/2ikyyp73W15QClmjyt8dUeynoN8CwtEQ59DdrAPZe4ARZTwWAfsRXw\n", + "1vcZ6Gr1nCNWllQw5IyZyxQtXrfc5p4wjPvGaltciG7d3FG1SGk6HDsZy5i/PsnkjRXLUvGbzYp2\n", + "2gs7ZSGfSJbEifctcMGeSqhOOYORKy6f/9omoieCVEEkniBXwWZ/eImb3nxF7SFIaBjgG2j9w5ut\n", + "BY6zSuQ5zRCdajzJ1loNO0havI8mp5yViAeAlLKYCxeK0Lha1FskL67W1YsARZVZ5EkhqAYEeTNI\n", + "M38Og48OXmj6QBN7c1b9uDUTacYEXO88ZQ1gCIREIMnm2Fgkir8pN4gtSeQ12sfOVz5x5KX7sa95\n", + "L4LyFQPDrFZcDBr4PWLeEEv8yzk0cYHE97GmAlA6WQ0HlWsS42cnXefvTPXnx4vcq8pbEo/slAuH\n", + "IBsrJEN1+aMCc9FNxwUPVbZVaWVjwLY0qh+mNWEaiNGRmacDXrYWw0NjqMPiLiFHacY5oGELRgym\n", + "S2mSo6zhsD1wKQ3EUQtwrjKPiDYc/HCqhkVwoWKUdI8xTS60kn4f5UqB0L77Yevh/wt7AnvQKQAq\n", + "QAEEevggRl1uigbOBTtscnYRnAj0edW4QExAzdo+RwLWXTzW/l3cBWTrh3ORzZQlxJ8jQTvPLB+f\n", + "bLazJZWFQQDcWhuhQ3gYcP1ruNwIroINRIr8px0UOgAhnk6CllxMN6gA5S0YPhFVFKd3n0AAAC9f\n", + "vYgISQAAAltBmiRsQR/+tSqC8p1IAOZemTPutEfx0mzK8zG8tdIxonBsDpoLZ+NnIOp4qK6idP1s\n", + "vbGvZz/zHM86Bg3q0yx2atmtgoo/Trt3YRy3se4HTjou+tCi7oJt2d7A8vEhVDu33JNJx+WCOgP0\n", + "03nVdg9lBs15v/0w7qMc3zqqJXCOy/Whl9aRhcaeOEWcD7uK6mCV8a6MpDJ959xBRfv2i/qFOFbL\n", + "Grs58WiGJcq4MQJI+rVWuFN50oiqBgiunfUrRmdviPYpNN11V9pwcOJwssWfIE3agnor/RC7vfLY\n", + "YoXzaJjtWLEL92OOaHLZT0j555xfb4FZcoJee+RXovB9IaoDdYRusngtBXPMUvnO+g2Z5Qdo9P8q\n", + "Zb8ItBAeHT8IBZAD/Z2nEA6qbxqOBSBtQNW6ZFYLtCTIoP/bLjCDHgtZk3cf+N1CpXs15pUIYWDW\n", + "elZtlTkM4w4EJlLdjLZyQPAeaBx/qoLmKyTKAEhm0hU8EcTq00f6fwkWgz2J6GTGtL/vJXgC8u4o\n", + "nTnf+Ou7sVJGVaouXxrzx+yGVHEcp/eV4gaFA95rInngQAOZWbA3558nK61JBPZl3NjEv5B9r9pg\n", + "2+SYY3wBAUeu2fgAB2+yYGw82pkoJJKpzYWORs6i1vn3GEgUTcwlYsdJcraYC5SnGvqSZhX7KM72\n", + "uE1e9bkpvpVyG/mkACn5R4jwX3xc2utCjjZgM101rirIF/7VfDtmJsSTDes+UVhbSr3SeMSI9ixJ\n", + "+fVuFZ5bnQPoRIfPc+Erw+K99JiGN+HE98/eq4pPlMY9oCfVPSdNyOAAAAFfQZ5CeId/AUuqOi5D\n", + "jlKfxuJGZZ1+rVyomjOIykvxtsjsuCiGtElbraCSFWcn3aIYWLrF3fPovVLcOnroBkiRMsdf5yJA\n", + "F87MQuoKeTaGOrxojCCCS64RiHrqNsE+7mfRRUDuB4sAEHFQHxBorgTukPSvrdFr5QDq+BhZj/6H\n", + "KN+IutwFWKX3ZX9pO3sI8My78TgRY5AA6FEcT91WcvnMypB/OWXzK6M8fYuhVVWipAZigjVOYhcF\n", + "9i6GweQFX9AV9EUQOp2qFbkrT5jceBRFLX6j4JUQ781/UGTekv1fcpCmzlpNpp8GdSeWxRL4gasp\n", + "F5uO5KW63rlhYccBo1cFwIN8txHNnwyQNiP00XC0PWDRZfaWSxsACRWrISow71IyUfcL7JNhjTII\n", + "rwDYATS0xZ9ep8siFC3JTxg1eNaroYfeI4tbkRHok47Vk+CUOQPuagVBtFMOOcy2OUbw8AWlAAAA\n", + "ugGeYXRDfwHM79ghzBo9nMnzfQPPIuvorxBb6AC8F4fYGD/t93kNSKNSEuhUXq9FKGtxnCkxN880\n", + "BPb/uTbjLTQVyPNuYlGl/gTlyLcVA/cDoLrl5TvaR/AcSLFE7C/t3kLx0STNibmdAf4TsHWKSblH\n", + "VWB4X7oQHrrDdhwIivRgUZf7f63j2XaGB+cbp5aHCCwJoovY51YTqsZZTz70FlSnypPHQBNzif7h\n", + "uvZkXhtEzpu9rYMo3YECkgAAAXIBnmNqQ38BDchAitLfY16mYQAQlVmv7062W8KLpIS1/zhS50Ib\n", + "b3ERigmkZKZMPaCsAi+zsLcku/gHGHnVZpuCZMFs72gmyuL4JFo6VjWcr5FtBvzIgD26rBNvP73P\n", + "nJjl3JImmFHiKjNez/gG3zTuYyCACuJCEYXyuEmzCM13hdCPHKg5GZtso0Z1qk6T1k2oiqF/3RIn\n", + "kyjRWuxBlHHmJ46TXULiUY14G+RAGoXI+u/G6muNclld2bq+6Zztuy+5ynaDWNNjuN1Ag9KUIx2F\n", + "XwNdepmp52/rOvISNPbMJ0U26OvqplXi+qHTbg8MLpUSIGCY8w9FZ5woLAENgvgu9M79yGlL20e7\n", + "ypJ4RMBqHYDpEz6Z+SSjXD8LsJ7VKlwo22A5Yukp1vTp6HHA35nV+PXK09DuRWKKdQUzmXVihF51\n", + "/+bB0PEFdoNxGdbbM7WveaCJN8XI7JgQWvw2nPlHX8M5QyPGSJ2HEexumoFrABvRAAAB70GaaEmo\n", + "QWiZTAgj//61KoCPNGHq/MxnjqmxxQAEHvTwibmyMZGX3ES9Abh1tMR+/DjR+6dnqRr/VxCl6gEP\n", + "wJ/5EYCYfGaGmQYsLOeM3v2SZjdvqQBwrwKk5A/63kFm8fc3QCLe93Mldv3KWXHdFT7/mudSntDc\n", + "vJwStG4jgi5LKlWdSrVaAxOmElsF+zWNzaCIQ1dOiZqi3JKj64hOeq1XIWyGvRvh6OLKBpB4rL6W\n", + "ugf7H/IPbSQuF5jWV7zL5LhxWiTiI+kAZTUMfO2YOLzmhCUSN9GAmNzgY4D2awYB4V4QTDjI7kdQ\n", + "tL+3Pmfl1HVilu7nC9CzQSvWIosiwv4btyHTL7IPT2gusybyNfW8QO133L6KbDhhXSDWUtcIFCgn\n", + "QUm36C9hvgGjorpKYr5VnErpJX6fRJm76fFYs8/nt763alyqdcSrqaTOLaf/72Wkkmlwbq3nLOIw\n", + "ADFDkkAPwzaM811K11iK/3HaYRT3nEhjJQFk5v4WBXwIVLAZeKdtC8YoGN9K6isN142fOG3s6fm4\n", + "J1nMtOEZHIwep8In4slLmHh39qBzhGZO3igiVpgz7u+JMBeFkVHe72vduBjIy+1dqvxL/TPics3s\n", + "+alwfTMNQKave1qW+5Uj8jZQTjcLAtKvzoako9VMIOfQUQAAAQpBnoZFESw7/wC9ZU4P+UeGsidW\n", + "4n5tFkXmtxppYvKQ+WGj/x3AAdl6+9c9x7N2b/yJykTvVggfpMnFUWtxla4sr1ouwANom+Uf4IBJ\n", + "/zXPovndpGdy98nJbZxFU4rrWpr8aI4YmRX65+IGTn756CZWwXKY5DyMgKnDcCtk0HEuoHgdGhh7\n", + "1PG8+nue+pE9pBHqiBNWAjPd90qfMtABmMShLoXtUObqYbqXhJvVjjFhKdPS03IF24fu9Z0ax15V\n", + "DnkiLmgyOCvJmcdIX70L2ZEECd/hxrSq9JUVjC41OX0F/ayI6GtkPMUuZ2xWkMFo5rqOAo7v0Zlk\n", + "ke/79TjeY13FNiowqcbhMwfDuwAAATIBnqV0Q38BDXNpg2t4nJdhAA5ru/5Co2KbB/AnQt7fa959\n", + "0crOQgtTxL36jtVyKPmfuQMYuWbJ/7bYTEV8sEjceHvN6B0CSEZzVCjaPLzOQJZMQpQ4K4WKPlGc\n", + "lnEwYAC9Dsejj7Fbk2RyCFiJinyU2HOscjUR6fW2jRsAFpVq/PtZDVPvesPG3AqooVaKHp9Ex+Da\n", + "AH0OvccSugyDKsRBAEiYR8645aXxbFSzraQsELDsIIr6HRN8F3lUNVBvzNO3mxBhq4th/kgZSjjJ\n", + "JZrYmg3UfIUO/jn4xs2XQ9Pa7Uy5K3JhuIQwAOUKDmAMC0p6fgz2on4ceyEcfiCGDPZpPyL3391F\n", + "dXID0ctPQ1a+Hk7UcAc9gSDL8CZKz59YyO0ACPjfAKV3Y2dbTAKdWBsUU0EAAAFEAZ6nakN/AItk\n", + "aaqbMCcBE0iEIDnEBfRZN0neHQxaz5DPSzK0ZSL640q0AA5jkP0YAYAumNCN0MxJYpWFoQ9r43H0\n", + "i9SZLdv1UbgpG3aX6KESZW7AgdlevaBngH/w8xYsqWx5t90zzi7x9VyRYpIAD+XTrxvgBoFILNCs\n", + "gd+zDA9uvbAPlLMwG/qFltlwvLokMt344erv3a/C/ySOwZHFzpakInpJ7MQHkmKi1KHZB5KrfqwF\n", + "FnglZJwWbe7LtVojTdwQnAksziDNlEWCkMQQJwziY1KYtlXMNX8mZ3MtYR1KNf/CNin7/ys9ZQyx\n", + "4Zlk//H5KDc/8O2+JaxH20CAaAABxgSxo+yJal1LnRHYfOQ1TygNueW/rPAA37g/6fLS7mbYKz7k\n", + "dsiSiy1mAV7n/qq81UHJPShQSXK+E4Y5XKuXEWG4AAAB8UGarEmoQWyZTAgj//61KoAW7kO9JCjl\n", + "XSE6nAngAJVxWWFl/YDS0gZ32xjwUFed4hmI6rj18z16nS3Mz1iMmFblrtaE4zGXS046COODiIwH\n", + "QG5lRmcBExMKlnynQruQtA8n/NitzdP/ysLrucGyp5nKV+XyJURULfxk4kwNp0a5TFlJ1fusOOJm\n", + "y0hvsvEg+d4Jz3anvWT6M9n5A84CGucNifV+WlN9gI9gs3qSoCZdU/gglcFYM5u8YchzhQFyMKxn\n", + "kpfWK2LU7aaZHt6xLbqjuv74523K9/dtrrsFq/LySiv1P9Wk6/6d5RC72z4cyaUq6hMMn4IWWRo0\n", + "zJIM1/lSYsWxt5/M1Mkv00Rt8OZvmLxuFfd1BIVlANlpgZ39RYhqqzU6v1HwaW0EudelFBGhr5mf\n", + "GaDE05Z8ywp5rN4Qq4D4GNAGD/qgEjtaDDf4ZBAD/TAHBwxfNjm2nPAdbbbIuWSkkv8NK6EMlKqH\n", + "mOktd+CB3P6Szd1+HPnUsyQ3659r3XLnoi0cvM4usfW+BgxqT0mgHSgn/F6ajdTNM+a8xJQnT036\n", + "7195r0uF5vwi7PIviCQ2E4Vs4Wx80/8tBDEJS4qOY1YJ5aNV1OV82fB3HOimLHd2vU/d4Cv7OBh8\n", + "k3gNFcjeBGh+3lQcDCLZrG1mAAAA3kGeykUVLDv/AGVBMHxAlJYGEpFnv2bb0ADrwvVKxe7+SIJI\n", + "g0dPJdL0s9Hd2mGX7rpdIiUH9ZgtnBO+m3uPNae/YtN3u2p0kkCez2KiPNqgSoEcHM+ePgq7afkq\n", + "0HHTSZl/+QbjsyfbI/0lv1mLAJUd3u7VZPPHSdXK3vwLfAwOe3Nid72slU892DijWVvanzM1IzDQ\n", + "XfN6x6GH2qfaLrHePrJTJxXC/RSxcAol7x2JJ5OA8VjN8jXu0yKirBiYqgcdFf9odG8j4bRmE2wD\n", + "MG0SKuGrJfd91b6B7hbRUwAAAPYBnul0Q38Ahz7YAbwPIqnkAA5sEIcKo2/sVUP0LEeFOLjKjaet\n", + "5YFAjDbL5BIdGqWouG/H8ozoec2ZpUbIZu0ELtG5yXc/5opSZlnqbOpqdTQkLs6gr9dv5GbFvVjS\n", + "Os1j9FIMQsdc8pttosNtygWB8gLxr65El6umAZE5CVU9Mc8Xxg/tenmTduGK9Cd7qRDiu1sLYR2f\n", + "or3KBMo8ebz5q5EmWucvREbYSziQIIycIwJg9OG+aH+ZUEQbjbfHfaiX7yoxGJGP78aNOHP7GvC+\n", + "JwM6DxnSyowUBAqkW8ckgrhet8gYYrt8MIe1MPJQB6sv8hHuAXkAAAFWAZ7rakN/AI9XvmYGr0rf\n", + "QEvrPPTQWEAA5ru3wBCXPJiC8OaE25OBvVl2wRXqp61wQU4HxGJCAxkSOz+G3Yzvg36uCK8bPZTq\n", + "avaOG/H9WxjsuwAl/bIYJdnyD151CiUZ34aErVIixKJ53oKrLeHr3xLgxuH+y3w5uH5lQRsL0Pmp\n", + "0jQItTBkKwlPywxFk55pROuYZWi/h/N19QaFlF7WPobUElLlr+nCH+pVt1nW9/YwVGz/cO8zwmWe\n", + "Fb0OnFji7CYSsi9ScC3a50GjUP7IpaY5NAHv33V57bkO/BD6dnreymTbSmQdcj7PAJkvz610fMqn\n", + "mDGTMB31oxAIE5eWeH7mBZouSgmtxEamul7sYaTPe7mP6FqNCz0h6wLot/zAFwx9/D2+XB0x8mmS\n", + "b086o+gqkoYoHQeQm2Sb3MU1Bz0KHDGo9jCmsBmecxs3oNHV4KaIoLKAAAABrEGa8EmoQWyZTAgj\n", + "//61KoAcdmk2P6doyaR4wEHxsIcmssCD5f+3/v8PGtlbWZ+A0oGGFPTAdgmU2TFbrRxlmwUCouNe\n", + "8freV7blHDodFImzwP3saA3AZT6NUl7vDGH/tw5n9y8rP4XGnhEXBHK+6jIhoAYc6G1CDX0mqczJ\n", + "7tbei5I0YSkDjza4rJSbAF6cRoJQH3s2Q+ggBQR0BfH6N3QlPVwd9YFvP6++J+XrbNU56Pxu6Wey\n", + "51asar4AaARXHregTXL4xn/VNt8Ppk2xD3/1jXAVXdqMlS0tYGM/TtrcuTC63Lx21RQtklG6k0xA\n", + "eWm6W0oL0KTvxuyegpC2ySp5v6zpSEYvzWR4IYirfT0RYU+jLtX0t4M/L/0k8xOLTHbouoUPD6DN\n", + "dYYLYlVX5noJzjCAVCiS21OCcIKqWD/YiU/+dTZpdFFNdHEa/MPvUEq7cJD7ANJ0YUweepq2Eqdh\n", + "57SC4Tpg6jyEnFgMaHQLSz1nJNh4lxM1TPouGZ9bmQdDr9WY+nwzRBa+ZLnaqBSYKWSKEs/TNtNZ\n", + "ev7d+EnJUf9G9CAmmiSDlRAvAAAAz0GfDkUVLDv/AGU2nAwHHyQlvUxuENDSO8vXFIAPilnMlQWb\n", + "nTHwb8wkIo6JKOaIP9blrrNXcWeeQDVprB1Bn//+nbSDHls1apJcUyMHUmojA58P91gutTiF40zp\n", + "fDaF096G01gcvpH5Za4+DfUvxQpt/wH5PntJzggww1tLhP1NyH5U2TTgrnA/BevK2aCa9xCuCVgA\n", + "JJZF4uqHE//COeWbJ6LIFJPoadxAxbrAcxPQQHMzEG5G5S3Yfd+YJBLrdO35JvVrsUTYO4AfvJeC\n", + "zwAAAe8Bny10Q38Aj03WPPyvISnWAC7KM5WfLH925SBeAKcvJaYOa5WZCzX9H5nU/7qAFTCgAnl3\n", + "rAoSnKk1337XDAnLfPYAAOSIcqQwF++e4HouwNVAWCEsVyl7Y6DnBaBT2mD1H8560KoMvm3kKNNC\n", + "oxFCc4BdAIXk45JUbGFNGYAjCbBbJInMjwa41HA404yKnJG7rNXdBctnsSL/36UoXvVx3J2tGX84\n", + "+FHk7e72CsAyB49ajd62idmFQji9Jj1GaiqtCIjWs5o6Mz8s5QfrvipNYYD0YZ7gBBGm4AEz17d8\n", + "isscgsp4QI2odbuEJDq1nfJbW6+1HGcN1XfDC1Xfa5IptM5UYHm5zIT4rSPBIDE6l8/NhVxlFP21\n", + "JPQ0DZxnZFvxIBznQbqkhaGZjMafgFoRzC9Nl17x+K6e75RlplRZtXaUIbjAUFBJIQPkoIrT6/O9\n", + "NtkAmnl8qqUC1RktW/RjiJqOyRTTITHqNKvKy/0gb88xEvvGPgzcSs2KpkbHJWmCGIlSWEkuqcCE\n", + "jBn3Y8XOQxMUxEYeLPJ/9s/F2fT5NAnko+RFlv75fWLekZZP2s17yJ5ccFGhZyrkGX6u7xXK7N8G\n", + "Qlz8qfOHvgMQrlB8p4j7qtnPgBPf8mcsM295CuAZxkK+sut074W+0hM24VMAAADaAZ8vakN/AI9G\n", + "UrhSy/Rrhc/LGXguupji5cAHC2DVoxU1gWUkKeMT366GcmuxH5O8lBZJeHl8r2KNT0EaVARyW7pN\n", + "L4uNsKKl/WAzLJ1OZWTQf4NaAfodQGO9KzZS0j6oGvr/urKiQwbP44Tv//glYQyyCFeq+8nnrHBj\n", + "aACu2w1otySh0DYMX412uY6EYcx3GtQaRpNPiKQniWdVV2KH48fVxDy0uLS0SmCZEAWLVNvtWqO+\n", + "q2OwCBr1m50s0i8eRTlSP9xoKtxWC4ZqL77eAW3kYEBJOAywYUAAAAH6QZs0SahBbJlMCCP//rUq\n", + "gBY3NzYDjVIwwAKbp/vtZn3NtK6t0V/4sA0MV4ijJVoTZ+e36T0E9eQ0LOyzsqR0ULZJUDRy41oM\n", + "RdsBwM4wyEJC67daWmuDEXKhZo862uqAH8A0QJ5u5RKBPFpngChYYJdWzP3onEWImG8Yryy/SXt0\n", + "jQ5te76AagLius72bzwZ4AZfLm/04ID6oXhPwqkf1cNsu4/kIt7oCOETiL+lzwHLEnEsdPSz3DxD\n", + "uLGkH8o6jHofDxEXcB6cOS43aUxGKPYPtHCj2gw6RzcRoX5lD5mwqtoCTxk6N8TxyipSUyNnbA2b\n", + "G5NuBUVLHTce3QKY3SdkbyH/wzdOpT3YHUE+FYQwMKCF6SMyMBxp2gI9k4yUZYljUiekF2XIFkfv\n", + "TFy1RUmikOycLKkTYTreTarsMD5JfjZ2FJWrroj/YX+uNeGtKNZl9Zyt+k8u4Htq1bPYEjCrLHds\n", + "qeIuFWmvxTYEQblStjDXmWfITtxy8KvOgn9iV+KlidrnVhlE7Dz30fuHXxxFZvIzhgU9uv6sSC7T\n", + "vZuGMsKGBGTYmSe0P9hLI2VyM/8GUWwG/AITiU4a7OVDjUNRPaiIEt8jt2oImPIY8qcrJ82CVd+P\n", + "mSjoppoeHUTHmeo+koGqjhwT7ueVHNT5VZ4yuGKEDdFfEIkAAAEMQZ9SRRUsO/8AYrbCELHs5dcg\n", + "AyOPuRHZUWtdXLx9XaNQixO/8Cc4Q2MgEa/wKETsHiR8C1XOv7rI3JB0rg46JfjEArbHaTHmANKo\n", + "+czcI/sIduYNFOE3TvObMh/KtGpZSdF+qnDDtY8zD+7RQUdzmkG5zeDj3u4Vq+f3qnKCwgbU+U0R\n", + "dQR9Q60wXqL03p/iYVxkI8jJqvkECuxT7efJI+5rmzyP1yn+WKY2EsjjB7bwwVfe6RxBmzR9Ed/9\n", + "CA95ILUJxNg4HsmCO2Ko+MqZAH3wMlG18kUm2ogL3cKIkVXogjofyKhbsSpKLpFFk71DzB6NrY/3\n", + "HfknWM2yn9yeQB/joufGEf/bvMAS8QAAAN4Bn3F0Q38Ado97WJWiqN4XS53kTA5YWsnJBdebpf+9\n", + "lcN5zPySAC6fH/XzBsBKbxdm4pTiPFVrmGXyhaRiB6dxtlwj8MyI40Do8AXHq41BAunk4K4PTgzR\n", + "rFycWqaL549wB2C5jNCLXlq6Tuytik3ijlMSkx9noeIG2Lc83eWkRkQieksQSO4xI1tzzkdqaNhG\n", + "ExZARu3MauZwrBopslb/ZLdR5ZS0G6p8o9DD5cphJjxJoSV/70/0Gr+woS8Zj0JpVvvpygE5bXQp\n", + "/YBCqjmq4uOCyt9SvCzPelUEwXEAAAGyAZ9zakN/AHZ6+HiwE6fxvgA5rqP9zmI+FShvhJS43N4N\n", + "sc5a7qq0DK7DHadXkQxf+APmeqLrIGM9X5aCQgeyxdoAlcQoyNsm6ol85w5z6JV8A3YntmCae+s8\n", + "+8/Yheg1ctJWrSharoeypUyemQeq9Rm5cIkSOS9Ej0hbIHyFhPQW6K3SawgMNVKQ0s1BpJvXDQSY\n", + "x3jIEdIgEtwe7zce/DjcO3RNN3g+SlPoM7cl0qJbM44NIDG9JGXcwVrY/YKNrpChX0yegP2ZHDI1\n", + "MzOs5eWP/2l5loJrLid2mK4Qhw6EGFrIadsV8rSjzgHRNuzJ4U3JdubidEobU0ehkU0P6MYRK/XM\n", + "58mVywGbsw6LPu56h1S4w3zHGYMd1zPKOsnCUhaRfrSZTxvjerNQ22prVPqBstk4JgHdnSScrwGw\n", + "eQcqvIw7gKhonPDKM4fJtO4n2EsI5Cd0iGMjmgPw/PU3FL8ZP3QbYLMwZ81Wd7BLLBDf+ngKiFIe\n", + "it4neyhhaE/a71b8TxeM/ZrgH9+D76dlgPI1ZJW6CCVyIs6Y5gK2plkcgRYa0MwWF+1A6zPtBEgA\n", + "LOAAAAIIQZt4SahBbJlMCCP//rUqgBY9we30eRuAA2kMf/9/gX2SHKs8Uq31+W7Vx4LugxILnhMT\n", + "6icG5WQzdpL8yjIXjBq99nVaYweUdJE3LrdOpsVxNJ3kODVBkposYOoRuOMi/SNhcjrJwShp6ljG\n", + "Qs7tSeRJSYDkvm+SI2ckjbManbEesw6wo2ZffuryaLuWkU9SNALC+2QbPJD4bFy7sTmB9+6VOdMm\n", + "rnLvYN4ZyAJz7OhQG85P+JnxdgXgvSv66sWBs05p3vOE+53H+HQCMTLVgvoYmHNTIYtZ5CIln4hA\n", + "GrjLg53unVVQTiYlSzZrRE2vmtsqac+v6CrcbtgC4HktflvPTsvgqWNHri9NWa+EuXgx/AgGkZVJ\n", + "r1n6gAd3jtjLtv6YvbPiBBo2AhBUxCbYyroAjcvjwUBtRjXTdDEvdYfItmTKA7W3+KvVi/PCtod6\n", + "/3gOoaA7zRdO+8+MHlGl/c2xzQhj2O1n8eJkOu+NcsBkpmxyosDi11EOEaiQ6vfnOvH9MSM+7D/v\n", + "k91SLlwv/nF+5eDPHSLZQIoFUjHjwVoSGCdOLqmIe6tsfTERCeAhC+1bhRhe0612KIL6izjolsR2\n", + "nUgrl1o39HqnKAVqQ/HguEezLTgmGW27Df2kp4E1wRl/EQgEcsMfBPga1ndY4uHPYq84ArNCWk+c\n", + "YwxlHAPVC3PK3Zp2kQAAAWFBn5ZFFSw7/wBXFVHDEfqz5TAg6AmqzzGCl9B1ICKhB+tKz4Y9Km1L\n", + "/vZyZ1OR5rO815FlrTgGoncUDKVNjpKrVerCm+HleHb1b4FhYQG8B61zGq10uLuoQHIyL4Cv2/mm\n", + "s5Mi7ZftErBt64oWYphUyh0Hmn9dYYheGFzLdE9gvqcAEGJDyLZq+nfiK0Px8pHIgaIfsEdSUYcC\n", + "8Otyxta0EKY+Dm2m8AtQ8jjuDmkSHm/uLhgf1uCnztOKFhkR+ydRCeR9tnIlTfiv3gJbsPT8swjP\n", + "0OUm6yT8LhwwCJU0AGI9hN0/kTkz+NeSHjSPaBx26MAfS2Y5NEtva844h4B/RttjqxMsNDiDrfB4\n", + "5xn/Cl/3XrcF40eivyUSC+FHzx3M4BoLQLOKf7iz8hKiUrqRGVkGToUMxkr5192x9xCjbuvLRMd8\n", + "9Pel4WIOhSi52xuSf1eEhC5VVAp4lHpZmHCbgAAAAaABn7V0Q38AdnTaV3jxqK844c19uepGJJSA\n", + "C7DQuTz6pWfCzxcMbX5JwHItpyM9y3YT46z61a7h5Lyukp+nSKoO0zQhT0EB/u6ILUCNvVbb/89X\n", + "7TVI5UN6EFwYYfi4uoFmqb+5Cd0J/+d2405yTsK/f6WH/T+vNB1DYWrW67ctgHOgMHAWDLG9mitl\n", + "16bXmPVSi2sWzpWYg3147nlnaD00aZHqQlrMPzYTLLFwWHOLNqCoWpNLMMEevc8AnQWeykk9VNTU\n", + "NXzAXhrKDXl1tLQTxZG7GX3K9cQyeUnjfH3rMBGDD2zCLGXrMfPVl9EJ/F5M49Rjn38sXUf2JvF8\n", + "D9r9tV1APCHN27+egfFIMDg9OhrQMtjAe3WEfpYS7pl5yHh7ZZ2CedEo/Wf/ygYTAQFI72AaUTrV\n", + "n47d9OSqAdYs7lkgV0864auRyPQeTKK1Sp3ADeIFS134VGBNG1VnrfyZuznYkI2r0FVkGFrAXpUu\n", + "ZJmyKqqILhJ1OTBM8C0VBV2QXBYa2aSn2jj9t40/wJJWc9IGAVR0vj/u+wFocjwf4QAAAZYBn7dq\n", + "Q38AeUc/pR5QUuADgu7/kKjYlIf8yn+MfKKvFMJ4eRJz/DRqteBIBJsZW3T3phi3NzuSw0zOvEhr\n", + "CHz7xEUteyaR+fa6YCBeiCtangbUerW/UGoCobzV/74XB/lXH53NcEw+6x9o3/ZgwG/7l4psK3P0\n", + "EqSwtCrcKAAv8Wi0Z88mFp3Sp19shMF41mqYa8pNsyefrruQONS60LHg/1GySbrTeTWW74lCDwnt\n", + "BGXpwghp/QF087PP7hxkE8lvu8APh5F1FTiOCBSvJFm6yFC/tz24gmveLoV4Rq/qtYWRE09VDCDH\n", + "yjftToPMsyi4DoCtXsPRk5Jxr9Mn6xDxGjfz8uMmOKJ15ejPi/Sx9cR1QrBsU9dhcYifdB+c0AMF\n", + "PolB3N4pBZAASP6m7EzaTer6yZ2sIKcQdlGt9xsZ0SHtS2313gpdJkLEVrHpO5/BTcfUTTcK1+bC\n", + "PwRYX+iIyInP1m6htprdy84ySZ5IaGCpRKFxMCf5w22wXyyon+dlMPKACguyEPTCCZQ2MqEuC+sa\n", + "uB/hAAABxUGbvEmoQWyZTAgj//61KoAXgR9s4tVmwJ9HTza3s57iAAoQf/wjqzjlXnP+29f12EfR\n", + "S7B+4I2epG2qM/uoQ7VlrfXFlhjyX/aTq0n55QXAKa2xUKolKsuMfmZFFc6+GP96b13JiSidvPgt\n", + "2SSGnq9Yw4MfceFmgOaZRcwoMnpdb0UpI73YdP+DfypKyrkDqKWcBc/BGhrH8+XdnpCNDXfg5rMl\n", + "b0uFlQ11yUxnDYOfRwLbdjJA6FYddawSEVorFtY7jkSQx+OUBUgWkKC9rhKB+uV/yqQsvbuFiyYV\n", + "MviBpsZgSSN0TOC5JedQ5H38ENVBLjXnWZD9PQyueLoT4qwtI+7lodFSnBG3zboWdj6P7XDbgKT/\n", + "zKkFObUjwhstiQtohzxd5AXhBH3DQqNv6mRzuMxFDcTEo5ut/0/1HrPGOF4R3sJ/eQT+YnYseqvc\n", + "0m5njpgI3qkLmn8efBB4q3zWGpHCxBwC84HKjuugMICuXfcJHKn0aWkn65aEjT8AdxDWE09InGyo\n", + "EM1wsU0JgJ/qq/6MdHWfQW6+bt5xWlpYJ4axi9wZc3Aoz+Rixn8UVM2e/bd31+W37ucz9udquxnL\n", + "2JdNUAAAARlBn9pFFSw7/wBZVXkLa/7xg9HEtDOpc+GkSv0gCD3x6eQNkROUaCyL6QH8m/0USPLW\n", + "nllgC+uXg2X8kUpaUiErsLvwKd9y+trtKwV7xlvkAn0JqEnToCvptE1Sb8eF86DTi2ywy7WE/imn\n", + "jNBYQny1cV38ScnZp/V3phWQAYBG3kUdNNuj/FyVB7DgbQbTLK48AO5nLYv8B3LvBNBfBJ+ym1yg\n", + "YJXKwjm8kt8xUjO2UGKeggZOs7YHWr5Fj8OX4jV/B3/cMzP+f6YyrayA/80F6f9vgrbTlhWdlFQ8\n", + "QtrHKjmrl874OSSPJYH5wfQfF/1NrQd6soxjmSWYI9/FqOPoy6ujUPxQvg1fUda+wK31Cv8gD96H\n", + "LPqpgQAAAXkBn/l0Q38AeBaU9hYCjxV6lA176iBcJKIHTfhwkqkAB+a0LmdvcgdK3vyEsSkCI+8U\n", + "up3OQ4OQId/B45+Mf5P4Fc2VsfnQAACxyzNkvgEEYwZk+TyOR6/VZmeFNYMrBdqc2NNBlh56ISK/\n", + "h5V9lagvsX7yv0p9Hk6RXo3uoMgKhKOv/QgBAqhUvAKDw4DS7G31tehd/myRMmCPxIJ79bZsQe2/\n", + "iq7Nquzc/VDpPXFZHPvOmiyfyrt6Fxc2jLHZJGpvacPTIeLJiSaBxgRTEKBr/xXaKQjc5nLhlwgc\n", + "HSz1WRlyOsXOkob3rY8KoGVETaaIvHEl7sVHsV3QN7iR2rIGzf6YHv+c3l8OW1b7tAMShtcCLifl\n", + "8k1OtS8Z5o7MNTObuLXIONSPGo1fC97qRzqHFEfMZntEMqsFjjWPM6JduvRiAv8p/h0kRdcTeRox\n", + "t4PEdFJikYgCJgtFa00LDpNvd6Vv6MImiivCAgL9L7zEaNCr8p/p5ZiDugAAAO8Bn/tqQ38AfAnX\n", + "r+Rl0wYAC9kEZglKr0YEZPxbFiynbDVLyUoB5/4mwbggJCKqWcWLXkOc702XkfuMANGy7OD7QUCV\n", + "nopFHkp77AuzGvvM2JQndhYVkdbX30/kmHQDID1DcpthKQBbzUjm7wgAOqbulxKDc1OUw1plN1OA\n", + "iXs8Ju+zQDtZelKPfekDEF5iPA8IQMn3LLocZ168PVHW73hdmgfMFTsqduJxZ1oiezDuUBPUKdNQ\n", + "1lGg5KUsS5A9iNuo+n1shJKCmk20FfXGeNEywAjYeaq4bao/dd8nZn//htlIayY083IymAgdHbKW\n", + "UQAAAW1Bm/5JqEFsmUwUTBH//rUqgBbB5O6qXkABRezeefAxp9PjwxeDBuTTFSUNk2voPSz0T3Lj\n", + "1K/LmQtEI6YkskJKgxvIXHGf8LHTV/h2Mg/qV3IQ4zvBygOQs98iZyR5jgV+hQ58R6xIcus/6y5a\n", + "HrkViRrv8Sk7So3LYWmfkLzyR6vcCKhF/sCJsY8RS8BK5OOGU2Ll4Qs1n4jPQwTLDELf8SF2+07z\n", + "zB5hexERnOHmWZ9THKXS8j6NXPrj2p32k0gvmlI4b/Of9evEX9mDBp5GtQHOvTswQ/VYUajAUXz4\n", + "5w6EHuB/k+FBz9pe+B69syJ2X5MYn7Qi9rKpCl2kZv4uAWXuNo7oIaU7hr6elcFz53tdL9AEjCAb\n", + "BlT3p448134hjvo9lj95CHF5teK1w+R310Gc3NQ0eeJcsiYD2EoVrHHjVDF/m8I8JtTUFdJ3xm+G\n", + "muADOcIpcqYbeqyKWwHmgvRze+DMQbkLo4AlgQAAAR4Bnh1qQ38AfBSmnoPKZzTuFWeZOcrkeWeU\n", + "yVIALsozlefbqRZf6f7w7fkPoFSkdlxkJJsnO6qzfbc/Kotbm2yeFrIQw5yspszQL8gAAvMHKSnw\n", + "f4CTQ2vfLY55MADj1baDD7LZtn0UK1Eh1HnwXobc+mdHd/JEl/a2Tszf/EZ9+J7oMl+BYsjWKwNY\n", + "vOv5flnnPLcex/hWFIF4n+hpBybvasl5hI9mV0CeAAyAclftj8N9n7hadcpM/TOVmHbSkJ3cr/k+\n", + "StSwI8gY9k3tmbMSZc42caMpFr6YdNCCIj52zmNBccPNFxW+UT/4qCqtX1gc2j7obKDaWzC1yj1A\n", + "td8/VAjqVn+FzuuEokhhvubRT3RCdxeWnBTCG0CxwC7gAAACMkGaAknhClJlMCCP//61KoAXgkIw\n", + "VJpvAgAqN7f+5rJJcY8tkjj7p4LozjswOy2dTydK33mOBGS+NojRzBOlwt3ro+/vdQIUTIVrXKwh\n", + "2SrHPCPJXQoCjJUPkRODCmqbZeBHsv1r7iIOZPpX66HYYhWgPLvPzAb/Nqu9nQqKoyphhNy32+S5\n", + "qAFvjRKLSjPAx7GoKGUNMbYduhsBsrvVTwhrV8uWAls2mxYggJzVuRUZSL9cSt+tjl44BXjlbo1a\n", + "I7ybNHG97GCzcbSNcg0RA+iqwDsdnrZCO0zsNdWK1qVmER0PsSf0dicSrZwIcxZWy6JbkwQn5TnO\n", + "kAah3wAs6pJvW+a5ZiJHl6sVlU3yCOlrECAESqWu0YR75WfiMXgesBOuXGGNsC3icmPYNzM93us1\n", + "7GQTI6RmmFHGo+B2yAB2YJiK1YN/T0ltUuXfFAvL4UdHgEXOVIqVj+S+YpITMKy740IvYQ5zuZPD\n", + "ahdXF7HIU7xE0W12w+6qkuyZwxUMXLXdgx6svudMor1GNfDCdymcKIidhuuXh7vdQrgbivH7usVC\n", + "zjMqgjGahkW1YlmytCooEIoULx5ux9DK360iAi4u/nAomESdiosanRfQ9jQdJSpo4rurLfeCLF1Z\n", + "XsQAQRTcezHlxp1tz3A3WsYMA9urPBB8pUlDdB63MfZDCBphVx/Ddv1AMvPXFEPu18oREsV3BdKx\n", + "e3lxLWWpytzF3zXttYGgBb90j9DgRGE1uaAWyEAAAAEiQZ4gRTRMO/8AWVV6uU/hFqUNYqrP23yu\n", + "FpB+ECoAQNVnJ92i7ZF1i7u1D6K4L4gxm2RaiGsRDmf2iYWEjO8yGHAqwpcDep1/+H221WMh98AE\n", + "VV9Ferf+hy0D7Zu5rX4Hp3s1TpcNcEBIKPHVSHIzaZKKfPXkqE/ga/eepp8Bzdc39OW6g91hVVvf\n", + "WJxrnf77rapWbmivuJFfeO9u+RRykk/agdEi5E/5a475KGQprA2yl390PNrCvoamPyXbETwtbYAQ\n", + "pF9uDZkHdN/NQ1P4rz+zQLJx21eQsP9WBLswpDFYg9BjPw+3VrVEzeid2j5wJBlq+56Hw+Ex6fI6\n", + "1O0GbWSAC5/5Zg+kGX0Yx7/We9PseMWGwXWIVwqI7oHPEnK6wUkAAADgAZ5fdEN/AHk02mburIzA\n", + "1V5U+8CauxZABexQ9zxvy3GIkNn2+19EyZqnRm0DMMsXP4ZwiY8vW/qdBTlATfbmIFDxCTzt76+L\n", + "X3WaNfG+rqTfzj6gLFFHl5IJDtQmIC9KAmTgQM0Lp8TEDdYJnPYGFybq0Xdyl74+130DteV0SYTD\n", + "hgB6230zJvCx8ZW04pZHmYvtJ1LZAxF3BAWKPXcstkh7/Er8zYdPblR7K6t0r3b/sIHpME53VRBk\n", + "ggj1uN/p+iN4KwToxjP8kZ1opB7xpkyOQpicygiGnwjU7EpZpywAAAF2AZ5BakN/AIdka2Wer/IA\n", + "EJVZr+9KNmiS7zXHA/5uJU6D0CbJOrsLPWcfwAUCZZjhlCsnAlgzrrGOONmuxU3En1TfTKb/7Pu5\n", + "1R8PfIYkV/dZFitvMyRPMvzwXX1OcxtjbhM+M0LCh6zNEWJFi2Pi95t8cspIknD4iXNUblA3oEFp\n", + "VGuXt+8S3Upf64YqAxWADhb5zxXL+O/gnWiyawM9fyRrYcExecMkEiv5MHRsJs8Euzdps1vwxzNA\n", + "Zu4bu6ic2K2ueNja78qXGaHz7xLoPIVJv/T4KAuseyOhznfFtKf0Ey0eSBVK9qutGGF83lfe5Wtv\n", + "xb73lHTKLAyiyJassoDHBSQLAcUPb4nB6xWNr9G9gWtqEIp4Or9tKJzZIZ1tnIKZFZGb0ELAlV2+\n", + "pKKDz5nW+syHi871Soc3HtgomT3Y1cp83yQG1GdKkcJPkU1uJVzsVPzbXbSU7/z2Q7cikc4seN2D\n", + "ryQ1l58HjUs0ikCXV/V/CDkAAAH6QZpGSahBaJlMCCP//rUqgBbmS0XBN5gNQAaCJTjyhVwVkMwl\n", + "GF6KXnd0XUyzqjFCJEv0D2xQiJu8if6sKo6qHl+BP/MZw8ss5OKq407INzCjWOsjf2HTKyC5fNLK\n", + "wiJv+PzieOozn64ZK7RRud2QUaDe0kuhk4uCClSYQBImrxmWeEf/X9zH3+ilYhfoZigVm0IoMiuu\n", + "YX1ERVdg0Ld9E6wxbYMiQAGJU1qeeTwc8vb3w3kiJheTA2PNXtrJ98RwtpnhN6QxMe1dw+aQWI7S\n", + "j0oQ9iNx73N93RuNVRxXj/57S9VltjA0RTZBjLvYS81QDA3fBgaNHNzOBZ7dztz/rTxxOpumjTTw\n", + "x9FgnvlMsjx7FYPKUcXD5quVKd8lwTlOiGVI7X1HEv3Hh4EvpYVt6azhUBI1qGunVb3X1lyMhWJ9\n", + "p3muqcicwInEt+BuHY92HoNXaaJJbbQmNX5s3QJbI28Pg4gc2gaUF4SQRcBgM8uwcYUzxEkBS06L\n", + "0moZm8bwMsLYCLj3fgXOyFudpfg6jkYPDeVK811WbzEz8Hcd42XVL0EwE3bwDc+i2I4+NERo6J6l\n", + "d4d7nOIvqUuorZnDPtlYcfSWgBqdP0tQHvFb4Sv9QUCBvXlH2IEiNzo/daaHVtbFRNZ3cag2HOiP\n", + "lMxyt8xYJMnG7di2JiwAAAD7QZ5kRREsO/8AVwwP3fRRACC0tQoY45xe6yfL8KMHlR1wbd4HcPUC\n", + "+4PcnqOzdoNv80ufRyOopFYryJahX+qWFUVKK+nDtdvegTv/PqvENcT8ykEwwQ7z2oNUdaMITYi5\n", + "4tC5YA9FaLSBorMGx3aocAbiF8065MBqyaTkiW7FtGRHVSPubGixAl7hiQRoBoEipfCxkE/EBoII\n", + "omSCNrFRyjd8oY66cDfZt+iBI44uLDeP6eHMEpBALsV0FY7iWjBLaYO1t2PsklOb93SAExoyIX1I\n", + "TiPXiUgrCYe7dgepAF31BCnOuxiIAPWKLDHZLhGOJBLqdemk1EZoKCEAAAE5AZ6DdEN/AIteG4cJ\n", + "hGXgWAAHNd3/IaNiUh/zKhTXYgf+UKkbUvWJoLo7whMXByWkvy3MotNcPaSHeaKS5vKy/hBJIgk5\n", + "CWcdsbd5QzFHyjOIZiaEAA1AziqRPTDRRVYKhcrm181rAlAdaYmvKZAOu92pmI39/PSQjhiMouSe\n", + "XVT3pg0s+/zN7WMQCHqTmey2TTctwD0YnAH9CK4EMAw1jPCCTXgop9epuL/iXjup2S+LS3pGE3iO\n", + "oIHon+1ERGRC2Vp3b2QAstSXzK/2zI+bVnxf0PhgKqa/NeuEaF2SBGZ/TyqGPDnQfJRorCp1s+mw\n", + "tm/3aVbjKRTXeSwl+OCfF6rMqjf/Zw8/4yrjLNmiyOgD8OWqATkM50NFqOShrrTCaHdcxgVW70ss\n", + "cCXKxvzAUCe+4nK4C3zP8QAAAWMBnoVqQ38Ai2Rc7ISR6q0L0pberS7nbElvP1eAuajd6ehFPCEk\n", + "va4007gA4DkP0YAYAumNCN0kma3A2DvFPa+NTDmrilkXNhiNVTFRLzynsy8rdgQPBH6k5DFr/4eZ\n", + "jmJjfYPWB5+2eEYYc9uJ5Ni70hsVFfV+T8zp+ZkLZnd2wv7AZ7A8baF9R5O9oQlCkoVPxkDHTrmt\n", + "rElQhX8Fi0yj2+BVP5O9UNPGQU0+M3KYUTg9yTBG2cCw6Drt49/5M/86NN03F5R9JS9KGOfJjIlA\n", + "koCavGpTFqq7OYU0RM3ilfXBmxvL5QoIK28Uvs71J3h/IvKmg4v/14n3/eoSpqNUCC77ty2SgAAi\n", + "rxQNIHz2GF/lpTynlwsORrYNT1lJMVud8AAQb+/SaHWQXmhJ+8cZTt8XuMgG/t/hdF6GqyG0A/Pn\n", + "hWRq+asN+zBaeyQUWZrjl8ry0h3WPkAZksFb/gV7ABWxAAAB/0GaikmoQWyZTAgj//61KoAWw9mB\n", + "34Nmlq4DQoTYIkneVdOFHxDDrFwsv7yxZXXwNkGuLMduj7QGT/7lr2bNfzApMJfo9/ffM5g789Cz\n", + "1Mn0zxePHMHBL6IHHRVXWyqDMhVLYnQ9xFtc1jml18If/8STBCOf+AZjMnARcFmX1IwLt/ziVSoN\n", + "e4GPKKZqfZWytoW7461OuaeZ9dvtxrCL+W45zobgR5vOrVM+Opl+w/eFlupHlgpQBWgJcPy8sZC4\n", + "/O9laiYA63xx6M701UUvGFsRI+RM6anXyjKc7TVrmZ/YQKRjqB6Mejs2G1mTDkBn7T2ZURI2vZ3u\n", + "VXRNsQnGYDxRUokS3YRHs9LEF/gxKSdLEEiHDqcoIHyS2FPM+cIJRSvB7sxIA3hgfN/O4qDK6VO+\n", + "t71oi1H0Bkz1ugONnVTpQr+WeMS5AtXXNBMXU+ycO0+R9eRe9BwSk0V6tHm/HJ45oIYvyWTj3yZa\n", + "JQ6q+o4isbf26PsTbuSAcvQoMnzEXJkqElGJ8Z3rZtdkIzQW0DDnXeNRbj2wQmuUNBknMsWOw2/t\n", + "fD8BErzYLXI65PwTY+6R5c6RWYzF9HNMLBaO1c6cI4yEu1DMKtZW5FrmVuc6hg7VnWxgAgOdFKFA\n", + "QvmmcrbHsqCH4rkez1y5GoMlxeOuW5WKa/JdcefAflYgakEAAAEQQZ6oRRUsO/8AZUEtmg0dqwLy\n", + "ubLYtABfXw0ri+bvSnwBqWW9hB3/jYP94x5LyZNY560IvuBe5T4EX3/71Gbqj7BS5SJLQ7X1JK0z\n", + "I9iR6McwRU2BDEhu+2JQm1RA2fBVxnzCyNr1JVnfyyuumlkNzE8n1UgnkIbS/FMxc8DghB7zqZzK\n", + "rkagW0hHwSjNf+LJf3DnbXyvnzmB1lcv8Z9QlsnPKDef2giSgbZeTNWRMfeu91kckRy0SSKkaYVK\n", + "KUUpf450Vl2TzPLRaNhk7Du1IJzIJRf9supxssXD9v31LAVibgyznyLU/cS57Vr8KEXG+WpKysV+\n", + "6iQmQ/hCoRg82drzuniAPltxm8MMUZwVMGAAAAEzAZ7HdEN/AHUKF3WsfCAA7NAZyuGlRySXJzA8\n", + "WtPYIqCp+udF6BaVoG3w794kSqeP3syNbVlr+uFhruNMOOzTsNGrbATFZMl9DU6mhIXZ1HEAskmI\n", + "VVSgXlz4sVX35JqYrDPP8r9Bsg/O9tAp7LnTMjWlqOdgOPhHpyqf/hmokPsCwqtKfsDhxP/tmX60\n", + "fhM4KsfvpygzK8jmUmY/GDBCISRQeW6U8uaq8guf+cvy+sP09JLJ4HsULhIsm6kyYO04HBdOFUDr\n", + "/8IzlOKX3w/FCxhimlJIduY8iySAFQmALOuag1Ry1Z3p7NpGIGhZp/q5hzsMAsH2jpHXQPdtFNFH\n", + "4VkqDlRDeGqieCr6gwu3hPQQfF9yauq4qf5R+bfPha9tZ3XjpRO4eqNaj2xEQrcb5cIJOAAAAUsB\n", + "nslqQ38Aj1e+ZhXsJE07lvgA5ryx/X3Tt1hQ2T/wP93u+Km2fQtCsS47kHT/v+BMMbdxEWzwYvcd\n", + "d3NYalS7o/aUthPBRfYGmx2hUIQijLOXN4leC3SONeoCputIRor3Lgsy985K8UL4nvf1+pFmRQg0\n", + "eJgJ9ubt7jVqU4S6enDDZ82+hYwxDWOROomkxsOv8nlizRgAHHE1n42Dq5sLIu8oVYp/4M1h4rCy\n", + "m7AmDrR9dbHlpV6pqPLshIJSKr7R6XCF5H/mgt+78ttEoS2XxbrmVQj6DQtTzcYF1gqzE9DaiXTc\n", + "rKcf1aBAFclenBiNHhbAMEE20Br4FIkr51a0ynzJocMgaUhstOH+7gKJGCsTPkykOiVzQeIGOfi6\n", + "AmLkbzIds0NOnV21ExFbxIFAMu1BymG8Kjwvo1cLb7372R2f+Qt5Z8LjmGrBAAABxUGazkmoQWyZ\n", + "TAgj//61KoAWP/AeMmkxh4qDG8hcZFMZjYIY//v8PGtlbWZ+A0oGGFPTAdgmU2TFbrR0QmwUCouN\n", + "e8fq+V7LhZ4IhSGjAEZXRALCc6lvXQaVk4Hy29vGup69bTfpCSIWWGXFW7WfQjL50GRbZZRZHQ2m\n", + "pjAJ2N9/bloCCNQEfrVxCeDkKfJqKlRpIdnOUaiQpsnEysqkLqMfxaCLAtiv1vFXcLPLizzlMPs7\n", + "NIiiAuhD4+CMokPsODEut5yq6fM1zRym2P9iids6rfyvN0EtWlvUXkAIdmS8HfE5DlX5rtipWZ2i\n", + "d9rb+tQcwCfWN6erokI6tARQJu2c+ZSF/sI7qofDkfNVCHii2Msza0cnJEbLkEfdF+gBET2KrdRv\n", + "E5mgO+6ICEAI6O/h7r7DxvTQ9Wxzo3mHNo6898yojVZYUAEyiEUBn5+alz6XfA0d5GcOXFRjv906\n", + "SVSt5h/ZyjXd+HmcrubYPlDuxhjCrkqyrKcbhfJHp/Mq+DI065H9OXdNO/+uDSHvPcKkibqiAVhI\n", + "DqTA+NZM5+PbtXMsqU6iKpSzqr3AN5mBITP84n9JoTkmCR2U/+5h8eajZc3UcAAAAOdBnuxFFSw7\n", + "/wBlSP3uCsGGoV8bqfG+TF6JTvUuRSAD4pZzJUFnxrFOJYnshFJtjPOw7rAcguf7FPJIlPqbN5qs\n", + "fqCPl7TU74m2w4/OJHMnDpS1+crxo620hZORUqqaN/UeMSuSm/KKx2/MSsIgkvOy0fYS1MAD67Fk\n", + "Z5FUhBYQOPZatG+Xc3Icj+kvLjp5v9fX+nJsaNN4CCl0quEK1R//8eZO87p6DKKxlnRfV62uCNE9\n", + "o2MWYwf9qwHYbtyqG6I4xWPTngQnrsOmiw1Sy0bIvHiKKw6nsCsKdLVPqCFU/q5rppy8Ah4AAAIT\n", + "AZ8LdEN/AI9CIO0JMMhrV/0AB0HLuqwUdobO4BdVbPV1Ioua5WZC0IWTaPE/7qAFTCgAnl3rAoSn\n", + "Kk1336t4zGyyPYAAOSIcqQwF8zee7dn7XFk1tvgy6W/qOMTmkEiEdwceoRsnhNmrNp/TK9OoMIUg\n", + "ShyIuwXG8nP6tDCpAEYSuvpzo5kchXf9jICMUEGqQZjLulIdzbNUEecLTDRk1r3gpdToPPcXdXTM\n", + "AElxf3acmkXSo1kx4tBmKJrXm4kNQ2oDIaqLOc1dGZ+ccoProxsI+jQiCldj17rGF1/E4alcIa3L\n", + "dIofRLGOPkev2msNj9eN+tELiQktxoUq9fKnDsRx9Nbc5IkysRYA/KsIu02gpfPyisLPQwjLSjpr\n", + "jTxnZViCfPC6UCMSLVKUvso8AB0eV8Q+lldoHmqd+EeBeeJOkPU3vuU/GQacMWsLnKmVt/65Nw0r\n", + "y1AnL9+YKkDmvNgpqgQANfZvj5NhddHche/p4la1cXWhY3W/jmtWxMTkOC4tX16bao5sNwcVWRvt\n", + "UHjkDIOIXB+3akBV5Lzaef6YjjT1MeUeFh/FB0tOMV3Bhvdw35krP/ItZ1RF5hRCk1oYqz0ykGZW\n", + "YkciBlvCsweWM2wXwX55h7SZHtxiKM3rO4Aff+TOWGbe8hXaapPE+4wKof+j5KoQ530gP62KsQIG\n", + "BV49pf0LYkAEd7yVzO9dhYYFAAAA+QGfDWpDfwCPWoxxjdaiaFtca/OwfG9dSAC6jYuqYuZmzKSC\n", + "kzbTtnf9idy9v7frgKuFjQymibohZCHRXBQdujo9Laqcw233I4Za+//Mdf06kxHe/IBTsCsxcSfV\n", + "ksVUEdqCe9dEwWwg//4Ee8Le2gLXqz21e4jiFyBOjP5GsM1hpupcfwZtr5Mo/ou28BY4QZExXJ0H\n", + "FzCqK0jKq6c//ut1tsd+kiOyZUVGRAFVkS8bi0vvjrj3zga9Zaa6Mt7yQii43DdcrobbVIWdc0QI\n", + "3+rsc8fgmOnJ+GJGdWYzpFLd5zMjS5ofw5IMBt0GmHVcG82Z6YQkqKJHzQAAAe9BmxJJqEFsmUwI\n", + "I//+tSqAFjc3NgONUfiwAKbp/vtZn3NtK6t0V/4sA0MV4unWIJlE1N72EjQeUPmvxOpceaVXIrAK\n", + "21oMRdsBwM4wyEJDPiji6fXmMlmmsCvOtr78Aj8gA+xKnVDFjoVlH7PPNvnMo0iZJruZeFy1B4T9\n", + "/2iVnlLy1r3LZhoykeyNXqaKEANWeqYl2HjpH92g+fHSONko5D2m4SRKJwFWFllUBg2RTQ3etVYS\n", + "PdQGNCLeaZwhH8zjnIe5Vuu46VBC79Le/PF0x5A18FileZQS8Adcvcamp8leUQ9dML537b7ARaSt\n", + "9Lyu3Sdke9BouNe3+hTyxzxAi1Setn//aNMjVtdKZIT0wLvPIMCsfe3gvhpNMtez9cWJYRUO4qU0\n", + "Dlg6h/pUIog+BzidDDvn6SZ9WUgEXhGZOFeOBYowQfwTGI3ac1V8O93aTpJwa/om7scQbOrwAjjK\n", + "gaYt9yqViBt3FWYRIoJJGYqmGJkf0tLvcymA+Hyayho8kg3J33tLzi7Gkd8xVzsn0AbjvoJ9u5le\n", + "OKsB4L1kcStddnytXouu9GStBCQSRLPeb+iGeZTwQ5uYY8D5fTAcb3C6Ob+B7IWRbbytzq93Kz0y\n", + "yYvbeUq1qJCNW3/zJeXeH+8yV69x5FRyM+55j6UAAAEdQZ8wRRUsO/8AYsUcQvOGOSSADI46r94B\n", + "/W+PEO3biH5wUahFid/4E5wZcJb1S+5KPsyD0qQEL2HibG5BPsDLysut2eDJfU6ijjP6zrYmNEWR\n", + "huQfgh9NsMVuoggiphkYt9ccXxVhYHn++9K8YAnkm28Kzp0jUWHgD2VeIoDjCfJPNnBqH+CERm3s\n", + "nubUQ9LmttVf/+MNJAJgtOFW5A6IBAcBpJtd5kPS+zJ8VxzguhOiD6Pf/zfgjMDUsehmT57QUanw\n", + "gbdNgBf1mSXZw3Czfs4swXmaj+42V39PQblTRJ5hVxxBfyBMHdtD+eP+pUlQP8pBAAnf3v75+Q0T\n", + "L19oeS5dx79IIwiodA3vtFf2KOiU2gODZqY3kJGizWNAAAAA3AGfT3RDfwB2j3tYlaKo3hdLneRM\n", + "Dlhayh8NourV4B4kYRi+kgAOdUf8hAGAI5XCPTeroAwXn8G2yGEphnv3FPeZqmLNmvgLgUkPciaQ\n", + "A3x0WVLvMk+lZn6cJdklOXHEnjNKsClw6wU0RbMDBk1zQUzYb/75rZ2h0N0KqL096XGATDutyhUZ\n", + "RVkyTgfbEgHdPAmzdroStgpcOUEN4xVVZX2E+XrryGs2/tIi+iUaglsBszkGSHUeEuoEpHc8PRHH\n", + "tDc+6s5rO2oABm+Gux/PUd+4yoXEBbF4DtdMIooAAAHGAZ9RakN/AHaNgkMVTymoPnXABzXUf7nM\n", + "R8KlDfCSlxubwbY5y13VVoGV2GO0t+vExf+APmeqLrIGM9X5aCQgGSaQJX4OQoECqyNRzFZQDLhW\n", + "KA4dfYJp7oYRPF8AMOzGYqm7AO7w7FtM2J0yD1XqM3LrKYS1dGZTAzMM0YXyhFuS7+8HWwRTCnl1\n", + "B1MtLMYaA8qvJY/AATH13D2takXBcx78I1sCsI+P57X6Q2Nh62/bggQuV3uhAAN0tyrIgbNQYVBH\n", + "gFwoUmXrxaEApAv0P2E40tM9SJDDcZe8DyE7ljCyxGjQA+gKJHzTkZCCQsmlxDg5It6wsdQ6cusN\n", + "DyWnlyoq3MMo7ugMYcm1YMEY73l36Y/R5wo4wUzuNvV2tJ3rSYBCfXsVjc5o1oA8OllKUpgpBG5u\n", + "9AavXOqCqjA07sUF9WlQ9JPrhiXa9bThYRp0lNBazKKlKwsBPK9zJ1/OayuptCCUOtFLyDYWpp2k\n", + "qNXWH8r0IpnJjxnQFcNmI3LKk+rH0vqX+48vd2BUqTcJ4rwX4e+V6oU1+lJyU8fmS4Kj/iQFUx5A\n", + "ntiGKLVWwqfkoYN2YexrEPVBTpKi81wf61aU8NAxYQAAAjdBm1ZJqEFsmUwII//+tSqAFj3B7fR5\n", + "G4ADaQx//3+BfZIcqzxSrotcVc8CLm7cBBc8JifUTg3KyGbsl0UtvUGR3t77PRffuzjjVfcKeiAp\n", + "EmDpLoqmMXTQU5wmHksjapt36fasfEiGyN1dOKyOI9nT0TFFL0pzQSss7Ux5GajOaQUF29zSIoeo\n", + "7hOusjWiFyZylISVuEBU8nCgDYn9P601XpFko2u3FAuYp/svCLJOzc9W7b14FY05eVZdhfmiv0Wm\n", + "d+i5ZPIv9mhB+8Cb50V0LQeFfsyfPeAABtfp/HIPaN+amWONE9vQ2YbC1JsqKljPbi6Vrd258gHB\n", + "PNyXvESqATfkK1Gnk0AWxo7XFr5y0Ce95pJr1n6gAd91M5RV5lL/XAgE7sYG4524aA+cXAa2XPdd\n", + "1BugfbN6YGWbktwAoVIXoUq7TnrmhBrw2FHa1aE9uMJerl9x/Rs847iKP+iuBUD2VIUOVa/G9Po0\n", + "ksPo1bHVIsITIKnrhXV1NabDgHAc5kIv+PJk6IroGA19oMw2I1d4rGiaYQZE9dmK1VRARJ9VXDBJ\n", + "Vlz3aoQhCyQZvwzvxWhVA1iU1RO1TWnJsppajNeO4Vg4/b+BSviIvrSwwqmjaRr8iuCpVTgz+ZJ6\n", + "95zLiSdnoIFqQJA1Hz4YR/KIOmAfhTTnHcdDelso1m8Bx2oHlzAOiYwR4NhSSRD6EhhCU2kXf5vn\n", + "vYdShk1Y3/pp+Wd9yZwIwTneJB0AoI0bbmfrtbbWj1oAAAFQQZ90RRUsO/8AVxVRwqizyog1fzvw\n", + "w3oFk0s5kH60rPhj0qbUv+9nJnU5H1hbksC+yivmpdt3FAylOp/Re8NoooEKQr4q7MX/kjNCB5zj\n", + "aCmG5E3TxVGWGCYMCsdEF1I+HuXX2a3wLCwf1iqCfznNMRG46GE6nIgxc91oY/zfMduLLCzyb8AQ\n", + "b20W2eRODsXd4+7XC1RndLreJ7Km543AdL1iUo99hYdoASXjyWRNv6wvJrmyFngIDlQOrLluZf/9\n", + "T8Y21pcggXpfTtvdj+B+3lZv29AFHkL2xGPZvyL4UyVUgb3U1DWd/iySeGzlK1IbRNu7obP1czi4\n", + "Rchm1nI/pS+cSuamJbhlQHIreF0u2/zcrSGkuOpbObSfAY//5j6RVfcQovw5wL1RQN0tcA1GtFxu\n", + "ZpovaLthGUkeOPh8iV5bEpupJR1R79Ew1sEkTDugAAABwQGfk3RDfwB2dNpntdq7wHtHkfExb8Mi\n", + "4AOIW+6weDVD4WeLhja/JOA5FtORnuW7CfHWfWrXcPJWyNJJfpx2maEKeggtR3RVEAdA1a1truYO\n", + "N3PBvt2C5hri51AyWveiUQtRNh8OhcT8b+NVPo5dLHlfN2wr8ZipKDuUP3k1md+EiPqVCrK5TuMQ\n", + "knvfHHEV8fXqrrFiHhWYrAGbSJdOrXgrQTN4JDv0LMwXs1Nl1nmEdfSgT5BF3DohYi4r2xGfiJcJ\n", + "KMZ1oPHaRBjgxhu40ZP5HqUG5rQWHD92UCH/Terh0cf4e0554mxHgDF9CBXD2Ey6LaV8LB9Jb9nA\n", + "f7tFFMQRIVaLiP+uig+B5OoeaCY5+GdEeHuY+ZE9jNToZ4yOUwNfysZaXJBrtfqEkQosI3EYRZQA\n", + "COu9BHjZjXsKjEmWe9Jj9yWusbXq4WMANyEJEPNSeDcqy2nLsc2OqSE4CgyCqy8blbRZqycUiZt/\n", + "3NpFflI5dk/7eeQ8Uo727U5FhceNm/3Tv/0N3CZNlPGV4f+3/HHJknpIjibzMw4AkTq3Lkxy1XZ+\n", + "FA9yAR3cZ0/eN1EscyudULe5dTvs1EvlYMWBAAABtgGflWpDfwB5Rz+lHWcxYALocP/IVGxKQ/5l\n", + "P8Y+UVeKYTw8iTn+GjVV8vbhgCZ5cI/70wvHdrfJYaZZyRIawh8+61+/vwo8HAkEyAQL0QVrU8Db\n", + "Z7+ORIRATWUQyS/LIyP8q4/O5rf7OuybqgrrJ5JQm3dvb5EYgnYLHCULt4xtpfvTsT5gEynxu9HL\n", + "Km20sO4q1oqcF4MPx2dj7xETa3veUfVJqfvwop/9NWsmPrdhY/wz7rinYt2HcWm7+ulSBZtWIRv3\n", + "yMRoNM+lyCvZDr0PaN2HfwYWOYr/NgyLM3qvI6TujkJkGWBIPuiFK/SHsSPx7iAMcrZ3CQvQC1rq\n", + "psLEx1Lx0vtWsdQAcjEYe6l7VHqUFbgcjcHAYPQIIgi8NauIxLhxUOQnkJo1mXO/e5w2N9AAHA22\n", + "RlXXsFU92TGe3GmYdLlI4OC3IklyabPhxs95veQzY6n0a2BnyANXxWrQG1vVVVAYgtb88NEdo6By\n", + "gCh1aEE1VpUTP0of4shaZpNk/2gd6T34r4uIClLqdADAAdaA4/epPc357p2Ro8OkrT9okATGaQDM\n", + "AYBiPC2kAQBkyn5ImAAAAdBBm5pJqEFsmUwII//+tSqAF4In0o7iUdIU6DQAMu59v/f4eNbK2my3\n", + "LFfU4bVvmOXvurgANJp+yhdNshfKZWyf1yiq02eNo25TtXkBg+c9UZquU5KtxkSr2wTyRJb5fWbg\n", + "+NL8Fosje7XYkSxYEiB3sVwPhHSvNWh2d4v6fN1lP9qvuUnfb1Bn+TdruqmJdM2vx9efbO5Th2CP\n", + "KiH3jeuRzoCzSIUG7cY38FVzT4nUIJdz+2KjjjJ0E7ZNKQ6lROaPqjFN4utrXaZfqGFX2nWmlL+h\n", + "PxS7plcEcSC1oWpbRWphWgodqD5c2VmFV0yO9NkxWYeDoEeaPVORAB/gqWAbIHdoZVHMBBV6fLyv\n", + "D3u5FppjGB4tzB+WC5jnXJKg0Sk3SkInESay6cwWUVJt/G4Tfg6wbMdEkCvCKlRosg/RTpp5P6wR\n", + "Z2iZfctuN2EQi36vtriULh4PVI/bw9ZXWlyhMpAYPlW3C1NvZrlJMNaSqGSSnh5cJMfrxHquXcAN\n", + "CTgojRhZ3tMe14Ny/HV3UfnpEJgrqxN8KZxlRpYS28Q96uqEu6NBBsBIIz0ei/Mg1x57c0aguL4j\n", + "dVBDXATm12Zi0uXfiRBRiIror0O2CDrlUQAAAPNBn7hFFSw7/wBgSQL3wIE2Tv5B6OJXPcoXMcSb\n", + "cE8qv/1v/uy5HaAJNUQCTSWlcVovOwe/GLZOdN2BNEgb1OlzNEinzyASzg3GuZ9zFeyJHe/zvxXW\n", + "qHgQlhmuH8QdE1M1s5tXy5mwAyoAiCrzupaN60ez6jWL/yRvGdGiPt3qJJLeMG60zAMKa7QhUJFJ\n", + "FMWUFrcLW6iQXx7VTZR7Qo0gz/aCe+BxT2h34J4bdpQTH59SHjOd2X4DMr2kpW5buE3EQBEKSUD8\n", + "yEiNy7MVRtsZHXt1V4Pb6TljTGXtC9pzGwEXtgadiRP8dhtDjxgpVN3IyoEAAAFOAZ/XdEN/AHkx\n", + "u7J3fsEfo6cXtbkNOd4swcOB3voAJyKHu0c0/MGiiYXv+2wca3XUwSOEG+s8df2rHPxj/J/Armyt\n", + "j86AAAWOWZsl8AgjGF9fWv1mQf9jrWNuA4APvfeLBFbZJZm7otp6Fc0DFqB0XCbEvLTkRU5ySc7e\n", + "Y4CD3ziWyxgWkLgxNxAV0V3rzOqUGhFxcTbBCJI75knYyulzgB9+SazwgLVSR2N8nND844Y7GLCN\n", + "0aeRWZgNIAWJkPPhP1VnSRo1jOpV+axgAXL8ExpNwIvLk+O8lekZ0/1o7sI+uJ46XyI2SuA6uJHd\n", + "bwUKNMI2qDKAM6f4kKlJLSQWqzXAi8hAQzI017i25Vpi5npQJ4TsJeyOHRvmO1wY5ZnIEZHyhgB4\n", + "IoLWrdA5opbAou9XxH6m1F6osqepeJLd97Dr7+5BqWzoHoOLhOxNwAAAAQ4Bn9lqQ38Ah1fDGltb\n", + "SoFNBABy4LNe514R+dnaDTYn5E46OmsRrJgYyAm1lSXdflAXI1+CFQXE0A4eKb0poyZSLaaXfRBJ\n", + "r/tA3jW8xYt/UxFDszVrqnPHP/Ny6pw3mJ+pwWr+YYAHxNaLyZj85nxRNPFMUkOr96iCB+MslYrg\n", + "cr/vUoZCrrFka9nw08yFJlyN4Ky9KHUYJOXDrBIiz8KQQaHFalCe3rENKk9raHLB9E2PdI37xydW\n", + "9R3Ktqa3KW5rMJCOoArO2/3trkkCh+/FDlbsei4VdbDQ32DjCaAkDFjCyuqOJNsi8nSI2KDSRFCB\n", + "83l81kCObhPemVMTlMBQzSDvOtDFUtuVwHtirD8AAAFqQZvcSahBbJlMFEwR//61KoAWweTusUEY\n", + "AFR7WLigAceU/KgvW9LBBRTRioW652v1Xpv5tYMFhkRmmlUca4/8lM9NJwOZFgbdLq3dhRjr1SQ+\n", + "iitgTnIKVe77qt/yWy3INzcVxffYfGucVy2ypyvLSUZVvVzu37Ufe4d1uKQAC1EE3Wwzkx7sEK4N\n", + "QwJyCdTZZnLiyrlEXcLAMbB36CvMtmCiaP8XPpa1U2RaJxnBB9qYeP0+JCORflaC8m/hyWfMppd0\n", + "XeCFuAYTEakC9vO4HVF02QH4GZZigg7j7bXnvstEtP5QgYZViZcOoAaQGKtWm3PCHoS8mKWfCUk8\n", + "ZLC6z2a10V0U2DavVH2m02W1Lc4/2WzrwUTHr66DOaP+urnPdabeHdXruv1HJ087InGSipJtxGko\n", + "4rppNbdlP4z6g2o/ksCKcSZ76uS1diKM/39wzVYDu1tkCD1lomve9NoQwUToKqCn30PDqMAAAAEr\n", + "AZ/7akN/AIdka2XuDkeawxOj/BZhZtP+kNbRABb4RmWT8vSOMSH2HVKuz5/n3pn38gQM6YQqY5bV\n", + "v8KsLMWKt//3BpX7BUiSjA/GsXEpiGachc2o+KqjjRfujy3SLc+TvzNfgePwT9w0Jj9Y8j6ORxA7\n", + "13x9/iM5Lx1s2OQQyRluiOYKxXDE9QjNulPCcMLJFKpvAfnZmzl0pzzHw/ANcBEDhABHQ9ftCkUs\n", + "Q4pQOQF20mJ1++bXoRcUz/lR79ACwohpzpGuaQCknCVhUL3lnnyQzloB0PAIRq1VnOd+y8D18t8/\n", + "IEva3L9FTrRi90eT/2pNxjMaqrOmFzrhjd2kmSd3YBlll+A3KrjDn/HtXx8SDjztM7Km7BEd2LVO\n", + "U1pVGn0+C8gCov9gxoEAAAIMQZvgSeEKUmUwII///rUqgBet471BV4xl2QAFRvb+6Uilj9hVaCt9\n", + "oXOXB19FM5G4bNDJAOl9w7HrxMOF2dPOUf977Rp9NoBObCR9cN42Ht77Y+l36qfp5SrWPFz3DG9k\n", + "Uks1s5yfRvMME5RxPYk9+qohbe5TR7z2WNWBJjaTvhnu4485WU3BaTyIbA4BRRdj0/JwsbCXRVZy\n", + "OMmFdXnFdxhNGZ5JMCQy+ip435WTv8KevLzG3OUTxX5d8x0gaiQZdaPwNC9GVrgmtqTc0z7He5Hx\n", + "p/UnXiE+WgHU095CwXga4AbeOtQbj0tjxKUoS9sAoJ5fyTlHv9FnU0ujgUuoA3Kj0ma5qF69zgnv\n", + "MTXEIqf8zuYuInk435YB6s5Aa1W77q49/ZLR70JdKU9F42nWnuaGIFvaX8JNp0NTGvA0s1VSOWIl\n", + "YVdpY6hSPbDqLYXO/LE7X1D3sWpexh+/kcA2B6pYDzx14bD7OD1f9pMDWxIrW6BpNH75M54gOMY1\n", + "SxoTsfh6KVoyFK4Yqd6lPKCLY4O17tm0vzqLEva8zNeuM7b2yHKwMHpqK8FV5yaEer9Zd+uSgIqd\n", + "eftECExc0GDPrda1mDLPyRR8iDjZRvRS/EElnceTaWiUEonB934ThxItQqnJINdKSyNdNwx44Jgq\n", + "H9/Zh55FLA3sdVDr+1aesKMfNmYnbwaje7GN0y0AAAENQZ4eRTRMO/8AYEUc98FD5/CYkGD6VZTK\n", + "7qaMD8JeD5Yvz1s+LaCSFWcn3aLtkXWLu76WBTjEp2boTz2lISGgYIiIhTqGBdSAvn4GaApcqQ2+\n", + "sy0LjwIg9aZXDdjP9AWFTV1H8wY3dWCf+Rn8X8p7dsAFRxXZ4015PG0t6STtIq5DOqARSPJ32oCq\n", + "OenP2L2rQhT0bU7kBXZqDOvuedMFko4K8dbR3EOKtstAjt1gHGNubjQIVeNhJsdrdMtXEY7juX3P\n", + "NuPteAILXrR8S3R5mIOtuZ+vWEUdS+Inr7FnZsbQiIv9i7KDzU2m3LJLNdjmArFBBLgFXYHDvQmL\n", + "9VT51Mb8gx1TyNar/CPWDggAAADyAZ49dEN/AInJdfYNr4ilmYSAMFB4GADpypoeWWXE3q20mGL8\n", + "wfGmH6ZgcbtTXJWZn5/uB2IPeQFG/rqNYZ/bmIUcKhccFRuPa9wOgu4Qnm9oi81y+ChWQK1KoKDK\n", + "TWWDeg/SDhV8w/q9dFY0rcekgnjPKbKFgzK+IO7hoMF7vhpMoVCqvwMtBaesBfF4bzxIufyftMba\n", + "VRaJWuZpM22/FtH8FxujQ6EjGNr9PHZg3rsxXbkYHRqZvH6RGypNdfKRL4serPMKtCeuCWEKaj1Z\n", + "h+pr+ULdNvwpLLHfA3OCu3Ql8v/sLDD/O1LVB9ug+l/wHpAAAAGVAZ4/akN/AInJdjcgUcZACEqh\n", + "GvWiTtr19IbQdv8WE1dBOa+lNipi00vM+C9W8F7IDH0aaS+KKFaekfOwUNG520lVemVKNYbjnPl7\n", + "LimE+s4N2NJ5SYT5+XRMb+vTvKCkG/By5wQO/WbZo9HorEm10+Tu4CVIj+2Ky5hDZl+kA6mkBK7E\n", + "3LwAW+4rGYiO9JH1BLFQj0ZOJq0ybrdVynOYOw8TudsCI+I3fiT5nmYCkIO1N7h++s67fASBLfgP\n", + "CYo7yLNwfifRM3ay+JhoRmwX5tGJ8l9w676Zo1wDaqZ0Q5guAYSxSJk2jHShR6LxlZmIVJnq7S00\n", + "iBOM0mxomzMhjpxeX6zqy/aA2SEREi4ulxZsEvlIWhLQ5YFv6LMkVEh9RITRQOsKGEls7Y4eSRWc\n", + "f23FGWOVxL2MZUmPGVh++Xygx19XCiXwoatt/s2T7zGfLkQ2IBiMKXoeDb7yiR4q+0v6UjACWT2H\n", + "kOIRMpG/B4KQPsfMRT0Rk3cAwV9dNnKm4XTlo9P9TmyT71B/Greq+KvhEBDxAAACJkGaJEmoQWiZ\n", + "TAgj//61KoAW5ktFwTkgtAAhBassVgP2a7WSOTniW7GlpUC5YARIimzpboyDKn/53KIxVBS+A0NS\n", + "3NuuWMzq53zfHvhoSdYO4dYooBUDN2VkLpVK3v3kQo1FoE02X3cyV2j6ziOTJORgWGzqU5k0XKJO\n", + "1VCPDS1gJclQYem5NlGAENmSiR9I8XvNQLGvpLGF/2+aU31xCZzIPp4tUxyLu/gVqq+6L5DezfDz\n", + "gPP3+vv4JFttE5Nyc7LysmCaQfUhi6zPymHmdLjs3bZdma4hV61UMMsGBNZfYf2GUkV1dVZ9kkfz\n", + "RyUYJPFdwjA5S++T8sc03o81MYXnXYkO9hGiG6RRLRRV2fPSgGhghnaqxRhYVQiuVS0ENIpjxqqc\n", + "KBEaAMs1VoaLKEOrNhZ8yB1VLLV9KSiM7/prkkNKRuNLp0WeTv2eHtXhIdAfhKb+ic7Pb48CqpOl\n", + "FnnbgphlxDaS1dplrA4VxMNzEL/27xNMQzhuRvnSDNb60j/kSJHw5x2JG6G/VwCoVAfFrZll45AB\n", + "Puajv4y9+7flMd/pR8Rg9UAn+cey+vNCcCbbn7FNSWq2hl9cymk4fwW6iqBgiFEQ7YZtyDoNCyYz\n", + "KAnW0gvHCg+5n6+qxC+xDS291Y4JfSW927ZZudU0tXxvupwcKf6fDXxz/bqsOMvxj6Y81+e6Dezh\n", + "B2/8nCpk1Qc7N5s0JoStEQ8+K2ir0vIXayhFQIgAAAEeQZ5CRREsO/8AZTZTJbuKD3PiQhYpzA/Q\n", + "3Iqsld8XUz3sHppFsAHZevvXPBLN2cIUd+YCbEEH6MplVFEcbuDDV0dnlBcrCNrbp3+CAOdBsr6h\n", + "0YfLGDPxHlFlUCi4qTS1o0TT2Jzkq8/O+TU7SSImG1EjEmOGpKvxjn7KxERq2Pbd/0y1sNHk5hiQ\n", + "eJwHwc7Z19aIrWes4h3UYQqHeU6kfCpUHVgnGubU2A0Xjg0UrouNSumFogz0StLk4fuhL5slF3Bb\n", + "3NpP7YhgiVLV0FNM21/pfbXvRQFzmliOaZuScgePqa02nvOdEHEpGVRPLCGL/tvzSkZqhXResmQg\n", + "1qZ/TxlvqjWYqPRThBIk2nP66jbd6NLagdWz1BtbrwB3TQAAAVkBnmF0Q38Ajz7dDL7wKLyRAA5r\n", + "u/5Co2KbB/AnQg3XvWeaImUuto8KuobiZ5Rpi0jf/+r5lFprj/mYxpQ5OwqjQqFG0eXwqi1D6M23\n", + "HLH/3LvgYXkbAAGr9uWkQaEU+TeJ38WNXodDC29t8Y0uYEpwNzyC6FqtgkCyDYDpd/nESpdVRRJh\n", + "15SV0TP88AKwZsT7yWH2r5gpJv8AhXnnWmKJ/WMwiS/2+Kf3ikj614P+BDohXhMYGO4GSZ19EkRI\n", + "RjwO1zoy3Umd4iOMuBBPzevAs74sU7IUdkUF24rNAstoyqnAUgY510L3SgPXbZmJYMv+tRpT7ZuM\n", + "oLxE5ACIQ+eHStmGZgh2P1nvrIaZRiBxoWZ1B+DDOtu5OZpc7LbajGP/oy8HbEFyJIcGXHGB5VXY\n", + "HnskMmabuu5xyFIJcVaqbGg3TlqrbBE29OX6xO7K38oavU/okVlIM+AAAAGEAZ5jakN/AIdXv9ZL\n", + "/wCpeCQF0zyG8897iu+TVNq8xXl3pE8eXm424VBKoADmOQ/RgBgC6Y0IzpqUKPVKwCZafdEIuhUv\n", + "zhgtxewRpr3F4VdMy9NUqqvPfGroLPxDW64Af18RtCEv8t7amX9ezvEWK8AgZjHjHXeVi2k8dp4r\n", + "TuMjdngEOGe6y0V0qXE0vJudyGSblaiStnW6rV0e34JxbdN3Qbajy6ozlLfOkq7Wqx1iLXxa4foY\n", + "IPBIjzxdye8gOjZW7bP0axd+wppVHkXrrvuxUf9dp18AanJIIFv6MCm6ujRO2wyu4ZfSbZp/KVFm\n", + "xvxpBAJyjKSdCoPxWylEDyms9NAmwAADmUiy6WUOIsiAC130X9MRKfeLHi3miJh/YDGeINuX+P+e\n", + "NWBXxp3RqAzo1eISPcPztmgXUHCSN2VRpnCOFQoF4yyryK4v7s2U4a7V5e2sVJBhb7kguiVFACK3\n", + "rbLSCnWI4OCs6u017nghnGW3Juq0rF80iqmo5QCt19S62wAAAkZBmmhJqEFsmUwII//+tSqAFu/w\n", + "HjJpMYeKfGxaFh4NwH9VzFzipiNnWLhZf3lim8qQP0NcWviT9hCfSjxxrnYEE59yPQn7u6+tCr/u\n", + "vn8/iyWB73TxWIDTyqwOWzo0R8Wj7McP4QWP8yE0svd//Wkug5+3cHmcpP/ONbeBn+TAQ0VzErlc\n", + "2hXFLnmGW7EB004qvGi/S7JfG21T+V5Sx9Nre0PuomioWltV0uJSYiMg18UwZktQhoyeO+qpPgky\n", + "U9/xX6NUrUyAfCz03v4wSV58lpzV7BxftApX8ZGWBx2zWQV/YeOCEWbmbHqvN18Jd5FxK1iHRqe+\n", + "nBGg6SyBQEQQfCMxCo37AXM212ulRN9X2fE3P9HkhvkaOxQZ5AElyFJ4BlaM9J8bcUgOX6NS6Cqb\n", + "n7IHMcCIPjAIJ36atWVr0EheDYyrwatT/sRxqfSoF0RgoVqtGqstMXZF7XACu2N9LDV5Ss0B+mSl\n", + "kJJqGxc50wazbtpofP341QOLrRCoQigLO2IFkJyqTpln4FgoWIMbx8x6cKkFmIESXv7mZEx6LOrL\n", + "ggZa/EdzllkBPCO/+zBjmey1Y55MrbMpoidNDpdQ6yZ4UDU0ai3HtghNjtrUaVDC+dCrSCASLB02\n", + "bO819PX27qwUTWW1MCrVhUzQkUkht4Xa4bdnUW7zTudPa++EPxUMVY36vPDJoCGilCgIXzTOV6S9\n", + "OVTh4+OA6S/XkcoA6ZjbQLERX5kZSQMoFJs4bPot93titzpDSKAhc1QMx6eKK6Ol2IEAAAEkQZ6G\n", + "RRUsO/8AZUEFdKFRxHYcrgnLV1IJewAc5dAL6/Pr5YWcZb4ejev9b/lpY1ea5Xk1AlTe44c3rPkF\n", + "DXI6yAdEC7kxPh5StAse03AARSF2nro+Dr5bfPJyYF/ERJ9NScPmUIVihvTCsyh5qmuoAH9P7eCu\n", + "Y8rdH1hF/pTSa+Z1tzZc8gwGtgV/YsMtlWLs3VbLWxt2KTDW5Y2b0HA6zgNn25rXu72r6iiN5aw7\n", + "sjFipq/8rjgHE9K0EK2Opn+0SPK2Rbo28aoNdC9V8VxW1CpMNxKjFOs8YmQmJE6Qtkw+Uo5mh3ic\n", + "7Ng6Xje5wAF7a8Iyr8DMIwvMZnnVp6ilQ1B/LSGEPncviRIHH8w83Grtt0CsL1L2isuyMboY11N9\n", + "lxQPpwAAAUABnqV0Q38Aiz6zZgMl5b2XXQAXQ9yHCqNv7FVD9CxHdTnw5pqRTLAoFiba5ss3lqXG\n", + "QCf4/o32jzmzNKjZDN2ghdo3OS7n/NFKTMs4yX0NTqaEhdnVRvrbcGvcKo0NYMgzE8UNwneueU22\n", + "1vpuKbOkae4P82iS9XSi8TlOPcF8mmD+n9qfVTXzL4r0M/s5xxZempvnxqhz38EgmSM/Zw7kEyiv\n", + "giyuP/YjNhFl3FVcOSLiQTCj+F0nLUE7lia+UkuO/YNBXwUKZKD8Add8BG6ZTC4bD/RSktc7uv8w\n", + "NB82AXgnpuELTB2xZFOLAYJncjo03/3uAK678Cl8cw8fzlbnSpp5eUkHacCUtAY9LPrz/OMf2bA9\n", + "vBE2eUwrxz/W0Sg0tjzkUrpnJSF+xYsA2fgRolT6A0NA++mVN8PJVhaGzQAAAX4BnqdqQ38Aj1eg\n", + "HO2BrhbSJp3bjAA7Lyx/X3Tt1hQ2T/wP93u+Km2fQtCsS47kHT/v6cxSu0EEWzwOVr17m7uMIt8s\n", + "rOS2NL0s+wNbNsQiUhFGWcubxLdtukca9QFTdaQjRXuW15l7gz2QnuVPe/r9SLMinrQ8TAT7c4JB\n", + "GrUpwbYY2wvPKUw4NOIKdjGz2TGxM02Yhqm+YQD7nu+MPeXg/5dBf+XeKfPK+RchTbfnRfx28pUm\n", + "+MUq+ynmpWVmmfO3TbD8gZCbZRUeK4LOH5lP3nvVvkbZlQVhN5vPlxxNouZsDfsmprxmWrHzH3vb\n", + "E+c7VsDA88L9wCH+ZmQGzxFjyOQ8cz4P9rsZSuU8vQS1h6fmk4XXUosrmweEGKJT/Sv5qb0OG8e9\n", + "voRxFaPrroiqkALWSnA5n4zcQMwfY/xXX1aR5rslt9ItB406qJIsbsrkl8pXUe2CwOVm9B72bhd1\n", + "lqsCRNktqyPMF/Ek4JsxscPvDjbSqbQZL+uT8zjgAAAB5EGarEmoQWyZTAgj//61KoAZQB+OVG5p\n", + "SZHABUb2//v8PGtlbWZ+A0oGGFPTAdgmU2TFbsuJ6mwUCouNe8f1I2ythN04JSJ5lx+ik6KpnC91\n", + "1FD3eD5Jit+kJIg5holbnldcijL50GRMV+Tt0L65TPBxqSAUdrQu+eLUTHPpJCL4CV5RJau8pEIv\n", + "uK3a7QA/UMQ/nrDjeZ6jqf1BF3JjbyaeIc5drvnYbR6lQ0gBIzp/QRU9xrHm8FESnIe42aooWDJ9\n", + "bVMccs59QBQd45WisW0MXV7NFtyepgfK7biPJN57MDsWL2A4LYHAXH6f6In3GVsSrYQ2HUKGlxpv\n", + "Yf/Xvk0pBnHsuIEsslXTjxwTTzuRb2YT7QCJp6yHiUVL67n8RfvHMNoHfUzP4rVgPSXcPL8FOP2d\n", + "F8GxovHNOmsOSUyc+t9OZXQFF+4FJNSN23FsgARohBEJ3c1u0ax3ACLYlwfCd3/U1mT29ftZkWMR\n", + "uj01t9v2AGHvgKM29X2Vs/ALzLNDd2OM9z+AC4TlcpgcRujIhnjHf17Je/8RMBqJCZtdfrFmz6AW\n", + "Z/aNIv/p/WX6adpvStFWxoDAnf+Tai9COS20TO4GHDviQkpMo6tbNTk4tiYWsmvBNq5u/aO08r2y\n", + "Bs1eH2kAAAD6QZ7KRRUsO/8AZUj9pUTz7rNMoHjJ4gSsLw2wABNFEVCVBZ8at73oa3C8UmeDMVba\n", + "M3uHP8p2EFDXTkl9EiChbxZZgpuvefKfc50lYhoTJ/7H62X0Z9NX2I7S32WT1XJeJtD32zfVBu3K\n", + "VmE+30x6+W2pKnyMM0ZejDKLq8WyIyi+9rC0QVVyU0N739nDCyt6aqRfMfSdljqTnwOmgDB5pHyK\n", + "U8Nf/BZxnIET5uBVX/VcS4bjmT9sCYYwmAz5vBy8cv5J53FYPh0/wF7kP2myhm8SfTnmNtpTej0y\n", + "JjLbrdGSBUAu+lwbCsr/YdOCYrxvvrklZP4j4s5VlQAAAgYBnul0Q38Aiz6zZf6skuDOogA4jl3V\n", + "YKO0NncAuqtob34dJ/eVmQtCFk2jxP+6gBUwoAJ5d6wKEpypNd+AlIf83kNIAAC8trXyGAv3zzzV\n", + "tAa7kzCHOXS39Rxic+qZEHcHH0Hx0iIZnH1UNeoS6dQYQqolDkQpOXG8nP6tDCpAEYSQsJzo5kch\n", + "Xf9jICMUCBjMQXeVS1i3FdA07mrKCBowVzEdee9WvqvXV7KuMTufiL0hA8BHvtD6VFvEZ6eiqgvN\n", + "8RNM5cYXQ2i+4Lx4R2QlAIN1NNxqM8GvSjSh/rgipqY8DwHJh8p9Jbu0Zs+w86pgxJN8m/cvWxRZ\n", + "yFAtI7sBhDbJnNXx83ll0o93YVJhxi0TxWXPf6PlHZeEyvr6QOF2VVafQjsZUg34P/p6tj3lkAer\n", + "aZouLIrbfbTrpoGdtXuXR2qC418s780GZsUBVTlvppC7dgGYqQzB5daoV61BoiIg6tQyG20Yk/Ib\n", + "TtwSJmeU5Eiu/zRo0bpbU2jgV79WVCB/SVzxsmoD1jJEhzN1FHxsbajOijl9Vp76GofsezNr+37n\n", + "UWWhPPzCk1rCLQgaI34ekcMUWq/vBK2WDe7wKACe/5M5UglN5Ct9Orsd3SfYPc0336usW56marFA\n", + "xW2XgVLc1GludnoFyQrT+oASHSl68jJc1j3I4WTIeU/p+eW8RtUF4AAAAR4BnutqQ38Ai1egJmdK\n", + "YqnGBlYUAF9obzNVJ+s4Wyt0Rq0YuZmzKSClvCu/741bUzMW9+2RqBxHf8xROd9WCD2DFO6m3iiG\n", + "ZOgLMC6WQsGlrWDKBATBQkW8M70y/ztO1ZzNQj1ow5FREW75+T8qWeYnaEkP0sDPfhS/8A++EHpT\n", + "ONUZpoNHugOpCj8EFvE/MnQhkWbqDB+V4zYJeD+V1h9PGTTPeM5Ykyq4ZMi+8E5Gka9dd2CFXMaQ\n", + "M99mRo+FOH0+y87A4U4JusoMgrnGwBHn7tNdR1Jgk+wKYqmIwBj2jGPnQFJXhHhE3ZkpIjaeakM2\n", + "8MH5c8xC359KRjK1nfiZHGSkxS98YPps7lGGiAJ2WdM/l0XaVpItX1VPHy/wAAACGUGa8EmoQWyZ\n", + "TAgj//61KoAWNzc2A41R+LAApun++OIZUz7EikV/szjfxvYPLx+f9K2/F/he8DHawkBMdV2wRLxA\n", + "t50GIuRUSWE/39Xo4nAQqkjDTJdufKMgNIx0erMAcY2QA5ejjVo1tlzncJOxCqGpuGwA+5/4IKyu\n", + "bmTzdPecTw0ZdpVPq5j/sb/uUTmyS5oriK2QJUn4uMhurpWU0pM90BFHxmx/55iJQnC/E4AiRjGv\n", + "TSfvy9eol7L6q3/AmWDGKQmta5h6TQecJSS7keMMTmFMkcgh+dQEUTFbphGIZpTz6vxfkWPPyqpQ\n", + "VmS0gectGBeLssajkGiu1ivhXeMUvGnpqjpc6XSD8FJ8sVdfwdsse9JozsVq/t5YFq5+AnEYcopl\n", + "mlIiLVwif6/glDa/FvPVZyUrYuYY9L3TA7eEHe1IcHWSOPxpnafEFBrVGoeZPrbfymiVcHOQ/3CX\n", + "aGrpVwdWrmOHr8jLuajUxWOW37ajHobcyT1hYWMxRTx80fZmsfvsrNw/Nztdx7LidHGE8jPZ4gQZ\n", + "DABlByR/bof6mTmjqkfbsR1PCXy4RDNnn9nCnaSnb8pCApsF6YsDTv0+UmVzx2ZPSdm2LhZIqOim\n", + "mhiXHWt+ZE1dnYkLwTdsgNYEeAUTjY5XG25CAykSMfKGwGWeeOwqKmLAqTmb7mCXXxxpy4+bbELo\n", + "RAxOLFOR7z+Rlt4VIVMH4QAAASRBnw5FFSw7/wBiyP2mEJvZyVx6ACpM7CM8ZBKHKR5j7ndOem+L\n", + "X5lQTliSlHrc19blDxI+BarmPxVVRFr/CorqLGvI+vHNUfF9L5rOth1seL+LchCRD6bYXJMlctoQ\n", + "KBnrSfN8OsFA3rCX0rxhgXIKgdEDuCNRYd4XCiw0AyO8VPwgQ3UKQOwN4T9AdwOVZht3xWSjlGSY\n", + "LTfR+DOcni9vpFUI/V99yTFNeriW/Ezi0Mmb4Xp+UrrTAn+/oqePQryHATZ97i1I4TzdZJ6ol421\n", + "ZZiGDIa6I2z+mz36WJISXYfn5PcaqZon5evy7wkHdXdLSXQuyy6RoW3UMK1kv4eYGMx6MEUBV881\n", + "1DxJ4Az2tfQhJ60iq3lK6xGARpoGTWiGA3pBAAABAwGfLXRDfwCHPtdry+v+2nyY2Sk+gF5YW5HN\n", + "XoAL6QRR4alJgXnPRJGLu1H/XzBsCOVwj2OHZ7/Befz18ioG7PdTUWTo/DFmzXwFwKSHq5MESJ/K\n", + "+czoaBaMU0SilMUvvgF9NaNkzEcYOJjCpUUkl+lvc9iWY7aNcNT0YkO2YuPLl1ZJa6XpXyzgvJfC\n", + "YABMMMlHP4hWdgac8C4JyYJle4OEiXwhanMhhDIkpZpmZqqPP6iXGzuSTb+0ZDMJHqoDGqJmkb8S\n", + "IJuvyZGNE4panvJTPVd9f7g4/aXxMPm3Cn3wfT3mTthI056NzanOEWKjM1qGy4olpTOi0cV3zUKu\n", + "VGl1k7sAAAHXAZ8vakN/AInJcXImIY9AsY+/nZAB2XUf7nMR8KlDfCSlxubwbY5yyAvaK6FdhjtI\n", + "iTEMX/gD5nqi6yBjPV+WgerMVdQiwmsTWCh4ZDRMTEvRNiTK06p6H4BM93iWfwAaKh8Gz9Gaukwy\n", + "InHLEZ0yD1XqM2twrrM9K/zMIWUOeN0Z6Qpdges4mCaPjYBUMA0KTxEuHmES85gUYlt0s0Ks9Nu+\n", + "2hfyb2t0rmyvRs70WgBBgYrdeTZMCwmoCbRHPK4oxsSlCang/p1gu/DmbjnwYRln/v7ufz7R3gdP\n", + "Fr7XrHKEZc+f98DBxQMF82PBbmDGtLAQXHwptz6g5mqHfaJhvvgj78jkqTGrQ4WXMBaKzHGNvGYe\n", + "XIR0bHtcMMQd0uz0UHs+NS8bhlZ93PGBn0DI4S7X4qFOiND2PCIg5ogjbfFqU4Kuh5oLH4L3vi2E\n", + "bzWP7DaofhwjMqjCqAvZAgznNJDsvnJzQxJ6Pqjj2ny04t1drdQRUisSLN+PcLenLQZbe401Xg2H\n", + "yhW845ouHrITGSqb9EOEeoN97gj42PjsdYRMVLRDVvCV2BOAqdLbEmICPHZnyy75qPsejK7duPuc\n", + "fJ9rEnjynB/HxYz7zf/RM6xyYbzIoc3AAAACEkGbNEmoQWyZTAgj//61KoAbj1lLPyvb6PAZgAh9\n", + "7f/9/gX2SHKs8Uq31kdycpXc3bf6XPCYn1E4Nyshm7SbxYTXwR3t77AgzFtBuE6fBgZeY48yXmAW\n", + "rqOr3iMlgArjVOjemrjz47grY/T9rKmhvhaqPi8pvZTzkzZCl+tV6nzXVbBFw15yZW9xk2z611V7\n", + "GITjv5GH4Oi/06B5IbjEMVKEcRpvt893HwIyUBXniM9I90uh0TBxOedvsxxE2iLZsr/m/GNXryb+\n", + "9as6btju6GU5FfXHAHKy97PxI2Rac5Rx/FoPiuKEecRx7EQrDfRmlggPPP63oMY4jkBeTzC7Drwp\n", + "8ik2Z4rhoAMWlcRPfXCI56oe4Jt09oRInuaD3ww9/jGDjhHIXGbNYM/s5UG1XuYLCqaLxESIyPG/\n", + "eNnETthXX/QZDvDCFX3YINANkqDvHlUQ+vcUvksaWF/g1aVcMu45c8BoP1coWBAVWVE6iyDMwfYl\n", + "RYTcnNfp26mpOfqiSJnYH+AFj0qGJttgeZBuJCzdV4F5EDreo0WWAiq/0jdXljJ+ZxDij/UazQOM\n", + "0ct15Q7rTOqLKy+lpOVa/koSWj06e8eyy0wY1FBSVaROGYbDgXze1QzYiVyP6+WTk1fjz+Do+J+/\n", + "TxVlHJsfUOz0tbPJ3R4cSjRVigTxPg9VAYynpzzMlIr0/pCOGd4XYyl3SGTwAAABOUGfUkUVLDv/\n", + "AGU2ltMhgssRVFnYDYHdfwUIOpARUIP1pWfDHpU2pf97OTOpyP7SrW+j72yMHgCy10/KQJvVenOE\n", + "eMrSHUfyq6lVIsdEDgl0M+/NXx5VMpg+IZB+I7xozsY2f0ARjiAjA8ZSqG32YEqaGwpGp+vfKL3P\n", + "hav1CfnyaUmopPCa0Y5ww/PZN4YINPOwE+Gg36kaKP/ME/B0d8v00CzvLXmI8pIa3TqrGIa7PF4X\n", + "8miGO6oXkRH45ag0gFdgkGj+BD1PvtIptIkuqTa5jzG/NewDN9cCfws/hjc474K6NoCTyr++7Tth\n", + "LSIM60DcVje0csuhEMwOmCNob99l/AJp/9hMVsVsEaxUNsWBZFMKnZoLJU/ljkNlTtF1zcUwJoZD\n", + "oLTT6FmWVzlFnyfjiJdVIqMAAYsAAAIPAZ9xdEN/AI8+s1VkrBucudR5tN1L4cUDsugAOgW+6weD\n", + "VD4WeLhja/JOA5FtORnuW7CfHWfWrXcPJlwit0rQdaNL8wYmpMOBxVMKErdopYTnWfb0EZST9ZFP\n", + "kGeAI5wBNyE7pmk7U/hz6/Uncd5yONsvInzdtLdlFGIUuwPsZsiC4nxcPKJ4ER73zqMcPC62dMwB\n", + "YeP2JTSzcWxmsY8AuUeSUMff3wugzCWo2dZWIqj8MEevc9dnI6e4RX4rfqOmeKfJ7QFxuPllAOzz\n", + "FkyERujhdmr2mdRExctZgI01tg+iF/NwBCqP+hQ0BZaq12BgDPwBcWyuj8PXGo/75aroqbic3atK\n", + "78lcQoP6TccBH3q4TpJbdFKZCXZFrS7Hh71ZQxzuADlZ8DDRzGHyvFJs8+7LX0Z3SVEeli/7hzNR\n", + "3en2BovQV52x/rwTox00ojUHS89/I6QK5rr9xZ5z1Evdog7ewBETCofR8FQPxE+2X576ofb9SYpa\n", + "RU+FFWJ4WPQBj/u1ljXdmoINHOgs90YcpGG37DHSgRaxKh3h9samVWdsr/7ZPH7Krx9nfE8zJoXc\n", + "5Frf0sUOO22BhUTf6MatKarbA54SuNAmIi3ejRZKQJ4XCjhpsLBrmw33yy9Nk6OT0LCi0ELysL29\n", + "OvbOK/J+/iRz4bP6v+/3ppYXG9MzSEeggmS96wm6yOsevJy9wrAAAAHWAZ9zakN/AIdXwVSZADwX\n", + "ZeAC6HD/yFRsSkP+ZT/GPlFXimE8PIk5/ho1VfL2NNL2pqViOd6YYnwc7ksNMs5IkNYQ+fdC2XMm\n", + "GpZcBQdS+anJcAkZpOHFxqdIo1pLhI3h3bcsWXXBd+BTXZhbA2JSmhm8EWBGqSBNaO0U3Qcdcea5\n", + "428f3xthr08dSK0oFN+HNErgBuKfL3JZNShDHaW66u0MaG1B/cF2Go8z1F6LGKUAmsy0D/C2CM25\n", + "q38c827dgYTnZjZnTFxlPuxm+JuWvYpOeWyy3J/wjV/USVL+4BKz61/Ccy+EH/JkQUqRmUOtvYei\n", + "XxTdexyug9nI6kyTGc2H3hy0C3uFxKKFKo9PfiwDCQWhQ1+vZIsII4FYexn+pQbkz5kmdlWKB5Lx\n", + "ONpNVggWvIuTYEFI34NTLTOf285YYkebB68ywIJ5f1uX/OXMZ5RxH3gjNZ8mKLNX9suvs06qOt/Q\n", + "e2ZfZ7Orgt/l3O7GLxwWvzugIsO88I1KhpZhgYDdYZ//1lVBcwG/tKVYjF1obqjtyFctY9LPGIag\n", + "318ehZmIvkhW9djj90e+pnWknudbQDv3Os17s3l7qFADdqSGqYyGaSU47a6O12HCRSwmepV1bewA\n", + "AAIrQZt4SahBbJlMCCH//qpVAC8LE+AX+ndLRI9AAL65x3/f4eNbK2tvWi3seP5qm31GHdf4edmk\n", + "0/ZKv9BuxjUGH/qoYxXDUlaWZFHb65x0lomfbckqRBtklU+1LGTmYtvnPAbKnUSAh/jTBATZpFND\n", + "l6V6ofQ5PTBcFjOWwgI6YqalXUkmqnN6g77O4xvodhM7XQWhsA44ADmvatn61wvReF9d9MqoCN9N\n", + "Twpkx2kbbrSoHJrSyqidCsv+e2gnLoWDEdLGn/42++dseweQBj40iKRQ7paDrpDRwTZVjGQJ+52c\n", + "gaUSUp5A/cAn4FgESmp/sZ0NpfD9/7ZAmCbSUfPUar6ndxZ3XG2DXWcNFu473rzFQZNpJnXg/Pfh\n", + "QCQDuu/iX2Vi2NjGs1QVI3BReUxvD8Z/YeLy6w0jDh9dcJGJdKoNjb9Epdy5r0lFeFb9L8AWhdEd\n", + "sGreMPdTiMRlq+JOqjdogseyQTcuDo5iesxIsb0dhY+P9VqSJtTxyPO42dn6TXPZDgt1vROlp+Ic\n", + "VTutbib7FY5U+jSckVQsLzLRwDuIoa+HpEcHjzuwHMaHrKVljgiPeRI3Afdpqx3nHgy0MFCOhGEr\n", + "Jkw+Dadh5qrWjCGOX2K5HPLV0E5qw7krTDhpWX8sTsYsIqvxr/V2EjIFiKwnheBvunmhlbHNUKTl\n", + "ykWRC9Afa8QE+vO8sLJHYNqVh5kOrsn0+NP1Mm4JPbYiahSDJa4o8TJzkXFBAAABAkGflkUVLDv/\n", + "AGBJAvfAgTZO/kHo4lc9yaSVZkgaxkXEQAgySaAqoJy8U1XmJXFaLzsHv4KqZnckX0gP1AYFUr5X\n", + "3Zof5zltHp7OQG87KhkyMuJLOz4diYjf3ctsH2KA3/S29L1hP4qjZ9kfgNEsjrH/nSlX3ikiiFcQ\n", + "/2mu5vwlzQMTIUj5/0pAslvbULpI2rwxcgfjtpeW3qe/Q0sCZXyJ3L7VhEaeyKZo/ALUAi114xdn\n", + "Gao6fyKpZhWohGCsI53i8XO3Y7Dq+aD4ONx4A265BL770fTZiNNw+oM7dwTK1vcPMdOTVjz4fi6j\n", + "bCMBPzMCGM7CsAz7OQTIKiUTlOi8YAAAAakBn7V0Q38AeTG7snd+wR+ioRwfka+slSBm7w4HiigA\n", + "mYoe7RzT8waKJhe/5/xyHdk2lI4Qb6yur2vWdYx/k/gVzZWx+dAAALHLM2W5kE06MD+/WY8W9vMg\n", + "jgsWx+NCob+sUo3r0m3kC7Z6vE5pa/kp8NVK1XizBU/gSaY6/S/NP+nzZeAUHhvnb6LPnQnTmhI7\n", + "+CLAa1UiK6P+lwPbKP0S0Q5RWiopmhls/AKTmwxXB+WRWyrrFglLMCCi/H7yBlZCPn3f1nUi1WXW\n", + "txmtCNftDVTPLfu3fbw+YSszpG0LQoe/d+Hn14JtNEXcVveVKgdRtrJ2SZSzkDZoD5uTokEopKbG\n", + "geSmsxJSe6mDenK/tstnSjFiozTKWgyJb1mTK9iBWStV+uPeceDypkgatRgkwgz17Zgn457UL8xo\n", + "RIb3Rzvhn1PaM6KKHv4wQMqvpqRXKRm+SScKgBhgUzc706tHx+sk3QXrFbfmTj3VwEqpASdMV8SQ\n", + "Rc7Pl7VdiwexHM38nPcgZguGyvH4NF1CZay1mT9d+wee9MfU3VHZJgMp057sUGFJIJZNmQAAASYB\n", + "n7dqQ38Ah1fDGltbSoFNBABy4LNfpqaOuQiA03rsvInHR01iNZMDGQE2sq9jRvjWYcCsjv8TgHDx\n", + "TelM9UgK8aIkbW5xZBO7YH31DMzHB/HcoCKmBUni45/7i/CIo8gF1pGPr0DAA7wV6D09MIgWLTIz\n", + "u2RlgzWHXLOhQSqpesq6gEgghz4eO+szzJWiaji2cgnbFYV7gS1iXMpBIisJc8i3U9gywhFgtGxt\n", + "IPW/7TiYEwGOLwxyjZX1HkROuSI8lAAdZBpungwbYVpPKSngzu3PnOIcBqes7c29MHD8jRPn7Zrt\n", + "720E/jZ4jB2yT62h5AEs+TCYeJmiY6lwGwXm58hIVqeMFafCwAYhd3vDCtfE6mymrvYwtLYQ0YeE\n", + "Ebj2MbA5+zEAAAFwQZu6SahBbJlMFEwR//61KoAWx89GABUe1i4OfaowcQHQyqHCv9PnwkHOB5jh\n", + "ZaY1nqaJvfgMHLxnx0HRU319XsFiIgZ3fycxZ7MoTbod+V6rFy2y2Qtld8RvCt0Ug4PVQuLFLU9x\n", + "N6gbeWntqj92UVkXYHO8rtnoyHbc5vkyDRwK85+1rEknOmV2fCPAJQWJQHZKzqn/akJ6R91HlWya\n", + "u/8GgP8q7KTtX0XyZMALsB3jT/UhmW5AlGIwNHeW1rtDiMG/Xy+69i+m2kTOjww4y5o0/8WfwLLR\n", + "RKlhEE1LYjJQjoy3+hNy7YguxzdtR0GOg0UsPQLFZIBnnCwGmFharg9MSkzKoZck80tBnNzVcu5F\n", + "Ot8W+bdDLv2E/9UTXci1RXlM26z5jearPa/9d/CciU6kElsImbzJ5J2YpzVs+pvW89XbvAJMExZq\n", + "wXD26iUkefzti1p2cc2CbM5qN5CGCTCmR13du1Y9J/JQwXkxhEAAAAFiAZ/ZakN/AHwUpp6Dymc0\n", + "2L536BR5shJlFypABdlGcrzfdaw/6f5GB/atQKmEnLjISTsAvG6zfbdBMs7bm2yeFrIQxXuK81kC\n", + "9pAAAXcBlvswH72knWeKBsU0Ht1g5h3YcKtQv4e82ah693wXobc+mdHgPA3TBKIFWUv/iM+/E90G\n", + "S/NmTeZC+lgt/zT/+HMt/QSFK9C1+AMdH9l6Wmy5eJzA8pumBNuqAArwclv8LW1AC9Ryj7J7dIqZ\n", + "2nhKIYQ08cavMFAGExrDHt7RiTs4Auer+jpijDT1MWhCFcQjNZn9nbOp1MdYUZ3batlHR94YKH39\n", + "SB9iaEe1H+vDrSDRsP3b0PfVLevCUtQQ7tTMju5YxLigI0SkXHby6oMGwH35DOmYdZ/QEHihEbbH\n", + "ljlaWypqm6TR7b/zNBCPoaZiHS0IlbTr/gzMbXxGasP7GssB89XtUV2jZihKJYcij8456L2VAAAC\n", + "WkGb3knhClJlMCCH//6qVQAvW48vGhnpxPcAFRvWsRQfCH0ZQNKlkI/Fmy/VFBZqjdqwlFWyRDRU\n", + "ATa/x8nSCThm/LYIboN0iejGj3Uchm8nyLv3P3+HOOnCw7+XGsyycSpaT/SKI8hu4RwjrdDxqaYn\n", + "k6pZ6qjZtX+IZ04XS8X44piBkZKHHklQnddyez3eJG0JjT0fN5b/c72jAD+sOeXlR6iPKkSUzu0o\n", + "3ha2oHN6UEDmISbP1cbB3piI/SHrisHlFNjIuHiEdkqSzG95tlcEE5RmJMFHyIZtmV+VUnHUg//H\n", + "WOVjyT0+oFlaS4c8th8dtoQJgchjo9u+OPpSDxEJgWI6zeeh28ogNTGzlwRqjfRSsrTItvjA1MD/\n", + "oBFhKLk5Gm5LLSkMpDHu9T5I2IaoH3PKDFRJp5FswrHAqK+C6EMiKJRw3UfQ++e71IzTL0xpDNJL\n", + "z6AeitOHT7WHH1q0lcaxtRKIXyzlri2FOeAU+zEh7DbcM3wvbzCPYrbD4ePmP1flYALif0DM+F20\n", + "woqO1ciEp6KvfcdLwkVhOi6HukmunTXGsruYaqjkaLT2QlUIMJVPTAaXGvEAsJSG/0vfsDXKkk6Z\n", + "sB3ElNrSO3yHej1aIEgW5xnCNisEQsWn6TKnOYGilPN4ZN8EB64V0F8PWNB9Aq0baX+T8kKesmFw\n", + "2y/668NRP8ypn4s+0TEew3V5nLH+An+XxWolypflMoVnWhEhG2W+IIgxfWfPuSgDmqBKtSemnfnO\n", + "mj2z1HJ4yEmqNoBjJwYnWfK8e0PHHb381Mk1zGGJOgWAAAABUEGf/EU0TDv/AFlVerlP4Rak+BQA\n", + "rfH1MAekqKZtO9rI3YpPu0XbIusXd4D2mikBBjNWCs5ZCx1/nIkAW78LpHSyCScRX686DgqeELvg\n", + "+6gjEvz9oPv/Q5SyPMBeMNrb/QJ3ato+Qw19nLJWjl0bduh+HilMsrklIYKHCWBaC/dNC4s7Xl/r\n", + "RCzM7ZJuRKmUY/D5sEAdr/H6TIVmiD0u2jiehC8y8Gw6flB5fdlWyz5ArpMes88RS9cHH1n4Dp5A\n", + "9YiKoxa6XsjMVtwy/Q1CE1CcjEE8nX1x2wi3FF+AiuFwqQsSRlHtfUsVksDBdXLvE8zjbyOIuIMV\n", + "pnJU22cEHHqRAVAAAQz/a8I3JUwtCYefKDlHQuITIdlhxtkj1S9/MOKY0At1R1tnioLMWN7HUVCo\n", + "b6XS9uoGwS6oOJgKcTFbR1vNa4wchWq0XCPds0DBwQAAAPYBnht0Q38AeTSjvudgsbkOLNHOwJSE\n", + "7MIAOT4Tae/DlzyAOhFcKHSt+XmND2K3krM1WAe1ksxoXOx8R5ib25iI4yoXHAvjcPvcDoLvQIYy\n", + "rfzkEj8FCsgVqTty2M7mcrrsvBMmGI/tSEAq1Wpq/wSUg2I4oZj0GjiChzewD+uw3YnWAi/Ntf5Y\n", + "Cv2dU9qEo9e3jPCavhxnj6HVQyqcvxekJ6cEcAGQvRh8PwiQyys4LYMz+Th6jmnZO6zDQlY1h459\n", + "aXiX/1NPDVjhvbOibPxdXy1nW8ZFN/ZpmMtUtTAz4mvuGfLCJYTZv8r0n1cztBPRieehovEAAAGy\n", + "AZ4dakN/AHwTrqiSAEDVZr7cfUIfCi6SEtf6z4BBmn/qEvCbGFYoG0hJzipIIEfgPxGLOPb5hgYo\n", + "3EqlxYfhyi3ADlPB0rSvUe/2K1c1bOHHkBdbN7v2fRCe6cTgBUViIyBzKbW8+YVzs1NjLsftvDLF\n", + "Jws+AVbFUOsz2XZO6+tJqS4okplORVfI8Zh8pjE7ly6+HI7Omo301kEp6VZks8VHiVKJOuTRsuFe\n", + "1lak9cDIgZS7IV3MkEjdmu8V6wPVTOui5KhgRegdKpe7dvKwiZROacSHUyEpgoiQ49NAkgd9ICSC\n", + "nOG96XtcVUK5qLGXI1ECEXtJcuaFVMtCmmOBBiFL8jC1MpHbxQ+4k2qRSUjP3JvFi0NfrsxeXbrH\n", + "Ebg5vBmNpJE6T+wdC73c70xC+Mtp+wYFzu5kfTKcL8d+Nzu4GlIr338e6SWwNSpXRGjfdLp9o3Ic\n", + "2PzMtQmrlpbEeUDp1vnkaZoqSF5M9xanIk/zohgoPX5++NN/ebYvr56WROjUeIUdsOf6nrJlmboT\n", + "DZEat6r4aY15lVCgiz4Mpb/mqSazxzrszmdRYRxGsW8DnzAAAAHfQZoCSahBaJlMCHf//qmWALFy\n", + "5oM61QiAB+cxK4+jNCOHXw6RALujtnWF0llKsvjvaSIz+44BdTBn8Dqmduydu0Ab2yYLL8rBa9BR\n", + "bM/WBrO6FCt4pfpaT57HiAbORTevnWHgnUCdwsiqbddvhjkiuJYbgCMD0kEP1SURu/b2Z5hWsq5s\n", + "eIdJwlVUmffx/GFsHH2OVg2kldaudIzyWEsMXsnZccvZ4+1TTMECSDKdUtlhUW9AAgPUraaePKP1\n", + "hatMAsKbsEP5g1nzjTlmyHjs7FjRbwjKng4/qsqVQ+s9Z8Le9mq44VPerxrlkKxdRgf8PQXTEpxP\n", + "gMR8UP9I/vRSJBbzTafYsMhPytfC8ESUe9ySga0pNZKSvC+bN1h7zO9OEjqF3rsnXJU2SZN7NAbS\n", + "01WCPkWQIdWN39TZ8BwhuM2E1/XfXA9OxCI/7PAG40Z8M1rKVJPTY+iwZnIQA6cEF3rnJVasn/JZ\n", + "rircnzzi1JQr5NiwthCEkD02k7GAoyHtF8lIKArvw+GqH7Ox1Tpd6DhPPJm2hmyijeFH6E+9UCJk\n", + "Iiolc9K3UW1rmUlHlF/p9jHAvsiiJUpuG/KCfna2LEYj9yn6P2oNlWfqq5P2HNtctaJeVRZv9Qb/\n", + "mNVjyjAAAAErQZ4gRREsO/8AZUEtk8LzOoS4AAhIFC88oI10PfUAs3UxxCOOtSzHREgn4/jgVfHt\n", + "0r483Tf2Y8D+zGlycQw2lUV6Nidlo0k0sASUCm4dEwF8Hb0+IzseFE0dYexJdLqvhcI7IIUIH6RG\n", + "uv8cjTXFD8CTksvYGpGc+uBYXhlwc3/jHhNGtm8G24uHniey+Zy/NtEpSl5dub3bE324kx+/N1gF\n", + "sU/CxkQF6UQWvd6Br4nL+i2L6udCLqM/JAVJhScc01UR/bE+NX2i3upx0qofgxfWL8unNZ/BP9Vc\n", + "CvVXAtxPw+0JopAnWMlwtBFG9wd+oP4zOIJ88u/VEvyZQd0JJP1Y3qhYk13Deyiv0C1r6ci1z7CQ\n", + "UwYqgUT64pT/hlIvHeCzEZxqH+WbUbEAAAGYAZ5fdEN/AIteE+hbrZmAAHNd3/IVGxTYP4E6C+Wr\n", + "63le3xAHjzqOqEil1tIAAUY3LvF62/277H30QskV8sEjceHvPe7bE0mfZ44avBY2gS0AAAMByRDk\n", + "EKOyh31Y2H0mdsy+zcGsPrGm3pHtO2riBcgILxHO0F5398HG90hK8UgtDUfp9CQyPOvDSyEU4WTb\n", + "6/WT9Z3aca6tb4C53W6p8Geyjq/mwbvNpnCVbbqIcx1ZT2+dencovmeYmPlI7jrhk6KwLYEd+5gO\n", + "J2YeKk4iWai6BsaO9+Tb5P52jBVHcSZ+Vws5QhTxkBSpdHlWJRcbh50V4ViVltwUN//XNx+jx2bk\n", + "KsfglI41FGmS2xAJtr8ZhKDk1VRRL2tGsNB5nztuRXCFd8q4MIuVVWGjim0ntcxZ/R18mzJZN+sI\n", + "qKUvfsxoaeZp+oIaU1hLeXzgcHEe+3/6emdZeJWoDNhUqhkfWzWzVZbEzUKpDBS9AbVIA5KR27LD\n", + "3HEfRMw9yt8eYILg7m/Rm2ubtU8u6V2QuxVXq1OHry5oY2TAAAABvQGeQWpDfwCPV5unds/RGF4o\n", + "aWlq+XwTSVpG+igacFOApaqyNJIXSXT4q7gA4DkP0YAYAumNCN0MwD7HSEeIsv3Q3L9kZ2RagxvU\n", + "jle4yQq6Zl5W7AgdlZnaBngH/w8xYsqWx5t90zzi7s9VyRY9jaNshfxuJAZcRgFILNTmQNCPoCtl\n", + "wyo5Ht91VCy2qSby6JDLeTD096PzM4KOK7/I+amuefuT0S/QnDNs952oi11JV2mbadqtKDqJE9x4\n", + "nX/OjU9PBP1uhsFLNkjsz6ZHlTOcsZvWUxabbw0HBNFuLXWIYqtAYdWN7c/QUoqY2IlVBR//v+NN\n", + "Bxf/rxPv+9QlTTeUOAVhzyU/kQACorW+VEL2KFNUPF85LUxlbSGEYQv/98/fAQAu6hKRw3yoJoPy\n", + "tyr7S7Za9gGurMYseuvuasNoB+fPCmp37VWgm4yNZQ0LM+8CPtaQgShVMs2/RIG2cXksHuYVqEB7\n", + "PJtzP2tl8EYDen8RohIb2UO5d/Xdc8aoi/Nu4IzGq8ApuZIxjC5J9bUYtMDEDA6eChGKPjb20vqg\n", + "2PRBI2fSXJrcSROGTC4m+VsF+VagO1LnjrakndEAAAHtQZpDSahBbJlMCG///qeEAVH55ayIAL6z\n", + "9D9Go2JR/VsPgULYIy+HM1JNQWUio64eqKV59gHDbxQ77xKGvVi/RlMeepNHF+Cplpp4rKqgivaK\n", + "14o0jVVjKwdzXmYfm8QJck76NrSj9rXzMi3Th9DbQ5HQHvlFr1+Ft6fGVXaubVoF+Bx3J4nvsWO+\n", + "FhXDphKaWh9geM/3PqX1TK4zqhRL2wKgDCWdLvIi2s2e48RSWR1zksj0SjkMINJfgjA7wVj0dW8Z\n", + "NZGlcRPjgkoSgpomI+x9/l7dJ5fHEj4WOkMQMTJnj+KOqaXfgtXbhBachZ0Av1Z6rh+qw/iObJOy\n", + "7q2gUdlftEWI7In7KZjqqg18Bg+z35wI2FmknOyXdEiDAPaFiRrhqkKOLfgLssw1BdohiuTGWlKn\n", + "NvPL4EzIbAUeS+0qv5cFdXvRjnn1zOMYTMpyN1CZYg4pqjj8mGtGdm1F7w0Xo4Mnm3hRmvZyyOaW\n", + "yf38s1SCwyOkhQcwJhrAAebvkxMWrAUWrTq9K9PdCUqFbMVB9+93aovoux8zBfM/WLangtLLXd/D\n", + "T9TcgY0eosWGZeAhQk2sxNC3bgvMT328AT2T2XCg2nG4jsOakPWfscwbc0zKfItj/1eXvyR2tk+K\n", + "fpgdg9dJ/OdcXINTUAAAB95tb292AAAAbG12aGQAAAAAAAAAAAAAAAAAAAPoAAAnEAABAAABAAAA\n", + "AAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAEAAAAAAAAAAAAAAAAAAEAAAAAAAAAAAAAAAAAAAAAA\n", + "AAAAAAAAAAAAAAAAAAACAAAHCHRyYWsAAABcdGtoZAAAAAMAAAAAAAAAAAAAAAEAAAAAAAAnEAAA\n", + "AAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAEAAAAAAAAAAAAAAAAAAEAAAAABsAAAASAA\n", + "AAAAACRlZHRzAAAAHGVsc3QAAAAAAAAAAQAAJxAAAAgAAAEAAAAABoBtZGlhAAAAIG1kaGQAAAAA\n", + "AAAAAAAAAAAAACgAAAGQAFXEAAAAAAAtaGRscgAAAAAAAAAAdmlkZQAAAAAAAAAAAAAAAFZpZGVv\n", + "SGFuZGxlcgAAAAYrbWluZgAAABR2bWhkAAAAAQAAAAAAAAAAAAAAJGRpbmYAAAAcZHJlZgAAAAAA\n", + "AAABAAAADHVybCAAAAABAAAF63N0YmwAAACzc3RzZAAAAAAAAAABAAAAo2F2YzEAAAAAAAAAAQAA\n", + "AAAAAAAAAAAAAAAAAAABsAEgAEgAAABIAAAAAAAAAAEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\n", + "AAAAAAAAAAAY//8AAAAxYXZjQwFkABX/4QAYZ2QAFazZQbCWhAAAAwAEAAADAFA8WLZYAQAGaOvj\n", + "yyLAAAAAHHV1aWRraEDyXyRPxbo5pRvPAyPzAAAAAAAAABhzdHRzAAAAAAAAAAEAAABkAAAEAAAA\n", + "ABRzdHNzAAAAAAAAAAEAAAABAAADMGN0dHMAAAAAAAAAZAAAAAEAAAgAAAAAAQAAFAAAAAABAAAI\n", + "AAAAAAEAAAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAABQA\n", + "AAAAAQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAA\n", + "AAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAABQAAAAAAQAACAAAAAABAAAAAAAA\n", + "AAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAMAAAAAAEAAAQAAAAA\n", + "AQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAAB\n", + "AAAEAAAAAAEAABQAAAAAAQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEA\n", + "AAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAABQAAAAAAQAA\n", + "CAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAM\n", + "AAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgA\n", + "AAAAAQAAAAAAAAABAAAEAAAAAAEAABQAAAAAAQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAA\n", + "AAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAA\n", + "AAEAABQAAAAAAQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAA\n", + "AQAABAAAAAABAAAMAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAAB\n", + "AAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAAAgAAAAAHHN0c2MAAAAAAAAAAQAAAAEA\n", + "AABkAAAAAQAAAaRzdHN6AAAAAAAAAAAAAABkAAAGhgAAAl8AAAFjAAAAvgAAAXYAAAHzAAABDgAA\n", + "ATYAAAFIAAAB9QAAAOIAAAD6AAABWgAAAbAAAADTAAAB8wAAAN4AAAH+AAABEAAAAOIAAAG2AAAC\n", + "DAAAAWUAAAGkAAABmgAAAckAAAEdAAABfQAAAPMAAAFxAAABIgAAAjYAAAEmAAAA5AAAAXoAAAH+\n", + "AAAA/wAAAT0AAAFnAAACAwAAARQAAAE3AAABTwAAAckAAADrAAACFwAAAP0AAAHzAAABIQAAAOAA\n", + "AAHKAAACOwAAAVQAAAHFAAABugAAAdQAAAD3AAABUgAAARIAAAFuAAABLwAAAhAAAAERAAAA9gAA\n", + "AZkAAAIqAAABIgAAAV0AAAGIAAACSgAAASgAAAFEAAABggAAAegAAAD+AAACCgAAASIAAAIdAAAB\n", + "KAAAAQcAAAHbAAACFgAAAT0AAAITAAAB2gAAAi8AAAEGAAABrQAAASoAAAF0AAABZgAAAl4AAAFU\n", + "AAAA+gAAAbYAAAHjAAABLwAAAZwAAAHBAAAB8QAAABRzdGNvAAAAAAAAAAEAAAAsAAAAYnVkdGEA\n", + "AABabWV0YQAAAAAAAAAhaGRscgAAAAAAAAAAbWRpcmFwcGwAAAAAAAAAAAAAAAAtaWxzdAAAACWp\n", + "dG9vAAAAHWRhdGEAAAABAAAAAExhdmY1Ny44My4xMDA=\n", + "\"\u003e\n", + " Your browser does not support the video tag.\n", + "\u003c/video\u003e" + ], + "text/plain": [ + "\u003cIPython.core.display.HTML at 0x7f84b2253b50\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "import time\n", + "import traceback\n", + "\n", + "from matplotlib import pyplot as plt\n", + "from matplotlib import animation as anim\n", + "import tensorflow as tf\n", + "from tensorflow.contrib import autograph as ag\n", + "from IPython import display\n", + "\n", + "\n", + "@ag.do_not_convert(ag.RunMode.PY_FUNC)\n", + "def render(boards):\n", + " fig = plt.figure()\n", + "\n", + " ims = []\n", + " for b in boards:\n", + " im = plt.imshow(b, interpolation='none')\n", + " im.axes.get_xaxis().set_visible(False)\n", + " im.axes.get_yaxis().set_visible(False)\n", + " ims.append([im])\n", + "\n", + " try:\n", + " ani = anim.ArtistAnimation(\n", + " fig, ims, interval=100, blit=True, repeat_delay=5000)\n", + " plt.close()\n", + "\n", + " display.display(display.HTML(ani.to_html5_video()))\n", + " except RuntimeError:\n", + " print('Coult not render animation:')\n", + " traceback.print_exc()\n", + "\n", + "\n", + "def gol_episode(board):\n", + " directions = tf.constant(\n", + " ((-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)))\n", + "\n", + " new_board = []\n", + " ag.set_element_type(new_board, tf.int32)\n", + "\n", + " for i in range(len(board)):\n", + " for j in range(len(board[i])):\n", + " num_neighbors = 0\n", + " for d in directions:\n", + " ni = i + d[0]\n", + " nj = j + d[1]\n", + " if ni \u003e= 0 and nj \u003e= 0 and ni \u003c len(board) and nj \u003c len(board[i]):\n", + " num_neighbors += board[ni][nj]\n", + " \n", + " new_cell = 0\n", + " if num_neighbors == 2:\n", + " new_cell = board[i][j]\n", + " elif num_neighbors == 3:\n", + " new_cell = 1\n", + " \n", + " new_board.append(new_cell)\n", + " final_board = ag.stack(new_board)\n", + " final_board = tf.reshape(final_board, board.shape)\n", + " return final_board\n", + " \n", + "\n", + "def gol(initial_board):\n", + " board = initial_board\n", + " boards = []\n", + " ag.set_element_type(boards, tf.int32)\n", + " # We are being explicit about tensor constants to ensure the loop\n", + " # is not unrolled in the graph. This may change in the future.\n", + " for i in range(tf.constant(NUM_STEPS)):\n", + " board = gol_episode(board)\n", + " boards.append(board)\n", + " boards = ag.stack(boards)\n", + " render(boards)\n", + " return tf.no_op()\n", + " \n", + "\n", + "with tf.Graph().as_default():\n", + " # Gosper glider gun\n", + " # Adapted from http://www.cplusplus.com/forum/lounge/75168/\n", + " _ = 0\n", + " initial_board = tf.constant((\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,1,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_,_,_,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,1,_,_,_,1,_,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_ ),\n", + " ( _,1,1,_,_,_,_,_,_,_,_,1,_,_,_,_,_,1,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,1,1,_,_,_,_,_,_,_,_,1,_,_,_,1,_,1,1,_,_,_,_,1,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,1,_,_,_,_,_,1,_,_,_,_,_,_,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,1,_,_,_,1,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ))\n", + " initial_board = tf.pad(initial_board, ((0, 20), (0, 10)))\n", + " \n", + " tf_gol = ag.to_graph(gol)\n", + " game_ops = tf_gol(initial_board)\n", + " with tf.Session() as sess:\n", + " sess.run(game_ops)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7NgrSPCZxs3h" + }, + "source": [ + "#### Generated code" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 2323 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 753, + "status": "ok", + "timestamp": 1532101593840, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "hIGYeX0Cxs3i", + "outputId": "e0b62eb1-3e12-4e53-dc54-8a3fa56d823d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "from __future__ import print_function\n", + "import tensorflow as tf\n", + "\n", + "def tf__gol_episode(board):\n", + " try:\n", + " with tf.name_scope('gol_episode'):\n", + " directions = tf.constant(((-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1),\n", + " (1, -1), (1, 0), (1, 1)))\n", + " new_board = ag__.new_list([])\n", + "\n", + " def extra_test_2(new_board_2):\n", + " with tf.name_scope('extra_test_2'):\n", + " return True\n", + "\n", + " def loop_body_2(i, new_board_2):\n", + " with tf.name_scope('loop_body_2'):\n", + "\n", + " def extra_test_1(new_board_1):\n", + " with tf.name_scope('extra_test_1'):\n", + " return True\n", + "\n", + " def loop_body_1(j, new_board_1):\n", + " with tf.name_scope('loop_body_1'):\n", + " num_neighbors = 0\n", + "\n", + " def extra_test(num_neighbors_2):\n", + " with tf.name_scope('extra_test'):\n", + " return True\n", + "\n", + " def loop_body(d, num_neighbors_2):\n", + " with tf.name_scope('loop_body'):\n", + " ni = i + ag__.get_item(d, (0), opts=ag__.GetItemOpts(\n", + " element_dtype=None))\n", + " nj = j + ag__.get_item(d, (1), opts=ag__.GetItemOpts(\n", + " element_dtype=None))\n", + "\n", + " def if_true():\n", + " with tf.name_scope('if_true'):\n", + " num_neighbors_1, = num_neighbors_2,\n", + " num_neighbors_1 += ag__.get_item(ag__.get_item(board,\n", + " (ni), opts=ag__.GetItemOpts(element_dtype=None)),\n", + " (nj), opts=ag__.GetItemOpts(element_dtype=None))\n", + " return num_neighbors_1,\n", + "\n", + " def if_false():\n", + " with tf.name_scope('if_false'):\n", + " return num_neighbors_2,\n", + " num_neighbors_2 = ag__.utils.run_cond(tf.logical_and(tf.\n", + " greater_equal(ni, 0), tf.logical_and(tf.greater_equal\n", + " (nj, 0), tf.logical_and(tf.less(ni, ag__.utils.\n", + " dynamic_builtin(len, board)), tf.less(nj, ag__.utils.\n", + " dynamic_builtin(len, ag__.get_item(board, (i), opts=\n", + " ag__.GetItemOpts(element_dtype=None))))))), if_true,\n", + " if_false)\n", + " return num_neighbors_2,\n", + " num_neighbors = ag__.for_stmt(directions, extra_test,\n", + " loop_body, (num_neighbors,))\n", + " new_cell = 0\n", + "\n", + " def if_true_2():\n", + " with tf.name_scope('if_true_2'):\n", + " new_cell_2, = new_cell,\n", + " new_cell_2 = ag__.get_item(ag__.get_item(board, (i), opts\n", + " =ag__.GetItemOpts(element_dtype=None)), (j), opts=\n", + " ag__.GetItemOpts(element_dtype=None))\n", + " return new_cell_2,\n", + "\n", + " def if_false_2():\n", + " with tf.name_scope('if_false_2'):\n", + " new_cell_3, = new_cell,\n", + "\n", + " def if_true_1():\n", + " with tf.name_scope('if_true_1'):\n", + " new_cell_1, = new_cell_3,\n", + " new_cell_1 = 1\n", + " return new_cell_1,\n", + "\n", + " def if_false_1():\n", + " with tf.name_scope('if_false_1'):\n", + " return new_cell_3,\n", + " new_cell_3 = ag__.utils.run_cond(tf.equal(num_neighbors, \n", + " 3), if_true_1, if_false_1)\n", + " return new_cell_3,\n", + " new_cell = ag__.utils.run_cond(tf.equal(num_neighbors, 2),\n", + " if_true_2, if_false_2)\n", + " new_board_1 = ag__.list_append(new_board_1, new_cell)\n", + " return new_board_1,\n", + " new_board_2 = ag__.for_stmt(ag__.utils.dynamic_builtin(range,\n", + " ag__.utils.dynamic_builtin(len, ag__.get_item(board, (i),\n", + " opts=ag__.GetItemOpts(element_dtype=None)))), extra_test_1,\n", + " loop_body_1, (new_board_2,))\n", + " return new_board_2,\n", + " new_board = ag__.for_stmt(ag__.utils.dynamic_builtin(range, ag__.\n", + " utils.dynamic_builtin(len, board)), extra_test_2, loop_body_2, (\n", + " new_board,))\n", + " final_board = ag__.list_stack(new_board, opts=ag__.ListStackOpts(\n", + " element_dtype=tf.int32, original_call=ag.stack))\n", + " final_board = tf.reshape(final_board, board.shape)\n", + " return final_board\n", + " except:\n", + " ag__.rewrite_graph_construction_error(ag_source_map__)\n", + "\n", + "def tf__gol(initial_board):\n", + " try:\n", + " with tf.name_scope('gol'):\n", + " board = initial_board\n", + " boards = ag__.new_list([])\n", + "\n", + " def extra_test(board_1, boards_1):\n", + " with tf.name_scope('extra_test'):\n", + " return True\n", + "\n", + " def loop_body(i, board_1, boards_1):\n", + " with tf.name_scope('loop_body'):\n", + " board_1 = tf__gol_episode(board_1)\n", + " boards_1 = ag__.list_append(boards_1, board_1)\n", + " return board_1, boards_1\n", + " board, boards = ag__.for_stmt(ag__.utils.dynamic_builtin(range, tf.\n", + " constant(NUM_STEPS)), extra_test, loop_body, (board, boards))\n", + " boards = ag__.list_stack(boards, opts=ag__.ListStackOpts(\n", + " element_dtype=tf.int32, original_call=ag.stack))\n", + " with ag__.utils.control_dependency_on_returns(render(boards)):\n", + " boards_2 = ag__.utils.alias_tensors(boards)\n", + " return tf.no_op()\n", + " except:\n", + " ag__.rewrite_graph_construction_error(ag_source_map__)\n", + "\n" + ] + } + ], + "source": [ + "print(ag.to_code(gol))" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "p8zZyj-tq4K3", + "Lkq3DBGOv3fA", + "r8_0ioEuAI-a", + "7NgrSPCZxs3h" + ], + "default_view": {}, + "last_runtime": { + "build_target": "", + "kind": "local" + }, + "name": "Simple algorithms using AutoGraph", + "provenance": [ + { + "file_id": "19q8KdVF8Cb_fDd13i-WDOG_6n_QGNW5-", + "timestamp": 1528465909719 + } + ], + "version": "0.3.2", + "views": {} + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb index 0702273fac15da61a72d66d8344a5add32ad12a6..7e9cc54d4cafa64e4cd3b48f9376b1b2b4d3575e 100644 --- a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb +++ b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb @@ -1,49 +1,20 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Dev Summit 2018 - Autograph", - "version": "0.3.2", - "views": {}, - "default_view": {}, - "provenance": [ - { - "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K", - "timestamp": 1522238054357 - }, - { - "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ", - "timestamp": 1521743157199 - }, - { - "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-", - "timestamp": 1520522344607 - } - ], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python2", - "display_name": "Python 2" - } - }, "cells": [ { + "cell_type": "markdown", "metadata": { - "id": "g7nGs4mzVUHP", - "colab_type": "text" + "colab_type": "text", + "id": "g7nGs4mzVUHP" }, - "cell_type": "markdown", "source": [ - "# Experimental: TF Autograph\n", + "# Experimental: TF AutoGraph\n", "**TensorFlow Dev Summit, 2018.**\n", "\n", - "This interactive notebook demonstrates **autograph**, an experimental source-code transformation library to automatically convert TF.Eager and Python code to TensorFlow graphs.\n", + "This interactive notebook demonstrates **AutoGraph**, an experimental source-code transformation library to automatically convert Python, TensorFlow and NumPy code to TensorFlow graphs.\n", "\n", "**Note: this is pre-alpha software!** The notebook works best with Python 2, for now.\n", "\n", - "> ![alt text](https://lh3.googleusercontent.com/QOvy0clmg7siaVKzwmSPAjicWWNQ0OeyaB16plDjSJMf35WD3vLjF6mz4CGrhSHw60HnlZPJjkyDCBzw5XOI0oBGSewyYw=s688)\n", + "\u003e ![alt text](https://lh3.googleusercontent.com/QOvy0clmg7siaVKzwmSPAjicWWNQ0OeyaB16plDjSJMf35WD3vLjF6mz4CGrhSHw60HnlZPJjkyDCBzw5XOI0oBGSewyYw=s688)\n", "\n", "### Table of Contents\n", "1. _Write Eager code that is fast and scalable._\n", @@ -53,37 +24,39 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "uFcgBENZqkB2", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "uFcgBENZqkB2" }, - "cell_type": "code", + "outputs": [], "source": [ "# Install TensorFlow; note that Colab notebooks run remotely, on virtual\n", "# instances provided by Google.\n", "!pip install -U -q tf-nightly" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "Pa2qpEmoVOGe", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "Pa2qpEmoVOGe" }, - "cell_type": "code", + "outputs": [], "source": [ "import os\n", "import time\n", @@ -96,170 +69,172 @@ "import six\n", "\n", "from google.colab import widgets" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "ZVKfj5ttVkqz", - "colab_type": "text" + "colab_type": "text", + "id": "ZVKfj5ttVkqz" }, - "cell_type": "markdown", "source": [ "# 1. Write Eager code that is fast and scalable\n", "\n", "TF.Eager gives you more flexibility while coding, but at the cost of losing the benefits of TensorFlow graphs. For example, Eager does not currently support distributed training, exporting models, and a variety of memory and computation optimizations.\n", "\n", - "Autograph gives you the best of both worlds: write your code in an Eager style, and we will automatically transform it into the equivalent TF graph code. The graph code can be executed eagerly (as a single op), included as part of a larger graph, or exported." + "AutoGraph gives you the best of both worlds: you can write your code in an Eager style, and we will automatically transform it into the equivalent TF graph code. The graph code can be executed eagerly (as a single op), included as part of a larger graph, or exported." ] }, { + "cell_type": "markdown", "metadata": { - "id": "snaZRFdWd9ym", - "colab_type": "text" + "colab_type": "text", + "id": "snaZRFdWd9ym" }, - "cell_type": "markdown", "source": [ - "For example, autograph can convert a function like this:" + "For example, AutoGraph can convert a function like this:" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "9__n8cSIeDnD", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "9__n8cSIeDnD" }, - "cell_type": "code", + "outputs": [], "source": [ "def g(x):\n", - " if x > 0:\n", + " if x \u003e 0:\n", " x = x * x\n", " else:\n", " x = 0\n", " return x" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "gq0eQcuReHET", - "colab_type": "text" + "colab_type": "text", + "id": "gq0eQcuReHET" }, - "cell_type": "markdown", "source": [ "... into a TF graph-building function:" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "sELSn599ePUF", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "output_extras": [ - {} - ], - "base_uri": "https://localhost:8080/", - "height": 413 + "height": 431 }, - "outputId": "bb0c7216-1ca3-4da1-d1fb-589902cdcd1a", + "colab_type": "code", "executionInfo": { + "elapsed": 69, "status": "ok", - "timestamp": 1522345737505, - "user_tz": 240, - "elapsed": 243, + "timestamp": 1531750911837, "user": { - "displayName": "Dan Moldovan", - "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", - "userId": "112023154726779574577" - } - } + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "sELSn599ePUF", + "outputId": "2858bde5-ae05-4c32-be01-7770ac914f02" }, - "cell_type": "code", - "source": [ - "print(autograph.to_code(g))" - ], - "execution_count": 0, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "from __future__ import print_function\n", "import tensorflow as tf\n", - "from tensorflow.contrib.autograph.impl import api as autograph_api\n", - "from tensorflow.contrib.autograph import utils as autograph_utils\n", "\n", "def tf__g(x):\n", - " with tf.name_scope('g'):\n", + " try:\n", + " with tf.name_scope('g'):\n", "\n", - " def if_true():\n", - " with tf.name_scope('if_true'):\n", - " x_1, = x,\n", - " x_1 = x_1 * x_1\n", - " return x_1,\n", + " def if_true():\n", + " with tf.name_scope('if_true'):\n", + " x_1, = x,\n", + " x_1 = x_1 * x_1\n", + " return x_1,\n", "\n", - " def if_false():\n", - " with tf.name_scope('if_false'):\n", - " x_1, = x,\n", - " x_1 = 0\n", - " return x_1,\n", - " x = autograph_utils.run_cond(tf.greater(x, 0), if_true, if_false)\n", - " return x\n", + " def if_false():\n", + " with tf.name_scope('if_false'):\n", + " x_2, = x,\n", + " x_2 = 0\n", + " return x_2,\n", + " x = ag__.utils.run_cond(tf.greater(x, 0), if_true, if_false)\n", + " return x\n", + " except:\n", + " ag__.rewrite_graph_construction_error(ag_source_map__)\n", "\n" - ], - "name": "stdout" + ] } + ], + "source": [ + "print(autograph.to_code(g))" ] }, { + "cell_type": "markdown", "metadata": { - "id": "j74n-8hEe6dk", - "colab_type": "text" + "colab_type": "text", + "id": "j74n-8hEe6dk" }, - "cell_type": "markdown", "source": [ "You can then use the converted function as you would any regular TF op -- you can pass `Tensor` arguments and it will return `Tensor`s:" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "AkVaY0-dfEbH", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "output_extras": [ - {} - ], - "base_uri": "https://localhost:8080/", "height": 53 }, - "outputId": "4ffe3757-c44d-424c-c2a8-7ddc973bfcce", + "colab_type": "code", "executionInfo": { + "elapsed": 83, "status": "ok", - "timestamp": 1522345737841, - "user_tz": 240, - "elapsed": 257, + "timestamp": 1531750911965, "user": { - "displayName": "Dan Moldovan", - "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", - "userId": "112023154726779574577" - } - } + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "AkVaY0-dfEbH", + "outputId": "f04541ad-b1d3-4663-bf27-4d902648283d" }, - "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "g(9) = 81\n", + "tf_g(9) = 81\n" + ] + } + ], "source": [ "tf_g = autograph.to_graph(g)\n", "\n", @@ -272,77 +247,72 @@ "\n", " print('g(9) = %s' % g(9))\n", " print('tf_g(9) = %s' % tf_g_result)" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "g(9) = 81\n", - "tf_g(9) = 81\n" - ], - "name": "stdout" - } ] }, { + "cell_type": "markdown", "metadata": { - "id": "trrHQBM1VnD0", - "colab_type": "text" + "colab_type": "text", + "id": "trrHQBM1VnD0" }, - "cell_type": "markdown", "source": [ "# 2. Case study: complex control flow\n", "\n", - "Autograph can convert a large chunk of the Python language into graph-equivalent code, and we're adding new supported language features all the time. In this section, we'll give you a taste of some of the functionality in autograph.\n", - "Autograph will automatically convert most Python control flow statements into their correct graph equivalent.\n", + "Autograph can convert a large subset of the Python language into graph-equivalent code, and we're adding new supported language features all the time. In this section, we'll give you a taste of some of the functionality in AutoGraph.\n", + "AutoGraph will automatically convert most Python control flow statements into their graph equivalent.\n", " " ] }, { + "cell_type": "markdown", "metadata": { - "id": "u0YG3DPgZxoW", - "colab_type": "text" + "colab_type": "text", + "id": "u0YG3DPgZxoW" }, - "cell_type": "markdown", "source": [ "We support common statements like `while`, `for`, `if`, `break`, `return` and more. You can even nest them as much as you like. Imagine trying to write the graph version of this code by hand:" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "xJYDzOcrZ8pI", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "output_extras": [ - {} - ], - "base_uri": "https://localhost:8080/", "height": 35 }, - "outputId": "6c244ee4-b141-4ad6-eefa-cfffa71f33c6", + "colab_type": "code", "executionInfo": { + "elapsed": 169, "status": "ok", - "timestamp": 1522345738402, - "user_tz": 240, - "elapsed": 483, + "timestamp": 1531750912183, "user": { - "displayName": "Dan Moldovan", - "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", - "userId": "112023154726779574577" - } - } + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "xJYDzOcrZ8pI", + "outputId": "f392b475-bf87-4d90-919d-44f895ee9fc7" }, - "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sum of even numbers: 42\n" + ] + } + ], "source": [ "def sum_even(numbers):\n", " s = 0\n", " for n in numbers:\n", - " if n % 2 > 0:\n", + " if n % 2 \u003e 0:\n", " continue\n", " s += n\n", " return s\n", @@ -358,77 +328,74 @@ " \n", "# Uncomment the line below to print the generated graph code\n", "# print(autograph.to_code(sum_even))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Sum of even numbers: 42\n" - ], - "name": "stdout" - } ] }, { + "cell_type": "markdown", "metadata": { - "id": "_YXo4KOcbKrn", - "colab_type": "text" + "colab_type": "text", + "id": "_YXo4KOcbKrn" }, - "cell_type": "markdown", "source": [ "Try replacing the `continue` in the above code with `break` -- Autograph supports that as well!" ] }, { + "cell_type": "markdown", "metadata": { - "id": "xHmC0rBIavW_", - "colab_type": "text" + "colab_type": "text", + "id": "xHmC0rBIavW_" }, - "cell_type": "markdown", "source": [ "The Python code above is much more readable than the matching graph code. Autograph takes care of tediously converting every piece of Python code into the matching TensorFlow graph version for you, so that you can quickly write maintainable code, but still benefit from the optimizations and deployment benefits of graphs." ] }, { + "cell_type": "markdown", "metadata": { - "id": "UEHWGpBXbS7g", - "colab_type": "text" + "colab_type": "text", + "id": "UEHWGpBXbS7g" }, - "cell_type": "markdown", "source": [ "Let's try some other useful Python constructs, like `print` and `assert`. We automatically convert Python `assert` statements into the equivalent `tf.Assert` code. " ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "qUU57xlEbauI", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "output_extras": [ - {} - ], - "base_uri": "https://localhost:8080/", "height": 53 }, - "outputId": "add3db4a-2077-4dd5-f7a7-a5b5a4529c26", + "colab_type": "code", "executionInfo": { + "elapsed": 56, "status": "ok", - "timestamp": 1522345738697, - "user_tz": 240, - "elapsed": 253, + "timestamp": 1531750912292, "user": { - "displayName": "Dan Moldovan", - "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", - "userId": "112023154726779574577" - } - } + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "qUU57xlEbauI", + "outputId": "c9cd536a-4a95-4eb0-98c0-aafce5d79580" }, - "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Got error message: assertion failed: [Do not pass zero!]\n", + "\t [[{{node f/Assert/Assert}} = Assert[T=[DT_STRING], summarize=3, _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](f/NotEqual, f/Assert/Assert/data_0)]]\n" + ] + } + ], "source": [ "def f(x):\n", " assert x != 0, 'Do not pass zero!'\n", @@ -444,61 +411,35 @@ " \n", "# Uncomment the line below to print the generated graph code\n", "# print(autograph.to_code(f))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Got error message: assertion failed: [Do not pass zero!]\n", - "\t [[Node: f/Assert/Assert = Assert[T=[DT_STRING], summarize=3, _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](f/NotEqual, f/Assert/Assert/data_0)]]\n" - ], - "name": "stdout" - } ] }, { + "cell_type": "markdown", "metadata": { - "id": "w5hBZaVJbck4", - "colab_type": "text" + "colab_type": "text", + "id": "w5hBZaVJbck4" }, - "cell_type": "markdown", "source": [ "You can also use `print` functions in-graph:" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "6NdzRKLEboRv", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - }, - "output_extras": [ - {} - ], - "base_uri": "https://localhost:8080/", - "height": 35 - }, - "outputId": "fb82dfc3-790f-4127-87f6-361805be9e9b", - "executionInfo": { - "status": "ok", - "timestamp": 1522345739013, - "user_tz": 240, - "elapsed": 247, - "user": { - "displayName": "Dan Moldovan", - "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", - "userId": "112023154726779574577" } - } + }, + "colab_type": "code", + "id": "6NdzRKLEboRv" }, - "cell_type": "code", + "outputs": [], "source": [ "def print_sign(n):\n", - " if n >= 0:\n", + " if n \u003e= 0:\n", " print(n, 'is positive!')\n", " else:\n", " print(n, 'is negative!')\n", @@ -512,62 +453,58 @@ " \n", "# Uncomment the line below to print the generated graph code\n", "# print(autograph.to_code(print_sign))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "1 is positive!\n" - ], - "name": "stdout" - } ] }, { + "cell_type": "markdown", "metadata": { - "id": "9u_Z3i3AivLA", - "colab_type": "text" + "colab_type": "text", + "id": "9u_Z3i3AivLA" }, - "cell_type": "markdown", "source": [ - "We can convert lists to TensorArray, so appending to lists also works, with a few modifications:" + "Appending to lists also works, with a few modifications:" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "MjhCQJVuiTNR", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "output_extras": [ - {} - ], - "base_uri": "https://localhost:8080/", "height": 35 }, - "outputId": "dc320b87-595b-4392-d29c-994486fd8a0a", + "colab_type": "code", "executionInfo": { + "elapsed": 148, "status": "ok", - "timestamp": 1522345744470, - "user_tz": 240, - "elapsed": 5391, + "timestamp": 1531750912595, "user": { - "displayName": "Dan Moldovan", - "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", - "userId": "112023154726779574577" - } - } + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "MjhCQJVuiTNR", + "outputId": "96bf9131-c7c1-4359-ee82-9c38575e7ab4" }, - "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0 1 2 3 4]\n" + ] + } + ], "source": [ "def f(n):\n", " numbers = []\n", " # We ask you to tell us about the element dtype.\n", - " autograph.utils.set_element_type(numbers, tf.int32)\n", + " autograph.set_element_type(numbers, tf.int32)\n", " for i in range(n):\n", " numbers.append(i)\n", " return autograph.stack(numbers) # Stack the list so that it can be used as a Tensor\n", @@ -580,65 +517,62 @@ " \n", "# Uncomment the line below to print the generated graph code\n", "# print(autograph.to_code(f))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "[0 1 2 3 4]\n" - ], - "name": "stdout" - } ] }, { + "cell_type": "markdown", "metadata": { - "id": "UdG8ZFrkTAF2", - "colab_type": "text" + "colab_type": "text", + "id": "UdG8ZFrkTAF2" }, - "cell_type": "markdown", "source": [ "And all of these functionalities, and more, can be composed into more complicated code:\n" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "DVs6wt8NKaGQ", - "colab_type": "code", + "cellView": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "output_extras": [ - {} - ], - "base_uri": "https://localhost:8080/", "height": 53 }, - "cellView": "code", - "outputId": "0a4b8d08-8f65-4bbc-85ba-dc4c60563519", + "colab_type": "code", "executionInfo": { + "elapsed": 555, "status": "ok", - "timestamp": 1522345745186, - "user_tz": 240, - "elapsed": 658, + "timestamp": 1531750913176, "user": { - "displayName": "Dan Moldovan", - "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", - "userId": "112023154726779574577" - } - } + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "DVs6wt8NKaGQ", + "outputId": "8729229c-4f08-4640-d3a1-0d3f9c697a87" }, - "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The prime numbers less than 50 are:\n", + "[ 2 3 5 7 11 13 17 19 23 29 31 37 41 43 47]\n" + ] + } + ], "source": [ "def print_primes(n):\n", " \"\"\"Returns all the prime numbers less than n.\"\"\"\n", - " assert n > 0\n", + " assert n \u003e 0\n", " \n", " primes = []\n", - " autograph.utils.set_element_type(primes, tf.int32)\n", + " autograph.set_element_type(primes, tf.int32)\n", " for i in range(2, n):\n", " is_prime = True\n", " for k in range(2, i):\n", @@ -663,45 +597,36 @@ " \n", "# Uncomment the line below to print the generated graph code\n", "# print(autograph.to_code(print_primes))" - ], - "execution_count": 0, - "outputs": [ - { - "output_type": "stream", - "text": [ - "The prime numbers less than 50 are:\n", - "[ 2 3 5 7 11 13 17 19 23 29 31 37 41 43 47]\n" - ], - "name": "stdout" - } ] }, { + "cell_type": "markdown", "metadata": { - "id": "JQ8kQT99VqDk", - "colab_type": "text" + "colab_type": "text", + "id": "JQ8kQT99VqDk" }, - "cell_type": "markdown", "source": [ "# 3. Case study: training MNIST with Keras\n", "\n", - "As we've seen, writing control flow in Autograph is easy. So running a training loop in graph should be easy as well!\n", + "As we've seen, writing control flow in AutoGraph is easy. So running a training loop in graph should be easy as well!\n", "\n", "Here, we show an example of such a training loop for a simple Keras model that trains on MNIST." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "0CrtGWgwuLJr", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "0CrtGWgwuLJr" }, - "cell_type": "code", + "outputs": [], "source": [ "import gzip\n", "import shutil\n", @@ -754,66 +679,67 @@ "\n", "def mnist_test(directory):\n", " return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "2zu1U9Nqir6L", - "colab_type": "text" + "colab_type": "text", + "id": "2zu1U9Nqir6L" }, - "cell_type": "markdown", "source": [ "First, we'll define a small three-layer neural network using the Keras API" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "x_MU13boiok2", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "x_MU13boiok2" }, - "cell_type": "code", + "outputs": [], "source": [ "def mlp_model(input_shape):\n", - " model = tf.keras.Sequential([\n", + " model = tf.keras.Sequential((\n", " tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),\n", " tf.keras.layers.Dense(100, activation='relu'),\n", - " tf.keras.layers.Dense(10, activation='softmax')])\n", + " tf.keras.layers.Dense(10, activation='softmax'),\n", + " ))\n", " model.build()\n", " return model" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "Wuqg3H8mi0Xj", - "colab_type": "text" + "colab_type": "text", + "id": "Wuqg3H8mi0Xj" }, - "cell_type": "markdown", "source": [ "Let's connect the model definition (here abbreviated as `m`) to a loss function, so that we can train our model." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "W51sfbONiz_5", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "W51sfbONiz_5" }, - "cell_type": "code", + "outputs": [], "source": [ "def predict(m, x, y):\n", " y_p = m(x)\n", @@ -822,63 +748,63 @@ " accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n", " accuracy = tf.reduce_mean(accuracies)\n", " return l, accuracy" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "035tNWQki9tr", - "colab_type": "text" + "colab_type": "text", + "id": "035tNWQki9tr" }, - "cell_type": "markdown", "source": [ "Now the final piece of the problem specification (before loading data, and clicking everything together) is backpropagating the loss through the model, and optimizing the weights using the gradient." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "CsAD0ajbi9iZ", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "CsAD0ajbi9iZ" }, - "cell_type": "code", + "outputs": [], "source": [ "def fit(m, x, y, opt):\n", " l, accuracy = predict(m, x, y)\n", " opt.minimize(l)\n", " return l, accuracy" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "PcVRIacKjSwb", - "colab_type": "text" + "colab_type": "text", + "id": "PcVRIacKjSwb" }, - "cell_type": "markdown", "source": [ "These are some utility functions to download data and generate batches for training" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "RVw57HdTjPzi", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "RVw57HdTjPzi" }, - "cell_type": "code", + "outputs": [], "source": [ "def setup_mnist_data(is_training, hp, batch_size):\n", " if is_training:\n", @@ -896,16 +822,14 @@ " x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))\n", " y = tf.one_hot(tf.squeeze(label), 10)\n", " return x, y" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "2zEJH5XNjgFz", - "colab_type": "text" + "colab_type": "text", + "id": "2zEJH5XNjgFz" }, - "cell_type": "markdown", "source": [ "This function specifies the main training loop. We instantiate the model (using the code above), instantiate an optimizer (here we'll use SGD with momentum, nothing too fancy), and we'll instantiate some lists to keep track of training and test loss and accuracy over time.\n", "\n", @@ -913,33 +837,35 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "UUI0566FjZPx", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "UUI0566FjZPx" }, - "cell_type": "code", + "outputs": [], "source": [ "def train(train_ds, test_ds, hp):\n", " m = mlp_model((28 * 28,))\n", " opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n", + "\n", " train_losses = []\n", - " train_losses = autograph.utils.set_element_type(train_losses, tf.float32)\n", + " autograph.set_element_type(train_losses, tf.float32)\n", " test_losses = []\n", - " test_losses = autograph.utils.set_element_type(test_losses, tf.float32)\n", + " autograph.set_element_type(test_losses, tf.float32)\n", " train_accuracies = []\n", - " train_accuracies = autograph.utils.set_element_type(train_accuracies,\n", - " tf.float32)\n", + " autograph.set_element_type(train_accuracies, tf.float32)\n", " test_accuracies = []\n", - " test_accuracies = autograph.utils.set_element_type(test_accuracies,\n", - " tf.float32)\n", - " i = tf.constant(0)\n", - " while i < hp.max_steps:\n", + " autograph.set_element_type(test_accuracies, tf.float32)\n", + "\n", + " i = 0\n", + " while i \u003c hp.max_steps:\n", " train_x, train_y = get_next_batch(train_ds)\n", " test_x, test_y = get_next_batch(test_ds)\n", " step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)\n", @@ -956,173 +882,144 @@ " return (autograph.stack(train_losses), autograph.stack(test_losses),\n", " autograph.stack(train_accuracies),\n", " autograph.stack(test_accuracies))" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "cYiUQ1ppkHzk", - "colab_type": "text" + "colab_type": "text", + "id": "cYiUQ1ppkHzk" }, - "cell_type": "markdown", "source": [ "Everything is ready to go, let's train the model and plot its performance!" ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "K1m8TwOKjdNd", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "output_extras": [ - {}, - {}, - {} - ], - "base_uri": "https://localhost:8080/", - "height": 988 + "height": 585 }, - "outputId": "f9d3eef3-5bea-45c1-ddf9-4edee73e4436", + "colab_type": "code", "executionInfo": { + "elapsed": 17094, "status": "ok", - "timestamp": 1522345800262, - "user_tz": 240, - "elapsed": 52391, + "timestamp": 1531750930585, "user": { - "displayName": "Dan Moldovan", - "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", - "userId": "112023154726779574577" - } - } + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "K1m8TwOKjdNd", + "outputId": "9f63da19-c3bf-498b-cf00-29090bf3b4f0" }, - "cell_type": "code", - "source": [ - "with tf.Graph().as_default():\n", - " hp = tf.contrib.training.HParams(\n", - " learning_rate=0.05,\n", - " max_steps=500,\n", - " )\n", - " train_ds = setup_mnist_data(True, hp, 50)\n", - " test_ds = setup_mnist_data(False, hp, 1000)\n", - " tf_train = autograph.to_graph(train)\n", - " (train_losses, test_losses, train_accuracies,\n", - " test_accuracies) = tf_train(train_ds, test_ds, hp)\n", - "\n", - " with tf.Session() as sess:\n", - " sess.run(tf.global_variables_initializer())\n", - " (train_losses, test_losses, train_accuracies,\n", - " test_accuracies) = sess.run([train_losses, test_losses, train_accuracies,\n", - " test_accuracies])\n", - " plt.title('MNIST train/test losses')\n", - " plt.plot(train_losses, label='train loss')\n", - " plt.plot(test_losses, label='test loss')\n", - " plt.legend()\n", - " plt.xlabel('Training step')\n", - " plt.ylabel('Loss')\n", - " plt.show()\n", - " plt.title('MNIST train/test accuracies')\n", - " plt.plot(train_accuracies, label='train accuracy')\n", - " plt.plot(test_accuracies, label='test accuracy')\n", - " plt.legend(loc='lower right')\n", - " plt.xlabel('Training step')\n", - " plt.ylabel('Accuracy')\n", - " plt.show()" - ], - "execution_count": 0, "outputs": [ { - "output_type": "stream", - "text": [ - "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/autograph_mnist_data/train-images-idx3-ubyte.gz\n", - "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/autograph_mnist_data/train-labels-idx1-ubyte.gz\n", - "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/autograph_mnist_data/t10k-images-idx3-ubyte.gz\n", - "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/autograph_mnist_data/t10k-labels-idx1-ubyte.gz\n", - "Step 0 train loss: 2.244329 test loss: 2.2499208 train accuracy: 0.12 test accuracy: 0.161\n", - "Step 50 train loss: 0.64771986 test loss: 0.56013924 train accuracy: 0.82 test accuracy: 0.836\n", - "Step 100 train loss: 0.49011207 test loss: 0.42143965 train accuracy: 0.84 test accuracy: 0.879\n", - "Step 150 train loss: 0.3768609 test loss: 0.39319593 train accuracy: 0.88 test accuracy: 0.883\n", - "Step 200 train loss: 0.36007702 test loss: 0.37089333 train accuracy: 0.9 test accuracy: 0.881\n", - "Step 250 train loss: 0.182115 test loss: 0.28543878 train accuracy: 0.94 test accuracy: 0.915\n", - "Step 300 train loss: 0.2119576 test loss: 0.22305593 train accuracy: 0.92 test accuracy: 0.93\n", - "Step 350 train loss: 0.12932214 test loss: 0.29057172 train accuracy: 0.96 test accuracy: 0.906\n", - "Step 400 train loss: 0.22937602 test loss: 0.2200287 train accuracy: 0.92 test accuracy: 0.925\n", - "Step 450 train loss: 0.23444137 test loss: 0.19857481 train accuracy: 0.94 test accuracy: 0.94\n" - ], - "name": "stdout" - }, - { - "output_type": "display_data", "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAe8AAAFnCAYAAACPasF4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzs3XmAFNW9Pvynlt5mYdhmQMHggnGN\nS9zCD0ElKug1edUY9ZoQTYze3GuiRk1uYjRqRHNj4n5NrhKjiUYlbihGQFRUFDSoKIvgICAO6+xL\n711V5/2jlq7qZaZnpnumZ3g+/zjTXV1dXSP91PecU+dIQggBIiIiGjLkwT4AIiIi6h2GNxER0RDD\n8CYiIhpiGN5ERERDDMObiIhoiGF4ExERDTEMb6JeOOigg3DllVdmPf6rX/0KBx10kGe766+/3rPN\ne++9h9mzZwMAtm3bhkMPPdR57osvvsCPfvQjzJw5EzNnzsTZZ5+NV199FQBw0003YdasWZg1axYO\nO+wwnHLKKc7v4XDY8x7JZBLz58/v9edavXo1Lr300oK2XbBgAebMmdPn97J19/rZs2fjhRde6PO+\niYY7hjdRL3366aee0Ewmk1izZk3WditXrsQnn3xS0D6vu+46TJs2DYsXL8bixYtxyy234LrrrsPO\nnTtxyy23YNGiRVi0aBHGjRuH3//+987vVVVVnv188sknfQrUI444Ag8//HBB2y5fvhxTpkzp83vZ\n+vt6oj0Zw5uol0444QQsWbLE+f3tt9/GV77ylaztrrnmGtx+++0F7bO+vh5HHnmk8/uRRx6JxYsX\nY/z48QUfV3NzM3784x/jo48+wkUXXQTAbAF48MEHMXPmTOi6jlWrVuHcc8/FrFmzcOaZZ2L58uUA\nzFaB0047DQBw//334ze/+Q2uuOIKfP3rX8d5552HxsZG533ee+89HHzwwVnv9cEHH+Bb3/oWTjvt\nNJx//vloaGgAAOzevRsXX3wxzjzzTJx66qm4++67cx5rPu+99x7OOecczJo1C9/+9redC6Vc++3u\ncSEE/vd//xczZ87EKaecgjlz5kDXdQDAwoULcdZZZ+GMM87AN77xDbz33nsFn3eiwcDwJuqlM844\nAy+99JLz+z//+U/MmjUr53ZCCCxatKjHfU6fPh1XXnkl/va3v2HTpk0AgHHjxkGSpIKPa+zYsbjm\nmmtw1FFH4YknnnAeF0Jg8eLFUBQFv/71r3HppZdi0aJFuPzyy3HTTTfl3NeiRYtw/fXX49VXX8WY\nMWPw7LPPAgA2bdqE2tpaTJgwwfNe4XAY//mf/4lrrrkGS5Yswfe+9z1cddVVAIBHH30Uxx13HF5+\n+WUsWLAADQ0NMAwj57FmikQiuOqqq3DDDTdg0aJF+OEPf4jrrrsOhmHk3G9jY2Pex1944QUsWrQI\nzzzzDJYsWYKGhgY8+eSTAIBbbrkFDz74IBYuXIibbroJr7/+esHnnWgwMLyJeun444/Hxo0b0dLS\nglgshlWrVmHKlCk5t73++uvxhz/8AYlEott9/v73v8d3vvMdLFiwAGeddRZmzJjhBEt/nXzyyc7P\n8+fPxxlnnAEAOOaYY5zqONOxxx6LCRMmQJIkHHLIIdi5cycAYMWKFTk/6wcffIBx48Zh6tSpAICz\nzjoLX3zxBXbs2IExY8bg7bffxvvvvw+/34+77roLdXV1BR376tWrMX78eBxzzDEAgJkzZ6KtrQ3b\nt2/Pu998jy9duhTf+ta3UF1dDVVV8e1vfxuvvPIKAGDMmDF46qmnsH37dhx77LH45S9/WdjJJRok\n6mAfANFQoygKTj/9dCxcuBCjR4/GiSeeCFXN/U/psMMOw3HHHYdHHnkERx99dN59BgIBXHrppbj0\n0kvR2dmJRYsW4fbbb8fEiRMxbdq0fh3vyJEjnZ8XLFiAv/3tb4hEIjAMA/mWNqiurnZ+VhTFaV5+\n5513cMkll2Rt39nZiYaGBk8LhN/vR2trKy655BIYhoFbbrkFjY2N+M53voOf/OQnBR17a2srRowY\nkXVsLS0tefeb7/Guri48/PDDmDdvHgBA13WMHj0aAPCnP/0Jf/rTn3Duuedir732wvXXX4/jjz++\noGMkGgwMb6I+OPPMM3H33Xdj1KhRPfbZ/vSnP8W5556LiRMn5ny+tbUV69evd6rWESNG4Pzzz8ey\nZctQX1/f7/C27d69GzfccAOefvppHHLIIfj8888xc+bMgl+vaRrWrFmT8yKkrq4O+++/P5577rmc\nr7388stx+eWXY8uWLbjsssucSronY8aMQXt7u/O7EAIdHR0YM2YMVFXNud+pU6fmfLyurg4zZszA\nd7/73az3+dKXvoTf/va3MAwD8+fPx7XXXotly5YVeGaIBh6bzYn64Oijj0ZjYyM2btzYY4VWV1eH\n73znO7j//vtzPh+Px3HllVd6wmLr1q34+OOPceyxx/bquFRVRTgczllRt7a2oqKiAvvvvz80TXMq\n0EgkUtC+V69ejYMOOgh+vz/rvY488kg0NTXh448/BgA0NDTgZz/7GYQQ+PWvf4133nkHgBmSY8eO\nhSRJ3R6r7YgjjkBzczNWrVoFwBxfMH78eEycODHvfvM9/vWvfx0vvPACYrEYAOCpp57C888/j9bW\nVnz/+99HOByGLMs48sgjezXWgGgwsPIm6gNJknDaaachFotBlnu+Bv7BD36Ap59+Oudze++9N/70\npz/hvvvuw5w5cyCEQFVVFX75y196RqAX4phjjsEf/vAHTJs2DW+++abnuYMPPhjTp0/HzJkzMWbM\nGPziF7/Ahx9+iNmzZ+O///u/e9y3fYtYvve67777cOuttyISicDn8+Gqq66CJEm48MIL8etf/xq3\n3norhBCYMWMGpkyZgh07dnheryhK1ntWVFTgnnvuwa233opoNIrRo0fjrrvu6na/I0eOzPk4AGzc\nuBHnnHMOADPYb7vtNowePRrTpk3Dt771LSiKAp/Ph9tuu61X551ooElcz5uIiGhoYbM5ERHREMPw\nJiIiGmIY3kREREMMw5uIiGiIYXgTERENMUPmVrGmpq6i7m/UqAq0tUWLus89Ec9j//Ec9h/PYXHw\nPPZfsc9hbW11zsf32MpbVbPvKaXe43nsP57D/uM5LA6ex/4bqHO4x4Y3ERHRUMXwJiIiGmIY3kRE\nREMMw5uIiGiIYXgTERENMQxvIiKiIYbhTURENMQwvImIaNh6443XCt723nvvxI4d23vc7sMP38cN\nN/y8P4fVbwxvIiIalnbu3IFXX11c8PZXXXUt9t57QgmPqHiGzPSoREREvXHXXb/D+vXr8Mgjc2EY\nBnbs2I6dO3fgnnv+iN/+9jdoampELBbDD35wOaZOnYYf//hyXHPNz7F06WuIRML44out2L59G668\n8lpMmTI153u89toSzJv3dyiKgoMOOgS33XYL6us34M47fwefzwe/349bbvktdu7cnvVYdXXuqU8L\nsceGd0c4gfc3NOLYg+sG+1CIiIa9f7z+GVZuaCzqPo87uA7nz5ic9/l///fZeO65f+D7378MDz/8\nIDQthT/+8c9oa2vF8cd/DWeccRa2b9+GG2/8BaZOneZ5bWPjbvzhD/fh3XeX44UXns0Z3tFoFA89\n9AAeeeQJVFRU4Oc//yneffddvPzyyzjnnPMwa9a/4YMPVqK1tQUvv7wg6zGGdx9ceecbaO2M46ZL\njsOk8X0/gURENDQccshhAIDq6hFYv34dXnzxOUiSjM7OjqxtjzjiKABAXV0dwuFwzv01NHyBiRO/\nhIqKCgDA0Ucfg/Xr1+PEE0/CH/7wP2ho+AJf//ppmDRp35yP9cceGd5b23YiPOFNSMnD0dwRZ3gT\nEZXY+TMmd1slDwSfzwcAWLJkETo7O/HAA39GZ2cnfvjD2VnbKkp6gREhRM79SZL3OU1LQZJCOPbY\n4/HnP/8Ny5cvw5w5N+PHP74652Nf/eqxff4se2R4f7ztCyjVbTBG70RLZ3ywD4eIiEpAlmXoup71\neHt7O/baa2/Isow333wdqVSqT/vfZ59J2LbtC0SjEVRUVGLVqg9x1VU/xrPPzsOUKSfi9NPPgBAC\n9fUbsGXLpqzHGN69dPykA7G4CZArO9DSwfAmIhqOJk3aD59+ugH33XcnKiurnMdPPnkGfvGLa/DJ\nJ2vxb//2TdTV1eGRR+b2ev+hUAhXXHEVrr32J5AkGUcccRSOPfZY7NzZghtv/AWqqqrg8/lw/fU3\nob7+06zH+kMS+doDykxTU1dR93fjit+ipTOCQyLn4yfnHlHUfe9Jamuri/632dPwHPYfz2Fx8Dz2\nX7HPYW1t7m7dPfY+7y+P2Q+SL4mmcOtgHwoREVGv7LHhPbFmPACgLdk2yEdCRETUO3tseI8JjQIA\nxBFGPKkN8tEQEREVbs8N74rRAADJH+egNSIiGlL22PAeW2FW3pI/xtvFiIhoSNljw3uME96svImI\naGjZY8M75AvCLwcg+eNoZuVNRDQs9WZJUNtHH32ItjbvnUjlsAyo2x4b3gAwMlDDypuIaJjq7ZKg\ntn/+88Ws8C43e+QMa7a6ijFojDWiqSt7UnoiIhra3EuCXnDBRbj99lvQ1dUFXddx9dU/w+TJB+Lx\nxx/Fm28uhSzLmDp1Gg455FAsW/YGtmzZjDlz7sD48eOz9pu5DOjVV1/nLANaWRkCIJdkGVC3PTy8\nxwItQKfWPtiHQkQ0rD332UtY1bimqPs8uu4rOHfyWXmfdy8J+uijf8YJJ/w/fOMbZ2PLls24994/\n4J57/oinnnoc8+cvgqIomD//WRx33NcwefKXcc01P88Z3LmWAf3ww/fx1ltLcc4552H27AuxaNHr\nJVkG1G2PDu/a0FgAQAysvImIhrM1a1ajvb0Nixe/DABIJMzu0pNP/jquvvq/cNpps3D66bN63E+u\nZUDr6zc4S362tOzClCknlWQZULc9OrzrKszwTildMISALEmDfERERMPTuZPP6rZKLjWfT8VPf/oz\nHH64dy2L6677JbZu/Ryvv74EP/nJf+Chh/7a7X5yLQMaCAScJT/XrFlZsmVA3fboAWt25Y1gFNE4\nZ1kjIhpO3EuCHnro4XjrrTcAAFu2bMZTTz2OcDiMRx6Zi0mT9sX3v38ZqqtrEI1G8i4lCniXAQWA\nVas+xEEHHYpnn52Hzs4OfPOb38QFF1yE+voNzmOnn36G81ix7NGV96hgDSQhQw5EEYmnUBXyDfYh\nERFRkbiXBP3hD3+E2267Gf/1Xz+EYRi4+urrUFVVhfb2Nlx22fcQClXg8MOPwIgRNTjqqK/ihhv+\nG7/97Z3Yf/8DPPvMtQzokUcehVgsihtv/AVGjaoBIJdkGVC3PXZJUHvZtp++fgvicQM/O+pa7L/3\niKK+x56ASwj2H89h//EcFgfPY/9xSdABEpCCkNQUIvHUYB8KERFRQfb48A4qIUiqhs4oJ2ohIqKh\nYY8P7wo1BABoj4YH+UiIiIgKs8eHd6XPvFevIxEZ5CMhIiIqzB4f3iMClQCAjjjDm4iIhoY9PrxH\nVZgj+Xa1tw3ykRARERVmjw/v0RXm7WE7OjrQHk4M8tEQERH1bI8P70qfOWBNUpNYvallkI+GiIio\nZwxvn9nnDSWFpvbY4B4MERFRAUo6Peodd9yBDz74AJqm4T/+4z9w+umnO88tX74cd911FxRFwfTp\n03HFFVeU8lDysm8Vk9QUWjvZbE5EROWvZOH97rvvYuPGjZg3bx7a2tpwzjnneMJ7zpw5ePjhhzFu\n3Dh897vfxcyZMzF58uRSHU5eITVo/qBoaOviRC1ERFT+Shbexx13HI44wlx6bcSIEYjFYtB1HYqi\noKGhATU1Ndhrr70AACeddBJWrFgxKOHtV/wAAJ9foK2NlTcREZW/koW3oijOYuXPPPMMpk+fDkVR\nAABNTU0YPXq0s+3o0aPR0NDQ7f5GjaqAqipFPcba2mqM1M3K2+8XaI8kMXZsFSSu690r+SbOp8Lx\nHPYfz2Fx8Dz230Ccw5IvCfrqq6/imWeewV/+8pd+7aetLVqkIzLZK78IISBLMiTFQCKpY+u2NlQG\nuTRoobgKUf/xHPYfz2Fx8Dz237BYVWzZsmX4v//7P8ydOxfV1ekDqKurQ3Nzs/P77t27UVdXV8pD\nyUuSJPhlP2TVXHi9jYPWiIiozJUsvLu6unDHHXfgwQcfxMiRIz3PTZw4EeFwGNu2bYOmaVi6dCmm\nTp1aqkPpkV/xAbIZ3h2R5KAdBxERUSFK1mz+8ssvo62tDVdffbXz2AknnICDDjoIp512Gm6++WZc\ne+21AIAzzzwT++23X6kOpUd+xY9kyhxpHo5xXW8iIipvJQvvCy64ABdccEHe54877jjMmzevVG/f\nKwHFjw6YS4IyvImIqNzt8TOsAYBf9kMXZmhHGN5ERFTmGN4w+7wNGIBksPImIqKyx/BGeqIWyDrC\ncYY3ERGVN4Y3zD5vAGZ4s/ImIqIyx/CG2ecNAKrPYJ83ERGVPYY3rPu8AYRCEitvIiIqewxvpPu8\nQyEgHNMG+WiIiIi6x/BGus87GABiCQ26YQzyEREREeXH8Ea68g5YS3tH46y+iYiofDG8Afhls89b\nVc2KO57UB/NwiIiIusXwRrryllUBwGw6JyIiKlcMbwABJQAAzrKgrLyJiKicMbwBhFQzvCXVrLjj\nSVbeRERUvhjeAIKKNVJNNu/xjiVYeRMRUflieAMIWpW3IZkVd4yVNxERlTGGN4CQGgIAGJJZecdZ\neRMRURljeAMIWgPWdCQBsM+biIjKG8MbgCqrUCQFmhXe7PMmIqJyxvAGIEkSgmoAKWGFNytvIiIq\nYwxvS1AJImkkAABxTtJCRERljOFtCaoBJHQrvDlJCxERlTGGtyWkBpHQk1BkNpsTEVF5Y3hbgkoQ\nAgKBoOCtYkREVNYY3hZ7opZgCIiyz5uIiMoYw9sSVM0pUitCQCSWGuSjISIiyo/hbQlZ85sHQwJJ\nzUAixaZzIiIqTwxvi115B4IGAFbfRERUvhjeFrvP2+c3wzvM8CYiojLF8LbYzeYqw5uIiMocw9ti\nV96yz+zrZngTEVG5YnhbglblLavmbWIMbyIiKlcMb4tdeUNheBMRUXljeFtC1mhzQzJDm+FNRETl\niuFtCTK8iYhoiGB4W+w+b3tNb85vTkRE5YrhbfHJKmRJdtb0TunGIB8RERFRbgxviyRJCClBZ01v\nneFNRERliuHtElQDiGlxKLLEypuIiMoWw9slqAYR1xJQFRmaJgb7cIiIiHJieLsErWZzRQE0Vt5E\nRFSmGN4uITUAAQHVLxjeRERUthjeLj7Fb/5XNRjeRERUthjeLn7ZBwCQVYGUzj5vIiIqTwxvF5+s\nAgAU1YCm9b/ybutK4MEX16G5I9bvfREREdkY3i4+xay8FaU4fd5PvFqP9z7Zjb8u3NDvfREREdkY\n3i4+u9ncZ0ArQrN5PKl7/ktERFQMDG8Xu89bUQwYQsAw2O9NRETlh+HtYjebS4rZZM5Z1oiIqBwx\nvF2c0eayGdq8XYyIiMoRw9vF7vO2K+9i9HsTEREVG8PbxWk2tyvvItwuRkREVGwlDe/6+nqceuqp\nePzxx7OemzFjBi666CLMnj0bs2fPxu7du0t5KAWxK2/I5ujwfjebC1buRERUfGqpdhyNRnHrrbdi\nypQpebeZO3cuKisrS3UIvebPCG8OWCMionJUssrb7/dj7ty5qKurK9VbFF1Ws3mxwlsqzm6IiIiA\nElbeqqpCVbvf/U033YTt27fjmGOOwbXXXgtJGtyUs6dHFZLdbM5mbyIiKj8lC++eXHnllZg2bRpq\nampwxRVXYPHixZg1a1be7UeNqoCqKkU9htraas/vcf9IAIBqLi6Gqqpg1ja94fObp9enKv3aT7kb\nzp9toPAc9h/PYXHwPPbfQJzDQQvvs88+2/l5+vTpqK+v7za829qiRX3/2tpqNDV1eR4Lx1IAgJSW\nBAA0t4TRVBPo83ukkpq1Pz3rvYaLXOeReofnsP94DouD57H/in0O810IDMqtYl1dXbj00kuRTJoh\nuXLlShx44IGDcSge9mhzQ+KANSIiKl8lq7zXrl2L3/3ud9i+fTtUVcXixYsxY8YMTJw4Eaeddhqm\nT5+OCy64AIFAAIceemi3VfdA8St2n7dZMevs8yYiojJUsvA+/PDD8dhjj+V9/uKLL8bFF19cqrfv\nE6fyBitvIiIqX5xhzUWRFEiQYMCsvDnDGhERlSOGt4skSfApPqfy7uk+7x3hXXjsk38grsUH4vCI\niIgADOJo83Lll33QhTVKvIc+7/s+eghdyTDGVdTi9H1PGYjDIyIiYuWdKagEkDQSAAC9m8p7W2MY\nXckwACBpJAfk2IiIiACGd5bairGIGREEv/oqtic3593ulfcbnJ8lzn9KREQDiOGdYXyFORe7pGpY\nrb2af0N3i/ogT+tKRER7FoZ3hnGV6YVUVPjzbifAe8CJiGhwMLwzjK+oTf8iCquoZTabExHRAGJ4\nZxhfOc75OYEINEPLvaGn8GZ4ExHRwGF4Z6j2V+EHX/4h9I4xgCTQGm/r8TXs8iYiooHE8M5h/5pJ\nMLpGAQCaYq05t/GMV8tTebNXnIiISoHhnYOqSBApc7BaLJV7KVLhSmbeKkZERAOJ4Z2DqsiAYU4+\nl8g7AYsnvYmIiAYMwzsHVZEhDAUAkNBzh3chzeZERESlwPDOQVUkQDfDO5knvHuD4U5ERMXE8M5B\nkiQo1pot21o6cm8kvNsTERENFIZ3Hgp8AICV9TuwsyWS9TxHkhMR0WBheOdhhzdkHZ2R7pvO2SxO\nREQDieGdhyqlwzsX4bpXzBD5lw4lIiIqNoZ3HnZ4S0qe6VFd3EFORERUagzvPHyKz5yIRdaR1Lqv\nrA2w8iYiooHD8M5DlRXAUCApOpKp7KZzd7HNZnMiIhpIDO88fKp1r7esI5nqofJmszkREQ0ghnce\n5ixrKiRFQ0LLUXm7f2blTUREA4jhnYeqyAVX3nqe8GZBTkREpcDwzkOWJXN+c1lHIpljxLn7VjEO\nWCMiogHE8M7DMIQ5YE0WSGiprOe9zeY9lNicw4WIiIqI4Z2HYQhAN+c3j2mJ7rdlnzcREQ0ghnce\nuiGcZUFjqXj2Bp5bxdi5TUREA6eg8F67di2WLl0KALj77rtx8cUX4/333y/pgQ023RAQyQAAIKKH\nu92Wfd5ERDSQCgrvOXPmYL/99sP777+PNWvW4MYbb8R9991X6mMbVIYhIBIhAEBMdGU9z1vFiIho\nsBQU3oFAAPvuuy9ee+01nH/++Zg8eTJkeXi3uJuVdzfh7VmYhM3mREQ0cApK4FgshoULF+LVV1/F\niSeeiPb2dnR2dpb62AaVIQREMggASCLXet7pwK7f1pZzxDkXLCEiolIoKLyvueYaLFiwAD/96U9R\nVVWFxx57DJdcckmJD21w6a5m86Scq8873VS+uy2CpvZY9hZ2djPDiYioiNRCNvra176Gww8/HFVV\nVWhubsaUKVPw1a9+tdTHNqgMwwAMFUJToSvRrOfdlTckkQ5q9zZW5c0KnIiIiqmgyvvWW2/FwoUL\n0d7ejgsvvBCPP/44br755hIf2uD60rhqAIBIhKCrkawAzry3O3ezub1taY6RiIj2TAWF9yeffIJv\nf/vbWLhwIc455xzcc8892Lp1a6mPbVBdcsbB+N7Mg+DTqwFZR0cyo49fSieyxMqbiIgGUEHhbYfP\nG2+8gRkzZgAAkslk6Y6qDFQGfTj56AkIiBEAgMZok+d54b63WxI5VyGxA53ZTURExVRQeO+33344\n88wzEYlEcMghh2D+/Pmoqakp9bGVhZAwP+fOsDe8vROziJwBzcqbiIhKoaABa3PmzEF9fT0OOOAA\nAMDkyZNxxx13lPTAykW1MgotALZ37fY8nll557rXO+lrARSJfd5ERFRUBYV3PB7H66+/jnvvvReS\nJOGoo47C5MmTS31sZWGkbzSAXM3mmaPNvQm9qf1ztO31OvwVYyBaTy71YRIR0R6koGbzG2+8EeFw\nGBdeeCHOP/98NDc344Ybbij1sZWFmmAVhACi1uIkH3zaiBfe3gJkNJvruje869s2AQCUmhb2eRMR\nUVEVVHk3Nzfjrrvucn4/5ZRTMHv27JIdVDmpCKpAVIJm6ACAB55fCwA4cLLrukcS6EqGcdt7D+Oc\nyf+GQ8cchNZ4KwBApHzs8yYioqIqeHrUWCw9g1g0GkUi0f0a18NFZVAFhAxN1z2Pp9y/S8DqjlXY\nEdmFBz5+GADQEm8DAIhkiH3eRERUVAVV3hdccAHOOOMMHH744QCAdevW4aqrrirpgZWLiqAPEBI0\n4Q3vpK65fhMQGQndaoe3prLyJiKioioovM877zxMnToV69atgyRJuPHGG/HYY4+V+tjKgll5S9AN\n74xqKS0d3lLGgDUhhFN5QzYY3kREVFQFhTcA7LXXXthrr72c31evXl2SAyo3duWti+6azYUnoBN6\n0pk+VZJ1DlgjIqKi6vOi3HtKNVkZVCGEbC5U4pIy3GEunAFtABDX4+mnFH2POVdERDQw+hzekiQV\n8zjKVkVQBSBBhze8tYzKO2GkB/DFtHR4S7LOAWtERFRU3Tabn3TSSTlDWgiBtra2kh1UOamw+ryF\n8PZdp3QNAdd2yTzhzT5vIiIqtm7D+4knnhio4yhbiixDEjIMpKC5J2KR3TOsGXkrb7DPm4iIiqzb\n8J4wYcJAHUdZkyUJAgZSWrqpXJJdt4pJQMod3qmoazsDBpjeRERUPH3u8y5EfX09Tj31VDz++ONZ\nzy1fvhznnXceLrjgAjzwwAOlPIx+kyADEEhprn5v1R3eAkmRDu+2RIfn9ULSQEREVCwlC+9oNIpb\nb70VU6ZMyfn8nDlzcP/99+PJJ5/EO++8g88++6xUh9JviiRDSAaSrvD2VN4QSBnp9c2d8DbM0ysk\n721mRERE/VGy8Pb7/Zg7dy7q6uqynmtoaEBNTQ322msvyLKMk046CStWrCjVofSbLCkABOJJVwgr\n3klaUsIV3vF28wctCAAQYOVNRETFU7LwVlUVwWAw53NNTU0YPXq08/vo0aPR1NSUc9tyoMgyJFmg\nI5JuGpdUb+Wtwd1sboV3yhwf+2w3AAAgAElEQVSPzsqbiIiKqeAZ1gbbqFEVUFWlqPusra0uaDtV\nMU+TUFzXOlblLTQVkj/puQu8PWk1m1vhDVkv+L2GouH82QYKz2H/8RwWB89j/w3EORyU8K6rq0Nz\nc7Pz++7du3M2r7u1tUW7fb63amur0dTUVdC2spABCdi2M31vu2SHt+5zqvB9qvZGQ3gHuhJh87lU\nABIAQ9IKfq+hpjfnkXLjOew/nsPi4Hnsv2Kfw3wXAiUdbZ7PxIkTEQ6HsW3bNmiahqVLl2Lq1KmD\ncSgFUWTzNHVE0/3akDXz/m3NvP6pwlhMm+AdnCecZnP2eRMRUfGUrPJeu3Ytfve732H79u1QVRWL\nFy/GjBkzMHHiRJx22mm4+eabce211wIAzjzzTOy3336lOpR+UxUF0IHOqGvaU1UDdBWQzHu4fQhA\nldOn0y/7ENet5nb2eRMRURGVLLwPP/zwbpcNPe644zBv3rxSvX1RqbIV3rH0oDQoGoSuOn3fighA\nkdN98j7Fh6iumE0bDG8iIiqiQWk2H2pUK5S7oq7R5opZeUtOePuhSunwViU1fZ+3zGZzIiIqHoZ3\nAXyKGcrhuN3nLZzK2x6spghvs7kqqxCaFeYyK28iIioehncB/NatYl0xK7xlA5IkzD5v2A/5PM3m\nqqxA6NbvisaVxYiIqGgY3gXwWfeX68KqoJ3bxNzh7Tebyi2qpMIwrN9lnUuTEBFR0TC8C1Dh91k/\nmRHszGvuCm9J+Jy+cQBmFW5V3pJVeXdGk7j/2dXY1hgekOMmIqLhieFdAL/PCm/JmkdNsSpwwzXj\nm65mNJur6cpcMdf0/ufyrVi1sRn3Pbt6AI6aiIiGK4Z3ARTJOk2SgCSlK293szkMNavZ3A53STYr\nb3s98GSKA9iIiKjvGN4FUOxbwCSByqAPvoDVg+2qvCXd22yuuprNoegw2OlNRERFMmQWJhlMslV5\nS5JAZciHsGrAACB0BYmNR0EZ2QRZqvbcKqZIKgAZQpedypuIiKgYWHkXIN1sbuDEr4yHL2D1fRsq\njLbxSG35CoRhB7b9GsXZxu7zdkjSwBw4ERENSwzvAshWEP/7qZNx5tcmweczk9i5jxuArouMZnPV\n2UbKvM+bVTgREfUDw7sAduU9fmwIkiRB8dmVtyu8hchoNndV3jL7vImIqHgY3gWQrSVBDWGGtqxa\no8Vdo811XXjmNreb0IWuAIoGwzDSO2SzORER9QPDuwB2Fa1b4S1Z93m7m80NQzgD2wCkg1xXIUlA\nyuDiJEREVBwcbV4AO5S/6NyGz9o3A0rKfMJwVd6GAclVUctOs7n537iWXguceuetj3dgQm0lDti7\nZrAPhYioLDC8C2D3eS/e+joAQLZWGfMMWDPSk7CYr/Fuc+fqe3AELhqQ4x1OYgkNjy7cAAD4yy9m\nDPLREBGVBzabF0B29WUDgJCyJ2nRDYHHXql3frfDW1LNKj2hJyBghntnJIkHnlsDg6POe6TpRs8b\nERHtYRjeBVAk72kSMMy7vQxvn/fGhnbXa8zntN2T0tsgXZl/UN+Enc2REh0xERENZwzvAmSGNwAr\nuNN93PGk7unztkebG51jobWMN3+Gd05znfeP9YhniIgoG8O7ALKsZD+oe4cLRGIpz+8KXK+xKnQD\n3hHnDO+esWeBiCgbw7sAco7KWxgZ/eAwB1c5r5HdK45Z94lL3srbYHj3iOdoePvX+t247I6l2N0a\nHexDIRpSGN4FyN9s7hV2Vd+K69TaQS/YbN5rDO/h7c8vfQLdEFi2eudgHwrRkMLwLkDmaHMATjXt\n5g7j5vZk+glhbtssbYLkT1cYKY6k7hFH5A9v/PMS9Q3DuwC5Ku9RVaFuX7P4vW3pX6yg362uQ/Co\nt5yHUymGd08Y3kRE2RjeBcjV5z1+VBWqQj4AQCjQw1w3OZrYAVbehWCzORFRNoZ3AZQczeaqrDrL\nfFZX+Lp9vcjRxA4AyZSe83FKY3jvGbhWD1HvMLwLIOf4ZnEv/1kdyg7vQyeNxrUXHoWvHTouu/KW\nNQACb3e8jHe2vwcAWPDOFsxd8ElRj3s4YHYTEWVjeBcgksq+jcW9/Gd1hT/r+ZHVfhy272jzOeE9\nzZI/DskfxxfJT/HEp88CAJ5ftgUr1u0q8pEPnLWbW7BibfGPn5U3EVE2hncBJo3YBwDw5VGTncfM\nZnPz5xGV6fAWmlmFj6kYCQCQ5exmcykQg+RPrzJmrxMOwGmKz2f+ss34+LPmPnyK0rrrHx9j7kvF\nbznggDUiomwM7wJU+6vwwIw7cOa+pzqPqa5Z1/yq7Axei6+ZisTGo3DUhP0BABUBNavZXArEIAVi\nzu+NkXQYdxdWndEkXnznc9z7zOr+faAS6unio7fKObxffGeLs+IZEdFAYnj3guIKbFVWPfNu71NX\nBQCoCYzAdbPOwKTx1QCAiqAv655wSUlB8qfD+4GP/+KsEa7p3YR3JJn3uXKR1Io7gr6cm83nL9uC\ntz7eMdiHMaSV8bUZUVljePeCu59blVQ4y2ZIwMGTRgEAamtCzs8AUBlSs/q8IQlP5d2aaIW692YA\ngN7N7WPhaCrvc+Wi2CPoyzm8iYgGC8O7F7Iqbye7JZxxwpfwzan74rJvHOp5TWXQlzUPOmTDCe+v\n7zPd3F9tAyBrWZV3Uk9i4ZZX0Z7oQGe0/CvvRLHD23U6/ufvH6KxPZZ/40HCC4y+4y1iRH3D8O4F\n9/3e7j5vSQJURcbZ0/ZH7UjvzGuVOZrNIRmQ/HEoIoBzDzwLB4a+AknVIPnj0DIq73/Uv4CXtryC\nFzctQke4/MM7WeRZ49x93vUN7Zj32sai7r8YONlO37HZnKhvGN69oHbT551PZUjN7vOWDEi+JFQj\nCAAwdOt5SUBzVXHtiQ6s2LnS+b0jo897xY6VWLBpUS8/RWkVu/IWGVVtOS7mknnBRURUagzvXvBU\n3pKCQtK7MuiDENnN5lBSkI0AAEDT0o+7+7zvWzU3/RJJdgashQLm/h7f8DQWbX0dulE+M7UVu887\nM6zLsYk6VeRBekREPWF490L2aHMzSLrrtzNvFcucpCUBSQIk3QzvlDUOTZIM6Fafd2ckieZYi/Oa\nqBZzKu+qjBndolr59AMnSthsnuv3cqAxvPuNfd9EvcPw7gXPaHO5h8VILLIsZd/n7TMnaJE0c3IX\n3S5WJQOaYQbBtX98C7rQceDIAwAAsVQMHZEEAMCvKp77qbuS4d5/mCJyH0vxR5tn/l4e4a27Dox9\n3n0nCup8IqJMDO9eUFyBrcqq606xHsoG4X1e8pshbM/GJgzreUkgkdTxxqrt0GWzyq5UK+BX/Ihq\nMcQTZjDqhkBcT8/QFklF+vyZisHdtF30Pu/MyrtMwlvT0sfR3b35ROXglZUNWL+1bbAPg4qI4d0L\n7nW9PQPWemzyywhvnxnMwqq8Dd16Xjbw7Fub8bfFn0JSzI5wvxxEhRpCTIs5FZ4hBLqS6cAO55h7\nfSC5w9tdeacMDXd/+Ces2LEy18sKkt1s3uddFZW72uaANSpniZSOp17biN8/uWqwD4WKiOHdC1kD\n1iw9ZXfdqFDOx42kWXk74S0Z2N5kNoFLqtkRHrDCO6rFnYFRhiEQTrnDe5Arb1d4ufu8t3Y24LP2\nLXh8w9N933eZjjZ3BzYHrFE5K5fWKiouhncvSK5RNYprkpae3Pz94zC+60QkPj3G83gqYTbD233e\nkiTgXApY06X65QBCaghxLY6UZm5oCOFpKh/sZnMtT+WtGVquzXsl84unXAasuQepsc+7H6w/Z+bY\nBiLqHsO7j3yyAvf0qN0J+lVMCnwZRkcthKv/OxaVYRgCuqvytkmqGXw+KYAKXxACAklh9pUbRmaz\n+WBX3q4+by0d3rmWUu2tzLDOvO97sLgvWDjavP/K5aJsOOK5HZ4Y3n2UOT1qTxTFOtVGeluR8iMS\nT0HX0n3e6RdYlbcUQIVaYT1vPmYIIJxKjzAPJ7NDMvMfbCKl4911u5zqvZjczebJZPrnLtcxLnrv\niz7tO/N7Rx/gLyJNN7BkZQOice+88u7AZp93/7Fpt3TKpauJiovh3UfmwiSmQu5R9dnh7VqkROgq\nWjsTcPLUGm0OpPu8VRFASA1ab2pW45l93pnN5l/s7sIPf7cUb3603XnsuTc346EFn2D+si0FfT63\ndVta8doH2/I+7xlt7ro4CLtuYVv4Xu/fF8jRbD7AX0TPv7UZT762EX9fUu953N1Uzmbz/mN4l065\ntFZRcTG8+6jQ+7xtimIlvHuFMV3BLY+uRDhid3obTsVsh7cCPypUc8CbZFXjhiE8TdI7Irs8s6wt\nX7sLAPDU6585jzU0dgEANu3o7NVxA8Cd8z7C35fU521+y9fn7b7/3JD7tiJa1gxrA/w9tHFbBwCg\ntTPheZwD1orD/nOyabd0mN3DE8O7j1RZ6dWiCnblLQz7vxKc028FuiS5m83NKlsRfgTUgPVYesBa\nXDPv8z669itoT3Rgbct656V2FSP7EqhvMwPcp5qj4/vTbJ6vOvKMNk+6wtvVIiAUb/gV/J5Z93kP\nbFC2dZnHPbI64Hnc22wuEI2nsO7z1gE9tuGEAVM6bNUYnhjefaRIhU2PalNVO6itjQ1X5S7Sk7TY\n7CpbNgLpJnopfatYzArv0yedAgB4a9sK57VOv/A+a3DvqofwcdNa+K33T/ajSszXt5tvkhZ35d3X\n8M5s8hvoUcntYfO4R1T4PY+nXIP0NM3AnfM+xp1PfYT6hvYBPb6hzv6nM9AXZW5CCCz9cBt2tQ7u\nfAmlwlaN4Ynh3UeqrOLkoycAAA760qgCtvc2m8vCHd7Wn8E1YE3yJyAMCbLwwWc10UtyepKWmBaD\nLBQ89sIuHDhyf2xo24hdkd3m7uzAC5lBMn/Ty/D5zPdI9WPu8XwDX9yjzd39v+5BdYbct+VMM9/S\nEAKabmTNvFYq9mfOfD8tY5KWLTvN7oimMlxvvJw5zeaD2POweWcnHnulHr+a++7gHUQJsfIenhje\nfeSTFXzntC/jjh9NwWH7ju5xe6fytprNZeSqvN3hHYNIhqDrrv512Wo2N4CYFofQfdi8vRMnjDfv\nH/+0bZP5vN1vnjJHqTdGmyGpZgWZ7EezuZ5jGlAhhGeeb/e0oRHXKHhD6Wt4Zw9Yu/z3b2DO3z7o\n0/56w90FkDkoTcszYI0LbPRNb6vD9zc04sEX1xWlqozEzC6q4VqgsvIenhjefaTKKmRJwtiRuWdP\ny9o+Y7S5e7S63Q9uN5srqg7Jn4RIhKDpBnyKtYqY5K6844BuTtEaUioBAM+/vRHN7bF0FaOkB4nF\nVXOFMvfgqriW7hMvRGbl3RZvx/ee+ylWtb0PqGY4u0MtYbgCW+1bs3n2gDXzd7vSLaXWrvT88cmM\nFgv3eXT/LGekdyyhIZbo/2Q1A03TDWzd1TVg79fb6vCP89fivU92Y3cBTd2vf7itV5/FMATunPeR\n526NoYyV9/DE8O4j91SphVCt0eb2JC2q5FrW0x6wZjWLT5xo7lskQkhpRlazOWCGt6GZj8swt4+m\n4nhx+efpK20lHRqblXcANenp835iwzO4d9VD+KhpbUGfQc+oPjd1fI6ElsCyliUIHvkG4Is74W0I\nA5qhQYHVV9zHyjuzz3sgFwGx108Huq+8NU/l7Q3vK+5+C1fc/VaJjrB0HlrwCW55dOWA9eH3tTp0\n5k/Io7E9hsdfqcctj+afXz+ztaSxPYZ1W1rx10Wf9umYyg2ze3gqaXjffvvtuOCCC3DhhRdi9erV\nnudmzJiBiy66CLNnz8bs2bOxe/fuUh5K0Vz+le/hrP1metb2LoRdeUtWda1KKn71PWu61IxmcyVo\n9puKRAU03Ug3m9vN6rIBXegwUnZ4p5/fvKPTudIWcgp1FWMBAElEoY773FMl2iPUN1rN7T3pbrIH\nSTGgjGx0giypm8EXRJV1Avp2q1jSSHpaEIq95Gh3uqLp982cRU3zDFhzDTTsY7N5S6wVN7xzOza0\nbuzbDors/Q2NAFBQZVsMfa0OMy8oMxXy/0vmn2y4dX2w8h6eShbe//rXv7B161bMmzcPt912G267\n7basbebOnYvHHnsMjz32GMaNG1eqQymqI2sPxxn7fb3Xr3Oaza3wliDjgL1roMiS0w/ujDb3m1+Y\nduWdsjPErrytMBO6VZFb64VLio4dzRGzipEMQNYxOjAKFx30LfP5jACt9JnN7YVOr6plfAnYI95t\nysgmZxR2QrdmiDMqrffuW+W9UnseoWNegz20qbezRdW3fYZH1j2BVB/mWe+KuirvjLECqTxzm/e1\ngnz1izfRlmjH3DWPFbS9EAJPLKnHui2lvT2tKuTreaMi6Gu+9HSPfUE5LHX765DHPu/hqWThvWLF\nCpx66qkAgAMOOAAdHR0Ih8M9vGr4UuzR5s7tZVbftyJD2KPN7T5t1aq8k0GkdAML3ramFrUGrNnL\nhcIKbwjF83wkrjkBH1KDOGj0gZ7nbVU+c0BbOJk7vNvi7Xh03ZOQrIuJzCrHvtf8lJFnw4hXQK5u\ncyrUlNXfrYgghC47y6D2VhhmOMk1zX16/b2rHsL7uz/CxwV2Dbh1uirvzJDwDNJzN6FrfWz+tVpy\nNFHYRcb2pghe/WAb7pz3UZ/erzvukfWZF2yl4q4OOyNJrN3ckndbz/H11I3ShzJ6uGXdnlh5G0Lg\nd3//EP9c8flgH0rJ9G6asF5obm7GYYcd5vw+evRoNDU1oaqqynnspptuwvbt23HMMcfg2muvzeov\ndBs1qgKq2rum6p7U1lYXdX/dGdlsNT/azeaKitraavhUGYmUt887FJKBlFlZ+/wqWtpTwCjXJC5W\neAvdrIpGjrA+hxXOmiGchU1GVY/AXnXWrWzW/u3PXREIAl1AzIjmPBfLPnkbK3evQuAIGfH3T0f1\niJBnu+QXZrhVh6ogkgHIwSg0w0BtbTVi7eaAMkX2QST9gJrM+R7vbVuFUcEafHns/p7Ho/EUKoLp\nqk8Zux1GR61nm978/Xyh3v+9U64vPUOSPK/3B9LHJrv6XYMVfmc7dytBT+8dCJj/FDVDK+g4O+Lp\nC7Fi/3/c1pluUQm5Pk8pqT7FeZ9fPLQEja1R3Hftydhv75qsbSOx9EVVVXXQeV2u44y7rrnyfY6a\n1phnm5he+N9tKGgKpy+cC/k8w+Ezh6NJfNrQjk8b2nHJN78y4O8/IP9mSv4Olsz7ZK+88kpMmzYN\nNTU1uOKKK7B48WLMmjUr7+vb2orb91ZbW42mpoEbTdvVZX1BWAEqdKCpqQuyLOWYpMX6YhYyOrvi\nUCUFCddrneZva8BaW6s5ktsO/65Iup9Y0hR0tVnPWzO0NTZ2QpIkdMbMintXuAlf7GyCX/Z5+vI7\nwzFnv1IwjJaWCJpC5nsmkjpefOdTqOOARFQ4rQApI4mmpi7s6jAHOukpCdD8kIKRrPNtCAN3vvMQ\nAOCBGXc4j2/Y2oY7nlyF807e32yokAClphkpyfBML9ubv19LR5dn+22NYUgSMKG2yrPdZ9s7cO/T\nH+Pq849EY0u6RSIWT3le3+EKuIireb2tPepsF0+mq+jujvWtxmVY9Nkbzu/23ydTMqXj3U9247iD\n69DWnv73UKz/j1es24VdLVEctl/61sfWtuiA/DuJu85vo9XPvnFLC6p82Y2Dja576Ztawmiq9uf9\n99zSkm7ty/c5OjLOZVNzz68ZSlpb0/8f9/R5Bvp7sVQiroWEBvrzFPsc5rsQKFmzeV1dHZqb002d\njY2NqK1NV05nn302xowZA1VVMX36dNTX1+fazbDh3ELk6vMGkNHnbQ1YU+1FjmVougG/Yo3Ylg0E\nfIrTbG73ecPwNpvHk5qzTYUagk/2eZ5/7q3NeGfNTqfPOqEncd1bv8b1Cx/CZ9s7nGN2z58uBSOe\npuKuWNJpAZAMn3MsQtagG4bTbA5dgdB8kBQdCc3bdJ7Qc98+ttIaLLVw5RanA1JSNchVfR/5HE15\nJ0/59V/+hRsf/lfWds8s/QyRuIZnlm5yms1HVPg8zea6YeCL3el/nO4+b/e98O6R/d01Xc5bu8Dz\ne0TLfaG6YPnneHThBjz56saSNO3OXfAJFiz/HM0d6XM1UPO25+qXzddk7668e1qOtZAxEpnvM9xW\n4ervx+mIJPHBp03FOZgBMtz+hrmULLynTp2KxYsXAwDWrVuHuro6p8m8q6sLl156KZJJ88t85cqV\nOPDAA0t1KGUhff+vNe847D5vKV1NWuEtK1Z1bshIaQYCavo+74BPTt8CZjWb6xrM6t0K51hCd6rz\nkBqCIiuQhAzJev6fK7bi4X+uzxpwFg5twd3/SPehulcFU0Y1oiWRHhxlGMK5QJCFz6m8JUWDpgkk\nrNHmwlAgUubFR1vcezUaTaXf372wiv1FLqvmY0I3L07kEd5+UPMiQcsK5lw6k7mvhA1hYHc0/cVk\nT6aj6Qa6oklUhXwI+BVPiK1YuxtrXQPFtDyD19yz2el5phAzRPbjLbHcg9B2tZih/vmu0t7j3tKR\n/ruUMrzdrXG5Lm7yjST3hHcPo80LGayV+d7D7YvffQ76MjPh//z9Qzzw/JohNfVvrgmlhpuShfdX\nv/pVHHbYYbjwwgsxZ84c3HTTTXjuueewZMkSVFdXY/r06c5tZKNHj+62yXw4kK0Ba3a/tV15T6yt\nAmA1nctWVW6FNwwFKV0gqKbv8w74FSek7cldNN2AJBSn2Tye0JyAt5cTlaB41wuHQEJPoNJeKxyA\n0FTPVbq78lZrt+PvDQ8imdJR39BuTlpih7er8oaiIaUbSOr2iHgZ0Mzwrm/x3pIW1dKh61772/4y\ntS8OjC6zGVcOeQc8aprAw2sfw8+W3YRIKuoJccMQePHtLc5kOJ3J3IG3cMur+M27v3fudbfvCtB0\nga5oCtUVPvhUb3hv3tHh2UfKM2DNtba5PUJd0pHQcg9Ei2vZrQ/5Rv/b/w/phijpl1OLq0uglMud\nunMkZ3jnCdGwK7x7Or5CgjgrvHvY519eXo///r/lPe63XLg/X19Gntu3Cw6lqX/zXSwPJyXt877u\nuus8vx988MHOzxdffDEuvvjiUr59WXEqbzugrdHml5xxMPbfewdeiSsw7AFp9n+FDE0zYOj23Oe6\n2Wwup8MdsKoj4Qp1pG/NspcTlYXqHW2uaBAQ2K9mknO/t4hXpudgR+4Q+cvL6/Gv9Y046/9NgqRo\nELoC3YCn8tZ1A0nDWr5UV2AkzGOYt+kZHD/hSAStVdJirvDuSHRiZMAcnGR/v8jWoDsRr4DQFUhB\n7/GkdANrms1j//mymwEAPzjsIhwz7ij8a/1uzH97C0LHCEABOhPp4HdXa+/sMCfvWLV7DY6qPdwJ\n70RKRziWwoSxlYgndU94B/2utdxlAy0j/gUlPgJSMIIW3Q/AHHyXTBmAZCB45FuYv6kT3z3sW1nn\nM7P1AzAHreVi37FgGKLHirM/8lXeiZSOrmgSY2sKm1WwJ+4gyZWx+YI3Ek+fn54uYgoZaZ35Pj0F\n/turd1rbGVDk8p/nyhPehkAP89rkNZRaJIbSsfZV+f+fN0yMqQlaP3mbzasr/Pi3KftClc1wkkc2\nIpwKQ7Kq8ZRuIJGy/keUzD5vJ9ytyjulGea93q5wloLm1fLY0Bjzd6E4zeZAetWyoBLElV/5sfmg\nrHtmrAqnIhgd9C668q/1Zn/0xoYOc1CcrkI3BISRWXlbzea6DL3xSxBJM7Cjrv5cd3gv3rrUqdad\nudntixFdhYhXWp/JfZtQdoDtipjHZ37BC2cZ1Q5X5e2e6tReS317s9msbs+E12atJmZW3rInxGLW\nQLTbLjsBoTFtSFR/Dv8Bq+GbsAmrxItOU3hS0yH545D8CWzsyD0NbVw3gzKoBHDKPicCQM570qOp\nmPP31Q0BrciVhbs5tdm1drm7JeGOJ1bh539a4Zl5rj/0HirCfBcoUdd0sz1V3oWFd+ZtgIV98Wd2\nKQgh8PFnzWU3Ha773PYn1IbSLWdsNqeiGVUdwG2XneBUywq8k18okgJJ1RD48ofYEdllzaomIaUZ\nSKYEhCFDkg34fa6QFq5lPo2McA5EASFhTMhscpZyVN4AsGZjJ3738GdQjRAgG051J4RAOBVBlTWR\nS6akpluVt2p++en2RDEaNF044W1oMiBk6O3mYEU7oAHvILKPm9bipS3mGIms6V11FUas0hz17k+/\nRtMMyJL3f2FP8Lmmh+1IdDrv51621J4AJ2m9zl533V6UpLrSD58qw3AtwBK3ngv61Zwz7dmfMakZ\nkHxmOLfEWz2f3WZX3idNnIqJVXtnfwaYs9XNee9ObAq+CkBYK6sV98vJHUSeytsVjvZ88u5m6/5w\nh0GuUMn3GfU83RQ5t+1L5a27jyv//jOPb/naXbj3mdX466INPb7nQHJ/hP5c8w2lanYoXWj0FcN7\nAO01phIHpk6D3joOU+r+n+e5zBHGftkHVTFHmyeSulllywZURXaazYWr2VwYCiRfylkARA5GIeuh\n9LzoIqMyt5qk7XlzzIsD3ak8E3oSmqGhUq1AcrN5n6Q9hzoAJDTdDEddhaYJT5/3X15ej664GZS6\nZq+mZr7WDnXAW3kDwEeNa8xNnT5vb+UNAFIo3XSe0DSnYjyq1jzGlGEHp56ezAaAgMD7uz92nks/\nYU9ba430z2hTHFFhhjeQDji7sgr6FShq9rdh0khCNwxsbwxD8iec998VzZ4C2J7oJqQGnb9VKiPk\nP2hcjY5kJ8LyLsjVbVafd3Er77hrBTXPimnWZ3avsFasL3F3tZ85h735Prk/o2dq2j40my98dytW\nb2rJu02+VfIyZVbe9iDGTdtLv2hOb3gGBvbjNoVi/z9XSkPpQqOvGN4D7IpZU3H1cT/AtMO+1O12\nqqzCp5qVdyJlhndNtYq6kaH0wDO72Vw3zIFhAEJfXQqoCUj+BJRUFYQQ+OPzaxCNCUiygNPsbM8X\n7p5iVU734dn93RVqJTsCVIYAACAASURBVPTmCdA7R8OADsBuEk5Bks3Qjic1T5/3Z9s6sGqTGVS6\nZlXyVmVu94UDQDSjv7cl3obGaHO6/9OpvBWIlNns7p5mNZyMQkDgyNrD8c39Z5rnwqpaw7GU83q9\nrQ5CAG82LEdjW8QTRPY99kKyBt9l3F49wmo2B9Jf1vGEBglAwK84I+LdknoKTy/dhKde/wzwpZug\nd4R3ZW1rV95BNQjVuqVPM7zhvXLXh87PytjtVp934V9OWzq+wJ/XPo5wMpJ3tHE8zxzg9mduaDSv\n8tQJG/HytpcKfu/u9NRsnm+kuztce2w2z9hvNK7h6Tc24Z6nP05v002fd+b+3ecvc8rcZmtA11in\ni6w86D20cBS8nyE09Vyxu5XKEcN7gPl9Cg6eNKrb2eQAwCer8Cmy1WyuQ5FVCDkJQ04BkvWl4fR5\n604VDgDKSPP+eilZiVhCw/ufNmXdCy75zdCwQ1FYfeZ25b0zYgbNCL81QYDzeiu8EXFeH4lrnsob\nSFfYuqZ4Xp9wVd6fN5mVynXH/Bhn7GtOpdueaE9/Qcqu+9l17/EDQJc1rWuVr8IJPrvyjsRS6dHq\nsSroLXthV2wXfvXss3jh7S3pE23tLwnzizfzy7raVXl/vH0zVu/YjFhSRzCgQJYkyEqu8E5imTWo\nya68ge7DO6QE0pV3RrN5S7wN1b4qyEKFXNmRNWBtV2Q33tuZf33zf9TPx6rG1bjx5b/i9sdzbxfP\n009rn4+GRnNMgG/CJqxu/xDhWApPLKnvVxO6O0dyZUrmMqw276IwvWs2D8ey++u7azbPvIBwH1Pm\nc01Wd8OoEYFuj2mgGT3cklfwfoZQNbsn9HkP2Axr1DuqYt5fHEtqSGoGAlAQTnXhXflvgGwu4iKE\nq9lcl5wFFeRqs0lQSoUQtkfmWkHv+9IGpD4/DFLAXrnMHDls6GZzvH070spdq8ztIxMAtDmVM2Qd\nMFTEpU4oMEeoR42Uq/K2mrqtCwwtZVW2OZrN12zdBXUsMMJf5dzSFtPi6S8J2T52NT0JnSss7TnZ\nK32V8Cne4AvHNE+fubbjAKhjd0Ie0YpVG9OTB9n3w8eMMO54/37ElNEA9nKe19ROrK94BsrY/fBk\nwyLz/RJnOyPOJVflbUSrIFeEkdCTqAgoiCU0p88bAHZEssM77qq884V3OBnBmNAoSMlKdIR2Q0fK\nEzi3vncnAKAmMAL71UxCwJ7Uxz4uawBdomIbNm34ctYxAN5mczc7HKMJLf33ADD3pbVYs6kNmiHw\nvZkH5XxtT/L1ecuSBEOIvCuCefq8ezlgzT1ffa73BrxVW+b+3bPmuS/0hBDOQL5yC7nM0eZ9NZQC\nkc3mNGh8soqAT0VXxPyySfc3Cydw7C/7ZMqAcC2bKVeYVZIwZIStLys7PNW6bZBrmiFb4W3Ezfu8\nDatvWlENGMLA6uZPEMIIvPCKNWGIYc+/bo149pnNqCJe4am81boGyKN2Oc3Qeko2mxFzhLd7xLs7\nvJ2Kxj2TnD0Lnavytu9Dr/RVOLPI2f3F4VjK2b/QVAgtPdGNh7WNhhS2djag0f+x5+k2YxeSUgT+\n/dMLm8STOkKBdDcBACTWHwe9dbzzGYP281blXaFU5q68dbvPO2Teiw/vrWIpQ0Ncj6PaV4VqqRaS\nBBiBzpxNyvd/NBd3rLwPQgi8/uE2ayY2gdZ4m3ksqubchZApka/Z3AooTRee127aaf5/YfSjedId\nJO4+b7v1J6nluaDoplk7U+aXeFeOkfLdNptrmeGt53zOfftavhYDIQSefXOTZxbDgeAZbd6Ppu+B\nWqSmGPaE+7wZ3mXi9Emn4OBRBzr3OvtkFUG/4vzDE3L6S8eerGREyAy8aELzNM/KlWZ4G7qcbtbU\nXaOiJQEpEDOraWsCFfteclk2oBk6UkYKWjQEZ35Su9ncqnxl64vciFeiPZxIr3AGwDfhMwgrZFMp\nCaOqA1AlMzyT1rSpumFO8iKEOUNb0BXedsUl5HSfd2azPwBENKvyViucCxk7+CKxFKCmK3e4lk1N\nnwcjPSFOPkp2c3IslUDQnx5dDwBC8zvvsaO1AxV2ePviEJoPtYFx6Eh2eia+AbwD1pZ+YIb79tb0\ngCd7lrsqfyWqYN72JwKdeQcP7Yo2YmP7Jjz+Sj2WvN+AjmSXZzIcuTJ7MFVcS2Bly/KsVeeAdEBp\nugHZdZ99NGVdlAT7vmSonmcglWKHd54Q9FbehQ9YE0KgM9q7ZvPsyts1sM89IY/r4iffRcfnu7rw\nzxVbcftj+bs4SsGd17kGBvbEPb/AUFGs1oZyxvAuE//fAWfgJ0dfZt0iBqiy2WxuS0npL2A7qGsq\nrfCOa5B82TN1CUN2+vjcfeKSrEMKRK0mc8nZFgBk1XACUNcl1768fd72hCkiXoH2cDIdrjCb0u3K\nW+gKVEXGiKDZPG/fLhWOpszPofmh6cKpvONaHAnNACCQ8rUDwgxGu9nefTucHUqVvgrzVjtIzoC4\ncCzlDG4TKX+Oyl3Af9BKz2fPJQVr/vdPj4HeYYYnanYh4LcvqlyD6qxz8MTrG8zKXElBCkZhRKsw\n2mfeKpdZfcdc4b16o1khN7anJ5SxZ56r9lVBMsygFJLebRW0rvlT5+ftHebAwUlV5gBJKZQ9Teyz\nG1/EB13L4Nvn06zn7PBKaYZ5+6HFvmjpzz3NIk+zuT1oMpmnP7uvfd66Yc6alykz4LuvvHM3m7tb\nLqJGFz5sXJ31Pok8XROl1t8Ba3Z4r9ncUvKpeYuluwuw7vx10QZnEp5yx/AuM4p137JPVhH0eatl\nN2HIGGmFdyyhQW8dl7Uvs/K2vmxEOoilQBSSqkEkXTNluSpbu8/VHinuft4OTykQMydesSZnCfpV\nnDzCmkFMNiAk3ZqaVIJPlVFTYb5XJGmGVWs4BikQgxGvQDJleJrNkykdck0z9EA7KhP7mHO4Z1T+\nABC1Ku8qfyUkSYJPVhFJxLHgnS1WeNvN5n4AMoQhpcPfl4AywgxLvWW8s09ZT48U3m+vaucWPpEM\nQsTMufn9B6xG20irerIH1bmqe8jmjGxyVbvZzN01GmN85t9nc8fnnr+RHd6bGiLOHPbuPm+7X7/K\nX+lZtz39hZT9ZRxOpPvZd0fN/v0vjzBnN5QrO7NGnO+KmhPb5Ap2O7xSuuFtclfSa8c/99YmvPbB\ntqzX9iTfaPN05e0Nu2ff3ISXln+OsNTsdH/0ps9b13NX3lpGuOVbqx0AYnmazd2tBLvHvYSH1z6O\n19au97z2b1/8H/wHZy+GUyhDCM/FQ8Gv62cVav89GhrD+M2j72PTjg7c98xqRON9v3ArplxjI/py\nwZJM6Xjzox34y8vre964DDC8y4w9baoqKZ7KO4sho6bKbPKOJjSkPj8MX459A0YsPamK0CWn8naW\nEQUgWc3u9qxn9v4As9nZvlXJXXlnjVZXXCPMAVQEVewd2MfaRoMOzemHVhUZFX4zFCNWsOzobIYk\nCYh4BZKajpCSEd6VZr9gZXQ/81hzNJt3psyFEsYEzYrYp/iwszWM55dtMf/B+lyVN2BeaNjH75rn\nXa4I43jpAnNbpB8/cOJIp5lbaH7rIsDU5WswH5fTg+Ls1gF17834rKUBcrV5cWCER2Kczzw3G9q8\nM63t6mqF0GX88dkNCPnM/bv7vO1b9qp9Va7WAyNdfarZYbR5d5vzc7u1GEyNOhpGIgg51JX1ZaZI\n9oWZAXWfDVDqvjB/l1zN5prhad2RrM8djafw0vKt+PuSeuyM7Ma8T5/POV97Lkae+7ztqYTXbmnF\nR5+lBxf+c8VWvLBuGT6vfhnq3uY8+T32eXtmFzN6rLx1XXQ72txTeWu5K2+7p+mJ1z51LpQMYaAj\n1QZlRO5FZwpx3zOr8V93vdVta4emG57lMM337t993plTwP7+yVX46LNmvPnR9l7vq9iefXMTfnTn\nm9jZ4p06ubtBh24bt7Wjrcv8/zVfyLd0xEs6HXFfMbzLjF15GxCe8E6PJbcYCmoqrfCOpwChoEau\n9TRf64bkDFjzfPFat4l5mrqtn9uqV6fvv3Y1J2eFp6x7Xl8RUOFTFXMOckWH4QlvCRX/P3vfGW9H\nVa/9TN/19H5OzknvIR0SEjpEulIFiShYLyI2BEQR9PpD5aJX5d5XQbHAtYAIypULWABpIXRIg5De\nc0pO3XXKej+sMmv2npOQkJAE5vlAOHvKXrNm9jzr356/wfTMWax0xyDt5EUKSdiOhxjTYM+5eRSY\nJjgdgxEYq+w2H3D7YGkmKkxqERuqESB33eR9z/k51OD4GZztI2F5FehIjxDu/mRMxylzRvgxascA\nsf34rgGLndIG8VhnODZGNZaFPvlpkZvgZSphKnG0pVqwrm+DiGl7xENPoQskT5vT8LwDuXXqoBTz\nlhdQPO6rsAXKmMqRmBc/EwCwTYqZ9xeY7CuJg2QroJhF9GT7AxKnQo42MQijeQN0Rt5xU4fteMg5\nebylPwa10idSYXnnfCL58Su348mtS/D0tucwHAghIg8j2DDD30f+/Cf3Bd3PWgO18DU2Fsfx8Pra\nnmFL1uRzOR4JlXYtzXrfXZ33cAlroXFuhYhEtqCG/b4RAReWkRvHlOJbv34Bn//RUwGyGS488Xah\nlogfcC/Du9Uudnd4aMlGAMCK9cFFUVAlL/yaO3uz+O7/vIyb734RQDjJ9/Tn8dWfPov/vPe1sm0H\nGxF5H2Lgcp8e8QJu85OSl+CyKR8RWeeEqEjFDWiqItxXhq4G4rfE8RPWvLxvkYsabznWy+uwY9vx\n5zUPBT4D4Mufqi4AAqjB2vKERevS4eo0EU3xyds0NCTMIHl35ujLl1reXiDmXXRcEVsnpCRhTvXd\nxUNuH+pitejpz+P7v30Zjh20qDXTptYwczcTT2rqwv51drbD3dXC+qYbgOohldBx2xePRW1lDBk7\nA0MxAaIGLG9DYeSt+AI1gfkC/GQ3x4DjemhPt8EhDr551xMAgK5cD4jiwcumEDM1DGR4jbwjXLfC\n8jZTUtzft7x5XH989RjUaC1iO8cga4X6qwfXw8vSmv2/L1+BL972NJ5bSePv/YX+wHGinaylwXE9\nPLVlCfr1jZClCbjl3ZcpAIoL64h/iYXGa10rUIqubA8Gi0P4zSNv4KofP4Wt3ZlhNbdLX7YD2SJW\n71oLrWETVOba91gI47W1PfjRH1/Dd38d7o4W51IdvLTzFQxk/UUst4qD3+1hlf009DYa/y+zvAsS\nebthbnP/M0X1hFUnt9flHqF9xe6M561d9HmRPQHDLYzeLrRS5aJDEKW6GYFF2zBW86ad9J70MC3/\nsORH3tt+1cbesm0HGxF5H2KQyduSyLs53YA5jTNgKiwm66mIWzoqkqZI7DE0FfUVPkm7riLI29ky\nDsUNkwFIwiEy2Uj//0bvWwDoAsHfTv/fHLNMxHJlyzwRM6DrKrW8VQdEcaEpPB6uIcXc5isHX8fD\n6/+B3iJzKRcSKNgutnXmoCoqNnX3omh7Qq5UfAdhMWtOiEYBLhxk+k1c87MleHNzHwaGXMhtTxW9\n6LvMAboA4W5zVodOHJal7hIYGvMU6P6POGNnYaksN0Amb7Yw8VAU4QNSQt6KXqTKdVDgekR0U4Pq\nghCC7Sx5jeTSKBRdZLJ+jTx/6QvL20j6fd+lmDe3vFNGCiopDy2IVquOKch7xY4NAIA7HlwJQgj6\nCiWlSxqXf6WWt0tCrErNRUXSxMBQEUosCzVGX3KqomJd/4aApekRD//x4m345Yrf4cnXaDLQ+m0D\nw8a8S8l73dYB/PjV22GOXClkfVHiiXpdcq/L4C9xo/0N3Lf+TxhM+fFM/rIujXFvJstgtKwXf3O8\n2rUcAwU/L6DUba5YGcSP/Jv/5ap/H4ekKgO5MmRf8HZ6cpcuSDj25DbfuGMA/3X/soDrfTjyPpRy\nuEuH+HZi3p0lLU7DKjhMQyv77FBBRN6HGHi3MZd4Abd5OkGJQybvmKkFpBh1XUVrbYX4m3gaeofY\nKp9ocDtHiAYn/Bz+viGPQpjbHIA1eSn9n5KYNz1GBzQXRHVE85WYqSNp+eP86/q/+brmtonnV+7E\nt3/zIlxbw2CBveTYGF23xDvAPuelal2d0o+LaDSBTKUdxVy1ECBcriInn58vWlzPg6XSfXVDJu8M\nYiodu6gVB+ApLC9AsUXSXqAcDzSWbjHCdlxPiKcomgvXI9gyRInMy6XYi1AR96efuXeHbE7eKSGB\nC9XzrT5meadNSu6EoKScboiGDYgKkqXPhpyYtq2/F45EzqpnsnI6D6ahwnY9IfIiQzc8NFbHA+1n\nTxt5MmY3TAfgl8ABtAFNxsliTd86keCn6wrk05JhyAYA1m4rr4tWNV8NcHcQOvkshGEnt4ltnHxl\nK01uHSvvs7p3LX6+7C68YD8ItXon9Kb1cFwPHgsDFG0XetPGkkF6tIwS/n0EEBDuKRsv8fZIzm8n\nbC1n4e+N5X3lfzyOl1d34bkV5Tr85eMoP9ef1/zfbtX+9hV7Gnep5e0GLO/wY3mf8mSMl5mW73co\nl5lF5H2IgVvepJS848wFzcmbqIgZQfI2NEVYhAAATxUPKIXix39RalmHrDBD3OoyFATd5lTpTYei\nuVAU1gwFQNzUUBEP9oC2XZ6lreI1Fssjju5b1oxcbVt+80iWM3P9J9W0fz2uCkUliM/5BxQzD6J4\nAVc37bxGaDy9pDOb6xKYjFw1g373uv4NsD0HMS3BxufPnY0C8k4BLmxhvYfNkaHQc7oegcoFDVUX\nRdtFX44nDkpa2GyBMsAWXZt6ekA8BU7RD4koiivkTLmLO2kkqTEqhwYA5NwMVI/OPfdCFIlv+e0c\npLHCZHYU8svno4KwzHvdEfK8A5LL18vTcyUTCpK8xpuPQU+IOZRlcLn17xFPJPFpqor/emCZf97d\nSHgOZotlOR+q7ore67sDf4nzygoS8/MBuFXtegSKlYU17Um83hOMsfMXf3+BHpdVemGNewVG+5so\n2g4efHo9rvrxU1i5cRfUip7AsUrA8vYTqoazvPNOAf/+3K349crf7/aa9qTbrVZ24U3mPQPefsxb\nnvdk3F+YD7eYKP246Nr4+6YncNeqe3Y7vr2F63n42h1LcO9j4W11AZQ6YgJW9HCaCNt76LuxtoL+\n/sLc6283Uc0jBKs29r6jxi97i4i8DzHwhDW3JObNLW9D5a5XD5apo7bSJ0VdV8vIuxQyAQXIhoQ8\nCnsgd/m7EjFK3rL1yWO0MVNDMhaU7LSJLb6DJxEpngHD8jCiIeWXAknhQSK3PWX/5vNAR2MaJ8xs\nDYxXYdnqPMnMMjV/PjSnrDOb43rCbc47hf34lTsAUPUzOugY7E1UCtT2CljeQ12w3mBN4FwyTGbN\nP7NsOx54YpMYe9HxROa9fJ9UaIDioj9TxKadg+jNDQKOibVbB0RCG1RPlCxx8t64Nc+6z0neBcVD\nkRRgEDZ+fq3En1SejY5CEqZdA5PF8hXNgc403Tlx0fmk280YEd4WbnnHtLjwLshKev1539LnBNfd\nnwsmj9VtwD82/QuATzAXnzyOnstxxaLW3jyOnch5W+QtkvGYkp6iEtF5T7a89Za1UONZPLL1kcDx\nolQupMd63inikefpPX3hra1Q48GMZyhyzFsi72Es78c2P4nOXDde3PkqVvS8OSxp7l6m1IM14SX8\nz9q7sXGAVkS83WzzjTv9+/R2Er5K8wGyUmfEMG/NviKTc9DVl8fGnYPIOTlsGiwvS1R3Y3kPN/4u\nFs/mW12XAEYe8SMfwWObnsTTr2/H7//xVuixpXjxjU6ahf9WePjmQCAi70MMgZh3wG1OicVQfOvN\nKnWba6qwfADfsm6o9gl+WPIOURIjw7jNOUbUVfqHqwomtlcFysc42cRMXciJcnDxEz6GuKVjYmsD\nHGLjcxeNRnUFHWdBTiJ2Zbc3V3BTkU4YaKyOB65HJOUxy7syYYpriM96TKjQnbuQkoHjEphsMdLX\n9Dh6cr2iZGt2zZHivM6OUXD7a2ATGy/upPrvXBa11G0OACZbbK3fPuhnzGsOirYrVMrkudVVHVA9\nPL9qJ2761QtQDBq394h0P1TXLxdi5L1lexH5okv34d4JVmGQGeT3UQFxNbjwJ7UvT4nZKRiIW5oY\nLzQbpk7H1S/FefkCSTdcn7wNej5L8cm74PrWZU/Od3trjLxL1dOMjlV4YM1DcDwXhAAT26tw1GRa\nG19winCJC7evDs72MaKigTeMCUPfUAGPv7JVkJDcjc4cuRJqRXfA8pZDQNLFis5hgYQzBqphH1zA\naH3t0gLDldzmEnlb4eQtJ/r9v9fuxLLulaH7DVceV3RtaHV+WODxzc8AKPdqEELws9d/jf9dG1yo\n9PTnWRc8EiDm4cgvb7t4+LmNWL2ZlmxmbT+GvCvfF3rMvoCX5xUdF//96p34/gs/KRM7KvXWBGr3\nh1ns8NACfw4c1xNVDH9a81f86onnsKnTv++7s8K3sERBfr/fDUTkfYhBVcOzzbmVoTGZUUV1ETM0\n1ErkPba1UsiE0pPQ40c2paXPJHeYbJmXan5Lx7O9yzaPb6kVC4yBrI3KlIXjjmgX23lTkpipiZcc\nh6uxlxnLJm+pTeDoFkqSd6+6V4xHdpvLMWthgbsa0gkDDdWJwPWUlsNVJM3A9ei1NN5cnaIuccfz\nhCeBqDbuXf0AvcaqMWhPjQheuEv3W9e3EaYSA8nR+f33y+eXzRG3vAH4Cxtmeedt+sKvTib8/TUD\niurhjU19gOJC0VwQx4TrefA8iJg4J29uUaowaRmT7DYXjVmkBZur0wx5Bp4QZ+cNxC1dkLeiOeLe\n9hcly7tAnzdV84TbnBOXqcZD3ea9WXo8IUy6Vy+KEkYK/9nryXK3ugKTkXPe4wsxQ1wDNCcQ81ZT\nvbjnzT8LBb8f3PMq7n70TSxdxWK3hpSAVbMT1sQXBQm6HvGrGmRIuQWDdjl5F11byMNyD4ilxv3K\nDtUT3gWZvLmGQSl6MoMwSAIN8ToAwxPgcKpyj21+MqDBz/NKZPL9xV9XYe2urVjWvRKPbHxMfH7/\nk2tx9zPPIT7zceitawLqdjIxquke4bnY3p3BH59Yi+/9lraslWV4d2a7Qse4L8ixDP9C0cP6Aerp\nKG3y44sJ2VjbtwHLi0/AGE1DII7r4anXt+FP/1obOIaHRGQJYPkdEZu6BErCf/YzuxGl6WFW/HCS\nvgcCEXkfYjihbSEA4JSO42GZ5daAcFUzy7u1jr4oJnVUY1RzhbAeAQh3LHe5A74rm26XasI7R8Dp\nbIPT2eZvl9zQJJeCN1SJRNHvuGVqJs4/kVoZU0ZS17HIqAZQLPjkHbf0gIAMjAIjW7pPU20Ccxpn\nYFrdJKzr34AhbTs7h3TxcsyaK615GtIJk1qBsopcSUb9rPH1qEun/HMxxboYUzVzXSL01wFfxtXQ\nDJhG8GfCSSTjZGEp/uKptT6FUpiaLITj66sXbQ95FhNorvGTDGO6IVnOvshMNu/QlzBbwMiWNyGA\n6rG+6l65d0K27ImniVp2wE+kKuboPbKYWA50h1U7EAzZQ1DsONA5GvYW1pVMc0SiD0+aMxCDxa5X\nJm/umvcGqJiOVrsNg3JrTql0atsQJVtVVYVlXfAYKTAvCl3EObAMSU9/1HI8ufVZPLLhn3A9V5RM\n9bA2nUqImI3sNlfCyrcUV7yMB/dgeXMPSEz1PUCK6olqD368O1gFNZ4RLm0ZWTuHQk7DB0efAcDv\nA5ArOPjt31eL/Rw3PKntjV1BFy8PXcge7KGcjd89/1TZsX99diNyBiVEo3VtoFc5J38lNgRr0guw\nJtM6/lK1uqyUUf+/ax8WvyEZL+54Bc9L/elLUXSL5fr/kuXN8asVv8OPX75dhMe4Vfzwhn/ihy//\nP2zxVkKv2waoDhzPw6/+7w08tGRjwAshpH+55e2RMg+kKpP3btrfdrPnbLgGPwcCEXkfYphcOwE/\nOf67mNVwBCyj/PYYnGBUDzFTQ1XKwq1XHI0vf5hm+fK4LQBBvgH3ouwelC1vosHeMBVeVs5WD24v\nrJyPKUnfhWxqBj588nh89zPzMGMctRZiElnlmfEbs6jbvLBsIYobJ4rtinR+njRycvvx9OsUjyXE\nlMfdrSnPCsubeNTytgzNT3aDH1fk1xC3NMwd7y88eDcxUzOggCa1aFKHXMI8Dbyvugw5NGCqscC2\ns0afiqq+2eLvQHvOgHyqi6JXBCEKWup8z0jcNH0vCCccx0Qmb9MsbE+lMe+Cr3QH1wAhCLjNdU3y\nTngl91+aJy7/6hQMxE0NMS3OzmvT5iu6DZe4UPKViO+a5hOo6kiWN53LJ1/y+8bLMW+e8ObsGAli\nGzDa30Rfwbcqq6sly3Dlb2BNfRquPgRNVaEqCoqk3PImqitCSfLcPrrxMVz1xNegN9FSL9cjwoPh\nZYOLq0DCWgi5Q/WEBSqTN/GYfKtr+6ED9jwljARkHf3BnA3Xc7FhYBNMLyUWMLe8eBs2D/ou7oJt\nU8liR0c2xyxCRn6PLN0kyc8SbM9tx5cfvwlPrH8hMNwEy81wulqhQRM6/4E4t+KhR6WJX2rp619a\n5PXYdBHVP1TwNQXYgpiXBZaSmWx5bx7ahqU7yrPOf7Xy9/jNyj+Iv4u2G9B8v/Wl/8Y1T92EXfle\n/HrFH7Az0yme9VIZ1NV9a6HX00UQJ+A1fesD+6iJwYDbnH+XJ2nYc0+G63riPnLIev6lynUyIvKO\nAADQVPYjUspdedzyVlRPuNJrKmJCwjBgeTOrNpDMEaKqxlGZNANxW0teCDA0VfrkbmoGFEVBY7Xv\n9pXJm7+EYqbGXKBKILNazlavTtPjWpK+znhpkhx3hauJId8t7lLL2zTUACnpsWLgHKauBRc2jNhM\nzYSmqXC8oOXNyVtX9fLYqpQ3YKlWYNOpI09ERW6cv13zr5d7PbTa7XijawOKjg14KtobZcvbpN4F\nkDLL2/OISNrjpMFnvQAAIABJREFULwlFt0EcAx4jb3gaFIVlC3P3eWkSIRfaAc1GB2huQNzS/fun\nOYiZmp87UIhTS5yoILaBIjJSzLsI4mpYsqwbdz9MXZM524/r8nixl6mE09kORSEYcCh5X/Ghqaiu\nCU6vmhhCwaRuV8NQy8ibXoODdFJ+PoOWqNEuNVlhiwsu7AJQFz63vF3J8k5qKcQd2kRGUT3YbJ4H\n7SHEtBgaN58HZ/toAIBDbL8Gmn1Hykj4vyvFQ6Ho4q3e9cg5eaSdVnj9dWIMXUO+KtiWHuZKdw0M\nZei4RJMdRhp6y1rE5vwdD+/6LYrI4c8rngxc85Cdpde1fip01RALKNntrbe+BcegnhAufyyseOn3\ns9T5E17dvAFf+q9nxCLHigWJqbQ3Oq9l93rpb3ht34bAdtlb4HgObNfG13/+HP7th/8Sn29l5ZM3\nPPtdvLDzZfxr6xLkmOVdCHNJMw8cH2NNrDqwWUkMBkrF1vZuxuceuwbLu94Qn8ltb6EHr4mrJAI0\ncY4QgnvefACvdvnhCdvx0McSE4frQX8gEJH3IYzm2gROmNWKL15whPgskE0eAqOEcDVVCcgbkuEs\nbwC1lbEAoafiQWICgNYaP0lNjudyWJLb3M821/06TEk0hYu4AEBVih4XsFRLM+Cl97OaYGVWsuWt\ny+QddJsbuio0vGWYqgFNU6h2t7SY4Mk3pmqUkbec9CeTsxibNN+xEMtbjWXxt/7f04Q1TxPudtNQ\nYeq+Z0V0RXNMZPIOtSCY5e1fqA04BmzH893mAGJxlLnNP3bqBEHufFvOy9JnytMRs3RhvampPjqn\njLy9giU8EKQYR8YblFzGRX9O2Hf15XwrLONkaEzZMcR+3EqLWRo8zd83zix/T6X3z9BUOGD3UnwH\n/d5kXAqTlHTVcwerpG1sIWeb8DIVYpz8he95HqAX4RXi+FjHlTAddqzqsg531PJOm0nYDoSHpugV\nxQKAex8qrKT4XXHPx7Iu6vKOF5vhDVXD3joGAPDcm742+NZeupghjo6BQWZpMsubX6Wa3hUoAyxk\ng7+/jJ1hc6RAgyFCF778bT/05vVAMQEvkxZiQaVqfRxLNvtlc+NHVOH8U/zcDzW1C/IP8ub/eQld\ng9TFbO8YgYQex/qSJjy25xPj7ct+gy8/eQN2ubQpztbuTKjVammmkKQNI0au9Oc4dCxyxjtA3d6y\nbsCDq/8BAPjjW38Wn/FjHdcLvEMAQElI5J230Vvow5Nbl+Dny+4Sn+8azIuZiCzvCACoxfzRRRNw\nxBh/tT4hdQTcwSoU3pwdekwpucdMLaiQJJM3CZKZkDdlkBOpONpqq6T9yxcSlaZvRfKXWMyULT//\n/LpE3tzy1lTNJ9mSxUVx3RHwWMIUb0nKY96moQlXOOCX9nDi0jU1IBwiX4OuKtjUOYTfPbpOfM7L\no/RQ8vYXKLy0SoYWIG95MVOSw6C6UIkuLNiKhCnlNDjCCiCOQd3mhJM3LwVzoai0tj5XcFDgbnMA\nlim7zTUoACxDCyTNAUCB5JDQaC5CwtIRN+j86rU70KmsgWKypKd8TMwDKcThEgdEp5nJil70xXDY\ngi3LuscRQtDv7PLbz7Lvz7t0e6ezCTuTVNr0kxMvw1ktF9Ahs/71pqHCUYKVA2JRKB5PAhhFjEx3\n4OjmuXRqpC588iKosHIezGKNyDsAaLKiYhQB2/QXSACgUMvbIx6G7AzSZgq27Sc2Op4jkTf9jsp4\nWlLCo9v6cixbv0jnmeTpwAekBc6OPhZGcA08vIS6yGWyo99B5X7zrx1L50hx8M+XttA+5ZkiuocG\nxBzpii5i5tzw1iq7oShAYeN4EMeEogDZYkHEkkvj/tttP8FLU5WA0Iw1+XmYE18AJ/A1W/rx6rrt\nbJ4NjKrsQHd+FwaKfqWCHMte2fMmrftnv+MbfrEUP/uLb81yFN2i0DRwPSJCFv7AWNWJ68sJK0RF\n7oVF9JqsbMBtvnEbnfNdhV6YE16A3voWc6F71I3O3iFpvQLENpnbnC0M8g5eWeu3C/WIhy2dQ3h+\nVac/3ihhLcJwiGtxFFfNg9dfH7q9lLwtUyuxvOWENXr7501uxPc/O5+2/pMs7/GtJf5MMMuCn1sr\nt7wbEv5Cg1tIPGv5e5+Zh7PmjQ0dK7e8AYiM5dIMYJJP4SOTzg1+oWR5u92tZePxX8R+9q0MQzWk\nemH/+3gs2ND0snri0fX+NRoh5C3PtxJI6C8JA6gudMUQZXQVSRNtKRqX16q6AuSbZZa3r89OAuSe\nLTjCbQ4Ahkl8kvc0aBpLAJOS5gACm+TAeSNmakjr/uIrhz5R1uQWLDEPXoH1ZlcGac9yzRPhEL5Y\n4q1fu3O7YKMgLF5ueXsKJYo3B/3yqE1bbdz10Dq2ncdXM7Br3gKgCNLjC4BYjL2U9SIUBUjoSVw0\n4TwoTgyKbqOphu4vXP/FGEBUmAodP7f+O6uepIsgT4Preb4YDot5d2W74REP9fE62K4nFp02sSWl\nO2Z5mwkpt4Fu4yEEl1Vf8DnK2QUUbRe/fGgVVm1hjXpcXWyX8wb4dRLHFGI7ikoT2VZv7sMdf10B\nWymI+VWhS25ztsBgdegkmxZz2DkwhKLtQW9eC60mqKrW7/o1y6qqBKRhAdAOaZK1nnPZ78s10Jyg\nZX49OT80kA35/cmu+tfWlau6FdyicJtD8VhIyQfPc+GLqEwxA89mioKuCkVzgqV10m9Qq+wRTXgc\nhzDLm97Hj7Z/Bl6mIuClsl0Pv3/CL9+7c/n/4DtLbsMDT/niMZHlHWFYqHu4Y2aJNWwZWlD3N1Aq\nRh/kptoE6qvi0FQ1QO6TO2rF/08fU4svXzjdj8ejNL5OURvzCZ94Kl08MJd5Q3UCR030CVa4iAGk\nErIrmrfwLL/YMXV+TJx386pImNA1Bc7WsdQqIeUxfuIRVFoVpaeDqRmi5EgJqXU3VCMgvfjZD07B\n5R+YIf7Ww8hb2t+TpEcntNUFd9RtqIqOuGR5z2+eCxAFesNmP6Pe1ZHJ2zR2yaw6a+ozfptXx0Au\nHyTvbGIjI2h6nzVNoeTLXtqJlEvdpooHz6afxS0dST2Jwlv0+lzFFpa3V4gL0RauVpbxBkUSk8hl\nYN+/dscu/Owvy/HEm5ScSYaFW3jZGrdwTN+789yrg4J8XOY2R7IXUF2MVuaCFBOB73hsgJYUcosx\nriawoycLt6hDt1xRiREgb0BkxOftIjziIR9jFmM+IfIKAFoOt2pjL/64lNbzt6aaqQyqwlu32uVu\n81hSkD/XyOdVBbativtJPy/gsZe34ull27F5Fy2Rq02m/aQ/j7vNFYgcCFsqeWT3dzBrY8122mKX\ne0Dyeep2J4SAe43VWAbEU0AKcXGN37l7KdZvH4AxgvUzcHTkXliEpNMIBzZ4GZ+mKgErmkP+zdhM\nuY84hig5zEreroxdImID3+1N57A8abDgFkTCGkqSyVJGUhCrIyzvrK+q6OqA7gT7juvhSWe268F1\nCfNuqMgXJE8Zu8ai7QbG8GrXcmjpXvEbEfu8S4jI+z2GcLe5VPIVEvMOWJYSudek/Bfr/KlNmDra\nJ3MAAUEYDpnc4WkBlzk9p2+5VyV88RiZ8MQChBGVSGarS6IuLnkDPA26ptDEKkUBoIAUEtAhuarZ\nS8ojBIs6TkBV/wwa72OQyZk37pARqJsHVZKrkhYBCb085i27zT2pfj5VojKnKABcFQ3VCUwbXYu5\nExtQHauC5VZBiQ8FMuppqZhfh6omhkQmLHENDOVtFGzfbd6XXI7p09l99VToqsK6zrH5GPMMVNZb\nmj8TBduFoijwhmhoxEYOipmHQhTAlmLezPLut/tgxIPkzc/Vn83h+VWd+PsKSt6lljePLTpMMCb/\n+kLs7CmIuHavthHfff5HUHU6B5br51qIlynJQzFzIt5tKQlk8g6Ia8BVCjBYtUYpefNcjbybD1iD\nzrYxrByPC9FQ8n19K81gbk01U8ubu80JJe+EpQuXdtKK+d4PI0jeXPdAdPDzCsiKen3672lzx4rf\nnS2XWqksROKYoGI7qng+sgUnQJwAkMl6ICBwPMePeceyIIUEAFXyDri49wnfclR0ByAqFM8MzLWm\nKqHlcmD3UYkPQq3sogtqT4NC6Dhkb1fGLre8lVjWJz+jnFjzbkGUivE5cnc14gOpy2CqlvjMcT04\nnoO8mxeqisTVoaiOmGN6fSXfwd3ujkcXAJoNuAZts+zySgJ/n7LjpTkChkmqO0CIyPsww5566IZa\n3oGYd3mdNycbz/MCCWvyQqBUfrB0eygUglhJrbpsrVt6+PG8QQh/iR49tQkfP20irr5oBkzNpCtu\nNv50wixrSiCTN38RVqUs6KqO6sKkQHtUUzVEOdCYmjZ888jrMK5q9LDXaGhqoJZ9XHMdPnDkCNx0\n2Vz/slUFhVVzoearsKDZF27hgh4yKhMJ6JqKL104HfOnUq+CiQR9YRh+0h0tFSOB8h7enAWOgf6h\nYHY9APSxzm10kaMGLG8A0OtYwhT7bOG0ZurZYQRAyTEPnSQAKNA1BQumNWFEFQ3Z7Mr3wUqweHKJ\n5a0YBVimImWrJ6Brip/YptlQFF+qlQghGVVoxW8Z2ibmwHN8qV23zw8ZKUZRWN4WErSch32HbrJY\nrsVkMJnHIGXRf9/Y0oVfPkL7NDudbSDFOAtNcPJmI0pQi7Ml2URj3sxtTsnbRdzSka70UBmjrV1L\nyZ+7r4u8RxD7DRb1XRhwugHFhcZEgyqsBGrScRAiuc0VOW4vJe0x0ti4Y1C4r0lJ7kHRsyl560W6\nwGChB+Fh01ykkiE04AYXWSqzvHUvjuKGSbC3jaLbmSWqN9MFDsnR3vSKS8chk3e2pH4bAPSGLYjN\noNnmhuWTYFKpggIFBafot2FlY/EKcShOnC7CJLc5j6kHLG/NoUTMUGrdK6oHKJS4HZewcj0ahvIX\nOPQ7MnlbkLfT3QJnRzubA7o9bmmR5R1heBT3QN56iaVoGcGENeLJ2xXpv1wmskSqkyGsLWCY5Q0A\nY6voD5vYFsa0BF3VMtEamoFvf+JI3HrF0SXnpeTI5V1jpoZjp7eIuHiVVcmuRRMNW2QYEnlfcfZ0\nfP7caRjTSo/RNZW9YPh1aSJO1VSTQGOqBhWmbJkH57M0/p00E/jwiePQ3ugfo6kKvMFaJDedgOqY\n/7nc7IGjpabc2rcU+oJVOem4GnIFF7miCy3mZ1VziyWQxyBZ+l051vCFuc0NPRgWEYlVnobPnD0F\nNRUxen+IBuJqKHg5KLotXsS6ruITZ0zGNefT+9WT3wU9zsnbEucCAK1yF5qmv+W77l0dLbVJsVDQ\n67Yj1rLJF3MJkZYFAM9gCnBF+twkYjq83iZUDbIKDL0o9NKTajWyeceP+zJLTjGZNeZpuOXf5iPN\nyHv11h68vtFPsgJo8ppX4vZWrBwsNYZ7/rYJBP4C1CU0YU2zCsg4Q+iobGGeJhWEKFA1Rt6eDVM1\nYLOsZrHAqejBC7gPWsNmaJXsGow4aiuo0EvOLhey4fFuIkkFr9vZi9gUKpzCFy5y3NzziEgMEwtX\nISTjoChJ2Y6qYhnlJeENVSXoLw4grqThdnb4izUtaBUX3jgyMA5ZMjUjZYIHcmMAQPEEeRc3TMIM\n71xYmomiWxAxb1GD7eooFF0YqinKHh2X+Cp2IrFRh6J5yBSkeWTkrQ42wB1gZWWqC9vxqKdDs0Fc\ng3lwuOVNv38wa4v5cLvaUB1jybvsGY9behTzjjA89mR5lwovWKaGie30ITthVmvoS5KngHgEQQlR\nibiUEPIu7fTEccX0T6Bi2wkgmUqcefTIYcdqqDra6lOoqYiVfS6j1Hr33dYkKNTBj5dUz2pSScwc\n71tqhqYG6n0BX7iBZ30nDD9cUGZ5l2Sex8JKxbgbnhDELR1nzO/AvMmNWDituWxfSyuPmXPyVrhl\nzd2sRRfElPpCM3KXQx2Vdf6LWGQrexp0VaVubzlhx2KuVtfXntfZfSaOgSFniMqzshc5d5vH9TgS\nehy7cr2iJI+/zD+4YIw4f6eyRri947qJz35oqug0BwBoXYmCW4ACJZDMJ5frOBq1evM5Rt5snKpD\nCXj2tCTMup3wCnGkvUYqYcleuq/iL4DqUPJm46urjPueE83P6OcvfM8jovWqxvu6aw5yWQVLWJtM\ngy1aHTh0MZ2gNdod6XamSgfAU6FoLJud2DA1U2RNBxfQ/iKNz21N2gI8Ddtz23HHsrsgMvqlccLT\nxMKoM+tnO8ulcIBP3rwlqli4Sm5z3oa3Wm3CN46/KuCh4ffC1bLwiIdxDS04YVYrxrXUse22P5eA\nOG57Fx1vgLyZZXzVjE/jqhmfDswBvRd8gRLDUMaFpZk0YU3EvNn8cfJWTJFQ5rgeduWpp8kX86H/\nDhX8MSi6DS+TRmbVLH8Bwo7f5qyHogDeUGXAbS5yC3K2mA/iGJg7voWek2kimHpkeUfYDfbU67fU\nhZyMGWitT+G2Lx6DxaeMLy9XAkRrPyrmIFnGEonK33v6yJNRH68NxH5lWJqJq886ATd8bI7I+JXB\nm6+UeglKt/NyH8sILjgqmeWt6DbSyZBac0XOXA9+h6YpActbBpf7TOp+LH5PlndY9yQeo+eqcecd\nNwafPnsKmmuTqDQqA/uGldvFWekWfzEeN82vr1WkF7/O483SgqxdnVZ2PuL6lrdcIy7I3/W158e3\nV2H+lEY0pqtEaRBPaNOlhUtNrBo9+V7U1zOyZy/CU+YEdeCJ4oB4KmaOa0RTTQLfuuyowPa8W4Cl\nmaiuKF8EAUBRoyV7OUbeosENK9dzk53wFAfurkbs7M0hm7eF29RGHlrtdroAkcSBYixPQdHcMnf0\n8nW70DfAXMUa70jmBBZIwuOkuFTVLk5Jo6OizV9oen5M2iU2DCk8M5yXAQDqE3WoqYgJ1/1rXcuR\nJf2wJlBJUd/y1oU7l3sv7O0j4Q0wi5aR8zPbn4dLiBAb4QtX/syYY19FhtDx1+ktSFspGLoKj7e5\nZZamrdHjm5J1+OiiCWirpgaBYmVhjnsZWsUudk56n555lWaqb+rpFdfG3ea5IR1JI/heUHQbnsEs\nZ9tAf6YIS7NQCIl5wzGQLzq+qJLqoug6uH/NX+k1MhU7fo0i1q7QOm6xuJcaBdmOh60OFW5xu1tD\nLe+hrB2o8EjH6Hvig8eOwHc/PQ+WoUUx7wjDY97kJswaX4+vLZ4Vup1nVNdZ9Zg2uhZnL6Qu7GSM\nJWaFkTezvUvbBcqiJvKmM0Yvwk3zrw0mp5WgtjKGUc3h5M7JebiYOd8u9MdLkt7SJn0BKbqDdLyc\nvGXLu1RIxtBUv+SoBJwYZMtbLyFXTmAfnXQhxlaNwsiKkqYlAM6Y34HT5rXjU2dNLtv2hWlXBWr0\nwzL2E5rk1lc0jGzyCb+m+xhfI54n+kj3tEFvxw+O/ffgCT1VinlL99RgbnfPz3jXVBWfOmsK6lL+\nvXOKLAFLWrjUxqphezZ67R7qvuS92y3//GkjhYq0CsXTce6xNI9A04KLy135PliaiY+dOhFnzO8o\nmwvCwgCdPVRHnN8jTmI7MszqtC1s685Qy1tWA+WhBdvCSbOobn/CYG1Nx7wu+otzwn9pdZcIJ2ga\nK8nTnMACSSgPcnI26QJjRLoVlsmS+jwNHmhfe9uz0dMnJToNoxxYvXURLM2kSoeyVCk2+vMhW95C\n598RcyD2Y8f/c9OT6NPXQIlTD0ap5a1oHvQxNJuee5Fk8halWCo9vi5OibE6Sc9jtKyHVs3ugfQc\ncm8N11bo7suhM0sJ/Sf3vFn221fTPSC1G0CKFrxsBdZvH0BPn4OcUxAiLVzbgdgx5G0XGgwxxgIZ\nQme2G9PrpooWvdzyzhTZ74Qt1OIaj/v7mgeO6yGDXpCiBZJL0wx1fs/ZImkoZ0uuewOVTGggFgcq\nUxZMQ0XRdvdoYO0vhJs+EQ5ZWKaGK88tt644UkYS35p/HSrMVHhMOqS1J3/Zlbb+k6340pZ77wS+\n5R1O/pogb/riLiXvuJThHeY214gpvqd0gUHJV8XE+Cx0NFQFtnELf3duc+5Wntc8B/Oa54SOP27p\nuOD4saHbUrFYwPIPu0dJPQk4/vbqtH+9llcJe/0UWJOfB1GZFSBZhZahIaZbSOoJP8bo6dBUBTFL\ng9vTgqJuw+zw5SHh6mVd32TLyCkyYpeItyZO44V9hX5U6JXg7RsURUFh9UxY41+B7dlIWAZqrKQI\njWglnouck0M6UYdpo2sxbXQtHlqyEYWVR6FuyhoMEhazJ4Bjq+hoTQjPByfvXqaRbqoxbO/JImbq\nouc44MtbLpzSjounUNnadExanNWzpL1Aq1z6Ha45CKj1rCpAmmON11mzeD57oSeNhL/wJCo8uEhY\nOlzVLbG2gwsYTt5xRp7phBH4neaJn+XtDXBi8suYeLlVIJ9B+v+COgjVGoKXjwsPguyB4z9z/rsy\nNBW9/QRWo+/9Kap0DPUsVl2TKM/VCHw/I/8iyeP1tT348f8+g9gR6+EO1ACuIQiZQ6vugqIAxS3j\nAU9HvujCLChQzaLoC6ym+kEIdWsXii7i4HF5FzZT4RvoKxeEGsxnAZjQ0rS6okKvQhcQ0DywHQ8O\n8ZUCs3lb/K54eCJrboNVxWrfPRWVcbqIzjssVBUbgDbiDeSKp5XNzYFAZHm/B1EXrxk2mexblx9V\n9hmnZbIbgi61yt8JRFx+mFOqSnncXkZCcmvL7U55rJlnm4dZtfzlf3TNSTh7zKmh35PYjdv8ncLU\ntYALN8z7kNb9a7JUU7jhAZoMN29C0NqX3eomW4BUxZi1ThSAKNBUBcmYgfqqONydI6ES2UrSAhYz\nAD+jHxAvYpng5Xr+ilgSs8bXi0XloglzUUmakXcLyDm5gJiPripCHpQjVhL394aqMdM8xf/A1QEo\naK5Jipp8txict7pUBTp7cxjMFuF2t+KktuMB+PKWDekKUXWRNMt/G7LkbYJtH4qvg9FOFzky2QkJ\nYJ6Mp9iIaVbwuXV05L0cYpVDVMSmpLJDBifvGHvuUnED8o8jD2r1FlbPFORbEaf7ajXbpQ575RoO\nAPWsKUYRhCWr3fjxuThnQfnikqvrmYYq7jl3mxdUujyrZ5Z3Q0U5eZeqNxJPRc7JY9naHiEA43bS\nZ3XZup7gHHDPgOwV83QoCi+1I1CT/SC5FFRCyZ9b3takpbAVapWv3iDVkrt8AcF6rTdSQZZRFvOI\nSZZ10XHhoCg8BnLuBPds8GeBjRhxg3fQo+SdSa6F0bwBm3qD7UoPFCLyfp9hREMKN867BjcvuEF8\nxt08Lvt3XP+5uOGoqwPH7U/y5pa1S8KTO9QSy7s05t2UpOpN9bHaQO05fzlzyzuMGPnLP6xjG/+e\nZIjlffHJ4zB+RFVACW5fQL9fETHN/kJ5b2ceFgAASw+St6oquPC4oDu+QrIkeQ9s/pKloQefMK5f\nPBvzpjQibUrk7Oplcyxn3E8f2YxPnDEJx83wBXZ4xj9A5+vKc6dhFksMvPCEsRjZQL8/5+QDSXma\npsLZOg7FdVP9awxJ2otp/vg4cTbX+Za36ypI6v51N1VVwiME67cPQFVUnNi+AIBfTiff0zE1I/zs\neAauugbQOefQG1g3L4mYDFWjiyJmeXuqLeLoHPa2MSDwUGxaRj/gmvNmubdJMQsgRKEd5cDIW/N/\nG0VlqGwMHfXU82GOXAU11ReYJ/n7AKDIiJfY1LXb0ZTG+NagZgPgaxYQ4ru9eYJWERkYqiEWdfXp\n8pBYaSIeHAMFL49ETIfCeoDzkM+ytUHyVpmSn+w14Za8Yuap7oHmwstUwjI1arm7vuVcTG1mx0jN\nhQIxawIt3Y8RqTbUWLWB8Sqai6ydp78VtmjpzxQDx1tSCaCzg4Z3+D1/bPNTeHzz037mPQn3KO5v\nROT9PkRDog6VVvnKmbvGLSWFpmRDYFtIXtY+QxXkHX5SlcfaWcy71FoZXdmBzx7xcXxp9hXB47ik\nNCPv0pp3gFs1VIq0FMJtHmJ5nzJnBK67ZFawZn4foCgKrr5oBmbWURWz0sQdAEiYMVHrbGkmkjFd\nkLKmBUkLAJKmFONn19DMFjgcfOlVmbLw6bOmoDImddjytLJEx3qplOeoiW1YMK1ZzB2AwPOTCLsG\naQ5ly5vfS/klHUbeimv6nhNO3jVJcbzreUhLY2itoZ6ATN5BIqajwkoHqiHkMVZYaehvLkL+9YXi\nM/k+xIzdh5scj8BQLKiJQSjxAXhKOXl7fY1I6Wm4Zl/g+NLKCQFXF2JKybgRVC5TWaxXImdu9QF+\nXH+4RLiiTscwsq4ON1xKQz1hiZIJk96zwWxRsjqZ2xw5VJgp8ZzIz5x8DYCfsEk8DXllANvct0SN\nNo9Db9hRrtYG0JJD/qyLOTviaRH+IPkkYqaGfNHBUM6fI0/3NQ9Kx2N0rGJzRFATr0KMe5mE5e0K\nsR5O/rmCFPPWHDom1QOxTdibJtFxSc/tfW89KMJYils+twcCEXm/j/GlC6ejrT6JY46gJQ+cvGWC\nGsvqoxtq4uUn2EfwOHRYpjbgW+Ya++2kQmq5p9VNLluAcLe5RuiPKszyPml2G7568UyMaPDJ6+On\nTUR7QwpjWqk1EbS8939ayOSRNbjsiPNxycTzsajjxLLtluRaNzUqQiMatygKNFULvDiC5E3nroy8\nSxwnSSNoeZeiPu6Tt0zEHBVSA5rQBUiAvP2xcs+HHDoI08h3XCI8LPwl3taQFDK6iZiBasn676j3\n3fiJmA5VUZHQ/WssXfAYugqST6Hw5mwU101FXaU/3ngIecvEWSi6mKgfDUVzoTdtgIci4iELkLRR\nCaIEO7uVhif880slmpoaIG9P9ZOkOJRA1QDvXS/FsaUua45OiW/e+A7RwS6szDNlMNlbqVaeyt8S\n2MghLXljShd79BqY5rzJa8jpta8k//TVANl5t3aXS6USxwCIhoZqdi8kmeOWDno9Zx45HjGTajP0\nZf2ySRKKQnq/AAAbb0lEQVQbCJxfHo+i29CbN7BrTIjx8Xtijl6OdblV9CBXFzkAckKbodN7Egif\nlNxzy6I/suaaYEXJgUJE3u9jTBtdi29/4ihhhXLXuKym9qULp+Nri2dhTMv+eyC55T0cefPt6YSO\nb19+JCrfpqv6iDHUHdZWzVyKIdZFzNQxqaM68PI5dnoLbrr8SBh6iOUdco79AV3VcXTLkaFjNAzV\n713NSJeTN19Y8aoCUzUQtyTVOmF5S33R6ZkCf6UMyfIOJW/frZowysm7UnqRlxJj6THyNXLrkkus\nAggo1vE6esfxfCEPRkS1FTGcMa8DC6Y14XPnTEW15Sccjm32a/m5cE9ausbSaxC1+P31cLvbxLMD\nhJO3vMAp2C7aY+MB0HI7opAyyxsAKqQmL9yKa2sIL1NEqbWmlv825PvUW/Sbhvglf/52N6RxkRyO\naU+34YNjTgsQZMqS480avEwaWsUuaPVbQBQvcDwAjKroQEqpFp2+xjbV4rR57SK0YW8eL/bV0n1C\nOnU48FBGQzUTKUr6IaW8QRu3jKqvQ8LSkc07yG0ZAQyy6xS96/05mDLC9x7yTPWkkfS9H9K+y7LP\n0jE4BkawBQ4kt7mha9QL4egY1ZzGly+cXrbodPUcTM1EOvHOQmtvFxF5RxDglrfspo5bOsa1VQ13\nyD5B3UPMW1jm8IZ/2YXg46dNxFXnHYGFU6hs4R7lW4eBoRnCZbuv53gnMHVNvMy5vCTPOOfZ2jUx\npg6lKIhJ8WruNi9VsCq1vNvSkmBMyAtVJtQwy1te1ISRe3wYy1tkrEtWolySyMvRbMcT18jjoYqi\nIBEz8IkzJqO5NonqmL+gbK3zn9EjxtJrr5LIPVGywCgNxUwf689XmEuYuJqwVfNFF0nTAnE1EVMP\nI++URN7cyjv32NE4//gxZfsSV99ziZFENos6ThL/z/UQZOudZCqRe/HkQGy/QiJfRVGwqOMEJGzf\nQ5OWyRsKiutoAqLesLnseAC4es7nME+7UCwARjdX4YLjx4pyPrenFflXj5PGT3uNDwfujeGWt73N\nn6ch0Bh52kxhVEsFXI+gp9dD88CxwXOwRe8lp4zHF844DpW5CQAANUkt85SRRJznHYQtJFzd98rx\nksGqbrg1a2jioWvgklMmYOro2jLvQ09uV6gH5kAhIu8IAsfPpAlJs8aHtxvdXxhTORJAmHVIMb1u\nCh1P24K9Om/M1DFjXB3SZhKWZgaSrvYWPEZ6INzme4IpWd5claqmgvc7py8MTmxFtwhLiqNazHug\nqzqumvFptPfRspVSWhhd6ddUf2NxeQWCjHgIecsItbyHiXnLXp20Qq1dXv8L+Ja37bpoSNDnkBBg\nzsRgDgYAVPMFDILiOVzJri3V4o+x1PKWyPu4GS1oqfWvIWUl8K3510IfkhY4ro6PnEItyUVzR9De\n6LYpuqrFJaW9tnrqrm9IBpvoANSDcvq8Dnz9yC/jxBHHBM4vo3YoqONAXE2QCQBMrh+Lr5TkfJSF\nPzxdlNQBCP09aNLzLSc+AsBXPngs4GqC+NIhxxMoYlw8h+Wy0yYKWWRSjAnLXHZpA0DlllNwztgz\n/HOxPIiKBPME9jYF8hIA6k2Z2O7f95baCmiuf2/5d+SLDlRFRYfH+ruzkreUZHmHClY5BkYIqWMF\nbj99RovJbfQjVy+rfvHnItwDc6AQ1XlHEDh9XgcWTmt+227qfcWF4z+ICdVjMatxeuj2SbXj8b2F\n3wyWK+0FNFXDNXM+H3AN7y0Sehx9hf6DYnkbuubXm7JabRHzZqRTKxGX/DIxpSz6CTVjYbkZAD1l\n7D0i5WeOj24O96x8ceZnsHFwS5m7tBRhZYmyNR6WkAYA7eZErCg8g+aUb/0J8nY8zG2cgXV9GxDL\njMTZx00qO77GCo77psvmYiBbFHM1qmoEsCV8DHweZ42vx8dOnRgcu6WjLl6LOSNH47lupn3u6Zg/\npQknzaZCL6+s7mJSpdQzInsqvrZ4Nrr6cuhVN4nvT5oWvvrJo0TYoCXVhIUtR+GxzU/R87s6IHHC\nCGUatrxYg/icf9APSohGVYKJi6RE2lh87vj3Jox8PdZ61elpQoKFX266bC5Wb+7D5JE1UFdUwovv\nYseXPwfHTm/GP5dyOWBK3o01CXz90jlYuWEX/vPe10DsGHXtl1xDtVGHk9tnYkzlKNz94t+xsZMu\ndqaMqsH0TbUY0ZjCX5esByGK8C6kzBTGj/AXXifPbsPmt6rQyYVYuDgMlzw2LZCiJRZZKTOJmDK8\n5U1cHUdPbcIf/klbpBbfnIPYnL/BNQbE9pgxvOt/uGf9QCCyvCMIKIpywIkboC7Go5pn79aqTUuZ\nrfuCpmQjUua+kT9A3c5pI1VWc/5uwNJV1roRqGSJYaUxb9EUAcGyt9KSLz6DpIS9Dc3AhOqxGF9V\n7sLlGFc9Bie3Hzfsdo6w8IdcBx6WkAYAU5Nz8NkjPo6zR/v19kdNpkQ+sb0auqrjkknn47w5c0Q+\nggzZbQ4A7Y1pTB3lx65HVUqysiXPkio66ZW7qrkVP76+zf/Q1f0sZQCmqQUy5mXhoLilo70xHcgb\nOG56G1rqgs9jXPYGuDpOnOV/XzpBVdZ4YlmYlRhIOvR8adLW+iTOOYY1B5LqpsPCHzWDM+H2NqCt\nOE/McXtjGiczmVut6Lv+wyz3uso4KhL02r2S52DyyBpceMJYv1lKyTV091PCHVXZjqnG8aLne1NN\nAl+4YDqV2iWqyI8wVB2WZiIVN/CJMybhyxdOR3tjGpMa/C6AHz91Mlrrk2KR1TdUCCRHUstbCx0P\nAMAxkIobuP6js3Hy7DYACohtwVN5A52g5X3JxPOFpxAID58cKESWd4QIIfjY5ItQcIvvaAGxr9B1\nFW5nG2yjgCvPOx8AUJP21a+AYDKWXH5klpK3SJ0t/56rZn66/MO9wHnjzsIDax7CxJpxZdsaE37o\nZTjytkwd0+qCNevnHDMacyc2BKoBhgOPaZdm1nPwpD7e31lGW30SmzuHUFc1/MtWvi7iagGXv6Vr\nAZd02Eu7OdmI5mQjtmd2oqmy3LshW84nzhyB8SP8fToa0wAUmEocBZIVVutpR7ULyeOEEYcChS7M\nJCI6YWYrTpzVhhNnt+Hz/1WEmuqDSozQZ/mCo+bi0aWNWPyhCaFzoBerwIVd08N4ssZVjcZLna/5\n1QESqtIWyFZekkUt81HNVP50guT+lhdn3PuSZImH3kAN1FgWtudn4C+QmvzMbZyJf215BgBNPj12\nuh8uaapJYHl3HGqKJr8ljSRivN+BY+L85o/h3rcegJryLWuAVtmMba3EkZMace+W5diapS4U4hgB\nsaKjW47ElNqJeK2b9q1/N2PeEXlHiBACUzOHVak70DCYhKuzdRzqE9R6a29M4ZxjRokOaTweXB+v\nDVjbZoj4DDCsmN07wokjjgnGbSXIRCFaNZYgbKyqqgTaq+4Opmbg5gXfGHZxAAC5l06iL9sPBD//\n6AcmoK0hxayr0rHTfyulxjtfuXB2YB/TUAPkHQ/pLqcoCq6Z83m81rUC0+unlm3XVA03zbsWf1n3\nMOY1B88/aSQlNiPXgEJsA5QEJRdNU8X9VhUVcT2GrJMrkTv1O7DpJIbCiqNRlQ4nlTEtlbjinOHl\nluOZDgwZW2FU9pZpP3AsnnQBptVNxuyQMFgypkvtR6llfvyMFpxz7KhABYvrlmfX88WS21/ni+WE\nYGTFCDQmGtCSKs+h+dAxo5B5eRJeGKKqZykjAVPKjxhd3Q4vlxbkXZr1P7atEvW91YK8DcUs03pI\nmynoqg7HcyLLO0KE9zOSMQOXfmBCwPpUFAVnLRgl/q6NV+PauVehxqoWtdMAy1QPwbvUKyGARR0n\n4G8bH0d7upwggXIvwb6gcpjOdhy3f+kUhDlPYqaO044qb4RSio9Nvggv7HwFExpaA5/HTC0QTx7u\npW1qJuY2zRz2/PWJWnxy6uKyz6tSFlrrkujcUgN97Aa/R3XJjeTiIh111Th/8Ww8+vwmHD2Fkpii\nKHBcD4CCUU27n6fhoCkGiqtnY+KY6mFzH3Z3jRPaq1D7Vgp96BYKZJapBcIbANA7SGPSFSG9Cnij\nkXFVo8u2AfQ6bzjqK6GeBUPXcMmcU/DCE4/T79aswH6WqQXaKB83pTyMJHdPTJrhXRJrrCp05roD\nuQ8HGhF5R4hwCIJn/u8OYaSol3Tt8t9T7z57nz36VBzVNCvUnQr4mfEHEqX913eHay6eiQeeWheY\n+yObZuHIpvIOfg3VCZw4bQyeHqB61/EDYHGdd/wY/PTPWRQ3TII3SC3x0kXYpJrxWLVrNU4ffRLG\n1ldibFu4FT1hxL6Ve/L5U/YxPUpTVZw//Rj8YvkGuF10XsMWmJUpujiZPKqmbBscE1+c/CWMqKsu\n38awu/CWoer4ztHXI+fky/bTVQWktxVesg+jnWNw6YfKEyPlBWKlVU7eAK3+6Mx1v6sJaxF5R4jw\nHkLpy4n/fRAMbyiKMixxA8O7+A8WJnZU42sds/e8I8MHpx8JrO2GSzyMZuWP+xMzxtZhbGsVVm30\nPQSliYeXT/kICm4xkMAYhpHDtOfdEy47bRLufvRNXHRSeV7D28XMhmn45lFfxdeefx1AeF+Bs44e\nicqkhWOOaA58fvnpk/Dy6i6MaWh6R9LE1bEqhFG/rqtQMrUoLF+A6qnhYYFKSU2wKhHufeClm2Hh\nkwOFiLwjRHg/4GCw9x5Qmhl/uCFhxHHxxPMO6Hc01iSwamOv+LvU8k4YiVBteY5vXDoHb27uxbi2\nfVNIbKlL4tpLyj0Pe4vGZD14NnxYuMTQNZEhLmPhEc1YWELo+xO6poqSjPgwuvNT6yahVZuADVuK\n+MAJ4eWtnLwjt3mECBH2Cj/43AK4XnnSz26SzQ869kfM+72OxupgedecCeHW4XAY3VKB0S37ZnUf\nKBxKizZFAbhBP5znPWkkcN2xl9Me4lY4ZY6pot6RltSBW2iUIiLvCBHeA6geJpt4yqgavPRmF2aN\nrwvdfjBxqLnND0U0VvtW9e1XH79XMfxDFeYwCmXvJj555iS8sbGPlX3tObSkKsqwxA0A46vH4gfH\n/ntkeUeIEGH/4NjpLRjVVPG26qbfbZjvASI60JgyqhqTOqpx9NSm9wRxA76lezBx9NRmHD2VWsn7\nK6fz3SRuICLvCBHe01AVBR1N+67xfiDwqbMmY/32gVDVtAhBGLqGr148fKnZ4YTT53Xg2eXbUfUu\nqDjuDfzQ0qEYXBoeB3Qpd/PNN+PDH/4wLrroIrz++uuBbc8++yzOP/98fPjDH8Z///d/H8hhRIgQ\n4RDC/ClN+MjJ4/e8Y4T3FM4/fgx+eOXCQBOZQwEXn0wz6WVltsMBB8zyfv7557Fx40bcc889WLt2\nLa6//nrcc889Yvt3vvMd3HnnnWhsbMTixYvxgQ98AGPHjj1Qw4kQIUKECBHKILvQDyccsCXQkiVL\ncPLJJwMAxowZg/7+fgwNDQEANm/ejMrKSjQ3N0NVVRx33HFYsmTJgRpKhAgRIkSI8J7CAbO8u7u7\nMWWK322lpqYGXV1dSKVS6OrqQk1NTWDb5s2bd3u+6uoE9P0cI6uvP7RigYcronl854jm8J0jmsP9\ng2ge3znejTl81xLWSjV59xa9vdn9NBKK+vo0uroG9+s534+I5vGdI5rDd45oDvcPonl859jfczjc\nQuCAuc0bGhrQ3d0t/u7s7ER9fX3otp07d6KhYe/EByJEiBAhQoT3Kw4YeS9YsACPPvooAGDFihVo\naGhAKkVrTdva2jA0NIQtW7bAcRw8/vjjWLBgwYEaSoQIESJEiPCewgFzm8+aNQtTpkzBRRddBEVR\ncOONN+L+++9HOp3GKaecgptuuglf+cpXAACnn346Ro0atYczRogQIUKECBEAQCHvNBj9LmF/x2Gi\n2M7+QTSP7xzRHL5zRHO4fxDN4zvHYR/zjhAhQoQIESIcGETkHSFChAgRIhxmiMg7QoQIESJEOMwQ\nkXeECBEiRIhwmCEi7wgRIkSIEOEww2GTbR4hQoQIESJEoIgs7wgRIkSIEOEwQ0TeESJEiBAhwmGG\niLwjRIgQIUKEwwwReUeIECFChAiHGSLyjhAhQoQIEQ4zROQdIUKECBEiHGY4YF3FDmXcfPPNeO21\n16AoCq6//nocccQRB3tIhzRWr16NK664Ah//+MexePFibN++Hddccw1c10V9fT3+4z/+A6Zp4sEH\nH8RvfvMbqKqKCy+8EBdccMHBHvohg1tuuQUvvfQSHMfBZz7zGUybNi2aw71ALpfDddddh56eHhQK\nBVxxxRWYOHFiNIf7iHw+jzPPPBNXXHEF5s+fH83jXmDp0qX4whe+gHHjxgEAxo8fj09+8pPv/hyS\n9xmWLl1KPv3pTxNCCFmzZg258MILD/KIDm1kMhmyePFi8o1vfIPcfffdhBBCrrvuOvJ///d/hBBC\nfvCDH5Df/va3JJPJkEWLFpGBgQGSy+XIGWecQXp7ew/m0A8ZLFmyhHzyk58khBCya9cuctxxx0Vz\nuJd46KGHyB133EEIIWTLli1k0aJF0Ry+A/zwhz8k5557LvnTn/4UzeNe4rnnniOf//znA58djDl8\n37nNlyxZgpNPPhkAMGbMGPT392NoaOggj+rQhWma+PnPf46Ghgbx2dKlS3HSSScBAE444QQsWbIE\nr732GqZNm4Z0Oo1YLIZZs2bh5ZdfPljDPqQwd+5c/PjHPwYAVFRUIJfLRXO4lzj99NPxqU99CgCw\nfft2NDY2RnO4j1i7di3WrFmD448/HkD0e94fOBhz+L4j7+7ublRXV4u/a2pq0NXVdRBHdGhD13XE\nYrHAZ7lcDqZpAgBqa2vR1dWF7u5u1NTUiH2iefWhaRoSiQQA4L777sOxxx4bzeE+4qKLLsLVV1+N\n66+/PprDfcT3v/99XHfddeLvaB73HmvWrMFnP/tZXHzxxXjmmWcOyhy+L2PeMkikDvuOMNz8RfNa\njn/84x+477778Mtf/hKLFi0Sn0dz+Pbxhz/8AatWrcJXv/rVwPxEc/j28Oc//xkzZszAiBEjQrdH\n87hnjBw5EldeeSVOO+00bN68GZdeeilc1xXb3605fN+Rd0NDA7q7u8XfnZ2dqK+vP4gjOvyQSCSQ\nz+cRi8Wwc+dONDQ0hM7rjBkzDuIoDy089dRT+NnPfoZf/OIXSKfT0RzuJZYvX47a2lo0Nzdj0qRJ\ncF0XyWQymsO9xBNPPIHNmzfjiSeewI4dO2CaZvQs7iUaGxtx+umnAwDa29tRV1eHZcuWvetz+L5z\nmy9YsACPPvooAGDFihVoaGhAKpU6yKM6vHD00UeLOfzb3/6GY445BtOnT8eyZcswMDCATCaDl19+\nGXPmzDnIIz00MDg4iFtuuQW33347qqqqAERzuLd48cUX8ctf/hIADX1ls9loDvcBP/rRj/CnP/0J\n9957Ly644AJcccUV0TzuJR588EHceeedAICuri709PTg3HPPfdfn8H3ZVezWW2/Fiy++CEVRcOON\nN2LixIkHe0iHLJYvX47vf//72Lp1K3RdR2NjI2699VZcd911KBQKaGlpwXe/+10YhoFHHnkEd955\nJxRFweLFi3H22Wcf7OEfErjnnntw2223YdSoUeKz733ve/jGN74RzeHbRD6fx9e//nVs374d+Xwe\nV155JaZOnYprr702msN9xG233YbW1lYsXLgwmse9wNDQEK6++moMDAzAtm1ceeWVmDRp0rs+h+9L\n8o4QIUKECBEOZ7zv3OYRIkSIECHC4Y6IvCNEiBAhQoTDDBF5R4gQIUKECIcZIvKOECFChAgRDjNE\n5B0hQoQIESIcZnjfibREiHC44ZZbbsGyZctQKBSwcuVKzJw5EwBw3nnn4UMf+tDbOscdd9yB8ePH\nCz3rMHz0ox/Fr3/9a2iatj+GHcDOnTuxbt06zJ8/f7+fO0KE9yOiUrEIEQ4TbNmyBR/5yEfw5JNP\nHuyh7DUefPBBrF27Fl/60pcO9lAiRHhPILK8I0Q4jHHbbbdhy5Yt2LZtG6699lrk83nceuutME0T\n+XweN954I6ZMmYLrrrsOs2fPxvz58/Fv//ZvWLhwIV5//XVkMhncfvvtaGxsxIQJE7BixQr89Kc/\nRV9fH3bs2IGNGzfiqKOOwg033IBCoYBrr70WW7duRVNTEzRNw4IFCwI9ijOZDL7yla9gYGAAjuPg\nhBNOwJlnnokf/ehHIISgqqoKl1xyCb797W9j48aNyGQyOPPMM3H55Zfj/vvvx9///ncoioKdO3di\n9OjRuPnmm2EYxkGc4QgRDk1EMe8IEQ5zbNmyBXfddRemTp2Kvr4+3HTTTbjrrrtw6aWX4vbbby/b\nf+3atTj33HPx29/+FpMmTcLDDz9cts/KlSvxk5/8BPfddx/uv/9+9Pf348EHH4TjOPjjH/+Ib37z\nm3jmmWfKjnv22WfhOA5+97vf4Q9/+AMSiQRaW1txzjnn4Oyzz8Zll12Gu+66Cw0NDbj77rvxxz/+\nEQ899BDeeOMNAMCyZctw66234r777sO2bdsOSy9DhAjvBiLLO0KEwxzTp0+HoigAgLq6Otxyyy0o\nFAoYHBxEZWVl2f7V1dUYN24cAKClpQV9fX1l+8yePRuapkHTNFRXV6O/vx+rVq3CkUceCQCor6/H\n7Nmzy46bNWsWfvKTn+ALX/gCjjvuOFxwwQVQ1aCNsHTpUuzYsQMvvPACAKBYLGLTpk3ieN4+debM\nmVi7dq3okxwhQgQfEXlHiHCYQ3YrX3PNNfjWt76F+fPn4/HHHxfNPGSUJqSFpb2E7eN5XoCIS0kZ\noL2M//KXv+CVV17BP//5T5x33nl44IEHAvuYponPfe5zOPXUUwOf33///fA8b7fjihAhAkXkNo8Q\n4T2E7u5ujBs3Dq7r4pFHHkGxWNxv5x49ejReeeUVAEBPTw9eeun/t3eHOAoDYRTHHyGYJlwAMAjg\nAFROSC0STCWCIJCYBhwOwxEqegIkuqLBbRN0LQaBxkBZsdkaDJutmeb/05PJ517eZCbz9bYmSRLF\ncazhcKggCOQ4jm63m2q1mh6Ph6SfVv97VJ/nuXa7XdH+z+ez7ve7Xq+X0jTVYDAobX6gSmjeQIUs\nFgvNZjO1Wi3N53MFQaAoikrZezqdKo5j+b6vTqcj13XfGnq329V6vVYYhqrX6zLGqN1uy3VdrVYr\nNRoNLZdLZVkm3/f1fD7leV7xVWq/39dms9HlclGv15MxppTZgarhqRiAj1yvV6VpqvF4rDzPNZlM\ntN1ui3fn/3U4HHQ6nbTf70vZD6gymjeAjzSbTR2Px+J/4tFoVFpwA/gbmjcAAJbhwhoAAJYhvAEA\nsAzhDQCAZQhvAAAsQ3gDAGAZwhsAAMt8AxJ5C+54P8QOAAAAAElFTkSuQmCC\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEcCAYAAADUX4MJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsvXeAVNXd//++ZdrONsqyNBUECxZQRBHUoKLoE+lP0F+i\nxMT4tRDFWBKVxG7UJPaK8mBBE40lQBAVFAQE6bAU6WWBZXvf6bec3x+3zu7M7iw7w+4Onxd/MDO3\nnXtn9rzPp5zP4RhjDARBEATRAnx7N4AgCILoHJBgEARBEAlBgkEQBEEkBAkGQRAEkRAkGARBEERC\nkGAQBEEQCUGCQRDtzLp16zBq1KiE9n399dfxxz/+sc3nIYhjgQSDaHeuvPJKnHvuuaitrY36fMKE\nCTjzzDNRXFwMAHjooYdw5plnYtu2beY+hw8fxplnnmm+nzp1Kj7//HPz/cyZMzF69GgMHToUl19+\nOe677z4AwNixYzF06FAMHToUZ511FgYPHozzzz8fQ4cOxTvvvJPK240Jx3FJ2bc15yGI1iK2dwMI\nAgD69u2LhQsX4sYbbwQA7NmzB+FwOKoD5DgOubm5ePnllzF79uyoz2Mxd+5cLFiwAB988AH69u2L\nqqoqLF26FADw5ZdfmvtNnToVEydOxP/+7/+m4tYIIm0gC4PoEEyYMAFz584138+dOxeTJk1qst+k\nSZOwe/dubNiwocVzbt++HZdeein69u0LAOjWrRumTJkSc9+WCh68/vrruOeee/DHP/4RQ4cOxfjx\n41FYWIh33nkHI0eOxBVXXIEff/zR3L+8vBx33nknhg8fjmuuuQafffaZuS0cDuOhhx7CRRddhLFj\nx0ZZTMax06dPx4gRI3DVVVfhww8/bPFeY7F//35MnToVF154IcaNG2eKJQAsX74c1113HYYOHYpR\no0bhvffeAwDU1NTgjjvuwIUXXojhw4fjpptuOqZrE+kJCQbRIRgyZAj8fj8OHDgAVVXxzTffYPz4\n8U06crfbjTvuuAMvvvhiQuecN28eZs+eje3bt0NV1Ta1cdmyZZg0aRI2bNiAQYMG4Xe/+x0YY/jh\nhx8wbdo0PPLII+a+9913H3r37o2VK1filVdewYsvvog1a9YAAF577TUUFRVhyZIlmD17NubNm2ce\nxxjDHXfcgUGDBmHlypV4//33MWfOHKxatapVbZVlGXfeeScuu+wyrF69Gn/+85/xwAMPoLCwEADw\n5z//GU899RQ2bdqEL7/8EhdffDEA4L333kPPnj2xdu1a/Pjjj7j33nvb9MyI9IIEg+gwTJgwAfPm\nzcOqVatw6qmnokePHjH3u/7661FSUoIffvih2fONHz8ejzzyCFatWoWpU6di5MiRbYpPDBs2DCNH\njgTP87j22mtRU1OD2267DYIg4Oc//zmKi4vh8/lQUlKCzZs344EHHoDD4cCZZ56JKVOmYP78+QCA\nb775BnfeeSeysrKQn5+PqVOnmtfYunUramtrceedd0IQBPTt2xdTpkzBwoULW9XWgoICBAIB3Hbb\nbRBFERdffDGuuOIK0xXndDqxb98++Hw+ZGVlYdCgQQAAURRRUVGBoqIiCIKACy644JifF5F+kGAQ\nHYbx48fjyy+/xNy5czFhwoS4+zmdTkybNg2vvPJKi66ksWPH4t1338WGDRvwxBNP4NVXX231aN2g\nW7du5mu3240uXbqY8RO32w3GGPx+PyoqKpCTkwOPx2Pu37t3b5SXlwPQXE49e/aM2mZQXFyMsrIy\nXHTRRbjoootw4YUX4u2330Z1dXWr2lpeXo5evXpFfWZvw6uvvoply5bhyiuvxNSpU1FQUAAAuPXW\nW3HyySfjlltuwdVXX90uCQBEx4UEg+gw9O7dG3369MGKFSswZsyYZvedPHkyGhoa8O233yZ0bkEQ\ncM011+CMM87A3r17k9HcuPTo0QN1dXUIBALmZyUlJabFlJeXh5KSEnObkQUGAL169ULfvn2xbt06\nrFu3DuvXr8fGjRsxc+bMVrfBfg3jOkYbzjnnHLz55ptYvXo1Ro8ejT/84Q8AgIyMDDz44IP47rvv\nMHPmTLz//vumK40gSDCIDsUzzzyDDz74AG63u9n9BEHAXXfdhVmzZsXdZ+7cuVi+fDn8fj8YY1i+\nfDn279+PwYMHJ7vZUfTs2RPnn38+XnzxRUQiEezatQuff/45xo8fDwD4n//5H7z99tuor69HaWkp\nPvroI/PYwYMHIzMzE7NmzUI4HIaiKNi7d2+TwHhLDBkyBBkZGZg1axZkWcbatWuxbNkyjB07FpIk\nYcGCBfD5fBAEAV6vF4IgANDiNIcPHwagiYcgCOY2gqC0WqLdsafFnnTSSXG3NWbs2LF455130NDQ\nEHP/zMxMzJw5EwcOHICiKOjduzcef/xxDB06NOFrtAb7eV544QU89thjuOyyy5CTk4N77rkHI0aM\nAADcddddeOyxxzB69Gjk5+dj8uTJmDNnDgCA53nMnDkTzz33HEaPHg1JktC/f3/cc889rWqLw+HA\nW2+9hccffxxvv/02evbsib///e/o168fJEnC/Pnz8fTTT0NRFPTv3x/PP/88AKCwsBBPPvkkampq\nkJOTgxtvvBEXXnhhUp4P0fnhUrmAUmlpKf70pz+hsrISgiBgypQp+PWvfx21z7p16zBt2jSzo7j6\n6qsxbdq0VDWJIAiCOEZSamEIgoCHH34YgwYNgt/vx+TJk3HJJZdgwIABUfsNGzas1T5agiAI4viS\n0hhGXl6ema7n9XoxYMAAM0uDIAiC6Fwct6B3UVERdu3aFTPgWFBQgIkTJ+K2227Dvn37jleTCIIg\niFaQ0hiGgd/vx9SpUzFt2jRcddVVTbbxPA+Px4Ply5fjmWeewaJFi1LdJIIgCKKVpNzCkGUZ06dP\nx4QJE5qIBaC5qowJTqNGjYIkSU2qljbmOGgcQRAE0YiUp9XOmDEDAwcOxM033xxze2VlJbp37w5A\nK4sAALm5uc2ek+M4VFQ0NLvPiUJeXhY9Cx16Fhb0LCzoWVjk5WW16fiUCsbGjRuxYMECnH766Zg4\ncSI4jsO9996L4uJicByHG264AYsWLcLHH38MURThdrvx0ksvpbJJBEEQxDFyXGIYqYBGDBo0erKg\nZ2FBz8KCnoVFWy0MKg1CEARBJAQJBkEQBJEQJBgEQRBEQpBgEARBEAlBgkEQBEEkBAkGQRCEDZ/P\nh7lzPz+mY//0pz/A7/clvP+7776DTz75qOUdOwgkGARBEDYaGuoxd+5nMbepqtrssX//+8vwejNT\n0awOQadcQOmR715AqDIHd1w8CdleZ3s3hyCINGLmzNdRXHwUt9xyI4YNG44RIy7Be+/NQrdu3bFv\n3x58+OGnePjhB1BRUY5IJIwpU36JceMmAgCmTBmP2bM/RCAQwAMPTMe5556H7du3IC8vH8899wKc\nzvj91d69u/H8888hHA6jT58+ePjhx5CZmYnPPvsE8+f/B6Iool+//nj88b9i8+aNePXVF/RFuzi8\n8casqDXkU0WnFIzdVfsADvhg0SDcPTm1y20SBNF+fLp0H9bvatuSCILAQVGs+ckXntkD1185MO7+\nd955NwoLD+Ddd/8JANi8eSN27tyBDz/8FD179gQAzJjxGLKyshAOh/H//t+vMWrUlcjOzgZgrbpY\nVHQETzzxLB588M949NGHsWzZUowZc23c6z799OO4774HMWTIeZg9+2289947uPvu+/DPf36Azz9f\nAFEUTXfXJ598hPvvfwjnnDMYoVCoWSFKJp3aJVUfab5IIUEQRDI466yzTbEAgE8//Rd+85tf4fbb\nf4vy8nIUFR3Wt1jC1KtXbwwYoAnTGWecidLS4rjn9/t98Pt9GDLkPADAtddeh4KCzQCAgQNPw+OP\n/xmLF38NntfWVz/33CF49dUX8fnnn6ChoR48f3y68k5pYUQOnQnnKbvAvDXt3RSCIFLI9VcObNYa\nSIRklAZxu93m682bN2LTpg1455334XQ6cffdtyMSiTQ5xj7q53kh5j524lVp+sc/XkFBwSasXLkc\n77//f/joo89w002/wciRl2H16pW4/fbf4uWX38TJJ59yjHeXOJ3SwmD+HACA5CDBIAgiuWRkZCAQ\nCMTd7vf7kJWVBafTiUOHCvHTT9tj7teaMn1ebyays7OxdWsBAGDRoq9w3nlDAQBlZaU4//wLcOed\n0+H3+xAMBnD0aBFOPXUAbrzxZpxxxiAcPlyY+A22gU5pYagBrYCWH1Xt3BKCINKN7OwcnHvuENx8\n8/+H4cNHYsSIS6K2Dx8+EvPmfYHf/OZXOPnkU3DOOefatloxDC0gnTgzZjyO559/FuFwGL1798GM\nGY9BlmU8+eQj8Pv9ABhuuOFGeL2ZmDXrLWzatAGCIKBfv1Nx8cWXtHj+ZNApq9Xe/MQi+PsthuCQ\n8dpVT7X6i0knqBKnBT0LC3oWFvQsLE7IarXvPzoGQiQHTIigLlLf3s0hCII4IeiUgsFxHFyqppSV\nwep2bg1BEMSJQacUDADwIBsAUBGgOAZBEMTxoNMKRqagZUrtrYif20wQBEEkj04rGJeeoeVmbz96\npJ1bQhAEcWLQaQVjxOn9AAaEWOKVIQmCIIhjp9MKhsAL4FQnFK752ZMEQRCtoS3lzQHg008/Rjgc\njrnt7rtvx+7du4753O1NpxUMABCYE4yPQFaaLzlMEASRKM2VN0+Ezz77GOFwKIkt6jh0ypneBg7O\nBUkMoCEgoUuWq72bQxBEGtC4vPm0adPxr399iO+//xaSJONnP7sct9xyG0KhEB599CFUVJRDVVXc\nfPOtqK6uRGVlBe6++w7k5ubilVfeinudb7/9Bh999D4A4OKLL8Gdd94NVVXx3HNPYffunQA4XHfd\neFx//S9jljhvDzq1YLh4D4JcFap9fhIMgkhD/rPvS2wu39amcwg8B0W1Clqc3+NcTB44Nu7+jcub\nr1+/BkVFhzFr1hwwxvDgg/dhy5YC1NZWo3v3PPz97y8DAAIBPzIyvPj3vz/Ga6+9rZc7j01lZSVm\nznwd7733T2RmZuHee3+PlSuXIy8vHxUV5fjgg08AwCxnHqvEeXvQqV1SHkFbMKTCR7O9CYJIDevW\nrcX69etwyy034pZbbsThw4dQVHQYp546EBs2rMPMma9jy5YCZGR49SMY7GXOY7Fr108YOnQYsrNz\nwPM8rr76WhQUbEbv3n1QUlKMl19+HmvXrjbPGavEeXvQqS2MDIcHCAM17ai4BEGkjskDxzZrDSRC\nW2tJMcYwdepvMH78pCbbZs/+CKtXr8Lbb7+Oiy66GL/5za0JnzNWGb+srCy8//7HWLt2Nf7zn0+x\ndOm3ePjhR2OWOD9ea2DY6dQWRpZTU9/qABUWIwgiOTQubz58+MVYuPC/CAaDAIDKygrU1NSgsrIS\nLpcLY8Zci1/+8ibs2bNbP96rV5eNz1lnnYMtWzajvr4OiqLgu+8W4bzzhqKurhaqqmDUqCtw6613\nYu9e7ZyxSpy3B53awsh2e4EGoC7U/JdDEASRKI3Lm0+bNh2FhYW4447fAtAE5ZFHnkJR0RG88cYr\n4HkOoujAAw88DAAYP34iHnhgOrp3z2sS9DYqa3fr1h233/573H337QCAESMuxaWX/gz79u3FM888\nAcZUcByHO+64O26J8/agU5Y3B4CKigYs3rcK8w/PxymRS/Cnaye0d5PaBSrdbEHPwoKehQU9C4sT\nsry5QY+sLgCABpl+DARBEKmmUwtG36x8AECA0VKtBEEQqaZTC0ZXTxdAFRAWKK2WIAgi1XRqweA5\nHqKcBdXZAEVV2rs5BEEQaU2nFgwAcLMccLyKcl9tezeFIAgiren0gpHt1KL+e0vL2rklBEEQ6U2n\nF4zeOV0BAHtKy9u5JQRBEOlNSgWjtLQUv/71r/Hzn/8c48aNw5w5c2Lu9/TTT2PMmDGYMGECdu7c\n2apr9M/LAwAcqa5sc3sJgiCI+KR0prcgCHj44YcxaNAg+P1+TJ48GZdccgkGDBhg7rN8+XIcPnwY\nixcvxpYtW/DYY4/h008/TfgaeZm5AID6CNWTIgiCSCUptTDy8vIwaNAgAIDX68WAAQNQXh7tOlqy\nZAkmTpwIABgyZAgaGhpQWZm4tZDt1KbIB1U/1M45aZ0gCKJTcNxiGEVFRdi1axcGDx4c9Xl5eTl6\n9uxpvs/Pz0dZWeIBbCPozYQwGgJSchpLEARBNOG4FB/0+/2YPn06ZsyYAa/XG7UtVikro0BXcxg1\nUbqqGdoxjjCYwLe5Vkpn5ES853jQs7CgZ2FBzyI5pFwwZFnG9OnTMWHCBFx11VVNtufn56O0tNR8\nX1paih49erR4XnsxMSc8CDnC2H+oBrnuTl2At9VQYTULehYW9Cws6FlYdPjigzNmzMDAgQNx8803\nx9w+evRozJs3DwBQUFCA7OxsdO/evVXX8IpecI4I9hyhyXsEQRCpIqXD8Y0bN2LBggU4/fTTMXHi\nRHAch3vvvRfFxcXgOA433HADRo0aheXLl+Pqq6+Gx+PBs88+2+rr5GXmokauxNKNhzDh0n7IcDtS\ncDcEQRAnNikVjAsuuCCheRWPPvpom66T49LMLFUIo6o+TIJBEASRAjr9TG/AypSCI4x6f6R9G0MQ\nBJGmpIVgZOlzMThHBHX+cDu3hiAIIj1JC8EwLAzOEUIdWRgEQRApIS0EIy+jGwCA9/hR5yPBIAiC\nSAVpIRi9vb3AgQOfW47qAOVbEwRBpIK0EAy36AIDA+8O4oC4or2bQxAEkZakhWAAwPCeFwAAgq7i\ndm4JQRBEepI2gvGL08YBAHjJ28KeBEEQxLGQNoKR4cgAwplgnNzeTSEIgkhL0kYwAIBXBTCeBIMg\nCCIVpJVgcMwB8ApUprZ3UwiCINKOtBIMXi+NFVFoISWCIIhkk16CwbSig2GFJu8RBEEkm7QSDEG3\nMMIK1ZMiCIJINmkmGJqFEZRC7dwSgiCI9COtBEPkNMHwR0gwCIIgkk2aCYYTABCUyCVFEASRbNJM\nMDQLI0AuKYIgiKSTVoLh5HXBkMnCIAiCSDZpJRgO3SUVIguDIAgi6aSVYLh4FwAgKJNgEARBJJv0\nEgzBAwAISMF2bglBEET6kVaC4RbcAICgTIJBEASRbNJLMETNwggq5JIiCIJINmklGB7dwggpZGEQ\nBEEkm7QSDJdDBFMEqiVFEASRAtJKMESBA5MdCKvkkiIIgkg2aSUYDpEHFAcijASDIAgi2aSVYIgC\nDyaLkFiEVt0jCIJIMmklGE6HACh6iXOavEcQBJFU0kowsjwOMFkvcS4F2rk1BEEQ6UV6CUaGA0zS\n6kn5JH87t4YgCCK9SDPBcAKyLhgRXzu3hiAIIr1IK8FwOwVwClkYBEEQqSCtBIPjOHjEDACAL0KC\nQRAEkUxSKhgzZszAyJEjMW7cuJjb161bh2HDhmHSpEmYNGkS3nzzzTZf0yt4AQANErmkCIIgkomY\nypNPnjwZU6dOxZ/+9Ke4+wwbNgwzZ85M2jWznF7UAqgPk2AQBEEkk5RaGMOGDUN2dnYqL9GEHHcW\nAKAu1HBcr0sQBJHutHsMo6CgABMnTsRtt92Gffv2tfl8mW43mCKgLkKCQRAEkUxS6pJqibPPPhvf\nf/89PB4Pli9fjt///vdYtGhRm87pdTvA6jNQI1aDMQaO45LUWoIgiBObdhUMr9drvh41ahSeeOIJ\n1NbWIjc3t8Vj8/KyYn/eNROsPANSRgMcWQxdPMfXJdYexHsWJyL0LCzoWVjQs0gOKRcMxljcbZWV\nlejevTsAYOvWrQCQkFgAQEVFbJcTUxSooQwIAHYeKcRpXU5tXYM7GXl5WXGfxYkGPQsLehYW9Cws\n2iqcKRWM+++/H2vXrkVtbS0uv/xy3H333ZAkCRzH4YYbbsCiRYvw8ccfQxRFuN1uvPTSS22+ZoZL\nBAtplktFsDLtBYMgCOJ4kVLBeOGFF5rdfuONN+LGG29M6jW9btGsJ0UFCAmCIJJHu2dJJZsMtwNg\n2m3JqtLOrSEIgkgf0k4wPG7RFAyFyQkdwxhrNtZCEARBpKFgeN0imNo6C+Optc/jlc1vp7JZBEEQ\nnZ52TatNBR6nCDBt7oWcoIVRFqhAWaAilc0iCILo9KSdhcHzHDwOLeitUAyDIAgiaaSdYABAhtsF\nAJDVxCwMgiAIomXSUjCy3JqFISVgYahMTXVzCIIg0oI0FQw3ACAsSy3uS24rgiCIxEhLwcj2aC6p\nkNSyYMiMBIMgCCIR0lIwMj2ahREhC4MgCCJpJCQYX331FXw+bQW7V155Bb/73e+wffv2lDasLeRk\naBZGRG456J1o6i1BEMSJTkKC8dZbbyEzMxNbt27FypUrMXHiRDz99NOpbtsxk5OhWxhK6ywMCoAT\nBEHEJyHBEEVtft+qVaswZcoUjBs3DuFwOKUNawtZXieYykFKIK3WHsMg9xRBEER8EhIMjuPw3//+\nFwsXLsSIESMAAFICAeX2IlMvQJhIaRC7SCgUACcIgohLQoLxl7/8Bd988w2mTJmCk046CYWFhRg+\nfHiq23bMeN0ioPIJCYA9hqGQS4ogCCIuCdWSGjp0KN58803zfb9+/fDII4+krFFtxevRLIxEBIMs\nDIIgiMRIyMJ47rnn0NDQAFmW8atf/QrnnXce5s+fn+q2HTNupwAwHmoiFoZKMQyCIIhESEgwfvzx\nR2RlZWHlypXIz8/HokWL8O6776a6bccMx3HgwENFAhYGIwuDIAgiEVo1cW/9+vW4+uqrkZ+fD47j\nUtWmpMBDAEPLMQl7gUKyMAiCIOKTkGB069YNf/nLX/DVV1/hkksugSzLUJSO3bkKnADGqVBbWEkv\n2sKgoDdBEEQ8EhKMF154AQMHDsRLL72EnJwclJaW4re//W2q29YmBE4AOBWhcPPCJlPQmyAIIiES\nEoyuXbvipptugtfrxb59+9CzZ09Mnjw51W1rEyIvgOMZGgLNTzC0i0SiS7oSBEGciCSUVrtt2zZM\nnz4dTqcTjDHIsozXXnsNZ599dqrbd8w4BAcAoMYfQn5Xb9z9yMIgCIJIjIQE469//SueeeYZc5b3\nmjVr8NRTT+GTTz5JaePaglMQAQWo9Yea3U+xTdxLJA2XIAjiRCUhl1QwGDTFAgAuvvhiBIPBlDUq\nGTj1+ld1gebbGT0Pg4LeBEEQ8UhIMDweD9asWWO+X7duHTweT8oalQzcorZMa12gJQvDFsOgUucE\nQRBxScglNWPGDNxzzz1wOvW1siUJr776akob1lY8DkMwAs3uFx3DIAuDIAgiHgkJxuDBg7F48WIc\nPHgQjDH0798fY8aMwbJly1LcvGOniycbqAXqwr5m91No4h5BEERCJCQYAOBwOHD66aeb71kLE+La\nm64ZWQCA+hYEQ6bSIARBEAlxzGt6d/TSINkuTTBCavMuKYWKDxIEQSREsxbGvn374m6TE1gvuz3J\ncmYCACJoIYZBFgZBEERCNCsYt912W9xtLpcr6Y1JJlkOTTBkNJ8lJavWyoEU9CYIgohPs4KxdOnS\n49WOpJPp1GZ3K0LzpUHCSsR8TRYGQRBEfI45htHRyXRogsGEMBRVhSSrqGloKh4h2fqMYhgEQRDx\nSVvBEHkRHBPBCRJCEQUvf7YF97+xCrW+aNE4UFpjviYLgyAIIj5pKxgAIDAR4BWEIwp2HtKEoSEg\nRe3jC1ulQ+yLKREEQRDRpFQwZsyYgZEjR2LcuHFx93n66acxZswYTJgwATt37kzq9QXOAfAqghHL\nchD4RunAvLUtKDcfICcIgjiRSalgTJ48GbNnz467ffny5Th8+DAWL16MJ598Eo899lhSry9wIjhB\nwaqtJRC6lsB5xjqE5UjUPpxoWRUkGARBEPFJqWAMGzYM2dnZcbcvWbIEEydOBAAMGTIEDQ0NqKys\nTNr1Rc4B8Aq+WXcYzoFbIORUY1/DXgDAtxuO4I7nlwG8AjXsBgAE5I5dgZcgCKI9adcYRnl5OXr2\n7Gm+z8/PR1lZWdLO7+Qd4HgVgFXGpCpUDQD4+Lu9iMgqwMuA5AIYcNRXgtpwXdKuTxAEkU60q2DE\nqkeVzJIjDl5bdc8epygLldgupoLjGZgigmMOVIdq8OdVf03a9QmCINKJhIsPpoL8/HyUlpaa70tL\nS9GjR4+Ejs3Ly2pxnwyXG5AAzmW5mioj5fDk8HCduR5SST/tQ0UAVN6Uz0TO3ZHobO1NJfQsLOhZ\nWNCzSA4pF4zmqtqOHj0a//znP/Hzn/8cBQUFyM7ORvfu3RM6b0VFQ4v78EzQ/s+w9q0JV2PJzjXg\ns6vgyq7S2qiKYLyVbpvIuTsKeXlZnaq9qYSehQU9Cwt6FhZtFc6UCsb999+PtWvXora2Fpdffjnu\nvvtuSJIEjuNwww03YNSoUVi+fDmuvvpqeDwePPvss0m9vkvQFlHiXFYBQgaGf+76PHpHRQD42HWk\n/FIAHtENnkvrKSsEQRAtklLBeOGFF1rc59FHH03Z9Z2GYDi1dFk1kBVlbRgwVYh5fGWwGk+u+Qem\nnD4el/UZEXMfgiCIE4W0Hja7BC3obcQwlOr82DvGEYxSfxkUpqDEX56S9hEEQXQm0lswRK0EO+fU\nBaO2B/p5Tmu6oyqALxze5OMGyQ8ACMvNV7wlCII4EUhrwXCLRgxDn8EtO9DV0TQLi6k8uIYe6J99\nSlSswhfRlncNKyQYbUVVGRSV1hshiM5MmguGbmHoAW0mO83MqShUHqrKIPA8VKaamV0+3cIIkWC0\nmfvfWIXpr6xs72YQBNEG2nUeRqrxOJzWG1XQXE9cDMFgAhTGIOjbVKZC4ARTMMjCaDt1/kjLOxEE\n0aFJawvD47CWkeUUTTyatTB0wTDWxfBFdAsjRgxj/9E6vDF3G8IRWkODIIgTg7QWjAyn23zNq5p4\n8FxTo4rZXFKATTBMC6Pp6PivH27Ext0VWLmtpMk2giCIdCStBcMreszXAtMD4LEsDMZDsVkYsmpY\nGC0HvRWFArkEQZwYpLVguEXLwjAEg48VtlG1x8A3dklJ2gzx5oLe8QufEARBpBdpLRgem2A4OD2e\nwZresjHDQ/F7AAAgAElEQVTTW9HDEXuO1EBSZYQULR1XVmUoKsUqCII4sTlhBEPUBSNm0FsXkcMl\nWsziixV74dfjFwaUKUUQxIlOWguGyFvuJ5cuGDFjGLpLqs6nVazNyXKgIRItGDQXgyCIE520Fgw7\nTkGzNjjEF4yIpEUkvB6hiYURK7WWIAjiROKEEQw3r1sYMQoNMt3qYLpwhCXZzJDyihkArBTbpscm\nvakEQRAdkhNGMFy6haEqMZaA1YUCTNsWliWz8ODJ2X0BANWhmqhDhG7FcA9bBL+anmuAl9UEUOtL\nvlXV3IJaBEF0bE4YwRD1dNqv1xxtutEUDO3/kCybFsXJWbEFw3HqVnA8Q6G0LUUtbl8efnsN7nt9\nVdLPS3pBEJ2XE0YwOF63LNQYt8yiLYyILCMoaym1fTJ7AQCqQ7XRh4S1SYFBRks/tgaVFIMgOi1p\nLxhy2ckAgGxeXys8xjwMQBcTm2AoqgwA6JGhHdfYwjAEI8Dqk9zi9IZcUgTReUl7wegdvgjBddeg\nb9cu+icxYhgGuphEFBmyPtvbI7qR5chsIhjGefxqHVRG5UGawy4SKukFQXRa0rq8OQA89KsLUFUf\nQqhRVVnGAK6RdjDdwpAUGbKqvRZ5EV3dXXDUVwyVqeYCS5ygWSAyIjjqK8VJWb1TfCfHj2RbAXY3\nlEqKQRCdlrS3MFxOAb27e8Hb7jS48UqENl7VdGfdwuA4hoisTeITOAFO5oXMFDToqbbaBtl8uat6\nT0ra3l4kO85gPx25pAii85L2gmHA280JxQmoMYwr3cIApyKiaIIg8gJ27NUC4Ha3FCfIYLIIDjzW\nlW5Kq44w2Sup2q2KjmxgbNxdgcNllMRAEPE4cQSDbyZ2YWAExDkGSRcMgRPBIlqA2xAMSVYAQQYL\nZyCf749ifymK/aUpaXd7kGy3kV1LO2qWlKyoeGPuNjz+3vr2bgpBdFjSPoZhIMQQjOCmK6PTbG0W\nRlh3SfEcb2ZEVYdqMe+HA/jvqoPwXKRAVQR4oQXTG5cS6cwku1O3n491UBND6aDtIoiOxIljYTSO\ncAOAHO2aMkqDgGOIyDI4cJAkBiZpa2k0RHz476pCK36hiDAeoWwrf76zeg8KKran4jaOC8mPYXR8\nl1Q6uRQJIlWcOIIRw8JwORvVlTItDIaIIkHkRS27Sq8/JTNNKIwMKaaI4Fj0sq4A8HrB/2HWtjnJ\nvoXjht0llYyOVO0EQe9kx20IIh05cQQjhoXhaSIYepYUr0JSZIi8gFBENoVE1ifzRVkYrKmFYWDM\nFu9s2AUjGa6aqLTajioYHbRdBNGROHEEI4aF4XE1CuHYLAxZVSBwAgJh2YxzGIFwu4VxqFSLXZhi\nYqM23DkLE9o1QlHa3pGyTpAlRfNDCKJlTmjBcDtjC4bgrYfMFPgDCv46ZyMYixYMiFpAHLIT/oBm\nWRgzw+2zvjutYERZGG331US5pFrRMW/YVY7f/W0pSqpSn1BAFgZBtMyJIxgxYt4ZrmiXlOrPQaYj\nE3yXMiiiD4oSXbAwohoWhiYYTHaYLimj9lRYiZjn++7QcqhMBWMMSzcVobQ6kNR7ShX2zlNOhoVx\njC6pd7/aCcaAZZuL29yGlkimhaGoKpZtPor6QKTlnQmiE3HCCEastFp3Y5eUKmJkrwut92YV29gW\nBlMcZmaVYWGEbHGLXTV7sbN6Dw6U1OOjxXvwl1lrk3ErKSdeDONQaQO+WL6/1aNxtY1ZUrES3Foi\nLDWNKTVHMgXjhy0lmLNoN96c23kz5QgiFieMYHAxg95Np6FkOb3WG0MwDFFQNaHgTJeUPehtWBjR\niw4V+0oRCGnbVCh4o2A2NpVvPfYbOQ7YO3hFsVxSywuOYuHqQyipap2ldKwuqWPtwn/YWow7X1iO\nzXsrEj4mmS6pitogAOBAMVUyJtKLE0YwYloYjbOkAHgcGdYb1TiGB2OApGdCGYLBZKdtlT5NFEK6\nYAicdu4Sf5lZH5f31mFH9W7M3v5RW28npcSzMCKyJh6hSNMAf3Mcq0vqWBXj2/VFAICVW0sSPiap\nMW/zp0ZxESK9OGEEI1bQWxCafpYheszXzL52hsrb0mptMQzd+vAFNVdUSNYEY8wpV0DkRZT4y6zl\nNljT69U0hBFppfsk1dhFQra/1q2NcOTY3T0dNbaczJnenP6F2+91b1EtZi/cYT7D9mDVthJ8vmx/\nu12f6Pyc0ILBxVgbwy4YUYstMR6SEfQ2XFK2eRiBiPaZYWFkODzIEbvgaH05Xvz3Fu043upoGWOo\nrg/h/jdW4c15HcvXbe/o7C4pST42wYhXSyoQkrBxd0XcyXxMH6G3NobBmQLdijYmUzBiXP/ZjzZh\n1bZSbNlXlbTrtJbZC3fiqzWH2u36RGIEQhLmfLML5bprsyNx4ghGgr1OhsMuGBxGnN0TYy48CVB5\nax6GKIGpHKAKtpRbXTD0oLdbcKG8SoLMJJiuCcHqaINyEAdLtMqoW/e3XycSi6gYRpSFob1ubUA5\nnkvqrXnb8cbcbVi/qxwAsPtwDZ58fz3qfNFxoFjC3hzHECNPWgxj/9E6LFwdv1OWlPa3JimFuGPz\n5epDWFZQjDf/s629m9KElAvGihUrcO211+Kaa67BO++802T73LlzMWLECEyaNAmTJk3C559/nuom\nIcMl4vSTcmP2LN6oGAaPrAwHHCIPxmwuKVECZAcAzoxzGGJiWBhu0Q2oAjieAZw+UrZZGIX1R9Bg\nS7sMyiGsOroWSowZ45v2lmHBmn1tueVWERXDUJq6pEKtzUCKCnpbr38q1Kr/Fldq8yx2FNagsLTB\nFNLGIYBV20pQVpOa1ORkdaJ//XCj+Zp10BgGTVLs2BhJMnX+jpeWndJqtaqq4qmnnsL777+PHj16\n4Be/+AVGjx6NAQMGRO133XXX4S9/+UsqmxLFq3+4DDzH4bPvm3bCjV1STgevWSeqzSXFK2BG0ULT\nwtA60bAew3AJLr04IbRSIrIzatGlN7bMxmnCRQC6AgDm7luIVcVrURWqwfgB10a16Z0d70LIqsGo\nwFPIznA1afPhsgYcrfBjxDk9W/8wYqDEmbgnHWMMI56FwXGa28b4LCJr5/WHpOgTcEBRhQ+zF+4E\nALz70JXNX9B0CSXeMaakWm2sU3aAvlpRGMSm+R5EB8HwnneAn0oTUmphbN26Faeccgr69OkDh8OB\n6667DkuWLGmy3/EuSGe6p2JYGA7eYb5mjIfLIUAUeIDxCMu64vNWQUJTMHQx8ellzr0OD5i+j2lZ\n6P93dWsl0asjWtqnxyWgMqi5pbZX7WzSJiFLG4nvKyuLeT+Pv7ces77c0bSjPUZYnIl78jHGMKLK\nm9teG9+DoUkRSXvhD+pJBbZzGKOu+NdQ8fyqt7GudJMVdNa3lQcq8fKmmagKNl6X3SIVy7J3UL1I\nyuz9YyEQknCwhFKNW4QzkiY6wq8lmpQKRllZGXr16mW+z8/PR3l5eZP9Fi9ejAkTJuCee+5BaWnq\nFiK6dvjJ+OVVp8XdfuGZPaLna6gcnKIuGCoPcCq6ZLnA8aqZHWVO3NNdSUY5kFxXDqDooqILhSEc\nkwZeBwAI6gKU5XGaIlIeqIzbvgOVzT8bo8NtK/FKg5hZUq2OYcQ+t5GIYFoY+nl9IQm+oCV+HGIn\nLdgpD1RgXVEBPtjxSZNtH+38DHtrD+CLvf+Ne/xx8+u34jLBsIxguPUpzC11NO219seTH2zAUx9s\nMOepELExfukdUC9S65JKRCGvvPJKjB07Fg6HA5988gkefPBBfPDBBy0el5eX1er2/P7686Pee3X3\njihw+Ntdl6Ffr2w4HQLcQgZCiuYr79Y1AxFJBSvlwfEq8rt5cIhXwQwxMFJleRV5eVnwq35wHIeT\n8/NNC8N0RelB7755eQAAf1hzXzkcAjiH9qwkVcL+0F5c1Oc88Hy0nleEapq97+wcD/K6eWNuKyyp\nxyNv/4iHfn0hzj61W7PPqaja+oP2ZrrNaxrfJifwLT7/3YeqwXEcTj+5C6oCVuefneMxjxV4DhIA\nt9uhfSZo9/vlj4fw5Y9W4Dgjw4luXa37inXtBsGyHhwO7TxOp4i8vCwonL4YliP+76akzpqh35rf\nViAkweMSY04MjXWuzCztee46VI1Fqw9h2i+GwCHGHreNu38+AGDBCxMSaosvEMEvH/kaYy/tj9sn\nDY66tj2dN7eLF12z3QmdM5mU12i/KyYIx/T32xaO9/XagsejeTk4jutw7U6pYPTs2RPFxVYdoLKy\nMvTo0SNqn5ycHPP19ddfj+effz6hc1dUtH3t5WBQG+EzBnTxiKir1UQii89BSAmAc4YQDkmaC0bl\nwXGAIOqdfyOXVEiSUFHRgEpfNbIdWSgr85kxDI5XwABwvHZsxMfAgYfC6YHysIQan2Wqv/jjLNx1\n3q0Y1PV0RGy1qUrrK5rct/0HVVJWDyGOu+H9/25HbUMYr/17M566dXizz6XaFliuqQmY1wyFdSuq\nPhTVjj1HarFs81H89ueDzM7vgVd/AKDFG2qqY5/P6GN9/jAqKhrQ0Cg7yiAQjKC21jrHrP9swbkD\numFAb+u3c7TassxM11lYRkVFAyKS9pyliBL3d2O/55Z+W2t+KkVOpgsOkcczH27EL0efhqsvPCnm\nvo3PtftgFQ4eqcG8lQcBAKf3ycawM7W/CZWxmNl8if7W9xdr1u2XKw/i9kmDo46zWyrl5Q1Qwslx\nXx4LtbUBVFQ4j9v18vKyktJfHC9CumtZVdWkt7utApRSl9S5556Lw4cP4+jRo4hEIli4cCFGjx4d\ntU9FhVW+YcmSJRg4cGAqmxTFWf20gPMVQ/tEfZ7j1DoizhWEUxS0CX66MDhdWmdkWg+qVXyQMYa6\ncD1yXTnarGjTwlCi/j941A+m8KaLKhRRcKAsOrW2Pqz9UHy2pV+DsERlR2E1bnluKXYVVpuf2V1S\nQTmIz/bMR0PEpzWTGXMaWk46jSo+aA96c34ArIlL6rl/bsKaHWUo2NfUnSYratxaUlYMI9ol1aQ9\nKovK1vrvqkL8dc7GqH0CsmUVMS76PEYFYaWZQEVr5mG8s2AH/vHxZmzao/12P16yF8s2H03o2K/W\nHDLFArC+j6/XHsKtf/selbq7Jl7cpzliVTMwsD/b9ophdAa2H6jCii2pL3bZHJwZw2jXZsQkpRaG\nIAh45JFHcMstt4Axhl/84hcYMGAAXn31VZx77rm44oor8OGHH2Lp0qUQRRE5OTl49tlnU9mkKAad\n0gUv/P4S5GZGj3Z6ZuZhn38XOF6By8EjIvOmMBRlaCPnxhaGwhT4JD9kpsDDe3GkrMF0WxmWhRHL\neG/hPrjPFgBBQddsF6rrw3ApQXDMiQGu83CArTNX9/NL1sjX7zqChogPWc5MfLxkLwDgw+WrAUcI\nkNxRncJXB7/DsqJVOOorwR+G3mF2ynwCQwSmMkCMACpvdtQH6gohnfEtxOL+CEeiXVqc2w+hSxlU\n9awm56ppCMfNkrJiGNp7o/RIYxSFtdjJ2Z+TKuiuD92JVhfSRLesPn7QO9EYhr0dWR4rQWLOot24\n/Pw+sQ5pFpdT+0I++16bgb1lfxVGX9A3KsgvySqcjpbTmuwDhsaps2Hbs23v9cuPpZjk8eLFT7VJ\ntj8b0rvd2mDGMIz/GcPhMh/65Hm1eGo7klLBAICf/exn+NnPfhb12fTp083X9913H+67775UNyMu\nXbKapqlOPP1qLN+5B1JJfzjO4iEKvDlBz8fpFpEhGODAmCYYxmh++x4/Nh/aAqGbkVarB70NS0MR\nwFTNwuia5UZ1fRicIIHJTuzcE4HrNGteh3FOJjnAOSSsKdmAq0+5XAsK8zL2ur+B53wguO7aqJG/\nsWTsoYYi/b0uGLYMjHjWhqKq8AxdCjXshqKeAwDYXa11aI7eBxEqvihqf9fZP4ITFBRHDgKITu39\nfN9cfW6J1pnaR/JxLQxBgpBbAaWqFwAOispa7OQCNsFo6FIAFA0BGKCoCoKKH+CA2kj8DB1FjR7R\nx3s2stx0XkpbaDwp0RBXvy3oH4ooCQmGZBOFmobo1R6jLYz2FYyOOHJujKqyFhMtUkajWmTrd5Vj\n5vyfcOXQPrhpzBnt0yadE2amd2vwONyI7DsfzJ8LABB5DpwjehKNkR2lveGhMsWsVMv02AVrlCUF\n3hb/UAWAV3WfPwNEOao2VUTVrlcf0gRDLusHMA4bywoAaB1K4zbZR+heUZuAaMRAmGlhcFh86HvM\nWPV0lLvLjk/Wrsm7QmZpEJG3OqzGabWGEMpoOtFob8Mu7Pb9BHDaeaJcUk2ypLR9nP23wTlgK4S8\nIwA0AWtRMGwuqbCnBJxX8+f7pID5B6jy4ZgrIwLRa3o3Z21INpEoScL6Jo1Fx7i0L2QXjMQypYx5\nLABQ2qiisN36SMYqiolSWh1oMgGtvQUrEaQ41u7xoHEtst1HagEA63Y2zTA93pBgxOGZ2y7GpMv6\n47STciEIPDhno1RA1TbiU3koULG7qCp6m/4/7w4AnArOFdRFhNMFQ4HLof3PcQxQHOYxRqmRyqDW\n8alBL5i/C474iiGpsjY/QrT9IYqRqFFkWLW2qUw1O2qe4zB//9eojzRgdfH6mPdeI1nxFKP4oGHp\nAEBIit2BSUwTTCvXniHMQlChgvPolpIhDkoEkZNWg8+pMMXM6PB4fd4J79XOoygsZicXDMtm525Y\nGMPyz9MeR14RGLRYjp36SHQQsSZUi/21hdFus2Y6NPszLqlsWTBueW4pNu6O/4feeIEqNY6FkQj2\nTs5+vLbNOsfxSiFmjGHGO2tw72sroz5XErTMUj0PISwp2FFYbT4P+/Xs4nu8MWuR6e+NZnUEVx4J\nRhx6ds3AuEv6g+c4OAQOnKsZwWA8GBSs31MStc2wNMT8w3CctFsbsVf1wvCztJRbjlchijAtBfsK\nfuV1flTUBrFgk7Z2BgtmQY1o5zOsBrMIIgBnv59QHCwy39sXclpyeAVUPSbCc0A3txbs31a5I+a9\n19kEw+ioa8OWOyckWokKhrABgE9uwM5DNXjqgw3aB7b28Rna8cYf5+6afVAzy+E6Y6M54jRGwcw+\nQx7QXVJNO5nfv7QCr3+h1dvx68Iw/lRtljznCAGMNRGMunC0W+qx1X/Di5vehF+2xUD0Symq0qRU\ni6SoAKdC7HkQxbVWTCTbGz/rZ86i3XG3Nb4v08I4BsGwWxGSrKLWlnUWbgcLI968oHirODYEIman\nXV4bxO/+9j1WbYsuUb+vqM7MBmsrc77Zjec/KcCan0qbtPd4WBjltUE88OYq7C2qjfrcKl4ZLWSJ\nJKykGhKMBBAEHkp1tG+eRQkGh7As4XBFbfQ22z5Cdy2LRqnqhT7dvRh0UncAgCgycE6tc2cRt+nG\nWrPzKB6cuRp8Rj2YIoCFMsxyJIYY2F1SQtcyfF8zD/uOan9MQZtgzNv/FeoytE6L5zm4BK1zO1h/\nOGo/QAtu7w/9ZL43Rlp1YWtkLuVr1XUDIQlPrPmH+Xm1VInVRwrM95zNAuIztOMNwbB3xEfFjTjS\nUGyN6owZ8oarq5kYhpGZZVgYOa5s/YLasQH9/pik3XNdIwvDiPUEFMuCUvXJbw+seBSvFcyK2l+S\nVYh99sJx8m6gj/accjKdZipvLLo1M+ehqUvKsDAsKy5Rl5TdXbZ43SHc9/oqFJZqAtkeWVLxKg/E\niv3sOVKLe15dif+sOAAAWLNd68SNcjAGz3y0sUmGXCzqfGHUNMRO0zbYul/77Rws1n4TgXB0okGq\nWbDqIKrrw5g5/6eoz824lv6TtwSj9dc4WuHD58v2J+07J8FIgByvE9LBcxAptGUBKZYYaAFs1YpV\nGBZGyAsW0YLqnD5/g6kC3E4BXpfWiQgOm2CE3QAz4h4qwCngPH6ogSyYbiwADSFdMMRGMQNexfwf\ntD84Y10OA5nXOlSO40x/v8pU7KiyRr+MMby48S1Uy5YLxReQsLNqD/bVaUFvpa4buIw6bDlUhLte\n/gE1YWt0dCS0HxsjX4PTrQm7BcRn1gK8bJbgsLuGKt3b8fGuL6yRp25lGfenKGqLo+KAHESGwwOR\n10rOc4IClVkuKTWoTfybtW0OygNNV+Lzyz4Yf6EqYwjIQURUCXtrD5hpuYDWkQi52vGcU3vGWR5H\nVPpxY5oXjNguqZYsjFjuGskmCkfKtOdbU6+10e5ikY9TDMEXjC0YscR/2wHNqv16zeGkXPve11fh\n/jdWNbuPEUMzEkKOt2Bk6ll2TZ6TEXMzBUP7P9GK23aemrMBX605hI27E199sjlIMBIgv2sGHvv1\nCPx53HXWh3pw+qW7LwWLeMA5w+Yo+srzTtb2YTxCWy+LPpkqwO0U4eK1Ea8gKmZ8hEU85nnBK4Ao\ngeMYWETrcAzrwxcOAWBaJwwgvGeotl1yIcOt/QgbWw6GEPE8h5AcAqf/+2zPfEiKBEVVsXD9XjMV\nlelVeOsCIawqXqt/xkOp1WapL9tfgHhwDv3adgsjsw6us9aYHWLjWIKTt7l09OM4dwDgFCgqMztk\nvksphPzCJtf0SwFkOrVAP8e0uJCiqKZgsGCmue+yoqYdyfd1/4F76BKAU6GqDA229hX7rJIskqxq\n7QIA2QEOgNspQpaZfm/MDPAbZGVo34nj1K1wnBztBmw82pbMVQ2tDn7TngoU7LXmuDw6ex2mv/JD\nk3uwJz3U+bRnaIjDdxssd+Xxqlbrj1P/K5aFITRKgEgWzcVBGmfpBW3ttT/LzXsrzPkxycSr/602\nFifr+7EGMEDiFgZjDDsLqxGKyKabraVabIlCgpEgp/TMQj+9pAegFSn87f+ciRyvE0q5NstX6KGN\njjJdthGlKkIN299rFoZT0Kf/Cyo4l80lZRYsVK2ihYZPX9+29WAZ+NwKCF21YoQs5IUayAQnSvC6\ntX2DSgheMQMPXHCXdgpO74Q5hpAShlzfBVJlTzRIPtSG6/HV6kOYu0brzOSyk6HWafdaFwiA5/TM\nrf2DwfzapMafiovMjlGp64bIviHmLRqussZZXHyGDypjqAnV4utCrQilGtKqAwucNafBsEw4UYLn\nwm8R5GrNUanrtAI4T9llrnpoEJACyNTXY+f0DDRZZQhKIf06lmB0cWnZb/b4i3Y9GRAjUBlQbwvy\n76zeY7tOSLMmAUCMgEErLaMyBklS4Th5FzwXLgbfpRTiSbugTXRUAV6G2L0YYs/oEbSsMESUCBwD\nCsB5a82OKmxLLFi3sxyvfmGtA19U4YM/pE0U3by3An94bSWq6kJRnZzhglIUFTUNYRSWWgJ4vGIY\ngTguqVgWRpMU1iS565uL/5hZemosC0M7rrIuiNe+2IYH3159zG2org/h9f9sQ3mj0vweV+xZDcbz\nYY0sjETXhdm6vwr/+KQgytWVLCEmwThGpk0cjMv0yT0PTbgaALS1LwBkuT3RO0vWXA+m8nA7BTh0\nwdiozrO5pBpZGKaLS/9MtzCWFhwC77UCf0xyaIFiQTLrKIXkENyiG909WoBb5YxKu/ofsSzqa3oA\nYSWMo5V+0zLQhEs7T00giCMVesC6vqs1GVGQo1xwSnUvXOwZq20zBEO3FOQqK/7DGPCZrQhgeMcI\nvb36WuiCas1X0QnxNXonZ/3oxV4HIHTT4kKSIiGiSqZggIngeAWKwkz3GwtYgmEUioyVVszxShML\nY8XR1WY6bnXIcsFxDq3NxmSqyrogxJ5aDSzXaQVw9CoE5/Fh9U+lZsaXdpNWR6ooKjaUbYHYrRTu\ns9eYHX04gUKSEUnFa19sQ70/ghVbiqMyoczzq6xJxlQy/NmL1x/B9gPNL/zVGgujsbulpa4x0Tkw\nDXHcYgDA8wycK2B20AFbuRQjHmTEkuz9bWFpfdRaNi3xr+/2YtOeCnzwTXTyQ7y4XGMLsLUxjCPl\n2mDHvjBbsgw3EoxjxO7yObVHHpyC5VLJz8mE02E9WjPrBzBdUkbgWoEEzhnSOmhVNK0IIbcSYm+t\nhARTBW3tCHswWD+nQ83U0nEVBzgOCEohHK30IygH4RHd5voeKq9nYhnrkSsOK4iuhBGKKFHBd6Md\nlfV+lNQYCxoJWhsBbU6JLhiGiDiYdi2tI2Xgs7UfrFw8AEqdJlyyKkfFBKDPPTHmsLgztG1yVU8z\n/lNcXY/NeysAm8Xi6H0QzgHbAFiikOnMwIotxZAlDuAVyKrNJSW5ENl9AQCY14o5D0VQwBhDQ0Tb\nluPMRnWoBosOfQ8AqI1YmVG8JwDnoDWQnFp5lkdmr2tyOk7P9LILPOeyRpolVQFsO2C5mzbzX6Cg\nfJs114VTwGXoqdVqdCVae1B5wY+F2HbAKBOjwtF/G/icCpTXBPHOAm2kaYxo2zoPIiwp+GTJXrz4\n6ZZmR66+QMQaVNiIlSUVq6wJn10F8DIW/FioXVeWwLm076W5GIO9w20sllHt6LYP7iErUMNrIh8M\nW22V4gh2MCzjr3M24tMYa+nEw6jj1bjN9vRi+/ca18JIUDBifSfJckOSYLSS+4ZOw+DuZ2NI3tnm\nZxzHobueqgoAORkZeHX6ZfjddYO0D6LmbAhwOQUInPUZ5wrYgujWr0LsZqXpCjxvncdmfXRvGA6A\n01JyAazacQRPfPkJwkoEHtENgRfgFJxgumWh8LovVhHNa+4trsK2I0Vw9tdcUizianQt/Yet8qb4\ncYJiWQL6viLTXW9iBJzHByGnGkpdNy1+oAuNpESQ49QzmcKZ2v2qIsJ6qrA3Wxe0iBuRQ/rzE2Ts\nKKwB74ox74FXTMEoLAri/a93aRYZr0BWmCnsTHFA1WNBxrViWxgyFJuF8asz/xcCJ2BbxU6tVlgk\nuryIkFWLqowt+sEx3B+6YHA2weAzGsx9V24rwcb9tnRovhabyreas/Yd/XbAfY42XyUsKVEj68bB\nUmPlQj6rBmLeUbjO2IgFPxaiqEJfo8V9bILhC0pRKx3a/eH7iuKnuG4J/gDPsG8BRzjKqoll4TRO\nGdUBXG0AACAASURBVC1R9sN15no4T92GuXrm1OJD38M95AfwXUpRVF/S5BwG9ooHDXql5IgiYfb2\nj7CzynIvylnac68QdkNVWVSBRimOBdMQiEB116ImmHhRQEMMGmui/bu0W2OqKRjG/CRtv+bSarcf\nrDLvO5Y4JGtOCwlGKxmQ2w+3D745yqIAgG4eSzCcghNOh2CO6KIsDKa5pK7rP8b8iBOU6DTdxqgC\nnCJvm59gCUbEGHTLxjbJHNn3d2mi5hE8UFy14Lx1OJS1WGuGKpjn++KHPVEdGot4rJnsvKKvMMgB\n4CAw2xyJRllhX63UgsOcI2LGIVRfLgDOtEIiTMKuI1r7QruG6seL5qi/78n6XIxAllXt1yitos+F\nUX1WlVoICvy6NXDoqCEO2tK4ZTU+lNbVWc9Hb0NIDmPzngrUhzXTfdzJ43CaMNw8X3lNED/sKAQA\ndPd0g4Nz4VBFDZYXFKNe1s7X3WG52VzQ3F1mMNyGZWFYLinnqdvhOtcKvBvZVgYO3mH63oXuWiE8\nMf8wnl/xEd7+cru5X60vjltEjD2qNrJyYsUwGGPYWbUHSw+vaLLtyffX4+G315hCYff1Hy5rwJqd\nxZAVFXW+6LphRdDmyAi5ZVGxhFgWRuNRcbWi/Zb4HMv6MuYNuU4rwMvbXjMXHWuMPYXYsDCe+Wou\nNpVvxZtb3zW3CbL2vQWcxXhj0cqoCgZGsLix66vSXwf3OatxOGdhzGvHwhDoxnEa+3Owx3tMC6PR\n/cQTjK37K/Hiv7fgrXnabyOWNiQrz4EEI0nYLQxXIzGxp+ACgNspINPpxbhTr7HtE7+sF1MFnHda\nd2t+Aq+YgdeI3tf07qIFcjlRMjvy+Qsi+GFrMer8WkfrPtsWuJOtWeWcIJsBdqUhFyycYbuWNlHN\neJ/tdWvioQuJ0T7thQAmi5pLqpGYGBZGfTCA0toG/Tjt56fKvFkKJcRrI3g1mNVkAp8xW1wqOg1y\neV+z7XVhv3VP9mtyKg5X1gJMqzZsuOB2HKnAa//ZhrV7tNIjny85gp8O6B06r+CNudvMLK5sZxYU\nSQAnyPh+81EE9LIpl+Zei1OyTjKvA6Dp5E6j7YIE3h2AGrTWi+dt4mK4Ag3CShgRSdFcEJJ2T0Ju\nBUqFHdhSaQW/lx9dYQbW7fEd3tYOzuU3S9JogsFQK0V3tC99ugW/+9v3eH3L/+GLfV/CF4m2vCrr\nAnCduwKf7JoHILpU+vyCDfiw5GX88V+f497XV2LpJqtqL6enR/NZtVHHxJrpbc1jYfj20DL4mR4r\nsv1d5Dpzo46xx5PsRFkYumCUBLV2uQUrAYVj1t/lT8HVUccZFkYTwQhoAwZFaJSFqBOQAiitDuDl\nz7Y0qTzcuMO3p2LbRbixBRavirNBlb6WixGzUE2LxroeBb07GD0yojOoAFvVyUZi4HZqP9QshxWI\nRTMWxjmn5GHcyH4xXVL6GkxwGX8IeufPFB4Ah/e+2hXl+zeQy0+K7pANwSg7Obo9hktK/+PPyXAB\nqhgd9LYJourPBp/hM+cqGGJiWBhLtxy2soxsM+IlFobY8wB8rBo8eM2NZVoY+ig9s0ZbA9yfY2uf\njJqAXnZEFwxr4SpN1DgmAuDQv4fW4dQFtM66rF7vcGSneQ+cIOvxnDB4JsIjujWxE2R4XCIkpnVA\nua4c3HXerdrt81qasxEEz/QPtGacCzJ4XeiU2vwm3wOXUQehizbvxUiPNmJKOV5nk98ObDXMdsmr\n4ehVCPf5S+GyDQbswuUe8gOcp23W2uVxQMg/hK/r5mBd6Sb4ghIqa4PmHAiDssbzVBwR8J4ANlZr\nMZoo102u5i6K9NkA99Al+HanlW4tqppA8hn15loqQPTI+l+7Psc3hUvMDlroWoJ5+79CNbRsMuO7\nNObG2JHi1AWzz9iuD4Qwc/1HELtrLqwcwVZpmbfNE8poiBKMQCQIX8TfpIJybSh+ActN5Vvxxx8e\nxxNz/4ut+6uwYqt2TaP/53kOBeXb8O/d88BYdLmbYEiGrKhY81NplHvqQHE99uhuP9mW1LC35gD+\ntesLKKrSRIhCakBzC9s+TkahTOA4VKs9UejltToDI2X27P5dMbBvDnqc0RubGw6a242smkynJRjN\nuaT69chFZobDCi732W8GkcMh7Ufn5LSAM++t1zp42/mYypkZXAAwKOdsbLIF2DleAThj/oXuRtM7\nJtcZG8Fk69ouhwCmCOAzfHD029Gk7XLJAAg51RC66nMXDDExr2V3ZfFR2xwn70GY5SLT6YWf8dGC\nxqngM+t0V5XDFkuR8e/lO+E8xSbM9vsSFPD6SHLMsH74oASmxcIEW0kW0Qjm6+a/IwQnMsBxnGY1\neRR4XDwa5AjAAR6HC27RBQ4c6vijcJ3lg1KjLw5Wn49Tc06x2m6Ufom4oFTnm+nQAIOYd9Rsg1qb\nBzAtxhKWFGRlOBBq7F4y53hYHQDnkMA5bEvaNrJ0hBxNELxuh/m9fLjzU8gNXyOy7zwA0Vl9ZYEK\nDMjthyUbi7CjsNoUQkCb7GkXDHsRTk6UEcksQllNAHm5HiicdhznDkRlIBmdl0/yY1WxJkKXsf+n\nP/9GI2H9O919uBZ1IT+YIkAqOg3OU3aZKdONsXf8FeEybD9guf/s7WC8BMY4qHXdIORWwu+zLKvl\ngU/x1cpaBNddA3tc0V5aRlEVCLainKuOavOVkLcfqOhhDgzNET8PzNr+IQDg2n5XRnXitb4IPl6y\nF9/bLDQAeHrOBvO1PWj+8uaZAICzup0Bvy1lPBiWsZZ9As9QCcr+oWAKoNb2SNpERLIwkkSvTEsw\njHkLToeAGTddgFN7Wu6qHrkec0SQZROMU/O7xD232+HS4iH2UiM5WkZMSP9bzuP6AbILYs9CcKIU\n1YkbqasGP+3X/fpGp+uIxO3EAd3Npb93OQXzOMOtkuG0pQ0bM9sNq8ZMCbbiL11zdNcR49Et22Va\nLwAQVANwCa6oY3i3D5zbD45X9Vnv0efjmri/LOuI4xWAieA5Dh6XA1BEy5VmCIbkjC5HwqmAIwIH\n00bITI8PFVXVotrvB1M5uByi+T0D2sREo2MNBUS49ew0TpDBGWm0iojIgcHWvBxeMcUrvPsCaAkA\nDoSVMMKSApcTUUJg3C/QNO6hfWi4xmIXRfR6RG1yKLSOn8+sg5CnB9xtAfsVR3/E90dW4p/f7sHm\nvZVR82nW7i1sNAksenTr8zE8/PYarN5eAlUfwXO8ioKqzeY+isowf//X+GjnZ+ZnZufJR1sNxiDg\nHx9vRqWvQRN3Q0SKLMuoMliF8oAW77ALRkiJFpWGUNDswFVeAhQRql9LwqhnVrzEcInxuRXgnNbz\ntE84/XL38qhzu0Xtd8tn1kHIPwSfXpGBqQyctxZlmZZwheSQaWnxORX4v6VrmoiF7SmAzy2HxJp6\nCurD9VHJD5V1ISjQ3gsDNsF1+iYA8et6tRayMJJEpiP2WtoA4BasDvW5O6zO236M8WMDADWUEeXj\ndotO8BzX1D0BmCN4t+CCw5cPKfuwlhoatM7NAtlQ6rtAyNbjA4rW0ZnFEXseMjv6Jp2u2SjtmHNP\n7YbdpdGjwAyHC8afUeM2mi4po3JvVg04XgWTNZdZbpYLDbZ5CRKLWM/COMbbAOdA3dWhOKKuY3eN\n9evRBQfqbBaP7rrjVDdEgYPLoWea6Z1uhOkjcVs8B7wCzhEGxwG8rHWuiqRtq/b74OQVPQlBvy9b\n7MCwJP535CB4jHsQ5P+/vTMPr6LK8/631rvl3pt9D1khJEAgAcIWdmQTJGkWhRe1WxRFWxRwQXrU\nntHWmcYHp/vpx8exfbrtxWec0R573ufFcXoGX0VfEW1axBZwWFQSIAkhZM9dquq8f5yqU1X3XiAo\niCT1+QdS66lTt36/81vO77AEAKLfR+tJBu9qpsrOsHZCPvbM/dEQrVnl1kujdKRDSG4znxdm3EPr\nCYJP6mT3giLFVVY2FF6SR2Lns316jS2rUmjsPonG7pMAP4+6Hy0Wxnst/w/DBLNEjnWftX0ff3EK\nsHhc32l/E5BmAVE3FFXBn/Q0ZQNjBMydJ2BPrx21pXwrxDz28T3/AABQT1SiZkQGDKupX6XfEYnI\nAE+gIIrX/u8x1FXlQOOiIKoIEnXrx/YDoHEeA9eIv4BEJYQ+ngPALP0PAH86/SamFVYj2RXEZ1+0\nQ+bNb1guPISmaAaAcmiEwD3qA1jzqvpVY+kADa5yWhur/8OF9gcWopCKPgNUEWJmE5TmQrx3oBJ1\nVTnskLZQO3r7zcFmbyix1ZVojs7XwbEwLiO3jVqDtRWr4rbHBcF1rBaGWzSPCf91KrIj49jfXsk+\n4rahj84FgYdILItBxQp8y7lMoFpiD8aIlQniGEFg3Ke2IhMen/3H55XNQGJlfoZtH2LuJeV8gQ6t\nlSkgr0sCHxNjMRWsOXrlPb0xbTfdVUYW1aIpZbhr2Sh2bcMlBU2EKPCQJUGvEkyPjyIEmXMB4G0T\nEg1hzOmCRInw7F6Gu08U4z8do8+mVRSy4CqnB71pmyVb2zneuk93BSoCzhquEd1S0HqDCB80srgM\nhUGVgtKWixJ3JbuX26uAEzSo58z3YMR2fB7pvAI+UZzLmHhpVSZfRj/F7tC/xD0zQxf41gWtzOvR\nfbGjfoAuGMbJ5sRHdg6bx0EAga4Zw+IaMCbVmQJeGHYQn4TMkf8XZ6gVEvmqki5Cxit488MT+JsX\n99K5SYrI+j6il+ePrSLASVHW7/2qPSHgz6cP4L/+3IifvXYAB0/YYz/tfZ147vVPcbYr3hrsV0K0\nbIulb901u2wWlpjRCDGtGWImtQI5bxd+9QYtxmik5Z/oakJbxKz91hEyFZoJGdBE0IHgKIzLyPis\ncZiSMyFuO88JCY6mglHk6Y/VI1mFvYhUyfzgPYbLRxMR+nQa1G5rtggVqqLAsTgGEB8TIYawAphA\nNeIVNoyPsTdo20w0ARyoHzys2T94aymUZdOG29P6mIURG7wVWLtjS4hYra04Yhan4t294EX6MYwp\nzqJ+Y0MR6rW4NJWn/SPy9HyjbDofhgS97WxCogroQlDTrS6j76Sig3pCgZBwohnn6odP8kLgBfMZ\nBMUsPBmbxaUrO7ZGivFcutDo1Vd31HoDccv9GnEKEvGYmT+CAneSns4c8qH/41n6OVTo+twSIEbg\nhh/jhRvYOQDAiQlcXEaJllilYH1mfaARbRyuX4ee09qjB2rbzNGwoZxCoEJ3Zv5UjEmnc23CSgTy\nyA/Z+ZEv9Tk4hsIQFBrEVU3XrFEdIOEETMM9Z9QzUyQ9ecFUQLTvRWaBRbWw7ZlslxMj0DQN3Rq1\n9JRmGqP6n45jaGmn76IzRli3dnfhz5+fscV8DPqVEBRVs/32OTFqm+BpxBUNSJi6SOn6NvT5jnQc\nx1fBnWwA0R1O3BeOhXENcb4qkxzHsUwpm8IAkOo2zUyfvu+Bm8ZhTd14WyE9A1HgIXNm2mZsKq/N\nOlHjLQwDluranYolefXmDo2H1y0mXLbS6zKVkUcWY6wZe1mT2O0CzyGpc7Rtn0s4v8Jgqxhqpjst\nJZd+ZC5RhiQKTFm6hlM3lhrlIQg8C9hTIURAhAibbMgEsqufCS0lIlF/uP48gr+DChNNQDDB+he8\nKwS/TGMsIi9C4kXwvk5zBr0S605T6WjW1l8iOEEDQNBJ6MhR6w3aLCoA4APt+r4AzeQCVSYun24p\nhj1A1A2tJ8DOccs8OCkMGR5oUaOWma4wdMGlfFUJF9HnlRgKw514ZUYIUXBiFEJvJpTTpbRqs369\n3sAh1nZDmRjtCGnU+gjIfrgFOsiJkAh4t+lKU9vyqJvUSEKwVihg82ki6O6L4Me/j587wgpYWtyB\nRBV0a5SwthhVEgAgrIXAwVSQysnhiJ4u0q8XxeftxxHiu6C05SB6YiQIAXojvTQOZ7lXibfc9nei\n2e4hJQRVJXHK2FYRQE5sDfZG+0BAQIjlW9Sv0xOhfcsSMAA6V0nrwtuNF67eOxAchfEtYA2OxmLU\nP/LGKIwMrxko9+gun8qiVMwdn4/a4QVx1xF5Dh7eku0SY2EE3JaYBps3kaBdFuFV4LPcR+Ph0yd+\nzR1mX6PdWgbFJQuILYUCIP6jMUqgCDwy1QpEjlWxXdaYT13WjITnWe/Rqa/V7RJkyBIPrSPTdoqq\n6BaGJACKTOMT/nZwHEE0bM+sEoJn2WS5aJhHR3fY/mECyEsNskmZPxi1xrYvYEmVLg0WgxMVljbL\n+tbmTlPs/WVZ1pfz9FDfe9TNLLTcLBe23TIWQuAcDdZG3eZgQ1DBu43yLh7WT5yg4YHVVYiQMDie\nQIIHasS0cgBTKah9SVCbS+g2MQLe3w4huQ2C6mGVio1Kxka1ZD90a1iRqFXCaeCDbbS6cVuePd4E\noF9feyQgm8ouooXNwYD+PoxFxgBAyGw076G/q7AaxpHTbQgVvQ0AiDaOgHKWWjRs5C4amXAyJE42\n+1ZXQD7Zg1HDMlkbkv0uJqi1sAuwxHi+7KJtUM9lAeAAVaJl8I0MJDEKLezBkgK9phqz0CxlhCTq\nau6N9OsWhm6hnSzV+9SqMOh5N5Wstr0rowqB2poPtbnYdq8e3RWodadAOZPH+r1d/hyvHvl3fFMc\nhfEtkJdEf8RjM0bH7TMsDLfooi4TnaJMM188yWVXJiWZMXECUMHrEUylIPEinrl7KsaU0OvwmtUl\nZZYhCR+eyASAfR/g4k2LhRAeyUm0HfWli/HszCcROlCHaONw5LoKzXMkIa4UCkAtFhvEtDACXpmu\nBWJcw+KSWla6ALnuYeZpTOjGW0cuQYZLFEAiHuYyMNpAYxg8lDN0wp9r2BEAQJeRJWkR2oK+RGxP\nN/DBwRZz3ghrn2ldTMgahxJ+PPvbGpdaVW5aaETjLMrO4l4SFHbv5CSZKYZ5tdmQ3SqIItvbxysQ\nfL0Ap0HrplaoVzLjJWYBSSPuRd/7sFwP2iI0pdZDUqBEeXYOADZXROtPQl+v/nsQo+D0kv1ZkXGI\nHKmG1pcEcATJfpkpjGx3nv5cIjghimBArwLbka6P4O3W0alOah0F5CTmumtKfSMmRsfRd6wPNIz0\nY63Pb1EYUXzRaVb/JRE3KzLJSWHw/rMQ0/T0bkWCZJTQFxQ2CVRWk1i9NQgKAl6ZlaAhETfrf06M\nsNG7UUyUKCJCagjhqAo+0AbeFQJUAZnBII2b68rKmOXv66zAoQP0fby57zhaO/qZwtB6kvX0bdou\nztcBIaUVhAATc8aA0wSmFA6dpXWsSNQFrd/IxqP7enWFYc0mg6BAwYUXkxoojsL4Fkh1p+Dv6x7D\n7aPXxu0zBIwsyCxv+4l1tSjLM2MIse4qq1Ay6Asr8ImmgJ9YnoPUgNucvGONIVhy57WuNFbKnGIq\nD06zWwrpQSqYeI6n9alCSVBOl7J2A/a0W8AapJYQ+mSGZTttgyhwtFS3Ygphq4XhcYnwy5YMNEPo\nRjzQQubzcuAgCRJy032oHp6OCaWmdURUEQLPQxYFaJ3pVGD79NURozK7QviwPf5EVBH//t4XUFqG\n2awMq8IAAAHm3wHdJQUAGZ40UxmrElyyiCdvn2S+CykCjifI8Adw45wyOjlTVyZzJmYjrIbN2JOx\ndK/SiLcaqQtGC1NhkSSbAk8TdJeUIdQs5VBO9dMRcpKWgUiEp7Emyyx6EpUBRbYIySgTRElCgM6W\nD3vAccC9KyvY+i9FQdrXdD6LgtxMe+wn1sIw4i9BVwAe68zrGBcM0QRbDEMLeaC25bPf1LmeXvzH\nXjNIThTJJuDFLEspeSLArSefcLzKhLhXS4NsJCiICqKqBjFJXx2yz2/ri96IRRgDgCIhrPUjElUh\nDTvM3lPQ6wKnSbqA1yCX6bXGoh7WFz3hPpwJtUIqOKK/Lxkk5NNTogmkAlrZluMAt0tEwO2DKNP+\n+/cj/wWoEtS2XNbHctkBCKmn0aN0mX1h6XcFA6+ueyEchfEt4ZeTErqmDJcUVRj0BWemUEG4bvRa\nLCmeb5scBMTMENcJeCX4JGvWlV3J8Kop1OIC4uGYcuw6GiFmmwmf0G8PmAvBALDXvALsEwiVeCtH\n4HlMr8qlAslou+DC/SvH4o4lleA4DgGX5XmNaxMe4QPTWbaIJEjgOA48z+He5VUYV2h1pwkQBU6P\nv3BmCitAZ3kbh/WZwt7WXk1E1OYys/cDT8y/rcqc53hz8StFgqpqyE33MSEuF1IhU5CejAW1w+Bx\ni6zv2kPnQECQ4be3CaAzigGwa/tc9P2JGY1QjOKSMZbJrz57Gaf08hj97X5EIhqdk6KnJfPufmj9\nSRg5LBmV+XROESdGzPiHYMR6aJ8E/BxdD4RwyPYnY9a4XL1iMkEwzZ75Zfw7rzYbnLsHQkYTBM2N\nTG9GwgQHI1gPjdYE41x0Do4R9F08kbphQmrETEuOuOhgQFeUvByBqCtytZNa2Sw2JqjwpOjVAfqC\nOHUmRAcEQhSn2nrpipFRN6C4zHIzYoSt+24qQgkqVOw+0MQC1NHGEeB5DjxxUYUhmgFvviPPzJQT\nFPsSBRE3tLCXPq8cYuO26CnqHvRJXmhSH4TMr6DyIag9AZCI1/atyWWf4AuiL1+rSKYVLijQuPOn\nK18KjsK4yhiL+filJDZSFwX6a6nJrMKi4nlx5/gkc2RdNzYX96+swrQxObYJdMYo2Ai4c+r54xux\nghIApo3JRllekNXI4sRInMJ49NYJmFOTh+oR6Wwbx3F214LFmkm03e+VML48A//4w5mWtrtQVZqG\nKaNpgb9kt2VGfIzbIsVFLTGZtygjAMkuUwEZLin2p8UymVJRgBEFetaZItvjFVZLyaLs3DEWn6BZ\nFYa9L41ArtaTjPJhKXHXBQCvaFhunF6sEfiwmU64Ks1Kwy/un2GLE7E26QojoK+/wid1oVdoocrC\niE/pM6e/6mpEe+QsSFTG4S/68HljB+1L0eLGCnuQnuzBD+brylGMMjcIc9vo/dCn9FOFokjweiTc\nsnAkND1773/wLr0ey2jTkwYkDby3GxwH5ChVcAkyunrtgkwLeYGoG5nJHowtobEF99h3bf2W7NXf\nn2XiY7SxHADHLEZXcjdLj40co4t7ifpvhE86B9HTD6JIaDsLLJtabCpPMQIihaD20PdoXE/K/QLH\nO7+09QH7Tej127SwG1oX/RYkuMHJYXNNmDN5UBXebm3pbZd68gDFBaL/LqXSA+AFDUTlkROh5WKM\n9VjkokO2e9sGYRZIxLRmeF8HVC5qq5D9dXEUxlVmWm4tvl+5GuWpZSjM9iMvw3fBMsaAXWE8fMtE\nVJWmQxR4uCXLaD6m2JgxpwBAvMJIkHW17vpK8DzHsrU4Vz8CMQqjOCeAtfPLael1C7bgpW0mMGdZ\n4J6ek5mi+2CtzxxTHSLda5kFH9P2FDcVUrFzXZJdZuoxUQWmhAEzPREAJpTms7LfAMeCnDzHo7LQ\ndNVZLTSPZL9Xrs+sXBuIcRdGvhoJEpUw0lVL54ggPmXZrSuMmhEZmFkyDi7ehY9a6Mxoj+iB1y3i\nb2+rxTBSYzvPUBheyW4hcqpl5r3lXXRGOi0uOIDjCHhXP3PdkKgMTSPwy0m0/Iv/HBN4xj0MIdQX\npdlkRJFoZhyA22uXQiAy+jU9vVQVMboklQn6kBKCKOuz7BUZR5o68L8/OGJruzEq9rpFdITtpdON\nZ/G4JJqRxZsTH3m9ijLpC0DrS4LqPwWS1MbaAQBpoHEtKfc4iEDb3tMfxeiSNKT5/OBFBYvqqJIy\nrG4jeQAAQoSWJmHKWBfW1HWnmNYDAJ8eT5RLP2FtiCrUqiME4LzdZu2xvjL9nvR3KfjPgfN2IuD2\n4cHV1QCA1n5zFjq9d0xsy0L0VIm+3DPdJ+UdB/GcYwkG3wRHYVxlZEHGxOxq8ByP7y8aiR//YOJF\nz0lxJ2Pl8GXYMv5u+7VkgZWdOBemPvrKIipsq4vy2XFG/MDjElGSG7C7aGIo8NOApkeSUD0iPtie\nkPNU3r11YTmCxshfz5PPTDaFtzECCql2X3a2pU6XITSeuXsq/v7OyUyhSTEKI9Vtr2wqWCwM6yz4\nJMlnsz4Ml4ZHcCPoNa+Z7DUtB3dM2u+SCebaKF6LMgcAtaUIoY/nYFp5iem600SED9WyY9J0K04U\neNx8XSUKg6Y7zbA+slK8eHjuTfh+5Wrz4npbjbXMDXjVFAzKqVJInEWBxKz+CIBNliNRF8JRFSIv\nguvIB+8KQUg5oz+XBwGfzEa0Lx9+FZwUBYnKdAY9gNqKbGQlmckNRJFQlO3HjrtnAQDeP/0R5k2l\n7qHjjf14+vd/gdKab6tHZQT9PS4RTT2nbM/F0psFDhyvlzZJp242v5EFSMzEBgiKTcCnCFkYnlwG\nTg4jTPoARURuuk/vQw8kt4KRZW57PxEeoU+m256J/V+1KAxRQW5yEE+so++1NjgbRBHB+2g8ROZd\ntMAi4aGeyQfv7oOoZ34ZvycjC83AL/tYSfqCpFzbPkPx2+ZX6Rjl/+OKnl4gXX2gOArjOwTHcXGj\n9fMxq2AaSoJFtm1uSWCZM8aodU5NPratHY/66aXmgfoo/b4VVchJ8wKqBKGtDBVCTAorgMXF12F2\nQR22zbyDZkCdh4fXVONHN9NsIbUzsWLRNIKyZOp/NjJsMlLMEdyDE36IMemVmJpbazsv22dJk9U/\nghS/C5kpXqYwYl1SAi+gIf8maCEPtO5UiJb5I2qXmYHmO4/CUImKJI+pMFycKYRjlZMk8phTMB0i\nLyLHZ0/ppXA2C+p7M0qw1KJkioL2NGlrIUt3zKhwQtY4TM6egNrsGty3YizuvGEU0j1p4L+YxALs\ngmY5R5Uw1jeV/TkqP4cJoegXMVl7UReum0DbInXnsc1EFeCWJTyxrhar6uis8rMhfSEpRbYp46DF\nFUhUEa3n+pnSA4C3mvRZ2Cy+ISP8V7N9rNSNLGBa7iRb8wxlIvI8KyjJe6k147aU5bAtiWwRxpkc\nvgAAFtxJREFU8KLAI91jWqsF6SnYciNNcy0JFiGqKdjXoseHrNcIe01LzSqg9Wu7KmgBxaxAEHkZ\n1MIszciB0mJm6vkkD1vRT23T0131+AYb+SsyIkfNWJnVcrxr7A+QKZjXY0kiSoLBGYvL2U11NdGx\nl4ijMAYRLklA9MtRiDaVob50EQBaUrksP2gTikb8QBJ55iKSzoxCTdr4uGvKgoQVw29ga4Ofj/Jh\nKSjVM7u0zgxET5aiyluHm+ePYMeoGsGkbF2p6EI74DU/wAJ/Hu6q+n6c6Wyr06WPFg0BnHoelxQA\njEorR/jATJB+P7xuy8dicc/5ZR8k0bJuQBd9zpAaZusVGGuwGyNhI25i5XtlS/DszCfhERMnEFgd\nc0umFmFpbTn7Oy9m9GhVOt6Y63Ech5srV+HWypswtiwdkyqpciFdGQgfnIwcbTRSQxVI8ZsCLyCb\nQjw3OYWlWmvdabbFqB5YMYnFc1xRM8OLKBJ8bhF+r4yZJTU2F5zEuVj2HGBXGOlJSbhuYgGrZmDF\nNlK3JB4Ygxm3LGLliGUYlzHG3GfUPhN4/GjyJtv1rKnYVrebVcBLAm+LbWX6A6yfjBU0P2r5S/w1\nwLE4H5s5DqA4RtFb331mssfWt92WQlKGK9Fg0cQy9n/N4i61Zj0mu4IYhrHmSUb/kfhBnNG3WmeG\nrSpEW/s3D3w7CmMQ4ZIFQBOhnCpLKLhMU1VflIfnMFUPLC+ZWmRzD31TlJPDUe4ej9k1+Vh/QyUC\nPhkTK7JQmVaOB8bfg++PXoX1SysvGq8xmJozEZWp5Vh3fSVumFbEtpsuqXjT3Mg6AwC/165QNo3d\niDtG3wyP6LEpU8OdUZlWztYrCPpkyIKA0P6ZSGqahYnZ1XH34jguYRbcTXOHQxJ5jCy0VyPmOR4z\n86diTsF0SDECNccikD1SYgUUiyjyIH1BjOCn4oeLpuMnG6axfUGX6U4LyH4MyzLjLJolQ84q7GVB\nBukz/67Q2y8LMlaOWMa2zxtXauu/ZIsy3bC0GqW5QXAch3UxKeW2YK3NzWO4pARIvIiiQEHcPlHg\nkO3LtGULWgcZomYpkWMZVQsCh6ClfdZvpDhQaMbXABDF7r5RWwtsbQAAn5qNGcEl7G+rRZAacGFq\neTH7Oxw2r11XUQTrEGJ4bjqeXj+ZXt+iMGLffapo/i58kg8TR9KBhW0FSgAjcvQkFMIjcthirV9g\nkbaB4lSrHURcyGUEAOHDtXopCl1hCFSQPb9lJmRJQGfv5cnVZugW8eTKbEyuNH/sxcFCFMcP0i/I\n/6pYmXA7C3rz8RaG12X+vA03zMYVVTh2shNlafkAqHKwWV+qBOHQfKy7czr6iglOn+3FzQvK8S+7\njgCKC66o74Iz92OZP7EA8yfGz8wHgFUj6hNutwpJcYCZLRuXV+EP7xzDwknDEPDJyMgwlYTVOgzI\nfvgyLLP+LQLKOodEFHmo57LAJ3XSkicWhZtvsYh8MTGboMWasQrxmswq/IvkM+s+2Xzv1oQH+n9D\n2VvbZMYwaP+nulPQHaUuKTp5sRcuWcDT62bhRx/sjrsPz3M2C8Mq4CVBQoY3jZVIJxEXinMCKM0N\n4PjpLvzo5tlY/+vfQOtJxuiSVBz+6hyWTivCud4e7NZj81ZrkOM4rJo+Gvve/QN9HsENoxwjITy8\nfBL6tG4QjYfEi5CM9FuL8pRjftMLa4vxJz1hbNPy8Th4UMNHh1sRPjgFj24owzP7fgEAaD5jqV1F\neD1BQLvwMtADxLEwBhEu+SI/CE2wuWOIvtCvrCsawz1kdWd8HQwBWT4s+SJHfnMyPelYWDgHM/On\nxu2zpqL69WcbV5aO5TNLbcfZFAYAF+eFW3QhNeDGI2vHIz8jCafO0s89mPTNA4cXQ+RFXDdsFoB4\nd9X5KMkN4MHV1XGZbACQ6gli1Yh6ZHrSURwchorCFEwbnY07llSymFehv8Am4JfPLIHSRu9dmVJp\nu55X8uDGEfVIknwoCgyz7Svw03MkXrQLe1jWvdd4WMvSrJk3nC13y3vpxDPjt2i9htXCAOyJDR7Z\nFLRBj0UhWq0XYs+ei7XCc7xm7GjNrFF4aHU11lw3An9zywRwHAe1dRhIXwCFWX688OBsFOcEkOIz\nrZzYZA2rS2ndItO11tsfZen0RsxGZoM9jrmR+hR7xV+XJGBTzQbMyp+GAn8eJlZkQuA53LGkEmmW\n2MzyGcNt5xluViNu+E1wLIxBxMUsjFnjcvH2fjPzJLZEN8dx+NnGujgBeqncNHc4GqaXXFyBXQY4\njsPS0oXn3WcQ65KyYk25BZBwzkN3H7W+aisSBbUvP8tKF2FR8bzzlsa/FFySgJlZU21Kdd0SqgQq\ni+ohSEvhkz22/hpdnIZfbVmMzvB0c20PCzPyp2J63pQ4l2JxsBBPTfsbcBwXF7BPd6fiq65GWiLe\nQpJHwuphy7Fj33PoPkWVeWYqFbb5fovCVI15SnpKttda0ZkqBlUltjZpFrcaIQRZlnO8Me0zrpfm\nTsG88YWIZerobLz/12akBszz3LKA6MlSSHnH4jKZrO3IS0nB8PwuHGnqRG9/FHmBTJzsb2QLZFnL\nAkWOjkNwxP9gYdHcuDaUJRezxJH0oAe/fGg2ezaDuqocVgZ9XFk6XMHJOID/w9xq3wRHYQwiLqYw\nblk4Emvnl6OxtQefN3YgOzU+ZnEhwXpJbfkWlMWlYLikEiHFKM7Z1Xlxx2xcUYX9R9owZVR23L4r\nAcdxl0VZANbRazwXs5is8Y9Yzhd/ssZCrFRnVmFf6ydx271uEQX+dPy07m9x5wc0iypLz57z25Yx\npuLKr7/L2QV12NW4GwVJuZB66TtU9USFm8ob8Lu3Dphr1INaGLIl1hUbjJ+ZPxUCL2B2QV3C9n9/\n0UhUD09H9XBT6QR8MpSTZfBHCzB+9riE5wF0lvniyYX42WsHMG9CATwZSfiwZR/bb/0NPnvnXEji\nfHuixkXgOA4NZdfHpc5WFKVAVZOx973ZQPT838BAcRTGIOJiCgOgftzCbD8Ks88vCAYjfu/5Pxbr\n8pX/eG9dQrfO6OI0jC5Oi9t+LTCQ38W3wdiMUZiQNQ4dZ2R8atlupJJLotnOjGTTXfS3Ux7Gm4c/\nxFu99L2kJ5vK5KfTfwwOHP5jD11kyBhoT8+bgvTp5XjtnWP44nS3vo/uTPekoa3/LOy5azQetrRk\nwXnbLwo8xpfbLUyfW8LT66cgySslVKC3j74ZRzqOISAnYWyZn8ULVS0NxYFClKdQi4rjOFw/pRA5\nad6v7facN2xm3Dae4yDKvC3V+JvgKIxBhEsWsKF+NBudOZhcyMIw1qj2uMSEyuJaxyV/N0KVPMez\ncvDaTILb/4Eu0+pOYI1a3aLpnjRMzpiKt7APxTl268WwwhItaFVRlIpHi1Jx29+/BcBUJptrNuC9\nU3sxKbsm7pyvQ1YCS92gOnMMqjPN+IVh7Qm8gAcm3GM7Nja2djkQeA6CMLBMxIFwxRXG7t278dRT\nT4EQguXLl2P9+vW2/ZFIBA8//DA+++wzpKSk4Nlnn0Vu7sACfQ7xGKl2DpQlU4vw2RftbC2PRPSF\n9bURLsEFcC0x0Mmg3yY8x+HJ2yfh4JftKMk1lcC9y8fYStwYlOUHsfnGsRienziRYiBC0fDyB10B\nXF983ddq97UGz3PQNHLxAwfIFf1CNE3DE088gZdeegmZmZlYsWIF5s6di9JSU5O+9tprCAaD+NOf\n/oQ33ngD27dvx7PPPnslm+UwhPjejBJ8b0bJBY8xXDaJYjrXMo/eOoEF67+L5Kb7WGkOA2t8IJYL\nuQQvpBR9Hgm9/dG45IahgFsW0NN/eSrVAlc4rfbAgQMoLCxEXl4eJEnC9ddfj127dtmO2bVrFxoa\nGgAACxYswJ49e65kkxwc4lg5uwzXTSjA+htGXfzga4jinACqStMvfuAgIJFLyuAnd03FxJGZmFOT\nf95jBhuP3joBs6vzUDMigxX4HJ5/iZOfEnBFLYyWlhbk5JiLwGdlZeHTTz+1HdPa2orsbJp5IggC\nAoEAOjo6kJx85XP4HRwAGrtYPW/4xQ90+M7i8+iT+hIojtL8ZGyoj1/tcjBTnBNg8Z7RxWnYfONY\nlOR8xxVGbIntgRxDCBlwuQgHBwcHAJhQnonjE7pQNybn4gcPQS5Xht8VVRjZ2dk4dcqcKNbS0oLM\nzMy4Y5qbm5GVlQVVVdHT04Ng8OKa0Fr6YKjj9IWJ0xcmQ60v7lsdXzzTYKj1xZXiisYwxowZgxMn\nTuDkyZOIRCLYuXMn5s61z16cPXs2Xn/9dQDAm2++icmTJ1/JJjk4ODg4fE04MhC/0Tdg9+7d+MlP\nfgJCCFasWIH169fj5z//OcaMGYPZs2cjEongwQcfxKFDh5CcnIwdO3YgP3/oBKccHBwcrhWuuMJw\ncHBwcBgcfPdm9Dg4ODg4fCdxFIaDg4ODw4BwFIaDg4ODw4C45hTG7t27sXDhQixYsAAvvPDC1W7O\nFWfbtm2YOnUqli5dyrZ1dnbitttuw4IFC7Bu3Tp0WxYMfvLJJzF//nwsW7YMhw4duhpNviI0Nzfj\nlltuweLFi7F06VL89re/BTA0+yISiWDlypWor6/H0qVL8Ytf0JXWmpqasGrVKixYsACbN2+Goijs\n+E2bNmH+/Pm48cYbbanugwVN09DQ0IC77roLwNDtizlz5uCGG25AfX09VqxYAeAyfyPkGkJVVTJv\n3jzS1NREIpEIueGGG8jRo0evdrOuKB999BE5ePAgWbJkCdv205/+lLzwwguEEEL+6Z/+iWzfvp0Q\nQsjbb79N7rjjDkIIIfv37ycrV6789ht8hWhtbSUHDx4khBDS09ND5s+fT44ePTok+4IQQvr6+ggh\nhCiKQlauXEn2799P7rvvPvLGG28QQgh57LHHyD//8z8TQgh5+eWXyeOPP04IIWTnzp3k/vvvvypt\nvpL8+te/Jlu2bCF33nknIYQM2b6YM2cO6ejosG27nN/INWVhDKQ21WBjwoQJCATsJZ2t9bcaGhpY\nH+zatQv19XSd6LFjx6K7uxttbW3fboOvEBkZGaioqAAA+Hw+lJaWoqWlZUj2BQB4PLQ+UCQSgaIo\n4DgOe/fuxYIFdD2HhoYG/Pd//zeAwV+vrbm5Ge+88w5WrjTXff/ggw+GZF8QQqBp9hUNL+c3ck0p\njES1qVpbW69ii64O7e3tSE+nReUyMjLQ3t4OwF6XC6D909LSclXaeCVpamrC4cOHMXbsWJw9e3ZI\n9oWmaaivr8e0adMwbdo0FBQUIBAIgNertmZnZ7PnPV+9tsHCU089hYceeoiVFDp37hyCweCQ7AuO\n47Bu3TosX74cr776KgBc1m/kmloAgDhTRi5Iov4ZbHW5ent7sXHjRmzbtg0+n++8zzfY+4Lnefzx\nj39ET08P7rnnHhw7dizuGON5Y/uCDKJ6bW+//TbS09NRUVGBvXv3AqDPF/vMQ6EvAOCVV15hSuG2\n225DcXHxZf1GrimFMZDaVEOBtLQ0tLW1IT09HWfOnEFqaioAOkJobm5mxzU3Nw+q/lEUBRs3bsSy\nZcswb948AEO3LwySkpIwceJEfPLJJ+jq6oKmaeB53va8Rl9car22a4G//OUveOutt/DOO+8gHA6j\nt7cXTz31FLq7u4dcXwDUggCA1NRUzJs3DwcOHLis38g15ZIaSG2qwUjsSGDOnDn4t3/7NwDA66+/\nzvpg7ty5+OMf/wgA2L9/PwKBADNFBwPbtm1DWVkZbr31VrZtKPZFe3s7y3QJhULYs2cPysrKMGnS\nJLz55psA7H0xZ86cQVuvbfPmzXj77bexa9cu7NixA5MmTcIzzzwzJPuiv78fvb29AIC+vj689957\nGDFixGX9Rq650iCJalMNZrZs2YK9e/eio6MD6enpuPfeezFv3jzcd999OH36NHJzc/Gzn/2MBcb/\n7u/+Du+++y48Hg+efvppjBo1OBYF2rdvH9auXYsRI0aA4zhwHIdNmzahqqoK999//5Dqi88//xxb\nt26FpmnQNA2LFy/Ghg0b0NjYiM2bN6OrqwsVFRXYvn07JEkaMvXaPvzwQ/zqV7/C888/PyT7orGx\nET/84Q/BcRxUVcXSpUuxfv16dHR0XLZv5JpTGA4ODg4OV4dryiXl4ODg4HD1cBSGg4ODg8OAcBSG\ng4ODg8OAcBSGg4ODg8OAcBSGg4ODg8OAcBSGg4ODg8OAcBSGwzXNqlWr0NDQgOuvvx6jRo1CQ0MD\nGhoasG3btku+1u233z6gctePPPII9u/f/3Wae0kcPHgQ//mf/3nF7+PgMFCceRgOg4KTJ09ixYoV\nF6w+apSKuFZ49dVXsWfPHuzYseNqN8XBAcA1VkvKweFS2LNnD7Zv345x48bh4MGDuOeee9De3o6X\nX36ZLaizdetW1NbWAgBmzpyJl156CcXFxVizZg2qq6vx8ccfo7W1FUuWLMH9998PAFizZg3uvvtu\n1NXV4cEHH0RSUhKOHTuGlpYW1NTU4OmnnwZAa/M89NBDOHfuHAoKCqCqKubMmYMbb7zR1s62tjZs\n2bIF586dAwDU1dXh9ttvx3PPPYe+vj40NDRg0qRJ2Lp1Kz7++GPs2LED/f39AICNGzdixowZOHHi\nBNasWYMlS5Zg3759iEQiePzxx1FTU/Ot9LXDEOGbLNbh4PBdoampiUyePNm27f333yeVlZXk008/\nZdusi8scPXqUzJo1i/09Y8YMcvz4cUIIIatXryZbtmwhhBDS1dVFamtrSVNTE9v37rvvEkIIeeCB\nB8jatWtJNBol4XCYLFy4kOzdu5cQQsiGDRvIL3/5S0IIIY2NjaS6upq88sorcW1/8cUXyWOPPcb+\n7urqIoQQ8q//+q9k8+bNtrbX19eTs2fPEkIIaW5uJjNmzCA9PT3kq6++IuXl5WTnzp3s2WfNmkUU\nRRl4Jzo4XATHwnAY1JSUlGD06NHs7y+//BI///nP0draCkEQ0Nraio6ODiQnJ8edu2jRIgCA3+9H\ncXExTpw4gby8vLjjrrvuOogi/ZQqKytx4sQJ1NbWYu/evXjyyScBAPn5+cySiWXcuHH4/e9/j2ee\neQYTJ05EXV1dwuP27duHpqYmrFu3jhWkFAQBjY2N8Hq98Hg8WLx4MQBgypQpEAQBX375JUpLSwfa\nXQ4OF8RRGA6DGp/PZ/t706ZNePzxxzFz5kxomoaqqiqEw+GE57pcLvZ/nuehquolHTfQdRbGjx+P\n119/He+//z7+8Ic/4MUXX8Tvfve7uOMIIRg1ahReeumluH0nTpyI26Zp2qBa68Hh6nPtRAAdHC4C\nGUD+Rk9PD6tO+sorr5xXCVwOamtrWVnpkydP4sMPP0x4XFNTE5KSkrB48WJs3boVf/3rXwHQtS6M\nMuYAUFNTg6NHj+LPf/4z23bgwAH2//7+frzxxhsA6BKlAFBYWHh5H8phSONYGA6DhoGMprdt24b1\n69cjJycHkyZNgt/vT3h+7LXOt+9Cxz366KN4+OGHsXPnTpSUlKCmpsZ2P4M9e/bgt7/9LQRBACEE\nTzzxBABg2rRp+M1vfoP6+npMnjwZW7duxXPPPYft27eju7sb0WgUBQUFeP755wEA6enpOHLkCFau\nXIlIJIIdO3ZAEISL9omDw0Bx0modHK4Q4XAYkiSB53m0tLRg5cqVePnll1FQUHDZ72VkSb333nuX\n/doODgaOheHgcIU4fvw4HnnkERBCoGkaNm3adEWUhYPDt4VjYTg4ODg4DAgn6O3g4ODgMCAcheHg\n4ODgMCAcheHg4ODgMCAcheHg4ODgMCAcheHg4ODgMCAcheHg4ODgMCD+P4xSKOOE0RxSAAAAAElF\nTkSuQmCC\n", "text/plain": [ - "" + "\u003cmatplotlib.figure.Figure at 0x7f97f1e98d90\u003e" ] }, "metadata": { "tags": [] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAe8AAAFnCAYAAACPasF4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzsvXe8XVWZ///e5dTba3pCQiAJCSWE\nIJGmoSSgjsg4gmCb4Tf+dCwURUdEQXGs41gYFQvDiIyIiKIIJIAgEBJCgJBKertpt59z76m7fv9Y\nu55zboiQBCL783rllXt2WXvttfden6et55Fs27aJECFChAgRIhw1kF/vDkSIECFChAgR/jZE5B0h\nQoQIESIcZYjIO0KECBEiRDjKEJF3hAgRIkSIcJQhIu8IESJEiBDhKENE3hEiRIgQIcJRhoi8I7yp\nMW3aND796U9Xbf/iF7/ItGnTQsfdcMMNoWOWL1/OBz/4QQB2797NCSec4O3btWsXH/vYx1iwYAEL\nFizgkksu4bHHHgPgpptuYuHChSxcuJCZM2fy9re/3fudy+VC19A0jfvvv/9vvq/Vq1dz1VVXHdSx\nDzzwAF/72tde9bVcvNbz3wi46667+P73v/96dyNChFeE+np3IEKE1xsbN24kl8tRX18PCBJas2ZN\n1XErVqxg/fr1IZIeCZ/97Gd597vfzW233QbAqlWr+PCHP8zDDz/MV77yFe+4+fPn8+1vf5vTTjut\nZjvr16/n/vvv55JLLvmb7umkk07i9ttvP6hjly5dyvnnn/+qr+XitZ7/RsAHPvCB17sLESIcFCLN\nO8KbHm95y1t49NFHvd9LlizhxBNPrDruuuuu4+tf//pBtblp0yZOPvlk7/fJJ5/M4sWLGT169EH3\nq6+vj09+8pO89NJLXHHFFYCwAPz0pz9lwYIFmKbJypUrufTSS1m4cCEXX3wxS5cuBYRV4IILLgDg\n1ltv5atf/Sqf+MQnOO+883jve99LT0+Pd53ly5czffr0qmu98MIL/OM//iMXXHAB73vf++jq6gKg\nu7ubD3/4w1x88cWcf/75fO9736vZ18p7ueqqq1i4cCHz58/njjvu8PatXbuWSy+9lAULFvCBD3zA\nu85I26dNm8b+/fu9893fy5cv5/LLL+fqq6/mM5/5DAD33nsvF110ERdeeCFXXnkle/bsAcC2bb7x\njW8wf/58FixYwC9+8QtvrL74xS8CsH///pD15MknnwTAMAy++MUvsmDBAi644AI++clPVllMIkQ4\n3IjIO8KbHhdddBF//vOfvd8PPvggCxcurHmcbdssWrToFds855xz+PSnP82dd97J1q1bARg1ahSS\nJB10v9rb27nuuus45ZRT+PWvf+1tt22bxYsXoygKX/7yl7nqqqtYtGgRH/3oR7nppptqtrVo0SJu\nuOEGHnvsMdra2rjvvvsA2Lp1Kx0dHYwbNy50rVwux8c//nGuu+46Hn30UT70oQ9x9dVXA/C///u/\nzJ07l4ceeogHHniArq4uLMuq2VcXP/nJTxg/fjyLFi3il7/8Jd/97nfZt28fIISiq6++msWLF3P+\n+edzyy23HHD7gbB+/Xouv/xyvvvd79Lf389Xv/pV7rjjDh555BEmTpzIj3/8YwD+9Kc/sXr1ahYv\nXsx9993HXXfdxerVq0Ntff7zn2f69OksXryYn/3sZ3zuc59jcHCQJUuWsHv3bhYtWsQjjzzC1KlT\nWbly5Sv2LUKEQ4mIvCO86XH66aezefNm+vv7KRaLrFy5knnz5tU89oYbbuA///M/KZfLB2zzO9/5\nDldeeSUPPPAA73znO5k/fz533333Ienv2972Nu/v+++/n4suugiAOXPmeNppJU477TTGjRuHJEnM\nmDHDI85ly5bVvNcXXniBUaNGceaZZwLwzne+k127drF3717a2tpYsmQJzz//PPF4nP/6r/+is7Pz\ngH2+8cYb+dKXvgTAhAkT6OjoYPfu3Wzfvp3BwUHOPfdcQJitb7311hG3vxKSyaR3P21tbbzwwgue\nteO0007zxuepp55iwYIFxGIx6uvreeihh0LWlkKhwPLly/nIRz4CwKRJk5gzZw5PPvkkra2tbN26\nlUcffZRiscg111zD2Wef/Yp9ixDhUCLyeUd400NRFC688EIefvhhWltbOeuss1DV2p/GzJkzmTt3\nLnfccQezZ88esc1EIsFVV13FVVddxdDQEIsWLeLrX/8648ePf80TfXNzs/f3Aw88wJ133kk+n8ey\nLEYqVdDQ0OD9rSgKpmkC8Mwzz3gEFcTQ0BBdXV0hC0Q8HmdgYICPfOQjWJbFV77yFXp6erjyyiv5\n1Kc+dcA+r1mzxtO2ZVmmt7cXy7IYHBwM9U1VVVRVHXH7K6Gpqcn72zRNfvjDH/L4449jmib5fJ7J\nkycDMDg4SGNjo3dsOp0OtTM8PIxt21x++eXetkKhwBlnnMFJJ53EjTfeyK9+9Ss+//nPM3/+fG66\n6aZQexEiHG5E5B0hAnDxxRfzve99j5aWlpo+2yCuvfZaLr30UsaPH19z/8DAAC+//LKntTY2NvK+\n972Pp59+mk2bNh0yLa27u5sbb7yRe++9lxkzZrBjxw4WLFhw0OcbhsGaNWtqCiGdnZ1MmTKF3//+\n9zXP/ehHP8pHP/pRtm/fzr/+678yZ86cA17r+uuv58Mf/jDvf//7kSTJG4OWlhYymQyWZSHLMrqu\n093dPeL28ePHI8uyJ3xks9kRr/nQQw/x+OOPc9ddd9Ha2spvf/tbHnjgAe+6g4OD3rF9fX0kk0nv\nd1tbG4qicN9991FXV1fVtrs6IJPJcMMNN3D77bdz7bXXHnAMIkQ4lIjM5hEiALNnz6anp4fNmzdz\n+umnH/DYzs5OrrzyyhHNuKVSiU9/+tM8/fTT3radO3eyatWqEaPKR4KqquRyuZoa9cDAAOl0milT\npmAYBvfccw8A+Xz+oNpevXo106ZNIx6PV13r5JNPpre3l1WrVgHQ1dXF9ddfj23bfPnLX+aZZ54B\nYOLEibS3tyNJ0gH72t/fz6xZs5AkiT/84Q8Ui0UKhQLHHHMMo0eP5pFHHgHgd7/7HV/+8pdH3A7Q\n0dHBhg0bALjvvvuQ5drTWH9/P+PGjaO1tZXBwUEefvhhb2zmz5/Pgw8+iKZpFAoFrrjiCjZt2hQa\n93PPPZff/OY3ABSLRb7whS+wb98+7rvvPn70ox8BwgoyZcqUgxrvCBEOJSLyjhABkCSJCy64gLe+\n9a0jkkEQ//Iv/4Ku6zX3jR07lp/85CdeVPiFF17Itddeyxe+8IVQBPrBYM6cOfT09HD22Wd72qaL\n6dOnc84557BgwQIuu+wy5s+fzymnnOKtPX8lLF26NOTvDl4rFovxwx/+kFtuuYWLLrqIT3ziEyxc\nuBBJkrj88sv53ve+50W4z549m3nz5h2wr1dffTWf+MQneNe73kWhUOCyyy7jS1/6El1dXfzgBz/g\ntttu48ILL+TPf/4zN998M5Ik1dwOwvJx88038+53v5tUKuUt8avEO9/5TjKZDBdccAGf+cxnuOaa\na9i/fz/f/OY3ufjiiznrrLO48MILec973sN73/teTj311ND5N998MytWrGDhwoW85z3vYcKECYwZ\nM4bzzjuPdevWceGFF3LRRRexZcsW/vmf//mgxjxChEMFKarnHSFChAgRIhxdiDTvCBEiRIgQ4ShD\nRN4RIkSIECHCUYaIvCNEiBAhQoSjDBF5R4gQIUKECEcZIvKOECFChAgRjjIcNUlaenuHD2l7LS1p\nBgcLh7TNNyOicXztiMbwtSMaw0ODaBxfOw71GHZ0NNTc/qbVvFVVeb278HeBaBxfO6IxfO2IxvDQ\nIBrH144jNYZvWvKOECFChAgRjlZE5B0hQoQIESIcZYjIO0KECBEiRDjKEJF3hAgRIkSIcJQhIu8I\nESJEiBDhKENE3hEiRIgQIcJRhoi8I0SIECFChKMMEXlHiBAhQoQIRxkOK3lv2rSJ888/n7vuuqtq\n39KlS3nve9/LZZddxo9+9KPD2Y0IESJEiBDh7wqHjbwLhQK33HIL8+bNq7n/a1/7Grfeeit33303\nzzzzDFu2bDlcXYkQIUKECBH+rnDYyDsej/Pzn/+czs7Oqn1dXV00NTUxZswYZFnm3HPPZdmyZYer\nKxEivGmhGxZL1+6jWDZe76542NuXZ822/te7G0cNXtjYy879wyxduw/Lsl/v7rxq9GWKrN8x8Hp3\nA4D9AwVWbekDoKyZPPdyN7Y98tjmSzovbOw54DFHGoetMImqqqhq7eZ7e3tpbW31fre2ttLV1XXA\n9lpa0oc8Z+xICd8j/G2IxvG143CN4d2PbOTXizdw3twc11x+6mG5xt+Kf/nm4wDc/+13oSiHTn/4\ne3wP9/Tm+NEf1ni/48k4F8075rBe83CNo/vcf3XzQpobEoflGn9rX+79+jv4+d0vsmzNPmRV4aK3\nTq55/I9/8SzPv9zNdVecytvnTHjF9o/Eu3jUVBU71JVuOjoaDnmlsjcjonF87TicY7hhu9BwN+wY\neMM9p737syTjh2YK+nt9D7dWaKobt/dz2tS2w3a9IzGOXXsz6K3pw3qNg0V3zzArN/YAsGnnAKcd\n117zuA3Oc3hh/X5mTWw+YJuHegzfUFXFOjs76evr8353d3fXNK9HiBDhtcE180lIr3NPqqEZ1uvd\nhTc8SroZ+m2aR/+YvZFcOJZtY5jiG1EPYAVqrheWgsHh8hHp18HgdSHv8ePHk8vl2L17N4Zh8MQT\nT3DmmWe+Hl2JEOHvGq6LTnrjcTdGRN6viHIFeRtHsc/bRb6kv95d8GBaticQqcrIH0mLY+bP5N44\n5H3YzOZr167lW9/6Fnv27EFVVRYvXsz8+fMZP348F1xwATfffDOf+cxnALj44ouZPLm2ryFChAiv\nHW9E8tYj8n5FaHp4jEzz6CfvQun11byDQWeWZeP+UuSRddn6VAyAzAE072x5iKZE4yHp48HgsJH3\nrFmz+NWvfjXi/rlz53LPPfccrstHiPCGwf6BAo3pOOmk+Nx6MkXSCdWbEGqhe6BAQzpGOukf0z1Y\noLk+QSJWHbiZzZUxLZvWxmRou+Wazd+A7H0kzOYDQyUUWaKp/rUHSFm2TVd3jgmj6pEliZ7BAk11\nCRLx8PMoayZ9QyXGtde9pusVSjq7e3OhbYPDJbJ5jaa6uLetN1MkGVdoSMcrm6BYNtiyJ8u49rqq\ndwOEANWXLTKmrbqvA0Ml4jGF/mzJu+dK2LZNV0+Ose11ntnZtm329OUZ116H5IxTXeBdz5cM9vbl\n6WxJeedYts2W3VniMZljRjfSkynSmI6FYiJ2dQ8zpq2OmFqbZGudUwslzbdmmJZV9bdl2WzenSGV\nUEknVOJxxfuOhgo6fZkidakYqYR/naV7V/B/G+7lIye8n4s7zjng9Q8VjpqAtQgRjkaUNZMbfvYs\njXVxvv+pswD499uWIQG3//v8mufohsXNd6xg9nHtfPQfZgLQny1x48+X8455k7jk7ClV51z7388A\n8D8VbbpKhvw6cLdpmWzL7mBq85SawsOR0Lw/++OlQPW4vBosfm4X9z6xlcvPO45Tj2vn33/6LBOm\nZ7A7N3LdnH+jOdEEwDf/70V2dg/z7Y/No7059aqvd/MdK+jLlkLbNuzKcO2tS0L3c8Mf7sYup/nF\nxy6vauOexzezZNdKmuoVvnvlZVX7n1i5h3v+spmvXHU64zvqve2mZXljB3DlBcdz3pzxVeev2TbA\n9+9dxZknjuaqd5wAwOMv7uH/Ht3E+88/jtOmdfLvP32W9iZfcFi5uZdfLd7I208dxwcvnAbAqi19\n3HqfiKq//v2z+c7dK5k+sZnPXSFWSGzcNci3fr2SudM7+fgls6r6kc2V+ffbljF1XBM3fHBOjdH0\nEdT8g0vvXFJ/YVMvP7l/rbddSuZomLUSuXkaVqaTz922jHHtddzy/73FO+avu5cAsKJ7JRefeGTI\nO0qPGiHCYYRmiAlhKK8BYDj+tQMZP4tlg7Juhvxre/pymJb9N/vcfBPhkWfvezf/ie+v/CkrulfW\n3K8bZs3tbyT8cevDfGHJLWimxsrNIsh21ZY+9vYXIFair/FZ+kuDdA3v8c7Z2S0ijQcOMrhJt2qb\nkSuJuxZ2De0hPmkDieNfrLm/qzdH4riXKI15ofY1MkVsoKsnrOFXmrbXjrAuf9veLADPrNnvbXtx\nUy8AK17uYSivIaWHyE+7H7mpx2lLRG4/8aI/Zv2Be3VzAGzYlfHb3LsRKV5gxYaemv3oHiwCsGVP\ntuZ+0zIxLfG+BX3uZoC8NSe+oC9bDJ0rN/Wjy3kxxoo4d09fPnRM2RDPOqkcuSVwEXlHiHAYURlf\ndDDapghSssnGtzGsiUm1NyMmt6DPc1XvWnoLB0524pL366F5P71HJF7al+/2tlkBf+PR4PN+ZOcT\nDGnD9BbD45zJlVEa/WVcRaOaaJUDBEC52DW8m2v+egNL9jxb+wDJIjZlNXJzd2iz+1yf6FpywPZ7\nC/6qHsuuHm83IK43EyasSvIeyVRdy6Li3rdhWpiWTWzsVtHGpA0j9jN4vf4KoaW/OMCSwu+Jz3hu\nxPMrz6nEf734E7723HeBcLR7Tisi1WWIH/8Cw8ZwVV8AJFXz/pbragsHZVMck1CqXReHCxF5R4hw\nGFG5tEc/iKU+Zd1Ead9LpvU5/mfdrwF/cnU1hf35Hn625k7+47n/8jQGoCoDl6d315hkNw5sYU3f\n+oO+l1eLxri/TjVI2G/0pWLuhAygW+EI6d5MESnuE0ZBD5MfgHIQEtOjO/8KwIPbH625X2ndj9q+\nl8TxYeuFu7xpW3YnALZR7QEtaQZFxSfvWn0cibzzAQKT6rIMpNbXJH+5xj2qTuCXYdqifcVpq0Yf\na12vp6Ivq3qFCVtOjEzQwf6XtDD5dg3vYcfQLnoKfWim7l9L0fn+y98mOfNZlOZetiUfreoLgBQr\nB/7WqIWyeeSj0CPyjvC64I2UZvBwwqwgU10/OPKWG4RWtze3D/AnJ3epUG9RTMq6pYcmm0rh4EBW\n8x++9DNuW/2/r9ifkfBiz2q+8dz3KRrVpBCc6IPEFyTv16p5W7bFrSt/7hFg9X4bpW0vcsv+mvtf\nCbuGdnt/l4zw5NybKSLFAuRtVCeRMi2bIW2Y32/+MzktX7U/eI2JDeOq+g4gN4nnbNvhB1jWTWzb\nJlN2TMuKUUWufZkScr2vKeZr9NGNZu/LhImx4JiWpbosyZnL2Bd/gd3De6vOryWfuEuuTMuirJtI\nqmjLNqsDNF3BsxAwZe9zTNKphAgEfMkhb9saWRjakFvtPefKe1m+z3cZ5PScZzavJGJNyZLT86G+\nACEhDTV8zu83/5lfrL0LzXnHa1lgDhci8n6TwPUvWrZdMYGaNY97pW2vBat61/LJJz7PtuyOmvst\n2+Kl3rUU9dIBr23VEAAOVV9/uf433LT0m6/6fLcfQfI2TCtEriMJMJpmIsUFIbYmW4Cg2VycP1jy\n/YHByaaSED2zeeC3bdvkdJ9MKv3ouiGIwbKtmtqWi9vX3sXu3F5W9673zgHxXILm/JAGG+hfppzh\nzvX30F3oxbQsBofL2LZd9QxFX6rHqrfYz4bBzdy/9aGq4yzbxjBMYsesIzaxtrm2ss18SQ+ZTHcN\n++RdOSn3ZkpIcX/cCjUEGNO0+c7z/81fup7imb3Lvevphskftz7MZ578Mn0lIaTJkh+xXiwb7O4f\nRG7sQ2l0yLscDnzTdJMhLYdhi/5KEhR1v49lzWT7/iGkhE/Y+RoCxLDdR+yYtfQMD4W254o6iZlL\nSc70a04MlAZDxzyzdzm7zZeRm3tInb6IPbl92LbtRZAXk7tZMfCMr3lXQB2zleuXfImCXgwJoK5F\nJp2IYds2e/Ou8CXh2pJ0w2JgqESuqDNYyrIz/gyJ414C/JgDwzK4e+PveaFntX9fWt5/xnJ1v7Kl\noWqzeby25m1ZNn/peoqVgfaPJHlH0eZvAnQPFvjCT5/l4jMm8fLOQbbvG+J//n0+Dy7bwX1PbuPm\nf57LxFEN/PWlPdy5aCPXv382MyYJ0vjNXzbzyIouvvWxeXS8hsjZIO7fIibbv3Y9w5SmY6r2L937\nHHdv/D2N+jF0r5zOj649J7QsA2BTV4Zv/t+LfPySWcydLrLzPfp8F3c/tpkbPjgHPdVNS6KZ0XWv\nLnPfc/tFAJBhGajy3/aZvLxzkO/cvZIPLZzGceP9VIr5khEycRumTUyt1iZKuomUEGSwfVeZNW39\nXhCNaxbvCfgyQ5p3gBwf2LaY4eQQ0IYkSZR1k49/90nOOGEU557lB9Zc96On+MAFM5h/6niG8hrX\n3LqEc04eiz3xRV7u38R/nHUjsQOMgaZb/P//+SRnnTiGf3nHDD73k6VkpN0kRCBxyKToE7PFnwZv\nB6A50cSG5Z1s2JUhEVcoayafuewUZk5u5cWe1WzdrPDw091V0dvd+XDw0n/d8xLb9g3xb+85ke/+\n5iWuuGgikmKCbGFaJorsE+SqLX384Herue59JzNrShvb9w1xyy+fB+CbH5tHZ3OKVV27vON/8dBq\nxqnT/WsnXkJp9f3QtUzSmfKgR3j3P72de34Dn79iNt/69UpSpz8ROvalbfvZ2TlMc32cz922DOnY\n5SSm+wKQVKHxLVu3n98/v5LkTH9btpynLp6mpBlc/+Ol5EsGiVk+mQxr1Zp3X/ol1NR+8oqBbvhR\n0kPlAnJdmNAHy2F/76833AdAfLLQqH+y5EFSPafQ0SKeUXncc7yUAzlZ+x5iEzZj2LBpcAu5cnXf\nDNPi9kWrKDrmckm2QBZC4pduX06PE6TWcEwXeJ+5ze0PvsykUQ3stzZXxRIM6znyJdFfqYZQ8Z3f\nPUeb7FtB5MZ+5PQw2BJIdugeilr1+bWEuMOFSPN+E2CjE7X50LM72b5PfJCWbXPfk9sAPzr0waXC\nf7Z07T7v3EdWiIIxm7p8Te+1wtXmRlp77EbuZhFmuv6hamn2ryvFMfc+4ZeSve+vIjBm+cbd/PdL\nv+CW5f/5mvsa1BoPFktWi34/tGxnyOddKOkhzbsye5aLkmYguf492eLhZ3d6y1hcTb7HMZvHlXhI\nU3DJ0bAMFu34CwPNKwDxvIediPdn13fTlfMjfVEML3p2l6O1PLVqN893v0TeKJAp1Q7ScbEvI/Yv\nWbMPy7YZGCqHTI1lI0jezrNP+pO1bulelHDZuc+la/ezYWAzt6+9i8cHBUm8vCus+e2vIO91OwYp\nlk1+/dgG1NHbWbRGmFslSZivg1i0XBDzn5buAMS6eq/d/gLL973Attxmb5tml0NLBOzOTeJ/UwgE\ntczmPaVe729TEtaRxc+NUIBJ0dm0O0NXTw7dsFCawgFykmqA5L87f1ixKqQVA2RLeedehCY7arRN\nLOW/vzm9uo+u8UFp2093xo84HypWa+lBzTtoNZJi4t4yWYOd3TnvGVYimbb454unM2/mKN4+2yfI\nn6/9FXvG/L7KvVHWTZZt3hHaJsU0hgs6PYNFL9+BlvYtJA2NYoy27xtClqvzIeS0PANDzvtYg7yF\nZu5bshLTxfcjmXHn+v67nCtWZ4orRWbzCIcSlZGiUrzAQCHrmbfcCdUNsKn000LtwJRXCzenkSzV\nfv08U63j56t15ZST8CS47MM1t9nqa5N+g6biVxOIogdyJQfHslAyQj5vbQTyHtZySJK7QNsMCS8e\neRd8Yqg1BtlymKzcyF8Xe3P+RCkpJprmulWcyzb4wlqmXE3ewTEKEpebgSpE3gEBSDMsUDXkQKT2\nYDHnBWC5SMQVT4iT04JUKpPT7A1EsQ+X/D70JlcTm7iR4lh/nfJAKSx8uglz+lIr+dnqX4aEqoHC\nEHe+fA92zH+PJMW3mtiYge2Oz7aW5q0NBI4TRDE40lI/xaA3U/RiG1yhIISA1hcb5QsBVkEEBA47\nhNubKSI39DM0cTGmFCDvcjUhm5Lfn64B35ozXK6+n6DmXWt5m6aJ96w0AnnbisbZJ43lX981k7NO\nGlO1X2kJC2NlKUd8imOSdueCWJk9TuKaGcc0ITf1hPz6l54nhILebBHd9L8LN2ZgWM95Yyyp1fcg\nqToZR8gNCktSsQXbkpDrhjyXVqYQHs/GeEOkeUc4tIhVJNxPnvIUNy3/ukfq7oTvEnStmsG1siu9\nWnjBOCO8fqZDDF6QTsW1C3qRVervUNp3UyxXTxSGUjs46GAR/AC1V6F5uzm7K8k7XzIOSvMeChCv\npBj0Z/0J1jQtbNv2JlLN1MgXq33Kg+UwWelGONZhMKhNK4bXF89H3uATT7YGebtL2AAKZoA43Ykx\noKGEzeYWyROXED/Gj3LPFKsrMCVisqfpuYFKlQUt9hd88t7W7U/87uQaxEDRHw/TMulveB65foBy\n0xZW9a1DM/yJfqhUYwJWDIYdTasQ8C3bloRqJ8g774zhPV+bQS2gPTvrg0dapy8pOn2ZkhfbYBuB\n4C6XuFS/j5Ll77cKIrmKaxbvzRRDhKZaooJXvkLz3jS4FSvhH7c34z/zXMDEbmbbkGyZTDDOooal\nwRVkhgvV38yY5Hh0S/e+p2DSFu/8eFhrjU3Y4AluFB33k6qxs1tsax6TJTFNuLdkW4yHkhTj25sp\nUTQDz0lLOPeV9yPTlWrNWVI133IQ0MylvTPBjCHFyyROehqAgYL/DZzVeiGtyRaKRumIBeNG5P1m\nQ0CajKsVmrdyZDXvkczmXiCRM2lVLrfaNbybItmQ9haEJudqbg9ix9CuUDRxELkAMb0as7k7gcdU\nudpsHiDQkTSUIT3Qf8UMBVaZlk1eL2AENJ+hgJbktl+pLZtWOFguUwoICLJBWQ8njwlOpBkt7PuE\nsHBQNH1hyU0sEgzyqQxYq4zyHa6hESZiCrudSHsMYbKsDCQKmnF39PmWCKxqrTWoea8f2EivuoHE\nCc+BLO47GImdr0HekmJ4VoVcYLy1DXNRiFN0iNHVzmPHrmJDYZV/vqPlDeVqvE+2BIpBT7ZAr5sg\nJBCZHdNF/Elw3EzE3+ZQC9aQKBE6rPmad5D8k5YgviB59xT6+MHKn3r3D9CdC0SmOwKKOdSKtuUU\nVCsVGsOagVmOcDGU10LzjLZsY2voAAAgAElEQVRtFi0J0Qc3ULJWamC5gryDEfbGUKM3Bjv2i/dR\nSfnHj5VEPIIuFVBkib5MMWzCdsZzqJwj4zyDWj5vggKSs9/oHYdeSHrjLzlj1p0V42V0T6TdmEZK\nTWLaZkjjP5yIyPvvEAOlQTQzaEoNkETg5Yx55C32u2ZzzSpXSemHMsmHa3IdSZu3bLe/Yn/l8ifX\nZOx+XJU+tqLtE9NI0dI/X/Mrfvly7dz6w4Go3FdjNvfIW5FCVox8yXCehY0yagd7ctVLbwAKhk/e\nUkVErGnZVcQ8FCC/2uQttO6gmX446ANWzCrNO6g51zKbBzX3vOFf39e8S9imgm0qlMxqn3cQtZaa\nxWOyPz6KDtgVS+L00Du6dzAgyFnV01qwv7XKoxYD1oOc5vfX1WpRDE+wcTXvuvxUrFwrshX3rDWu\nEKS2+W4J28bT8mwIERtA0m5Ckm36snl6B4uOUO2/N0nTIe9gwJfzHWtbT8Z2hJu8JvrQmy2FiClh\nC7N60KKU06sF3IGCP0Yll7wHRoMZQzHTDGnDXpayWm4C1zIwVNA9rdUcGIXZN576mMidviWzXRxb\n49sXAl8wsMA/xsoJ8pcSBU/zjiXEOOp7jmVS8ngAslqWtqYkvZli2IK2XUT2DRQDgmhgjLRts5x7\nEGOcSqj+flMV30dIKLTodSL0bSNGb6ZIWhWBevkaY3M4EEWbH4XY1T3ML/68nk9eeiKdLeGi9nm9\nwJeWfoPx9WP5wunX8H+PbOIvL/oaZnACiFVq3g5Db2m+l889bdO89VLv2B/9YS3nzxnPFRccf1B9\nfOCZ7by8c5Dr3z/b+1DveXwz2ZxGLqWBAiDxg3tXsdkpSPCpfzyJyWMasQhr3pWlI7tdf6/zcXUP\nFkKBa3nTn4SeXtPFI8v3ceOHTvMi1otGkUw5S4MzET350h4ef3EPLQ0Jpk9soXNyteb98PKdrHi5\nhxs/dNorWiFcYUOpMJs//uJu9vUXkFI54pM2cHfXBo4ZfQ0dHdNC5xfMvC9WKz7hqopMMbGXb6y4\nN3R8vkLzfu7lbh7ZvAncVNWySV+2xH/+5iVQyySmr6Bo+WSlxk1PAHIzuAXJ+4k1W2nMdHHh3An8\n/qmtDA6VGTfL13SDxPenZ3aI8+MlbC2JpOrsGxxC003iMSUsSOL4CZ3+j++oZ3dvDqV1L08Ul1G2\nXQ1JRBkPFzRu+eXzDAyVuPyiseLWJBnLtugeGgScSHTJH3NbjyPFNJ7esI1ZiX5mTWmrGVQUFCCW\nb9hLYgZItoK2+VSSJz8VIkPN0kgACScVpmzF0S2DsqlVuUJsSwJLDWt5FRpfPpNAaRVBcXv6ZEa1\npukPHJM0WxlmK/Gpqyi91Iytpfz2jJj4B+zoHeCWX66gN1MiMcrC/WpiOLWoy4M8u24/v35sM23j\nstDqdlJEUvfkMlzzvb/S2ZRkz2CWeDNgim9GNtPY2GTKQ7SlWnjmZT8S37+voNbqru0W53em2wG4\nc/09nNwxq3a8i2yKNtzgMMdaUVpzJraewDZU1NE76Fk3FmjwrmFmOmlKNEFJBLt2NI9l3fYBVm7d\nBzEorT4Lu1SHjEJ/UVhrmuvj5J0xHNf/DrYMlGHKWi/4rqkuTlmvuIf++fR0OMl0VIP+vAZpMUZ9\nmRId44UrIK8XSODniT9ciDTvoxD//fs17O7Nc/+S7VX7smUhDe52tJYgcYP/UUE1eXuk5Ex++/rD\n2vdjL4i2Hti6iJuWfjNkuq3EH57ezoZdmdBktvi5Lp5d3+1V7zEsg1Vb+ymUDTI5jXXbhfZkOaTq\nkXeFGd9dJuVOYNv2DbFuh29CzZm+VnnnY2vZ11/w2ga8NchlS5DDLxdtpKsnx+qt/fz2iS0hs7mr\nzdz7xFZ27B9mYFhM/I/tepIvLf1GTbO6YYj+xlQ51Hd3PIPEuLXGWveiJTRZ24gJE51kkUooJOMK\nQx3LveMkXZBVIUBGmmFy2x/XMRQ0dct+pLrascf3IzqIxy3vOXmacUwTpk5bwlSKvLBR+JT/vHQn\nz6zd7yWPUWWVklXh/5QNpJiOrSWxTRXd0ti0W5hcS3p4vNpTbRhogM2oVnE/8amrKdhhbV+KldnX\nX2D7viGyeY2N+4VmO8FJbtJfEM9/zrQOkulAycfhZmxLwlIK/OB3Ivgp6Au1ikIjdCO1g+M1Sj8Z\nu5wS5tsa5OvmsVY1oRWu69/gCUFuxrPyunnYhhr2V1eQt2vilhQD07LpaEqScuSQs8fOo8mY6B3b\n0C76Kak6tiWDrXiad1e+i+37hsgVdU8rbUk0M8Y6EXO4ha58F89u3k6uqLN70DeB1xtCECrbBbbu\nzrJsXbfXx6ljBOm675rrLtm63/fne/0P3CMO8br7zh53BhMbxmNjM1QerhKgZEMoIak5j3uBeZKq\nYVsSdrEejDj67uORZBu5LksqoVK2xHc0a2In582ayrFNk9kwuJlp08XzH8yL91wkh5EYk5jAgN6L\nlMgzujXtPceGRJpPvWc2tiV5yk1bU9IXnB3yTlvtTFBO8PqWLfrfaaGkM7vzJE5sn0FnXTtHAhF5\nH4UYcgJC6pPVfiP7gCUvCJnN3UxIru9VqTRlSbVNzot2Pk5faaAqorkWatbudYSDsiHuo9NZu+ul\nAK2INh9Z8xb34i6BclG2AxOxc7/BW3Ozk2mmVm1Wl02yAZPycCk8ybhc/IctDzJQGqyZdcowLZAN\nCsmuqr7Hj3+exPTnvd+1AuLKtiBDu5T2+pSIKSiyhGwEAn1KQrovB8hINyykZA65MbBGOKC9h9Jo\n6qItJRYgb9MCbKRYmeZEI7aeQIqXqtJn7sntJ6HEmdQwAc0uhd6Vjsni2tZwC5gKyCb5ongPKrN8\n1cXS4n1QjFCZSxcJSbwbUkwjmw8kRTHFxDyrbTqqpNDPDuIxiX+7ZBbHT/K1HqtUJywA8ZJnBXH9\ntbGhiRyfOA2A/ny1sCNZCiCBWaE5O/uTagJFlogPC3J9dt/zvrAqmyT0Nuxio/C31iB/M9sqtErX\nv+1sb29OolllpjQdw+XT30NCrqO8+RQAzjhFVC5D1T2N2y6lMTPtKI0DyM3i25Ad8rzm1I+RUJKY\nvULI2Ws6Firn2zH6RzPJnOe0GXgXnb5ceubxpBMqaOJdcZMDlZx3Ttt6EqUX345VrAuR97hRTh4B\nh/iS8RjHNYtqeHkjXxWVPaV9lD+8DQNCaFUMMGOeddF2+iCpGsm44rXxrxefQjKhcsGkc0UDDb1M\n7KzHkp3+OO/8xLiwcClt+xjVmvb6m5QTzD6uQ5i9nW1j2tJV1gNVkZkxfpTTB92L1bDNGLppM731\nOD520j8TV0Yu9XsoEZH3UQg3pWFDuka6wVcIsAp+YO5yD9eXqCgyECAb+cDZygyrOjBDMzVRLEEO\ntx3KmuWQd8kh77FO3WOXICqXihkBn7duGV6gkrf8JlS9yaZsB5f4iD4G/es9gexfwSUvUqJA6rRH\nWbTjL962XLmCvKtyh9fI8mZaxI9byZ66p9haeDm0T2nuC/2u5VPXKGDbYDlZtSTFEOStSEiaX3fZ\ndLRGzfKfeX+5j8TMZUiq4S83CvrNAwFKdllMikrM94drugmqjiTbNMQbsEsppHiJTD6Q7U6y6Cn2\nMrZuNC1JQSZBa4LWuB3bkjB6JmBbigjGGhSknXeimG1TYczwOYK8Ee9lMmVWvXMtsrOkKFYmGxDS\nipYg77H1Yzix/QSM2BAtHWUkSfKIxb1HW0tCrOwJGG4msobisXTUif5nQwF8jvbs+DhtUw2Rr/ve\nJeQE8ZiCVaynM93OjuwuQd6ShSTbWIZ/vhCgrND5Vq4Fu9jgkYvSIN7rlqYYNjZJ1dHsFckjLtcq\nIyl6IChNwth/DBAonOFcI6UmUWQJa0jYyPP0e+MNYPaOp0FtAFsKPUM1bnrnx2Iylkve5QyWbVGS\nRTu2HgdkQdJObAKAGjP8sUMQn/usl+xZ7uVkd3FSu59tJnHcS6RmPYuk6khmjNaGROBaQEwnHlMo\nODEPKUX0rTMlNN5sOYuqytiyLtwWtqC5MbFjkWwFdew28vWbkRQD25KIqWIcG2PNSIkCclOvqG8e\n8Hm7z8H1a8tNfdhjnRUThloVVHskEJH3UYxaUeFBM26tJQth8hbHFsoOwcmSZ+6CEaIxAyjVIJ4l\ne5fzu81/Iu6UKHQTHlQm+we8oLr6VIzm+rgXqeznwq4OWOst9PmEqRiA7ZVelOoyJE56GjsogDj3\nIAX81K7mDWHylOurE9G4ZnMXlZp0Lf+pYfpJNoaMAye3qWV216WiiLB2J+eA5m1LgQxteYe8bf8e\n9pZ3ICkmetfxGN2TgLDmHXz+linGRFZNz/pSMjSSJz0FQL1aj1VOIUli+ZUXSZ7MY9kWY+pGe+lb\ng9HphprHLtWBkRBaqwQ9Q4Jsi6azpGr/MSQK47wJXU6UeEL/XxLTnwtZB1plx7edyntCK0DRsa40\nJxqZ2ijiMFIteeceAuRddDRvyRcw3ICidCzFqAZB3iUr8Jwd8rZ0550xYk6kcfC9EwlyknGFsm7S\nnmwjbxTIlYre+YYhuwMi/ne/rQpSsDVBCLGJGyFWoqlRXNclJUWWQBcEVrByoh+q7hEj+OlTvWVy\nAdO+KsvYWgoZGSvmm91BmHx1A5JyGjk9jNK2l8TMZ5AdQSKlJokpMmbJ1byz/HHrwxjNwuftWg2E\ni8dGaXfW5scCPnkH7rNetm8Fd738W4KYN2Yu7516iX8/ySFQdFRJCEiiLdcXrpGIyRSMIkkl4WXO\na3LqqWfKQyKHhaO5e5kiTJVYYTSSbLFOe1ospzNVYoo4/8KxFyFJoI7ewdi2tDf/ueMcU2XqnMC7\n2Litfl/N2EEVHDrUiMj7KEN/IUPylCeQW/bXXCccJCOj1gtVg7xdYrVtO+QTr6V5u/5qqC7WAHjL\nJJTGAZANr22fvG1PA3LN5om4Qkdziv6hEoZp+ZHyjoYeLIPZEyBeSRZtuZp3/NhVyE7mLjcgxtMw\nAm0EyTtoqQhOhi4KevgeNcMK+fprJWUIErxsB9s8sLAFIg7AUHJYpTS25ZyrGCTiCrIsYztadHnT\nqZ42plt+H12t08o3CpO1c76LEHk7yT0kxcS0bAzTIqsNeoFCti152rmUKLK/wmffnGyixZkw3XSu\nSBaWpHt+WDdCtzebc8bL0byNGLph0eFoS7HxIpuZXJ8N9bfDnoptg9wUWAoGDNvid0uiBVsT10qk\nxHlFowSWjLZtltBuXXOrI2C4qTjr4knGNAvhQx29E6VjlzceAGXNyXtQSiPJlne+uz+pCGIp6xYt\nSeH3HigP+uTtnO8SnEsGleZYs38Mck6YY+Vkgfp6cZ6vecvYegJsWJ9ZizJqJ5IEiuW7GWwthW0H\nnoOsE1fiKLLiLAGVSMuNSIkCx09o8oPLzBjDeY3ZTWcgqQbxY1cj1w1jJ7NOH5IidqMk+rJjaCeP\n7XrSfxCGK4CIMY5PWRsao+A35Uac10JKTXJK58zQNkm2ScpJLzmPa2lQO/YwNGoJBb1ISvXT5SbV\nBEklSaacJaZIQrM2VRJxcb5hWpT2jQ9dwzZjnvvw2JaJIj4hVqalIeG9h+51Fdm3HoTa0OO159rD\njIi8jzI8vnOZSBRw3Es10xAGyeBHf1hbtT84eWu2+LusmTy8fKfIchUMOlGq28+X/PaD5kkXwQpS\nUsqv4OOlHJRsQbrgVeJJxAR52zbc9sd1dGcdE6ZD8oWSwc/+tI7t+4ZYvlVIvHaAmAbcDGSBicI1\nobkfYFk3uOuRDdz57OPszPjpX3/1WKAkZg0ff7DYA8AdD7/MMxt2+PsDWt7zG3q4/+ltXoY1gBUv\nB7JG1bBkDBULWJbN/zz0Mt/9zUrW7u0CyRZBOs49SorQvFVZwpZ1UnIdVqbTm1SCfmQ3iMc2Yz75\nB4Uwl0D2noDpaeZim6abFAMa6JT6qb5Glyiw28ls5b5DxbzM48sdM6yreTt+U9fE6fZxy0AXS9fu\n8zVcI0ZXT47lS2JYxXqkej/gUJIgZTfz2TmfADOOlWsRVhGnbbmhnyFpPzNaj6cp0YBWEtey4jl+\nt/lPDJYzKHoDZt94QAqR992PbWbzfiG8NSTqGN/a4l03Ptl5FxzXgusxsYvChy4lnVgKR4BKqkkS\nMZmybnoWiKyWDWhsYdLxfMoBUhjVmgYkFIe8pXiJtMNHSVX0W5Yc06/zPONOXeykEqg1YMvCwuCQ\ntyXrYc0dyA6oSDGd3WN+i5zyg62GCjpnj52HbVbTQVJJEFcV8jlQibNreE9ovyuY6Hum+hslv5JY\ncL16LeKb0Xo8lxx7MZIk0ZRoqNqfjqX9zHqm6iXsKSf3M1jOkI6F6y00JxrZm9/PQMOLYi4zVS+W\n4p7Ht1AeaGFy9p3e8dZwi5dpsi6pYhtxJFVHVWTf8uhp3lLoHqxcE+WNp4KexDBtNnVl+PkD6w95\nIaeREJH3UQYvwtGSR9C8fXJdvbWvan9wOYcZINpHV3RhmFY4KrZm1R0/GKyW5h3URCXVr9Lk19AN\nrNX1yFtm1mThk3txU6+fwcohnadW7eXZ9d3c8svneXHnDsDRLBHLSTzLgeFrIp3pDq8PAC9s7OXJ\nrmdZXlgUShm5bmcgM1egb66/2fXLu9jTm+dXj/vJN4JLjH58/1r+9MyOUCYwSxJ/nzdnfDga10Hf\ncJ6+TJElq/exbscgK7YJ4cQq1vtai2z4ZnPZICY596m7+Zb9PnqWF1MNkH+15l3YO9Zr31svr1ue\nyVnfNY2J6SkhzXt3r/PsHRLq2qvRtdsJ7nK1UvcenWdh9ApNRx23lb+8sMeL5G9K1lPWTdZvz2L2\nj64al3pzFJObJqEbFtZwC5IEckoIdW5ynrdPOFvcS05MY7uNjSLeApDxScMn7yKPPt/FcKmIbcP0\n8e00pfzJ2DYVLjhtgvfeubGK9YrQqpV0nqb6uDdeKVVohZpmehaIgVLGF5Zcn7kTeCincsSOXYXa\nIVZttNbVMXl0g9NH8b6NHS15edBd8vVQIfx1NjaGfttaCjlRIjZlFSYaKZf8HfK2HdO7jS0sHDbU\nxVNccf5xjG2vJ2ZWL29SZMVZlSJR6vOjqMubT0HvOs57zu3pZibEnWWkqk5RFXPP5I4OLxVqXQ3N\n+x+mLOSCSW8T/ayxfGx0U6OnOYPvv3aRVivJ2zGdpzYiyRa2odJQkRBmzqQp3t9mpsPLQJlMqCTk\nJHJcFwKPa4Gq4bcHMHomYGU7aUzHMEyLp1ftZdm6/fRnj0x+84i8jzJ4ZGHEvIQQQYQCoCo0ye99\n8kzGjvJfZMM2mDymkUmjGiiUDQzDqliPWi0cZEv+MqNaPu9g6khJMTzy9uoDy7XIW+GMmaOZM80h\nXOe6biajYPUeOT0szLmOyTc4oU0d3eH9PaqCvLsHi6GsX36DZu2/HXPgLv1lnt+/MnxOILCndi7j\ngHnc6d+0Cc186rJpVUdqZjkkhA3oTgnIYr0n8UuqCNBRFAkUHQXXzxj0AYoJzrUE2IbqTToN9YHP\nXNHF0idLEe3bEpYsyLikGb7mbsQo66YnxMjJgle8xBUWTC0WIka3L+75AP/90X9gUuMElPpBhu0e\n9sbEWH78nbO5/v2zAbDyTVXjIluOS8AwPdLxVg441291TNXZrF1V79pSfWuEbz1wBQwDhRjzZo5B\nkiTU3aeKZUKKiT12LRPGiOuVyqKm9LX/cBYA8+c1M3/2uEAwWIJETMEGGmMueQ/6AW+u5u1o7kpr\nN2rbPuQ6IYTc8E9v83IPWI5PefrxKe+7cjVvL6eMHo7GnzVugvf3tz8+jwmdghzV9n0YUtkjb1fz\ndp+Vi3Qsxa1Xn8Ox45qIqTIzx4r2bEOl05jBW8ecDvhLSs0+EX9QrzZgDY5G6fdzPnz742+lNS2+\nydiY7QzYuzm+ZSpffN+5/MvFM4BqzbshVs/ExrAZ28qFBZJxLS2hnPZSxZyUrqHNB2FrqVBFwpnH\ntPD2U8czPjlZXG+oDdW5P1mSOG5MBzYW31/zQz/4L0DeDTFfwDH7x5KMKzTVJzBMS9R4l6Ct6dBU\nX3wlROR9lMElC3dyrUTIh1rxosdUORSJbdg6MVUmnVTRdKeggHJgzXsoQN7lGpp3KFtWgLx9zdvv\nk0fejmTtfaTudR3ydgPzpGQOuW4IK9vmE1egv2ogAdKYOmfpiUMm/dlSyKzuJVEIEHZQcLED2ZTu\nWH936B6DwVmVZnWlfTdKu798zG1TUYTJOwjbktCscMrUYVMEuNmlOo+0pFiZZFxBloVA4+ZxxlKw\nLRkppnnpJstWwIXg3INhB56pE8ErGEEiIdWhSeKZarqFZrmFMWLifdATyCjIyYDP2yFRvaSCGcM2\nFc9c6xKrazaPqTJtyRaQID9KVGhS+o9lcvNEjxRcK0pojB0TsW5YHmkpDYNI8aInILgTaX8mXPEL\nwFQC5K255F1EqsuKwCzbJ8J0cRJG9zEAPLN/GT3SJm98VUVmVLodWZJZ2buGPnmrFxOQiiW9dzet\nOMVB9Kz/jjvjbzlL+jwyQBBZc6LJS0lslcWzzpQy3mqKSh+xvP2tnNJxovd7Zsdx3t8xVWFWy6zQ\n8UmPvMU4G/smc0LsLG99eqXw3ZRwnoNkc4w1jytnvFcc5xatGWrj7JaFvHvM+wHoqCCphrgjPIze\nSVxOcOnUd4b2B8n7golv4wunX0MlyhtOp7R2nve7M91BIjYyTbkCigvTDs95drE+RN5u8NtFoy6l\n+OLbwYyhBoJZ61RxDz3FXuRkwUmyI85RFZn6eB0fPuFyJmbeAbZMQzqGqsjohk1vtkRrQ7KqENTh\nQkTebzDolsEL3atGXPJVMv3JtbbPO1A4voJ8Y6rirSEGQLZEBKVTYSlb0MKm3Rqad1+gwENNzTto\nNlcM8k4ke9Eh76CJ17CdJTfOB5WIK0h1Wc8n7loO3OVZSpvwVZt943yTslJtogY/o5N7P2U9LJgk\nqHP6GPQHB9ZDl/2JqXISDd5DZWrP+JS1xKesCbTpkLcsoVNhTrNUdFsLZR3TLFdzjnvFFKR4mURM\n8QOeLH+JkK3HQfXJW7c1Z3mM4pnNTcIJQkLEJTWiSQWQRIpUL3LdUB3BSyItNSKlh4jPWYSUzHkC\nUbkozKlWoQEplUNp24OccM+Pe/fdFHdIIZHHzHQwpnwasiT7BXMMv7b4jBZhnVBNJ5LesDxBQB29\nk+QpTwrLhy15/s7eTMl7Z+aOOhWAMaXT/HE2Y9iGitLc65fRDFilEjEFK+dr/xa+5qwqMnElzsJj\nzmNYy/F88REUZy11OuYHU6WoR5UUitKQJxDalkIqoYAR9zK9uVAlX5sDMHSFhBJnsJxlq5NCdErT\nJIKQyg28a8qF3u+JjX5lrrgqc+boed56cPDJ0hUQsFSmpWbzDqeNyY1+8hcARXIEViksCfnr6yVa\n9anETfE8O5rDxOmSN8C5o89hQsPY0H41UBP+H45d6AsLQVgqdqGJy6dcwTWzP8bcUbOrqskFMXfU\n7NDv9x1/ifftg1jn71aQA0g6wlZSjXvvnRog2/p4WJMXFj4xfm5g2+mjT0XVxfuSTsaIKRKGKQJn\nK8fkcCIi7zcYFu94nP9Z93/cv/Xhmvu9IDFbqql5B3OaV5KvLNuUrTC5xlWZdDIG2BhNO/xoVaiK\nNpcb+1jc/cfqvgQQIrMKn7c6ertXHxfAQiz18j5OtRSuUSyHydsNGDKHW3yTskNoqiJj4pN3a7KV\nhJIIpYN1NSZzqIVOyzFhBwQc19xp5ZrQd87wtlea+4Jthgs01Fiap7h542VPsDGHm/nQ8R/ANhVM\nWw/lHNesssiFbckhzTsekz1BwF0/DIARR1I16p01/yaaFyTkBqwFyRtVJyb5ZFkvi0lIbhhk7eAa\nj7xtM0beqaLVoDoR5bKN2tnlkVCx4JiF841IEsSPXYM6QQRTeZYRSQpN0ma2jQ4nKU88oFG17lnI\nl97yWT4y/QOUN84hVRJJRXTDCsUyiPHQUOwEsiRjWlaoZOqxzZP477d/i3GcGDonKIyBsxzPQTyu\nYA2OQt80h45Um3+Q5QsYFx9zPh+Y8T5/V66JZEz1a0obNu2pdqz4MI0Nzn2ZKumEeBZuJjcXHzxB\ntOUSgmHatCRb6C8OsGlwK82JJi8ILvhadaTamdQwgYXHnIcs++MXU2WSCVUkxnFwaufJ4hoBzTKm\nysyfcDafnfMJPjDjn0J9mtYqgs7M3rApOzPsv++9mZIXhNpWURmsMemblMc3d3IgjFQO2MVJ7TM4\nrmUKkiQRj/vvu7Z9JlY5yWzlHXz5LZ9leutxofPG1o/m+jmf8n7bxbqQ5u0+r2CKYzVQddHVvF2Y\nw63e30GN2nUDphNqiPzbm4+MyRwi8j6ieOz5Lrp6wqkphwoaDyz1g5y2O8kLdgyJZSuPrOjinsc3\ne8uhvCQNstCUlq3bz12PbOS3T2xhqKCFfd4V5Js3CuGkIgHNW27qJT55HWpnIA96KEDGJhYo4wjV\nAWuWbYfK5EmKEYo2V8eJ7E5mthUz60ySkuV9nHvl1aH2RE1rvw61p7kYcZ/YHD92XVL1lr5JG9/G\njq4yacXPmAR45KdvO9ExHVdq3k7U9daTwIxTeukcZDPBkBZ+ZkENasPgZm596o/c+9ctxBM1stsF\nNG+3kIax5zjmjj0RLAXN0vnLCr82s47uCCaSuE9bglhZWCVc4UMPrO/V40iK5UUoIxu+VcJdR+ya\n6yUTSbZIyP6k26AKv3HsmHX8pe8BhmNOX4yY9+wanWPASTiiashWjGLJoqUhUdNnbet+bEWQvO1S\n2iPvYKnaJI2MruskEVOxsh1YTlSxHtC8XcjJAoolnv/9T2/HtGxithCwOlLtSJJUpa25goxVrMPM\ntnGccoa3TxwrYWY7mBYkA0vxtFZJkpg35jRa4+K9NbonElMV7zovbupF1uqRFJN0Y9k737VqeX57\nQNtyMjNahb/YNWnbNqT0jnwAACAASURBVExrOZaSWaZgFJnaPLlm8Q5FVvjc3E/xrikLKrZLwrxs\nJLC1BAoqJ7YLAVRRwiQPMLlpkhfU6eLE9hOIbTsXfdf00PZgVbvebJGCM1e11CdCxzUHyLsj3Uot\n/MeZX+Rrb72h5r4g4oHnlwz8bfZOoLzqbYxPTmZUXW0BIRiBbpdTNc3mSlCgCZJ3haBuF/0I+CDJ\nu0pJXSoW2t4RkfffH/b05vj1Y5u56X+eC23/5cMb+MNT2/ijk6fc9dmokkJfpshv/rKZxc+JZTa6\nqfumV4e871y0kcdf3MOi5bt4fkNPyOddGdzh1om2yk6QkWx6Pu9ahelD/uBE0VtD7aLSbL5++wAl\no+RPtorOcN5J0lLWQRITsbb5VH8Nsmx6H+cwvdiWRGnVOZhZ5+OXA+StaiKBhy37EbxOn+rTMXRL\nx9bjFLJJfvC71aTUdCi5hr+EJ4ahy1X36Js7nWIMRh2y1iisCZIFWMSnP4fS0uMtWQHYYDzDw8/u\n8pa+ubBtKeTzdjNC2UYMWZbEGnDZ4PHnffK2EMk3Jo6qByTQ48LnHTCbG5r/2bpaaSJtihSwiu6T\ntvMcOjqcNe+uoBPQLprjzc44OlYBxe+jm9K0PuaTr6Tqjt88TqFk0NaYrE3eAW3ZM5sjfPluLedY\nYFJWHRLz/LPOulndsJDNMEkAyGaSkmbw4DIh7F7UcQVXTv8nprUI7TFoKhX37rgjSmm0jXOZnvTN\n6u77Z9swJu2n6cSWQxM7wPunXIm2YwZm/1hiqkzKuc79T29n5y4np3Z6nTjdUkgnVc49ZayXZAXC\na59PmSpMvP9w5jGcMdrv07njzwx0vur2PbQ42cckSfJIpLT2TK4c93FPuw1mF4yrI5ugAS47ay7Y\nMmec4I/DhXP9wLjeTNEjrvGdgqynTxTvUFOAvNuStcm7OdHkrYmvhfEd4t0MCl+1zOYH8oMDTJFP\nQ983GZBFeteKtoKat+dWIOxDNzPtmAP+OAQ17wWnC5fD204ZGyJvNxvckUBUVewIoViuvfZv/4CY\nLF3Tn0veiqzSE8gnPVzQ6Q/UL0a2KGtmyHReLBuUpaDmHSbkYUeDtMtpSJRANompCnXJcO7lifHj\n2aVtCpO/G6S07xhGG7Pon/DnqoC1fFlDUkyxbjemISkG/UMlLMumZJSRZJvpbZP5+GfO56tPbKef\nHpAt74PSEZnF7HLaXx8qW77ZPKb564fLKSRkSAhLREdTim5LCwWaqXZSRKzLplgj6yWmUCmXbVER\nyCHsKWMbSU2qZ1sOT7CoS6pCY0oBqoYkWSL5DAifbkX0uhtjEJfjDL94BvETlgc0b9kb/1s+IqKX\nZTuGpYhc4u4MLSkGthHnuHHNfPby2Vz/2FKkZE5MHE77Wjkwm7sJJGI6LQ0xioqF5Wb0slQScoJU\nyuCmj8zllj8+AAjTopsfqjXRChUp6t1o9JyjeadUn4ileAlUDb2QwrJt0kmVn3ziHfz48TgbSi+i\nNGRoUBspBsgqpHlrqZqatxfxK0tIkh+kqBsWsRqEI5kJT7iYPKaBK+efSl+fbyFprwim0ndNI3Hc\nS+h7BbknAqbYoJY3OqTNSSGTKMC4pk7MHuGLjqkyHQHTsV1hGscU39aHFkyjdetuFu0Sgkaw1vak\n0Q386NpzhGUFeNv4M2lPtYX93QcoV/Ctj83zAh49Td2Ih4g0SE6V91OJd501hVMmt4a01ffNn8q7\nz5rMt3+9kr39ec+d0tqQ4NZrziYVF8cGY0Nqrek+GHz5I3PRdCtErkGzubftAH5wgOPUuazrEgpR\nkLxdn/dImncwT4W2KRA3QVggPPeUsZw+YxTppMpTq/wA1WT8yFFqpHkfIVg1UpUCVaYxw3KLhMih\nYhCFkkFf0c/JLSt+SktX8ivr1oE1b6fghuf/U0zH5+1XParPT+WczvnO/mCqVD/pQn3MCc6p8Hkv\nyQo/va0lhd9WFVWSBoZLXtnIpkQ9qiLTXu8kvohp3sdZtovexGa7NZklV/O2QdWwveAmmTq5Ednx\ng7c3J0WQn+l/1JYernYkMi4JE2nZ4V3h47dJxhUM1zfsCACphIrlraUuh0zwthFnsiYKIXjJLZzx\nPnPs6SiWWKftEroiS2TKQ0hIjKpvce6gVhIVE0yVeFymPhUThUEUC1s2sJVgoJjTDyeoTZcLtLW4\n5nKfHBpiDWS1Idqakshp8fyntvqaVGstDckQZnt3kp7deiqzW+eIcUgNI8m2l9WsLqkSjym0SZPR\nd8yEgQl8aMpVBNXFUGCSLdf0eYeIXJFFwiBElbRa0buSmfDM+lPGNFV9R5WBQ9bgaN7ffg22YyUI\nam5BIvdWKQT6EkRdYAKPq3LIxxn0N4MTsJZUkSSJ9nTAOmGE1x2nEiqyJCFJEv90/Lt5+4Szqu53\nJKiKHCLaWvcUJKr4K5C3JElV7cnOto7mJLphsddZMphOxqhLxjyiDa7jrmXyPxioilxlNamteR+Y\nvINCyiuZzYPHnth+AnVqmium/2NVm3WBQlCSJHn9DL67ifiRo9SIvI8QauUZr7Xd07wlhd6MT475\nUrXm7aIuJV4iTTdDRSrCmrfN5sw28Zdjcg6bzZ3lN6UpXuIKKWg293Ihq6QSKkk1GdK8dctgW2GD\nc7AFZswj/L5Myat85Urnk5pEQJJclyURU9AtQ0RKuyZ3JxmDu9YbVRdm4YD/s0FpEfV3FZ2OppQg\n74DmrZXcQLhAZivHZFly5CK1bT9Kx26x3MPSPHJXFYlEXMEoOwJATAv5usHGGhjDhORkp9604Res\nUBNiQgsUtFBkiaw2RH28zsvFbLqme1dIkiyRWMJUfFOuQ85lCuiKIF+9GCAMxyeXNftobnL8pwGz\nbGO8kbxeYG+xy8vHPGOUr9U1JeqIyeIeZdstpOFkbnPMo+lEgg/NfC9WOemZ122nIlk66QpbNnax\nAXXvKbSkwlHESSXBuPix6LunoioyTfV+JLoLNaAdKrLkFXoQmncN8jYSnvm2crKH2r7HukCyjrBZ\n1m+/MR7O8hXsl/gd9h8Hr2NraYrPBXzRAZ93Q9zXhG0zTN6viFfBg0HNVKkIbHu1cO91pxO3U1dJ\nskqck9pnVvnjXytqEXWlUHWg/alEtQl+pIC1hng93z7nZs4c+5aqNmu9ZxAm/1cSKg4lIvI+QhiB\nu6vglsNUZYW+rK9590gb+O2m+/0DA8TqlgYtaWaIUIOat9zcy7J9ItLbKjnLpGJlYorsmM2dNddy\niua0Y/JSqoO9MEVO6sZ4AwOlQTQnA1le9zOvGb0TsE0VJSau35sp8v/Yu/P4qMqzf/yfs81MJpls\nkAAJ+yabICgo4i5Qt69WWxUXcKlaRVu1daFUpbUPuFT9Wbva1trqQ12hllddeLpp1YLWlcUVtAjI\nkkD2zHaW3x9nmXMmM5mQZCYZ5vP+h8xkZnLmJMx1rvu+7uuOWevL7eA9uXqMeVyl+/CrDx7Gqzv+\nbZ6npJ7Y/knrzKBmNUZxFy/ZhVRCoB0DyvxQDc0zbG533rKDrjkkbZ6rcLsrWJTvRWNwM3a173Iy\nd0kU4VckaBFX8PZUrsdR3xSBz96yUo45vxO/5IMkCOb6Z8mcKxdFc7cjuwMUAGiqfYFiPq+4OLGk\nx+nnbL3fiNaGmGiNnESCieYeVrOaFr0egZB1fK7gXR4wA+nKj58xHx8NYGBxYs4x4JfNoXMAJdoQ\nCLGg01TEzmz9PslsEqO5A5V5UWF/gNt75IiC0GGeWBAEnFJ9DtQvx6KqPODMwbqzMzkp8/YOm4uY\nV3YhYlumIbLpaMS3j4PSNNK5uEgOIgBQXtJx7tGdOaWbUxUEAQtGXYzohzM7HFcyRRZR2mFnP8HZ\nIcuI+Z2LG89FQYoe+r3NfUHiHjbvjeAdjWnmErqkQCUIAr459RKcMvLkbv+MVFLOb2e4oHFfdLmH\nsv0phs2TL9DSCaYY4QAS9RoAg/dByT1s/t6n9dANsxfuftd2llu/bEJMtTJcw0BdY9jJAPeXJ/aA\n1ttLrAIq8zXNDy8De/WtqI/sT/xQV/C1h5dlQYbeWG0uMSpuhqJ4h819QgClwSLo4SDEUKPzGu7M\ne2d9G0q1oYjpcTyx/jWs/2C3s7etumeY9fqJIri6pjBiVqFdsbWOcnRlrbmOdsBufNGyHau2/MU8\nUCt428PmghL3NOZwF0KVyFaLVCWC0lJ7eU7iP09Ts13oZm1VKCUqsdvaBGCH+SErVdShrug980nW\n/2NZMiuW7QsdqWq7N/OW4tjXHIEMa3jWmuMHzEzTzLytD3YljrgRQVxXUe4aQraXfCnDzKYgpSF7\nIwvZmUqwq5TbjVZExCYYutnD2mn5GPfDiPuwH19ik/oP8xQ0JqqI7S077SmX2JbEOmDAXPs/sMgM\n3pE2H9o3HO08xh42d9bhC4lhUfu47OBk/32bc9YdPwztAJuuGtcdJCVR8BSs+WQRgwODoe0fAqO9\nFOquMdBVn7MbXjDFvvbuzMo5Blfm7Q48/qQ51YmV46C3mFXlyRciboospXyvV0y5GPH35gGaL2Xm\nndziMxv8nmJAd/DufnAZ6JqKSHXBlC2ZsuxMz3FPz9gXAuky784Up/g7AwBZTrwWg/dByJ15P7Rq\nA155dyfuXvmOp9HK8sfeRn2zOTcc0+PY1xxFZWkAJQEFQsRd9GP9J7IztiIFUtUObCsyd/s5acDp\nAOBtB2oF4UvGLwIMEXpbGUR/GIZoNfiQzbaZPsmHoF+GVl8LQdQhVe72PB+agoaWKN57y/xD/vN7\nr+PXaz7Axm3m4+zgamgydMEMmvubo4gLVvC2Mm9REJ0Mz3OeUvTrdg9ZuzPvErnEeZ8lQSvwuTJv\nLe7q2Caaeyy7s57wl0M9VePun2suvZGgt1ZgQukkSKFGSFWJZXSxz6bCMIC4nZl7Mm+/uYey9f7E\nYDPaNHsLy0TWO7bYWspTuQcQVZSUJC5A7A+BgUHz8a/sfx5hcb815SGgusIOgmaTFA0xRIw2xHeM\nxfTBiTXqdvAGzPXtRpv5evaccNAvozpoBqq2Fsks7LOCS0u7+Tu3i3wG+BJroO3gbVfXjq01f85h\nYwc6owLuoFhZGoAAYGhVx9854B16lCUhkXlrZuadnDFqmp5YrpNuODPpQzlV4RLQ8QPXPUfaWYGX\nnbFVliay/AGl5haVMsy/02CK4J343XXN6CHm//3Dxg3M8MhERul+T6LYu5k3kH4IORvsn+WTRQyz\nKtyryjpvhuK+6PLMSSuJkbVU3+/KcSSTPXPeuQverDbPESczKWmAPPQTfLjTu7ymTdgLsbzOmeON\najFEYqq5jlY30KaJEACE6meiWbaWFllV1MGA7GzacP74ryKyx9zooXaIjMtPnAXdMPBKXQPW7/3M\nyXy11lKIZXUIi/sQ9I81s1PV3NtWFAV8e958/PKjTzF1qoh3/55ocFLiC6IZifWPdr/o3U1WW0+7\nGMfOOiUVMVWHjihEeCtSJ9YOxsdNiZaR5vPNDz17pACAOd/tt+daXQ1GrOB95IwifN76mfVzXWug\nnUYuGgTr/BiajOKAbA25CmZBmL9jsxnJNSw4o2w2Pmr+wNmJKbLhWHO/agDtLQJQYgV915y3JEWc\nJVRicTNaVfPnuzPvb596PJa/vBX75E8hKDFUlgWxwzpGe8574UmH4uebXI1rrO5XQwYU47yTxiIU\n9OGdLwdgXcM/UOoL4ZRDvo5h1SFcEDYb5OwT/us89fCRI3HWCWbryTsunYn9zebWh9VNZlCwh8JP\nnjEUf39nBzTdQHFAdoYd506agj98bG7KcsnJhyEkDMCU0WbWfszUIRhUUYTRNWaf7GWXzkSFK6hV\nlRfh9kuPwODK1FXInjlvSUQsrsEwDKfaPDljVDU9MSef5kP1vmuPxusbduGZl825fk+Rmiu4JQc0\nd5CXU2Tw9187B63huJN1L7t0JprazIs+ewcrO4ja2Zq7u9hti7xVzJnMnjIYA8sCGF2ToiNZkvsW\nH43G1ljSnHfXC9Y6M6DU3BfdMHIbvAM+GcsunYnykB+KJODLfe2oTXMRaHNfdCkpRlnSFay53X/t\nHLz18V488Tdzu9p0GXqqi4NcYPDOEbswTSzdB6m0Ac2tewAkrh63la6Fuyg3psUQi5vLqEQB2Cuo\nKJaLIDQMhVi1y1xcJOowAAQUGYJsZn2TB0zA3z7ZZ3bv8oedtZhavVWQ5jOvnu250jZhPwRBgCDH\nYaiJhgMTBtdC/FhE3JpntTPvimAJmvfHzbXWmuQUpe1vbzH/muxqcSdwxs0GNIpVze5aQlJRFAK8\nsdvJrGPbJiIweb35GnIMorVlpN6ayFxDivke3t3/Nt7d/7Z5pyvzdobQ5Rj8483vG9EihII+TB0z\nEOs273aCfak+BOMGV2P9f8zzJEuCk50FjUrobaUQi5s9xwgATc0wg7ccd4oI/ZIPoiA4PbvF4ia0\nqOZzy1xz3n6fhOpQGfaFzfqD8jLRXLalJ4bNq0PeavD4DrOJSHFAdrLYE8dNw4mY5nmcX5FQWQpU\nxkc79w0vH4RqK3sqtiqFAWDW4MPx7KsfQ9s/BJNHVngyQ3e2NW5AotBt+qihngsxURBwyPBEtfWI\nwR23dxw5OH3w6ThsbjhD58mZt98nQdONRJerNMOZpUGf50PeHdB8nmHljnP0smQeQ6oP9oqQ31lf\nDQChoA+hoLeRjP1+3BcCK+bcDkkUUaIcWMFa8rntTFmJH2VJ8/2pmrR0hyyJqAz5sa85mnYIOVvc\nf0/2KE9n5DRLwVIWrKW4QAPM3/PIFH/Hydw1BRw2Pwg5w+ZWT+WInmKHKxd7yZdfkcwPJ1GDIvoQ\njWuQkpYYybIASUlsU1jXGIERC6BdS6x7tVtzhvzmB669WUMM7TAMwwrePkSsHbxkUUaFvxx14X0Q\ny/dCHmAPi7u2WlQVZ/mUvduYMyft7GGsojUchVhsZuYhV+FOyrWg9rB7WzmiH1vLk5QoxFAj9EgR\nEHd1B/OluPp27fhlX0BIFXsgKHFIrYOh7jSDn7OUyPp9+EQ/Lp9yEcT95m5DdsEaYFZdq9aOSoYu\neLL7BmsBgFi6z5mXtzd+QDwAI+aHEGzGjlZzyL0maSlSib1LkRxDaYld7Z0YNncXOk0UToTeYI6q\ndDXzce+6lG7tbUD2I/blKECXMLC8KG27R3exXbHcvXW86aQqWItZ65cVyRu8Az4JqmZkHDYHktY4\np8mQUs2P2wVv7u1dD4Q9kuD+PZX5Qx365OdCb2XeQOJiLpeZd3d4Mu8U1eBdybzNx2U+X+6Lg1R/\nS9nC4J0jTsGatYGCs/uTLWlLQ7tHud9ndmkSJA2KYG4Dam9q4ARvSXSGtQNyAPWNYQhqAO1qO1Td\nvD+shiEKIoKKGbTsIdKI0WbuYiQYgKZ4CuiqigagOdYC//h3nPvawq41yZriFGm1WNXmRofMW0Wj\n8hnE4hZUxMd4AkiqYOLeYcoZQg81QJDjHdbRFrn28lWsYUnPPLp1DGKRtca8bZIzn+tklFa2LMIq\nHrP+I8qS4BS6tEfi0PbVmIFb9cFd6mq0l0JvLYNUXg+p2mxp65f8zsWa3h6C6I9g475NKJKLMCxU\n63kPIcVV+e/XneO2P2R8kmvNtpgYdTiQzOeiCeeiSC7C5AET0j4mZm0vW1bs82Qi7vXSgiDg4gnn\n4uyxp3d7HW86SoqlYnbzEZ8ieoJOQJGg6ZmHzYH0WZV7Pa6U4jF2Zt3Y0vlFdjr2h36uM9RUpG4U\nZ6VjX8wV+/v+fXXGezHoyox9nS8VS9aF2J2x8U22MHhnQTiqoqXduyuY0yXMyvTiRlJ3LtVbgOEE\nb8Xa9UtUAV1GLK5Bttbl2vPjih28NRkCBNQ1heEXzMBoN2Zpj4dRJAcgSaK5VCfuh2EIaFWb8ceP\nVgEAtP2DPB9WA4OuTRosEVenOMOqKPcrIiL2Bh3OnLcVOBUVMdnMumvh3bIwOXifNOxYs2DKZjVZ\nEcvMPa6T23C650EXTVqAG2dcA3XXqMTx6e75bxHlYiLrtT+c7WAfMMzXtocYJUl0/qO3RVRA9SH+\n3ymIb0/sya3IImCIiG01h6ztna38kh/2SgB7HXZYi2B8+egOGzKU+q3aASXmbCBiaHLK5TElUuLi\npegAMp+ja2bivuN+2GlbSltxkeL5MEquDp9dMxNzhx/f5Z/dZUnLxlQtfebt90nQNHPY3C4sTCdd\n5uS+v7Pg3dDazeCdIvPuK+5h855edOVL5q2kec+pMu/OCtbc2/Wm09MLou5i8M6C6x96Ddc/9Jrn\nPrt6Fk7w9gZ3Q9CgR4LmGlZRcTbZ8CsSSopkCJKOPfti5iYMgt061GroIgnOhhRtERXhqIZia3/h\npqg51xpWwwhamar5QWj2zd7eth0fNXyKQfIIaPW1GDwgEVA9OyxZhg8yX3dAqd8pShs62Oc0QnEy\nb6dtpwpVtIbsFe/8kXsI8VuHXdlh/99xQ8xWlfb/PfcmAYD3P+DQkhqMLR8FGCnmvGF1RLOqdodV\nlzgZROyzqYhvH4ehhrk8ys4AZVFwisbs4VmtvhbaPnP4fNSQUqdHtxEtcvrFA4Bf9jkdLY32xEhA\nqsy3LJAI3s7fhC55AtL/G30Kjhx8OAJS9pbq2PPcQyqDnkrsQTnaaMHdrEiWBOiG4azEUBTJO2yu\nmHPebRHVHJXqJCBJXVjDW2o1jXFn92NqzIu5IQO6N8ztVyQU+eU++2B3S3Vx0l2DrL+T0mJfhkf2\nrXS/d/v3ka63eTK7F7zYyd9YV9eJ97b+ffmUp+xCG8MwnA+WRPA2/9XQMXhDC0DdNQbV46LYGfkC\nsLbLnDV5IF54C04wcipXreCtSCIgxWFEfE5L1XJ/KRoANESbMArmnLddLFUe8mPP/nZz/tgXRZEc\nwHdnX463yxow3bUcZXz5mMTxqTK+MuJknHTUZLz7aT3iqo6nt5hrzysrJWyPxc1kU1NwzNQhEMsF\n/CeyCZKswfBFYBhCh0zbfXt8xRgIgoAbzp0Gnyxi9/52zJo4CDe/9ifnfert3jluWRKwaOL52NL4\neYcLjbISH5paDRiGGfyrS8px+lEjURr0YfaUwSgOKLj27EPx8z9thLprDIQae7g8kXnbRU32nuTj\nhpbhzGNGYV9TBDPGV+FnqzbA3GFcgN40EKK1I5tn2Nx1wTFj0FQkq7CaqMiDvsCGfebPOe+YyZ6i\no1NGmu1qX3rjC+e+AaW9u2/wLRdMx0dfNGDK6AGIxTV8/YQxB1Qo1VPupZR2sLH/litCfs8oix3I\nW9oT+5ink/yhe+uF0xGNe7OpMTVluPTUCc4GGwBw8hFD4fdJmDHeu/NWV10wd5wzrN/XejN4H35I\nFRbOH4+jJg/utdfMhuRs+vuLDkdTa+Iz1/130dn5GTE4hEtPnYBDhqcfteqrCzQG7yxStcSmCppm\nz3mbHxya4A3eEPREYxIkdtzy+yRnq0l7DbOdeTt7RUsCdCEOQwtiZ521UUdxJT4PA/sjDeZuZLrq\nZN5V5QEzeFsXEjXFQ1CsFOG4ad4sa3jpUAwOVmN3+15EP5yFY4+ag1DQh+Om1WD9B4lK7WDQgICw\ntbm9gLOPHY09cRn/ec8cNheUKBD3IVDi/aAt9lQrm+996hgzCE8YYfX/1v3QxXarUKxjRe+RVYfj\nyCGHdzj35SV+8z+rIQCCgepQGfyKhLlHJPp6H35IFYJ+Ge1R1cmU7Yst93CsnXkfOnoAJo9MVH+7\nq5zj28cDhoihlRVQRNmpcTDCJdAjRThhzHTPHL1znEWJC5KdrealwNETRnV4HODNEFJ1EOuJytIA\njp4yBIBZiX3aUSMyPKN3uTNve5jX3rSnqjzgyYrt77dFVFRXdF44l5xVpbsYOW5ajee2KAgd7jsQ\n44ZmnqLIld4M3pIo4sQZQzM/sI8lz0PbIympZJpKyPR30JWitmzo+zGdg5j7Cl/Tra+tYXNnj2Xz\nljlfavfztpc7WTtuOXt0W8HSZ+/HbDdOkTSn4GxHnVn1XVtqZtANkUa0W/PRRdY+t1X2jkuy+bo1\nJemvom8+4jpEP5wJI1zq3bQ+oCSGyJUwRH8EmpUZK7LobK0nyCoEXxRGzO/Zlxfo2s5DJa3mHLPe\n2DED6uxDKRS06wLMoFDiSz386QzJG4bntuyZ844797l55v00H+LbJmGYPsN6Qet+Q0R0w3E4b/xZ\nKX9+kc/nFA8CZuFdukpud/FVLqtac8Gdedvnede+xI5x7mFz9+890/RBbwaufCX1g6H7XMvlUHa6\nTaeyrfB+qzlkL7sCEsPmguBu2WmxN9+wMm87wxZE1QreiblQAAiIAc9rGFYWb2gydlrBe0SlOV+8\nP9Jo7kcNIGgFVLvNYeyzQzEiNAynjpyb9j0E5IDTKtL9HyIYkJ3g/Z+ItZtYuGPwhq/dXI8eD3To\nhdyV4F0ePgSRjXMQ+++UDt/r7EMped/idEt07Kvu5P9/dntUIJF5JweCVEU79lW49+VStwwFzGVP\n0Q+OcgJ4mb/jDlm2rhTP5Bv7nRquM2af50TmXeQ59+7fe6bCqT76XO1XCvECJpdD2U5ilmMcNu9F\numF45lI8mbfmLVhzb7cJwargtoKz0SHztrqLWXPefsneDMMM3ppgXQioMnZY2/UNG1AJn6hgQ/1m\nJxjYa4ZDRebws948ELfMPK/L7y8580bS7kh24xdFFlFkWM1gfFZns5i/Q1WwLMo4YegcDAqmn1eU\nZbFDoZqts0KT5Cvv9MHb/Dd52FwUEsHbnitLfs1Uy4DsD8p0u8h1PE4RRqzIXG5WuRdiJzsu2Mv4\netJoo78RBAGGYSRl3lbw3tcOvyIhFFSSNjFxZ96dz3l39fdwMGPwzi7nsz3HDp5PgT62e387rrjn\nn/j7267+1/HEsqrkgjVB6ph5G9awuW7vNuWLwO9zZ97mtZZTdWy9hu7KvJvaYigt9iHgk+GTzCD9\nft0mAMCospEA9OeBHQAAIABJREFUgPJQ9ypFk5tcOHtuW/RwCLIkmPv/Wpm37rOat8T9Kfv+njv+\nLBw39Oi0P7OzZRzJnbHcyoq9c8IBOXWBlz13XGQdWyITTPS/brcadSRn+ikzbyl1Jp+JfcEW19MX\nOdmvOXxQ560h84m9JMtd4S675rQHlgc6jES4f++ZMu+DbXqhOwrxHOTygqWvLqaZefeStz7aCwBY\n+ddPnPvcm444QyuiO/M2AAiJPautYFgSrwX8m6AM+wQ++WRnztvOvINyEIh3zLxrKspR7h/gVMi2\nurbpBIDRZWYR0qSRlTh99ghMH9e1Stprzz4UX9a3ej4EKkJ+nHroNLypbcWYkrF4+8P9MNpKofgT\nFfHmkjd7NzJft1oHdjZ3lep7t1wwHe9vrcfXjh8DUQT+ZT9WTP2zrz17Cl5Y/wVOn20VaLlesqqi\nCKOGhPD5LnP0IPkDIdV8qzNsbkVaWRJxwdxxad8DAFx51hSsa9iBrZFdnuHjZGcdMwqqpuOsY1IX\ntOWj75w/Df/3n+04+fBEEdSx02rQ0h6Hbhg4ekqiHuPCueMQ8MnY+mWip26m4D24MohTjhyOKaMq\nO33cwcyvSDhzzshO29MebIr8svmeh6R/zxfPH9/lTUk6M2N8FU6cXotjpw3p8WsdCAbvA7S18b+o\nC9fjqCHezQXswCYEWiGW1UPbMwLRlJm3GagF0XA2FnGG0q3g7YsOQoV/KBqKdwBSvMOcd5EcgBAX\nnNakLaq5DehXpo/F7JpEj+sLJ3wNL3z+NzRGm5znAeaQ8NeOTywDy+TwQ6pw+CEdA/3Xjp6Mq6uO\nwsdb67B+rbkft/sqNCgH0BSzh/SVbgZvMem22XMaSD1sPmFEhVOpfv5J4/Avc5dMKGLq4dXqiiAu\nPTWx/lpAYthbFARcd85UfPfnr6c8lmCKLlPJAf74aTU4cXpth8e5nXncGEzYfiZ+uWEfFhxydtrH\nBQMyFn7lkLTfz0dDBhTjklO869/H1pbh21/vuKzOXimwbXeLc1+mYXNBEHDeiWN74Ujz21ePHZ35\nQQeZTO/5pF6qmpclsU/+XzJ4H6AH3vkFAGDW4Bmebln2XHdgqtmcJdJa4Q3emrdgDYCZfetyIhu3\nhs1jcQ2KHgREQBOirszbqjZXJBTFi9BqBe9P2z+CKIiYPND7ITin5kjMqTkSb+95DwNTNFzpLe4P\nUPeSnqASRFMssZtXd7bLSw6YPlmCqplDy501TrAdWzsbr+5ch9HWlEEmiepz89+yksQUQ3Kmn7pg\nzTts3tWGVhWBciyddWPXHlzg3Bdt7o0/iAoJ//K7SdVVZ04ZAJKnWARRSxo2TypYgznsbcQDEKwm\nJPYcciSmQdB9ZvAWI4iq1lIxe523LCIoF6FNaoXgb8OeyJeYWDnes4mF2+GDDuvRe83Ep4hmP2rd\n8GTe7mpyo7uZd9J8kt8nOXPQXWn1eN74s/DVMachIHdvXbS3mYP3WFIOm9tz3vbwd+FNN2ad5ClY\n40cYFSYWrHWTmlRY1CGQGELSsLm9zjuRedubeiSGzc3gFo1rEKyGJHEjirC1TttemuWTRQSVIkCO\nQRpgNvaYOWh6z99UNwmC4HyIuueQPOuVXZttHIjkbPdAd0USBfGAAnfyum+3mKp5bqfaijIx523/\nfEbv3uYtWOvfG2QQZQuDdzfFda3zBwg6/vi3T7Hps30AXJm36FoTaFecJw2bb9vdgi92mvPcUSOM\ntri53tVu0qLIEkqUIATRgFS1A7IgY2rV5J6/qR6wP0QVxTtsbjNUxbOTU1clD5tne79cZ847xfda\n2uOe26kL1rpXbU5dx8ybiMG725Izb7ufucMKyA88/T6AdMPmKgZVFGFghbWHtWvplb0dZtQIO3tx\n25m3IotOYBT9EQwtHppoitJHjpo0CANKAzh8fLVzX1BJtAOVoXSrjaA7eI8cHMKF88b37EAzSZEo\nf+/iGZgwvByzJ3v34lZkEbMmVnsKopKHzZl4976JIyowqDKISSMrUFHau21iifIFL1u7SdW9WVgs\nufuVNY9tf3inK1i7Y9FMPP3uK3izHYAuosgvIRzVnK0129VE8LaboiiyiGI90XQk5Ov7db9nHjMK\nZyYtYXIPm/vl7q0td+/zfPslR2R9swdnnbfr1zRuaDluuXBGx8cKAq4+y+z89vQ/twBwVZsbicdQ\n7xo3tBx3XXVUXx8GUZ9i5t1NquEdNjdbV7rms63MuzJkZsTJvc0Bs1GLTxFRVGT9GgzRqdy2s+zW\nWBva4+3wiT5nWN0niyj3JdYvhtIUqvU1d8FadyrNgUTBmiSaLUaz3Xwh0XGte+Pe9pJBnfVqRJRF\nWc28V6xYgffffx+CIGDp0qWYOjWxdnPlypVYs2YNRFHElClT8P3vfz+bh9LrkofN46ruZNsAnK8H\nWMN6qea8BVmFJIoIBqzgrZutIOubIs6weVu8De1qGAEpALs1hSKLKJMSwbs0zaYbfc09593duWq7\nOMkO2tnecEBI7pd6gOSkJi3MvIkoG7KWeb/55pvYtm0bnnrqKSxfvhzLly93vtfa2opHHnkEK1eu\nxBNPPIGtW7fivffey9ahZEVyG8u4qnn7lVvBu9Rqv6m72qMaqnnNJPniePHzv6MdjQDMOe9ia39i\nSfdBgIDWeBva42FnO0/ALFgr8yeCd1mgf3ZOch9z8qYkXWVn3nZGm+3t9xLD5t2M3slLBhm7iSgL\nsvZJuG7dOsyda+5WNWbMGDQ1NaG11exzrSgKFEVBe3s7VFVFOBxGWVn6/Vb7Ql1jGI+t/djZDjJZ\nqszbvVOY0/LUCgLujUnsrFoYsAN/+XwtXtn5uvVY0Wk6IUkiipUgGqNNiGgRTxaryKI3ePeDOe9U\nZDExsNPtzFtK7K8N5K5Pc7eLxa0n9tU2gURUGLI2bF5fX4/JkxPLlyorK1FXV4eSkhL4/X5ce+21\nmDt3Lvx+P04//XSMGtV5v+aKiiBkuXeXCVVVpZ8rXrHyHWzZ3oiyUABXnNVxO8rikOJ5viCJiXXb\ngJN5y4qEqqoQJFkCYEAQAD3uAwLtHY+nrBhl1hy5IokYXl6DD+o+BQBUliSC9eDqEAJFiYA9fNAg\nVA3su3nvdOfRCNYC7wB6OIjSEn+n5zudygpzSkCWRef5AZ+E8cMruvV6mVx46kT86JE3cN68Q7r1\n+iWhAKqqQrj6nKn45aoNOHXO6C69TjbeS6HhOewdPI89l4tzmLNqc/cwZGtrKx5++GG89NJLKCkp\nwSWXXIKPPvoIEyZMSPv8hoaOwa4nqqpCqKtrSfv9fY1h69/2lI/b19CCOiVxf2tbzDNsPn54CB/s\nBMLhOOrqWtAeiSWK1TQZhi4msnPLVacfin+tM3+uKAoY5B+ED2AGb9lINKNoaQ5DiyZ+dVq72Ol7\nyabOzqMAH04oPh8vvl0HjDO6dYzhtqj1WnCe/7MbjoMgICvveVRVMX57y4kQRaFbr9/cHEZdXQtm\njhuIw7v4Opn+FikznsPewfPYc719DtNdCGRt2Ly6uhr19fXO7b1796KqytzcYuvWrRg2bBgqKyvh\n8/lwxBFHYNOmTdk6lG6xLzbSjdJ2GDbX9KRtPs3MW7NeJ2q0QvBHrBcXALXjdZMsys7wuiyJqA3V\nON9zL7tSJNFTCFWi9M9hcwCo8g0GNB/8Svf+1Ox13u4qc9GqPM+W3hqaL8StGIkoN7IWvOfMmYO1\na9cCADZv3ozq6mqUlJhBpra2Flu3bkUkYgazTZs2YeTIkdk6lG4xUqzTdY8edChYi2uA7B42N7Nq\nOxjvqFqDwNRXrRcSYcQ7NpdQRBmqtaRMEgUMK0kEb/dabrt/+IiQuctSd/t254IdfANK9wZ5ZDm3\nc91ERPkga8PmM2bMwOTJk7FgwQIIgoBly5Zh9erVCIVCmDdvHr7xjW9g0aJFkCQJ06dPxxFHHJH5\nRXPIvdRH1VX88aNVmO3aBlQ1OmbeYlFiqMQQzO87Veae1xZgtJVBLDYff+GEr+HThs9RVTQQmlYH\nwCxYqykZjONqZ0MWZcypmYUnsN45JgD47uGLu70eOVfsgjNfN1qjAole6WKWq8x7C+vUiCgXsjrn\nfdNNN3luu+e0FyxYgAULFmTzx/eIu8nGxvoP8cbut/HG7red76tJvc1jqg6xsinxfGgQBQGaYeCL\nvc3eFzdE6K3lQPUOAImtO4FEm1VZFCAKIs7vZH9nScxun+/eYGfe3a02l6zny3mSeff3iykiOjjk\nRzrTBxKZd+pGG8lz3jEtBiHYAr3NrArXoEIUBei6gR/8YZ33yboIvS310rhZE83+2ccfVpPy+/mm\nImQO6Q8o7V7v9UTm3b+D9xGHmPUcIwf3zzX3RHRwYW/zDARBgF/s2Jc7ntzbXG6EIBjQWiogBJuh\nG6qzx7Wn8xoAASLu/8ZXsOaLCMaVj/Z878hJgzC2tgyVKTZc+NkNxyY6teWJMbVluPvq2RhY1r3g\nbQ+79/fg/c2zJuO8ligGlhVlfjARUQ8xeKfhDJsLqbt6JQ+bq4K5lE2PBiHpkpN5R2Kad/03zD2m\ny0sCWDTp/JQ/e0CaQJevexdXl3c/oLl7m/dnkigycBNRznDYPI3EUjEBmqF3+H6HLUFFa9vOmB/Q\nRWiGBkkUsLehvUPmDZ2nvavsXuH9PfMmIsolRpE03FXDWlKWDXirzQ3DgC5Za7jjPhi6BNWIQxIF\nGAY6ZN727mCUmZ1550vBGhFRLjCKpJEp845riYC8e387oJidwIy4HzBEqIaayBalpODP4N1lSp7M\neRMR5RKjSBruOW/N6Dzz/vmfNkFQYgCAUn8I0BKZNwAIHQrWGIi6yqdIkCUBRT6WZxAR2Ri803A3\nadFTDZu75rwjMRWCEoUiKvjhJbNRO6AUcT0OwT67ycPmev9fn91fyJKI755/GM4/aWxfHwoRUb/B\ndCYDM/NOVbCWCOiabkDyx1DmC6G02I/yYDF2RXRIkpW+JxesaflZNd5XDhle0deHQETUrzDzTkN3\nNWlJOeftWuet6ToMKWoOmQPwS+YabdGa6xaS57w1XjMREVH3MXin47RHFVLPebuGzXUhBggGQtbu\nXgEreAv2RiVi0rC51rHpCxERUVcxeKdhrxRLV7AW01QnOzdEs9K8WDG37fRbu3wJaTNvDpsTEVH3\nMXhnIKYpWPt8dyN+9Zy5B7kmmpXmQTt4S1ZmbQftDpk3h82JiKj7GLy7INWcN0QNb31sbt9pWMG7\nWDaDtzNszsybiIiygME7A90wUg6bu7umGZJZvNZh2FxUARgQ/OGkF2XmTURE3cfgnYFupMm8reBt\nGAYMKXnY3NoRTFQhDdoGsbjZ05iFTVqIiKgnGLwzMAwj5Zy3ORRuQDcMCLKdeZu7StnD5pBUSKX7\nAQCXTlqQk+MlIqKDH4N3BuaweYrMGwAkFf/e+SaU2q0AgGDSnLchqBCKWmDEFVQFB+bkeImI6ODH\n4J2BYaRYKmavAZdUPPnpaufu5DlvTYpADIRhREKQRc5zExFR72DwzkDXOxasybDntL33FyctFYvI\n9eY3wqEO+38TERF1F4N3BobRcT9vUU/MaeuRIud+RTSXgNnD5mFxHwBAiIYwpHgQJMOH+M4xEFiv\nRkREPcDgnYFhGNCT5rwF3cysBUmFEU0Eb8GKyvawuV1ULmh++CQfZukLoe4cl/2DJiKigxqDdwYp\nC9bsJiuSCgjmBPjo8Hzn285SMYtgPV4wmHITEVHPMXin8NH+TyH4zMYqqQrWjLgVjK3gbRgCSvUa\n5/uKKENxFagJujeYExER9QSDd5KWWCt++t5v4J/2CgAz8/7vnibPY1S79kxSIQgGYAiQRG9WHfKF\nnK9FnbuIERFR72HwTtIWbwcAp6hM1XTsaWjzPMauX7MzbxgCxKQzGfKVOF+LOnuZExFR72HwThLT\nY57bcVV35rVtumadNsnsXW4Gb++pLHUFbwHmELoB7+sQERF1B4N3koga8dyOxXVAMAvWoh8dgaJw\nLbR6c37bnXlLSeu/Qkpi2Dz5e0RERD3B4J2kPSl4x1XNybz15gEI7T0aajQAABCUqBO8haQzWepP\nBG8hKXgn3yYiIjoQ7NmZJBz3bt8Zs4bNDQMABLRH44Dqg6EqEAJt1lruFAVrimvOW2SwJiKi3sPM\nO0lY6zhsLgg6YJinqj1ilprr4WIIgTAEUYNhCB0CtGepGDNtIiLqRQzeSbyZt4G4pjtD4wAQjpql\n5kakGIJgQPBFAUPskHlLouR8bX+L5WpERNQbGLyTeDJvwUAsrgGCDlmUMLqmFLo5fg4jXJx4nCFA\nTMquJw+YABgCYtsmdPgeERFRTzB4JwnHXcFbVJ05b1EQ4VcS2bQeDSYel2LYPOQrwYTGi6DtGclh\ncyIi6lUM3knCqmvYXNSdanMR3uAN3fV1ig5rAKwit8SwORERUW9g8E4Sdi0VEyTVWectQoLf5w7Y\n7lPXMfMG4AyxC4zeRETUixi8k3gzbw2abkBwhs1dp8u9Q1iKOW8gEbyd2M2KNSIi6gUM3kncTVoE\n0W5ibgZvn2vY3NATpy7VUjEAmDKyEgBw2NiBnvs5BU5ERD3BJi1JYpqrt7lkB28doiBBkdJn3qnm\nvOfOHIZDhldgWHVJh+8RERF1F4N3kqh7Y5KkzFv2BG9vIE+VeYuCgBGDQx3uJyIi6gkGbxfDMBDX\n4s7txLC5DkmQkoK3O1innvMmIiLKBs55u8R11bttp5TIvCVBhCy5h8q9WXiqYfNkrFcjIqLewODt\nYs93S4JZmCaIGiCqEATAJ/o9mbe7YC3dsHk6zNGJiKgnGLxdolbwDohW9zRRg+Azq89LlNABF6wl\nG1xpvu7omrLeOWAiIipInPN2iVvFakVSEG1aCyCp5sYjAEqVEGQxdcGakWadd7KTZtSiOCBj+riB\nGR9LRESUDoO3i5N5C2aGLEgaBMXMvEt9pZCTsm33110ZNpclEXMOHdJ7B0xERAWJw+YuMavS3O8M\nm6vOsHmZrzT9UrE07VGJiIiyIWPw3rp1ay6Oo1+IWcPmPqMIgJV5W8Pm5f4yyHLP5ryJiIh6Q8bg\n/e1vfxsXXHABVq1ahXA4nOnhec0eNldgBm9zztvMvCuKyrwFa8jc25yIiCgbMs55P//88/jkk0/w\n4osvYuHChZg4cSLOPfdcTJ06NRfHl1N2gxZBV2BoIgRJBXwRGLqIkFKMqBRJ/URD5LA5ERHlTJfm\nvMePH4/rr78eS5YswdatW7F48WJcdNFF+O9//5vlw8stO/OGIQG6DLG4GWKgHdq+wVCUpA5rbhw2\nJyKiHMqYee/cuRN/+tOf8Je//AVjx47F1VdfjWOPPRYbN27EzTffjGeeeSYXx5kT9py3oUkwNAmC\nYt6v7hwHRfL2Nvd0WwOYeRMRUc5kDN4LFy7E17/+dfzhD3/AoEGDnPunTp2aceh8xYoVeP/99yEI\nApYuXep5/K5du/Cd73wH8XgckyZNwp133tmDt9E77A5rhi4CmnlqDEOAEQtAlrztURXZvc5b5Jw3\nERHlTMZh8zVr1mDkyJFO4H7iiSfQ1tYGALj99tvTPu/NN9/Etm3b8NRTT2H58uVYvny55/t33303\nLr/8cjz77LOQJAlffvllT95Hr7CXihmqBEO39u6OKwAESJLgqTZP7rbGYXMiIsqVjMH7e9/7Hurr\n653bkUgEt9xyS8YXXrduHebOnQsAGDNmDJqamtDa2goA0HUdb7/9Nk466SQAwLJly1BTU9OtN9Cb\n7DlvXXNl3roMSTSryd0BO3nZGIfNiYgoVzIG78bGRixatMi5fdlll6G5uTnjC9fX16OiosK5XVlZ\nibq6OgDA/v37UVxcjLvuugsXXHAB7r///u4ce6+z57x1VYSzFEyTnEDtnvNOzrwZvImIKFcyznnH\n43Fs3boVY8aMAQBs2rQJ8Xg8w7M6MgzD8/WePXuwaNEi1NbW4qqrrsLLL7+ME044Ie3zKyqCkGXp\ngH9uZ6qqQp7bwqfmMUqiD7D28jZ0CQFFQlVVCEUlifddFFDgXMIYAgYOKO7weoWiUN93b+I57Dme\nw97B89hzuTiHGYP39773PSxevBgtLS3QNA2VlZW49957M75wdXW1Z7h97969qKqqAgBUVFSgpqYG\nw4cPBwDMnj0bn376aafBu6GhPePPPBBVVSHU1bV47mtpN39GuN2A4Lf28tYlSKKAuroWxOJa4sGu\nixEYApoa2+EvwOQ71XmkA8Nz2HM8h72D57HnevscprsQyDhsPm3aNKxduxbPP/881q5dixdffLFL\nmfecOXOwdu1aAMDmzZtRXV2NkpISAIAsyxg2bJizTnzz5s0YNWpUV99L1tjV5mocTuYNXXIqy93z\n3JJnqRg7rBERUe5kzLxbW1vx5z//GQ0NDQDMYfRVq1bhtdde6/R5M2bMwOTJk7FgwQIIgoBly5Zh\n9erVCIVCmDdvHpYuXYolS5bAMAyMHz/eKV7rS1E9BlmUoaqAPedtqDJ8VtB2B2jJ9bVhCDBARESU\nGxmD9w033ICamhq89tpr+MpXvoLXX38dP/jBD7r04jfddJPn9oQJE5yvR4wYgSeeeOLAjjbL4loc\nPlFBXNUhfjEdyvCPEd5+CJSqjgMU7gI1QTAYvImIKGcyDptHo1HceeedqK2txa233orHHnsML774\nYi6OLeeiWgw+yYeYqkNRy1C291hA9UNJUSgnJbdKNRi+iYgoNzIG73g8jvb2dui6joaGBpSXl2P7\n9u25OLaci2kx+CQz81ZkCbpuBmR3NzWbuymLJAuoCAVydpxERFTYMg6bn3XWWXj66adx7rnn4rTT\nTkNlZSVGjBiRi2PLuZgeQ7lYigZVQzCgQNV0AHDmvN3cwfvsY0alDPBERETZkDF42wVngLmka9++\nfZg4cWLWDyzXDMNATIvDJ/kQ13Qosoj2iAogc+bNGW8iIsqljOmiu7vaoEGDMGnSJCeYH0xUXYUB\nwwzeqg6fLELVzcxbSbEVqOgJ3kRERLmTMfOeOHEifvKTn2D69OlQFMW5f/bs2Vk9sFyLWq1RFVGB\nqhlQZBGaZs15KykK1kT3rmIM30RElDsZg/eHH34IAHjrrbec+wRBOOiCt92gRbY28VZkyZnzTpV5\ne4fN9RwcIRERkSlj8H788cdzcRx9zt4ONBG8RSd4y3IiUFeVB1DXGPEOmzPzJiKiHMoYvC+88MKU\nc9wrV67MygH1leTM2yeLUK1hc9k1RL78yqMQi2t4+p9bnftYsEZERLnUpQ5rtng8jvXr1yMYDGb1\noPqCvZe3CHN+293HXHb1MZcl0dkaVI8UQQyEUSQX5fBIiYio0GUM3rNmzfLcnjNnDq688sqsHVBf\nienmsLmExLC5rUM3Nfs5Hx+B4NCdOPb4g2v+n4iI+reMwTu5m9quXbvw+eefZ+2A+krMybzNU+Ju\nzCKLqZbGGTCixZD3TIFPUlJ8n4iIKDsyBu9LLrnE+VoQBJSUlOC6667L6kH1BSd4GzIA3ZN5y510\nTzv4VrwTEVF/lzF4/+Mf/4Cu6xCtoq14PO5Z732wiOnu4B3zbEYipxk2JyIi6gsZo9LatWuxePFi\n5/ZFF12El156KasH1RfsgjUY5ilxr+2WUgybc3UYERH1lYzB+9FHH8WPf/xj5/bvfvc7PProo1k9\nqL4Qt9Z5Q7fmvBV3wVr6wfGDsVUsERH1bxmDt2EYCIVCzu2SkpKDMmBFtCgAQLCCtzvzdq/ztjHx\nJiKivpJxznvKlCm44YYbMGvWLBiGgVdffRVTpkzJxbHllB287cxbUdzrvDnnTURE/UfG4H3bbbdh\nzZo12LBhAwRBwJlnnolTTjklF8eWU1HVCt6anXm7C9YOvpEGIiLKXxmDdzgchqIouP322wEATzzx\nBMLhMIqLi7N+cLlkZ96GZgZt91KxUNDX4fFVZQEAQG3VwXUeiIio/8s4Hnzrrbeivr7euR2JRHDL\nLbdk9aD6gp1566oZvH2yiOVXHolLTjkEIwaHOjz+lCOH44KTx+HKMybl9DiJiIgyBu/GxkYsWrTI\nuX3ZZZehubk5qwfVFyJaBIqoQNPM24osYsiAYhx/WG3KxyuyhHkzh6XMyomIiLIpY/COx+PYujWx\ng9bGjRsRj8ezelB9IaJFEZD8iKvWHt6ddFUjIiLqSxnnvL/3ve9h8eLFaGlpga7rqKiowL333puL\nY8upqBpFQPYjxuBNRET9XMYINW3aNKxduxarVq3CkiVLUF1djWuuuSYXx5ZTyZm3z9UelYiIqD/J\nmHm/9957WL16NV544QXouo4f/ehHmD9/fi6OLWd0Q0dUi8Ev+xHXmHkTEVH/ljZC/eY3v8Fpp52G\nG2+8EZWVlVi1ahWGDx+O008//aDbmMTuax6Q/IjHzYo1Bm8iIuqv0mbeDz74IMaOHYs77rgDRx11\nFICDt4931FrjHZADaGPmTURE/Vza4P3yyy/jT3/6E5YtWwZd13H22WcflFXmABCx1nj7JbNgTRBS\n7yRGRETUH6RNL6uqqnDVVVdh7dq1WLFiBb744gvs3LkTV199NV555ZVcHmPWOZm3VbDmk6WDdpSB\niIjyX5fGhmfOnIm7774br776Kk444QT8/Oc/z/Zx5VRYjQCAWbCm6hwyJyKifu2AolRJSQkWLFiA\np59+OlvH0ye8mbfG4E1ERP0aoxSA9ngYAFAkB9DcHkdx4OCqpiciooMLgzeAdtUM3qLuRzSmoao8\n0MdHRERElB6DN4D2eDsAIBoxT0dVeVFfHg4REVGnGLwBtFmZd6SNwZuIiPo/Bm8kMu+WVvM2gzcR\nEfVnDN5IzHk3NZnd1TjnTURE/RmDN4C2eDsUUUFLmxm8K0L+Pj4iIiKi9Bi8YQ6bFytBRGPmpiQ+\nhduBEhFR/8XgDXPYPCgXIRLX4FNEiGyNSkRE/VjBB2/d0BFWIwgqRYjFNfiZdRMRUT9X8ME7rEZg\nwECxHEQ6NoGoAAAYmElEQVSUwZuIiPIAg7ddad6so6E5Cr+PwZuIiPq3gg/eMc3co3zL9jYYADNv\nIiLq9wo+eMd1M3gbunkqGLyJiKi/Y/DWVfMLBm8iIsoTDN5W5g3dDNo+peBPCRER9XMFH6ni1pw3\nDPNUBFiwRkRE/RyDtzPnbWfeDN5ERNS/MXhzzpuIiPIMg7cz583gTURE+YHBW/MOmzN4ExFRf5fV\n4L1ixQqcf/75WLBgATZs2JDyMffffz8WLlyYzcPolDNsbhWsiSI3JSEiov4ta8H7zTffxLZt2/DU\nU09h+fLlWL58eYfHbNmyBf/5z3+ydQhdkrxUTNP0PjwaIiKizLIWvNetW4e5c+cCAMaMGYOmpia0\ntrZ6HnP33XfjxhtvzNYhdEksqcOapht9eThEREQZZS1419fXo6KiwrldWVmJuro65/bq1asxa9Ys\n1NbWZusQukR1qs3NzLu4SOnDoyEiIspMztUPMoxERtvY2IjVq1fj0UcfxZ49e7r0/IqKIGS5d4vJ\nqqpCkD63bugi5h85Al89aTwkznsfkKqqUF8fQt7jOew5nsPewfPYc7k4h1kL3tXV1aivr3du7927\nF1VVVQCA9evXY//+/bjooosQi8XwxRdfYMWKFVi6dGna12toaO/V46uqCqGurgXN7ebrGrqEuTNq\nsH9fa4Znkpt9Hqn7eA57juewd/A89lxvn8N0FwJZGzafM2cO1q5dCwDYvHkzqqurUVJSAgA45ZRT\n8MILL+Dpp5/Gz372M0yePLnTwJ1NqqvaXBILfuUcERHlgaxl3jNmzMDkyZOxYMECCIKAZcuWYfXq\n1QiFQpg3b162fuwBi7matMgSh8uJiKj/y+qc90033eS5PWHChA6PGTp0KB5//PFsHkannI1JdAmy\nxMybiIj6v4KPVqquWg1aBBaqERFRXij44B3T4xAMs4qdmTcREeWDgo9WcSt4CwJboxIRUX5g8NZU\nCKw0JyKiPFLwESuuxwFDYqU5ERHlDQZvPW4tEyv4U0FERHmioCOWYRiIaXFAl1lpTkREeaOgg7dq\naDBgwGCDFiIiyiMFHbxjWsz8QpcgcdiciIjyREFHLDt4G5rEYXMiIsobhR28rb7mhsaCNSIiyh8F\nHbHszFtn5k1ERHmkwIM3M28iIso/BR2xYnoi82a1ORER5YvCDt4sWCMiojxU4MHb3stb5FIxIiLK\nGwUdsRLrvGXOeRMRUd4o6IjlLBXTRQ6bExFR3ijs4O3qsMaCNSIiyhcM3gCgSdzPm4iI8kZBR6zE\nsLkERSnoU0FERHmkoCOWe9hcYcEaERHliYKOWFFnqZgEHzNvIiLKEwUdseJWhzWDmTcREeWRgo5Y\nMVfmrchS3x4MERFRFxV08I46c94iFLmgTwUREeWRgo5Yqq5CggRAgI/Bm4iI8kRBRyzVUCEKMgAw\n8yYiorxR0BErrschwpzrZvAmIqJ8UdARK66pruDNgjUiIsoPBR28VUOFYJingJk3ERHli4KOWKqu\nQrAybxasERFRvijoiBXXVQgG57yJiCi/FGzEMgzDzLw5bE5ERHmmYCOWqqvmF8y8iYgozxRsxIpr\nVvDW7cyb1eZERJQfCjd423t5W8PmLFgjIqJ8UbARy868DZ1z3kRElF8KNmLFrMwbmghBACRR6NsD\nIiIi6qKCDd6qlXnrugBFFiEIDN5ERJQfCjZ423t5G5oIRSrY00BERHmoYKOWXbCmaQJ8CivNiYgo\nfxRu8LaHzTWBmTcREeWVgo1acd0O3iIUpWBPAxER5aGCjVpxa85bUwXIYsGeBiIiykMFG7Xcw+ay\nzEpzIiLKH4UbvK2CNV0TITHzJiKiPFKwUcvpbW6IkCVm3kRElD8KN3jbvc11ETKrzYmIKI8UbNSy\nm7TAENkalYiI8krBBm9nP29dhMTMm4iI8kjBRq2Ys6uYBJmZNxER5ZGCDd5x97A5C9aIiCiPyNl8\n8RUrVuD999+HIAhYunQppk6d6nxv/fr1eOCBByCKIkaNGoXly5dDzOGSrYgaNb/QJBasERFRXsla\n1HrzzTexbds2PPXUU1i+fDmWL1/u+f4dd9yBhx56CE8++STa2trw6quvZutQUgqrEQCAocssWCMi\norySteC9bt06zJ07FwAwZswYNDU1obW11fn+6tWrMXjwYABAZWUlGhoasnUoKUXiZvCGJjPzJiKi\nvJK1qFVfX4+KigrndmVlJerq6pzbJSUlAIC9e/fi9ddfx/HHH5+tQ0kprEYhQLCqzZl5ExFR/sjq\nnLebYRgd7tu3bx+uvvpqLFu2zBPoU6moCEKWe2/f7XA8Ap/kRzsElJYEUFUV6rXXLjQ8dz3Hc9hz\nPIe9g+ex53JxDrMWvKurq1FfX+/c3rt3L6qqqpzbra2tuPLKK3HDDTfgmGOOyfh6DQ3tvXp8YTUC\nBQoAIBqNo66upVdfv1BUVYV47nqI57DneA57B89jz/X2OUx3IZC1YfM5c+Zg7dq1AIDNmzejurra\nGSoHgLvvvhuXXHIJjjvuuGwdQqci8QgU0QcALFgjIqK8krXMe8aMGZg8eTIWLFgAQRCwbNkyrF69\nGqFQCMcccwyee+45bNu2Dc8++ywA4IwzzsD555+frcPpIKxGUSGXAgAL1oiIKK9kdc77pptu8tye\nMGGC8/WmTZuy+aM7FddVqLoKQTffPoM3EVHfevnlv+OEE07u0mN/8pP7ce65C1BTU5vlo+q/CjJq\nRa0GLbv2xgBw2JyIqC/t2vUl/va3tV1+/PXXf7egAzeQw2rz/iSimcHb0MzqdS4VIyLqOw88cA8+\n/HAzHn30N9B1HV9+uRO7dn2JBx/8Be66607U1e1FOBzG5ZdfhTlzjsV1112F73znFvzzn39HW1sr\nvvhiG3bu3IFvf/u7mD17jvO6qqpi+fIfdHj+J598hPvvvweiKGDKlGm49trrU95n/5zRo8di1aqn\n0NjYiOnTD8eTT/4v2tvbcd11N+Ldd9/Gyy//HbquY/bsObj11u+ipaUFd955G9ra2lBSUoI77vgf\nXH75Rfj9759AMBjEhg3v4cknV2LFih93+5wVZPCOWsEb9rB5DtuyEhH1Z0//Ywv+89HeXn3NmROq\ncd5JY9N+/4ILFmL16qdx2WVX4pFHHoaqxvGLX/wWDQ37MWvWUTj11DOwc+cO3H77EsyZc6znuXv3\n7sF99z2E9ev/jT//eZUneLe0NKd8/oMP3oebb16KsWPH4Uc/ugO7d+9KeV86W7duwRNPrIbP58O7\n776NX/zitxBFEeeddxauvfabeOKJxzFr1myce+4CPPXUSrzzzls47rgT8dpr/8L8+afgtddewbx5\nX+nROS3I4G33NTc08+0z8yYi6j8mTpwMAAiFSvHhh5uxZs1qCIKI5uamDo+dOvUwAObyZHcXz86e\n/8UX2zB27DgAwO2335n2vnTGjh0Hn89crRQIBHDddVdBkiQ0NjaisbERn3zyEa644hoAwPnnXwQA\nqKmpxW9/+0vMn38K3n33bXzjG1cf+IlxKczgrSU2JQFYsEZEZDvvpLGdZsm5oChmD46//vUlNDc3\n4+c//y2am5txxRULOzxWkhLNu5KbgaV7fqpNsFLdJwiJxE5V1Q7Ht3v3Ljz11Er87ncrEQwGsXDh\nedZrSTAM3fNaY8eOw759+/Dhh5sxatQY+P3+zk9CBgUZtSL2piR25s2CNSKiPiOKIjRN63B/Y2Mj\nhgypgSiKeOWVfyAejx/Q66Z7/siRo7B5s7ni6a677sR///t5yvuKi4uxb5/ZbGzjxvdTvn5FRQWC\nwSA+/vgj7N69G/F4HBMnTsLbb/8HAPDcc6vw4ot/AQCcdNI8PPDAPZg375QDeh+pFGTwthlx88qH\nmTcRUd8ZMWIUPv74Izz00P2e+0844ST8+9+v4vrrr0FRURGqq6vx6KO/6fLrpnv+9dffhJ/97P/D\nNdd8A6FQKUaOHJXyvjPPPAf3338vbr75egwcWNXh9ceNG4+ioiCuueZy/P3v/4ezzjoHP/zhD3Hu\nuRdg06YNuO66q/Dvf7+G448/EQBw8snzsHfvXhx++MyenTAAgpGq6Xg/1Jvt5uJaHNf89hnojdWA\nIeKWC6ZjwojOe6tTamyn2HM8hz3Hc9g7eB57rrNz+Pzza7B79y584xvfPKDXS6Ug57wVSYHeMNi5\nzcybiIiy6Z57/gdffrkTd911X6+8XkEG72SsNiciomy69dbbevX1CjLl1HXvTAEL1oiIKJ8UZPCO\nxr1VjRw2JyKifFKQUSvWIXgz8yYiovxRkME7OfOW2B6ViIjySEFGrWjc2/mGmTcRUd96+eW/H/Bz\n3nvvHTQ07M/C0fR/hRm8Y0mZN+e8iYj6zIFuCWp7/vk1BRu8C3KpWMdhc2beRER9xb0l6PnnX4gV\nK36IlpYWaJqGG264GWPHjsP//u/v8cor/4Qoipgz51hMnDgJr776Mj7//DP8z//ci8GDzd4dfbEN\n6OWXX+VsAxqLReD3F2VlG1A3Bm+w2pyIyLZ6y1/w7t6Nvfqa06sPxTljz0j7ffeWoL///W9x5JFH\n4//9v6/i888/w09+ch8efPAXePLJ/8Vzz70ESZLw3HOrMHPmURg7djy+851bnMAN9M02oOeff6Gz\nDejixVfiZz/7VVa2AXVj8AabtBAR9RcbN25AY2MD1q59AQAQjZobSZ1wwsm44YbFmDfvFMyfn35j\nj77YBrS5uTkn24C6FWTwrgz54ZNF6IYBVTMgCgzeREQAcM7YMzrNkrNNUWTceOPNmDJlquf+m276\nHrZt+y/+8Y+/4lvf+iZ+/es/pHz+wbwNqOfYe+2V8sghwyvw1IrT8fBNJ+DXN5/Q14dDRFTQ3FuC\nTpo0Bf/618sAgM8//wxPPvm/aG1txaOP/gYjRozEZZddiVCoDO3tbSm3Ej2YtwH1nLNefbU8Iksi\nBEHgfDcRUR9zbwn69a+fj507t2Px4itwzz3/g8MOm4GSkhI0NjbgyisX4dvfvhqTJ09BaWkZDjts\nBm677VZ89tlW57X6YhvQ+++/x9kGdOHChVnbBtStILcEBbj1XW/heew5nsOe4znsHTyPPZd8Druz\nDWjy66VSkHPeRERE2dbb24C6MXgTERFlQW9vA+rGCV8iIqI8w+BNRESUZxi8iYiI8gyDNxERUZ5h\n8CYiIsozDN5ERER5hsGbiIgozzB4ExER5Zm8aY9KREREJmbeREREeYbBm4iIKM8weBMREeUZBm8i\nIqI8w+BNRESUZxi8iYiI8kxB7ue9YsUKvP/++xAEAUuXLsXUqVP7+pD6tU8++QSLFy/GpZdeiosv\nvhi7du3CLbfcAk3TUFVVhR//+Mfw+XxYs2YN/vCHP0AURZx33nk499xz+/rQ+417770Xb7/9NlRV\nxTe/+U0ceuihPIcHIBwOY8mSJdi3bx+i0SgWL16MCRMm8Bx2UyQSwRlnnIHFixdj9uzZPI8H4I03\n3sD111+PcePGAQDGjx+PK664Ivfn0Cgwb7zxhnHVVVcZhmEYW7ZsMc4777w+PqL+ra2tzbj44ouN\n2267zXj88ccNwzCMJUuWGC+88IJhGIZx//33GytXrjTa2tqM+fPnG83NzUY4HDZOP/10o6GhoS8P\nvd9Yt26dccUVVxiGYRj79+83jj/+eJ7DA/T8888bv/71rw3DMIwdO3YY8+fP5znsgQceeMA455xz\njFWrVvE8HqD169cb3/rWtzz39cU5LLhh83Xr1mHu3LkAgDFjxqCpqQmtra19fFT9l8/nw29+8xtU\nV1c7973xxhs4+eSTAQAnnngi1q1bh/fffx+HHnooQqEQAoEAZsyYgXfeeaevDrtfmTlzJn7yk58A\nAEpLSxEOh3kOD9Bpp52GK6+8EgCwa9cuDBo0iOewm7Zu3YotW7bghBNOAMD/z72hL85hwQXv+vp6\nVFRUOLcrKytRV1fXh0fUv8myjEAg4LkvHA7D5/MBAAYMGIC6ujrU19ejsrLSeQzPa4IkSQgGgwCA\nZ599FscddxzPYTctWLAAN910E5YuXcpz2E333HMPlixZ4tzmeTxwW7ZswdVXX40LLrgAr7/+ep+c\nw4Kc83Yz2B22R9KdP57Xjv72t7/h2Wefxe9+9zvMnz/fuZ/nsOuefPJJfPjhh7j55ps954fnsGue\ne+45HHbYYRg2bFjK7/M8ZjZy5Ehcd911OPXUU7F9+3YsWrQImqY538/VOSy44F1dXY36+nrn9t69\ne1FVVdWHR5R/gsEgIpEIAoEA9uzZg+rq6pTn9bDDDuvDo+xfXn31VfzqV7/Cb3/7W4RCIZ7DA7Rp\n0yYMGDAAQ4YMwcSJE6FpGoqLi3kOD9DLL7+M7du34+WXX8bu3bvh8/n4t3iABg0ahNNOOw0AMHz4\ncAwcOBAbN27M+TksuGHzOXPmYO3atQCAzZs3o7q6GiUlJX18VPnl6KOPds7h//3f/+HYY4/FtGnT\nsHHjRjQ3N6OtrQ3vvPMOjjjiiD4+0v6hpaUF9957Lx5++GGUl5cD4Dk8UG+99RZ+97vfATCnvtrb\n23kOu+HBBx/EqlWr8PTTT+Pcc8/F4sWLeR4P0Jo1a/DII48AAOrq6rBv3z6cc845OT+HBbmr2H33\n3Ye33noLgiBg2bJlmDBhQl8fUr+1adMm3HPPPdi5cydkWcagQYNw3333YcmSJYhGo6ipqcFdd90F\nRVHw0ksv4ZFHHoEgCLj44otx5pln9vXh9wtPPfUUfvrTn2LUqFHOfXfffTduu+02nsMuikQi+P73\nv49du3YhEonguuuuw5QpU3DrrbfyHHbTT3/6U9TW1uKYY47heTwAra2tuOmmm9Dc3Ix4PI7rrrsO\nEydOzPk5LMjgTURElM8KbticiIgo3zF4ExER5RkGbyIiojzD4E1ERJRnGLyJiIjyTME1aSHKN/fe\ney82btyIaDSKDz74ANOnTwcAfO1rX8NXv/rVLr3Gr3/9a4wfP97pZ53KwoUL8fvf/x6SJPXGYXvs\n2bMHn332GWbPnt3rr01UiLhUjChP7NixAxdeeCH+9a9/9fWhHLA1a9Zg69atuPHGG/v6UIgOCsy8\nifLYT3/6U+zYsQNffvklbr31VkQiEdx3333w+XyIRCJYtmwZJk+ejCVLluDwww/H7Nmzcc011+CY\nY47Bhg0b0NbWhocffhiDBg3CIYccgs2bN+OXv/wlGhsbsXv3bmzbtg1HHnkkbr/9dkSjUdx6663Y\nuXMnBg8eDEmSMGfOHM8exW1tbfjud7+L5uZmqKqKE088EWeccQYefPBBGIaB8vJyXHTRRbjzzjux\nbds2tLW14YwzzsDll1+O1atX469//SsEQcCePXswevRorFixAoqi9OEZJuqfOOdNlOd27NiBxx57\nDFOmTEFjYyN+8IMf4LHHHsOiRYvw8MMPd3j81q1bcc4552DlypWYOHEiXnzxxQ6P+eCDD/DQQw/h\n2WefxerVq9HU1IQ1a9ZAVVU888wzuOOOO/D66693eN6///1vqKqKP/7xj3jyyScRDAZRW1uLs88+\nG2eeeSYuu+wyPPbYY6iursbjjz+OZ555Bs8//zw++ugjAMDGjRv///bu2CW1MIzj+NcONQQRQi3W\nYnBsjDoSBFKNOVaEo0M4REO4HGyrKQin5ob+gDBaoiVyECEipakhWkKkQKFoiERPd5DOzYxLlysX\njvw+4+F5X97tx/PyHh7S6TSHh4eUy2VP3jKI/A/qvEU8bmJiAp/PB8DQ0BC7u7u8vb3x8vLC4OBg\nW73f78c0TQACgQBPT09tNZZlYRgGhmHg9/t5fn7m5uaG6elpAIaHh7Esq23d1NQUe3t7bGxsMDc3\nx8rKCj09rT3CxcUFDw8PXF5eAlCr1bi/v3fXf4xPnZyc5O7uzp2TLCK/KbxFPO7ztbJt22xvbzMz\nM8P5+bk7zOOzrw/Svnv28l2N4zgtQfw1lKE5y/j4+JhiscjZ2RnLy8scHR211PT19bG+vs7CwkLL\n90wmg+M4fzyXiDTp2lyki1QqFUzTpNFocHp6Sq1W69jeY2NjFItFAKrVKldXV201uVyObDaLZVnY\ntk1/fz/VahWfz0e9XgeaXf3HVb3jOOzs7Ljd//X1Na+vr7y/v1MoFBgfH+/Y+UW6iTpvkS6SSCSI\nx+MEAgFWV1exbZuDg4OO7L20tEQ2myUWizE6Oko4HG7r0IPBIKlUiv39fQzDIBKJMDIyQjgcJplM\n0tvby9raGre3t8RiMRqNBvPz8+6o1FAoxObmJqVSCdM0iUQiHTm7SLfRr2Ii8iOPj48UCgWi0SiO\n47C4uMjW1pb73/m/ymQy5PN50ul0R/YT6WbqvEXkRwYGBjg5OXHnE8/OznYsuEXk76jzFhER8Rg9\nWBMREfEYhbeIiIjHKLxFREQ8RuEtIiLiMQpvERERj1F4i4iIeMwvRph4T/csGFUAAAAASUVORK5C\nYII=\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEcCAYAAADUX4MJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsvXeAHMWZ/v/pNGlnc5S0ymmFUE6WEAgQ2UJkGxtjsMEG\nbDD+YnNwZ3PnH+fD2GcwnDFHMBmcwETLIiMJ5YByzqvd1eY0eTr9/uie7p7dlRACHQ7z/LM73dVV\n1dXd71NvqLcE0zRNcsghhxxyyOFjIH7eHcghhxxyyOHvAznCyCGHHHLI4ZiQI4wccsghhxyOCTnC\nyCGHHHLI4ZiQI4wccsghhxyOCTnCyCGHHHLI4ZiQI4wccviUWL16NXPmzDmmsg899BC33377Ce5R\nDjmcGOQII4dPjTPPPJNx48bR2dmZdfyiiy6ipqaGhoYGAO68805qamrYvHmzU6a2tpaamhrn99VX\nX81LL73k/H7kkUeYO3cukydP5vTTT+e2224DYN68eUyePJnJkydz0kknMX78eCZNmsTkyZN57LHH\nTuTt9glBEE5I2Rxy+FuC/Hl3IId/DFRXV7NgwQKuuuoqAHbt2kUqlcoSjoIgUFRUxAMPPMATTzyR\ndbwvvPLKK7zxxhs888wzVFdX09bWxvvvvw/AX/7yF6fc1VdfzcUXX8xll112Im7tnwa6riNJ0ufd\njRz+hpHTMHL4THDRRRfxyiuvOL9feeUVLrnkkl7lLrnkEnbu3MnatWs/ts4tW7Ywe/ZsqqurASgt\nLeWKK67os+zHJSx46KGHuPXWW7n99tuZPHky8+fP58CBAzz22GPMmjWLM844g+XLlzvlm5ubuemm\nm5gxYwbnnnsuL774onMulUpx5513Mn36dObNm5elMWWu/d73vsfMmTM566yzeO655z72XgG6u7u5\n8cYbmTlzJjNmzODGG2+kqanJOd/V1cW//uu/cuqppzJjxgxuvvlm59y7777LxRdfzJQpUzjnnHNY\nunQpYGl/K1asyBqHjEmsvr6empoaXnrpJc444wyuvfZaAG699VZmz57NtGnTuPrqq9mzZ0/Wvd97\n772ceeaZTJ06lauuuopUKsUNN9zACy+8kHU/8+fP57333jume8/h7wM5wsjhM8GECROIxWLs27cP\nwzB48803mT9/fi9BHggEuPHGG7n//vuPqc5XX32VJ554gi1btmAYxqfq46JFi7jkkktYu3YtY8aM\n4brrrsM0TT788EO+853vcNdddzllb7vtNvr378/SpUt58MEHuf/++1m5ciUAv/71r6mrq+O9997j\niSee4NVXX3WuM02TG2+8kTFjxrB06VKefvppnn32WZYtW/ax/TMMg8suu4zFixfzwQcfEAgEuPvu\nu53zt99+O6lUioULF7J8+XJHwG/atIk777yTO+64g3Xr1vH8888zYMCAI7bTU6Nbu3YtCxcudLS+\nOXPm8M4777B8+XJOOukkfvjDHzpl7733XrZt28Yf//hH1qxZw+23344oilx88cW89tprTrkdO3bQ\n3Nx8zL6dHP4+kCOMHD4zXHTRRbz66qssW7aMYcOGUVFR0We5L33pSxw+fJgPP/zwqPXNnz+fu+66\ni2XLlnH11Vcza9asT+WfmDp1KrNmzUIURc477zw6Ojr49re/jSRJXHDBBTQ0NBCNRjl8+DDr16/n\nhz/8IYqiUFNTwxVXXOEIxDfffJObbrqJ/Px8Kisrufrqq502Nm3aRGdnJzfddBOSJFFdXc0VV1zB\nggULPrZ/RUVFnH322fh8PkKhEDfccIOjiTU3N7N06VLuvvtuwuEwkiQxdepUAF566SUuv/xyZs6c\nCUBFRQVDhw49pjERBIFbbrmFQCCAz+cD4NJLLyUYDKIoCt/97nfZsWMH0WgU0zR5+eWX+fGPf0x5\neTmCIDBx4kQURWHu3LkcPHiQ2tpaAF577TUuuOACZDln9f5HQu5p5vCZYf78+Xzta1+jrq6Oiy66\n6IjlfD4f3/nOd3jwwQe57777jlrnvHnzmDdvHrqu8+677/KDH/yAsWPHcsopp3zi/pWWljr/BwIB\niouLndl2IBDANE1isRgtLS0UFhYSDAad8v3792fr1q2AJbyrqqqyzmXQ0NBAU1MT06dPByyNwzAM\npk2b9rH9SyaT3HPPPSxdupTu7m5M0yQej2OaJo2NjRQWFhIOh3td19jY+Klm8t57MQyD+++/n7fe\neouOjg4EQUAQBDo6Okin06TTaQYOHNirDp/Px/nnn8/rr7/Od7/7XRYsWMCvf/3r4+5TDn+byGkY\nOXxm6N+/PwMGDGDJkiWcc845Ry176aWXEolEeOedd46pbkmSOPfccxk9ejS7d+/+LLp7RFRUVNDV\n1UU8HneOHT582NGYysvLOXz4sHMuEwUG0K9fP6qrq1m9ejWrV69mzZo1rFu3jkceeeRj233yySc5\ncOAAL730EmvXrnV8AqZp0q9fP7q6uohGo72uq6qq4tChQ33WGQqFSCaTzu+WlpZeZbwmqjfeeIMP\nPviAZ555hrVr1/L+++87ZsXi4mL8fr+jRfTExRdfzOuvv86KFSsIBoNMmDDhY+85h78v5Agjh88U\n99xzD8888wyBQOCo5SRJ4uabb+bxxx8/YplXXnmFxYsXE4vFME2TxYsXs3fvXsaPH/9ZdzsLVVVV\nTJo0ifvvv590Os2OHTt46aWXmD9/PgDnn38+jz76KN3d3TQ2NvL88887144fP55wOMzjjz9OKpVC\n13V2797dyzHeF2KxGIFAgHA4TGdnZ9YMvby8nNNOO42f/OQndHd3o2maY666/PLLefnll1m5ciWm\nadLU1MS+ffsAqKmpYcGCBWiaxubNm3nrrbey2uzpY4rFYvh8PgoKCojH49x3330OoQiCwKWXXsq9\n995Lc3MzhmGwYcMGVFUFYOLEiQiCwL333ntUDTOHv1/kCCOHTw3vDHXgwIGMHTu2z3M9MW/ePCoq\nKnqF3mYQDod55JFHOPPMM5k2bRr33XcfP/nJT5g8efIR2/808NZz3333UVdXx6mnnsr3vvc9br31\nVsdHcPPNN9O/f3/mzp3L9ddfz8UXX+xcJ4oijzzyCDt27GDu3LnMmjWLu+66q0/NoCeuueYaEokE\nM2bM4Morr+xlZvrFL36BLMucf/75nHLKKTz77LOARVL33HMP99xzD1OmTOHrX/+6owHdeuut1NbW\nMn36dH7zm99w4YUXHvGewdIS+vXrx2mnnca8efOYNGlS1vk77riDUaNGcfnllzNjxgzuu+++LNK5\n+OKL2b17t0OuOfxjQTiRGyj927/9G4sWLaK0tJQ33nijzzI//elPWbJkCcFgkHvvvZcxY8acqO7k\nkEMOJxivvvoqL774Yq8Q2xz+MXBCNYxLL700a4FWTyxevJja2lrefvtt7r77bv7jP/7jRHYnhxxy\nOIFIJBL8/ve/58tf/vLn3ZUcThBOKGFMnTqVgoKCI55/7733HHV+woQJRCIRWltbT2SXcsghhxOA\npUuXMmvWLMrLy5k3b97n3Z0cThA+17DanuGJlZWVNDU1UVZW9jn2KocccvikmD17NuvXr/+8u5HD\nCcbn6vTuy32SS8yWQw455PC3ic9Vw6isrKSxsdH53djYeMTVwV6Yppkjlj6w+1AHtz2whLOmDeLW\nKycdsdyFP7BWLL/+y/m5cQSu+veFdMfSnDdzCN+9vO+1A5kxe+JHZ1NREvq/7N5x48ofLSAYUHjq\nrqOvifm/RmYsM3juJ+dRlO//VHXe98I6Fn1UR2VJiN/+6Ow+2/vpjbOYMLL8U7Xz94LMPT/6r3N5\n8vWtrNrayOCqfB66/cxPVe8JJ4yjBWHNnTuXF154gQsuuIANGzZQUFBwTOYoQRBoaYl8lt38u0V5\neb4zFoebrL+xeOqYxudwYzeK/I8TWe0di08Cv2KNQWdX4mOv37anBWFoyXH17/8S5eX5pFQD01T/\n5r+Vg3UdqGV5n6qOZMpaC2IYxhHvt+sYnu8/Glpbo6TTGgC6/ulyscEJJowf/OAHrFq1is7OTk4/\n/XRuueUWVFVFEAS+/OUvM2fOHBYvXszZZ59NMBjkZz/72Ynszj88kpkXwzi2SGlVM/6hCON4EfBZ\nn0EipX1s2cb2OGP/DgjDNE003cAwzL95jTyaUP9vGjpxKwj+ZqHrn+09n1DC+Lg8QQD//u//fiK7\n8HeFV5bsQ5FF5s0aclzXJ9M6AFqPl0TTDR59fSunnNyPiSNdDU7tMeMwTJPH39jG5FHlTKv5eNPg\nkbB6exMbdrdy/YUnIX4Ggmr7wQ4+3NjANy4Y84kJ7u3Vtai6wRdnDgFg4cqDxFMal80Z7pQJ+qw9\nIHYe6uRXf9rI9fPGkB/y9VlfY1u8z+PHCtM0eXrhDkZWFzF7fL9jvm7jnlbW7mzmG+ePQRSPPqYv\nLdpLSZGVB8swTVTNwKdIJFIaD7+yme64ypTR5Zw7bRCPvLaFc6cPomZwcVYde+q7+MvyA5w7fRDv\nravj+nljHGI9XpimydKGlQihbsy4Gz0ZswnDNE1eeGcXowYWMX1MZda1ndEUzyzcwVfPHkV5UZBn\n39zBkH4FnDahP0fC+l0trNja2Ov42wc/YEndCn404zaCspWRIJnWeGLBdubNHMLgqnyn7CtL9iHL\nIudOG8hv/7KNc6YPQhDg7dWH+OYXxyAKAk/+dTtnTBrAqIFFR+zLB+vr2X6gHd0wue6LJxHwSTy5\ncAszavpRVhTkpUV7+cYFYwgHlaOOoWGaPPfWTsYMLu41Ru3dSZ57aydfO2e0c+wXv19Pfsiq87Og\njlzywb8hvLH8AMBxE0bKJoyequfO2k7W7Wxh3c4WnrzTtWFqWna5uuYoq7Y1sWpbE9PuPH5b5yOv\nWUn65s8eStVnYO//799b0TcnDyth1snHLmQB/vC+tZdDhjBeXLQXIIswAn7rM0imdTbva2PBioNc\nOXdkVj2CYE1Qu2Kp47qHDJJpnQ83HebDTYc/EWE8+NImAM6ZNoiBFb0TEHrx15UHs9tUdXyKxKHm\nKFsPdADQHUtTkh9g4942Nu5ty3ovAH7+wkfohsmmvW0ArNhSzBmTq4+5v31hS9t2/rDzFQKjAyTW\nn+4cz2gY0YTK+x/V8/5H9VnCcEf7bl5eu5Y9eytpj2zmrmumsmhDA2xoOCJh7O7Yy0MLt2AmXOGv\n2xrGa3sXAlAfPcyIIiur73vr6li3s4Wt+9t5+DZ3hX3mmxxVXcjanS0E/DLLNh3GBEYMKCQUkFm9\nby8fRRfxmytvQBazReob+94iko7y7lsuIb+5+iCD+iusDz7HmlWDKY9PpaE1RuHivXz9vBqOhuaO\nBIs3NLB4Q0MvwvjtX7axo7YT5QN3/5JoQiWaUBH8cUw+/beYs0f8A8HRMHqYpOJHMLX01DAy139W\n+KyTCLR19y2sNUPrs62jte89F/Rnf+R9jYMsiUc890lgfMoxMY7R3OhFMq2zs30Pf2n4MwhW/6MJ\nFekomkrGrCnkdYKkEgpkz3y70xHePvgBunHs47Hy8DrrHymddTyaUDFMgx1te+hrHvzrDY9TL68H\nOUVje5y0mmnT4L9W3c/Le/6SVd4wDR5Y/yiBccuy6lM1g7Tumr/SejrrnFjYgjFkFRtatvTqQyRu\nXdfUHndqTKQ1UqqO/+RlyFUH2dDcO1/YmwfeY1nDqqx+mCbsi1i5vuSqg6iabtenO/3f3bG3z/e3\nuaNvDVc3dGqL30Cu3oUiZYt1sbCFwIQlJPx1fV77SZAjjL8RaJ+BQyp5BOdW0kMY3pdQ7aFhpNTP\nmjA+m3qCfstk1BntTRgN0UZuXfRvvFu7uNe5tOreX8/xTXvuvafZrOe4eMt8WsI4VptyR7KTrW07\nevUn84w/CZIpjf/Z8Bh7YzsR861913XDxMRECHXjG7GepfUre10nFjcSGLsSZeBOeloWf7n2N7y2\ndyEfNW/qdd2KhjV9Po+D3XZG3VS2hhRNqLy463We3v00UqmV+TcSt4R5d9p1Uguyagl9ezzEgg4a\nYo28V7skS1uOqwn3HvLbnf9VzeBQpN75HVOzha/cfy9SUSsv7sqO4gKIJlUEX5zD7VGHaFXNQNMM\nBNF+pj0GyUtOSO5zkyWRjnRH1m9w39E/7XqNB9Y/yrrmjb364TWJxpJu/S2JNkx/DKX/Pgrzss2p\nUqG1GDoldvWq75MiRxh/IzheQaRqOnsbupw6hGCEpsKlJDUrpfXuuk7W7Gy2CitJNh5yzRVb9reh\n6Qa76zqtvReSfQuj7liaw22xY+pPbVME5DRCsBvdMDnUHKUj0lvQ17VE+WhXizO7OhqKwn7kfnvZ\nl9rE4bYYLZ2uQFh52MrY+vq+N3td53Wm9iTDlGe89R47+Xn71BlN0dQedz7mhJo9O/44dMfTNLS6\nY9fXxKC2KdKLDH+8/B4e3vgknamurPtNqTrN8RY6U9kf//aWvexv7OxTgzkcdVOaC7Lb/2hCQ648\niFTSxO93vuwcb+uy3h2l2kojLxa0O+S7pXU7d3z4/9GWtASxZljvzKHmKA2tMZYcWM/zO17klT0L\n2LinlbSqs35XC3/e9i4dKYusDDF7DOtaoyypt7bHFQKWQFy5rQnTNNnZ7ppXBMW6LqNhSCWufyLz\n/bR0Jlm5y33HMwQJ1thn+g0QVd3nEjHakeyykXSUVdsbWbuj2Tm/uHYVgYlLSBbudrSvnYc62RV3\ntZGE5qaRB9hS56ac9457Q2uMpqSVHNLURYcwDhyO0NAa48N6a0vdPa3ZGoGmG/z10EL8J60ATJra\nrfdibeP6LHI50Nht/aOk8I34CKnEakslu3/Hg5wP428ExzNzBPjv59exYvNhbr9yIsm0hlxeRzJY\nx57O/QwMDONnz3/klA1OWsTjexYB5wHw4gd7WbDccgJ/4/yarFl3Bg3RRn76wROk9kzgkVvOR5Gl\nI/ZFNwx+8tQafKM2IxW1cChaw2//ZM0Ye9rI//2J1QB8/bzRnD7xyNuJAsgSyAP20KqH+NHjpVn1\ntdtCqMTf2+HoJYxkSifPY1ZJpjUK7JmYN6pMHrCLneHFrG5UmV41mcde30pDWxzdMJEqD9A+eAeH\nIoMZmJ9tO09oSVJ6iiJ/YdbxH/5mGZpu8ugPT0eRxSzCeLd2MZXBCn711OGse1IN913oSnXT2elu\n5BRPqfxy3aP0z6vi+5NvBGB72y4e2vxbzHSA6092d//LoDHu7gsu+JKOP+ZQpA653J1xx9UEISXI\n7f+7HDAQApZANVNB0jaJvrxnQZagTWgJuqIp7n56Dbphogzajmwnb3jwlbV8oWYgK3cdJDhpkdsH\nOTsqavPBRoKZva1MS3j+/t3dHPQvYX2LZ5ZtC92uWBowkIotwhBMKWvC9YclW/DXZNpyiVjVDKIe\noo157mOT9h7YCoJu6jz6149AzaToN2kOWxqYWNgKh4cBsC+6C7/hrm6PezSWw20xHv7ravwZ/7Os\ngt2VNTua8U9oQrSXnkiS1XBbd5IfP7GcoL3X1vvrGri8xnAIZdX2w6jFe61ZvpKmuSNOVbnCU9t+\nnzWeO2o7AAGl/16kEpf0eo778SCnYXzO2Nd1gOe2/4l46thnrpqh8dy2P7Gv6wArNlvCpqUraWkY\n9kee0JJEssIVvTNPj/qe0kDUeLnlCbZE1/Zq65FNTyHkdSH320c0cXRSy3y0UpE1o13ZvKJXmc5U\nF49vfg4hYKX7jnn6uKN9Ny9sf7GXXTxhRhBEE1Po3X5H0iKM4kBvwuiMuTPzbEI2Wd+yyTGneM1E\nctVBNCnBxpatmKbJwaYI3bE0CAa+wZaJ6EB3tlO5NdHGv3z4E/5r1f29+q77u5BKGhytRbXbEnwJ\nXtmzgEc2P9Wr33s79zv/d6W6s3xQjYkmYmqcxpgrCFoSrXadSZ7Y9Thi2DV3AFnaiOBLUhS2JNVH\nvJJVzjv7FgJxx8IiyKqjYfQPV2VdE1XjtEdSrs/D7wpNQVGpa445moFzXNJBcN9B7+wbyX0fssgC\nMIwonQdWEImriMXNCIodXSXoJNLudVnt2f/Xr36SSDRKV6o7q+8ZJOjEiBWgHh5i1eFzZ+OZb6on\npEIrIEA9NArINnF1x9IIfvf9y+qToCP4rHOCZCDJ7jsj5nVnXdNqa3sHuw+xtMPdbEzwWd93T7Oa\n1THrfTH17MndsEFH36PmWJAjjM8Z9617mJWH17KtfYdz7OOcxR81b2Jl41ruW/ewcywvoPQgjESW\nDd/7ASBlCzUx3ElajLBLX96rrbakJXxM1Z8l3PtCMmU77aJWyGRnuqNXmUc2PsWGls3I/Syh6J3d\n/3rD4yw/vIZaj50ZIClYH5Ep9m6/w+6fX8peKdyZ6uKJgw+gDLYitrwzUKmsntfrXubJLVYKbs0x\nSZkgWuW60910x9Ik7HvyCkJnKmpjY8tWDNMgriWIqO6+F2lVJ3DycnwjNtEcs4SLrhuI4Q4CE10b\nf0Z4uPW5Zo7udMTtu6RSn7DGLaJGHRu5LGY7pMViV6MQi5pZ3f2+p62ko1n1RGui3Xaqm/hGeyYP\nctrRMCQhW2TE1FiWJpcxKYFFBLGkmkUIercdLeSx6XuFqSBpDKrsOwrM0KN0HlxBJJ7GN3gbpiFg\nJKwFfwnNHXdBSXn+t+oeMP2byEowizAyGoZmaJiCjqkpmGlLqGa+F8GXIDB+qVufl0jsdvRua11O\nTHPvfWf3dnxDtmeNhXeMvO4OFY/PJeya0JDTNLZbdf732oeoVbdl9SMaV4lrvQnDuX+jhzVA+vQa\nRs4k9TeCtMdubpgm0lHWL/S0lYJlDoqnUwglCadMWnDr9M6SBEnD1F0hk0UmR4IufewCq2TGTyBb\nwiClZ9vldUPnUNQyUZmqJeD7cgKrRnY7qhSxZjaiYc1MbbOFZmh02U5Rb8QLQF2kAR0NufIQWtNg\nkmndIWKxwBLeJiYJLen2QUk5H3JXqtv5WCF7jBJq9nh5TRGdqS7HLNXs8T3s7NjDkNIqVN1ALGrO\nul4saEdvtcxyHzVvYkm9q5l1pSOItnYUnPIe2zy32ZnqpCJUTlLLJhzRQ27+UR9lnUNJUZTn46BH\ny0zvOxnfsC00x1qJhVWEvG5Ev0cwejSMuN3WndO+z71rHiCmxok6zl0zezYuq0SjKkKBdX6wOpO9\nyUNQ0IEgq5iaDyGvC/+Y1VltDanKp7ap94ZTzWtXoMbbePBn38McahKuGknzijUopSbJg1sYdOoP\nqV/zDIZQj2molM0cSPFJFqHse+9nbJ/4bbbVr2bfcxvJG1TE/ob1bB24ih/ffbfVgC5jpizzX6xp\nOy2rX8IUUygb0wy+fCxSyI9pxmnc+CeSnXWIqxNUnjEIfypE9+42/vT4U7znf5WioiLU+WEaP9iP\n5JMoP2UQgpzmwOL7GTD9GyhVm9jx65WEhxQTr+vGd9ZYmjavJdl5COQIReNLqDpjKIKSYv2GjTx6\n3wvsbtmHKIsMu2Yi+5/fSMXkcqLJycRUld2/XUf1haMJ2kQrKGnMZLYJSjDFLBPc8SJHGJ8Cj72+\nleqKMBd8YfDHll27o5n31tXx/740AZ/S2w+wfHMTYAmaX/xuPd+/YgKvr9iNIMCX5libSv158V40\n3aB0RG/BrWoGcbPLEXhr9zTQb6ibF0kIel4WMVvD8M4KAXx2qoy31rqmEUSjF2F8uLGBrQfa+fb8\nsTzxl+2I9uRTsGePqulKt39/YhVCqAsyoeO2SSKjYUTSroCIqjF+984u8kMKF8wcjOGLuqqwpIHm\n44/v70aVXVNLysgmjPakO1MT/AmSac1pSwy5kTc/enoRna3WjNtLCm2JLnYeyq7D6V86gW4YPPDi\nJmaMqaRRd/vRmewCe01aoydQYE/3Ps5ltkVOerZGkJkRNnXEeXbVe5AHp/afyYcNK+hOdRNIZZtw\nMuhIdlERKidhE7PW2h9faTOGp6+mJiPYBB4S84kF4hT4fM4MX2+vwIhZ791bm7YxavoUZzast1eC\npCEVtpHUUqRUnX2NbQiKyLLVUZIb57DCFFlt2DNpSSO1cY6nbQXDMKCznOSGOew2/OjmaLSG4Zj2\n7NdMB0husK6RShoRC1sZXFUAG9090zOonD2BZFOCC274Aet5lc51PhLNbQz68jTMhjkYUeh/6lxC\n47dgqDq7H11HwUjL1yQIAjvTFjGl2xMMu2ISgeIyzPc6efbFV2EwmLqMmbSEbv6IPIK+m5FKGoh0\nvUnj2+30O30cTWtWIPoGMmTObfgnLEZPpkltTlP3wQ5GXDeFSmEW3z3lbH665efW/dn3KfhS9t8E\nUkE7qbYEA848jeoL06jpA5TVnIevsgNl2Eb2PrmZxKgkgao2nv3lvdzz01/wh8hLpFIJ9JbhlExp\noWvPFnYWnsaqvRswdYNgZRhTlxAkHd+odSTXneU8Y/XQSIoGtmWZ4I4XOcI4TpimycptTbCt6ZgI\n4+FXLTPDtgMdWautM2iLxsgQxu66LhZtqGeJ8RSmIfIl7sU0TRassGznlw/pHVmk6gYJwRVcta0d\nxPpZAr60wE+XR033mgPAQLRnhULm5bZZ509LNxPIbJ8t6r0I46mFlhlt3qwh2StqbdVX8xBGXUsM\nqbKOjDFEsEkrE6FUH3UFRDQd4911ll3+zMnV2dqRqGHi463Vh5BKDuMbYR1P6T0JwzWHCZJGMq1b\nhCEYCEGXnLrT3UAZYKIM3OVWIBi8unwXYAn3LA0jneJwa5yt+9vZur+dorFNYKdC6vD4Cxq6XZ9A\nl+2c13TDGZ9Qqpq4v85x5m7Z105STyIBs8pO48OGFXSlIwiq3uOZ2fdo15mJiNOaBpNXHEf3R7F8\nVgJmOoAgW/dbKJYTV/bhl9IOqZu6jJkIY2oKCV8zb66qdWamelcpYoF1D3E1wZZ97cTVJAISopKZ\nHRhoug6IzkTE1BS7jky4qfXXMHCseYKoW85eUUQ3PSHO4U5KigVkScA0rUhVpWkcauVmRJsIu9Jd\n4ANT9REor8BXFCDVbOeSalzHzqUbQVNIR1Ko3d3IFbVWqpRoAUIgSaikADlcjq5ECBcNYMee/UiD\nsTSMZAgVe2coAAAgAElEQVQjGUIz66lb9Ti61g5SAtlfweDScvbu66BqwjzARPSlkM0COjtqCQ8p\nwlcUoN3YwMrdE62bMUFrGoSpg1RxyHr37O/QVxBEjp2BEVuBmNdN5PA6utd/CIKK2mmS6ugCOQ+l\nSCQiFqGhQrICvWUoRSfvo+n9j6gfGKH90DpKJvWjguHUbqsiMG4ZgqRbwQ32M9ZaqvEPitOuu+/j\n8SLnwzhO9Ey/cawwj7RAX8wWCM2dthAXrY/Jigyx0BrrnUBN1QxU0bXPIqnUtViC4sJThhIMecxT\nGR+GqBOY/D5SiW3zNi3C0DQDwzSz7bWSSxh7Ovez6NAypLI6lOEbaenwmEQE3YlL1/EIDUCyTUHW\nj8yqdOt8c9wN/Yx6/AD3b3jIcS5a11njJPgS+Ea4TtGeJikvYZAhDN1AGbTD0sJMwa7Hnvn540i2\ncDRVW+OwP24hrwtlwF6nuriazHqKpmchmtdG3p5w+xDXrdmdphvOhyx1DrTqtwV0bVPENhdKiIYP\nRZTpSnWRTGsIHvtzRdCacOzptBZ/ZQgDTYZ0yNIoMuYIu63/N/kmCgQr3Uta6XDMhugKgyrzMbqL\nEf0JyyeUaUtXQLPGIqHHiSbSlilJV/jSmSMITFxMYOJifKOsSKHAycvwjV6LGIoQmLgYuf9+EHQC\n45ZaPht/FLGgzbnu9usHMe2sJue3MmgngqTzSsPvqShREATQO8vpPjgAvbMcMa8bQdSJaDYpaz5E\nyeqff9RHjCnvJNVcx8hvT2XIWTcQrCjC0HR8Q2zbv01oAbEAIxFGEE1SZoqUaj3nYZWlgIDRVUb9\nwu0U10xi+BWXUX3haEzN4LSRYyw3V7iLqgofCAZjq/sxwJM4URAN3k09ab0XyUICYhCtfgSCaODz\nPBdBCFptRYtIdyTorHuHEdePY+T1c8grr0HrtFaoC5LGa13/a9Wn+UD1kxcIkz+8hGjzRiJ12yke\nV4nZWo2ZyEdtsFauC/6EO8nQFQrkEnxS376rT4IcYRwn+lrcdSS0JToITPwAsbAl67jXuS30cEQf\naM+2cze1xwETqayO2m7LKWwa7uPTdANN8URYSBqHbDuwTxadWSzgkJPgSzjmCsg4lU10w6S1K5nt\n2/BoGL/66H95cfdr+IZtQS49TG17b4FuNYBr/pJTiIWtGCnbqZjRMOwxaEm4dURSrkbRmMg2TWSE\nrVdLKPIX9iKMTMy/dY1KMq2hGabjP0jvH2t1K2Bf57H3aq22GcMmE9leTJZBQku6z05Oo/pcYsi0\na5om27QPneMZ56Smm05biYjlx8kQxsHGiDV+ukxKMygJFNOe7LSc3vZzUvQwd07/PmXBUtY1bSCu\nJhyflqnL6HEr/YMYtLQMQUkR0ssYUTSUAsEimqjQ7BCQqcvkh3yOWapN2YNUbI2RqfkcX1Pc7Ka2\nvc0aE122khp2WvZFQUkC1n2Zmg9Ts7UyOY3cf58zjqbmw+ioRK21Yk3rY41Zi+wyaEk24Q+a9jWW\nEURrGYDokzH0JHHDevam6sfUPUYSqRMpKCPKIsm2CPGG7KALQdIsTSqlOSlDkkYcVbfGNuy3828l\n8zBSOr5CGUFWad/QCKbA2LIa8keU0LlvFSm/pVH7NR9SuJrI7iRp22el2d+JEixGSDRhJMLEGyJE\nOzsQ82yys18fMxVET+mIPgnRL5FuDBJr2YnePJBgXn/UaJp4g/Vd63ER0zQZVjCUkmkVtGx/hdCA\nAqSgQn2T6tQHOBqGaYhgikzNP527Z97Za6w/KXImqeNEz7QaALXddeyvTzBr5HAUWcQwTdbvamWX\nsRTBl8I3YgO6foZTPsuM4o0aCUY4rLaR8XRs3NtMR0RFzG/HN2wLhzPWJU8UxKY9bZgFUTAES1BL\nGrV1liaiyFLWLNgRut64bDUAStIS8IZMXXO0F2HEEmqfC+0aOj3CWe5hOpE0MGTE/A4E0URrHYA4\nYK9DJJpmsGxLA+vbXHNQdyoKFPRpt3c0DJsAvzr6Mt7Zv4yo1sq2/W3UHe6iJXWY/V3uoilkjb31\n3fQvy0OQVYxYAUbMcjTkF+qk6tyxUA+NtGZyeKJNMuTWVerY8zNrVvwnrcAUVcsMI2mOZlPb1kZM\nsEjQSAVI+ZNEEylWbWtyxj/SKRMwcQiktjlKoFrDTPtZta2J4tJimuIttGh1jqCRuwdxoD7G2KKx\nLD68hHV1u0lmggt0mURHPkopSOWHEIuaEUSTZERm8YZ65GQZpilRl96NMtg1H+UXKZjddhK+4u3O\ne2dqCoYdAXQgtpe60IcgWOGaq3c0kd49Cd+kd0FWEfwJS3NTfZjpgKUlFbZhJj35izQFENC7SlGA\nDc2b2W1rST1hZhb3ZQhB9SGHFEJVFax78vcU1BTiZwiioZA+WINv8A4aBx4kujRG08OrkcVWAkWD\n7f6bCIL1jI2khKabTnRVVypCLKVRSMhJRGimglSeMYT6hS+jbA4QGlhAyhQp8hcy8NRx7PvrWjb+\n+SEQBUZdWU1Mq6by5MvZ/7s/ASZyno/h10wkXF1Dqm45+xc8R94IEX9pyF33YgduGKkgwUFhgv3C\n7HxoNUqgjGDxUECkIDKRwZcdpv4v2zE0A8HMY8DkGkYVj2Br/21IAZGSSVY+MkO1EwymLcKQ++2z\nfJP2+IV8PvJ9R89BdizIEcZxomfiPt3Q+fna/wFg+66vc+NFJ7N002GeXriD/hNbwQcYEgnPegBv\nDLXgIQwrB46LB1/5iEElJb3CYREM8kM+IvE0O+vbCFRGMOMFVtiepNFt57/xKyKm5PVhZMI0M06x\nUYihbqTSRkfAb2/Zh1LtWWVraxjrdmVrSQDN3RHIiJketnZB0jBVV8CbSVt9t4Xwxr1trGhfhNLf\nDQWNZKI5PGYYPVKElN/pjpMtZNvadRpbU0gFGnc89CEgIA/cgdLPQK0bgVK9B0HS2LCnlQ17mwlO\n0zE0BdNelCVlTFKOTV9xZtWCkrQCADIhlgemwIS3SeopVDsiTAzY8fSyiqiFaEtYZq3/+v1ylLGg\nNVdbZORP8r2H3wfVj2+MhmkIFuHrisfcZIKkYep5vLu2Dt+QBFIFNBa/j8+ORu3qMvn579ZTPcKA\nEvjj7j9jyvYs3RTRuotQALnM1czSCZln3twJgDK8nEhpI2JGjusy4aAbTpoFTcFI52OqCnKFu+pY\nDEb57RvbMUxLQImhCIEJS6wuaAoYMnpbP+sa2+8gdw90TJ5mMoxihhyyEBD4cuVNPL/5dacdof9u\n0HA0iAyJV593CmJBO4KskvxoNGd8fQ5rDlkmp6SQoPqrJxEQQ3SsPM1qt3wHcIARX/4acngHkhxm\nyJzbMFWLgMMnjySUCgHbCSlBQMNMhigcX05hjbvZ0ujpVwKQ5ytm0KUnOcdPH3cm+Wsklqo6pefN\n6hGZlMcPf/Rz/ufljwhOfTdraIecfROyJqLb79qgS6w6UzunYHRZ7QbS+YxWvkno238EQG0Yilbn\n46Sykfxpg/Xe5o8occcd8JsWKYghWwuzAz0CviMvuP0kyJmkjhM9NYxme/EUwOrtlkq/z07Z0RW3\nPhrTELPWA8Q0T+RSH07NDARRo7Y548j0HJd0CsLWiyLmdSGIJka0yI6McV9cRRYxPITRU8MwNcX5\nMDPn6hPZi9MyUVKxPhbvtXt8KkLPWO8+2hJN2THBdcfSjo9CShTjl3xE7Ygp12nXH1oH91lfdwQw\n7HmPTUKZ64yovZgv0yfJ7QOaYglYKZFdRpO5cJplMhF8KQZV5CMoKUxdol9xIaYukVCTpDSjlwYk\nqnl0pSOkdRVDsAlS87k+EZs0BVm1Z36C7SBOO/0XBNOZFRqpIL1gR1jV19rCV/aadATQ/Jjp7DUp\nGQIEMDqzd5zLmKRMtTdhmLZGYESzU5+bhuSkIDHV7Igv2bTqccYea11Ov9gpngpERhinuj8xmTVm\nEN+ddhWTS6xlzo3avqz7dbQ+XxLBH8eI53PV2aO5ft4Yrj9nalYf8hU3Q22GCH1DtyJIumvCsutD\ndgMA8ny2hpHuPe43zLOiP4JiftbxylAFXzt7FP9x7TR8PVwERrSIQRVh9/304AeXTWf2uH4Y8QKM\nuFun939ZEqgqDZHaPg0z7XdCrzd+uJbdj6+l6qxhDM0fglQ3BQyZorCP//z66VntZL6THGF8zujp\nw2iI9g4DzJCDKGUMltmE0Z3yOK+dUNc+nOJStiD0oiDfeoQZk4UeKbY+fiXtCEFJBiTVTQbo+BWs\n8yWhsPsh2W2kNDvssqvUmg3bGkYqbc+OPfDaoTMLpRw7dg+NwNR8Vqih2Nu0JR6YSZG/iC579ud1\nzJaGM05AeywywlfzuStae2hOVeHSrD44JKpbgtBn5qEKcfucq2GcNtZyHKKkGFQVRvAnMdMBqsvD\nln9BT6NqRtaCO61xMELamrZ/cOhDN4eRprjCTk4jD9yBGIw54ZboEoIvbZmQQvYCxcysug/CyNj0\nzXQAI9l3umovQQCOoAHQu0t7FBYI+qS+NQy7jxnzHcDs/jMYFHfNqpnV1hmU5RXafXClp5EMUxTO\nlqaFRjUPnH4PY0tr+GrNZUiiyPjhpVQXZffPebaagmlaa1YEAcx4mJrBxUiiyPQRgykLlCCLMpMq\nxnN6v9Pd6+PZAj4gWfdZErLfJ1l1vpOwL5R131747b1A8noQRnmwFJ8iMbgqH9MT7ZXeOx40n5OO\nRu8qRcKtt6a6wtJeDYnUllMY4B9ivZeeZyeJApUlIYxIKckNZzghv+ef/0V+/L8/Y9bsU7l50nVU\nYKXqLykIUBLOo+DwaaS2zbBuJWb199PuZZJBziR1nPDmBDJNk4Zo741aMoudRMkua4hZSe/aeoR+\nWoV7C1JB0qwgyZ4mKSAUykT72I7PVBCjuwQpvwMxvx2js5LX619yzgmBBMqAvWjNAx1tYES/MtbU\ntme1ldRS4AO1biT+0WudKKmkpiKIJnpnGXpbf3zDN3kcxiZyv/2YJugd5cjlDS7ZZcpoCromZt2n\noKQQ1TxSSYHSQDFN8WYQNQTbOW9qClXFhXR67tNLQJkP3Aq59Ttj6RdDCAh9ajkAQTGPLpoAN9QV\nXabAlw+m1a+64GKEtIqp+qgsCUKHhGqmSau6syq3Mj2eA7X9CA6tBaxEiMpAPP2zSF0qbUSusHwr\n7sI46/n5hm51H6pNGHpXH9sVO2s4BFJbZiGGO/HXrHVs8pBNGKmdUzBTHmJRA4wqqGFXtxUSbSbz\nrDDqPoRkpm+GZ0+Jr9RcxsPbNwMtWWOZQUAMuved6U86QFFpNokZpokiynxnwjezjs+pnoUoiAzM\nH8Db+5ew0dGIBNAUh6CMeIEVzAGIgsh/zPwXTNNEEiUi8TTPYq3OHhQeTO0u3VnAWJIXphsYUFrI\nLl1CUNKIpjVpKA4UAJZ14Cv9b+Cpv+5AGbKNEZXucwhIQSezzpSKCUiiO26Z8GC1YRh6mxU4kVnT\nlN45lSvPG41QVks0HUUUxKy8bFcN/xq/fnkzCVxLgCSJR9xPZu6g05g76DT7/oWsv/5UBUY0yuDO\neezYZ1kxPqudNXMaxnHCq2E89PJmlh305sI3eWt1LRv2WGaqhO7amJdsbOAHv1nGDx9exs5GTwqM\noxBG5pwv4LaZSbFQG7TTPkiuQDZsQZMJE21OWmTmnWlKth0YYOzAKlcjsDWE1qgdhaTLljARdeJJ\njZitTZi67NEi3BBOMRTB6C7FiJTY5/oQ1rpkCUw5TSaSRzaDpFSdPMme+fkTjoYRkP2UBSqsuPxQ\nt3WNrWHEYgJkZqGiq2GYhoBf8hOQ/S5ZeUgLICSFARPfiI1O/itTU5BECUH3I/iTHEpbfhwjUoxf\nkZBQMFDpSHXgG24984BZCAjEunrvlmZqiiOwM2QBONpeJmIo65oMKegKqe3TUBuGec65c7zivDyM\n7jLmyNeS2jrTLWObpExDdOzhXnxp6JepqL2MxPrTMZNhR/CqDUPRGgdz57Tvc8fU7znlja5S9O5i\nrj3pq9b9emar6r5xaI2DLU3UhALJ3r5W9RKGv1fK7SPt6xGQA5w9+HRqSkZy3divY6ZDVBRnk5Bp\nWuG2smffB1EQHeHtTTJZXZ6P0VnhaGahgFWmqiRkmQP9CcTCVoRUPpVhV7vpX1COmQ6S3jWF2cXn\nO8eDinsf3zz5qqy+OyHzpquBC4Jgj69AMmUwp3oWXxx2DuDuJW/1J8+a3HggScIxbUCW2aUv8zez\njsqnF4Lmd+r6LJDTMI4TXsLYcOgAgWKPM1jS+OP7vdMyI2pZi98+OlBL5vtyBGsfZqfMMcVnkAZS\n26chV1o+hiitWAI0EyapYCZsf4Q/AZjEtThiogitYTi+/vsxRZ3+VQqqItEFDC4rwUzYaQWCEaev\nAKP6l9AgKhiiRgroiEUhz1LfC0tLaAIKC0UKggaGqNEMmKmQa4bx234a2TaJ6bITIRKc/L6V7E0A\nn2l9GPFuxel7hoiCcoCg7MfszkMq6CAw5V3XBxLVPFFNacyEPV6GjF+WCEgBYpLttLZJZkBxMeWF\n5ZRVRjncsMddh4IrkAv9BXTq1jM1Yvmoh0YhjxbxC3kkpC7WR62QWb2rlCADgE6MVO+Pe9zgKjZu\nVjENwVmfotYPR++01kQU0p90ohgj6AkB9chSI1JKuTIQv5yiRavP0h5GDypi5dYmGpu1LDt55h6O\ntD+SLIkEfYqTjfULY6vYUduBblQwqDLfycT7pTMMJ6tvLDmKaVVDgGx7uJkOotZamQgQNQomlAP1\nWRqGqAcJBrJFzbFsJBXyy5wzbSBD+xVwsDHCYqx3yegu4dSThvQyczntiQLzZg2hrDDA5FHlmILA\nTsqI0UhC7GDK6GmcMq4fi1f5EPyWGfCU6slZRBgOevrr6eqXJp3G3sV7OXf4ab3a/ebYq3h++4uc\nVDwJIRxkUKU1+fnxNVN5fel+5kzMznA8fngZm/a2MbgqH38fPgZZFKksDvGFkyqtRcJHwNfOGYUg\nwFVnj7Lv3zpumCY3XXwym/a0UlHUhz/sOJAjjOOE1+mdlTAMPLmaTKSyekQ79YaXDARfAqmkCcGU\n0ZN+O1bd7B0JBc7MOUMY86aP5C9rTDd1sZy2BbLghiEaohUWKqvopo6oBQCBkeo57PIvJL9fC/u7\nLRNKcSjsONsy2TIzff32vAk8vHk1TVFLW+mIRyAPJg/vx1nDJnL3ync4eVQeqxvfdrprpv0Y3SWY\nhkigooVo/UhL+NtOVG+4rtLvAAB+wRK2sYgCPpswbNJSBL/luI8VIgZjDlmYhkB7JIUpZaKabHVe\n0jB1GZ8iEZQDCJLtRLcjdvoVFfLtOeN4v7Z3csSM9jG8tD/rmi3C0BqHgO5DlkTyKSNBA43GXkxV\nIb1zCsnBlmTuy+cwdXg1GzfWYkRKHOe+1jDcIc1pNRXUh8vYF3P74kSSYZkZ7vnWDDRjCk+/s5kV\nmvuuDanMZ+XWJg4191jImVmUeAT7gSQKWcJRkUW+deHYXuXOmzGoz+sD/iM4UL3OXc//kh7o5XQ1\njmEZkyAIzla5M06qpHHdaHZ27USrH8k3bh5z1GsvPc3Vyu78+jSW7/LxwPpHOGfw6cz4wjgg25x2\n/qjZ+D0+hnDQoyF5CCPPH+Cn59zUZ5tTKicwpXJCr+PV5WG+c8m4XscHVoT5169NOeI9SJKAKAp8\ne/5Y5k6p5r+eW9dnuZKCALdcNt75ndEwTNN6v6bVVByxjU+KnEnqOOENq80IIjFtxzlnVvKWHsY3\nzLPdo3fXrSorT5NsBDET1voAb8RGFjL1KVabhUHLFKE1WpFD/lEfWZu/2AIZBGQziOCP4x+zyuqj\nZs0mFXsmnyELsGbw6ApGKmjnWXKJK98fxC/50AXLfNSVtMgvpAQJyZaAbI67EWJg29AN2TJlKBGU\nYZusML+MIOuR7hogKFpj195qvZK+wdtRBlob+CiCD0US0Vtck1p6/1jSuyfRGU31zjBqL37zySJB\nOWiNn5J0MuTmS5Y5r6iPlOgZQd4vz90vORMlZJgmJZnNHsBe7CaSyKQf13rPeDOOVL3Ds/+y6X52\niizik+2oqGSIacp8x/4N1kxXEAQUScHXY0/monzL1JPZutbnmDhs34PZt1SWJfFTRc18UgeqpId6\nXXM8W9VeUHUuh1/194raOhaMLB7G7JbxTCw52TmWeV/KxGqK/IX4PTnegh5S/LTb6h4vvCY3fx/5\n546ETD6549nO9+PwT0UYa3c0H/POcQBbD7Q75ZfVr+LB9Y+x6NAyHlz/GCnVE29tE0Y6Ypt1MkK/\n15oETyimLTTDTV/AsGeUYjCWlSJEtGc8gqSBnKbbb4UaFoes8hlBKYatqCKvfdtPHoKiWnWCE32h\nkD0LvnnC9c6MxIgUIShppPI6pIJ2ezcw2UkpoAzZSjRtEUbYZxGGX/JxwEM+Vr9sG3pmEZG9JsBU\nA4693IuS5BiqJWvG2NSH5q0IPkvDiJSSPlhDet/J6C0DMboqSKV1d92Eo6VpmJpHwxCsPToE0UBr\nGkSRYtmqi/wFWe3oHRVkhG2VhzAymkNbV5Iqv0tahm3Gc/ercG1AWtMgpGgV+b48u+5K+xpXewCL\nMBz7sikwrGBYFqH0lagyg4BPptJj4y7Jt94Hvd0itUtGfLHP6yRJ+FSE8UmEF1ihtgGlp4bxyYWZ\npBl07l7/8QWPgJdf+hOplKvdavUj0CNFTAudC1imrAwET7boz5IvdP3Yd9b07rkuHmX/9Z5wNYwc\nYRw3uqIpHn51Cz96fNUxlVc1nfv+sMEp/7udf2ZXxx5e3P0auzr20Kl6NqXxxzFNHD+AQxR9PS8x\n21fR1OheJwSijrlletVkTi+4xDouq/g9+xP0K7EEXc9QSO8aiJ6z0YyJQBEUfPb+CacNmMWY0lFO\nkYxQcyN2rBfvi0MtJ51cUYdUbS0AG5jfD0mUmFHVW6XOxPR7yUnvLkatHU1Bno/UzsluWU1mgDrN\nEaz0CAkFCAhhZJto9KYh6K3V2e1lNAwlZa9lsO7Xp9gaBm5Kc729Etk28hb63N3xUtunkd7t9qsq\nz1Lj8+QQX5w5BIAR1YUUhcKOZpcJLvBubas1DsJI5KEeHENB8yn4MpEwqp/k5lmkdkzP6rsiiRT5\n8537KAhlayleQpgwIjtqKuCTGFLlRjBl9pEw4wXcPv7HzB14GtPHWPcxdog7K5dFkZKCQK/6jxVe\nshk+wHoX+9v5lE4aUuwIuglcSGr3RBRJ7mXGOnlYySdu9+knH0GNt1G37H94+GFrkezvfvcc3/rW\n17n22q/y5JOPAZBMJvmXf/k+3/jGV7nmmitZuHAhL730B1pbW7jllhu59VbLpKS39yO9/QsUBQp4\n+unf8q1vXcOBxffTtOnPAEwaWUY61sbzj/yEa6/9KtdddzUNDVagygsvPMM111zJN77xVR599DcA\n3HLLDezcaUWfdXV1csUV8wFYuPAv3HXXndxxx//jtttuIZFIcOut3+G6667mmmu+wtKl1t4oM0+u\npLtuHQcW/4qDSx5g6YLHiMfjXHHFRYSD1viVF4hcccX8oxKP6DFJfdb4p/FhJNVjZ3bg43eX09zw\nN9GfsNIhZGyits9h6thiNtmLuX1qCWmlHUHWkIUApcUSHYaAronki4VowMk1QbbssGZANcUj6U5b\nAk0IRrN24qosCvOf189gb9ce/njQTcAnKCrhoEI0oeLvQRhCOjOzFcj35dOWbGdcmWsHlkQBvbOc\nfMqJ2CGTGfIaXjSEfKmQiN6F6E+iNQxl9BlW7Pe4spOy9m8A+NcvzaIsXMi6Vh8v77MIRj00GjNW\nREF/H60NFWiNg5GrDmJqlvbgmizcmZR6aBR6exWhYWEUj3r+X9+awdqdLbyyxNK4fvSVmfxqxyLL\nxODJwqrIEtjpHjIRY0a8wJnRF/o9C6aS2TP/fnmVfHvcNQzKH0CRv5DZ4/tRWRwimlBRa2vQmgc6\ncfEZk9SVZ47gD+5eRUii4BAdgJnI1mgAZFnk4hEXsHhjA2r9CAKTJX51y2zASo8+sMJN5zBxRBnX\nnl/D03aW4IBP4tLThjFpZBmSKFKc73cWjQZ9fgRB4LovnsQVp49g+ZbDbD1g+UkkSeC8GYMYNbDo\nmKJweiLgk5EH7kAqaSSWH6BqqDWrrTIMXm1bReUpJqYJO9M6yiCVuLiLx3Z/gH+CvTo56OPt6Cre\n7rFf16SKcVw6Yt4R273pplvYu28vj//2BYJ+mTVrVlJXV8vjjz+LaZrcccdtbNy4gc7OdsrKyvnF\nLx6wxiIoMHWqyR//+Ht+/etHKSjIfg6yJHDZZV/m2muvJxJP88tf3M3y5Uu56eJZfLTgl3ztm9cx\ne/YcVFXFMAxWrlzO0qVLePzxZ/H5fEQivZOBWnDf5a1bN/Pss38kHA5jGAY/+9kvCYVCdHV1csMN\n32D27DlMGGDyx9YVjD3rFmKqzIzRhYRCISZPnsKm9av4+Y0z+fD9vzLo9LlI0pG1vIwyciI0jH8a\nwhCOsiFRX/i4zYISzqY1JigppEQxWo/V0uEwELdCJwdXF3KQ1Yj57ZSZZdZCsbi12tcvBtEAjZSb\nUVP2kySIqUvZ2VptDCjLo6L4JA7qk1m2qd6J9Mns+RASLPu8IsrcOP4b/O/WRkBFECxhmNbTjCwe\n7tTnUyQSKZPx+kWYFftY3vZ+VnsFvgIiiS7MtB+haYwznhUhd9Z77uAzKQ4UMaLKmtUWxz0fph1m\nmZlBqw3DEQJx1EOjUEb3bVM3U0HMVAhZErPiyPuV5hFQ3FTNAyvCBHblEVNSrjlQl/ErIrKc0T7S\nVuJDXXFsw7Loef378D9MKHcdwZXFlmDND1p+ogxZgDXmfkVi3PBS/uCJjpMlIYvo+oIii4SUIOoB\nq62AX3JCUHuGogIMKHeJLeCT8CkSowdZ2oNXQGSISpFFSgsDTsglWEQmCAKjBvbhwzkGeLUFr9lE\nsl7OvssAACAASURBVO9VEAQrd1MmlTnZ35/8KUI8JVEg6Lee2+rVq1izZjXf/OZVmKZJIpGkrq6W\n8eMn8pvfPMgjjzzEzJmzOeusU0kkbN9cH2q/KAqsW7ea3/3uOVKpJJFIhFEjRzJx4mS6OtuYPdva\nr0NRrDFcu3Y1X/zihfjspd35+fm96uyJadNmEA5b74xhGDz66ENs2LAeURRobW2ho6OdDRvWcdbc\ns9mWCoGaxh+w3rl58y7id797jtmz5/D22wu4444fH7Utx8T88cP5ifFPQxifFD0JwycqpD07waV0\n2xYqWInNArKPZIYwZBVREBDkTNK6MvoNHsRBczVSSSO0jUJHdXwOQTFIDEgZSQT7W/RLfuv1TuYh\n2NpFdbg/pw1w4+0VSeG2U77FB6/+GZ+S4ltf+CK/XW9F0fQXxnDR5AmUBIooCRQjmNZaDEEQuOak\nK1ENFcUjMH2ySCJl7fw3INhjNTBWfDxYkSU+z4KjYr8rdMaW1jC8aIjzO19xhWrGz+Bsh6r5SO+y\nzFmyLGaFFRrRAsRwt+MjUGSxVyoW78zdp0j4hSBxOe5ESpmqgiKJ+GSP2c7ug1fI5UtFdGtdWX6D\noyEv2HutBVhC0DsuVjtiVj/7Qk9C+Tj/gN/TRk9HcpZQFrPr9fb7k06eevVBkdAO1aAdquE/7zzz\niOUWbajn2Td30q8qn3/56iS+c7+Vb+rWG77gEPCngWmaXH31tcyff0mvc0888TwrVizj0UcfYteu\nzVxxxdVHrEfXNO6//xc8+eTzlJWV8+STj5FOW0EeR2oXeo+hJEnOam/rehfBoGuefeedN+ns7OSp\np15AFC0TUyqVdgg/U3OG/8eNm0Bj48/ZsOEjDMNg6NBhHA0nUsP4p/Fh6H1klz0avPtX17VEScaz\nBUU0kw7DdmIHFZ8zSxWVNGWFAYdUTE2hUClCNoIIgTiabpA0kgiGveLYFmppM+msp/BLfnyymJX+\n4bwhczllwIzenTVk0tu/wJTKCc5MMi/gY0TRUEoC1uwzbM/sg36JkBKksIfDt7TQ9jvIEiXB3qYT\nE3vmbkieaByyVro6foi+ftuJ5zIC0jtTVmSRoEf4pXZOJbVthpOCWpFEZ9V8xmneU9AGpCCC5K6+\nNhNhEARnbMFNV+FdxPSV6utJrj271/0eCflHIAxJElGU7D7JkpBFTn2h5wrcjyvvHfujOa5TPUyw\n4SP0+0Qi0z9ZErOIsCeZHStCoRDxuJuwc8aML7BgweskEta3aM3UO2htbcXv93POOefxla98jW3b\nttnX5xGL9Q56EdEQBCgoKCQej7No0XtO+YqKSj78cBH8/+3deXxU1f038M+9d2Yy2ReyEjBCEAWM\nAsomNMgiQcKSFKIsVm1Q3BGiCNIifUqr/YHlKTwqlmKlVV7Sal36M6htQUULYl0ALaKCYkggC4Ts\nyyz3PH/cmTuZbDMJmSQz+bxfr76aO3MzOXNk7ne+59zzPQCsVisaGxswdqz2d50T6FVV2he6pKRk\nHD+u/a133/1Xi7/jVFNTg+joGMiyjM8++wTFxdpNIddcMxbvvvsv2B03ljTUu9qakTELv/jFz5CZ\nOddjPzm/HIQEdf1/8z6TYXT0rozqJgHjnY8L9LBvLUqFMfkkyqqrAARr+0xD+7Y/69qh2Ft3CCkD\nTZgWNxifWBxbVzrWBBgagmBVamG122GxW2CUItEAINhkhFkxo1GthxxdqxW6C43HgCtM2HcmAWeh\nZQfNL/JOt994hT7xuiLnKrzz8WnccO1At3Puy74Sez76AZlt7A54X3Ya3vjwe/w4fTAMBoGYE8lI\n6+cakokLjcF31d9DrQ9zK2kAuLKv5uWTw4zux4C2O194iBGzr7sUj2w76Og7GUMHRmHssHiYTQom\npfXHZ9+U4e2PtbuvDAYZE0Yk4lRxNW64Vpvwbn6hHRAdjeLSH/RbcZ3lLNwDhpZhNL1gpQ2Kw/Uj\nB8JkkPGP/zQpid6GfpFmTB2djJIL9YgMNeHAl45V9HbV7ds/oF38w4KNuHHcJahtsKG8ugH9Isw4\nXlDh2N/EFfhWLRyJ4wUViPOwwKpp37dW7mHdbdfi0LESDLs0BufPu/YM6cqAMbh/BDLGDsS1l7d/\nf78zAzIoknv208kyFRERkUhLuxq33bYQ48Zdh3vvXY5Tp07h7rt/CkALKOvWbUBh4Wk8/fQWyLIE\ng8GIX/96AwBg7twsPPzwcsTGxmHLlm149JbROPhlMa4dMRBz5mTj1ltvRlJSfwwb5vp3//Of/x9s\n2vQ4duz4PYxGIzZs+A3GjZuAEye+wdKlt8JkMmL8+IlYtuxeLFq0BOvWPYp33nkL11wzps33MWPG\nTKxenYc777wVQ4ZcjpQUrXbZoEGDceutufjt//t/UIUEUTAE98y/1vE7N2LHjmcxffoMj/20cNpl\nMBpkZE1qPxPpjD4TMOwdDBhNh6TCgo1AowrVUacJySfRqDrKYjvmHAySAbPHXIG97wNh4SrGDkvA\nB582aIvpVAVGgwwDTIChEjZo6WqwYkY1AHOQAaHGYJxvOA85SCuJ7RwCmjUyDc99qU1sR5paHytN\nv9p1335yXBhyM1suakrqF4qlmcNbPO4UHR6E22+8Qj/eMOVBt+dzhs7FocNVsBamICjW/QP/iwmr\nUdFYqd+R5KSViwaCJDOcMz4hZgNuv9G9fUaDjBCzAXfPc90jP2RApB4wjI45jFszXGU0mo+qhAe5\nByfn4rembWotw5BlCbdmXI7PvynzKmBIkoRbZriX8zjwZTEaLPYWGYZzTD9nyhC3x785XYHf7NJq\nGzkvnsMujcGwSz3fOdS0nERrQ0uDkiIwKCmixW2YXRkwZEnCzVMv83hecJMMo6mLmcN47LENbsc5\nOQuRk7PQ7bH+/ZMxdux4/TguLhxlZdWYP/9mzJ9/s/74ZQOicNkAbUj1jjvuxh133N3i7w0YMBBb\ntmxr8fiSJbdhyZLb3B675JJL8ac/vaQfO1/vxhtn48YbXZP5kZFRePbZP7b6/mbOzMS/votCeVUj\nJqa51vwcOfI5rr9+GkJDPe9pERUW1O5n/WIwYLShpt41BilLEiRZdZQB1z54+hyGI8MwyAaYFCOC\nDWZUWbS7Jupt9XoZa5NRhkEKgiQJ2GXt22WoY4evIKOCEGOIXozQVurKDgaEuYJBRBsZRncINgQj\nvOpK1Kv1LdYGRAZFtJr9yJKMX123Fke/rcCfoN3RpLRykfNmYri5pkUcASDM6Brisl+Id5UfaSXD\naHUMv5PXMOeF2K4K/XZGp7aGl5rOJ3h67821ty6jPT0xJBXUZsDoMyPhF8dxyfrd7zbho48O4skn\nt/Rse9CXAkaTPbj3HzmDlIRwpCS2/MYuhMAbJ97BB7YPIIeNhloTjeo6K2CyA6pZ2zcZTSe9HUNS\njrUNEaYIVFq0Mc16W4Ne9MxkUGCStAuW3VAHBUBEkHaRM5sUfW2EWh8CUedaHxAb7PrW2XSSuic4\n747pSOXLaHMUgg1NbkFu5SLqaYiitQtM8zH60CYBo+l6ioSQJsX3HHNMVbUtV5o3L/zmrbYmwYG2\nA0bTi3dHq4h6muNoS2cDzcVoOiTV1MVkGH2B89+i84q1YsWqnmtMM30m1NubFK/Z+dZxvPHh962e\n91Hxp/jn6X2QjFZHZVSgqs4CSCqEKusZhr4VqbPOk+NiHmkKR621DjbVhjpbvV6O2qBISIrSAkGQ\no8rpgOh+UGQJ/WND9Q2Y1NpITBntWk0sSzIWXp6Nm4e2vBOkuzk3u4+NbGXvhHa43XrZygWvMxmG\ncyhh2mhtTsOstFzwd/nAKIQYQ/QsbdwQbf5mUFIrmVonr2FtTYK3J7RJIb6Oftt2ZkfhIZ3LGDob\ncDojMtQEk9G1SND576Z5JkbuenP39J0Mo9mQVFVdy2+ZAHDgzMf6zwYDYAdQWdsAqZ+AJBRXcT/F\nimnXDMC732jrAYyKI8NwLAYrrTsHi90CYXUsvpMkJEVG4kgFYIMFI+PSMOuydMxcriA4yIC3/q39\nK8m6ZjRmDnatvgaAHzW5lbYn/XTWMMwcl4Kkfh27JdKt5EJrAcPTraetPJ+SGI7fPTCp1QvnM3np\nsFhVRDjWMeRdcy++vXASI/pdgYUTbK0Oz3T2M9pehtHWIGjTINGZfQqeWpEOo6HjLX4mL73TmVRn\nBAcZsOme6/Ry448vG49Gq/2ib+vtK3qohFW7+mzAaG1hnipUFNWc0Y8jwhSUAaiu14afTIoBjfq2\nmlaMuiwW755wZhiOW9kM2sX0bK1294yzrpIkAcFG1wRs7ojF2i2pjv8Cd111Gw6e+Q+mX3pdr/1A\nGRTZbeWxtzxlGJ6+ZbeVgUQ0Wdg2Kj4NpxtO49p+Wplqc5M1b0GKCVc6VrV39Vh+cHs1mbz4wHcm\nYISYO/ex7apd1zoivEmpE4Mic/7CC66Pf++LGH0nYNjdO7+2WcAQQuBc3Xk02i0wS2FoEDVw7pVS\n3dAIAwBFcmyPaTVBMjZiQHyYPodhUrTnzAYtQJyt1Uo06HWOmq0JaLp+AQAujbgEl0a0Xk7a3zXN\nMFobjvC0uYs3m78YZAPuGrMEZWVtlWloX2djdHsXYW8+7ryAUlt6X7joSwGjWQH+2gYb7KoKxXFP\n/rbXv8TnZV/ANASIkhJQLGrgqAQAq2prGTDMtQgJUiDJzoDhvgjvrVPawh1nwDAbFZQ3qT/VlzS9\nM6q1DKM3jGl3tAKrU2hwOwHDizGFjlQhpb6hX4QZZRUNCA9ufYOonuTzgLF//348/vjjEEJg/vz5\nWLZsmdvzZ8+exerVq1FdXQ1VVZGXl4fJkyd3eTtau622tsGm1zb65OsyKLGOrVCFNrlrNDqW6jdZ\nawEAsJkgSY6tVx0BI8jgWEltcJ8QXnx9Gi6cicDll0QBF5IAoNUKr4HMLcNoLWB4uGh2xxDd0IFR\nmH1dCkYPbbmlaXuS+oUi+0eD9HpOD908Er/9y2GPv7d68Sh8W1jZar0o6tvumD0c//jPacydeGlP\nN6UFnwYMVVWxYcMG7Ny5E/Hx8ViwYAGmTZuG1FRX0btt27Zh1qxZWLhwIU6ePIk777wT+/bta+dV\nO6fVgFFvRUSIybUK3BEYZKF9iA3O4W5HUHBmGMFKCCwAqi01gKT9jp5hKO4BY2hiIpKHaIHiipjL\nsHrMciSHJnXZ+/IHngKGp3jQHd/BJUnCj9NTPZ/YijkTB+k/jxgUg4HxYThdWtPupOXll0TrQYao\nqZgIs77TYG/j0wHUo0ePIiUlBcnJyTAajcjMzMTevXvdzpEkCTU1WgmDqqoqJCQktPZSF635HAbg\nmvgur3Ls1OYIDJLzVliDM5Boj+sbGtm1eYoaa22TDEP7HXOTDCM18lL0D3Wt1gSAS8IHtJi/CHSe\n5jA8DUn1ghGrDvGz5hJ5zacZRklJCZKSXN+mExIS8MUXX7idc//99yM3NxcvvPACGhoa8Pzzz/uk\nLc3nMABXwCi+4Cho5pjA/vpULUypgKw4Aogji5BaCRh6kHEU12saMG4amtVr73jqTp7u/fc8h+Fn\nfehnzSXylk8DhjeTfvn5+Zg/fz5uv/12HD58GKtWrUJ+fr7H34uL81yDvqngkJYLuyRFQVxcOGzf\nOfZWcFz8nWXHTWatrr/z8eEpcTh9BLhueAr2lR2BMFmR0j8UZwAMHhCLuLhw1BtdpcFTkhIQE9yx\ndnZGR/uiu9VYXcG6aVunXjsQ+z45jauuSEBkWMv/PjdOuBRvHTyFMWlJ6BfZflG+1l6/pxgdE+hG\nk9Kj7ekNfdFbsC+6hk8DRmJiIs6cca1rKCkpQXy8e4XLV155Bc899xwAYOTIkWhsbER5eTliYtov\nxNbR2ycrK+tbPHa2rBplZdU478wwZOdeDY7yHxYLQs1G1Dke7xcWjGcfmoTvq7/HvjKguPw8RgyK\nxJkCQG1UUVZWjfoGV8mKxiqBsprO3ebpLWdhtd6sssJVkrppW5dMG4KbJg+Gpd6CsvqWCylzJg/G\nvOtSoFpsXr3H3tIXNpv276Wx0bt2+0Jv6YvegH3hcrGB06dzGGlpaSgoKEBRUREsFgvy8/Mxbdo0\nt3P69++PAwe0vRpPnjwJi8XiMVh0RmuT3s4hKYujLpHkGJISqgFCAHZhQ2iwUb9LyigbYTIqCHPs\n81BtrYXVsamSwbFwz9xk0tvQw7Wfeou27oKSJMljjaOeqIFERK3z6RVNURSsW7cOubm5EEJgwYIF\nSE1NxdatW5GWloYpU6Zg9erV+PnPf46dO3dClmX8z//8T5e2obSiHmFmY6tzGM7FexbHN0LnXVJQ\nZUBVYBM2hAUbcM6qfft1VkR17vNQY6lBiFFb2e2sJeVcuEcufW2tQfMd04gChc+/AqenpyM9Pd3t\nseXLl+s/p6am4qWXXmr+a13is2/K8NSrXyA4yICMMQNbPF9Tr627cGYYzklvqDIgZNiEHZEhJkj1\n2oK7CMd+FGGOIFFjrdVrSDkDhizJkCAhOaxv3TrbntZKmgeyhJgQnCqu7nCRRqLeLqDHTM5VarfL\n1jfaWpTDBoAaRwFC5/af+qS3UABVhk21Ys7ES1H+5ccoBRDpKCyoyApCDSGottSgzlYPo2xw26ti\n65QnfPiu/E9fyzCW3DAUybGhmOqopEsUKAI6YDTdx7uhlYBRGvI5vjwXAoujUrnzFlmoMoQqw6ra\ncGliBGKLJZSWa3tdOIWZQlFcp9WLuixqsNteFbLE+kBNdWdJ7d4gLFjbgpYo0AT0lc3WZKK7+Q5t\nMDagMeobbDv6vGsOo+mQlKrApmqRpNJShSDF5DY/Ue7YHQ8ALo/unasye4u+lmEQBaqADhhNM4wD\nXxa7PRfUpISPPofhzDCENofhvAOqylKtz184OSe+r44dgWmXuM/RkLu+lmEQBarAHpJqZx/vsHAJ\nztUBFpsKJeYslIhyCFUCIOlDUnbVjhpLLeIj3YvS3X3V7ThR+T0mJ/fe/St6C/YPUWAI6IBhs7e8\nldYpOFi4AobVDtOQIwAASXYEGVWGKlRcaKyAgEC0OdLt9weE98eA8P6+aHbAYYZBFBgCfEiq7QzD\nFOya09DnMJpy1IYqrdP22o4OiuraxvUhnMMgCgyBnWG0MSQlRxfjbJhrzwJLK3dQCUd5kMJqrbRJ\njJkBo7MYMIgCQ4BnGK0PSRmTvnc7Pnu+rsU5wqIVuztZqZ0bzYDRab1hRz0iuniBHTBayTDWLBmN\n/v3Cmj3qfp7JKGP2tcMAACcqTgHgkNTFemB+GtbfPqanm0FEFyGwh6RayTAGxIUhvCwIxU2315bd\nh6TSBvVDapwROAM02LXV4lHNJr2pY0Zd1rGtT4mo9wn4DMOQ9B2MqYfhzCIUWYLavCpcs4DRaLMj\nxuzaPlOChBCDd/sxEBEFqoAOGFa7DcaB38DQrxgwaHWjFEVCjbXG7TypWcCwWlW3gBFiCGa5DyLq\n8wL6KlituDZvks11SEkIhyJLqLI020xFsQNCm5iV6qKxaPplWikQx94WIUZmF0REAT2H0ShX6T9L\nQXVY/9MxsNitqLc1wCAZYLUCksGmD0n1D03E6uuX6xsfRZjC0FDfoO95QUTUlwV0hmEXrm0/JbN2\n62xFYyUAYHjUlbCVpGjPKTZAEgg3hbntkucsNhiscF8DIqLADhiSVf9ZCnIGjAoAQIw5ErBrq7kl\ng3ObVfeES5G051Vw6zQiosAOGHAFDNlchxMV3+O7ygIAQGxIDITq2C/aETCMzQOGrD1vd5Q5JyLq\nywJ6DkN1ZBjCrkAOq8T//Wyb/lxCWAx+lDYQh6q/ajPDMEjasV20XcSQiKivCOgMwxkw1NqIFs9F\nm6NwSZxjMZ7SesBw1o+KbLYXBhFRXxTgGYYNQgCiLhyIuOD2XHRQJIIUbRclSR+SMrqdM/+yOTAb\nzMhImdo9DSYi6sUCNmD885PTsAkLJFWB2hCqPy5Bwk1D58FsMMNs1AKEpGhzFAbHnIWT2WDG/Mvm\ndF+jiYh6sYAdknrpX98Cih2SasSg8MH64z8fl4f0AdcBAMwGxz6tbWQYRETkErABA3BkDnYFa3Mm\n64/FmGP0n4ONQY7zHHMYknuGQURELgE7JAUAUGxQHftaPDZ+FSobK2FSXFlEkJ5haENSzDCIiNoW\nwAFDhSSrUG1a1pAQEoeEEPcS284AoWcYSgB3BxHRRQrcISnHRDbUtoOAHjAcGYZz3QUREbUUsAFD\nMtcDAISl7TpQpmYZRfOV3kRE5BKwAUMO1kqYq3VtL7prPmfRfOEeERG5BG7ACNEChqhvvn+3iyIp\naLr5HjMMIqK2BdwV8u8ffo+TZ6ogBWu76rWXYUiSBKiKtoESmGEQEbUn4K6Qr3/4PQAgaLgNQpVw\n95yr2/8FVWbAICLyQsAOSUFSIUPB2GEJ7Z6mlzgH12EQEbUncAOGrEISXqzcVl1d0LyWFBERuQRu\nwJBUSN68PWYYRERe8XnA2L9/P2bOnImMjAxs37691XP27NmDzMxMzJkzBw8//HCX/F3JywxDbrJY\nj3MYRERt8+kVUlVVbNiwATt37kR8fDwWLFiAadOmITU1VT/nhx9+wI4dO/CXv/wFYWFhKC8v75o/\n7mWGkZoYjZNV2l4ZvK2WiKhtPs0wjh49ipSUFCQnJ8NoNCIzMxN79+51O+evf/0rFi9ejLAwbb1E\nTExMay/VcbI26e2Js2ItwAyDiKg9Pg0YJSUlSEpK0o8TEhJQWlrqds6pU6fw/fffY9GiRVi4cCE+\n+OCDrvnjkgrJi4BhNrgCBjMMIqK2ebxClpSUICGh/VtT2yKaLqNug91uR0FBAXbt2oUzZ85gyZIl\nyM/P1zOOzhGQZAFZ9RwwghRmGERE3vB4hZw/fz5GjRqFxYsXY8KECR168cTERJw5c0Y/LikpQXx8\nvNs5CQkJGDVqFGRZxoABAzBo0CCcOnUKV155ZbuvHRfX9gpuSFqgMshK++cBiC5yPZ8YH6Wt/vYz\nnt5jX8K+cGFfuLAvuobHgLFv3z7s2bMHv/vd77BhwwYsWbIE8+bN8yoDSEtLQ0FBAYqKihAXF4f8\n/Hxs3rzZ7Zzp06cjPz8fWVlZKC8vxw8//ICBAwd6fO2ysuq2n5RU7f+F3P55AFSLK0CcO1fj8e/2\nNnFx4R7fY1/BvnBhX7iwL1wuNnB6DBgmkwlZWVnIysrCZ599hry8PPz2t79FdnY27r33XvTr16/N\n31UUBevWrUNubi6EEFiwYAFSU1OxdetWpKWlYcqUKfjRj36Ef//738jMzISiKHjkkUcQGRnZ6Tck\nSYCQtYCheDPp3WRIioiI2ubVoH1RURF2796NN998ExMmTEBOTg4++ugjLF26FK+//nq7v5ueno70\n9HS3x5YvX+52vGbNGqxZs6aDTW9JCKFVn3VkGLIXGyIFGRgwiIi84fGKevfdd+Obb77BwoUL8eqr\nryI6OhoAMHr0aOzZs8fnDewIu6rNXUgdyDDMzDCIiLziMWDMmzcPM2bMgKK0vPi++eabPmlUZ9nt\njruy9AyjY3dJERFR2zyuw4iMjERdXZ1+XFVVhYMHD/q0UZ3lzDDgyDAMXgQMs6HtLVyJiMjFY8DY\nuHGj2x1RYWFh2Lhxo08b1Vl21XF3lCPDULyYw+CQFBGRdzwGDCGE29oEWZZht9t92qjOcs1haO1T\nOCRFRNRlPAaM0NBQHDlyRD8+cuQIQkJCfNqoztLnMJyT3l4EDO6BQUTkHY9jNqtWrcJ9992HIUOG\nAABOnDiBp556yucN6wzXkJQWOIQX5c0jTOEwK0EYkzjal00jIvJ7HgPGqFGjkJ+fj8OHD0MIgVGj\nRl3Uwjpf0ie99ZXenst8KLKCJ9N/6ZclQYiIupNXC/ciIyMxefJkX7flojVfh+HVFq0AgwURkRc8\nBozjx49j/fr1OH78OCwWi/74V1995dOGdYbdLgDFCslUrz2gBu4OtERE3c1jwPjFL36BFStW4Ikn\nnsCOHTuwa9cuhIaGdkfbOsyuCgSNOAjZrK0bEQwYRERdxuMV1WKxYMKECRBCID4+HitXruy6TY66\nmF1V9WABAAaVi/KIiLqKx4Ahy9opkZGROH78OC5cuICioiKfN6wzVNV9wyazLbaHWkJEFHg8Dkll\nZmbiwoULWLZsGRYtWgRVVVtUm+0tbM0CRlRocA+1hIgo8LQbMFRVxYQJExAdHY309HR8/PHHaGxs\nvMjtU33HbhcQqgxJVjFWmY/Z4y/t6SYREQWMdoekZFnGz372M/3YaDT22mABAFa7DZKsop+cjNsm\nj0NwEPfoJiLqKh7nMFJTU1FYWNgdbbloFrt2269RMvVwS4iIAo/Hr+Dl5eWYO3currnmGrcaUlu2\nbPFpwzqjwd4IADDKDBhERF3Nq0nvzMzM7mjLRWu0OTIM2djDLSEiCjweA0Z2dnZ3tKNLNNgbAAAm\nmSXLiYi6mseAsXz58lZrLfXGISmLygyDiMhXPAaMKVOm6D83NjbinXfeQWpqqk8b1VnOgBHEDIOI\nqMt1eEjqxz/+Me655x6fNehiOO+SMimc9CYi6modrs4nSVKvvc3Womp3SXGfbiKirtehOQwhBL7+\n+mtMmDDB5w3rjEbHbbVmA4sOEhF1tQ7NYSiKgtzcXIwcOdKnjeqsRlXbByPM2Dv3HCci8mcBdVtt\no3AGjN5bvoSIyF95nMNYtGgRKisr9eOKigosWbLEp43qrEahrcMIN/XODZ6IiPyZx4BRV1eHyMhI\n/TgqKgo1NTU+bVRnWVEPoUoIMXIOg4ioq3kMGKqqoq7OtYtdbW0t7Ha7TxvVWVbRANhMMBqUnm4K\nEVHA8TiHMXv2bOTm5mLRokUAgJdeeglz5871ecM6wyo1QNiCoCgtV6YTEdHF8Rgw7rrrLsTHx2Pf\nvn0QQmDhwoXIysrqjrZ1iF21Q5WsELZwGJUOLy8hIiIPvNphKDs7u9ffLVVcVwoAEFYTFAYMF7dU\nkAAAFDhJREFUIqIu5/HK+sADD6CiokI/vnDhAh588EGfNqoz3jq1FwBgP5/EDIOIyAc8XllPnz6N\nqKgo/Tg6OhoFBQU+bVRnlNaVQVKNUCviOYdBROQDHgOG3W53uyvKarXCYrH4tFGdUW9rgKwaIUGC\nIjNgEBF1NY8BY9KkSVi5ciU++eQTfPLJJ8jLy0N6errXf2D//v2YOXMmMjIysH379jbPe/vtt3HF\nFVfgv//9r9ev3VS9rQGSaoSiyK3u30FERBfH46R3Xl4efv/73+M3v/kNAK221Lhx47x6cVVVsWHD\nBuzcuRPx8fFYsGABpk2b1mI/jdraWrz44oudrlElhECDrQGKGgIDh6OIiHzCY4ZhNBpx//334+mn\nn8YNN9yAv//971i7dq1XL3706FGkpKQgOTkZRqMRmZmZ2Lt3b4vztmzZgjvvvBNGY+d2ymu0WyAg\nALsBBk54ExH5RLsZhs1mw759+/C3v/0Nhw8fhs1mw3PPPed1JlBSUoKkpCT9OCEhAV988YXbOV99\n9RWKi4sxefJk7NixoxNvwbWXN+xGZhhERD7S5tfxJ554Atdffz12796N2bNn4/3330dkZGSHho2E\nEB6ff/zxx7FmzRqvf6c1DTYtYAhmGEREPtNmhvHSSy9h1KhRWLZsGcaPHw8AHZ5MTkxMxJkzZ/Tj\nkpISxMfH68e1tbU4ceIEfvKTn0AIgXPnzuHee+/Ftm3bMGLEiHZfOy4uXP/5glSm/WA3IMhkcHuu\nL+hr77c97AsX9oUL+6JrtBkwPvzwQ/zv//4vNm7ciMrKSmRlZXW46GBaWhoKCgpQVFSEuLg45Ofn\nY/PmzfrzYWFhOHjwoH78k5/8BI8++iiGDx/u8bXLyqr1n8+eLwcA2K0KpGbPBbq4uPA+9X7bw75w\nYV+4sC9cLjZwtjl+ExERgSVLluDVV1/F008/jcrKSjQ0NGDJkiXYvXu3Vy+uKArWrVuH3NxczJ49\nG5mZmUhNTcXWrVvx7rvvtjhfkqTODUk5tmZVbQrnMIiIfEQSHbhCW61W/POf/8Rrr72GP/zhD75s\nl0dlZdUoqCrEy9/+HcNiLkP+9/+E5WQaBpmHY+1PrunRtnUnfntyYV+4sC9c2BcuF5theFV80Mlo\nNGLWrFmYNWvWRf3RrrL7m9fwQ9VpfFd5CoA26W2x9c69OoiI/J1f31IUFRTp/oDdgPOVDT3TGCKi\nAOfXASM2OMbtWNgNqG2w9VBriIgCm18HjBbsnVspTkREnvl1wLCr7vMVwm5AanJED7WGiCiwdWjS\nu7exqe7DT4umXIGJIwb0UGuIiAKbX2cYNuGeYQzpH4MQs1/HQCKiXsuvA0bzISmzSemhlhARBT6/\nDhjlNfVux2YTswsiIl/x64BRXe++5oIZBhGR7/h1wLA3m8MIMjJgEBH5SkAFDFlm4UEiIl/x64Ch\nwhUwbGXJPdgSIqLA59ezxKpQIVQZDZ9NBwSzCyIiX/LrgGEXNkDI2v+IiMin/PpKq0JlsCAi6iZ+\nfbUVsEOofv0WiIj8hl9fbbUMg3MXRETdwa/nMATsgDBg2ugBGDs8vqebQ0QU0Pw8YKiAKuOmqakw\nGrhoj4jIl/x6SEpI2qS3ovj12yAi8gt+faUVjrukZInzGEREvua3AUMVKiAJSP77FoiI/IrfXm2d\nu+1JXIdBRNQt/PZqa3NsniSBk91ERN3BbwOGs1Kt7L9vgYjIr/jt1VYfkmKGQUTULfw4YDgzDAYM\nIqLu4LcBwy60DINDUkRE3cNvr7ZWZ4YhMcMgIuoOfhswnHMYCgMGEVG38NuAYbE7h6QYMIiIuoP/\nBgybFQCgyAwYRETdwW8DRqMzYDDDICLqFn4ZMKpqLbDYtCEpAzMMIqJu4Zf7YSx57C0kDq4AYjnp\nTUTUXXyeYezfvx8zZ85ERkYGtm/f3uL5nTt3IjMzE/PmzcNPf/pTnD171qvXLausAwAYZL+MeURE\nfsenAUNVVWzYsAHPPfcc3nzzTeTn5+PkyZNu5wwfPhyvvvoq3njjDcyYMQMbN2707sUlFQADBhFR\nd/FpwDh69ChSUlKQnJwMo9GIzMxM7N271+2csWPHIigoCAAwcuRIlJSUePfishYwjAoDBhFRd/Bp\nwCgpKUFSUpJ+nJCQgNLS0jbPf+WVV5Cenu7di+sZBucwiIi6g0+/ngshvD73jTfewH//+1+88MIL\nXp0vSdprh4cEIy4uvFPtCxR9/f03xb5wYV+4sC+6hk8DRmJiIs6cOaMfl5SUID4+vsV5Bw4cwPbt\n2/Hiiy/CaDR69+KOISnVBpSVVXdJe/1RXFx4n37/TbEvXNgXLuwLl4sNnD4dkkpLS0NBQQGKiopg\nsViQn5+PadOmuZ1z7NgxrF+/Htu2bUN0dLT3L+4YkjJxDoOIqFv49GqrKArWrVuH3NxcCCGwYMEC\npKamYuvWrUhLS8OUKVOwadMm1NfX48EHH4QQAv3798czzzzj+cUl56Q35zCIiLqDz7+ep6ent5jI\nXr58uf7z888/36nXlWRnhuHlEBYREV0UvywNAgBwTHqbDAwYRETdwS8DRnR4EOcwiIi6mV8GDINB\n1u+SCvL2rioiIroo/hkwZFnPMIKYYRARdQv/DBgGSV+4Z2KGQUTULfwyYCiya0jKzElvIqJu4ZcB\nw6BI+pCUmRkGEVG38NOA0WQOw8A5DCKi7uCXAUNRZEiyCqHKMBm40puIqDv4ZcAwKjKg2AC7AqPB\nL98CEZHf8curraJIkIyNENYgBgwiom7il1dbWRGQDDYIaxAUxS/fAhGR3/HPq63SCAAQVhNkSerh\nxhAR9Q1+GTBUgzNgBPVwS4iI+g6/DBh2uR4AINkZMIiIuotfBgxVbgAAGEVwD7eEiKjv8MuAcVoc\nAQAEqRE93BIior7DLwNGg1QJ27n+CFHjeropRER9hl8GDABQq2IQEsSyIERE3cV/A0ZdOIIZMIiI\nuo1/BgwBiPowmE2sI0VE1F38MmAEl6cBQkFYMEubExF1F78MGOvmLsaVg2Iwd+Kgnm4KEVGf4ZeT\nAEMGRCHv5pE93Qwioj7FLzMMIiLqfgwYRETkFQYMIiLyCgMGERF5hQGDiIi8woBBREReYcAgIiKv\nMGAQEZFXGDCIiMgrDBhEROQVBgwiIvKKzwPG/v37MXPmTGRkZGD79u0tnrdYLFi5ciVmzJiBm2++\nGWfOnPF1k4iIqBN8GjBUVcWGDRvw3HPP4c0330R+fj5Onjzpds4rr7yCyMhI/OMf/8Btt92GTZs2\n+bJJRETUST4NGEePHkVKSgqSk5NhNBqRmZmJvXv3up2zd+9eZGdnAwAyMjJw8OBBXzaJiIg6yacB\no6SkBElJSfpxQkICSktL3c4pLS1FYmIiAEBRFERERKCiosKXzSIiok7wacAQQnT4HCEEJEnyVZOI\niKiTfLqBUmJiotskdklJCeLj41ucU1xcjISEBNjtdtTU1CAyMtLja8fFhXd5e/0V+8KFfeHCvnBh\nX3QNn2YYaWlpKCgoQFFRESwWC/Lz8zFt2jS3c6ZMmYLXXnsNAPD2229j/PjxvmwSERF1kiS8GTe6\nCPv378evf/1rCCGwYMECLFu2DFu3bkVaWhqmTJkCi8WCVatW4auvvkJUVBQ2b96MAQMG+LJJRETU\nCT4PGEREFBi40puIiLzCgEFERF5hwCAiIq/4XcDwVJsq0KxduxbXXXcd5syZoz9WWVmJ3NxcZGRk\nYOnSpaiurtaf+9WvfoUZM2Zg3rx5+Oqrr3qiyT5RXFyMW2+9FbNmzcKcOXPw5z//GUDf7AuLxYKc\nnBxkZWVhzpw5eOqppwAAhYWFuOmmm5CRkYG8vDzYbDb9/ECv16aqKrKzs3H33XcD6Lt9MXXqVMyd\nOxdZWVlYsGABgC7+jAg/YrfbxfTp00VhYaGwWCxi7ty54sSJEz3dLJ/6z3/+I44dOyZmz56tP7Zx\n40axfft2IYQQv//978WmTZuEEEK899574s477xRCCHH48GGRk5PT/Q32kdLSUnHs2DEhhBA1NTVi\nxowZ4sSJE32yL4QQoq6uTgghhM1mEzk5OeLw4cPiwQcfFHv27BFCCPHYY4+Jl156SQghxK5du8T6\n9euFEELk5+eLFStW9Eibfen5558XDz30kLjrrruEEKLP9sXUqVNFRUWF22Nd+RnxqwzDm9pUgeba\na69FRESE22NN629lZ2frfbB3715kZWUBAK6++mpUV1fj3Llz3dtgH4mLi8OwYcMAAKGhoUhNTUVJ\nSUmf7AsACA4OBqB9Y7bZbJAkCYcOHUJGRgYArS/+9a9/AQj8em3FxcV4//33kZOToz/20Ucf9cm+\nEEJAVVW3x7ryM+JXAcOb2lR9QXl5OWJjYwFoF9Ly8nIA7nW5AK1/SkpKeqSNvlRYWIjjx4/j6quv\nxvnz5/tkX6iqiqysLEycOBETJ07EwIEDERERAVnWPtKJiYn6+w30em2PP/44HnnkEb2k0IULFxAZ\nGdkn+0KSJCxduhTz58/Hyy+/DABd+hnxaWmQria4ZKRdrfVPoNXlqq2txfLly7F27VqEhoa2+f4C\nvS9kWcbrr7+Ompoa3HfffS22DQBc77d5X4gAqtf23nvvITY2FsOGDcOhQ4cAaO+v+XvuC30BALt3\n79aDQm5uLgYNGtSlnxG/Chje1KbqC/r164dz584hNjYWZWVliImJAaB9QyguLtbPKy4uDqj+sdls\nWL58OebNm4fp06cD6Lt94RQWFoYxY8bgyJEjqKqqgqqqkGXZ7f06+6Kj9dr8wWeffYZ9+/bh/fff\nR2NjI2pra/H444+jurq6z/UFoGUQABATE4Pp06fj6NGjXfoZ8ashKW9qUwWi5t8Epk6dildffRUA\n8Nprr+l9MG3aNLz++usAgMOHDyMiIkJPRQPB2rVrMWTIENx22236Y32xL8rLy/U7XRoaGnDw4EEM\nGTIE48aNw9tvvw3AvS+mTp0asPXa8vLy8N5772Hv3r3YvHkzxo0bhyeffLJP9kV9fT1qa2sBAHV1\ndfjwww8xdOjQLv2M+F1pkNZqUwWyhx56CIcOHUJFRQViY2PxwAMPYPr06XjwwQdx9uxZ9O/fH1u2\nbNEnxn/5y1/igw8+QHBwMJ544gmMGDGih99B1/j0009xyy23YOjQoZAkCZIkYeXKlbjqqquwYsWK\nPtUXX3/9NdasWQNVVaGqKmbNmoV77rkHp0+fRl5eHqqqqjBs2DBs2rQJRqOxz9Rr+/jjj/HHP/4R\nzz77bJ/si9OnT+P++++HJEmw2+2YM2cOli1bhoqKii77jPhdwCAiop7hV0NSRETUcxgwiIjIKwwY\nRETkFQYMIiLyCgMGERF5hQGDiIi8woBBfu2mm25CdnY2MjMzMWLECGRnZyM7Oxtr167t8Gvdcccd\nXpW7fvTRR3H48OHONLdDjh07hnfeecfnf4fIW1yHQQGhqKgICxYsaLf6qLNUhL94+eWXcfDgQWze\nvLmnm0IEwM9qSRF1xMGDB7Fp0yaMHDkSx44dw3333Yfy8nLs2rVL31BnzZo1GDt2LABg8uTJ2Llz\nJwYNGoTFixdj1KhR+Pzzz1FaWorZs2djxYoVAIDFixfj3nvvxaRJk7Bq1SqEhYXh5MmTKCkpwejR\no/HEE08A0GrzPPLII7hw4QIGDhwIu92OqVOn4uabb3Zr57lz5/DQQw/hwoULAIBJkybhjjvuwDPP\nPIO6ujpkZ2dj3LhxWLNmDT7//HNs3rwZ9fX1AIDly5cjPT0dBQUFWLx4MWbPno1PP/0UFosF69ev\nx+jRo7ulr6mPuJjNOoh6i8LCQjF+/Hi3xw4cOCCGDx8uvvjiC/2xppvLnDhxQlx//fX6cXp6uvju\nu++EEEIsWrRIPPTQQ0IIIaqqqsTYsWNFYWGh/twHH3wghBDi4YcfFrfccouwWq2isbFRzJw5Uxw6\ndEgIIcQ999wj/vCHPwghhDh9+rQYNWqU2L17d4u279ixQzz22GP6cVVVlRBCiL/+9a8iLy/Pre1Z\nWVni/PnzQgghiouLRXp6uqipqRE//PCDuPzyy0V+fr7+3q+//nphs9m870QiD5hhUEAbPHgwrrzy\nSv341KlT2Lp1K0pLS6EoCkpLS1FRUYGoqKgWv3vjjTcCAMLDwzFo0CAUFBQgOTm5xXk33HADDAbt\nozR8+HAUFBRg7NixOHToEH71q18BAAYMGKBnMs2NHDkSL774Ip588kmMGTMGkyZNavW8Tz/9FIWF\nhVi6dKlekFJRFJw+fRohISEIDg7GrFmzAAATJkyAoig4deoUUlNTve0uonYxYFBACw0NdTteuXIl\n1q9fj8mTJ0NVVVx11VVobGxs9XeDgoL0n2VZht1u79B53u6zcM011+C1117DgQMH8Le//Q07duzA\nCy+80OI8IQRGjBiBnTt3tniuoKCgxWOqqgbUXg/U8/xnBpDIA+HF/Rs1NTV6ddLdu3e3GQS6wtix\nY/Wy0kVFRfj4449bPa+wsBBhYWGYNWsW1qxZgy+//BKAtteFs4w5AIwePRonTpzAJ598oj929OhR\n/ef6+nrs2bMHgLZFKQCkpKR07ZuiPo0ZBgUMb75Nr127FsuWLUNSUhLGjRuH8PDwVn+/+Wu19Vx7\n561btw6rV69Gfn4+Bg8ejNGjR7v9PaeDBw/iz3/+MxRFgRACGzZsAABMnDgRf/rTn5CVlYXx48dj\nzZo1eOaZZ7Bp0yZUV1fDarVi4MCBePbZZwEAsbGx+Pbbb5GTkwOLxYLNmzdDURSPfULkLd5WS+Qj\njY2NMBqNkGUZJSUlyMnJwa5duzBw4MAu/1vOu6Q+/PDDLn9tIidmGEQ+8t133+HRRx+FEAKqqmLl\nypU+CRZE3YUZBhEReYWT3kRE5BUGDCIi8goDBhEReYUBg4iIvMKAQUREXmHAICIir/x/apbYj523\no60AAAAASUVORK5CYII=\n", "text/plain": [ - "" + "\u003cmatplotlib.figure.Figure at 0x7f97f1330850\u003e" ] }, "metadata": { "tags": [] - } + }, + "output_type": "display_data" } + ], + "source": [ + "def plot(train, test, label):\n", + " plt.title('MNIST model %s' % label)\n", + " plt.plot(train, label='train %s' % label)\n", + " plt.plot(test, label='test %s' % label)\n", + " plt.legend()\n", + " plt.xlabel('Training step')\n", + " plt.ylabel(label.capitalize())\n", + " plt.show()\n", + " \n", + "\n", + "with tf.Graph().as_default():\n", + " hp = tf.contrib.training.HParams(\n", + " learning_rate=0.05,\n", + " max_steps=tf.constant(500),\n", + " )\n", + " train_ds = setup_mnist_data(True, hp, 50)\n", + " test_ds = setup_mnist_data(False, hp, 1000)\n", + " tf_train = autograph.to_graph(train)\n", + " all_losses = tf_train(train_ds, test_ds, hp)\n", + "\n", + " with tf.Session() as sess:\n", + " sess.run(tf.global_variables_initializer())\n", + " (train_losses, test_losses, train_accuracies,\n", + " test_accuracies) = sess.run(all_losses)\n", + " \n", + " plot(train_losses, test_losses, 'loss')\n", + " plot(train_accuracies, test_accuracies, 'accuracy')" ] }, { + "cell_type": "markdown", "metadata": { - "id": "HNqUFL4deCsL", - "colab_type": "text" + "colab_type": "text", + "id": "HNqUFL4deCsL" }, - "cell_type": "markdown", "source": [ "# 4. Case study: building an RNN\n" ] }, { + "cell_type": "markdown", "metadata": { - "id": "YkC1k4HEQ7rw", - "colab_type": "text" + "colab_type": "text", + "id": "YkC1k4HEQ7rw" }, - "cell_type": "markdown", "source": [ "In this exercise we build and train a model similar to the RNNColorbot model that was used in the main Eager notebook. The model is adapted for converting and training in graph mode." ] }, { + "cell_type": "markdown", "metadata": { - "id": "7nkPDl5CTCNb", - "colab_type": "text" + "colab_type": "text", + "id": "7nkPDl5CTCNb" }, - "cell_type": "markdown", "source": [ "To get started, we load the colorbot dataset. The code is identical to that used in the other exercise and its details are unimportant." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "A0uREmVXCQEw", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "A0uREmVXCQEw" }, - "cell_type": "code", + "outputs": [], "source": [ "def parse(line):\n", " \"\"\"Parses a line from the colors dataset.\n", @@ -1137,7 +1034,7 @@ " A tuple of three tensors (rgb, chars, length), of shapes: (batch_size, 3),\n", " (batch_size, max_sequence_length, 256) and respectively (batch_size).\n", " \"\"\"\n", - " items = tf.string_split([line], \",\").values\n", + " items = tf.string_split(tf.expand_dims(line, 0), \",\").values\n", " rgb = tf.string_to_number(items[1:], out_type=tf.float32) / 255.0\n", " color_name = items[0]\n", " chars = tf.one_hot(tf.decode_raw(color_name, tf.uint8), depth=256)\n", @@ -1169,23 +1066,21 @@ " dataset = dataset.repeat()\n", " if training:\n", " dataset = dataset.shuffle(buffer_size=3000)\n", - " dataset = dataset.padded_batch(batch_size, padded_shapes=([None], [None, None], []))\n", + " dataset = dataset.padded_batch(batch_size, padded_shapes=((None,), (None, None), ()))\n", " return dataset\n", "\n", "\n", - "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv\"\n", - "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv\"\n", + "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/train.csv\"\n", + "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/test.csv\"\n", "data_dir = \"tmp/rnn/data\"" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "waZ89t3DTUla", - "colab_type": "text" + "colab_type": "text", + "id": "waZ89t3DTUla" }, - "cell_type": "markdown", "source": [ "Next, we set up the RNNColobot model, which is very similar to the one we used in the main exercise.\n", "\n", @@ -1193,17 +1088,19 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "9v8AJouiC44V", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "9v8AJouiC44V" }, - "cell_type": "code", + "outputs": [], "source": [ "def model_components():\n", " lower_cell = tf.contrib.rnn.LSTMBlockCell(256)\n", @@ -1227,12 +1124,13 @@ " Returns:\n", " A Tensor of shape (max_sequence_length, batch_size, output_size).\n", " \"\"\"\n", - " hidden_outputs = []\n", - " autograph.utils.set_element_type(hidden_outputs, tf.float32)\n", + " hidden_outputs = tf.TensorArray(tf.float32, size=0, dynamic_size=True)\n", " state, output = cell.zero_state(batch_size, tf.float32)\n", + " initial_state_shape = state.shape\n", + " initial_output_shape = output.shape\n", " n = tf.shape(chars)[0]\n", " i = 0\n", - " while i < n:\n", + " while i \u003c n:\n", " ch = chars[i]\n", " cell_output, (state, output) = cell.call(ch, (state, output))\n", " hidden_outputs.append(cell_output)\n", @@ -1261,50 +1159,51 @@ " A Tensor of shape (batch_size, 3) - the model predictions.\n", " \"\"\"\n", " (chars, length) = inputs\n", - " chars_time_major = tf.transpose(chars, [1, 0, 2])\n", + " chars_time_major = tf.transpose(chars, (1, 0, 2))\n", " chars_time_major.set_shape((None, batch_size, 256))\n", "\n", " hidden_outputs = rnn_layer(chars_time_major, lower_cell, batch_size, training)\n", " final_outputs = rnn_layer(hidden_outputs, upper_cell, batch_size, training)\n", "\n", " # Grab just the end-of-sequence from each output.\n", - " indices = tf.stack([length - 1, range(batch_size)], axis=1)\n", + " indices = tf.stack((length - 1, range(batch_size)), axis=1)\n", " sequence_ends = tf.gather_nd(final_outputs, indices)\n", + " sequence_ends.set_shape((batch_size, 128))\n", " return relu_layer(sequence_ends)\n", "\n", "def loss_fn(labels, predictions):\n", " return tf.reduce_mean((predictions - labels) ** 2)" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "JjK4gXFvFsf4", - "colab_type": "text" + "colab_type": "text", + "id": "JjK4gXFvFsf4" }, - "cell_type": "markdown", "source": [ "The train and test functions are also similar to the ones used in the Eager notebook. Since the network requires a fixed batch size, we'll train in a single shot, rather than by epoch." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "ZWQMExk0S6X6", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "ZWQMExk0S6X6" }, - "cell_type": "code", + "outputs": [], "source": [ "def train(optimizer, train_data, lower_cell, upper_cell, relu_layer, batch_size, num_steps):\n", " iterator = train_data.make_one_shot_iterator()\n", " step = 0\n", - " while step < num_steps:\n", + " while step \u003c num_steps:\n", " labels, chars, sequence_length = iterator.get_next()\n", " predictions = model((chars, sequence_length), lower_cell, upper_cell, relu_layer, batch_size, training=True)\n", " loss = loss_fn(labels, predictions)\n", @@ -1319,7 +1218,7 @@ " total_loss = 0.0\n", " iterator = eval_data.make_one_shot_iterator()\n", " step = 0\n", - " while step < num_steps:\n", + " while step \u003c num_steps:\n", " labels, chars, sequence_length = iterator.get_next()\n", " predictions = model((chars, sequence_length), lower_cell, upper_cell, relu_layer, batch_size, training=False)\n", " total_loss += loss_fn(labels, predictions)\n", @@ -1340,16 +1239,14 @@ " # Here, we create a no_op that will drive the execution of all other code in\n", " # this function. Autograph will add the necessary control dependencies.\n", " return tf.no_op()" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "iopcs5hXG2od", - "colab_type": "text" + "colab_type": "text", + "id": "iopcs5hXG2od" }, - "cell_type": "markdown", "source": [ "Finally, we add code to run inference on a single input, which we'll read from the input.\n", "\n", @@ -1357,17 +1254,19 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "DyU0wnnAFEYj", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } - } + }, + "colab_type": "code", + "id": "DyU0wnnAFEYj" }, - "cell_type": "code", + "outputs": [], "source": [ "@autograph.do_not_convert(run_as=autograph.RunMode.PY_FUNC)\n", "def draw_prediction(color_name, pred):\n", @@ -1389,16 +1288,14 @@ " draw_prediction(color_name, pred)\n", " # Create an op that will drive the entire function.\n", " return tf.no_op()" - ], - "execution_count": 0, - "outputs": [] + ] }, { + "cell_type": "markdown", "metadata": { - "id": "Nt0Kv5OCHip0", - "colab_type": "text" + "colab_type": "text", + "id": "Nt0Kv5OCHip0" }, - "cell_type": "markdown", "source": [ "Finally, we put everything together.\n", "\n", @@ -1406,218 +1303,132 @@ ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "-GmWa0GtYWdh", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "output_extras": [ - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {}, - {} - ], - "base_uri": "https://localhost:8080/", - "height": 668 + "height": 415 }, - "outputId": "61f4af1d-c81e-44db-9079-1a7b8ed8ce58", + "colab_type": "code", "executionInfo": { + "elapsed": 15536, "status": "ok", - "timestamp": 1522345877153, - "user_tz": 240, - "elapsed": 75500, + "timestamp": 1531750946373, "user": { - "displayName": "Dan Moldovan", - "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", - "userId": "112023154726779574577" - } - } + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "-GmWa0GtYWdh", + "outputId": "2e7a9856-9809-43a3-8b43-3c8514ea43e9" }, - "cell_type": "code", - "source": [ - "def run_input_loop(sess, inference_ops, color_name_placeholder):\n", - " \"\"\"Helper function that reads from input and calls the inference ops in a loop.\"\"\"\n", - "\n", - " tb = widgets.TabBar([\"RNN Colorbot\"])\n", - " while True:\n", - " with tb.output_to(0):\n", - " try:\n", - " color_name = six.moves.input(\"Give me a color name (or press 'enter' to exit): \")\n", - " except (EOFError, KeyboardInterrupt):\n", - " break\n", - " if not color_name:\n", - " break\n", - " with tb.output_to(0):\n", - " tb.clear_tab()\n", - " sess.run(inference_ops, {color_name_placeholder: color_name})\n", - " plt.show()\n", - "\n", - "with tf.Graph().as_default():\n", - " # Read the data.\n", - " batch_size = 64\n", - " train_data = load_dataset(data_dir, train_url, batch_size)\n", - " eval_data = load_dataset(data_dir, test_url, 50, training=False)\n", - " \n", - " # Create the model components.\n", - " lower_cell, upper_cell, relu_layer = model_components()\n", - " # Create the helper placeholder for inference.\n", - " color_name_placeholder = tf.placeholder(tf.string, shape=())\n", - " \n", - " # Compile the train / test code.\n", - " tf_train_model = autograph.to_graph(train_model)\n", - " train_model_ops = tf_train_model(\n", - " train_data, eval_data, batch_size, lower_cell, upper_cell, relu_layer, train_steps=100)\n", - " \n", - " # Compile the inference code.\n", - " tf_inference = autograph.to_graph(inference)\n", - " inference_ops = tf_inference(color_name_placeholder, lower_cell, upper_cell, relu_layer)\n", - " \n", - " with tf.Session() as sess:\n", - " sess.run(tf.global_variables_initializer())\n", - " \n", - " # Run training and testing.\n", - " sess.run(train_model_ops)\n", - " \n", - " # Run the inference loop.\n", - " run_input_loop(sess, inference_ops, color_name_placeholder)" - ], - "execution_count": 0, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ - "('Successfully downloaded', 'train.csv', 28010L, 'bytes.')\n", - "('Successfully downloaded', 'test.csv', 2414L, 'bytes.')\n", - "Step 0 train loss 0.37890616\n", - "Step 10 train loss 0.18515904\n", - "Step 20 train loss 0.0892782\n", - "Step 30 train loss 0.07883155\n", - "Step 40 train loss 0.08585831\n", - "Step 50 train loss 0.09302989\n", - "Step 60 train loss 0.089012615\n", - "Step 70 train loss 0.07275697\n", - "Step 80 train loss 0.06644974\n", - "Step 90 train loss 0.0854013\n", - "Test loss 0.13216865Colorbot is ready to generate colors!\n", - "\n", + "Test loss 0.138294\n", + "Colorbot is ready to generate colors!\n", "\n", "\n" - ], - "name": "stdout" + ] }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ - "" + "\u003clink rel=stylesheet type=text/css href='/nbextensions/google.colab/tabbar.css'\u003e\u003c/link\u003e" + ], + "text/plain": [ + "\u003cIPython.core.display.HTML at 0x7f97ee42bb90\u003e" ] }, "metadata": { "tags": [ "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ - "" + "\u003cscript src='/nbextensions/google.colab/tabbar_main.min.js'\u003e\u003c/script\u003e" + ], + "text/plain": [ + "\u003cIPython.core.display.HTML at 0x7f97ee42be10\u003e" ] }, "metadata": { "tags": [ "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "text/plain": [ - "" - ], "text/html": [ - "
" + "\u003cdiv id=\"id1\"\u003e\u003c/div\u003e" + ], + "text/plain": [ + "\u003cIPython.core.display.HTML at 0x7f97ee42bd90\u003e" ] }, "metadata": { "tags": [ "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"b102d936-3379-11e8-ac70-0242ac110002\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"borderColor\": [\"#a7a7a7\"], \"tabNames\": [\"RNN Colorbot\"], \"initialSelection\": 0, \"location\": \"top\", \"contentHeight\": [\"initial\"], \"elementId\": \"id1\"});\n", - "//# sourceURL=js_e223a56194" + "window[\"a6045494-8903-11e8-99f9-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"location\": \"top\", \"borderColor\": [\"#a7a7a7\"], \"initialSelection\": 0, \"elementId\": \"id1\", \"contentHeight\": [\"initial\"], \"contentBorder\": [\"0px\"], \"tabNames\": [\"RNN Colorbot\"]});\n", + "//# sourceURL=js_02f896cbda" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2ab810\u003e" ] }, "metadata": { "tags": [ "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"b103532a-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_b8c6a821fb" + "window[\"a6045495-8903-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_7e8f9f77a0" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2ab710\u003e" ] }, "metadata": { "tags": [ "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"b105b28c-3379-11e8-ac70-0242ac110002\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_44805e254b" + "window[\"a6045496-8903-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_5531553c2f" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2ab6d0\u003e" ] }, "metadata": { @@ -1625,17 +1436,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"b106197a-3379-11e8-ac70-0242ac110002\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_a63d3c6c47" + "window[\"a6045497-8903-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_d1f809ec17" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2ab990\u003e" ] }, "metadata": { @@ -1643,17 +1454,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"b1069f44-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"b106197a-3379-11e8-ac70-0242ac110002\"]);\n", - "//# sourceURL=js_7e203b8bce" + "window[\"a6045498-8903-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"a6045497-8903-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_3a3123cadb" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2aba50\u003e" ] }, "metadata": { @@ -1661,17 +1472,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"b1070f38-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_d53293d4a7" + "window[\"a6045499-8903-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_1a0e1f7d6f" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2ab890\u003e" ] }, "metadata": { @@ -1679,17 +1490,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c6d90d5c-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"b105b28c-3379-11e8-ac70-0242ac110002\"]);\n", - "//# sourceURL=js_3000dc2c05" + "window[\"a8e54762-8903-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"a6045496-8903-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_6213539615" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2abad0\u003e" ] }, "metadata": { @@ -1697,17 +1508,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c6da872c-3379-11e8-ac70-0242ac110002\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_4136f669a3" + "window[\"a8e54763-8903-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_0bd7f95c6e" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2ab950\u003e" ] }, "metadata": { @@ -1715,17 +1526,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c6dac868-3379-11e8-ac70-0242ac110002\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_2f70dd9aee" + "window[\"a8e54764-8903-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_215f004f6b" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2abb10\u003e" ] }, "metadata": { @@ -1733,17 +1544,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c6db07d8-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c6dac868-3379-11e8-ac70-0242ac110002\"]);\n", - "//# sourceURL=js_7226726048" + "window[\"a8e54765-8903-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"a8e54764-8903-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_a06186c8ad" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2aba90\u003e" ] }, "metadata": { @@ -1751,17 +1562,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c6dcc6fe-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_72e7709865" + "window[\"a8e54766-8903-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_383fbaae67" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ee2abc50\u003e" ] }, "metadata": { @@ -1769,14 +1580,14 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAFZCAYAAADHDNdrAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAB9JJREFUeJzt3E1Lle0ax+HTF4jeEAyMBhE0DawI\nwsCH0AIlaGBWNJBo0CDoA0TQhmDXuKAGDioiCA2KlEAlnl05FD9Co8BeaGCQoBDa2jPZsXt4Bvu/\n0+o4Rmvd1zW4rsmP84bFamo0Go0C4H/WvNYHAPhVCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKDy\nUxgeHq5Dhw7V4OBgPXz4sHp7e+vWrVt15cqVOnnyZN2/f78ajUbdvn27+vr6qqenp65du1YrKytV\nVfXhw4e6cOFC9fX1VV9fX01PT1dV1dzcXHV3d9eDBw/q+PHj9ccff9TExMRaXpWfWOtaHwD+zuvX\nr+vOnTs1MTFRbW1tdf78+dW16enpGh8fr/b29hobG6upqal6/Phxbdy4sS5evFgjIyM1NDRUly5d\nqv3799fw8HC9efOmTp8+XVNTU1VV9enTp2pubq5nz57V5ORk3bhxo44dO7ZW1+UnZkJl3Zudna2D\nBw9WR0dHbdiwoQYHB1fX9u7dW+3t7VVV9fLlyxocHKytW7dWa2trnTp1qp4/f16Li4s1MzNT586d\nq6qqXbt21YEDB1an1OXl5Tpx4kRVVe3Zs6fevXv3Yy/IL8OEyrr3+fPnamtrW/2+ffv21c//+Xxh\nYaHu3r1bjx49qqqqlZWVam9vr4WFhWo0GnXmzJnVvYuLi9XV1VVVVS0tLbVp06aqqmpubq6vX7/+\nX+/Dr0tQWfe2bNlSi4uLq98/fvz43X0dHR3V29tbQ0ND3zxfXl6ulpaWevLkSW3evPmbtbm5ufyB\n+W155Wfd6+zsrJmZmZqfn68vX77U2NjYd/cdOXKkxsfHa2lpqaqqRkdH6+nTp9Xa2lqHDx+u0dHR\nqqpaWlqqy5cv1/v373/YHfg9CCrrXmdnZw0MDNTAwECdPXu2enp6vrvv6NGj1dPTUwMDA9Xf318v\nXryo7u7uqqq6evVqzc7OVn9/fw0MDNTOnTtrx44dP/Ia/Aaa/B8qP4NGo1FNTU1VVfXq1au6efPm\nX06qsFZMqKx78/Pz1dXVVW/fvq1Go1GTk5O1b9++tT4W/BcTKj+FkZGRunfvXjU1NdXu3bvr+vXr\ntW3btrU+FnxDUAFCvPIDhAgqQMi6+WH/kX8eXesjAPytf/3jz79cM6EChAgqQIigAoQIKkCIoAKE\nCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQI\nKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgq\nQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpA\niKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCI\noAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIig\nAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAC\nhAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKE\nCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQI\nKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgq\nQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpA\niKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCI\noAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIig\nAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAC\nhAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKE\nCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQI\nKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgq\nQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpA\niKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkBI\nU6PRaKz1IQB+BSZUgBBBBQgRVIAQQQUIEVSAEEEFCBFUgBBBBQgRVIAQQQUIEVSAEEEFCBFUgBBB\nBQgRVIAQQQUIEVSAEEEFCBFUgBBBBQgRVIAQQQUIEVSAkH8D1Aj8lNhhe7QAAAAASUVORK5CYII=\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACL9JREFUeJzt3F+IlXUex/Gv2ziiBRGVOQaFd2JBzOg5aiH+IZGoJgmM\n/uhVGIlgFE0QEYHdFQaRGBJ10VX0D5TAi8jKomGmILsYjEAkmBwbRIxKGDV/e7G7w8ouux9jd911\nX6+rcx6e85zveS7e/J7zb0ZrrRVA4A8XewDgf4dgADHBAGKCAcQEA4gJBhATDC6Kp59+urrdbt13\n3301OjpaK1euvNgjERCMS9yaNWtqeHj4Yo9xnq+++qqGh4frs88+q7fffruqqmbMmHGRpyIhGPxH\n/fbbb/XDDz/U9ddfX7NmzbrY43CBBOMS9tRTT9XExERt2bKlBgYG6vXXX69vvvmm7r///up0OrV+\n/foaHR2d3n/Tpk318ssv1wMPPFADAwP18MMP18mTJ6uq6vTp0zU0NFRLly6tTqdTGzZsqBMnTlRV\n1eTkZG3ZsqWWLl1a69atq3feeWf6mDt37qxt27bV0NBQLVmypN5777169tln6+DBgzUwMFA7d+78\nm7kPHz5cmzZtqk6nU3fffXft37+/qqrGx8er0+lM7/fMM8/UrbfeOn1/aGio3nzzzX/tSeR8jUva\n6tWr2/DwcGuttWPHjrVut9sOHDjQWmvtiy++aN1ut504caK11trGjRvb2rVr2/fff9+mpqbaxo0b\n244dO1prrb311lvt0UcfbVNTU+3cuXNtbGys/fLLL6211h566KG2ffv2dvr06Xbo0KG2bNmy6ed8\n5ZVX2k033dQ++uij1lprU1NT7f33328PPvjg9IwjIyNt5cqVrbXWzpw509auXdt2797dzpw504aH\nh1t/f387cuTI9OsZGxtrrbW2bt26dvvtt7fDhw+31lpbtWpVO3To0L/rVNJas8L4P9D+/HOhvXv3\n1qpVq2rFihVVVbV8+fK6+eab69NPP53e9957760bbrihent764477qhDhw5VVVVPT0+dPHmyjhw5\nUjNmzKhFixbV5ZdfXseOHauvv/66nnzyyZo5c2YtXLiwNmzYUHv27Jk+Zn9/f61Zs6aqqnp7e//h\nrAcPHqxTp07VI488Uj09PbVs2bJavXp1ffDBB1VVtWTJkhodHa3jx49XVdW6devqyy+/rPHx8fr1\n119r4cKF/6Kzxt/Tc7EH4D/n6NGjtW/fvvr444+r6k8hOXv2bC1fvnx6n2uuuWb69uzZs+vUqVNV\nVXXPPffUsWPH6oknnqiff/65BgcH6/HHH6/Jycm68sora/bs2dOPmz9/fo2NjU3fnzdvXjzj5ORk\n9fX1nbdt/vz5NTk5WVVVnU6n9u/fX9ddd111u93qdru1Z8+e6u3trcWLF1/A2eD3EIxL3F9/+tDX\n11fr16+v7du3X/Bxenp6auvWrbV169Y6evRobd68uRYsWFC33XZb/fTTT3Xq1KmaM2dOVVVNTEzU\n3Llz/+4M/8zcuXNrYmLivG1Hjx6tBQsWVFVVt9utF198sfr6+qrT6dTAwEA999xz1dvbW91u94Jf\nFxfGJckl7tprr63x8fGqqhocHKz9+/fX559/XufOnaupqakaHR2tH3/88Z8eZ2RkpL777rs6d+5c\nzZkzp3p6euqyyy6refPmVX9/f7300kt1+vTp+vbbb+vdd9+twcHB3zXvLbfcUnPmzKnXXnutzp49\nWyMjI/XJJ5/UnXfeWVVVN954Y82aNav27t1bnU6nrrjiirr66qvrww8/PO8NUf49BOMSt3nz5tq1\na1d1u93at29f7dq1q3bv3l3Lly+v1atX1xtvvDH9Hsc/WgkcP368tm3bVosXL6677rqrli5dOh2F\nHTt21Pj4eK1YsaK2bdtWjz322HmXORdi5syZ9eqrr9aBAwdq2bJl9fzzz9cLL7wwvcKo+tMq46qr\nrpq+1PlLKBYtWvS7npPcjNb8gQ6QscIAYoIBxAQDiAkGEPuv/R7GxN7+iz0C/F/rG/z6b7ZZYQAx\nwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQE\nA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMM\nICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCA\nmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBi\nggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJ\nBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYY\nQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAA\nMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHE\nBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhAT\nDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEww\ngJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOI\nCQYQEwwgNqO11i72EMD/BisMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBi\nggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiP0RoqNMBlokHDIAAAAASUVORK5CYII=\n", "text/plain": [ - "" + "\u003cmatplotlib.figure.Figure at 0x7f97ee42bb90\u003e" ] }, "metadata": { @@ -1785,17 +1596,17 @@ "outputarea_id1", "user_output" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c70592aa-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c6da872c-3379-11e8-ac70-0242ac110002\"]);\n", - "//# sourceURL=js_25c3aaf79a" + "window[\"a8e54767-8903-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"a8e54763-8903-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_28bd08ac10" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ea9efc10\u003e" ] }, "metadata": { @@ -1803,17 +1614,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c70842c0-3379-11e8-ac70-0242ac110002\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_984c56b816" + "window[\"a8e54768-8903-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_ae2887f57d" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ea9efb50\u003e" ] }, "metadata": { @@ -1821,17 +1632,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c708dec4-3379-11e8-ac70-0242ac110002\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_e0451a1217" + "window[\"a8e54769-8903-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_608805a786" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ea9ef710\u003e" ] }, "metadata": { @@ -1839,17 +1650,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c7092726-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c708dec4-3379-11e8-ac70-0242ac110002\"]);\n", - "//# sourceURL=js_7aa23d7385" + "window[\"a8e5476a-8903-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"a8e54769-8903-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_3d87cf7d0f" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ea9efa90\u003e" ] }, "metadata": { @@ -1857,17 +1668,17 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c7099044-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_5722756ddb" + "window[\"a8e5476b-8903-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_5e91101199" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ea9efa50\u003e" ] }, "metadata": { @@ -1875,24 +1686,149 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" }, { - "output_type": "stream", - "text": [ - "Give me a color name (or press 'enter' to exit): \n" - ], - "name": "stdout" + "data": { + "text/html": [ + "\u003cdiv class=id_45185901 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e" + ], + "text/plain": [ + "\u003cIPython.core.display.HTML at 0x7f97ee42bd90\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"a8e5476c-8903-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_45185901 span\");\n", + "//# sourceURL=js_f43052a94e" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f97ea9ef750\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"a8e5476d-8903-11e8-99f9-c8d3ffb5fbe0\"] = window[\"a8e5476c-8903-11e8-99f9-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", + "//# sourceURL=js_bfc0fb76ce" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f97ea9efb10\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"a9e9b8b0-8903-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_45185901 input\");\n", + "//# sourceURL=js_7f167283fa" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f97ea9ef610\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"a9e9b8b1-8903-11e8-99f9-c8d3ffb5fbe0\"] = window[\"a9e9b8b0-8903-11e8-99f9-c8d3ffb5fbe0\"].remove();\n", + "//# sourceURL=js_016ae4bf21" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f97ea9ef250\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"a9e9b8b2-8903-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_45185901 span\");\n", + "//# sourceURL=js_e666f179bc" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f97ea9ef550\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"a9e9b8b3-8903-11e8-99f9-c8d3ffb5fbe0\"] = window[\"a9e9b8b2-8903-11e8-99f9-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", + "//# sourceURL=js_cbb9d14aec" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f97ea9ef1d0\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" }, { - "output_type": "display_data", "data": { "application/javascript": [ - "window[\"c7baac12-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c70842c0-3379-11e8-ac70-0242ac110002\"]);\n", - "//# sourceURL=js_cdd622e58f" + "window[\"a9e9b8b4-8903-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"a8e54768-8903-11e8-99f9-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_2967a79665" ], "text/plain": [ - "" + "\u003cIPython.core.display.Javascript at 0x7f97ea9ef1d0\u003e" ] }, "metadata": { @@ -1900,21 +1836,98 @@ "id1_content_0", "outputarea_id1" ] - } + }, + "output_type": "display_data" } + ], + "source": [ + "def run_input_loop(sess, inference_ops, color_name_placeholder):\n", + " \"\"\"Helper function that reads from input and calls the inference ops in a loop.\"\"\"\n", + "\n", + " tb = widgets.TabBar([\"RNN Colorbot\"])\n", + " while True:\n", + " with tb.output_to(0):\n", + " try:\n", + " color_name = six.moves.input(\"Give me a color name (or press 'enter' to exit): \")\n", + " except (EOFError, KeyboardInterrupt):\n", + " break\n", + " if not color_name:\n", + " break\n", + " with tb.output_to(0):\n", + " tb.clear_tab()\n", + " sess.run(inference_ops, {color_name_placeholder: color_name})\n", + " plt.show()\n", + "\n", + "with tf.Graph().as_default():\n", + " # Read the data.\n", + " batch_size = 64\n", + " train_data = load_dataset(data_dir, train_url, batch_size)\n", + " eval_data = load_dataset(data_dir, test_url, 50, training=False)\n", + " \n", + " # Create the model components.\n", + " lower_cell, upper_cell, relu_layer = model_components()\n", + " # Create the helper placeholder for inference.\n", + " color_name_placeholder = tf.placeholder(tf.string, shape=())\n", + " \n", + " # Compile the train / test code.\n", + " tf_train_model = autograph.to_graph(train_model)\n", + " train_model_ops = tf_train_model(\n", + " train_data, eval_data, batch_size, lower_cell, upper_cell, relu_layer, train_steps=100)\n", + " \n", + " # Compile the inference code.\n", + " tf_inference = autograph.to_graph(inference)\n", + " inference_ops = tf_inference(color_name_placeholder, lower_cell, upper_cell, relu_layer)\n", + " \n", + " with tf.Session() as sess:\n", + " sess.run(tf.global_variables_initializer())\n", + " \n", + " # Run training and testing.\n", + " sess.run(train_model_ops)\n", + " \n", + " # Run the inference loop.\n", + " run_input_loop(sess, inference_ops, color_name_placeholder)" ] }, { + "cell_type": "markdown", "metadata": { - "id": "AHJ2c47U-A5W", - "colab_type": "text" + "colab_type": "text", + "id": "AHJ2c47U-A5W" }, - "cell_type": "markdown", "source": [ "# Where do we go next?\n", "\n", - "Autograph is available in tensorflow.contrib, but it's still in its early stages. We're excited about the possibilities it brings — write your machine learning code in the flexible Eager style, but still enjoy all the benefits that come with running in graph mode. A beta version will be available soon -- stay tuned!" + "AutoGraph is still in its early stages, but is available in [tensorflow.contrib](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/autograph). We're excited about the possibilities it brings. New versions will be available soon — stay tuned!" ] } - ] + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "Dev Summit 2018 - Autograph", + "provenance": [ + { + "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K", + "timestamp": 1522238054357 + }, + { + "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ", + "timestamp": 1521743157199 + }, + { + "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-", + "timestamp": 1520522344607 + } + ], + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 2", + "name": "python2" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/tensorflow/contrib/autograph/examples/notebooks/graph_vs_ag_vs_eager_sum_speed_test.ipynb b/tensorflow/contrib/autograph/examples/notebooks/graph_vs_ag_vs_eager_sum_speed_test.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..32742bec7ee4a412aabb6640b5a1329353ebfc9d --- /dev/null +++ b/tensorflow/contrib/autograph/examples/notebooks/graph_vs_ag_vs_eager_sum_speed_test.ipynb @@ -0,0 +1,519 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "moMkWaT_TTHi" + }, + "source": [ + "This Colab illustrates the differing overhead* between a custom, vectorized graph operation and a loop over a tensor\n", + "that computes the same function. The loop is implemented in TensorFlow Eager mode using Python syntax and control-flow, and using AutoGraph which takes a python function and converts it into graph mode. In AutoGraph the Python loop is converted into a tf.while_loop.\n", + "\n", + "The actual computation, summing a small number of scalar values, takes very little time to compute, so the graphs below are showing the overhead of the differing approaches. As such, this is more of a \"micro-benchmark\" than a representation of real-world performance of the three approaches.\n", + "\n", + "*Note the differing scales of the included plots" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "a0X_rfvuav98" + }, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "EdxWv4Vn0ync" + }, + "outputs": [], + "source": [ + "!pip install -U -q tf-nightly" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "erq3_S7QsjkU" + }, + "outputs": [], + "source": [ + "from __future__ import absolute_import\n", + "from __future__ import division\n", + "from __future__ import print_function\n", + "\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "import matplotlib.pyplot as plt\n", + "import math\n", + "import time\n", + "import random\n", + "from colabtools import adhoc_import\n", + "from tensorflow.contrib import autograph as ag\n", + "from tensorflow.python.framework import function" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1JgnsXooa2RP" + }, + "source": [ + "### Testing boilerplate" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "UyD5LLjVZzny" + }, + "outputs": [], + "source": [ + "# Test-only parameters. Test checks successful completion not correctness. \n", + "burn_ins = 1\n", + "trials = 1\n", + "batches = 2\n", + "max_elements = 2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4_NBL0RQa8gY" + }, + "source": [ + "### Speed comparison parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "Yq6daecyiJV5" + }, + "outputs": [], + "source": [ + "#@test {\"skip\": true} \n", + "burn_ins = 3 # Batches not counted in the average\n", + "trials = 10 # Batches run per vector-size (and averaged)\n", + "batches = 1000 # Number of random vectors summed over per trial\n", + "max_elements = 100 # Vectors of size 0 to this-1 will be executed and plotted" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "fiR8m13CbKH2" + }, + "source": [ + "### Random input" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "d8vrTlyNXuxc" + }, + "outputs": [], + "source": [ + "# Construct a random num x 1 tensor\n", + "def get_elements(num):\n", + " return tf.random_uniform(shape=(num, 1), maxval=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ILJ6SbF3bXFQ" + }, + "source": [ + "## Graph mode" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "vovRf597X55n" + }, + "outputs": [], + "source": [ + "def tf_sum(elements):\n", + " # Using custom vectorized op\n", + " return tf.reduce_sum(elements)\n", + "\n", + "def run_trial(num):\n", + " elements = get_elements(num)\n", + " return tf_sum(elements)\n", + "\n", + "\n", + "\n", + "graph_means = []\n", + "for num in range(max_elements):\n", + " with tf.Graph().as_default():\n", + " durations = []\n", + " foo = run_trial(num)\n", + " \n", + " with tf.Session() as sess:\n", + " \n", + " for _ in range(burn_ins):\n", + " for _ in range(batches):\n", + " sess.run(foo)\n", + " \n", + " for _ in range(trials):\n", + " \n", + " start = time.time()\n", + " for _ in range(batches):\n", + " sess.run(foo)\n", + " \n", + " duration = time.time() - start\n", + " durations.append(duration) \n", + " \n", + " graph_means.append(np.mean(durations)) " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 301 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 278, + "status": "ok", + "timestamp": 1532447361278, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "Jm9Blkyx90Eq", + "outputId": "d83cd51f-7e56-4d73-f7df-bb157dee46df" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAa8AAAEcCAYAAABwNTvaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3WdgVFXegPFnZtI7kJ5QQwlCKIGERGroEoSAYZFVEFAR\ngV1XXHvbFWEtK6xlUVgRXMuLiqBSZQUh9E5ChxRIn/ReJjNz3g8hNxkzCUMJkHh+X2DmtnPPnXv+\n95R7ohJCCCRJkiSpGVHf6QRIkiRJ0vWSwUuSJElqdmTwkiRJkpodGbwkSZKkZkcGL0mSJKnZkcFL\nkiRJanaaJHitWLGCV199tSl23WINHz6cAwcONPlxXnzxRd5///0mP87dIikpiUmTJtGvXz++/PLL\nO52cFi0+Pp4HHnjgTifjmsaPH8+RI0du6T5/b/eVObcqX9966y3Wrl17zfWsbmTnffv2RaVSAVBe\nXo6NjQ1qtRqVSsUbb7zBE088cSO7vW5paWmMGDGCs2fPolY3n0rkiy++iLe3N0899dSdTspdoSmv\n46effsqAAQPYsGGD2eVbt27l888/5/z58/Tq1Yv//ve/JsvPnTvHyy+/TGJiIgEBASxevJjAwEBl\n+bvvvsu6detQqVQ88MADPPvssxZv29SmT5/OxIkTiY6Ovi3H++CDD3jsscea9BjDhw9n8eLFhIeH\n3/A+Nm3adAtT1LgVK1bwySefoFKp0Ov16PV67OzsEELg7+/Pxo0bCQwMxN7eHpVKhRACa2trDh8+\nfNvSeCPMlWG3Kl8fffRRpkyZQnR0NFZWDYeoGyopTpw4wfHjxzl+/Di+vr6sWLFC+W78+PE3nOjr\nJYRQLrjUfDXldUxPT6dz584NLndzc2PmzJnMmTOn3rKqqirmz59PVFQUR44cISoqinnz5qHX6wFY\nu3YtO3fuZOPGjfz000/s2rWLb775xqJtm4PruR7Z2dkcOnSIESNGNGGKbo7BYLjtx3ziiSeUsvHv\nf/87ffv25fjx45w4cYKNGzcCoFKp+Omnn5Tv73TguhP5VJeHhwcBAQHs3Lmz0fVu+jFXCFHvR/7R\nRx8pT6BpaWkEBgayfv16hg0bxoABA1i7di2nTp1iwoQJhIaGsmjRIpPt161bx7hx4xgwYACPPfYY\n6enpZo89ffp0APr3709wcDCxsbEIIVi+fDnDhw9n4MCBvPDCC5SUlJjdPj8/n7lz5xISEsKAAQN4\n+OGHlWWBgYGkpKQon+s2Cxw+fJihQ4fy6aefcu+99zJ48GB++eUXdu/ezZgxYxgwYAArVqwwe8xv\nv/2WjRs38umnnxIcHMyTTz6pLDt37hwTJkwgJCSEhQsXotPplGW//vorUVFRhISEMG3aNC5cuGB2\n/wAJCQnMnj2bAQMGcN9997F169YG121sv8OHD2fVqlVMmDCBvn378sorr5Cbm8vjjz9OcHAws2fP\npri4WFn/5MmTPPjgg4SEhBAVFWVyE06fPp3333+fadOmERwczKOPPkpBQYGyDEyvY3JyMtOnT6d/\n//6Eh4ezcOHCBs9hx44djB8/ntDQUGbMmEFiYiIAjzzyCIcOHeKNN94gODiYK1eu1Ns2PDycsWPH\n4uHhUW/Z4cOHMRgMzJgxA2tra6ZPn44QgoMHDwLwww8/MHv2bDw9PfH09GTWrFlKDe/QoUONblvX\nli1b6jW3rVmzhnnz5gGg0+l4++23iYiIYNCgQfztb38z+W388ssvREVF0a9fP0aPHs3evXtZtmwZ\nx44dY9GiRQQHB/Pmm28CcPz4caKjowkJCWHKlCmcOHHC5BotW7aMadOm0adPH1JTU1m/fj0jR44k\nODiYkSNHNvh0vW/fPnr06IGNjQ0AK1eu5M9//rPJOm+++SaLFy8GoKSkhJdffplBgwYxdOhQ/vWv\nf5mUI99++y3jxo0jODiY8ePHc+7cOZ577jkyMjJ48sknCQ4OZtWqVUD965+QkKDsZ/jw4fznP/9R\nfsMGg8GkiT4kJITg4GCCg4Pp27cvgYGBSnnT2L1x9uxZJk+eTL9+/Xj66aeprKw0my+WsPQhobGy\nraac/fbbbxk8eDCDBw9m9erVJtuuXLmSUaNGERYWxtNPP01RUZHJtuvWrSMiIoKZM2cC8NRTTzFo\n0CBCQkKYPn26kq8NlWF181Wn07F48WIGDx7MkCFDWLJkCVVVVUBt+bl69Wql/Fy/fr3JuYaEhLBr\n165rZshNiYiIEPv37zf57sMPPxTPPvusEEKI1NRU0a1bN/H666+LyspKsW/fPhEUFCTmz58v8vLy\nRGZmpggPDxdHjhwRQgjxv//9T4wePVokJiYKg8EgPv74YzF16lSzx05NTRWBgYHCaDQq33333Xdi\n9OjRIjU1VZSVlYkFCxYoafmt9957T7z++uvCYDAIvV4vjh49qiwLDAwUycnJyucXXnhB/Otf/xJC\nCHHo0CFxzz33iOXLlwu9Xi++/fZbERYWJp555hlRVlYmLl26JIKCgkRKSorZ49bdV918nDJlisjO\nzhaFhYXivvvuE2vXrhVCCHH69GkRHh4u4uLihNFoFBs2bBARERFCp9PV23dZWZkYOnSo2LBhgzAa\njeLs2bNiwIABIj4+vt6xr7XfiIgIMXXqVJGbmyu0Wq0IDw8XkyZNEufOnRM6nU7MmDFDfPTRR0II\nITIzM0VoaKiIiYkRQgixf/9+ERoaKvLy8oQQQjz88MNi1KhR4sqVK6KyslI8/PDD4r333mvwOi5c\nuFB88sknQgghKisrxbFjx8zmZWJioujTp4/Yv3+/0Ov14j//+Y8YNWqUqKqqUo773Xffmd22rm+/\n/VZMnz7d5LvVq1eLxx9/3OS7J554QqxevVoIIUS/fv1EbGyssuzUqVMiODjYom3rKi8vF8HBweLK\nlSvKdw888IDYsmWLEEKIN998Uzz55JOiqKhIlJaWirlz54qlS5cKIYSIjY0V/fr1U+5BrVYrEhMT\nzZ57QUGBCAkJET/99JMwGAxi06ZNIiQkRBQUFCjrR0REiPj4eGEwGERxcbEIDg4Wly9fFkIIkZ2d\nrfyOfuvtt98Wb7zxhvI5LS1N9OnTR5SUlAghhDAYDGLgwIFKfj355JPi9ddfFxUVFSI3N1dMmTJF\nfPPNN0IIIbZs2SKGDBkiTp8+LYQQIjk5WaSnpwshqn+TBw4cUI5zresfEREhoqKiRGZmpqisrFS+\n+22ZJYQQS5cuFQ8//LDQ6/WN3hs6nU5ERESIzz//XOj1erFt2zbRo0ePevf0b61fv1788Y9/rPd9\nt27dTMqahjRWttWUswsXLhQVFRXiwoULIiwsTDnP1atXi6lTpwqtVit0Op147bXXxMKFC022ff75\n50V5ebmST99//70oKysTOp1OLFmyREycOFFJS0NlWM3x/vWvf4mpU6eKvLw8kZeXJ6ZOnSref/99\nIURt+fnhhx8KvV4vdu3aJXr37i2KioqUfW3fvl1MmjSp0fy4LR1FKpWK+fPnY2Njw7333ou9vT2R\nkZG0atUKLy8v+vfvz9mzZwH45ptvmDNnDh07dkStVjNnzhzOnz9PRkZGYwFY+f+mTZuYOXMmfn5+\n2Nvbs3DhQrZs2YLRaKy3nZWVFdnZ2aSmpqLRaOjXr5/ZfZpjbW3N3Llz0Wg0jBs3jvz8fB555BHs\n7e3p3LkznTt3brR2ZM6MGTNwd3fHxcWFiIgIzp07B8B3333Hgw8+SFBQECqViqioKGxsbIiNja23\nj19//RV/f3+ioqJQqVR0796d0aNHs23btnrrWrLfhx9+mNatW+Pp6Un//v3p3bs3gYGBWFtbM2rU\nKCWNP/30E8OGDWPw4MFAdY2mZ8+e7N69W9nX5MmTadeuHTY2Ntx3333KtjXq5rmVlRVpaWlotVps\nbGwIDg42m2dbt25l2LBhhIeHo9FoePTRR6moqDCpUdyosrIynJ2dTb5zcnJSnnZ/u9zZ2ZmysjKL\ntq3Lzs6OESNGKLWay5cvk5SUpDTBrVu3jhdffBFnZ2ccHByYM2eOsu66deuIjo5W+oA8PT3p2LGj\n2fPZtWsXHTp04P7770etVhMZGUmnTp349ddflXUmTZpEQEAAarUajUaDRqPh4sWLVFZW4u7uTkBA\ngNl9FxcX4+joqHz29fXlnnvu4ZdffgHgwIEDODg40KtXL3JyctizZw8vvfQStra2tG7dmkceeYTN\nmzcr5/TYY4/Ro0cPANq2bYuPj4+y77q/E0uu/4wZM/Dy8lJqheZs2bKFTZs28eGHH6LRaBq9N2Jj\nY9Hr9cyYMQONRsOYMWPo2bNng/u2xKRJkwgJCSE0NFSpnf6WJWXbn/70J2xtbenatSuTJ09W8vTb\nb7/lL3/5C56enlhbWzN//nx+/vlnZVuVSsWf/vQn7OzslHyaPHky9vb2yvrnz59vsBXLXFrnz59P\nq1ataNWqFQsWLODHH39UlltbWzNv3jw0Gg1Dhw7FwcGBpKQkZbmjo6NJq445NzRg40a0adNG+b+d\nnR3u7u7KZ1tbW+WmT09PZ/Hixbz99ttAbX+IVqs1+QE3JCsrC19fX+Wzn58fer2enJwcPD09TdZ9\n7LHH+PDDD5k9ezYqlYopU6aY7fswx83NTRm0YmdnZ/Yca87JUnW3t7e3Jzs7G6jOkx9//FEZLSeE\nQK/Xk5WVVW8f6enpnDx5ktDQUGVdg8FAVFSU2XWvtd+6abK1ta33ue5127p1q1IQ1uyrbsd63Wtu\nb2/faP4899xz/Otf/yI6OlrplzI3ku2311ulUuHj44NWq21w35ZycHCod7OWlJTg5ORkdnlJSQkO\nDg4WbftbkZGRvPPOO8ybN49NmzYxcuRIbGxsyMvLo7y83OTcjUajUoBnZmYydOhQi87nt3kF1UGm\nbl55e3sr/7e3t2fZsmWsWrWKl156iX79+vHcc8/RqVOnevt2cXGhtLS03jlt3ryZiRMnsmnTJqU/\nPD09Hb1ez6BBg4Daroea+zszM5N27drd0DmZu/51z8mcs2fPsmjRIlavXo2bm5uSxsbuDS8vL5N9\n+Pn5WZTehmzYsIG2bds2uk5jZRtUn3vdc/X19eXSpUvK+SxYsEAZECWEwMrKStkWTPPJaDSydOlS\nfv75Z/Lz81GpVKhUKvLz8xv8DTeWVl9fX5Nyxc3NzWRwlp2dncnvp7S0tN7D32/dtuBlKW9vb558\n8kmLBn7UBI+6PD09TfrI0tLSsLKyMik4azg4OPD888/z/PPPk5CQwPTp0+nVqxdhYWHY29tTXl6u\nrJudnX3Nm6CpeHt7M3fuXItGcfr4+DBgwAClP+BW7deS40ZFRfHGG29c97bmrmObNm2UvtBjx44x\na9YsQkND693gnp6eyg1aIyMj45Zcqy5durBmzRqT7y5evKj00XXu3Jnz588TFBQEVPdZdunSpdFt\n6/ar1jVo0CBefPFFzp8/z+bNm3nppZcAaNWqFfb29mzatKnewxdUX8O6fbN1/TZfPT092b59u8l3\n6enpDBkypMFtBg4cyMCBA9HpdCxbtoxXX32Vr776qt6xunXrZvJkDTB27FjeeecdtFotv/zyizKY\nxcfHB1tbWw4dOmT22nt7e5OcnGzxOd3M9c/Ly2PBggW8/vrrJiNBG7s3jhw5Uu/hKD093eKAe6Ma\nK9syMjIQQpCRkaHUvDMyMpTfjI+PD0uWLKFv37719puWlgaY5u3GjRv59ddf+fzzz/H19aW4uJiQ\nkJDrSmtaWppSU09PTzf7+21IQkLCNUfm3pZmw2s1wdU1bdo0VqxYQXx8PFDdHGGuyQugdevWqNVq\nkx96ZGQka9asITU1ldLSUpYtW0ZkZKTZIdi7du1StnVwcFCaSaB6wMamTZswGo3ExMTc0vdC3N3d\nGyxwzPnDH/7A2rVriYuLA6qbpHbv3m225jJs2DCSkpL48ccf0ev1VFVVcerUKWUQw43u91omTJjA\nzp072bt3L0ajkcrKSg4fPmxRDcjcddy2bZuyrYuLC2q12uw1vO+++9i1axcHDx5Er9ezatUqbG1t\n6dOnj0XpNhqN6HQ69Hq9yf8BQkNDUavVfPHFF+h0OuUpfMCAAQBERUWxZs0atFotWq2WNWvWMHny\n5Ea3DQsLM5uOmuand955h6KiIgYOHAigtAgsWbKEvLw8ALRaLXv37gUgOjqa9evXc/DgQYQQaLVa\n5Vr/9nc2dOhQrly5wubNmzEYDGzZsoXExEQiIiLMpik3N5edO3dSXl6OlZWVco+YM3DgQM6cOWMy\nkKR169aEhITw4osv0rZtW6XG5uHhwcCBA1myZAklJSUIIUhJSVHusSlTpvDZZ59x5swZAJKTk5Vu\nA3d3d1JTU5Vj3Mz1NxgM/OlPf2LChAmMHTvWZFlj90afPn2wsrLiiy++wGAwsH37dk6dOnXN490s\nS8q25cuXU1FRwaVLl1i/fj2RkZEATJ06laVLlyrBLy8vjx07dijb/baMLi0txcbGBhcXF8rKynjv\nvfdMgtu1yrDIyEg+/vhj8vLyyMvLY/ny5UycONHicz1y5IjJQ5U5Nx28zD05XWudxj6PHDmSxx9/\nnKeffpr+/fszYcIE9uzZY3a/dnZ2zJ07l2nTphEaGkpcXBzR0dFMnDiRhx9+mFGjRmFvb88rr7xi\ndvvLly8zc+ZM+vbty7Rp03jooYeUp4uXX36ZnTt3EhISwubNmxk5cuRNnWNd0dHRxMfHExoayoIF\nC665fs+ePVm0aBFvvPEGoaGhjBkzpsH3lhwdHfnss8/YsmWLMurovffeMylULN3v9ZyTt7c3y5cv\nZ8WKFYSHhxMREcFnn32m3BSNbWvuOp46dYopU6YQHBzM/Pnzefnll802zXTs2JF3332XRYsWER4e\nzq5du/jkk0+U90Ou9fv88ccf6dWrF2+88QbHjh2jd+/eygv21tbWLF++nA0bNhAaGsr69etZvny5\nsu8HH3yQiIgIJkyYwIQJE4iIiOAPf/iDRduaExkZyYEDB7jvvvtMCqS//vWvtG/fnj/84Q/079+f\n2bNnc/nyZQB69erFkiVLWLJkCf369WPGjBlKQT9jxgy2bdvGgAEDWLx4MW5ubnzyySesWrWKsLAw\nVq1axYoVK3B1dTWbV0ajkdWrVzNkyBDCwsI4cuQIr7/+utm0t2nThrCwMKWPq8b48eM5cOAA999/\nv8n3b7/9NlVVVURGRhIaGspTTz2lNJOPHTuWuXPn8swzzyjXv7CwEIA5c+awfPlyQkNDWb169Q1d\n/5rvMjMzOX78OJ9//rky2jA4OJjMzMxG7w1ra2s+/PBD1q9fT2hoKNu2bWP06NENXtdrsaQMBSwq\n20JDQxk1ahSzZs3iscceU5rtH3nkEUaMGMHs2bPp168fDz74oBKYzaUhKioKHx8fhgwZwvjx4+vV\n2K5Vhs2bN4+ePXsyYcIEJk6cSM+ePZk7d65FeZCVlUVCQsK1y1xxPdUiSZKkBiQkJPDCCy/w3Xff\n3emk/O6kpaUxcuRIzpw506wmbDDn7bffpl27dkybNq3R9WTwkiRJauaa62xDN+P3cZaSJEktnKXN\njy2FrHlJkiRJzY6seUmSJEnNzl33ntfN0OsN5Odf/zDvlqhVKweZF1fJvKgl86KWzItaHh6NvxB8\nN2pRNS8rK/PvoPweybyoJfOilsyLWjIvmrcWFbwkSZKk3wcZvCSpBSsq1VFcVv8FdUlq7mTwkqQW\n7J9rT/LBurhrryhJzUyLGrAhSVIto1GQnlOKrY3s25FaHlnzkqQWqri8CqMQlFfq0Rvq/z07SWrO\nZPCSpBaqsKT2T9MXl1XdwZRI0q0ng5cktVAFJbUDNeSgDamlkcFLklqoujWvIhm8pBamyYNXTEwM\nY8eOZcyYMaxcubLe8qNHjzJ58mR69OhR76+8QvWfTh8yZAhvvvlmUydVklqUgtI6Na9S2WwotSxN\nGryMRiOLFi1i1apVbNq0ic2bN5OQkGCyjq+vL2+99Va9P1ZX4/333yc0NLQpkylJLZJpn5eseUkt\nS5MGr7i4ONq3b4+fnx/W1tZERkaa/OlpqA5eXbt2NTud/+nTp8nLy2PQoEFNmUxJapEK6/R5FckB\nG1IL06TBS6vV4uPjo3z28vIiKyvLom2FELz99ts899xzyL/aIknXr6BU1ryklqtJX1K+maDz9ddf\nM2zYMLy8vK5rX81xduSmIvOi1u8xL4rL9djbaiivNFCpF0oe/B7zoiEyL5qvJg1e3t7epKenK5+1\nWi2enp4WbXvixAmOHz/O119/TWlpKXq9HkdHRxYuXNjodtnZxTeV5pbCw8NZ5sVVv8e8EEKQV1iB\nn4cjqVkl5BaUkZ1d/LvMi4bIvKjVHIN4kwavoKAgkpOTSUtLw8PDg82bN7N06dIG169bu/rnP/+p\n/H/Dhg2cOXPmmoFLkqRqZVdn1WjlZEthSaUcKi+1OE3a56XRaHj11VeZPXs248ePJzIykoCAAD74\n4AN+/fVXAE6dOsXQoUPZtm0br7/+eoOjDiVJslxBcXV/l6uTDS4ONnLAhtTiqEQLGw0hmwGqySaR\nWr/HvDhzOY/31p5kwsAOJKQVcuZyPp88MxQ/X7ffXV405Pf4u2hIc2w2lDNsSFILVPOOl5uTLc6O\nNoCc31BqWWTwkqQWqOYdL1cnG5ztrwavctnvJbUcMnjdJpU6A2u2nufoecvec5Okm1EzKa+bky0u\njtYAFMkpoqQWRAav2+RkfA4xseks/+E0KzeeobRCFiRS0ym8+oKyq6MNzg41zYay5iW1HPIvKd8m\nlzOLAGjjYsfBM1ouJBfwxIQedG3rdodTdnuUlFdx8lIO+cUV5JfoqKjUM3FQR7xaO9zppLVIBSU6\nVICLow3ODtU1L9nnJbUkMnjdJpczilEBf58dwi9HU/lp32VWbz3PP+aE3emk3RZfbr/A4XOmTaat\nXGyZMqzzHUpRy1ZYUomzgzVWGjUuV2te8l0vqSWRzYa3gVEILmuL8XF3xMHOmgmDOtLZz4WsvDKq\n9IY7nbwmpzcYiUvIpY2LLU//oTcvPhwMQFp26R1OWctVUKrD1ckWoE7Nq+UEr5LyKjnn6e+cDF63\ngTavjEqdgQ7ete9SeLdxQADa/PI7l7Db5GJKARU6A327eBDUqQ1d/N1wdbIhNbvkTietRarQ6anU\nGXB1qq5x1fZ5tYxmw0upBfz5/T38c+1Jsgta/v0jmSeD121wObP6RUiT4NXaEYDM3LI7kqbbKTY+\nF4Dend2V7/w9nMgrqqRMDly55WqGybs5Vte87Gw0WGnULabmde5KvvLvq6sO8cvRFIyyFva7I4PX\nbXA5oyZ4uSjfeV8dqJCZ1/KDV1xCDrY2GpPBKf4e1cE7VTYd3nIFJbVTQwGoVCpcHK1bzFD5mt/M\nlIgArDVqvv7lEv/dduEOp0q63eSAjdvgcmYRKhW09XJSvvNu8/sIXpl5ZWjzy+nX1QNrq9pnJX+P\n6rxIyy753Yy4vF56g5HSCj2uV2fIsFRhae07XjWc7W3IyL27HxTKKqrYeiiZvKJKist1lJRVMaS3\nL8P6+pmsl5Zdgr2tFWND23FvTx+WfHGUg2cyeWhUF6ytNHco9dLtJmteTcxoFFzRFuPn7oitde2N\n5e5qh0atQtvCg1dsfA4AvTq3Mfm+JnjJmlfDvvrfRZ7/ZD9FpdfX3KdMylsn6Dk7WqPTG6mo1N/S\nNN5K2w4ns/nAFQ6cyeR0Yh6XM4v5+UiKyTpVegPavHL8PBxRqVS4OtrQr6snOr2R88kFdyjl0p0g\ng1cTy8gtRVdlNGkyBLDSqPFwsyczr6xFj5pSgleAu8n3Pm0cUKmQgzYaUFpRxf7TmeiqjCSmF13X\ntgVmal41w+VrmhTvNlV6I7tPpuNoZ8U/ngjj44VD6d6+Fdq8MpMX+tNzyjAKQVuP2laMoIDqB6O4\nq32rUjVtfhknLmbf6WQ0GRm8mpgyWMOn/qzN3q0dKK3QU1x+Y30Rd3vQK6uo4lJqIR19XOo1fdlY\na/Bq5UBqduldfx51ZeSWsnzDKfKLmzYIHDidSZXeCNS+4G6pwt/0eUHtcPnrrcXdLkfPZ1FcVsXg\n3r54tXLA1kZDJ9/qB76kOsG75mGnps8UoIu/K/a2GmITcprVb6mprd5yng/XnyIhvfBOJ6VJyODV\nxMwN1qihDNpoZMThtzvj+b9fLtX7Pq+ogqc+2Mv2w8m3KKW33umkPAxGQe+ANmaX+3s4Ul6pNwkE\nCWmFrN1xCYPReLuSeV02xCRy9EI2/zuaUm9ZUZmO9JybbwYVQhATm45GrQJqH4AsVTuvYW3wuttr\nXjuOp6ICIur0b9UEr8SM+sHLr07Ny0qjpkeH1uQUVpDxOxi9a4n84koupVQ3o/58uP5vtSWQwauJ\nXc4sQqNW0dbTsd6yaw3aiE8tZNvhZP53NKVeoRgTm05JeRUHz2pvfaJvQn5xJTmF5RSWVHL8apNF\n3SHyddX2e9U2Ha7deYntR1K4eBf2X+QUlnPs6jntP52J3lAbYIUQfPh9HH9bfZicm3z3KDG9iNTs\nUvp29aCNiy2XM4quq0ZRWKrD0c7KZPCCU03N6y4MXkkZRSSmF9G7szsebvbK9518XQFMmk1r+kjr\n1rygtlk6LkE2HQIcOZ+FADRqFccuZLXI9+Fk8GpCeoOR5KwS/DwczY6Camy4vBCC73cnKJ93nUxT\n/m8wGtkTlwHAFW3xLZvk1yjETTW7nLyUwzP/3sdzHx/g6Y/2cfhcFm5ONrSrM8qyLj9lxGF1gZSZ\nV0ZCWnVBdSn17mvq2Hk8DSGqB9sUleo4lVhbUF5MKSAhrQi9QbD54JWbOs7u2HQAhvb2pYO3C0Vl\nVdfVTFlYUqnMrlGjtuZ19zUb7jyWCsDwfqajCl0dbWjjYkdiem3wTs0uoY2LLQ521ibrKv1eCTm3\nIcV3vyPntKhUED0sACEw21LQ3Mng1YTSc0qp0tcfrFGjsWbDM5fzuJBSQI+OrXF1tGH/qUwqq6qn\nkjqVmEd+cSU21mqE4JbUUsor9Ty7fD///fnG3pcxCsH6mERUKgjr4UVIoCd9u7jz4IguqFQqs9v4\ne9a861WD0HifAAAgAElEQVRd89p/OlNZdjH17qp5VeoMxJxMx8XBmicm9gBg79UHCIBth6qbb53s\nrdkbl0FuYcUNHae8Us/hc1rcXe3o3qGV0ldqadNhld5gdnh9zSwbhXdZzauoTMehc1l4tXbgng6t\n6y3v5OtCSXkV2YUVFJfpKCzRmTQZ1nB1tKGDtzOXUgspq7gzIypv9kXp3zah36icgnIS0osIbNeK\nEf38aeVsy57YjBb3lyxk8GpCjQ3WgOpOdAdbq3o1r+paVyIAU4YFMLi3L2VXCzWAmJPVT+YPDAkA\n4Fxy/k2n9XRSdUDcfTKds5fzrnv7ExdzSM0uYcA9Xsy5vwdPRvXkTw/0IrS7V4PbeLjZY2OtJjW7\nFKMQHDidgZ2NBk83exLSiu6qfq/9pzMoq9QzrK8fAb6utPNyIjY+l8KSStJySolNyKWznysPjuiM\nwXjjta9DZ7XoqowM7u2LWqVSHnwsHbRRaKa/C8DlarPh3RS8Siuq+GFPEnqDkeHBfqjNPOQo/V7p\nhXWaDM3X5HsFtMFgFDf0+71Ze2LTmbd0NxdTbuyhq7xSz5v/Pcqrnx666Vlnjlz9m4ED7vHCSqNm\nZH9/KqsM7L5abrQUMnjdpLKKKtbuuGS2Tfn81WlsOjZQ81KpVHi3cSC7oNykoD52IZsrmcWEdvek\nnZczQ3v7olLBrhPp5BVVEJuQQwdvZ4b19cPaSq0c52bUHVL7xc8XrmvCYCEEG/cloQLGh3eweDu1\nSoWfuyMZuaWcu5xPblEl/QM96d6hFZVVBpK1t3YYfUZuKa+tOsSWRgJLld5ITGw6b311nP/+fIH8\n4kqMQvC/o6lo1CplQMHgXr7VAfeMlp+v1rrGDmjHgHu88Gxlz57Y6mt1PfQGI7+eSEOtUjEoyAeA\n9lenFKsZ+NMQo1EQl5DD59vOA9RrNqyted1cs2FZhf6mA2BiehErN55h4Uf72HUiDRdHGwb29DG7\nbm3wKjI70rCumr7V2NvcdGgUgi0Hr6CrMvLpprOUX+e7dEIIVm0+R0ZuGWWVemJiM669USMOn8tC\no1YR3NUDgKG9/bCz0fDL0RQycku5kJzP0fNZN9w6cLeQM2zcpB/2JvHL0VS0eWU8NaW38n1uYQVH\nzmfh08bBZGaN3/Ju7UBiehE5BRV4tXbAaBRs2JOIWqUianAnANq42tGrUxtiE3L5+pdLCAFD+/hi\nbaWmi78rZy/nU1Sqw+U6Z2KooTcYiU3IpY2LHX27uvPL0VS2HExm4qCOFm1/Mj6H5KwSQrt74utu\nvmBpiJ+HE0kZxayPqa5pDuzpTV5RdQ3wUkoBHX3MB/7rVVSqY9m3seQUVrBuVwKujjYMDKotMCt0\nenYcS+WXo6nKDBUXUwrYfyqD3p3dycwr496e3kpQGHCPF9/svMTO46nkF1fi1dqBPl3cUatUjA/v\nwGdbzrH54BWmj+5mUfqMQvDZ5nOkZFXXXls5Vx/Hyd4ad1c7LmcWI4Qw2wR75nIen289T87VwqiT\nr4vJuQHY2miwsVYrf6SyMWcu53HyYg4TB3fEyb62b6msooq/rT5CaYWev88Kwb3O4ApLZeWX8Y8v\nj2EwCrxa2TOkty/3BvngYGe+KGrv5YxGrSIpvQjd1Wbzhmpe7b2dcXG0ITY+l437krC11mBtrUGv\nN1JRZUBXZSCwfSt6mGmevBnnLuejzS/H0c6KnMIKvv01nkfGBlq8/ZaDVzh+MZvO/q6kaEv45VgK\no0L80ajN1y1yCsrJK640OzONNq+MK9pigjq1Ua6dg50VQ3r7sv1ICi//55Cybr+uHsyfHHSdZ3v3\naPKaV0xMDGPHjmXMmDGsXLmy3vKjR48yefJkevTowfbt25Xvz58/z4MPPsj999/PxIkT2bJlS1Mn\n9bplFZTz6/HqgRSxCbkmo6K2H0nBYBTcN6C92eaQGjX9XhlXmw73nsogI7eMgUHeyjJAmSLn+MVs\nbG00SnNc9/atADh/E02H55PzKa/U07erO5MGd6KVsy2bD1y2aOoqIQQ/7b2MCrh/oGXBri7/q8Eu\nKaMId1c7urR1o0vb6lFmF2/RoI0KnZ7318WRU1jB0D6+ONhasWbreaWJJz6tkL99doTvdydSWWVg\nbGg73pkbzsz7AnG0t1aaYUb1b6vs08nemr5dPMgprMBgFIwJbatc5/CeXni42bEnNp2kjPrNffFp\nhRw4k6k0Dwkh+HL7RQ6e1RLg58LM3xR8HXyq+33MPSnnFJTz8YbTFJRUMqS3D6/PDOGVGf3xM/MQ\n4eJgQ2EjfSpVeiPf7LzEe2tPsuN4Kh//cFoZUSmE4L8/XyCnsILySj2fbjqL0Xj9fTw198VDo7qy\nZE4Y94W1b3T6KxtrDf4eTlzRlnA5sxiNWqWM0v0ttUpFv64elJRXsWFPEmt3xvPFzxf4vx2X2BCT\nyOYDV1i+4dRN9f2cuZxXr1lv5/HqASd/eqAX/h5O7D6ZbvGoxzNJeayPSaSVsy0LJgUxMKj64e3Y\nBfMvF5+8lMNrnx3mra+Om32lpKZrIbS7p8n348LaE9bDi0G9fIgMb88fR3Zh6vDm/bf0mrTmZTQa\nWbRoEWvWrMHT05Po6GhGjBhBQECAso6vry9vvfUWn332mcm29vb2vPPOO7Rr146srCwmT57MkCFD\ncHJquBZzu22IScRgFAzr68euE2n8uDeJp//Qm5LyKmJi02nlbEtYj4b7fMB00EZlOwM/7EnExkqt\n1LpqBHVqQxsXO3KLKhjQ3Qt72+pLF6gEr4JG+5cac+JidTNLcBcP7G2tmDaiC8t/OM0XP1/gmQf7\nNBp8YxNyuaItJiTQ02yBeS1+nrXX896e3qhVKtxd7WntYsul1AKT2sbxi9moVSr6dDE/9N4cg9HI\ne18dIymjiHt7ejNjTDdCAz1Z+m0sH60/xb09vatHYgkYG9qO8fd2UGoBQ9zsCbvHi10n0hDUNuHV\nGNzLhyPns3BxsGZgT2/le41aTdSgTvxn01kWfX6UPp3dGX9vB/KLK9l2+IoyotLaSk2/bh7YWGmI\niU2nracTT0/pja2N6cjUjt7OHD2fxeXMYpPajt5g5OMfz1BWqWfWuEAG9/JtNC+cHayVl8Lr1uCM\nRkFSRhFf/HyB5KwSvFrZ08bVjrOX8/lmZzwPjerKgTOZHD6XRYCfC26Othy7mM22w8mMC2tv8bUo\nLtOxNy6DNi62DO3j2+BAnt/q5OvCFW0xydoS/D0csdI0/Mz94IguhPf0plJnoPJqbcvaSo2tjYaz\nl/PZdiiZnw+nMHlI7f1VXqlnQ0wirVxs6dGhNf6eTmZ/86cTc1n6bSyd/Vx57o99sdKoyS2s4GR8\nDu29neni78rj99/DG2uOsHrrOSYO6kh8aiHxqYU4OVgzc2wg/nV+76cTc1nx0xk0ahXzJvXExdGG\nUf3b8uvxNLYfSTG5n4UQbD5whQ0xiVhZqfFws2P7kRTSckqZO7EH1ho1567ks/dUBlYaFX27eJik\n3cXRhjn397Aov5uLJg1ecXFxtG/fHj+/6lpDZGQkO3bsqBe8gHo/5Pbta28KT09P2rRpQ15eXqPB\nq6lmD9AbjPxn41kc7ayYEtEZe1srkjKKOHRWS3tvZx4e3ZXM3FJOJeYSn1bI2aQ8KqsMRA3u2OiN\nBqbD5bcfTaGgRMf4e9srzUY11GoVYwe045ud8QwPrh1S3MHbGTsbjfJnIqC6U/7AuSw6eTnh1dr8\nU2oNoxCcuJSNk721UuPp182D3gHVzZQ/H07mvgHmC6hKnYFvdsYDcP/ADo0epyF1m4DC6wSALv5u\nHDqrJTOvDJ82jqTnlLJ8w2mMQjDzvkCG9DZfUFfqDOyJS+dCcgEZeWVk5ZehNwgC27kx875AVCoV\n3Tu0ZvqYbqzZep7tR1Jo42LHY+O7061dq3r7s7HWMDq0ndlj3dOhNSOC/Qls71bvVYjwnt44O1jz\n077LnIzP4WR8bT9Mn87udPB25sCZTA6eqX5S9mrtwDNT+9QbAg61f0onKbOI/oG1T9TrdiWQlFFE\neA8vpY+sMc4ONlTpiymr1KPNK+dMUi4XUwtJSCukQlfdJDe4lw/TRnZBCFjyxTF2HEvF0c6K7UdS\nsLPR8Pj9PXCwtSI+vZANMYn06NC6XlBvyK8n0tDpjYwKaXfN+6KuTr4u/HqiuoWjoSbDGtZWajr7\nuZpd1sXfjf2nM/nlaAqjQ9pSU7x/uf0iB85Uj3T9jgScHawZHdKWyN/032692rcZn1bI97sTmDq8\nC7tjq1+fGN7XD5VKRVtPJ6IGd+T73YnKTPf2thqyCspZ9N+jTBvRhUG9fPhhTxJbDl7BSqNi1n3d\nCbj6TptXawd6d3bnZHwO8WmFdPZzpbBUx5fbL3DsQjatXWz50+ReeLjZs3LjGeIScnn5P4eoqNSj\nuzojy5Devg02w7YkTXqGWq0WH5/am8rLy4tTp05d937i4uLQ6/W0a2e+EKnx0GtbmTaiC6NC2ja6\nXg0hBMnaElq52CrvwZizcd9lpenodFIecyb0YMPVPpo/DAtQ+qfe+uo463YlkJ5TiuPVduZr8Wxl\njwpISC/k8Dktzg7WDQaL4cF+DO7lg02dCX41ajVd27oRl5BLXlEFGo2at78+oTT5dW3rxuBePrT1\ndEKjVqFWq2jtbKc83SdlFFFQomNgkLfSxq5SqZg1rjuvrz7M97sS6eLvZrZA+L8dl9DmlTE6pO01\nC5WGuDra0M7TiVbOtni1qg20Xf1dOXRWy6XUQnzaOPLdr/EYhcDGWs3nW89jY6UmrEdtsCutqFL6\nrEquTrdlb6uhracTgR3bEBna1qTAHNLbl0qdgdyiCiYO6qjUZK+HWq3iodFdG1zes1MbenRszfkr\n+ew4noaTfXWhWNMveP/ADlxKLeR0Uh7D+vg22GdpbtDGiUvZbD+SgndrB6aP6WZRLaZmiqjnPt5P\neWXtgBzv1g508XelXzdPetWZDeVP0b1YtOYIP+27DMBj47vjebXm9+i47iz9NpaVG8/w2swQk0mn\nzdFVGdhxLBV7WysG97p2oK2rZtAGgF8DgzUsYWutYVxYe9buuMTPh5OZ2641h85qOXAmk44+zozs\n35azSXnEJuTy/e5EurVtRWf/6t/9lcxizl3Jp4u/K8VlVfx8OIWOPi7EXJ2PMfSe2lrSfQPao1ar\nsLPW0KWtG77ujsTG5/DZ5nP89+cL/Lg3icJSHZ5u9syN6lHvVZpRIW05GZ/DtkPJdGvnxg97Eimv\nNNC1rRvzonoqv5M/P9CLDXsS2XYouTroBbShd2f3BoN3S9OkwetWzDOWlZXFc889xzvvvHPNdVs5\n27J25yU6+LtxbyNNKAaj4NDpDNbviufClXzcnG15ZVYo3drX78i9mJzP5oNX8Gxlz+A+fqzfFc8/\nvjyGEBAc6MmQkOpA4+HhTJ/DKZy8VN1WPXVkV9r513+SN8ejtYPyou7MyHss3q5GSA9v4hJyOZ9W\nxC+Hk8nMK2NYP3/yCiuIi8+pN3zXyd6a52f0p09XT7ZcnTpmWP92eHjUPkF7eMDz00N45ZN9rNx4\nlg+eGaaMWAM4cCqdmNh0Ovq6MDe69039KYoPnx2uzAZQY0AvP77YfpHk7FLSCyqITcilZ0AbHpvQ\nk5c/3senm89hY2dNhc7A8QtZnE7IRVdlwMnemgdHdWP0gPa4u9k1Wqj/cdw9N5zm6+Hp6aL8Tswt\nGxh87YctH3dHkrNKcHd3YvfxVFZuPIuNlZqXZoXS1teywiqwozv7TmXiaGfN4D7+BHfzpGdAm3oj\nE2t4eDjz4sxQ/v7pQQb38WPCsNp39iI8nLmYVsSmfUms3nqBF2eG1KtNGY0C9dVruu3AZYrLqoge\n3uW6f99t2jjhaGdFaYWeHp09TH6n1yt6VDe2H0lmx7FURoZ14MvtF7Cz0fDCI6H4ejgxYRicTcrl\n+Y/28u2ueN57aihqtYrPf74IwEP3dcfd1Z6F78ew8qczGAVEDQ3A39d08MSM8T1NPo/2dKFvdx/+\n+dVRziblMaSPH/On9DZb03Z3d2Ld7gSOX8zm+MVsHO2tmTu5B2PDO5jcIwBzo/vw+OTe9b7/PWjS\n4OXt7U16eu27BVqtFk9Pz0a2MFVSUsLcuXNZuHAhvXr1uub6rz0Wxgsf7eWfXx3jWaPR7BNIfGoh\nqzafRZtfPbS9W1s3LqYW8MK/9zF7XKDJ07yuysA/vzyK0Sh4ZGwg3du3orOPM//ZdJbCEh0TwtuT\nnV37NDxuQDtOXsrG2kpNeHdPk2WN8XS1IyuvDK9W9gR3bmPxdjXaXu3A/vTH0wCMCPbnqWnB5OSU\nkFVQzqGzWopLdRiMAl2VgUPntLy+8iAPjujMvtg0bKzU+Le2r3dcb1dbJgzsyA97k3jn8yPMmXAP\ndjZW5BdX8v7aE1hbqZk9rjsF+bd+Pjk7DTjaWRF3KZv4q4NRJg/uiIuthqem9Oa9tSf56LtYZX1f\nd0cGBnkzrI9fdS1Krycnp3potYeH83Xn6d2mrYcjh3NK+ft/DnDsQjZ2NhrmTOyBk7Xa4nML7+7B\nkL5j0FfolCCkK9eRXd5wc7uvmx3LFgzC3laj5GeN+8Pbk5RWwOGzmbzz+REeHd8dtUpFUkYRq7ec\nI6ugnF4B7oQGevJ9TCIateq67ou6Ovq4cDopDxdbzU1fy7Gh7fj6l0s8/9FedFUGZt4XiDVC2a+H\nkw1h93hx8KyWDTsvck+HVuw5mYa/hyNtW9ujUqmYProrqzafA2BANw+L07RwSm+0+WXVk3IXV1Ba\nbH64emRYez7+4TThPb2JHhaAi4MNeblN9xcYbuaB4E5p0uAVFBREcnIyaWlpeHh4sHnzZpYuXdrg\n+nVralVVVcyfP5+oqChGjx5t0fE6+7vxZFRPPlgXxwfr4njuj31NmrMOn9Py6aZzGI2Cwb18GDug\nHT5tHDmVmMsnP55m5cazxKcV0qNja/zcHdlxLI2M3DJG9vNXRvV1a9eKxY+FUVymqzdUuLO/K9HD\nAnB1tLmuYettPZ04nZRH9LCA6+oLqOHvWftkOqS3D9NG1T4he7rZc/+9HUzWH9rHj482nOLrqxP+\nBnf1aLDZZ/y9HbiQUsDJ+BzmLY3BzckGlUpFaYWe6aO73tAgDUuoVSo6+7kSe3XUVngPb6V5pbOf\nKwun9mZPXAZd/F3p0aE1rV3smiQdd4sO3i4cPpfFsQvZtPV0Yt6knibNrJZQq1S0drEju/L6Rts1\n1H9ibaVm/uQg3lt7kgNnMnGws8LWWsPWQ1cQAtq42HH0fBZHrza5DwryqdeXa6lpI7uQll16S67z\n0D6+bD2UTH5xJcFdPcw2Y06J6MyJSzl8vzuB+FR3jEIwJrSdcl8NDPKpfgfQKK7Zr1yXWq3Cp821\n75ngrh6s+OswpeYq1acSTfw3BGJiYli8eDFCCKKjo5kzZw4ffPABQUFBREREcOrUKRYsWEBRURG2\ntrZ4eHiwceNGfvrpJ1566SW6dOmijI76xz/+QWBg4+9PZGcXExObzpqt51EBfbq4Mzqk7dVO1kTs\nbDTMi+pJz06mM52n55Tywbo4sn7zsrFXawf+Nuvabfo3o6xCT0pWsdkBA5baE5tOdmEFUYM6olar\nrlnbyCuq4IPv40jWljBnwj2E3ePd4LrFZTq2HLxCanYp2rwycgsrCO7qwbxJPS0eMXYjth68wne7\nErC2UvOPOWE3XHC1hJpXSlYJiz4/SngPLx4a1dWk3/N6NEVelJRX8fZXx0m7Onm0u6sds8Z1J7Cd\nGylZJRw+l8UVbTEzxnQzmXj3Tjp+MZtD57N4eGQXk+bwujbtv6y8f9jK2Za354bf0MNlc9Aca15N\nHrxut5ob89iFLDYfuGIyJ1wrZ1v+MqU3bT3NDy6o0Ok5dyWf9JxS0nJKyS+qZOqIzg3OTXg3s6SQ\nqqwycCWzmC7+rtcVhPQGIxq1qkkDF0BqVgmvrz7MxEEdmXAD75DVaAnBC6rz/WYLz6bKi/ziSlb+\ndIa2Xk5MHtIJO5u7f7TbtfKiSm/glU8PkV1QwZSIgAYHUrUEMnjdBer+GIUQJKQVsf1oCuWVemaP\n637DzRbNTUspsEvKq3C0s7qpQNlS8uJWkHlRy5K8iE8tZPfJNP44qusNjUhtLppj8Gq5V4PqId+d\n/V2V4a5S81N3eiJJut1k+XH3apkNuJIkSVKLJoOXJEmS1OzI4CVJkiQ1OzJ4SZIkSc2ODF6SJElS\nsyODlyRJktTsyOAlSZIkNTsyeEmSJEnNjgxekiRJUrMjg5ckSZLU7MjgJUmSJDU7MnhJkiRJzY4M\nXpIkSVKzI4OXJEmS1OzI4CVJkiQ1OzJ4SZIkSc2ODF6SJElSsyODlyRJktTsyOAlSZIkNTtNHrxi\nYmIYO3YsY8aMYeXKlfWWHz16lMmTJ9OjRw+2b99usmzDhg2MGTOGMWPG8MMPPzR1UiVJkqRmwupa\nK6SkpLBu3ToOHTpEZmYmtra2BAYGMmbMGEaPHo2VVcO7MBqNLFq0iDVr1uDp6Ul0dDQjRowgICBA\nWcfX15e33nqLzz77zGTbwsJC/v3vf7NhwwaEEEyePJkRI0bg7Ox8E6crSZIktQSNBq/XXnuNM2fO\nMHbsWP7617/i7u5OZWUlCQkJ7N27l5UrV/K3v/2NPn36mN0+Li6O9u3b4+fnB0BkZCQ7duyoF7wA\nVCqVybZ79+5l4MCBSrAaOHAge/bsYdy4cTd+tpIkSVKL0GjwGjFiBG+88Ua977t168a4ceMoKCgg\nJSWlwe21Wi0+Pj7KZy8vL06dOmVRwsxtq9VqLdpWkiRJatkaDV5Dhw5tdGM3Nzfc3NwaXC6EuLFU\nNbDtb2tn5nh4yGbFGjIvasm8qCXzopbMi+brmn1eAG+99Rbz58/H3t6eGTNmcPbsWf7+978zceLE\nRrfz9vYmPT1d+azVavH09LQoYd7e3hw6dEj5nJmZSVhY2DW3y84utmj/LZ2Hh7PMi6tkXtSSeVFL\n5kWt5hjELRptuH//fpydndm7dy9eXl78/PPP9QZYmBMUFERycjJpaWnodDo2b97MiBEjGly/bm1r\n0KBB7N+/n+LiYgoLC9m/fz+DBg2yJLmSJElSC2dRzavGkSNHGDVqFF5eXhY14Wk0Gl599VVmz56N\nEILo6GgCAgL44IMPCAoKIiIiglOnTrFgwQKKior49ddf+eijj9i4cSOurq7MmzePBx54AJVKxYIF\nC3BxcbnhE5UkSZJaDpWwoGNq1qxZ+Pn5sW/fPn744QccHR2ZNGkSGzduvB1pvC6yGaCabBKpJfOi\nlsyLWjIvarXYZsP33nuPzp07s2zZMlxdXcnMzGTWrFlNnTZJkiRJMsuiZsPWrVszc+ZM5bO/vz/+\n/v5NlSZJkiRJalSjwSssLKzRvq0DBw7c8gRJkiRJ0rU0Gry+//57ANatW0dBQQFTp05FCMH333+P\nl5fXbUmgJEmSJP1Wo8GrZlqnI0eO8OWXXyrfv/LKKzz88MM8/vjjTZs6SZIkSTLDogEbWVlZ5OXl\nKZ/z8vLIzs5uskRJkiRJUmMsGrDxyCOPEBUVxbBhwwDYvXs3TzzxRFOmS5IkSZIaZFHweuihh+jX\nrx9HjhxBCMFDDz1Et27dmjptkiRJkmSWxTNsBAYGEhgY2JRpkSRJkiSLWBS8jh8/zrvvvktKSgoG\ngwEhBCqVSg6VlyRJku4Ii4LXyy+/zLx58+jTpw9qtUVjPCRJkiSpyVgUvOzs7Lj//vubOi2SJEmS\nZBGLqlFDhgxh9+7dTZ0WSZIkSbKIRTWvb775hhUrVuDo6IiNjY3s85IkSZLuKIuCV800UZIkSZJ0\nN7AoePn5+aHX60lKSkKlUtGhQwesrK7r71hKkiRJ0i1jUQQ6deoUf/7zn5UmQ71ez4cffkiPHj2a\nOn2SJEmSVI9FwWvx4sUsWbKE8PBwAA4ePMiiRYtYu3ZtkyZOkiRJksyxaLRheXm5Erig+u98lZeX\nN1miJEmSJKkxFgUve3t7Dh48qHw+fPgw9vb2TZYoSZIkSWqMRc2GL730Ek899RQ2NjYAVFVV8cEH\nH1h0gJiYGJYsWYIQggceeIA5c+aYLNfpdDz//POcOXOGVq1asWzZMnx9fdHr9bzyyiucOXMGo9HI\nxIkT620rSZIk/T5ZFLx69erF9u3bSUpKQghBp06dsLa2vuZ2RqORRYsWsWbNGjw9PYmOjmbEiBEE\nBAQo66xbtw5XV1e2b9/Oli1bePfdd1m2bBnbtm2jqqqKjRs3UlFRwbhx4xg/fjy+vr43fraSJElS\ni2BRs+H+/fupqKiga9eudOvWjfLycoteUI6Li6N9+/b4+flhbW1NZGQkO3bsMFlnx44dTJo0CYAx\nY8YozZMqlYqysjIMBgPl5eXY2Njg5OR0vecnSZIktUAWBa933nnHJHA4OTnxzjvvXHM7rVaLj4+P\n8tnLy4usrCyTdbKysvD29gZAo9Hg7OxMQUEBY8aMwd7enkGDBjF8+HAeffRRXFxcLDopSZIkqWWz\nqNmwZjqoGmq1GoPBYNF217tOzbHi4uLQaDTs27ePgoIC/vjHPxIeHo6/v78lSZYkSZJaMIuCl6Oj\nI7GxsfTu3RuA2NhYHBwcrrmdt7c36enpymetVounp2e9dTIzM/Hy8sJgMFBSUoKrqyubNm1i8ODB\nqNVqWrduTXBwMKdPn75m8PLwcLbklH4XZF7UknlRS+ZFLZkXzZdFwevZZ59l/vz5dO7cGYD4+Hg+\n+uija24XFBREcnIyaWlpeHh4sHnzZpYuXWqyTkREBBs2bKB3795s27aNsLAwAHx8fDh48CATJkyg\nrKyM2NhYZs6cec1jZmcXW3JKLZ6Hh7PMi6tkXtSSeVFL5kWt5hjEVcKStj2gsLCQkydPIoSgb9++\nuLq6WnSAmJgYFi9ejBCC6Oho5syZwwcffEBQUBARERHodDqeffZZzp07h5ubG0uXLsXf35+ysjJe\nfNEYr50AABg9SURBVPFFEhISAHjggQeYNWvWNY8nf4zV5I1ZS+ZFLZkXtWRe1GrRwSspKYmEhARG\njhxJaWkpVVVVuLm5NXX6rpv8MVaTN2YtmRe1ZF7UknlRqzkGL4tGG27YsIEnn3ySf/zjH0B139Vf\n/vKXJk2YJEmSJDXEouD1+eef8/333+PsXB2dO3XqRE5OTpMmTJIkSZIaYlHwsra2xtHR0eQ7jUbT\nJAmSJEmSpGuxKHi5ubkpf4gS4Mcff1ReLJYkSZKk283iiXmfeeYZkpKSGD58OHZ2dnzyySdNnTZJ\nkiRJMsui4NWxY0e+++47Ll++jBCCjh07ymZDSZIk6Y6xqNkwKSkJvV5PQEAAGRkZrFq1isLCwqZO\nmyRJkiSZZVHw+stf/oJarSYlJYXXX3+dlJQUnn/++aZOmyRJkiSZZVHwUqvVWFtbs3v3bqZNm8ai\nRYvIyMho6rRJkiRJklkWBa/Kykq0Wi07d+5U5h60cGIOSZIkSbrlLApejzzyCJGRkTg6OhIUFERK\nSorywrIkSZIk3W4Wz21Yl8FgwGAwYGNj0xRpuilyrrJqct62WjIvasm8qCXzolaLm9vw9OnTZr/X\naDTY2Nig0+mUWd8lSZIk6XZp9D2vFStWUF5ezvjx4+nduzfu7u5UVlaSlJTEnj172L17Ny+88AIB\nAQG3K72SJEmS1Hjw+vDDD4mLi+Obb77h3//+N5mZmdjb29O1a1dGjhzJV199hZOT0+1KqyRJkiQB\nFsyw0atXL3r16nU70iJJkiRJFrFotKEkSZIk3U1k8JIkSZKaHRm8JEmSpGZHBi9JkiSp2bEoeOXm\n5vLXv/6Vhx56CIDz58/zf//3f02aMEmSJElqiEXB65VXXqFfv34UFRUB0KlTJ77++muLDhATE8PY\nsWMZM2YMK1eurLdcp9Px9NNPM3r0aKZOnUp6erqy7Pz58zz44IOMHz+eCRMmoNPpLDqmJEmS1LJZ\nFLy0Wi3Tpk1T/gCljY0NavW1NzUajSxatIhVq1axadMmNm/eXG9GjnXr1uHq6sr27dt55JFHePfd\nd4HqKaiee+453njjDTZt2sQXX3yBtbX19Z6fJEmS1AJZFLysrExfBysqKrJoVvm4uDjat2+Pn58f\n1tbWREZGsmPHDpN1duzYwaRJkwAYM2YMBw8eBGDv3r0EBgbStWtXAFxdXVGpVJYkV5IkSWrhLApe\no0eP5rXXXqO0tJT169cze/ZsHnjggWtup9Vq8fHxUT57eXmRlZVlsk5WVhbe3t5A9ZyJzs7OFBT8\nf3v3HhxVef9x/L1sAlJMgpiQRaS0JraQGqAzKsERIYBZIITsBiIMUsKlpdoBKqFYwck4crXGyUhk\nOhIBKzRMa4HIJRBSgxI6XGy1hZkCRUEn3JJwS5NgypLN8/sjP3YbgrBWNvEkn9df7Nlnz373yzN8\nOGfPPqeKL774AoAZM2aQlpbG6tWrA/1MIiLSxt12hQ2An/70p2zdupXq6mr27NnDT37yE1JTU2/7\nukCOzm4cY4zBZrPh9Xr55JNP2LRpE506dWLq1Kk89NBDvvuJiYhI+xVQeAGMHTuWsWPHfq2dOxyO\nJhdgVFRU0L1792ZjysvLiY6Oxuv1UltbS0REBA6Hg0ceeYSIiAgAnnjiCY4cOXLb8LLi0v7Bol74\nqRd+6oWfemFdAYXXxYsX+f3vf09ZWRn19fW+7StWrLjl6+Lj4ykrK+PMmTNERUVRWFhITk5OkzGJ\niYkUFBTQv39/ioqKfOH0+OOPs3r1aq5evYrdbuevf/0rU6dOvW2tuj9PI92ryE+98FMv/NQLPyuG\neEDh9Ytf/IK4uDgGDRrku+IwEHa7naysLKZPn44xhvHjxxMTE0Nubi7x8fEkJiaSnp7O/PnzSUpK\nomvXrr5wCw8PZ9q0aYwbNw6bzcbQoUMZMmTI//YpRUSkTQnoTspjx45l69atLVHPN6b/STXS/yr9\n1As/9cJPvfCz4pFXQFcb9u/fn3/961/BrkVERCQgAZ02nDhxIpMnT8bhcNCpUyff9o0bNwatMBER\nka8SUHjNnz+fZ555hri4uK/1nZeIiEgwBBRenTp1YsaMGcGuRUREJCABfec1ePBgSktLg12LiIhI\nQAI68nr33XfJy8ujS5cudOzY0bcKxv79+4Ndn4iISDMBhdemTZuCXYeIiEjAAgqvnj17BrsOERGR\ngN0yvObPn092drZvlYsb6VJ5ERFpDbcMr4yMDAB+/etft0gxIiIigbhleG3YsIFly5bx6KOPtlQ9\nIiIit3XLS+WPHj3aUnWIiIgELKDfeYmIiHyb3PK04fHjxxk0aFCz7fqdl4iItKZbhtf3vvc98vLy\nWqoWERGRgNwyvDp27KjfeImIyLfOLb/zCg0Nbak6REREAnbL8Hr33Xdbqg4REZGA6WpDERGxHIWX\niIhYjsJLREQsJ+jhVVpaysiRI3E6nTe97N7j8TB37lySkpKYMGECZ8+ebfL82bNn+fGPf8zbb78d\n7FJFRMQighpeDQ0NLF68mDVr1rB9+3YKCws5ceJEkzEbN24kIiKC4uJiMjIyyM7ObvL8K6+8wpAh\nQ4JZpoiIWExQw+vw4cP07t2bnj17EhoaSnJyMiUlJU3GlJSU4Ha7AXA6nU1W7Xj//ffp1asXsbGx\nwSxTREQsJqjhVVFRQY8ePXyPo6OjqaysbDKmsrISh8MBgN1uJzw8nKqqKurq6li9ejWzZs0KZoki\nImJBAd1J+X9ljPnaY66vm5ibm8vUqVPp3LlzwPsCiIoK+/qFtlHqhZ964ade+KkX1hXU8HI4HE0u\nwKioqKB79+7NxpSXlxMdHY3X66W2tpaIiAgOHz5McXEx2dnZVFdX06FDBzp16sTTTz99y/c8f74m\nKJ/FaqKiwtSL/6de+KkXfuqFnxVDPKjhFR8fT1lZGWfOnCEqKorCwkJycnKajElMTKSgoID+/ftT\nVFREQkICAPn5+b4xK1eupEuXLrcNLhERaR+CGl52u52srCymT5+OMYbx48cTExNDbm4u8fHxJCYm\nkp6ezvz580lKSqJr167Nwk1ERORGNhPol0kWodMAjXRKxE+98FMv/NQLPyueNtQKGyIiYjkKLxER\nsRyFl4iIWI7CS0RELEfhJSIilqPwEhERy1F4iYiI5Si8RETEchReIiJiOQovERGxHIWXiIhYjsJL\nREQsR+ElIiKWo/ASERHLUXiJiIjlKLxERMRyFF4iImI5Ci8REbEchZeIiFiOwktERCxH4SUiIpYT\n9PAqLS1l5MiROJ1O8vLymj3v8XiYO3cuSUlJTJgwgbNnzwKwb98+0tLSGDt2LOPGjePAgQPBLlVE\nRCwiqOHV0NDA4sWLWbNmDdu3b6ewsJATJ040GbNx40YiIiIoLi4mIyOD7OxsALp168aqVavYunUr\nr7zyCs8//3wwSxUREQsJangdPnyY3r1707NnT0JDQ0lOTqakpKTJmJKSEtxuNwBOp5P9+/cD0KdP\nH6KiogB48MEH8Xg8XLt2LZjlioiIRQQ1vCoqKujRo4fvcXR0NJWVlU3GVFZW4nA4ALDb7YSHh1NV\nVdVkTFFREXFxcYSGhgazXBERsYiQYO7cGPO1xxhjsNlsvseffvopOTk5rF27NqD3jIoK+3pFtmHq\nhZ964ade+KkX1hXU8HI4HL4LMKDxSKx79+7NxpSXlxMdHY3X66W2tpaIiAgAysvLmTVrFq+++ir3\n339/QO95/nzNnfsAFhYVFaZe/D/1wk+98FMv/KwY4kE9bRgfH09ZWRlnzpzB4/FQWFjI8OHDm4xJ\nTEykoKAAaDw9mJCQAEB1dTU///nP+dWvfsWAAQOCWaaIiFhMUMPLbreTlZXF9OnTGTNmDMnJycTE\nxJCbm8sHH3wAQHp6OpcvXyYpKYl33nmHefPmAZCfn09ZWRm//e1vcblcuN1uLl26FMxyRUTEImwm\nkC+mLESnARrplIifeuGnXvipF346bSgiItICFF4iImI5Ci8REbEchZeIiFiOwktERCxH4SUiIpaj\n8BIREctReImIiOUovERExHIUXiIiYjkKLxERsRyFl4iIWI7CS0RELEfhJSIilqPwEhERy1F4iYiI\n5Si8RETEchReIiJiOQovERGxHIWXiIhYTtDDq7S0lJEjR+J0OsnLy2v2vMfjYe7cuSQlJTFhwgTO\nnj3re27VqlUkJSUxatQo/vKXvwS7VBERsYighldDQwOLFy9mzZo1bN++ncLCQk6cONFkzMaNG4mI\niKC4uJiMjAyys7MB+Oyzz9i5cyc7duzgrbfe4uWXX8YYE8xyRUTEIoIaXocPH6Z379707NmT0NBQ\nkpOTKSkpaTKmpKQEt9sNgNPp5MCBAwDs3r2b0aNHExISwv3330/v3r05fPhwMMsVERGLCGp4VVRU\n0KNHD9/j6OhoKisrm4yprKzE4XAAYLfbCQsLo6qq6qavraioCGa5IiJiEUENr0BO891sjM1m+8rt\nIiIiIcHcucPhaHIBRkVFBd27d282pry8nOjoaLxeLzU1NUREROBwODh37pxvXHl5ebPX3kxUVNid\n+wAWp174qRd+6oWfemFdQT3yio+Pp6ysjDNnzuDxeCgsLGT48OFNxiQmJlJQUABAUVERCQkJAAwb\nNowdO3bg8Xg4deoUZWVl9OvXL5jlioiIRQT1yMtut5OVlcX06dMxxjB+/HhiYmLIzc0lPj6exMRE\n0tPTmT9/PklJSXTt2pWcnBwAYmNjGTVqFMnJyYSEhPDSSy/ptKGIiABgM7r+XERELEYrbIiIiOUo\nvERExHIUXiIiYjltJrxut4ZiW1ZeXs6UKVMYPXo0KSkprFu3DoB///vfTJ8+HafTyYwZM6ipqWnl\nSltOQ0MDbrebZ555BoDTp0/z1FNP4XQ6yczMpL6+vpUrbBk1NTXMmTPHd/HToUOH2u28+N3vfseY\nMWNISUlh3rx5eDyedjMvFi5cyGOPPUZKSopv263mwZIlS0hKSiI1NZWjR4+2Rsm31SbCK5A1FNsy\nu93OggUL2LFjB3/4wx/Iz8/nxIkT5OXlMWjQIHbt2sXAgQNZtWpVa5faYtatW0dMTIzv8Wuvvca0\nadPYtWsXYWFhbNy4sRWrazlLly5lyJAh7Ny5ky1btvDAAw+0y3lRUVHB+vXr2bx5M9u2bcPr9VJY\nWNhu5kVaWhpr1qxpsu2r5sGePXsoKyujuLiYRYsW8dJLL7VGybfVJsIrkDUU27KoqCj69u0LQJcu\nXYiJiaGioqLJupFut5v333+/NctsMeXl5ezZs4f09HTftgMHDuB0OoHGXvz5z39urfJaTG1tLX/7\n298YN24cACEhIYSFhbXbedHQ0EBdXR319fX85z//oXv37hw8eLBdzIuHH36Y8PDwJttunAfX/80s\nKSnB5XIB0L9/f2pqarhw4ULLFhyANhFegayh2F6cPn2aY8eO0b9/fy5evEhkZCTQGHCXL19u5epa\nxrJly3j++ed9vwu8fPkyERERdOjQON0dDke7mB+nT5/mnnvuYcGCBbjdbrKysqirq2uX8yI6Oppp\n06YxdOhQnnjiCcLCwoiLiyM8PLzdzYvrLl261GQeXLp0CWi63ix8e9eVbRPhpZ+qNbpy5Qpz5sxh\n4cKFdOnSpV3+qPvDDz8kMjKSvn37+uaFMabZHGkPvamvr+fIkSNMmjSJgoICOnfuTF5eXrv47Deq\nrq6mpKSEDz74gL1791JXV0dpaWmzce2xNzeyyrqyQV1ho6UEsoZiW1dfX8+cOXNITU1lxIgRANx7\n771cuHCByMhIzp8/T7du3Vq5yuD75JNP2L17N3v27OHq1atcuXKFZcuWUVNTQ0NDAx06dAh4nUyr\nczgcOBwO4uPjAUhKSuKtt95ql/Ni37599OrVi65duwIwYsQI/v73v1NdXd3u5sV1XzUPoqOjKS8v\n9437tvalTRx5BbKGYlu3cOFCYmNjycjI8G0bNmwYmzdvBqCgoKBd9CQzM5MPP/yQkpIScnJyGDhw\nIK+99hoDBw6kqKgIaD+9iIyMpEePHnz++edA4/d+sbGx7XJe3HfffRw6dIirV69ijOHAgQM8+OCD\n7Wpe3HhE9VXzYPjw4bz33nsA/OMf/yA8PNx3evHbpM0sD1VaWsrSpUt9ayjOnDmztUtqMR9//DGT\nJ0/mBz/4ATabDZvNxty5c+nXrx/PPfcc586d47777mPFihXNvrRtyz766CPWrl3Lm2++yalTp8jM\nzKS6upq+ffuSnZ1NaGhoa5cYdMeOHePFF1+kvr6eXr16sXz5crxeb7ucFytXrqSwsJCQkBDi4uJY\nsmQJ5eXl7WJezJs3j4MHD1JVVUVkZCSzZ89mxIgR/PKXv7zpPFi0aBF79+6lc+fOLF++nB/96Eet\n/AmaazPhJSIi7UebOG0oIiLti8JLREQsR+ElIiKWo/ASERHLUXiJiIjlKLxERMRyFF5iScOGDeOz\nzz5rkfdauXJlk1tlLFiwgPz8/G+83wULFpCSkkJmZuY33tetHDt2jJ07dwb1PURamsJL5DZWrlzJ\ntWvX7ug+L1y4QHFxMdu2bSMnJ+eO7vtGR44c+Z/Dq6Gh4Q5XI3JnKLykTfn888/52c9+Rnp6Oi6X\ny7f8DUCfPn1YtWoV48eP58knn6S4uNj33K5duxg1ahRpaWmsWrWKPn36UFdXx6JFi7DZbEycOBG3\n201tbS0Ax48fJyMjA6fTyQsvvPCV9bz33nukpKSQmprK7NmzuXTpEleuXCEjI4OrV6/idrt55513\nmrxmy5YtzJo1y/fY6/UyePBg3/qdq1ev5qmnniItLY1nn32WixcvAnDt2jV+85vfkJKSgsvlYvbs\n2VRVVfHGG29w4MAB3G43S5cuBRpXpHG73aSmpjJt2jROnToFNK5K4nK5WLJkCRMnTmTv3r3f5K9D\nJHiMiAUlJiaaTz/9tMm2+vp643a7zcmTJ40xxtTW1hqn0+l7/MMf/tDk5+cbY4z5+OOPzeDBg40x\nxly4cME8+uijpqyszBhjzNtvv2369OljvvzyS9/r6urqfO/zwgsvmEmTJhmPx2M8Ho9JTk42+/bt\na1bj8ePHzeOPP24uXLhgjDHm9ddfN88995wxxpjTp0+bhISEm362uro6k5CQYC5fvmyMMWb37t0m\nIyPDGGPMli1bTFZWlm/shg0bzLx584wxxrzxxhtm9uzZpr6+3hhjfK/fvHmzmTNnju81Fy9eNAkJ\nCebEiRPGGGP+9Kc/mfT0dGOMMQcPHjRxcXHm0KFDN61N5NtCR17SZnzxxRecPHmSzMxMXC4XTz/9\nNNeuXWtyV+3Ro0cDMGDAAM6fP4/H4+HQoUM89NBD9OrVC4Dx48c327e5YRW1ESNGEBoaSmhoKHFx\ncZSVlTV7zcGDBxk6dCj33nsvABMnTmTfvn23/Rx33XUXw4cPZ/v27UDjoqnXbyi5e/du9u/fj8vl\nwuVysWHDBs6dOwc03g5mypQp2O12AN8K6jc6dOgQffv25YEHHgBg3LhxHD16lC+//BKA3r17069f\nv9vWKdKa2sQtUUSgMWC6detGQUHBTZ+32Wx06tQJwHcDQq/X2yyYbnx8Mx07dvT92W63N7mg47/3\nc+N9kK6/7+24XC6WL1/OmDFj+Oijj8jOzvbt89lnnyUtLe2m7xeIm9X134+/853vBLQfkdakIy9p\nM77//e9z1113sWXLFt+2kydPcuXKFaD5P+7XHw8YMIB//vOfvu99/vt7MoC7776bmpqar13PoEGD\n2LNnj+87qT/+8Y889thjzd7/Zh5++GFqa2vJycnhySef9IXusGHD2LBhA9XV1QB4PB6OHTsGQGJi\nIuvWrfNdXHL9Dsl3332377u665/36NGjvlulbN68mbi4OIWWWIqOvMSSbDYbU6dOJSQkxHcksW3b\nNt58802WLl3K2rVr8Xq9REZG8vrrr/tec+M+oPGmfC+//DIzZ87knnvuYejQoYSEhNC5c2cApk2b\nxpQpU+jcuTPr168PuMbY2FgyMzOZOnUqHTp0oFevXixatKjZ+38Vl8tFbm4uGzZs8G1LTU2lqqqK\nyZMnY7PZaGhoYNKkSfTp04eZM2eSk5ODy+WiY8eOfPe732XFihUMGjSINWvW4HK5eOSRR3jxxRd5\n9dVXmTdvHl6vl27duvmO7ESsQrdEEQGuXLlCly5dgMYjkU2bNt2R33KJSHDoyEsEWL9+PUVFRXi9\nXrp27crixYtbuyQRuQUdeYmIiOXogg0REbEchZeIiFiOwktERCxH4SUiIpaj8BIREctReImIiOX8\nH4gzFtcS9o9MAAAAAElFTkSuQmCC\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f47b20dd690\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(graph_means)\n", + "plt.ylabel('Time (seconds)')\n", + "plt.xlabel('Length of vector')\n", + "_ = plt.title('Time to sum the elements of 1000 vectors (vectorized TF operation)')\n", + "_ = plt.ylim(ymin=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4KZg2WXjbhg5" + }, + "source": [ + "## AutoGraph" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "UQJBQWbCbinm" + }, + "outputs": [], + "source": [ + "# Sum written using for loop and converted with AutoGraph\n", + "def sum_all(elements):\n", + " sum_ = 0.0\n", + " length = len(elements)\n", + " for i in tf.range(length): \n", + " sum_ += elements[i][0]\n", + " return sum_\n", + "\n", + "def run_trial(num):\n", + " elements = get_elements(num)\n", + " return sum_all(elements)\n", + " \n", + "ag_means = []\n", + "ag_run_trial = ag.to_graph(run_trial)\n", + "\n", + "for num in range(max_elements):\n", + " with tf.Graph().as_default():\n", + " durations = []\n", + " foo = ag_run_trial(num)\n", + " with tf.Session() as sess:\n", + " for _ in range(burn_ins):\n", + " for _ in range(batches):\n", + " sess.run(foo)\n", + " \n", + " for _ in range(trials):\n", + " start = time.time()\n", + " for _ in range(batches):\n", + " sess.run(foo)\n", + " \n", + " duration = time.time() - start\n", + " durations.append(duration)\n", + " ag_means.append(np.mean(durations))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 301 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 310, + "status": "ok", + "timestamp": 1532448438694, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "DLDOmrRW99v5", + "outputId": "ae0e0573-39db-4004-a064-efc618dbf867" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEcCAYAAADUX4MJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XdYVFf++PH3DE1AinTEjgULioggNiTYjcZCoiZqjEmM\n6RuzcVc32exPE7O72dTNmmhi1u+abLJqLFFsib1E7GLDjkgbQUSKyDAz5/eHySARcYwMQ/m8nsfn\nce4999zPPXOHzz23nKtRSimEEEKIu9DaOgAhhBC1gyQMIYQQFpGEIYQQwiKSMIQQQlhEEoYQQgiL\nSMIQQghhkTqZMObPn88bb7xh6zBqlQceeICffvrJ6uuZOXMmH330kdXXU1NcuHCBUaNG0a1bN776\n6itbh1OnnT17ljFjxtg6jGpV2e9Jr9czZMgQcnNzq2x9tTJhdO3alfDwcMLDw2nfvj1dunQxT1uz\nZg3PPPMMc+bMsXoc6enphISEYDKZrL6uqlTf/mjfjTW/xy+++IKoqCgOHDjAhAkTbpu/bt06xo0b\nR1hYGJMmTbpt/smTJxk9ejRhYWGMGTOG5OTkcvPfffddoqKi6NGjB+++++49LWttEydOZNmyZdW2\nvo8//pinnnqqwjgiIyMpLS29p/pCQkK4dOnSPS3z1VdfMWLECMLCwujduzeTJk1i7dq191RHVXF0\ndCQ+Pp7PP/+8yuqslQnj0KFDHDx4kIMHD9K4cWPmz59vnvbggw9WWxxKKTQaDfLsY+1mze8xIyOD\n1q1b33G+p6cnkydPZurUqbfNKy0t5fnnn2fkyJHs27ePkSNH8txzz2EwGAD49ttv2bx5M6tXr+b7\n779n69at/O9//7No2drgXr6P7OxsEhMTiYuLKzc9PT2dAwcOoNFo2Lx58z2tX6PR3FP5OXPmsHjx\nYmbOnMnevXvZsWMHv/vd79ixY8cdl7H2344HH3yQFStW3HOyvJNamTBupZS6rdE/+eQTXnvtNaDs\n6HH58uX069ePqKgovv32W44ePcqIESOIjIy8rTeybNkyhg4dSlRUFE899RQZGRkVrnvixIkARERE\nEB4ezpEjR1BKMW/ePB544AF69erFH//4RwoLCytc/urVq0ybNo3u3bsTFRVV7gj010c3t/YK9u7d\nS0xMDF988QU9e/akT58+/Pjjj2zbto1BgwYRFRXF/PnzK1znkiVLWL16NV988QXh4eE8++yz5nkn\nT55kxIgRdO/enenTp6PX683ztmzZwsiRI+nevTvjx4/n1KlTFdYPcO7cOaZMmUJUVBRDhgxh3bp1\ndyxbWb0PPPAACxcuZMSIEXTt2pXXX3+dK1eu8PTTTxMeHs6UKVMoKCgwlz98+DDjxo2je/fujBw5\nkr1795rnTZw4kY8++ojx48cTHh7Ok08+SV5ennkelP8eU1NTmThxIhEREURHRzN9+vQ7bsOmTZt4\n8MEHiYyMZNKkSZw/fx6Axx9/nMTERGbPnk14eDgXL168bdno6GgGDx6Mr6/vbfP27t2L0Whk0qRJ\nODg4MHHiRJRS7NmzB4CVK1cyZcoU/Pz88PPz44knnmDFihUAJCYmVrrsrdauXXvbqZxFixbx3HPP\nATdPbfztb38jNjaW3r1785e//KXcvvHjjz8ycuRIunXrxsCBA9m5cycffPABBw4cYM6cOYSHh/PW\nW28BcPDgQeLj4+nevTsPP/wwhw4dKvcdffDBB4wfP56wsDDS0tJYvnw5/fv3Jzw8nP79+7NmzZoK\nv4Ndu3bRsWNHHB0dy01fuXIlYWFhjB492tw2t67v1h7QihUrePTRRwGYMGECSilGjBhBeHi4eR9e\nsmQJAwcOJCoqiueee47Lly8DN089fvPNN3zwwQdER0fj6OiIRqMhPDycd955567bOHToUMLDwxkw\nYIA56f+yD8TExDB//nx69OhBXFwcq1evLrcd165d45lnniE8PJyxY8eW+7vh7++Ph4cHR44cqbDd\n7pmq5WJjY9Xu3bvLTfvnP/+pXnvtNaWUUmlpaapdu3bqzTffVCUlJWrXrl0qNDRUPf/88yo3N1dl\nZWWp6OhotW/fPqWUUj/88IMaOHCgOn/+vDIajerTTz9VY8eOrXDdaWlpKiQkRJlMJvO0pUuXqoED\nB6q0tDR1/fp19cILL5hj+bX33ntPvfnmm8poNCqDwaD2799vnhcSEqJSU1PNn//4xz+qDz/8UCml\nVGJiourQoYOaN2+eMhgMasmSJapHjx7q1VdfVdevX1dnzpxRoaGh6tKlSxWu99a6bm3Hhx9+WGVn\nZ6tr166pIUOGqG+//VYppdSxY8dUdHS0SkpKUiaTSa1YsULFxsYqvV5/W93Xr19XMTExasWKFcpk\nMqkTJ06oqKgodfbs2dvWfbd6Y2Nj1dixY9WVK1eUTqdT0dHRatSoUerkyZNKr9erSZMmqU8++UQp\npVRWVpaKjIxU27dvV0optXv3bhUZGalyc3OVUkpNmDBBDRgwQF28eFGVlJSoCRMmqPfee++O3+P0\n6dPVZ599ppRSqqSkRB04cKDCtjx//rwKCwtTu3fvVgaDQX3++edqwIABqrS01LzepUuXVrjsrZYs\nWaImTpxYbtq///1v9fTTT5eb9swzz6h///vfSimlunXrpo4cOWKed/ToURUeHm7RsrcqLi5W4eHh\n6uLFi+ZpY8aMUWvXrlVKKfXWW2+pZ599VuXn56uioiI1bdo09f777yullDpy5Ijq1q2b+Teo0+nU\n+fPnK9z2vLw81b17d/X9998ro9Go1qxZo7p3767y8vLM5WNjY9XZs2eV0WhUBQUFKjw8XKWkpCil\nlMrOzjbvR7/2t7/9Tc2ePfu26QMGDFDffPONOnbsmOrYsaO6cuWKed6v41u+fLl69NFHzZ/btWtX\n7je4e/duFRUVZd7/5syZox577DGllFLffPONeuCBByqM7Va/3sbS0lK1detW82913759qkuXLurE\niRNKqbLf+l//+lel1+vV3r17VVhYmLpw4YJS6ubvKTIyUh09elQZjUb16quvqunTp5db57Rp09Ti\nxYvvGpslan0PwxIajYbnn38eR0dHevbsibOzM8OGDaNRo0b4+/sTERHBiRMnAPjf//7H1KlTadmy\nJVqtlqlTp5KcnExmZuYd61e39HDWrFnD5MmTCQoKwtnZmenTp7N27doKz4/b29uTnZ1NWloadnZ2\ndOvWrcI6K+Lg4MC0adOws7Nj6NChXL16lccffxxnZ2dat25N69atK+0FVGTSpEn4+Pjg7u5ObGws\nJ0+eBGDp0qWMGzeO0NBQNBoNI0eOxNHRscKjli1bttCkSRNGjhyJRqOhffv2DBw4kPXr199W1pJ6\nJ0yYgJeXF35+fkRERNClSxdCQkJwcHBgwIAB5hi///57+vXrR58+fYCbR+6dOnVi27Zt5rpGjx5N\ns2bNcHR0ZMiQIeZlf3Frm9vb25Oeno5Op8PR0ZHw8PAK22zdunX069eP6Oho7OzsePLJJ7lx40a5\nI+ff6vr167i5uZWb1rBhQ3OP9dfz3dzcuH79ukXL3qpBgwbExcWZj95TUlK4cOGC+fTOsmXLmDlz\nJm5ubri4uDB16lRz2WXLlhEfH090dDQAfn5+tGzZssLt2bp1Ky1atGD48OFotVqGDRtGq1at2LJl\ni7nMqFGjCA4ORqvVYmdnh52dHadPn6akpAQfHx+Cg4MrrLugoABXV9dy0/bv309GRgZDhgyhY8eO\nNGvW7Laj83uxZs0a4uPjzfvf9OnTOXz4MBkZGVy9evW2XmJMTAzdu3enc+fO5f5+3LqN9vb2xMTE\n0KRJE+BmL7dXr17s37/fXF6j0fC73/0OBwcHunfvTkxMTLle+8CBA+nUqRNarZbhw4fftl+7urqS\nn5//m7f7VvZVUkst4O3tbf5/gwYN8PHxMX92cnIy/9AyMjJ4++23+dvf/gaUnd/W6XQEBgbedT2X\nL1+mcePG5s9BQUEYDAZycnLw8/MrV/app57in//8J1OmTEGj0fDwww9XeC67Ip6enuZzrA0aNKhw\nG3/ZJkvduryzszPZ2dnAzTZZtWqV+S4fpRQGg8HcHb9VRkYGhw8fJjIy0lzWaDQycuTICsverd5b\nY3Jycrrt863f27p168x/fH6p65c/ZEC579zZ2bnS9pkxYwYffvgh8fHx5usMFd2B8+vvW6PREBgY\niE6nu2PdlnJxcbntD3xhYSENGzascH5hYSEuLi4WLftrw4YN4+9//zvPPfcca9asoX///jg6OpKb\nm0txcXG5bTeZTObkmpWVRUxMjEXb8+u2AmjcuHG5tgoICDD/39nZmQ8++ICFCxcya9YsunXrxowZ\nM2jVqtVtdbu7u1NUVFRu2qpVq+jduzceHh7mbVy5ciWPP/64RfFWFH/Hjh3Nn11cXPD09ESn0+Hp\n6Xnb72Hbtm0YjUY6depU7mDk1m38pdy8efNISUnBZDJx48YN2rVrV27bnJyczJ8bN25cbl1326+L\niopwd3f/Tdv8a/UmYVgqICCAZ5991qKL5xVdFPPz8yt3zSM9PR17e/tyX+ovXFxc+MMf/sAf/vAH\nzp07x8SJE+ncuTM9evTA2dmZ4uJic9ns7OzbdrTqEhAQwLRp03jmmWfuWjYwMJCoqCgWLlxYpfVa\nst6RI0cye/bse162ou/R29vbfG3rwIEDPPHEE0RGRtK0adNy5fz8/Dhz5ky5aZmZmVXyXbVp04ZF\nixaVm3b69GnzNZfWrVuTnJxMaGgocPMaVJs2bSpdtqI7tQB69+7NzJkzSU5OJiEhgVmzZgHQqFEj\nnJ2dWbNmzW0HPHDzO7zTnUS/blc/Pz82btxYblpGRgZ9+/a94zK9evWiV69e6PV6PvjgA9544w2+\n/vrr29bVrl07Vq1aZf5cUlLCunXrMJlM9O7dG7h5I0B+fj6nTp2iXbt2uLi4cOPGDfMyvxwg3cmv\nf9vXr18nLy8Pf39/PD09eeuttzh+/Hi5pAK3ny24dRv1ej0vv/wy7777LnFxcWi1Wp5//vlyy+Tn\n53Pjxg3zgWFmZiZt27atNNZbnT9/nieffNLi8pWpF6ek7nZ651bjx49n/vz5nD17FrjZ1a3odAqA\nl5cXWq2W1NRU87Rhw4axaNEi0tLSKCoq4oMPPmDYsGFotbc39datW83Luri4mLvgcPOi95o1azCZ\nTGzfvp19+/ZZvA134+Pjc0+3Cz7yyCN8++23JCUlATd/KNu2bavwCL1fv35cuHCBVatWYTAYKC0t\n5ejRo+YLwb+13rsZMWIEmzdvZufOnZhMJkpKSti7d69FR/oVfY/r1683L+vu7o5Wq63wOxwyZAhb\nt25lz549GAwGFi5ciJOTE2FhYRbFbTKZ0Ov1GAyGcv8HiIyMRKvVsnjxYvR6vbknFhUVBcDIkSNZ\ntGgROp0OnU7HokWLGD16dKXL9ujRo8I47OzsGDRoEH//+9/Jz8+nV69eAOae79y5c8338+t0Onbu\n3AlAfHw8y5cvZ8+ePSil0Ol05u/61/tZTEwMFy9eJCEhAaPRyNq1azl//jyxsbEVxnTlyhU2b95M\ncXEx9vb25t9IRXr16sXx48fNF+N/+OEH7OzsWLduHatWrWLVqlWsXbuWbt26sXLlSuDmb2zjxo3c\nuHGDixcv8t1335Wr89fxP/jggyxfvpzk5GT0ej3vv/8+Xbp0oXHjxrRs2ZKxY8cyffp0du/eTUlJ\nCSaTiYMHD1Z6t1VpaSmlpaU0atQIrVbLtm3b2LVrV7kySik+/vhjSktL2b9/P1u3bmXIkCF3rPNW\nOp2Oa9eu0aVLF4vK302tTxiW3Pr26zKVfe7fvz9PP/00r7zyChEREYwYMeKOt8U1aNCAadOmMX78\neCIjI0lKSiI+Pp6HHnqICRMmMGDAAJydnXn99dcrXD4lJYXJkyfTtWtXxo8fz2OPPUb37t0B+NOf\n/sTmzZvp3r07CQkJ9O/f/7628Vbx8fGcPXuWyMhIXnjhhbuW79SpE3PmzGH27NlERkYyaNCg2+44\n+YWrqytffvkla9eupU+fPvTp04f33nuv3F01ltZ7L9sUEBDAvHnzmD9/PtHR0cTGxvLll1+aDxYq\nW7ai7/Ho0aM8/PDDhIeH8/zzz/OnP/2JoKCg25Zt2bIl7777LnPmzCE6OpqtW7fy2WefYW9vf9f1\nws3TJp07d2b27NkcOHCALl26mB86dXBwYN68eaxYsYLIyEiWL1/OvHnzzHWPGzeO2NhYRowYwYgR\nI4iNjeWRRx6xaNmKDBs2jJ9++okhQ4aUS46///3vad68OY888ggRERFMmTKFlJQUADp37szcuXOZ\nO3cu3bp1Y9KkSebz9ZMmTWL9+vVERUXx9ttv4+npyWeffcbChQvp0aMHCxcuZP78+eZTRr9uK5PJ\nxL///W/69u1Ljx492LdvH2+++WaFsXt7e9OjRw82bdoE3Lw7asyYMfj7++Pt7W3+99hjj7F69WpM\nJhOTJ0/GwcGBXr16MXPmTIYPH16uzhdffJEZM2YQGRnJ+vXriY6O5uWXX+bFF1+kT58+pKWl8f77\n75vL//nPf2bixIm88847REVFERMTw8cff8yHH35oPhX36210dXXlT3/6Ey+//DKRkZGsXbv2tluD\nfX198fDwoE+fPsyYMYPZs2fTokWLO36Pt1q9ejWjRo3CwcHBovJ3o1H3cvh9j2bNmsXWrVvx9va+\n48WmxMRE3nnnHQwGA40aNWLx4sXWCkcIUYedO3eOP/7xjyxdutTWoVSZvXv3MmPGDLZu3XrPy+r1\nekaOHMlXX32Fl5dXlcRj1YSxf/9+XF1dmTFjRoUJo6CggHHjxvHll1/i7+9Pbm5ulW2YEELUdveT\nMKzBqqekIiIiKr06v3r1agYOHIi/vz+AJAshhKjBbHoNIyUlhWvXrjFx4kTGjBljvhglhBDi5s0L\nNaV3ATa+rdZoNHLixAn+7//+j+vXrzNu3Di6du1K8+bNbRmWEEKICtg0Yfj7+9OoUSOcnJxwcnIi\nIiKC5OTkuyaMXx6mE0IIUX2snjAqu6YeFxfHW2+9hdFoRK/Xk5SUxBNPPHHXOjUaDdnZBXctVx/4\n+rpJW/xM2qKMtEUZaYsyvr5udy9UCasmjFdffZXExETy8vLo168fL774IqWlpWg0GsaOHUtwcDC9\ne/dmxIgRaLVaHnnkkUqHghZCCGE7Vr2t1prkiOEmOXoqI21RRtqijLRFmfvtYdT6J72FEEJUD0kY\nQgghLCIJQwghhEUkYQghhLCIJAwhhBAWkYQhhBB1VFXfBCsJQwgh6qD07EJe+HAHe05kVVmdkjCE\nEKIOWrnjAsUlBlwbVM3Lk0AShhBC1GqLN5zi201nyp1+StUVcOB0Nq0au9OpZdW9NsKmgw8KIYT4\n7c6mX2PLoXQA/Bs5ExveBIBVOy8AMLJ3yyodqFUShhBC1FIb9qYC4Oig5ZtNZ2gR6I5Wo+HQmRyC\ng9zpWIW9C5BTUkIIUStdvnqdg6eyaR7gxgujQjEaFZ+uPMaSLWcBGNmnVZW/BkJ6GEIIUcMopZi3\n8hhXC0ro0MKLTi29aNXYHXu7smP8jfsuoYDBkc3o1MqbB3u2YPXuFHKu3aBNEw86NG9U5XFJwhBC\niBrmXEY+B05lA3A+I581u1Nwc3HgyWHt6RzsQ2FxKTuTMvF2dyIixBeAh3q35Gz6NU5evGqV3gVI\nwhBCiBpn19FMAJ4b2Qk7Ow3HLuSy40gmHy5NYlh0c+y0GvQGEwO6N8NOe7PXodVqeGlMZzKuFNEy\n0N0qcUnCEEIIG7mYVcD5jGv06xpk7hHoS43sPamjkZsT4W190Wo1dG3jS9/Ojfl05TESfroIgLOT\nPX06B5arz8nRzmrJAuSitxBC2IS+1Mi/Vhxl8cbTJJ7UmacfOpNDcYmRnp0C0GrLTis1D3Djz5O7\nE9725imouG5BODtV7zG/9DCEEMIGNuxNJefaDQCWbD5Ll2AfnJ3szaejenYKuG0Zlwb2PD+qE5cu\nF9LEt2G1xgvSwxBCiGqXm3+DhD0XcXdxYGD3puQV6lm9K4WrBSUcT8klOMidQG/XCpfVaDQ083cr\n1/uoLlZNGLNmzaJnz54MHz680nJJSUl06NCBjRs3WjMcIYSoEZZtO4e+1MSYmGBG922Fj0cDfth/\nieXbzqEU9AoNvHslNmDVhDF69GgWLlxYaRmTycR7771Hnz59rBmKEELUCGfTrrHnuI7mAW706hyI\no4Md4/u3wWhS7DqWhb2dlsgQP1uHWSGrJoyIiAjc3Su/Yr948WIGDRqEl1fVPsIuhBA1jUkpvtl0\nGoBH+7dB+/OdUWGtfegc7A1AeFsfXKpwhNmqZNNrGDqdjh9//JHx48fbMgwhhKgWe45ncSGzgMj2\nfrRp4mmertFomDCgLV2CvRkW3cJ2Ad6FTe+Smjt3Lq+99pr5/uOqfjuUEELUFCWlRr7bdh4Hey0P\n92t923wfT2defriLDSKznE0TxrFjx3jllVdQSnH16lW2b9+Ovb09cXFxd13W19etGiKsHaQtykhb\nlJG2KFMT2uLbH05xtaCEh+PaENLa19bh/CZWTxiV9Ro2bdpk/v/MmTOJjY21KFkAZGcX3HdsdYGv\nr5u0xc+kLcpIW5SpCW1xtaCEpZtO4+7iQL/OgTaL534Tp1UTxquvvkpiYiJ5eXn069ePF198kdLS\nUjQaDWPHjrXmqoUQosZYseM8+lIT4+PaVPvT2VXJqpG/9957Fpd95513rBiJEEJUv5JSI9sOpbMr\nKZMmvq706dzY1iHdl9qb6oQQwoYMRhMnUq4S0swTRwe7cvOKSwxsOZTOhr2pFFwvxdFBy4SB7Wzy\ndHZVkoQhhBC/wfJt51m/N5UALxeeHNae4CAPlFLsP5XNf388zbVCPc5OdjzYszkDIpri5uJo65Dv\nmyQMIYS4R5lXivhh/yWcnezR5V5n7lcH6N+tKZm5RRw7n4u9nZbhPVswKLJpjX0I77eQhCGEEPfo\n201nMZoUU4aG4ObiyJcJJ/lh/yUAOrb0YsLAtvg3crFxlFVPEoYQQtyDI2dzOHr+Cu2bNyK8rS8a\njYb/NyWSH/ZfIsDLhW7tfK3yetSaQBKGEEJYyGA08e2mM2g1Gsb3b2NODE6OdjzYs4Vtg6sG8j4M\nIYSwgMFoYvm28+iuFhPbNcgmLzCyNelhCCFEJZRSHDqTw9Kt59DlXsfD1ZGH+rS0dVg2IQlDCCF+\n5WpBCeczrnE+I58TF69yMasArUZDv65BPNS7JQ2d686dT/dCEoYQQvys1GDi281n2HIw3TxNA3Rt\n40N8v+A7vja1vpCEIYSol86k5eHsaE9jX1e0Gg05ecXMW3mMlKwCGvu4Et3Rn1aB7rQIdK/V4z9V\nJWkFIUS9s+NIBv9elwxAQ2cH2jTx4PSlPIpuGOjVKYAJg9rh9KvhPoQkDCFEPXMq9Sr/2XAK1wb2\ndGntw6nUqxw6k4O9nZbJQ0Lo0zmwzj5Hcb8kYQgh6g3d1et8svwoAM+PCiWkeSOUUuRcu4GTgx3u\nrrV/vCdrkoQhhKgXim6U8tHSJIpuGJg8JISQ5o2Am+/T9vV0tnF0tYM8uCeEqPOMJhOfrTxGVu51\nBkc2o2+X2v1eCluRhCGEqPOWbjnH8ZSrdA72Jr5fsK3DqbUkYQgh6rQf915k475LBHq78MyIjrX+\nJUa2JAlDCFGrXS0oYfPBNIpulN4270xaHv9aloRrA3teiu8sz1PcJ6u23qxZs9i6dSve3t6sXr36\ntvmrV6/m888/R6PR4OLiwl/+8hfatWtnzZCEEHVIid7I+0sOk55dxIrt5xkW3YK4bkHorhazds9F\nEk/o0Gg0TBvZqU6+n6K6aZRSylqV79+/H1dXV2bMmFFhwjh8+DDBwcG4ubmxfft2PvnkE5YsWWJR\n3dnZBVUdbq3k6+smbfEzaYsy9aEtlFIsWH2CxBM62jdvxMWsAq6XGGjo7EBh8c3eRhNfV6aM6EQL\n3/o9pMcvfH3d7mt5q/YwIiIiSE9Pv+P8sLCwcv/X6XTWDEcIUYf8eCCNxBM6goPceeWRLpSUGkn4\n6SJbDqUTHOTOsOgWdAn2xs/Pvc4nz+pSY07oLV26lL59+9o6DCFELXD6Uh5LNp/F3cWB50aGYm+n\nxd5OyyOxrXkktrWtw6uzakTC2LNnD8uXL+e///2vxcvcb9eqLpG2KCNtUaautsW5tDzmrTyGAv44\nOZK2rXzuukxdbYvqZvOEkZyczJ///Ge++OILPDw8LF5Oupg31Ydz1ZaStihTV9vibPo1PlhyhBsl\nBh4fEkKAu9Ndt7OutsVvUaOvYcDNC1N3kpGRwUsvvcTf//53mjVrZu1QhBC1WPLFq3y0LIlSg4mn\nhncgumOArUOqd6yaMF599VUSExPJy8ujX79+vPjii5SWlqLRaBg7dizz5s3j2rVr/L//9/9QSmFv\nb8+yZcusGZIQopZJyy7kx/1p7D6WiVLw7MhOdGvna+uw6iWr3lZrTdLFvEm622WkLcrU9rYwmkwk\nnb3CjwfSOHnxKgA+Hg2YNLgdnVp631Ndtb0tqlKNPyUlhBCWulpQwq6jmWw9nE5ufgkA7Zs3on+3\nJnRp7SPDetiYJAwhhM2YTIp9yZc5kZLL6Ut56K4WA+DkYEds1yBiuwbRxK+hjaMUv5CEIYSwCX2p\nkc9Xn+DA6WwAGjjaEdrKm87B3vTsFCDjPtVA8o0IIapdwXU9//zuKGfTrxHSzJNHHmhNMz83OeVU\nw0nCEEJUq8t5xXyw5Ai63Ov06ODPE0Pb42AvA2fXBpIwhBDVJj2niH98e4hrhXqG9mjO6JhWaDXS\nq6gtJGEIIapFqq6Af3x7mMLiUsY90JqBkfKwbm0jCUMIYTUmk6KguJSLWQUs+P44xSUGHh/cjpiw\nIFuHJn4DSRhCiCp3IiWX/2w4RU7eDUw/Pxus1WhkSI9aThKGEKJKHTiVzfzvjwHQqrE7Hg0d8XR1\nIrydL+2bN7JxdOJ+SMIQQlRqx5EMDp3JoW9YYzoHe1d6kXrHkQwWrU/G0d6OF8eE0qGFVzVGKqxN\nEoYQ4o5+EQoNAAAgAElEQVRKDSaWbj1HYXEph8/mEOjtQv9uTWjo4oi+1Ii+1Ej+9VKuFtzgyrUb\nHE+5imsDe155JIxWjd1tHb6oYpIwhBB3dOhMNoXFpUS298PeTkviCR2LN56+Y3n/Rs68MDqUIF8Z\nzqMukoQhhLijHUmZAIzo1ZLGPq6M7tuKg6ez0Wg0ONprcXDQ4ubiiJebE54NnWQ4jzpOvl0hRIVy\n8oo5cSGX1kEeNPZxBcDLvQH9I5raODJhK/I8vhCiQjuPZqKAPl0CbR2KqCEkYQghbmMyKXYezaSB\nox3dQ/xsHY6oISRhCCFuczwll9z8EiLb+9PAUc5ci5skYQghbrP9SAYAfbs0tnEkoiaxasKYNWsW\nPXv2ZPjw4Xcs89ZbbzFw4EAeeughTp48ac1whBB3kXmliM9WHePgqWyCfF1pGXh/74AWdYtV+5qj\nR49m4sSJzJgxo8L527ZtIzU1lY0bN3LkyBHefPNNlixZYs2QhKi3SvRGPlt1DEcHO8Ja+xAa7I1r\nA3uuXLtBSlYBh87ksOdEFkpBc383Jg8JQSNDj4tbWDVhREREkJ6efsf5mzZtYuTIkQB06dKFgoIC\ncnJy8PHxsWZYQtRLq3encOTcFQD2JV9Gq9Hg7GRH0Q2DuUwTX1ce6t2K8LY+kizEbWx6Nevy5csE\nBJSNXOnv749Op5OEIUQVy7xSxIa9qXi7N+D50Z04fiGXw2dzKLheSocWXrQIcKNloDttm3nKC43E\nHdk0Yaifhz2+laVHNb6+cm71F9IWZaQtyvzSFkopPvouCaNJMW1MZ7p3CqR7aP16H4XsF1XDpgnD\n39+frKws8+esrCz8/Cy75zs7u8BaYdUqvr5u0hY/k7Yoc2tb7D2p48iZHDoHe9PKz7XetZHsF2Xu\nN3Fa/bbainoRv4iLi2PlypUAHD58GHd3dzkdJUQVKiwu5dtNZ7C30/Jo/zZyXULcF6v2MF599VUS\nExPJy8ujX79+vPjii5SWlqLRaBg7diwxMTFs27aNAQMG4OzszDvvvGPNcISos0r0Rq6XGCg1mjAY\nTBy9mMeWfakcu3AFg1ExolcL/Bq52DpMUctpVGVdgBpMupg3SXe7TH1rC6UUpy/lseVQOgdOZWM0\n3f5TbuLrSlQHfwZFNsPern4+p1vf9ovK3O8pKXnmX4ha6PSlPBZvOEV6ThEAgd4uNPN3w95Og4Od\nlqaBHoQ0cSfQ29XGkYq6RBKGELXMoTPZfLryOCaTIrK9H7Fdg2jb1LPc9Qk5qhbWIAlDiFpk19FM\n/r02GXt7DS+N6UynVt62DknUI5IwhKihDEYTmw+kkX3tBiaT4nqJgcQTOlwb2PO7h7sQHORh6xBF\nPSMJQ4ga6rtt59iw91K5aY3cnHjlkS40kXdmCxuQhCFEDXToTDYb9l7C38uFaSM6Ym+vxU6rwdvd\nCQd7O1uHJ+qpuyaMS5cusWzZMhITE8nKysLJyYmQkBAGDRrEwIEDsbeXnCNEVcrJK2bhmpM42Gt5\nbmQnmvpJb0LUDJX+tf/zn//M8ePHGTx4ML///e/x8fGhpKSEc+fOsXPnThYsWMBf/vIXwsLCqite\nIeo0g9HEp6uOc73EwOQhIZIsRI1SacKIi4tj9uzZt01v164dQ4cOJS8vj0uXLlWwpBDiXhmMJhYm\nnORCZj7RHQPo0znQ1iEJUU6lCSMmJqbShT09PfH09KzSgISoj0oNRj5deZzDZ3MIDnJn4qC2Mu6T\nqHEsGivgr3/9KwUFBRgMBh599FHCwsJYtWqVtWMTol64oTfw4dIkDp/NoUOLRvx+bFcaOMq1QVHz\nWJQwdu/ejZubGzt37sTf358NGzbw5ZdfWjs2Ieq8VF0Bf/3qICcvXqVrGx9eju+Mk6PcBSVqpns6\njNm3bx8DBgzA399fustC3IdSg5Hvd6Wwbk8qJqWICWvMhIFtsdPWzwECRe1gUcLw9vbm9ddfZ9eu\nXUydOhWDwYDRaLR2bELUOfpSI3tO6Fi35yK6q8V4uzfg8cHtZIgPUStYlDDee+89vv/+e+Lj4/Hw\n8CAtLY0nnnjC2rEJUWcUFpeycV8qWw9lUFhcilajoX9EE0b3bSXXK0StYdGe6uXlxeTJk82fmzRp\nQpMmTawVkxA12rXCErRaDW4ujhaV1129zgdLjnD5ajGuDewZFt2c2K5BeLk3sHKkQlStShPGc889\nx7Rp0+jcufNt8woLC/nuu+9o0KABY8eOtVqAQtQkJXojbyzcyw29gZ6dAhkU2bTSd06cS7/GR8uS\nKCwuZUiPZozo1RInB7moLWqnShPGSy+9xHvvvUdKSgqdO3fG29ubkpISzp8/T3p6OuPGjWP8+PHV\nFasQNrc3WUdhcSmODlq2H8lgx5EMQoO96djSi3ZNPWni2xC9wYgut5jzGdf43+azlBpNTBrcjn5h\nQbYOX4j7UmnCCAkJ4fPPPyczM5O9e/ei0+lwcnJi8ODBdOvWDUdHy7rkQtQVO45kogFmPxlFalYB\n6xIvknTuCknnrgDgaK9FbzCZyzs6aHlpTGe6tPaxUcRCVB2LrmEEBgby0EMP/aYVbN++nblz56KU\nYsyYMUydOrXc/MzMTP7whz9QUFCAyWRi+vTpd33CXAhbSM8p4mz6NTq29MLP0xk/T2e6tfMl+9oN\nzlzK49SlPFIyC3B3dcDfy4WARi6EBnsT4OVi69CFqBIWJYwrV67wzjvvkJmZyddff01ycjKHDh26\n6+kok8nEnDlzWLRoEX5+fsTHxxMXF0dwcLC5zKeffsrQoUMZN24c586d4+mnn2bz5s33t1VCWMGO\nIxkA9O3S2DxNo9GYk0evUBn7SdRtFj0l9Prrr9OtWzfy8/MBaNWqFf/973/vulxSUhLNmzcnKCgI\nBwcHhg0bxqZNm8qV0Wg0FBYWApCfn4+/v/+9boMQVldqMLH7WBYNnR3o2kZOL4n6yaKEodPpGD9+\nPHZ2N+/ucHR0RGvBE6k6nY7AwLKjLn9/fy5fvlyuzAsvvMCqVauIiYlh2rRpvPHGG/cSvxDV4vDZ\nHAqLS+kVGoC9nTyNLeoni05J/folSfn5+Sil7rqcJWUSEhIYM2YMkydP5vDhw7z22mskJCTcdTlf\nX7e7lqkvpC3KWKst9iw/CsBD/drUmvauLXFWB2mLqmFRwhg4cCB//vOfKSoqYvny5fz3v/9lzJgx\nd10uICCAjIwM82edToefn1+5MsuWLWPhwoUAhIWFUVJSQm5uLl5eXpXWnZ1dYEnodZ6vr5u0xc+s\n1RapugIOn86mdRMPGmhrx74n+0UZaYsy95s4LepbP/XUU0RERNCxY0e2bdvGxIkTefzxx++6XGho\nKKmpqaSnp6PX60lISCAuLq5cmcaNG7N7924Azp07h16vv2uyEKI6KKXYcjCNtxcfQAEDIpraOiQh\nbEqjLDlvdB+2b9/O22+/jVKK+Ph4pk6dyscff0xoaCixsbGcO3eO119/nevXr6PVapkxYwbR0dF3\nrVeOGG6So6cyVdkW+UV6Fq1L5vDZHFwb2DN5SHu6tfOtkrqrg+wXZaQtytxvD8OihHHlyhW++uor\nUlNTMRgM5ukfffTRfa38fsgOcJP8GMpURVvkX9ezYW8qmw+kU1JqJKSZJ08P70gjN6cqirJ6yH5R\nRtqizP0mDIuuYTz33HN06NCB6Oho851SQtQl6dmF7EjKZNvhDEpKjXg0dCS+XzCxXYPQauXdL0KA\nhQmjuLiYN99809qxCFGtSkqNbD+cwe5jWVzU3TwC9WjoyJiYVsSENcbBXg6OhLiVRQmjS5cunDp1\ninbt2lk7HiGqRfLFqyxal8zlvGLstBq6BHvTMzSQsNbekiiEuAOLEsa4ceOYMGECAQEBODmVnctd\ntmyZ1QIToqqUGkwUXNejN5go0RvZejidbYcz0GhgUGRThkQ1x91VBtIU4m4sShivvfYa06ZNo0OH\nDnINQ9Qql/OK+etXB8gr1Jeb3sTXlSeGtqdloLuNIhOi9rEoYTg5OfHkk09aOxYhqpS+1Mi85UfJ\nK9TTtY0Pbi4OONrb4e/lQkxYYxniQ4h7ZFHC6NOnD9u3b6dv377WjkeIKqGUYvHGU6ReLqRvl8ZM\nHhJi65CEqPUsShhLlixhwYIFuLq64ujoiFIKjUbDTz/9ZO34hPhNth/JYNfRLJoHuPHYgDa2DkeI\nOsGihPHdd99ZOw4hqkROXjE/ndCxetcFXBvY8/yoTnLXkxBVxKKEERQk7yIWNVvSuRx+WHKE4+d/\nflWqg5ZnHuqIj4ezjSMTou6oNGG89tprvPvuu4wZMwaN5vanXeW2WmFrJqX4fucFvt+VAkBIM0+i\nOwbQrZ0fLg0sOh4SQlio0l/ULy87+sMf/lAtwQhxL4pLDCxMOMnB09n4eDTgjSd74OYodz4JYS2V\nJoxfXskaGRlZLcEIYYkSvZHEkzrWJ6aSlXudkGaePDuyE62CPGSQOSGsSPrsotbIKywh4aeL7D6W\nSXGJEY0G+ndrwiMPtJZnKoSoBpUmjNOnT1f4bgq5rVZUtwOnsvm/9ckUFpfi2dCRARFN6dulMV7u\nDWwdmhD1RqUJo0WLFixYsKC6YhHiNsUlBr758Qw7j2biYK/lsQFt5SltIWyk0oTh6Ogot9QKmzAY\nTWw7nMGa3SlcK9LT3N+NqSM6EOjtauvQhKi3Kk0YDg4O1RWHEMDN22R/OpbFqp0XyLl2AycHOx7q\n3ZJh0c2lVyGEjVWaMJYsWVJdcQhBenYh/7fhFGfTrmFvp2Vg96YM7SFDjwtRU1j9Lqnt27czd+5c\nlFKMGTOGqVOn3lZm7dq1/Otf/0Kr1dKuXTv+8Y9/WDssUYPkF+n5Yf8l1iemYjQpurXzZdwDbfD2\nkAvaQtQkVk0YJpOJOXPmsGjRIvz8/IiPjycuLo7g4GBzmYsXL/LFF1/wv//9j4YNG5Kbm2vNkISN\nGIwmVu9KQW8w4uxkj4uTPVfyb3Ai5SqXLhcC4O3uxGMD2hHWxsfG0QohKmLVhJGUlETz5s3NF86H\nDRvGpk2byiWMJUuW8Oijj9KwYUMAvLy8rBmSsJENe1NZvTvltun2dlo6tGhEp5be9OvamAaO8miQ\nEDWVVX+dOp2OwMBA82d/f3+OHj1arkxKSgoA48ePRynF888/T58+fawZlqhmV67dYPWuFNxdHHhh\ndGdKDEau3zDg2sCe1kEeODrIaLJC1AZWTRhKqbuWMRqNpKam8vXXX5ORkcFjjz1GQkKCucdxJ76+\nblUVZq1X09vi84ST6A0mnn+4C9Fdm1h1XTW9LaqTtEUZaYuqYdWEERAQQEZGhvmzTqfDz8+vXBl/\nf3+6du2KVqulSZMmtGzZkpSUFDp16lRp3TJm0E2+vm41ui2Onr/CT0czadPEg07NPK0aa01vi+ok\nbVFG2qLM/SZOq97YHhoaSmpqKunp6ej1ehISEoiLiytXpn///uzZsweA3NxcLl68SNOmTa0Zlqgm\npQYTX/9wGq1Gw4SB7SocIl8IUXtYtYdhZ2fHG2+8wZQpU1BKER8fT3BwMB9//DGhoaHExsbSp08f\ndu3axbBhw7Czs2PGjBl4eHhYMyxhRUopUnWFHDydzcHT2Vy+Wkz/iCY09av8FKMQoubTKEsuNNRA\n0sW8qSZ1t3OuFfOvFce4mHUzHns7LWGtvXliaHucnax/91NNagtbk7YoI21R5n5PSck9jKJKnEnL\n41/Lj5J/vZSubXyI7hhAp1ZecpusEHWI/JrFfTEpxc6kTBZvOIVS8NiAtjwQHiTXK4SogyRhiHtm\nMilOpl7l4Kmb1ymuFelxcbLn2VGd6NhCHrwUoq6ShCHuSeaVIr5Yc5ILmTdf39vQ2YHenQMZFt0c\n/0YuNo5OCGFNkjCERUxKselAGsu2nqPUYKJ7iB+xXYNo09QDO60MOy5EfSAJQ9xVzrVivkw4SXJq\nHg2dHXj6wQ5EhPjdfUEhRJ0iCUPckVKKHUmZfLvpDDf0RsJa+/D44HZ4NHSydWhCCBuQhFHPGYwm\ntBoNWm35u5oycopYsuUsSeeu4Oxkx5PD2tOzU4Dc/SREPSYJox7LL9Lz1n/2U1JqJCLEj6j2/rg2\nsGf17hT2nbyMAjq2aMQTQ9vj5S4vMxKivpOEUU+ZlOLzNSduvjfb0Y4tB9PZcjDdPL+ZX0OG92pJ\neFsf6VUIIQBJGPXW2p8ucvxCLqGtvHlxTCinUvNIPKEjr6iE2K5BhLWWRCGEKE8SRj10+lIeK3ac\np5GbE0892B57Oy0dW3rRsaU8dCeEuDNJGPWIUoqz6df4bNUxNGh4ZkRH3FwcbR2WEKKWkIRRD5Qa\njOw/lc0P+y6R8vNIsg/3C6ZtU08bRyaEqE0kYdRRpy/lcfB0NmfTr3ExqwCjSaEBurbxYWD3prRr\n1sjWIQohahlJGHXQ0fNX+HDJERRgp9XQzL8hIc0aEdM1CD9PZ1uHJ4SopSRh1DGX84pZ8P1x7Oy0\nPPtQRzq09MLJwc7WYQkh6gBJGHVIid7IJ98dpeiGgSeGhNC1ra+tQxJC1CEyzGgdoZRi0fpk0rIL\n6dc1iD5dGts6JCFEHWP1hLF9+3YGDx7MoEGDWLBgwR3LrV+/npCQEI4fP27tkOqU4hIDWw+n8+aX\n+0g8oSO4sTuP9m9j67CEEHWQVU9JmUwm5syZw6JFi/Dz8yM+Pp64uDiCg4PLlSsqKuKrr74iLCzM\nmuHUKSaTImHPRTbsTeX6DQN2Wg0RIX481r8N9nbScRRCVD2rJoykpCSaN29OUFAQAMOGDWPTpk23\nJYyPPvqIp59+mi+++MKa4dQZeYUlLPj+OMmpeXi6OdG/WxNiwoJo5CbDjgshrMeqh6I6nY7AwEDz\nZ39/fy5fvlyuzMmTJ8nKyiImJsaaodQZxy5c4S9f7iU5NY+w1j7Mm/EAI/u0kmQhhLA6q/YwlFJ3\nnT937lz+9re/WbzML3x93e4rttomJTOfxWtPsvdEFvZ2Gp5+qBPD+7RCo9HI8B63qG/7RWWkLcpI\nW1QNqyaMgIAAMjIyzJ91Oh1+fmWv9iwqKuLs2bNMnDgRpRQ5OTk899xzfPrpp3Ts2LHSurOzC6wW\nd02Sqitgw95U9hzXoYC2TTwY178NLQLcyckpxNfXrd60xd1IW5SRtigjbVHmfhOnVRNGaGgoqamp\npKen4+vrS0JCAu+//755fsOGDfnpp5/MnydOnMjMmTPp0KGDNcOq8fSlRvaevMzWw+mcz8gHbr6f\nYnRMMKGtvGTYcSGETVg1YdjZ2fHGG28wZcoUlFLEx8cTHBzMxx9/TGhoKLGxseXKazQai09J1UUm\nk2LXsUxWbD9PXqEeDdA52JuYsMZ0ae2DVhKFEMKGNKqW/oWua13Mkym5fLv5LJcuF+JoryUuogmx\nXYPw8ah87CfpbpeRtigjbVFG2qJMjT4lJe5OX2pkyZazbD6Yjgbo1SmAUX1byTu0hRA1jiQMG0q7\nXMj874+TnlNEkI8rTz7YnhYB7rYOSwghKiQJwwZMSrHpQBpLt5zDYDTxQHgQj8S2xlFGlRVC1GCS\nMKpZbv4Nvlx7khMpV2no7MATQzvStY2MKiuEqPkkYVQTfamR3ceyWLb1HNdLDHQO9uaJISF4NJQn\ntIUQtYMkDCvLyStm86F0dhzJoOiGAUcHLZMGtyOmS2N5nkIIUatIwrASk0mxLvEiK3dcwGhSuLk4\n8GDP5vQLC5I7oIQQtZIkDCu4cu0GX6w5walLeXg0dCQ+JpjI9n442MtFbSFE7SUJowoppdhzQsdX\nG09TXGIgvK0vjw9uJ4MDCiHqBEkYVSS/SM9/Npzi4OlsnBzsmDwkhD6dA+U6hRCizpCEcZ+UUuxL\nvsxXG09TWFxK26aeTBnWHj/Pyof0EEKI2kYSxn1Iu1zIf388TXJqHg72WsbFtaF/RBMZJFAIUSdJ\nwvgNCotLWbnjPFsOpaMUdAn2ZlxcG/y9XGwdmhBCWI0kjHtgMim2HU5n+fbzFN0wEODlwvj+bQht\n5W3r0IQQwuokYVjoVOpV/vvjGS5dLqSBox2PxLamf0QT7O2s+lp0IYSoMSRh3EVu/g2WbDnL3pOX\nAegdGsiYmFYypIcQot6RhHEHJXojG/elkrDnIvpSEy0D3Xh0QFuCG3vYOjQhhLAJSRi/YjCa2JmU\nyaqdF7hWpMfNxYHH+relV+dAuftJCFGvWT1hbN++nblz56KUYsyYMUydOrXc/EWLFrF06VLs7e3x\n8vJi7ty5BAYGWjus2xiMJhJP6Fjz00V0uddxdNDyYM8WDI5shksDyatCCGHVv4Qmk4k5c+awaNEi\n/Pz8iI+PJy4ujuDgYHOZDh06sHz5cpycnPjmm2/4+9//zgcffGDNsMopNZjYmZTBusRUcq7dwE6r\noV/XIEb0aoGnXKcQQggzqyaMpKQkmjdvTlBQEADDhg1j06ZN5RJGZGSk+f9hYWGsXr3amiGVk3ml\niM9WHefS5UIc7LXEhTdhcFQzvD1kNFkhhPg1qyYMnU5X7vSSv78/R48evWP5ZcuW0bdvX2uGZLbr\naCZfbTxNSamRPp0DGd1X7nwSQojKWDVhKKUsLrtq1SqOHz/O4sWLqzyOkxev8p/1yZQaTTg72qPR\naEjLLsTZyY5pD3Uksr1/la9TCCHqGqsmjICAADIyMsyfdTodfn5+t5XbvXs3CxYs4KuvvsLBwcGi\nun193Swqd+R0Nh8tS8JkMuHl4Uz+dT3Xbxjo2Mqb343rSoC3q2UbU4NZ2hb1gbRFGWmLMtIWVcOq\nCSM0NJTU1FTS09Px9fUlISGB999/v1yZEydO8Oabb7Jw4UIaNWpkcd3Z2QV3LXPswhX++d1RlFK8\nMDqUzsE+wM2ej0ajAZPJonpqMl9ft1q/DVVF2qKMtEUZaYsy95s4rZow7OzseOONN5gyZQpKKeLj\n4wkODubjjz8mNDSU2NhY3n33XYqLi3n55ZdRStG4cWPmzZt33+s+fCaHeSuPAfDimM7lxnuSd1QI\nIcS906h7udBQg2RnF1BSamRnUiYdWjQi8OdTS0op1iemsmzrOezttbw4JpROLevu4IBy9FRG2qKM\ntEUZaYsyNbqHYW3fbTvHj/vT0AARIX4MimzGjwcusee4jkZuTrwwOpSWge62DlMIIeqEWpswUnUF\nbDqQho9HA1ydHdiXfJl9yTcHCAxu7M7zo0PlwTshhKhCtTJhmEyKxRtPoRRMGtyOji28OHYhlw17\nU/Fv5MK4uDY42Muw40IIUZVqZcL4cV8q59LziQjxM1+fCG3lLS8yEkIIK6qVh+GL1pzAydGO8XFt\nbB2KEELUG7UyYRRc1zOyd0sauck1CiGEqC61MmE82Lslcd2a2DoMIYSoV2plwnhmVGd5l7YQQlQz\n+asrhBDCIpIwhBBCWEQShhBCCItIwhBCCGERSRhCCCEsIglDCCGERSRhCCGEsIgkDCGEEBaRhCGE\nEMIikjCEEEJYRBKGEEIIi1g9YWzfvp3BgwczaNAgFixYcNt8vV7PK6+8wsCBAxk7diwZGRnWDkkI\nIcRvYNWEYTKZmDNnDgsXLmTNmjUkJCRw7ty5cmWWLVuGh4cHGzdu5PHHH+fdd9+1ZkhCCCF+I6sm\njKSkJJo3b05QUBAODg4MGzaMTZs2lSuzadMmRo0aBcCgQYP46aefrBmSEEKI38iqCUOn0xEYGGj+\n7O/vz+XLl8uVuXz5MgEBAQDY2dnh7u5OXl6eNcMSQgjxG1g1YSil7rmMUgqNRmOtkIQQQvxG9tas\nPCAgoNxFbJ1Oh5+f321lsrKy8Pf3x2g0UlhYiIeHx13r9vV1q/J4aytpizLSFmWkLcpIW1QNq/Yw\nQkNDSU1NJT09Hb1eT0JCAnFxceXKxMbGsmLFCgDWr19Pjx49rBmSEEKI30ijLDlvdB+2b9/O22+/\njVKK+Ph4pk6dyscff0xoaCixsbHo9Xpee+01Tp48iaenJ++//z5Nmsj7uoUQoqaxesIQQghRN8iT\n3kIIISwiCUMIIYRFJGEIIYSwSK1LGHcbm6ouy8rKYtKkSQwdOpThw4fzn//8B4Br164xZcoUBg0a\nxJNPPklBQYGNI60eJpOJUaNGMW3aNADS0tJ45JFHGDRoENOnT8dgMNg4wupTUFDASy+9xJAhQxg2\nbBhHjhypl/vFokWLePDBBxk+fDivvvoqer2+Xu0Xs2bNomfPngwfPtw8rbL94K233mLgwIE89NBD\nnDx58q7116qEYcnYVHWZnZ0dM2fOZO3atXz77bd8/fXXnDt3jgULFhAdHc2GDRuIiopi/vz5tg61\nWvznP/8hODjY/Pkf//gHTzzxBBs2bMDNzY1ly5bZMLrq9fbbbxMTE8O6detYtWoVrVq1qnf7hU6n\nY/HixSxfvpzVq1djNBpJSEioV/vF6NGjWbhwYblpd9oPtm3bRmpqKhs3bmT27Nm8+eabd62/ViUM\nS8amqst8fX1p3749AK6urgQHB6PT6cqNxzVq1Ch+/PFHW4ZZLbKysti2bRsPP/ywedqePXsYNGgQ\ncLMdfvjhB1uFV60KCwvZv38/Y8aMAcDe3h43N7d6uV+YTCaKi4sxGAzcuHEDPz8/EhMT681+ERER\ngbu7e7lpv94PfvmbuWnTJkaOHAlAly5dKCgoICcnp9L6a1XCsGRsqvoiLS2N5ORkunTpwpUrV/Dx\n8QFuJpWrV6/aODrrmzt3LjNmzDAPI3P16lU8PDzQam/u0gEBAfVm30hLS6NRo0bMnDmTUaNG8cYb\nb1BcXFzv9gt/f3+eeOIJ+vXrR9++fXFzc6NDhw64u7vXy/3iF7m5ueX2g9zcXKD8OH5ws/10Ol2l\nddWqhCGPjNxUVFTESy+9xKxZs3B1da13Y29t3boVHx8f2rdvb94nlFK37R/1pV0MBgMnTpzg0Ucf\nZcWKFTg7O7NgwYJ6s/2/yM/PZ9OmTWzZsoUdO3ZQXFzM9u3bbytX39rlTir6e3q3trHqWFJVzZKx\nqcpAbb4AAAdwSURBVOo6g8HASy+9xEMPPUT//v0B8Pb2JicnBx8fH7Kzs/Hy8rJxlNZ18OBBNm/e\nzLZt2ygpKaGoqIi5c+dSUFCAyWRCq9WSlZVVb/aNgIAAAgICCA0NBWDgwIF8/vnn9W6/2L17N02b\nNsXT0xOA/v37c+jQIfLz8+vlfvGLO+0H/v7+ZGVlmctZ0ja1qodhydhUdd2sWbNo3bo1jz/+uHna\nAw88wPLlywFYsWJFnW+T6dOns3XrVjZt2sT7779PVFQU//jHP4iKimL9+vVA/WiHX/j4+BAYGMiF\nCxeAm9dyWrduXe/2i8aNG3PkyBFKSkpQSrFnzx7atGlT7/aLX/cc7rQfxMXFsXLlSgAOHz6Mu7u7\n+dTVndS6oUEqGpuqvjhw4AATJkygbdu2aDQaNBoNr7zyCp07d+Z3v/sdmZmZNG7cmI8++ui2C191\n1d69e/nyyy/57LPPuHTpEtOnTyc/P5/27dvz7rvv4uDgYOsQq0VycjJ/+tOfMBgMNG3alHfeeQej\n0Vjv9otPPvmEhIQE7O3t6dChA2+99RZZWVn1Zr949dVXSUxMJC8vDx8fH1588UX69+/Pyy+/XOF+\nMHv2bHbs2IGzszPvvPMOHTt2rLT+WpcwhBBC2EatOiUlhBDCdiRhCCGEsIgkDCGEEBaRhCGEEMIi\nkjCEEEJYRBKGEEIIi0jCEDXaAw88wNmzZ6tlXZ988km5oa9nzpzJ119/fd/1zpw5k+HDhzN9+vT7\nrqsyycnJrFu3zqrrEPWbJAwhfvbJJ59QWlpapXXm5OTw/9u7v5AmuziA49/ln7S8KOvWoghaI8KL\nihkJWon0R/Y8S2NYOL1IEFqE3gjRRZZEBcPyJqE/lDSIyBp2UV4IEVgGXeyiDKMVFnSRltTmaPr4\ney/Eh3KL9vYG7+vb73O182znnN/DYL+dHfY7fX199Pb2EgwGf+vYcz1//vyXE8b09PRvjkb9H2nC\nUPPS69evOXjwIDU1NRiGYZc+AHA6nXR1dVFdXU1FRQV9fX32c/fv32fnzp14vV66urpwOp0kEgna\n2tpwOBz4fD5M0yQWiwEwPDyM3++nsrKS1tbWH8Zz584dqqqq8Hg8BAIBPn78SDwex+/38/XrV0zT\n5OrVq9/1CYfDHDp0yG5blkVpaaldL+3ixYvs27cPr9dLU1MTY2NjAExOTnL69GmqqqowDINAIMD4\n+DidnZ08fvwY0zRpb28HZiojmKaJx+OhoaGBt2/fAjP/kDcMg5MnT+Lz+Xj48OE/eTvUn0KU+g8r\nLy+Xly9ffndtampKTNOUaDQqIiKxWEwqKyvt9tq1a+X69esiIvL06VMpLS0VEZHR0VHZvHmzjIyM\niIjIlStXxOl0ysTEhN0vkUjY87S2tkptba0kk0lJJpOye/duGRgYSIlxeHhYtm7dKqOjoyIi0tHR\nIUeOHBERkXfv3onb7U57b4lEQtxut3z69ElERPr7+8Xv94uISDgclmPHjtmvDYVC0tLSIiIinZ2d\nEggEZGpqSkTE7t/T0yOHDx+2+4yNjYnb7ZZXr16JiMjNmzelpqZGREQGBwfF5XJJJBJJG5tS6egK\nQ807b968IRqN0tzcjGEY7N+/n8nJye9OX9y1axcAxcXFfPjwgWQySSQSYf369RQVFQFQXV2dMrbM\nqZSzY8cOcnJyyMnJweVyMTIyktJncHCQsrIyli1bBoDP52NgYOCn95GXl8f27du5e/cuMFMYbvYQ\npP7+fh49eoRhGBiGQSgU4v3798BMefe6ujqysrIA7Oqsc0UiEdatW8fq1asB2Lt3L0NDQ0xMTACw\ncuVKNmzY8NM4lZo1r8qbKwUzH+qFhYXcvn077fMOh4OFCxcC2AfnWJaVkgzmttPJzc21H2dlZaU9\nD1pEUs4RmJ33ZwzD4NSpU+zZs4cnT55w9uxZe8ympia8Xm/a+TKRLq5v24sWLcpoHKVm6QpDzTur\nVq0iLy+PcDhsX4tGo8TjcSD1A3W2XVxczLNnz+zf8b/d9wAoKCjgy5cvfzuekpISHjx4YO8x3Lhx\ngy1btqTMn87GjRuJxWIEg0EqKirsRLdt2zZCoRCfP38GIJlM8uLFCwDKy8u5du2avUE/e5JeQUGB\nvfcye79DQ0N22fOenh5cLpcmCvXLdIWh/tMcDgf19fVkZ2fb35h7e3u5cOEC7e3tXL58GcuyWL58\nOR0dHXafuWPAzEEyx48fp7GxkaVLl1JWVkZ2djb5+fkANDQ0UFdXR35+Pt3d3RnHuGbNGpqbm6mv\nr2fBggUUFRXR1taWMv+PGIbB+fPnCYVC9jWPx8P4+DgHDhzA4XAwPT1NbW0tTqeTxsZGgsEghmGQ\nm5vLihUrOHfuHCUlJVy6dAnDMNi0aRNHjx7lzJkztLS0YFkWhYWF9gpGqV+h5c3VHyUej7N48WJg\n5hv3rVu3fst/LZT6E+gKQ/1Ruru7uXfvHpZlsWTJEk6cOPFvh6TUvKErDKWUUhnRTW+llFIZ0YSh\nlFIqI5owlFJKZUQThlJKqYxowlBKKZURTRhKKaUy8hf8CwfjbzhfpQAAAABJRU5ErkJggg==\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f47b218dbd0\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(ag_means)\n", + "plt.ylabel('Time(s)')\n", + "plt.xlabel('Length of vector')\n", + "_ = plt.title('Time to sum the elements of 1000 vectors (AutoGraph)')\n", + "_ = plt.ylim(ymin=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "d7IAJ6Bwbk9t" + }, + "source": [ + "## Eager" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "XMu5-12yoOzY" + }, + "outputs": [], + "source": [ + "from tensorflow.python.eager import context" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "_vt9MzpyjQ4T" + }, + "outputs": [], + "source": [ + "# Sum written using for loop and run with tf.eager\n", + "def sum_all(elements):\n", + " sum_ = 0.0\n", + " length = elements.shape[0]\n", + " for i in tf.range(length): \n", + " sum_ += elements[i][0]\n", + " return sum_\n", + "\n", + "eager_means = []\n", + "for num in range(max_elements):\n", + " with context.eager_mode():\n", + " durations = []\n", + " for i in range(trials + burn_ins):\n", + " \n", + " start = time.time()\n", + " for _ in range(batches):\n", + " run_trial(num)\n", + " \n", + " if i \u003c burn_ins:\n", + " continue\n", + " \n", + " duration = time.time() - start\n", + " durations.append(duration)\n", + " eager_means.append(np.mean(durations))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 301 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 422, + "status": "ok", + "timestamp": 1532460024499, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "5gHVdMlD-A-T", + "outputId": "3b581cb7-7ef9-489c-92f1-3e52c0c2dc8a" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAEcCAYAAAAydkhNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XlclNX+wPHPsC+CILviCiruhAiiqaAimqnglpZLWpkt\ntmh50+rXvZZ2q9veq5ulZYtlaZp7mQuaorjklkoqiwgKsojsy8yc3x9cBxFUVGBYvu+/mGfmPM93\nzjzMd85zznOORimlEEIIISphYuwAhBBC1F2SJIQQQtyQJAkhhBA3JElCCCHEDUmSEEIIcUOSJIQQ\nQtyQJAlg8eLFvPrqq8YOo14ZOHAge/furfHjzJs3jw8//LDGj1NXxMfHExERQc+ePfnuu++MHU6D\ndvbsWcaMGWPsMCo1btw4YmNjjR0G0EiSxD333IOfnx9+fn506tSJHj16GLZt2LCBxx9/nNdff73G\n40hOTsbHxwe9Xl/jx6pOje2L+lZq8nNcsmQJgYGBHDp0iEmTJlV4fvPmzUyYMAFfX1+mTJlS4flT\np04xevRofH19GTNmDDExMeWef+eddwgMDKR379688847t1W2pk2ePJlVq1bV2vE++ugjHn30UcPj\ngQMH0qNHD/z8/AzfD2+88UatxXOtRx55pM78zzWKJHH48GH+/PNP/vzzT5o3b87ixYsN2+6///5a\ni0MphUajQe5frN9q8nO8cOEC3t7eN3zewcGBhx9+mBkzZlR4rqSkhKeeeorw8HAOHDhAeHg4Tz75\nJFqtFoAVK1awfft21q9fz7p164iMjOTHH3+sUtn64HY+j7S0NKKjoxk0aFC57YsXL+bPP/80fD+8\n8sor1R3mTel0OqA0YUVHR5Oenl6rx69Mo0gS11JKVTiZPvnkE1588UWg7Ffi6tWrCQ4OJjAwkBUr\nVnD8+HFGjhxJQEBAhVbHqlWruO+++wgMDOTRRx/lwoULlR578uTJAPj7++Pn58fRo0dRSvHpp58y\ncOBA+vbty0svvURubm6l5S9fvszMmTPp1asXgYGB5X5p+vj4cP78ecPja3/979+/nwEDBrBkyRL6\n9OlDv3792Lp1Kzt37iQsLIzAwEAWL15c6TF/+ukn1q9fz5IlS/Dz8+OJJ54wPHfq1ClGjhxJr169\nmD17NsXFxYbnduzYQXh4OL169WLixIn8/fffle4fIDY2lunTpxMYGMiwYcPYvHnzDV97s/0OHDiQ\npUuXMnLkSO655x5eeeUVMjIyeOyxx/Dz82P69Onk5OQYXn/kyBEmTJhAr169CA8PZ//+/YbnJk+e\nzIcffsjEiRPx8/PjkUceISsry/AclP8cExMTmTx5Mv7+/gQFBTF79uwbvodt27Zx//33ExAQwJQp\nU4iLiwNg6tSpREdHs2DBAvz8/Dh37lyFskFBQQwdOhQXF5cKz+3fvx+dTseUKVMwNzdn8uTJKKXY\nt28fAL/88gvTp0/H1dUVV1dXpk2bxpo1awCIjo6+adlrbdq0qcJlmmXLlvHkk08CUFxczFtvvUVI\nSAj33nsv//znP8udG1u3biU8PJyePXsyZMgQdu/ezfvvv8+hQ4d4/fXXy/2C//PPPxk7diy9evVi\n3LhxHD58uNxn9P777zNx4kR8fX1JSkpi9erVDB48GD8/PwYPHsyGDRsq/Qz27NlDly5dsLCwKLf9\nRonm/PnzTJ06lcDAQIKCgnjhhRfK/Z+eOHHCcJnw2Wef5fnnny/XErjVefvFF18Yzlu9Xo+FhQVd\nunRh9+7dlcZTq1QjExISoqKiospt+/jjj9WLL76olFIqKSlJdezYUb322muqqKhI7dmzR3Xr1k09\n9dRTKjMzU6WkpKigoCB14MABpZRSv//+uxoyZIiKi4tTOp1O/fe//1UPPPBApcdOSkpSPj4+Sq/X\nG7atXLlSDRkyRCUlJan8/Hz19NNPG2K53rvvvqtee+01pdPplFarVQcPHjQ85+PjoxITEw2PX3rp\nJfXBBx8opZSKjo5WnTt3Vp9++qnSarXqp59+Ur1791Zz5sxR+fn56syZM6pbt27q/PnzlR732n1d\nW4/jxo1TaWlp6sqVK2rYsGFqxYoVSiml/vrrLxUUFKSOHTum9Hq9WrNmjQoJCVHFxcUV9p2fn68G\nDBig1qxZo/R6vTp58qQKDAxUZ8+erXDsW+03JCREPfDAAyojI0OlpqaqoKAgFRERoU6dOqWKi4vV\nlClT1CeffKKUUiolJUUFBASoXbt2KaWUioqKUgEBASozM1MppdSkSZNUaGioOnfunCoqKlKTJk1S\n77777g0/x9mzZ6vPPvtMKaVUUVGROnToUKV1GRcXp3x9fVVUVJTSarXqiy++UKGhoaqkpMRw3JUr\nV1Za9lo//fSTmjx5crltX331lXrsscfKbXv88cfVV199pZRSqmfPnuro0aOG544fP678/PyqVPZa\nBQUFys/PT507d86wbcyYMWrTpk1KKaXeeOMN9cQTT6js7GyVl5enZs6cqd577z2llFJHjx5VPXv2\nNPwPpqamqri4uErfe1ZWlurVq5dat26d0ul0asOGDapXr14qKyvL8PqQkBB19uxZpdPpVE5OjvLz\n81MJCQlKKaXS0tIM59H13nrrLbVgwYJy2yr7brjq3LlzKioqSpWUlKjMzEw1adIktWjRIqWUUsXF\nxSokJER9++23SqvVqi1btqguXbrc1nkbHh6uUlJSVFFRkeGYr7/+uvr3v/9daTy1qdG1JKpCo9Hw\n1FNPYWFhQZ8+fbC2tmb48OE4Ojri5uaGv78/J0+eBODHH39kxowZtG3bFhMTE2bMmEFMTAwXL168\n4f7VNb9WNmzYwMMPP0yLFi2wtrZm9uzZbNq0qdLr3WZmZqSlpZGUlISpqSk9e/asdJ+VMTc3Z+bM\nmZiamnLfffdx+fJlpk6dirW1Nd7e3nh7e9/0135lpkyZgrOzM/b29oSEhHDq1CkAVq5cyYQJE+jW\nrRsajYbw8HAsLCw4evRohX3s2LEDT09PwsPD0Wg0dOrUiSFDhvDrr79WeG1V9jtp0iSaNWuGq6sr\n/v7+9OjRAx8fH8zNzQkNDTXEuG7dOoKDg+nXrx9Q+gu9a9eu7Ny507Cv0aNH06pVKywsLBg2bJih\n7FXX1rmZmRnJycmkpqZiYWGBn59fpXW2efNmgoODCQoKwtTUlEceeYTCwsJyv5DvVH5+PnZ2duW2\nNWnSxPCL9/rn7ezsyM/Pr1LZa1lZWTFo0CDDr/SEhATi4+MNl25WrVrFvHnzsLOzw8bGhhkzZhhe\nu2rVKsaOHUtQUBAArq6utG3bttL3ExkZSZs2bRgxYgQmJiYMHz6cdu3asWPHDsNrIiIi8PLywsTE\nBFNTU0xNTTl9+jRFRUU4Ozvj5eVV6b5zcnKwtbWtsP2pp54iICCAXr16ERAQwMqVKwFo1aoVQUFB\nmJmZ4ejoyNSpUzlw4ABQ2iLV6XRMmjQJU1NTQkND6d69u2GfVTlvp0yZgpubW7mWja2tLdnZ2ZXG\nX5vMjB1AXeXk5GT428rKCmdnZ8NjS0tLwz/XhQsXWLhwIW+99RZQdr06NTUVDw+PWx7n0qVLNG/e\n3PC4RYsWaLVa0tPTcXV1LffaRx99lI8//pjp06ej0WgYN25cpdemK+Pg4IBGozG8n8re49X3VFXX\nlre2tiYtLQ0orZO1a9caRucopdBqtVy6dKnCPi5cuMCRI0cICAgwvFan0xEeHl7pa2+132tjsrS0\nrPD42s9t8+bNhi+cq/u6+uUFlPvMra2tb1o/c+fO5YMPPmDs2LGGfoPKRs5c/3lrNBo8PDxITU29\n4b6rysbGpsKXem5uLk2aNKn0+dzcXGxsbKpU9nrDhw/n7bff5sknn2TDhg0MHjwYCwsLMjMzKSgo\nKPfe9Xq9IaGmpKQwYMCAKr2f6+sKoHnz5uXqyt3d3fC3tbU177//PkuXLmX+/Pn07NmTuXPn0q5d\nuwr7tre3Jy8vr8L2Tz/9lN69e1fYnpmZyRtvvMHBgwfJz89Hp9Ph4OAAlPZvuLm5lXv9tf/7VTlv\nr30fV+Xl5WFvb19he22TJHGX3N3deeKJJ6rUAX71S/parq6u5fowkpOTMTMzK/cFdZWNjQ3/+Mc/\n+Mc//kFsbCyTJ0+me/fu9O7dG2trawoKCgyvTUtLq/TEqw3u7u7MnDmTxx9//Jav9fDwIDAwkKVL\nl1brfqty3PDwcBYsWHDbZSv7HJ2cnAx9VYcOHWLatGkEBATQsmXLcq9zdXXlzJkz5bZdvHixWj6r\n9u3bs2zZsnLbTp8+behD8fb2JiYmhm7dugGlfUrt27e/adnKRlgB3HvvvcybN4+YmBg2btzI/Pnz\nAXB0dMTa2poNGzZU+JEDpZ/htX1n17q+Xl1dXdmyZUu5bRcuXKB///43LNO3b1/69u1LcXEx77//\nPq+++irLly+vcKyOHTuydu3aCttv1CJ/99130Wg0bNiwAXt7e7Zu3WroN3FxcamQ5C9evEirVq0M\n7/lOztu4uDhGjhx5W2VqglxuqsStLt1ca+LEiSxevJizZ88Cpc3Yyi6VADRr1gwTExMSExMN24YP\nH86yZctISkoiLy+P999/n+HDh2NiUvGjiYyMNJS1sbExNK+htON6w4YN6PV6du3aZWgKVwdnZ+cb\n/mNXZvz48axYsYJjx44BpZcydu7cWekv8eDgYOLj41m7di1arZaSkhKOHz9u6My90/3eysiRI9m+\nfTu7d+9Gr9dTVFTE/v37q/SLvrLP8ddffzWUtbe3x8TEpNLPcNiwYURGRrJv3z60Wi1Lly7F0tIS\nX1/fKsWt1+spLi5Gq9WW+xsgICAAExMTvv32W4qLiw2/XAMDAwEIDw9n2bJlpKamkpqayrJlyxg9\nevRNy1b2qxrA1NSUsLAw3n77bbKzs+nbty+AoYW7aNEiMjMzAUhNTTV0wI4dO5bVq1ezb98+lFKk\npqYaPuvrz7MBAwZw7tw5Nm7ciE6nY9OmTcTFxRESElJpTBkZGWzfvp2CggLMzMwM/yOV6du3LydO\nnCjXoX4zeXl52Nra0qRJE1JTU8v9qPH19cXU1JTly5ej0+nYunWr4RyFOztvi4uLOXHihKFejanR\nJYnKfgXe6jU3ezx48GAee+wxnn/+efz9/Rk5ciR//PFHpfu1srJi5syZTJw4kYCAAI4dO8bYsWMZ\nNWoUkyZNIjQ0FGtr6xsOu0tISODhhx/mnnvuYeLEiTz00EP06tULgJdffpnt27fTq1cvNm7cyODB\ng+/qPV5r7NixnD17loCAAJ5++ulbvr5r1668/vrrLFiwgICAAMLCwgyjaK5na2vLl19+yaZNm+jX\nrx/9+vXj3XffrfSf91b7vZ335O7uzqeffsrixYsJCgoiJCSEL7/80vAD4WZlK/scjx8/zrhx4/Dz\n8+Opp57i5ZdfpkWLFhXKtm3blnfeeYfXX3+doKAgIiMj+eyzzzAzM7vlcQHWrl1L9+7dWbBgAYcO\nHaJHjx6GG0HNzc359NNPWbNmDQEBAaxevZpPP/3UsO8JEyYQEhLCyJEjGTlyJCEhIYwfP75KZSsz\nfPhw9u7dy7Bhw8olxBdeeIHWrVszfvx4/P39mT59OgkJCQB0796dRYsWsWjRInr27MmUKVMM/XdT\npkzh119/JTAwkIULF+Lg4MBnn33G0qVL6d27N0uXLmXx4sU0bdq00rrS6/V89dVX9O/fn969e3Pg\nwAFee+21SmN3cnKid+/ebN26tdz2J554wnBPlZ+fH7NmzQLg6aef5q+//sLf35+ZM2cSFhZmKGNu\nbs7HH3/MypUr6dWrFxs2bGDgwIGG/oXbPW+hdARcYGBgpaPYaptG3c7P5ts0f/58IiMjcXJyYv36\n9QDExMTw2muvUVRUhJmZGa+99pqh+SuEELUlNjaWl156ydA5XZ3Gjx/PxIkTiYiIuKPyDzzwAAsX\nLrzpPTO1pUaTxMGDB7G1tWXu3LmGJPHII48wbdo07r33Xnbu3MmSJUv49ttvayoEIYSocQcOHKBt\n27Y4Ojqybt06/vWvf7F169ZK+xbrmxrtuPb39yc5ObncNo1GY7ihKScnp8KoACGEqG/i4+N57rnn\nyM/Pp1WrVnz00UcNIkFADbckoHS0zsyZMw0tidjYWB599FHDnc8rVqyo0lBRIYQQta/WO65/+OEH\nXn75ZSIjI5k3b55h6JwQQoi6p9aTxC+//GIYeTN06NByQ8VupoYbPEIIUe9l5xVz/Gx6tX5f1vjN\ndNcH6+bmxv79+wkICGDv3r20adOmSvvRaDSkpeXc+oWNgIuLndTF/0hdlJG6KNPY6iL9SgFb9p9n\n17ELFJfo+ee0XrRyK51mxcXF7halb65Gk8ScOXOIjo4mKyuL4OBgZs2axeuvv84bb7yBXq/H0tKy\nVtZxEEKIhijjSiG/7I5j71+p6JWimb0lw4Jb09K18ulU7kSNd1xXp8b0y+BmGtuvpJuRuigjdVGm\noddFXmEJG/eeY+vBJLQ6Pc2dbRkW2IrAzm6YmZbvRajTLQkhhBDV61RCJv9de4LcghKa2VsS0a8d\nQV3cMTG59WwSd0KShBBC1BORR5JZvuU0AGODvRjc0xML88rnp6oukiSEEKKO0+n1/LQ9lt8PnqeJ\ntTlPj+5Gh5YOtXJsSRJCCFFHXczIY/fxi+z9K4Ws3GI8nGx4dlwPXB2say0GSRJCCFHHFBRp+WL9\nSY6cTQfA2tKMgX4tGN3fCxur2v3aliQhhBB1SGZ2IR+sPEZSWi7enk0Z5OfJPe2da7zv4UYkSQgh\nRB1x/lIuH6w8yuWcIkL8WvDQ4A41NmqpqiRJCCGEESmlOJt8hT3HU4g+mUpRiY7xId6EBbSs0iJp\nNU2ShBBCGMmhvy+xMjKWS5dL16d3tLNk+vBO9PKpuD64sUiSEEKIWqbXK9b8EcfGvecwNzMhqIsb\nfbp50KmVo9EvL11PkoQQQtSi/MISFq87yfG4DFwdrZk1uhstXKpvrqXqJklCCCFqScaVQt798Qgp\nmfl0bdeMx0d2wdbK3Nhh3ZQkCSGEqAUX0vN498cjXM4pYkivlowP8a5zl5YqI0lCCCFqWNyFbD5Y\neZTcghLGhXgxLLC1sUOqMkkSQghRA5RSJKTksPvYRfb8dZESrZ5pw3zo16O5sUO7LZIkhBCimh2L\nzWBV5FmS0vIAcGhiweQRHbmng4uRI7t9NZok5s+fT2RkJE5OTqxfv96w/dtvv2X58uWYm5szYMAA\nXnjhhZoMQwghas3Wg+f5YdsZTDQaenZ0oV93D7q0bYapicmtC9dBNZokRo8ezeTJk5k7d65hW3R0\nNDt27GDDhg2YmZmRmZlZkyEIIUSt0OsVK7afYevBJOxtLXh2bHfaetgbO6y7VqNJwt/fn+Tk5HLb\nfvjhBx577DHMzEoP3axZs5oMQQghaoRerzgWm0Fyei5pWYWcS83hXEoOzZ1teW5sd5xrcTrvmlTr\nfRIJCQkcPHiQ999/H0tLS+bOnUu3bt1qOwwhhLhjWp2epRtPEX0ytdz27l5OzBjRGZs6fu/D7aj1\nJKHT6cjOzuann37i2LFjPPfcc2zbtq1KZe92Qe+GROqijNRFGamLMjVVF0UlOv799QEOnkqlU5tm\njB3UHrdmNrg52mBl2fDGAtX6O3J3d2fIkCEAdO/eHRMTEy5fvoyjo+Mty6al5dR0ePWCi4ud1MX/\nSF2UkbooU1N1UVCk5eOfjxGTmEWXts14OqIblhal6zzkZBdQF2v/bpNljXe3K6XKPR48eDB79+4F\nID4+Hq1WW6UEIYQQxhR/MZvXvz5ITGIWPTu68MyY7oYE0ZDVaEtizpw5REdHk5WVRXBwMLNmzWLM\nmDHMmzePESNGYG5uzltvvVWTIQghxF3R6fVs3HuO9XsS0OkVQ3q1ZFyIV70d0nq7NOr6n/p1mDSl\nS8llhTJSF2WkLspUV12kXylg8boTxCZn42hnySPDO9G5Tf0akXm3l5saXi+LEEJUgyNn01m64SR5\nhVoCOrkyOaxjnZ+xtSZIkhBCiGsUlehY+0c8v+5PxMzUhKlDO9K/R/M6sZSoMUiSEEIIILeghO2H\nkth6KIncghLcmtnwxKgutHJr3MOKJUkIIRo1vV6xYW8Cm/ado7hEj62VGff3acOwwFZYN8D7Hm6X\n1IAQotG6klfM5+tOcOrcZRyaWDCmf2v69fDAykK+Gq+SmhBCNEp/J17ms7UnuJJXjK+3M4/c36lR\ndkzfiiQJIUSjUqLV88vuOH6NTkSDhnEhXgwNaNVoO6ZvRZKEEKLRSEzNYcmGkySl5eHc1IrHRnSm\nvaeDscOq0yRJCCEavBKtno17E9i49xw6vWKAb3PGh3hLx3QVSA0JIRq00+ez+PrXGC5m5ONoZ8nU\noT5093Iydlj1hiQJIUSDlH6lgLV/xLPnrxQ0wCA/T0YPaCeth9sktSWEaFCu5Baxenc8m6NKJ+Tz\ndLFlylAfvFs0NXZo9ZIkCSFEvafT6zkRn8nu4ykcOZOGVqdwcbAi/N52BHZ2w8RERi7dKUkSQoh6\nq6hYx/Y/k9hy8DxXcosBaO5sS/gAL3zbNcPMtHFM512TJEkIIeqdohIdO/5MZnP0OXLyS7C2NCPE\nrwX3dvOgjbsdrq72Mm16NZEkIYSoV1Iz8/lw1TFSMvOxtjRlZN82DOnVEhu5W7pG1GhbbP78+fTp\n04cRI0ZUeG7p0qX4+PiQlZVVkyEIIRqQUwmZvPHNQVIy8xnU05O3ZvYhvF87SRA1qEaTxOjRo1m6\ndGmF7SkpKURFRdG8efOaPLwQooEo0erYevA87/10lMJiHdPu8+Gh0A40sZbkUNNq9HKTv78/ycnJ\nFbYvWrSIuXPn8sQTT9Tk4YUQ9ZhSiuNxGUSfvMSRs2kUFOloYm3O06O70aGlTKVRW2q9T2L79u14\neHjQsWPH2j60EKIeWRUZy+boRACc7K0Y4NuCQX6eODW1MnJkjUutJonCwkI+++wzvvzyS8M2pVSV\ny9/tgt4NidRFGamLMg2lLk7EZfDr/kQ8nG2Z/aAfHVs53vYsrQ2lLoytVpNEYmIiycnJjBo1CqUU\nqampjBkzhpUrV+LkdOu5VGRIWykXFzupi/+RuijTUOqiqFjHe98fAgXThvrgZGNOenrube2jodRF\ndbjbZFnjSeLalkKHDh3Ys2eP4fHAgQNZs2YNTZvK7fJCiFKrdsZy6XIBQwNa4e0p3w3GVqOjm+bM\nmcOECROIj48nODiYn3/+udzzGo3mti43CSEatphzl9l2KAkPJxvC+7U1djiCGm5JvPvuuzd9ftu2\nbTV5eCFEPXL4TBpfbjyFRgOPDO+MhbmpsUMSyB3XQggj0+r0rIqMZcuB85ibmTD9vk60a25v7LDE\n/0iSEEIYTWJqDl//+jfxF7Nxb2bDE+FdaenaxNhhiWtIkhBC1LrcghLW7Ioj8kgySkFQFzcmh3XE\nykK+kuoa+USEELXm6uytG/cmkFeoxcPJhomD29O1rSwnWldJkhBC1LgSrY7IIxfYtPccV/KKsbY0\nY8JAbwb29JQ1H+o4SRJCiBqhV4qzSVfYdzKVA6dSySvUYmlhyv19WhMW0Apbmbm1XpAkIYSodhcz\n8vho1TFSLxcA0NTWgvt6tyYsoCV2NhZGjk7cDkkSQohqdTmniPd+PEJGdhG9u7jRt6sHnVo7yjrT\n9ZQkCSFEtckvLOG9n0oTRET/dozo08bYIYm7JD1GQohqUVSi46NVx0hOy2OQnyf3B7U2dkiiGkhL\nQghxx/ILtRw9m87hs+n8FZdBYbGOXj6uTBzc/ran9hZ1kyQJIcQdOZt8hU9+PkZ2fgkALg5WDOrp\nyci+baX/oQGRJCGEuG37TqTw5aYY9HrF/X1aE9jJjebOttJ6aIAkSQghqkwpxdrd8azbk4C1pSlP\nhHeTu6UbOEkSQogqUUqx/PfTbP8zGeemVjw7rgctnG2NHZaoYZIkhBC3pFeK5VtOs+NwMp4uTXhh\ngi/2tnJTXGNQo0li/vz5REZG4uTkxPr16wF4++232bFjBxYWFrRq1Yo333yTJk1kamAh6iq9Uny3\n5TSRh5Np6VqaIOSu6cZDo2pw/dCDBw9ia2vL3LlzDUkiKiqK3r17Y2Jiwn/+8x80Gg1z5syp0v5k\nYfNSssh7GamLMtVVF4XFWtbtTiD2whWy84rJzi+moEhHK9cmvDDxHppY1/05l+S8KOPiYndX5Wu0\nJeHv709ycnK5bX369DH87evry2+//VaTIQghbkNs8hW+WH+SS1kFaDRgZ22Ok70VLVya8FBoh3qR\nIET1MmqfxKpVqxg+fLgxQxBCAHq9Yn1UAuv3JKCUYlhgK8L7tcPcTCZlaOyMliT++9//Ym5uzogR\nI6pc5m6bTQ2J1EUZqYsyd1IX+YUlvPPdIQ6eSsXZwZrZE/3o5u1cA9HVLjkvqodRksSaNWvYuXMn\n33zzzW2Vk2uMpeR6axmpizJ3UheZ2YV8uOoY5y/l0qVtM2aO6oKtlXm9r1M5L8rU6T4JKB1bfa1d\nu3axZMkSvvvuOywsZISEEMZyLiWHD1cdJSu3mGDf5jwY2kFWiRMV1GiSmDNnDtHR0WRlZREcHMys\nWbNYvHgxJSUlTJ8+HYAePXrwz3/+sybDEEJc5+/Ey3y46hhFxTrGh3gTFtBSptQQlarRIbDVTZqP\npaQpXUbqokxV6+LI2XT++8tf6PWKx0Z0JqCTWy1EV7vkvChT5y83CSGMK7eghKycIrLzi0lMzWVV\nZCxmphqeGdudbu1k3iVxc5IkhGjA/jh6gW9++xudvuyCgY2lGc+O6057TwcjRibqC0kSQjRQx2Iz\n+PrXv7GxMqNXJ1fsbSywszGnWzsnXBysjR2eqCckSQjRACWkZPPfX/7C1FTDs2O749WiqbFDEvWU\njHcTooFJSsvlg5XHKC7RMWNEF0kQ4q5IS0KIBuDvxMus3h3PgRMppF4uAOCh0A707Ohi5MhEfSdJ\nQoh6rKhEx4/bzxJ5uHQiTUsLU+5p70xAJzcCOze8oa2i9kmSEKKeSrqUy2frTnAhPQ9PF1seH90d\nN3tLuWtXb7olAAAgAElEQVRaVCtJEkLUMyVaPb/tT2TdngS0Oj2D/DwZF+JFi+YOcgOZqHaSJISo\nR04mZPLdltOkZOZjb2PO1GFduKe99DuImiNJQoh6oLBYy/LfT7PneAoaYKBfC0b3b4eNlSwCJGqW\nJAkh6rjE1Bz++8tfpF4uoLWbHVOHdaSNu72xwxKNhCQJIeoopRTb/0zmx+1n0OoUQwNaMXpAO+mY\nFrVKkoQQdZBSip93xrFp3zmaWJvz6P2d6O5V/1eLE/WPJAkh6hi9UqzYeoath5Jwc7TmhQn34NTU\nythhiUbqlkni/PnzrFq1iujoaFJSUrC0tMTHx4ewsDCGDBmCmdmNdzF//nwiIyNxcnJi/fr1AFy5\ncoXnn3+e5ORkPD09+eCDD7Czk7VoReNUotVzObeIgkItCoVSsPNIMruOXqSFiy0vPOBL0yaWxg5T\nNGI3XXTo//7v/zhx4gRDhw7lnnvuwdnZmaKiImJjY9m9ezcnT57kn//8J76+vpWWP3jwILa2tsyd\nO9eQJN555x0cHBx47LHH+Pzzz8nOzuaFF16oUrAyBryULKhSpj7WRVpWAd/8GsP5S7lk55dU+prW\nbnbMmeBLE+uqj16qj3VRU6QuytTookODBg1iwYIFFbZ37NiR++67j6ysLM6fP3/D8v7+/iQnJ5fb\ntm3bNr777jsAIiIimDx5cpWThBD1XWzyFT76+Rg5+SW4OlrTwqUJjnaW2FiZYaLRoNGAjZU5g/xa\nyPBWUSfcNEkMGDDgpoUdHBxwcLi9hUsyMzNxdi7tgHNxceHy5cu3VV6I+upAzCWWbDiJTqeYPKQD\nIX6exg5JiFuq0li6f//73+Tk5KDVannwwQfx9fVl7dq1NR2bEPVadl4xO48k882vMfxr2YHS9R1M\nNDw7rrskCFFvVGl0U1RUFC+99BKRkZG4ubnx/vvvM2PGDEaNGnXbB3RyciI9PR1nZ2fS0tJo1qxZ\nlcve7bW1hkTqokxdrIviEh3zv9hHSkY+AGamJnRq04wnx/agjUfN3QhXF+vCWKQuqsdtDYE9cOAA\noaGhuLm5odFoqlTm+n7xgQMHsnr1ambMmMGaNWsYNGhQlY8vHVGlpFOuTF2ti1+jE0nJyCeoiztD\nerWkhYut4Sa4moq3rtaFMUhdlLnbZFmly01OTk688sorbNq0ib59+6LVatHpdLcsN2fOHCZMmEB8\nfDzBwcH8/PPPzJgxg6ioKMLCwti7dy8zZsy4qzcgRF2TW1DChqgEbK3MeDC0Pa3d7eQuaVFvVakl\n8e6777Ju3TrGjh1L06ZNSUpKYtq0aVUqV5lly5bdVpBC1CcbohLIL9LywEBvbGWEkqjnqpQkmjVr\nxsMPP2x47OnpiaendLwJcb20rAK2/5mEc1MrBkrntGgAbtoGfvLJJzl27Filz+Xm5vL111/z448/\n1khgQtRHa3bFodUpRvdvh7mZXGIS9d9NWxLPPPMM7777LgkJCXTv3h0nJyeKioqIi4sjOTmZCRMm\nMHHixNqKVYg6Sa9XHI1NZ+vBJE6du0xrdzsCZH1p0UDcNEn4+PjwxRdfcPHiRfbv309qaiqWlpYM\nHTqUnj17YmFhUVtxClEnnU26wpINJ7mUVQBAp9aOPBTaAZMqjv4Toq6rUp+Eh4fHHd0TIURDlltQ\nwqe/HCc7r4T+PTwY3LMlnq5NjB2WENWqShdNMzIyeOGFF3jooYcAiImJ4YcffqjRwISo677b8jdZ\nucWE92vLw8M6SYIQDVKVksQrr7xCz549yc7OBqBdu3Z8//33NRqYEHXZvpMp7D91Ce8WTRnWu5Wx\nwxGixlQpSaSmpjJx4kRMTU0BsLCwwMRERm6Ixikzu5DvfjuNpbkpj97fCVP5XxANWJXO7usXFsrO\nzq4w3YYQjUFqZj6frD5OfpGWCYO8cXW0MXZIQtSoKnVcDxkyhP/7v/8jLy+P1atX8/333zNmzJia\njk2IOkOr0/Pb/kTW7k5Aq9PTp6s7/Xs0N3ZYQtS4KiWJRx99lHXr1pGdnc3OnTuZPHmyjHYSjYJS\nimOxGfy8M46ktFzsbS14KLQD/h1dqjzJpRD1WZVngR05ciQjR46syViEqDZanf6uJtXT6vT8FZfJ\nuj3xJKSUziZ6b3cPmY9JNDpVShIZGRl89913JCYmotVqDds//PDDGgtMiDuVmJrDv5Yd4KmIbvh1\ncKlSGaUUv/wRz7HYDC7nFBrWntYAvXxcGdG3DZ4uMsRVND5VShJPPvkknTt3JigoyDDCSYi66kzS\nFZSCY7EZVU4S6/YksD4qAXMzE5rZWdLc2Ra3ZjYM6ukpyUE0alVKEgUFBbz22ms1HYsQ1eLS5dIp\nMs6lVm3RmX0nU1i7Ox7npla8MsUfe1uZbkaIq6qUJHr06MHff/9Nx44dazoeIe5a2v/mUUpOy63Q\nN/H7gfOsj0qgh7cTfbt6YGZmwpcbY7C2NOXZsd0lQQhxnSoliQkTJjBp0iTc3d2xtLQ0bF+1atUd\nH3jZsmWsWrUKjUZDhw4dePPNN2XCQFEtrk62p9UpLqTn0cqtbPnGvSdSyC0oYc/xFPYcTwFAo4Fn\nRvWghVxWEqKCKiWJF198kZkzZ9K5c+dq6ZNITU3l22+/ZfPmzVhYWPDcc8+xadMmwsPD73rfonHT\nK2VoSUDpJaerSaKwWEtiai5eze0ZG+zF7uMXOR6bQXj/dnRt52SskIWo06qUJCwtLXnkkUeq9cB6\nvZ6CggJMTEwoLCzE1dW1WvcvGqcrucWUaPU42VuRkV1IYmqu4bnYC9nolaJDSwc6tnKkYytHI0Yq\nRP1QpYHk/fr1Y9euXdV2UDc3N6ZNm0ZwcDD9+/fHzs6OPn36VNv+ReN16XI+APd0cMZEoynXeX3m\nfBYA7T0djBKbEPVRlVoSP/30E59//jm2trZYWFiglEKj0bB37947Omh2djbbtm1jx44d2NnZ8cwz\nz7B+/XpGjBhxR/sT4qqr/RGeLk3wcLLhfGoueqUw0Wg4k3QFAG/PpsYMUYh6pUpJ4ueff67Wg0ZF\nRdGyZUscHEp/0YWGhnL48OFbJgkXF7ubPt+YSF2UubYu8or1ALRv04xzl3JJPpRECRrcm9kSdzGb\nVu52tG3VzFih1jg5L8pIXVSPKiWJFi1aVOtBmzdvztGjRykqKsLCwoJ9+/bRrVu3W5ZLS6vauPeG\nzsXFTurif66vi4Tk0ktKFoCbgzUAh0+l4OZoQ1GxjnYe9g227uS8KCN1UeZuk+VNk8SLL77IO++8\nw5gxYyqdzOxOh8B2796dsLAwwsPDMTMzo3PnzowfP/6O9iXEtdKyCjAz1eBoZ0lrt9IhrYmpuVzJ\nLQagvVxqEuK23DRJXLp0CYB//OMf1X7gp59+mqeffrra9ysat7SsQpyaWmNioqGla+kvqHMpOVhb\nlp7qHaTTWojbctMkcXW50oCAgFoJRoi7kV+oJbeghLYe9gDYWJnh6mBNYmoOJiYamtlb4tTUyshR\nClG/yLqLosG4ehOd6//6IgBauduRV6glJ79EWhFC3IGbtiROnz5NUFBQhe13OwRWiJpwdfiri2NZ\nkmjt1oSDMaWXTaU/Qojbd9Mk0aZNGz7//PPaikWI25KTX4xdcdn6JldvpLu2JdH6mnmb2reUloQQ\nt+umScLCwqLah78KUR2KS3S8siQab08HZo0uHT6dVklL4uq8TbZWZjR3tq39QIWo526aJMzNZZlG\nUTfFJGaRk1/C4dNpnEvJobW7nWEdCZdrOqftbS3o29UdFwdrTGRNaiFu2007rn/66afaikOI23Is\nNt3w9+8HzwOlLQlHO0sszMvPVPzI/Z0ZeW/bWo1PiIZCRjeJekcpxbHYDKwtzWjh0oTok6mkXykg\nM7sIl2v6I4QQd0+ShKh3Lmbkk36lkC5tmzGqfzt0esWqyFgU4OIg90EIUZ0kSYh651hsBgA9vJwI\n6dkSWysz9p8qHebqKi0JIaqVJAlR71ztj+jazgkrSzMG+JaNwLt2ZJMQ4u5JkhD1Sn6hljNJV2jr\nYUdT29I10Qf6tcDUpHTkkquDjTHDE6LBkSQh6pWTCZno9Ipu16xJ3czeij5d3bG2NMPDSZKEENWp\nSutJCFFXGPojvJ3LbZ8ytCMPDGxvmO1VCFE95D9K1Bt6pTgel4G9jTmt3csvpGJqYoKNlTSMhahu\n8l8l6o0T8ZlcySumWzsnuXtaiFpitJZETk4OL7/8MmfOnMHExIRFixbRo0cPY4Uj6jCtTs+GqAQ2\nRJ1DAwR1dTd2SEI0GkZLEgsXLmTAgAF89NFHaLVaCgsLjRWKqMMupOexdONJ4i/m4GRvySPDO+PT\n2tHYYQnRaBglSeTm5nLw4EH+/e9/lwZhZkaTJk2MEYqoo/ILtazbE8+2Q0no9IqgLu48FNoBGyvp\nRhOiNhnlPy4pKQlHR0fmzZtHTEwMXbt25eWXX8bKSqZUaOwKirTsO5nK2j/iyM4vwbmpFRMHteee\nDi7GDk2IRkmjlFK1fdC//vqLBx54gBUrVtCtWzcWLlyInZ0dzzzzTG2HIuoApRSH/05j+8Hz7P3r\nIsUlOqwsTBk3qAPhA7wqzOoqhKg9RmlJuLu74+7uTrdupYvFhIWFsWTJkluWS0vLqenQ6gUXF7sG\nUxdZuUV8ufEUf8VnAuDmaE1QV3f6dW+Oo50lV7Lyb1q+IdXF3ZK6KCN1UcbFxe7WL7oJoyQJZ2dn\nPDw8iI+Pp23btuzbtw8vLy9jhCKM6PDpNL7aHENuQQld2zVjVN+2tGtuj0aGtwpRZxitF/CVV17h\nhRdeQKvV0rJlS958801jhSJqWVGxjhXbz7DzyAXMzUx4KLQDA/1aSHIQog4yWpLw8fHh559/Ntbh\nhZEkpGSzeN1JUjPz8XRpwuMjO9PCRUa2CVFXyXhCUSu0Oj2/7U/klz/i0ekVYQEtGd3fC3Mzuelf\niLpMkoSoUVeXGv1x+1lSMvNp2sSCR4d3pkvbZsYOTQhRBZIkRI1JTM1hVWQsf8VnotFAiF8LIvq1\no4m1ubFDE0JUkSQJUe0SU3NYtyeBP0+nAdCljSMPDGqPp/Q9CFHvSJIQ1SY9q4CVkbEciCldb9qr\nuT2j7m1Ll7bNZOSSEPWUJAlx1wqKtGzad47f9p9Hq9PT1sOO8H7t6CrJQYh6T5KEuGNKKfaeSGFl\nZCxXcotxtLNkbLAXgZ3dZL0HIRoISRLijpxLyWH576c5m3wFczMTRvZtw7DA1lhayDxLQjQkkiTE\nbSko0rJmVxzbDiWhgJ4dXHhgoDfODtbGDk0IUQMkSYgqOxabwbe/xZCRXYRbMxsmDelAlzZyv4MQ\nDZkkCXFLl3OK+GnHWaJPpmJqouH+Pm0Y0ac15mZyaUmIhk6ShLghrU7P1oNJrN0TT1GxjrYe9kwb\n5oOnq9zvIERjIUlCVKCU4sjZdH7eGceF9DyaWJszYag3/Xo0l1FLQjQykiREOafOXWb1zlhiL2Sj\n0cAA3+aMGeAlU2kI0UhJkhBA6Qpxy7ec5tD/ptLo2cGF8P7taOFsa+TIhBDGJEmikVNK8cexi/y4\n/SwFRVraezZlwqD2tPWwN3ZoQog6wKhJQq/XM2bMGNzc3Pjss8+MGUqjdC4lhxXbzvD3+SysLEyZ\nHNaRAb7S7yCEKGPUJPHNN9/g5eVFbm6uMcNodDKzC1mzK46ov1JQgK+3M5OGdKCZvZWxQxNC1DFG\nSxIpKSns3LmTmTNn8tVXXxkrjEalqETHb9GJbNp3jmKtHk+XJjwwyFtuiBNC3JDRksSiRYuYO3cu\nOTk5xgqh0VBKcSDmEj/tOEtmdhFNbS14KLQdfbt5YGIil5aEEDdmlCQRGRmJs7MznTp1Ijo6usrl\nXFzsajCq+qUqdaGUYv+JFFZsPc3Z81mYmZowdmB7xg1qj41VwxnSKudFGamLMlIX1UOjlFK1fdD3\n3nuPdevWYWpqSlFREXl5eYSGhvL222/ftFxamrQ6oPTkv1VdHDmTzpo/4jh/KRcN4O/jypgB7XB1\ntKmdIGtJVeqisZC6KCN1UeZuk6VRksS19u/fz5dfflml0U3yoZe62T/A5Zwilv9+mj9Pp6HRQGAn\nN4b3adNg73eQL4MyUhdlpC7K3G2SkPskGgi9Uuw8coFVkWcpKNLRoaUDU8I60ryBJgchRO0wepII\nCAggICDA2GHUa38nXuaHbWdITM3F2tKMqUM7yjxLQohqYfQkIe7cpawCVu44y6G/S6fSCOrixthg\nbxztLI0cmRCioZAkUQ/lF5awKjKWLQcS0eoUXi3smTioA+2ay1QaQojqJUmiHlFKEfVXCqt3xXE5\npwhHO0vGhXgR2MkNjVxaEkLUAEkS9URyeh7f/vY3p89nYWFuyqh72zI0sBWW5rI6nBCi5kiSqOOK\ninWsj0rgt/2J6PSKe9o78/QD96DR6owdmhCiEZAkUUcppYg+mcrKyFgu5xThZG/FQ6Ed8G3vjIuj\njYwBF0LUCkkSddC5lByW/36as8lXMDM14f4+bRjeuzWWFnJpSQhRuyRJ1CG5BSWs3hXHzsPJKKBn\nRxceCPHG2cHa2KEJIRopSRJ1gFanZ/exi6zeFUduQQkeTjY8FNqBzjKFtxDCyCRJGJFWpyfqrxQ2\nRCWQfqUQSwtTxod4M9jfEzNTE2OHJ4QQkiSM5fCZNH7Yeob0K4WYmZowuKcn9wW1xqGJ3C0thKg7\nJEnUsit5xXz/+2kOxFzC1ETDoJ6e3Ne7tUylIYSokyRJ1BK9XvHHsQusiowlr1CLVwt7Hh7WqcFO\n4S2EaBgkSdSC0+ez+H7raRJTc7G0MOWh0A6E+LWQWVqFEHWeJIkalJZVwM87Y9l/6hIAQV3cGRvs\nJZeWhBD1hiSJGpBXWMLGqHNsPXQerU7R1sOeBwe3x6tFU2OHJoQQt8UoSSIlJYW5c+eSnp6Oqakp\n48aNY8qUKcYIpVoppfjj2EVW7jhLXqEWJ3tLxgzwIqCzm1xaEkLUS0ZJEqampsybN49OnTqRl5fH\n6NGj6du3L15eXsYIp1qkXs7n680xxCRmYWVhyrhgLwb7e2JuJlNpCCHqL6MkCRcXF1xcXACwtbXF\ny8uLS5cu1cskUVCk5fcD59m47xwlWj2+3s5MGtKBZvZWxg5NCCHumtH7JJKSkoiJiaF79+7GDuW2\nlGh17Dh8gY17E8jJL8He1oJH7++Af0cXWQBICNFgaJRSylgHz8vLY/LkyTz55JMMHjzYWGHclqIS\nHb9Hn+PnHWdJzyrA2tKMiGBvRvVvh42VubHDE0KIamW0JKHVann88cfp378/U6dOrVIZY66hUFyi\nY/ufyfy6P5HsvGIszEwI8WvBfb1bY2djUauxuLjYyXoS/yN1UUbqoozURRkXF7u7Km+0y03z58/H\n29u7ygnCmOIvZrNkw0kuZuRjZWHK8KDWhPq3xN62dpODEELUNqMkiUOHDrF+/Xo6dOhAeHg4Go2G\n559/nv79+xsjnBvS6vSs25PApr3n0CvF4J6ejOrXFlu5rCSEaCSMkiR69uzJqVOnjHHoKjubdIVv\nfoshKS0PJ3srpg/vRKfWjsYOSwghapXRRzfVNbkFJayKPMuuoxcB6N+jOQ8M9MbaUqpKCNH4yDff\nNf48ncayzTHkFpTg6WLL5LCOtPd0MHZYQghhNJIkKL0h7oetZ9h9/CLmZiaMC/Ei1L+lrA4nhGj0\nGn2SOJmQybLNMaRfKaSVWxMeG9FF1ngQQoj/abRJIuNKIT9uP8PBv9PQAMODWjPq3rbSehBCiGs0\nuiRRotXz6/5ENkYlUKzV49XCnkmhHWntfnc3nAghREPUqJLEsdh0vt96hkuXC7C3tWBymBdBXd1l\nGm8hhLiBRpEk0rIKWLHtDIfPpGOi0RDq35JR97bFxqpRvH0hhLhjDfpbsqhEx+Z959i0LxGtTk+H\nlg5MCu2Ap2sTY4cmhBD1QoNMEkopDv2dxo/bz5KRXYhDEwvGD/QmsJObTOMthBC3ocEliYSUbFZs\nPcPppCuYmmgYFtiK+/u0kTumhRDiDjSYb84recX8HBnLnuMXUcA97Z0ZH+KNWzMbY4cmhBD1Vr1P\nElqdnu2Hkli7J56CIh2eLrZMGNSezm2aGTs0IYSo9+p1kvg78TLfbjnNhfQ8bK3MmDSkAwN8m2Nq\nIjfECSFEdaiXSSK3oISfdpxl97GLaIBg3+ZE9G9X6yvECSFEQ2e0JLFr1y4WLVqEUooxY8YwY8aM\nW5ZRShF9MpXvt54ht6CElq5NmDrUh3bN7WshYiGEaHyMkiT0ej2vv/46y5Ytw9XVlbFjxzJo0CC8\nvLxuWCbjSgEfrTrG0dgMLMxNGB/iTWgvT7m0JIQQNcgoSeLYsWO0bt2aFi1aADB8+HC2bdt20yTx\n1NvbySvU0qm1I1OH+eDqYF1b4QohRKNllCSRmpqKh4eH4bGbmxvHjx+/aRm9gqlDO9K/R3O5IU4I\nIWqJUZKEUuq2yyx9JZTCvKIaiEYIIcSNGCVJuLu7c+HCBcPj1NRUXF1db1rGzsZCRi9dw8VFpja/\nSuqijNRFGamL6mGUXt9u3bqRmJhIcnIyxcXFbNy4kUGDBhkjFCGEEDdhlJaEqakpr776KtOnT0cp\nxdixY2/aaS2EEMI4NOpOOgiEEEI0CnKTgRBCiBuSJCGEEOKGJEkIIYS4oTqfJHbt2sXQoUMJCwvj\n888/N3Y4tSolJYUpU6Zw3333MWLECL755hsArly5wvTp0wkLC+ORRx4hJyfHyJHWHr1eT0REBDNn\nzgQgKSmJ8ePHExYWxuzZs9FqtUaOsHbk5OTwzDPPMGzYMIYPH87Ro0cb7XmxbNky7r//fkaMGMGc\nOXMoLi5uNOfF/Pnz6dOnDyNGjDBsu9l58MYbbzBkyBBGjRrFqVOnqnSMOp0krs7xtHTpUjZs2MDG\njRuJjY01dli1xtTUlHnz5rFp0yZWrFjB8uXLiY2N5fPPPycoKIjffvuNwMBAFi9ebOxQa80333xT\nbiTcf/7zH6ZNm8Zvv/2GnZ0dq1atMmJ0tWfhwoUMGDCAzZs3s3btWtq1a9coz4vU1FS+/fZbVq9e\nzfr169HpdGzcuLHRnBejR49m6dKl5bbd6DzYuXMniYmJbNmyhQULFvDaa69V6Rh1OklcO8eTubm5\nYY6nxsLFxYVOnToBYGtri5eXF6mpqWzbto2IiAgAIiIi2Lp1qzHDrDUpKSns3LmTcePGGbbt27eP\nsLAwoLQufv/9d2OFV2tyc3M5ePAgY8aMAcDMzAw7O7tGe17o9XoKCgrQarUUFhbi6upKdHR0ozgv\n/P39sbcvPwv29efB1e/Mbdu2ER4eDkCPHj3IyckhPT39lseo00misjmeLl26ZMSIjCcpKYmYmBh6\n9OhBRkYGzs7OQGkiuXz5spGjqx2LFi1i7ty5hrm7Ll++TNOmTTH530zA7u7ujeL8SEpKwtHRkXnz\n5hEREcGrr75KQUFBozwv3NzcmDZtGsHBwfTv3x87Ozs6d+6Mvb19ozsvrsrMzCx3HmRmZgJw6dIl\n3N3dDa9zc3MjNTX1lvur00lCbuEolZeXxzPPPMP8+fOxtbVtlBMcRkZG4uzsTKdOnQznhVKqwjnS\nGOpGq9Vy8uRJHnzwQdasWYO1tTWff/55o3jv18vOzmbbtm3s2LGDP/74g4KCAnbt2lXhdY2xbq5X\n2fdpVeqlTq9MdydzPDU0Wq2WZ555hlGjRjF48GAAnJycSE9Px9nZmbS0NJo1a/jref/5559s376d\nnTt3UlRURF5eHosWLSInJwe9Xo+JiQkpKSmN4vxwd3fH3d2dbt26ATBkyBC++OKLRnleREVF0bJl\nSxwcHAAYPHgwhw8fJjs7u9GdF1fd6Dxwc3MjJSXF8Lqq1kudbknIHE+loxe8vb2ZOnWqYdvAgQNZ\nvXo1AGvWrGkUdTJ79mwiIyPZtm0b7733HoGBgfznP/8hMDCQX3/9FWg8deHs7IyHhwfx8fFAab+M\nt7d3ozwvmjdvztGjRykqKkIpxb59+2jfvn2jOi+ubyHc6DwYNGgQv/zyCwBHjhzB3t7ecFnqZur8\ntBy7du1i4cKFhjmeqrLMaUNx6NAhJk2aRIcOHdBoNGg0Gp5//nm6d+/Oc889x8WLF2nevDkffvhh\nhc6rhmz//v18+eWXfPbZZ5w/f57Zs2eTnZ1Np06deOeddzA3Nzd2iDUuJiaGl19+Ga1WS8uWLXnz\nzTfR6XSN8rz45JNP2LhxI2ZmZnTu3Jk33niDlJSURnFezJkzh+joaLKysnB2dmbWrFkMHjyYZ599\nttLzYMGCBfzxxx9YW1vz5ptv0qVLl1seo84nCSGEEMZTpy83CSGEMC5JEkIIIW5IkoQQQogbkiQh\nhBDihiRJCCGEuCFJEkIIIW5IkoSocwYOHMjZs2dr5ViffPJJuWmk582bx/Lly+96v/PmzWPEiBHM\nnj37rvd1MzExMWzevLlGjyEaN0kSolH75JNPKCkpqdZ9pqens2XLFtavX897771Xrfu+3smTJ+84\nSej1+mqORjREkiREvREfH89jjz3GuHHjCA8PN0w9AODj48PixYsZO3YsoaGhbNmyxfDcb7/9xrBh\nwxg9ejSLFy/Gx8eHgoICFixYgEajYcKECURERJCbmwvA6dOnmTp1KmFhYbz00ks3jOeXX35hxIgR\njBo1ilmzZpGZmUleXh5Tp06lqKiIiIgIvv7663Jl1q5dy9NPP214rNPp6Nevn2GOsiVLljB+/HhG\njx7NE088QUZGBgAlJSW89dZbjBgxgvDwcGbNmkVWVhYff/wx+/btIyIigoULFwKlsxREREQwatQo\npk2bxvnz54HSO9XDw8N54403mDBhAn/88cfdfByisVBC1DEhISHqzJkz5bZptVoVERGh4uLilFJK\n5ebmqrCwMMPjjh07quXLlyullDp06JDq16+fUkqp9PR0FRAQoBITE5VSSn311VfKx8dH5efnG8oV\nFGMkQHAAAAOxSURBVBQYjvPSSy+pBx98UBUXF6vi4mI1fPhwFRUVVSHG06dPq3vvvVelp6crpZT6\n4IMP1HPPPaeUUiopKUn17t270vdWUFCgevfurS5fvqyUUmr79u1q6tSpSiml1q5dq1599VXDa7//\n/ns1Z84cpZRSH3/8sZo1a5bSarVKKWUov3r1avXMM88YymRkZKjevXur2NhYpZRSK1euVOPGjVNK\nKRUdHa06d+6sjh49WmlsQlRGWhKiXkhISCAuLo7Zs2cTHh7OQw89RElJSbmVCu+77z4AfH19SUtL\no7i4mKNHj9K1a1datmwJwNixYyvsW103M83gwYMxNzfH3Nyczp07k5iYWKFMdHQ0wcHBODk5ATBh\nwgSioqJu+T6srKwYNGgQGzZsAEonYLu6eND27dvZu3cv4eHhhIeH8/3333Px4kWgdKr0KVOmYGpq\nCmCY9fR6R48epVOnTrRr1w6AMWPGcOrUKfLz8wFo3bo13bt3v2WcQlxVp6cKF+IqpRTNmjVjzZo1\nlT6v0WiwtLQEMCw2o9PpKiSA6x9XxsLCwvC3qalppesjK6UqzMV/9bi3Eh4ezptvvsn999/P/v37\neeeddwz7fOKJJxg9enSlx6uKyuK69rGNjU2V9iPEVdKSEPVC27ZtsbKyYu3atYZtcXFx5OXlARW/\nRK8+9vX15cSJE4br8tf2YwA0adKk3ELxVRUUFMTOnTsNfQY//vgjffr0qXD8yvj7+5Obm8t7771H\naGioIbkNHDiQ77//nuzsbACKi4uJiYkBICQkhG+++cbQyX511bkmTZoY+lKuvt9Tp04ZphFfvXo1\nnTt3luQg7pi0JESdo9FoePjhhzEzMzP8Ml6/fj2fffYZCxcu5Msvv0Sn0+Hs7MwHH3xgKHP9PqB0\nAZZ//etfzJgxA0dHR4KDgzEzM8Pa2hqAadOmMWXKFKytrfn222+rHKO3tzezZ8/m4YcfxsTEhJYt\nW7JgwYIKx7+R8PBwPvroo/9v5w5xGASiIAwPBoMhHADNBRCcgtUEzQWQSByChAOgSHB4joVBLqlo\ngnumadK0/T/51LrZyeat1nW9Z2VZ6jgO1XWtIAh0XZeqqlKWZWqaRuM4yjmnMAyVpqmmaVJRFJrn\nWc455Xmurus0DIPatpX3XkmS3E0FeAVfhePnneepKIokPW/W27a9ZRcC+Ac0Cfy8ZVm077u894rj\nWH3ff/pIwNegSQAATDxcAwBMhAQAwERIAABMhAQAwERIAABMhAQAwPQAVSnSA55bZkwAAAAASUVO\nRK5CYII=\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f47b8e3bd90\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(eager_means)\n", + "plt.ylabel('Time(s)')\n", + "plt.xlabel('Length of vector')\n", + "_ = plt.title('Time to sum the elements of 1000 vectors (Eager)')\n", + "_ = plt.ylim(ymin=0)" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "Autograph vs. Eager vs Graph sum", + "provenance": [ + { + "file_id": "1olZkm32B7n7pQwlIAXR0_w8fZhRHCtkX", + "timestamp": 1531755808890 + } + ], + "version": "0.3.2", + "views": {} + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb b/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..e7dfb13e15a8c30fd905f0ed9db9f0f67d9b6e88 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb @@ -0,0 +1,1129 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "u3B7Uh50lozN" + }, + "outputs": [], + "source": [ + "!pip install -U -q tf-nightly" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "qWUV0FYjDSKj" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "from tensorflow.contrib import autograph\n", + "\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "kGXS3UWBBNoc" + }, + "source": [ + "# 1. AutoGraph writes graph code for you\n", + "\n", + "[AutoGraph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/README.md) helps you write complicated graph code using just plain Python -- behind the scenes, AutoGraph automatically transforms your code into the equivalent TF graph code. We support a large chunk of the Python language, which is growing. [Please see this document for what we currently support, and what we're working on](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/LIMITATIONS.md).\n", + "\n", + "Here's a quick example of how it works:\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "aA3gOodCBkOw" + }, + "outputs": [], + "source": [ + "# Autograph can convert functions like this...\n", + "def g(x):\n", + " if x \u003e 0:\n", + " x = x * x\n", + " else:\n", + " x = 0.0\n", + " return x\n", + "\n", + "# ...into graph-building functions like this:\n", + "def tf_g(x):\n", + " with tf.name_scope('g'):\n", + "\n", + " def if_true():\n", + " with tf.name_scope('if_true'):\n", + " x_1, = x,\n", + " x_1 = x_1 * x_1\n", + " return x_1,\n", + "\n", + " def if_false():\n", + " with tf.name_scope('if_false'):\n", + " x_1, = x,\n", + " x_1 = 0.0\n", + " return x_1,\n", + "\n", + " x = autograph_utils.run_cond(tf.greater(x, 0), if_true, if_false)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "I1RtBvoKBxq5" + }, + "outputs": [], + "source": [ + "# You can run your plain-Python code in graph mode,\n", + "# and get the same results out, but with all the benfits of graphs:\n", + "print('Original value: %2.2f' % g(9.0))\n", + "\n", + "# Generate a graph-version of g and call it:\n", + "tf_g = autograph.to_graph(g)\n", + "\n", + "with tf.Graph().as_default():\n", + " # The result works like a regular op: takes tensors in, returns tensors.\n", + " # You can inspect the graph using tf.get_default_graph().as_graph_def()\n", + " g_ops = tf_g(tf.constant(9.0))\n", + " with tf.Session() as sess:\n", + " print('Autograph value: %2.2f\\n' % sess.run(g_ops))\n", + "\n", + "\n", + "# You can view, debug and tweak the generated code:\n", + "print(autograph.to_code(g))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "m-jWmsCmByyw" + }, + "source": [ + "#### Automatically converting complex control flow\n", + "\n", + "AutoGraph can convert a large chunk of the Python language into equivalent graph-construction code, and we're adding new supported language features all the time. In this section, we'll give you a taste of some of the functionality in AutoGraph.\n", + "AutoGraph will automatically convert most Python control flow statements into their correct graph equivalent. \n", + " \n", + "We support common statements like `while`, `for`, `if`, `break`, `return` and more. You can even nest them as much as you like. Imagine trying to write the graph version of this code by hand:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "toxKBOXbB1ro" + }, + "outputs": [], + "source": [ + "# Continue in a loop\n", + "def f(l):\n", + " s = 0\n", + " for c in l:\n", + " if c % 2 \u003e 0:\n", + " continue\n", + " s += c\n", + " return s\n", + "\n", + "print('Original value: %d' % f([10,12,15,20]))\n", + "\n", + "tf_f = autograph.to_graph(f)\n", + "with tf.Graph().as_default():\n", + " with tf.Session():\n", + " print('Graph value: %d\\n\\n' % tf_f(tf.constant([10,12,15,20])).eval())\n", + "\n", + "print(autograph.to_code(f))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "FUJJ-WTdCGeq" + }, + "source": [ + "Try replacing the `continue` in the above code with `break` -- AutoGraph supports that as well! \n", + " \n", + "Let's try some other useful Python constructs, like `print` and `assert`. We automatically convert Python `assert` statements into the equivalent `tf.Assert` code. " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "IAOgh62zCPZ4" + }, + "outputs": [], + "source": [ + "def f(x):\n", + " assert x != 0, 'Do not pass zero!'\n", + " return x * x\n", + "\n", + "tf_f = autograph.to_graph(f)\n", + "with tf.Graph().as_default():\n", + " with tf.Session():\n", + " try:\n", + " print(tf_f(tf.constant(0)).eval())\n", + " except tf.errors.InvalidArgumentError as e:\n", + " print('Got error message:\\n%s' % e.message)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "KRu8iIPBCQr5" + }, + "source": [ + "You can also use plain Python `print` functions in in-graph" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "ySTsuxnqCTQi" + }, + "outputs": [], + "source": [ + "def f(n):\n", + " if n \u003e= 0:\n", + " while n \u003c 5:\n", + " n += 1\n", + " print(n)\n", + " return n\n", + "\n", + "tf_f = autograph.to_graph(f)\n", + "with tf.Graph().as_default():\n", + " with tf.Session():\n", + " tf_f(tf.constant(0)).eval()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "NqF0GT-VCVFh" + }, + "source": [ + "Appending to lists in loops also works (we create a tensor list ops behind the scenes)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "ABX070KwCczR" + }, + "outputs": [], + "source": [ + "def f(n):\n", + " z = []\n", + " # We ask you to tell us the element dtype of the list\n", + " autograph.set_element_type(z, tf.int32)\n", + " for i in range(n):\n", + " z.append(i)\n", + " # when you're done with the list, stack it\n", + " # (this is just like np.stack)\n", + " return autograph.stack(z)\n", + "\n", + "tf_f = autograph.to_graph(f)\n", + "with tf.Graph().as_default():\n", + " with tf.Session():\n", + " print(tf_f(tf.constant(3)).eval())\n", + "\n", + "print('\\n\\n'+autograph.to_code(f))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "iu5IF7n2Df7C" + }, + "outputs": [], + "source": [ + "def fizzbuzz(num):\n", + " if num % 3 == 0 and num % 5 == 0:\n", + " print('FizzBuzz')\n", + " elif num % 3 == 0:\n", + " print('Fizz')\n", + " elif num % 5 == 0:\n", + " print('Buzz')\n", + " else:\n", + " print(num)\n", + " return num" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "EExAjWuwDPpR" + }, + "outputs": [], + "source": [ + "tf_g = autograph.to_graph(fizzbuzz)\n", + "\n", + "with tf.Graph().as_default():\n", + " # The result works like a regular op: takes tensors in, returns tensors.\n", + " # You can inspect the graph using tf.get_default_graph().as_graph_def()\n", + " g_ops = tf_g(tf.constant(15))\n", + " with tf.Session() as sess:\n", + " sess.run(g_ops) \n", + " \n", + "# You can view, debug and tweak the generated code:\n", + "print('\\n')\n", + "print(autograph.to_code(fizzbuzz))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "SzpKGzVpBkph" + }, + "source": [ + "# De-graphify Exercises\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "8k23dxcSmmXq" + }, + "source": [ + "#### Easy print statements" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "dE1Vsmp-mlpK" + }, + "outputs": [], + "source": [ + "# See what happens when you turn AutoGraph off.\n", + "# Do you see the type or the value of x when you print it?\n", + "\n", + "# @autograph.convert()\n", + "def square_log(x):\n", + " x = x * x\n", + " print('Squared value of x =', x)\n", + " return x\n", + "\n", + "\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " print(sess.run(square_log(tf.constant(4))))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "_R-Q7BbxmkBF" + }, + "source": [ + "#### Convert the TensorFlow code into Python code for AutoGraph" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "SwA11tO-yCvg" + }, + "outputs": [], + "source": [ + "def square_if_positive(x):\n", + " x = tf.cond(tf.greater(x, 0), lambda: x * x, lambda: x)\n", + " return x\n", + "\n", + "with tf.Session() as sess:\n", + " print(sess.run(square_if_positive(tf.constant(4))))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "GPmx4CNhyPI_" + }, + "outputs": [], + "source": [ + "@autograph.convert()\n", + "def square_if_positive(x):\n", + "\n", + " pass # TODO: fill it in!\n", + "\n", + "\n", + "with tf.Session() as sess:\n", + " print(sess.run(square_if_positive(tf.constant(4))))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "qqsjik-QyA9R" + }, + "source": [ + "#### Uncollapse to see answer" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "DaSmaWUEvMRv" + }, + "outputs": [], + "source": [ + "# Simple cond\n", + "@autograph.convert()\n", + "def square_if_positive(x):\n", + " if x \u003e 0:\n", + " x = x * x\n", + " return x\n", + "\n", + "with tf.Graph().as_default(): \n", + " with tf.Session() as sess:\n", + " print(sess.run(square_if_positive(tf.constant(4))))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "qj7am2I_xvTJ" + }, + "source": [ + "#### Nested If statement" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "4yyNOf-Twr6s" + }, + "outputs": [], + "source": [ + "def nearest_odd_square(x):\n", + "\n", + " def if_positive():\n", + " x1 = x * x\n", + " x1 = tf.cond(tf.equal(x1 % 2, 0), lambda: x1 + 1, lambda: x1)\n", + " return x1,\n", + "\n", + " x = tf.cond(tf.greater(x, 0), if_positive, lambda: x)\n", + " return x\n", + "\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " print(sess.run(nearest_odd_square(tf.constant(4))))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "hqmh5b2VyU9w" + }, + "outputs": [], + "source": [ + "@autograph.convert()\n", + "def nearest_odd_square(x):\n", + "\n", + " pass # TODO: fill it in!\n", + "\n", + "\n", + "with tf.Session() as sess:\n", + " print(sess.run(nearest_odd_square(tf.constant(4))))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "b9AXIkNLxp6J" + }, + "source": [ + "#### Uncollapse to reveal answer" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "8RlCVEpNxD91" + }, + "outputs": [], + "source": [ + "@autograph.convert()\n", + "def nearest_odd_square(x):\n", + " if x \u003e 0:\n", + " x = x * x\n", + " if x % 2 == 0:\n", + " x = x + 1\n", + " return x\n", + "\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " print(sess.run(nearest_odd_square(tf.constant(4))))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "jXAxjeBr1qWK" + }, + "source": [ + "#### Convert a while loop" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "kWkv7anlxoee" + }, + "outputs": [], + "source": [ + "# Convert a while loop\n", + "def square_until_stop(x, y):\n", + " x = tf.while_loop(lambda x: tf.less(x, y), lambda x: x * x, [x])\n", + " return x\n", + "\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "zVUsc1eA1u2K" + }, + "outputs": [], + "source": [ + "@autograph.convert()\n", + "def square_until_stop(x, y):\n", + "\n", + " pass # TODO: fill it in!\n", + "\n", + "\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "L2psuzPI02S9" + }, + "source": [ + "#### Uncollapse for the answer\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "ucmZyQVL03bF" + }, + "outputs": [], + "source": [ + "@autograph.convert()\n", + "def square_until_stop(x, y):\n", + " while x \u003c y:\n", + " x = x * x\n", + " return x\n", + "\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "FXB0Zbwl13PY" + }, + "source": [ + "#### Nested loop and conditional" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "clGymxdf15Ig" + }, + "outputs": [], + "source": [ + "@autograph.convert()\n", + "def argwhere_cumsum(x, threshold):\n", + " current_sum = 0.0\n", + " idx = 0\n", + "\n", + " for i in range(len(x)):\n", + " idx = i\n", + " if current_sum \u003e= threshold:\n", + " break\n", + " current_sum += x[i]\n", + " return idx\n", + "\n", + "n = 10\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " idx = argwhere_cumsum(tf.ones(n), tf.constant(float(n / 2)))\n", + " print(sess.run(idx))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "i7PF-uId9lp5" + }, + "outputs": [], + "source": [ + "@autograph.convert()\n", + "def argwhere_cumsum(x, threshold):\n", + "\n", + " pass # TODO: fill it in!\n", + "\n", + "\n", + "n = 10\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " idx = argwhere_cumsum(tf.ones(n), tf.constant(float(n / 2)))\n", + " print(sess.run(idx))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "weKFXAb615Vp" + }, + "source": [ + "#### Uncollapse to see answer" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "1sjaFcL717Ig" + }, + "outputs": [], + "source": [ + "@autograph.convert()\n", + "def argwhere_cumsum(x, threshold):\n", + " current_sum = 0.0\n", + " idx = 0\n", + " for i in range(len(x)):\n", + " idx = i\n", + " if current_sum \u003e= threshold:\n", + " break\n", + " current_sum += x[i]\n", + " return idx\n", + "\n", + "n = 10\n", + "with tf.Graph().as_default(): \n", + " with tf.Session() as sess:\n", + " idx = argwhere_cumsum(tf.ones(n), tf.constant(float(n / 2)))\n", + " print(sess.run(idx))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4LfnJjm0Bm0B" + }, + "source": [ + "# 3. Training MNIST in-graph\n", + "\n", + "Writing control flow in AutoGraph is easy, so running a training loop in a TensorFlow graph should be easy as well! \n", + "\n", + "Here, we show an example of training a simple Keras model on MNIST, where the entire training process -- loading batches, calculating gradients, updating parameters, calculating validation accuracy, and repeating until convergence -- is done in-graph." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Em5dzSUOtLRP" + }, + "source": [ + "#### Download data" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "xqoxumv0ssQW" + }, + "outputs": [], + "source": [ + "import gzip\n", + "import os\n", + "import shutil\n", + "\n", + "from six.moves import urllib\n", + "\n", + "\n", + "def download(directory, filename):\n", + " filepath = os.path.join(directory, filename)\n", + " if tf.gfile.Exists(filepath):\n", + " return filepath\n", + " if not tf.gfile.Exists(directory):\n", + " tf.gfile.MakeDirs(directory)\n", + " url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n", + " zipped_filepath = filepath + '.gz'\n", + " print('Downloading %s to %s' % (url, zipped_filepath))\n", + " urllib.request.urlretrieve(url, zipped_filepath)\n", + " with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:\n", + " shutil.copyfileobj(f_in, f_out)\n", + " os.remove(zipped_filepath)\n", + " return filepath\n", + "\n", + "\n", + "def dataset(directory, images_file, labels_file):\n", + " images_file = download(directory, images_file)\n", + " labels_file = download(directory, labels_file)\n", + "\n", + " def decode_image(image):\n", + " # Normalize from [0, 255] to [0.0, 1.0]\n", + " image = tf.decode_raw(image, tf.uint8)\n", + " image = tf.cast(image, tf.float32)\n", + " image = tf.reshape(image, [784])\n", + " return image / 255.0\n", + "\n", + " def decode_label(label):\n", + " label = tf.decode_raw(label, tf.uint8)\n", + " label = tf.reshape(label, [])\n", + " return tf.to_int32(label)\n", + "\n", + " images = tf.data.FixedLengthRecordDataset(\n", + " images_file, 28 * 28, header_bytes=16).map(decode_image)\n", + " labels = tf.data.FixedLengthRecordDataset(\n", + " labels_file, 1, header_bytes=8).map(decode_label)\n", + " return tf.data.Dataset.zip((images, labels))\n", + "\n", + "\n", + "def mnist_train(directory):\n", + " return dataset(directory, 'train-images-idx3-ubyte',\n", + " 'train-labels-idx1-ubyte')\n", + "\n", + "def mnist_test(directory):\n", + " return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "znmy4l8ntMvW" + }, + "source": [ + "#### Define the model" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "Pe-erWQdBoC5" + }, + "outputs": [], + "source": [ + "def mlp_model(input_shape):\n", + " model = tf.keras.Sequential((\n", + " tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),\n", + " tf.keras.layers.Dense(100, activation='relu'),\n", + " tf.keras.layers.Dense(10, activation='softmax')))\n", + " model.build()\n", + " return model\n", + "\n", + "\n", + "def predict(m, x, y):\n", + " y_p = m(x)\n", + " losses = tf.keras.losses.categorical_crossentropy(y, y_p)\n", + " l = tf.reduce_mean(losses)\n", + " accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n", + " accuracy = tf.reduce_mean(accuracies)\n", + " return l, accuracy\n", + "\n", + "\n", + "def fit(m, x, y, opt):\n", + " l, accuracy = predict(m, x, y)\n", + " opt.minimize(l)\n", + " return l, accuracy\n", + "\n", + "\n", + "def setup_mnist_data(is_training, hp, batch_size):\n", + " if is_training:\n", + " ds = mnist_train('/tmp/autograph_mnist_data')\n", + " ds = ds.shuffle(batch_size * 10)\n", + " else:\n", + " ds = mnist_test('/tmp/autograph_mnist_data')\n", + " ds = ds.repeat()\n", + " ds = ds.batch(batch_size)\n", + " return ds\n", + "\n", + "\n", + "def get_next_batch(ds):\n", + " itr = ds.make_one_shot_iterator()\n", + " image, label = itr.get_next()\n", + " x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))\n", + " y = tf.one_hot(tf.squeeze(label), 10)\n", + " return x, y" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "oeYV6mKnJGMr" + }, + "source": [ + "#### Define the training loop" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "3xtg_MMhJETd" + }, + "outputs": [], + "source": [ + "def train(train_ds, test_ds, hp):\n", + " m = mlp_model((28 * 28,))\n", + " opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n", + "\n", + " # We'd like to save our losses to a list. In order for AutoGraph\n", + " # to convert these lists into their graph equivalent,\n", + " # we need to specify the element type of the lists.\n", + " train_losses = []\n", + " test_losses = []\n", + " train_accuracies = []\n", + " test_accuracies = []\n", + " autograph.set_element_type(train_losses, tf.float32)\n", + " autograph.set_element_type(test_losses, tf.float32)\n", + " autograph.set_element_type(train_accuracies, tf.float32)\n", + " autograph.set_element_type(test_accuracies, tf.float32)\n", + "\n", + " # This entire training loop will be run in-graph.\n", + " i = tf.constant(0)\n", + " while i \u003c hp.max_steps:\n", + " train_x, train_y = get_next_batch(train_ds)\n", + " test_x, test_y = get_next_batch(test_ds)\n", + "\n", + " step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)\n", + " step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n", + "\n", + " if i % (hp.max_steps // 10) == 0:\n", + " print('Step', i, 'train loss:', step_train_loss, 'test loss:',\n", + " step_test_loss, 'train accuracy:', step_train_accuracy,\n", + " 'test accuracy:', step_test_accuracy)\n", + "\n", + " train_losses.append(step_train_loss)\n", + " test_losses.append(step_test_loss)\n", + " train_accuracies.append(step_train_accuracy)\n", + " test_accuracies.append(step_test_accuracy)\n", + "\n", + " i += 1\n", + "\n", + " # We've recorded our loss values and accuracies\n", + " # to a list in a graph with AutoGraph's help.\n", + " # In order to return the values as a Tensor,\n", + " # we need to stack them before returning them.\n", + " return (\n", + " autograph.stack(train_losses),\n", + " autograph.stack(test_losses),\n", + " autograph.stack(train_accuracies),\n", + " autograph.stack(test_accuracies),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "HYh6MSZyJOag" + }, + "outputs": [], + "source": [ + "with tf.Graph().as_default():\n", + " hp = tf.contrib.training.HParams(\n", + " learning_rate=0.05,\n", + " max_steps=500,\n", + " )\n", + " train_ds = setup_mnist_data(True, hp, 50)\n", + " test_ds = setup_mnist_data(False, hp, 1000)\n", + " tf_train = autograph.to_graph(train)\n", + " loss_tensors = tf_train(train_ds, test_ds, hp)\n", + "\n", + " with tf.Session() as sess:\n", + " sess.run(tf.global_variables_initializer())\n", + " (\n", + " train_losses,\n", + " test_losses,\n", + " train_accuracies,\n", + " test_accuracies\n", + " ) = sess.run(loss_tensors)\n", + "\n", + " plt.title('MNIST train/test losses')\n", + " plt.plot(train_losses, label='train loss')\n", + " plt.plot(test_losses, label='test loss')\n", + " plt.legend()\n", + " plt.xlabel('Training step')\n", + " plt.ylabel('Loss')\n", + " plt.show()\n", + " plt.title('MNIST train/test accuracies')\n", + " plt.plot(train_accuracies, label='train accuracy')\n", + " plt.plot(test_accuracies, label='test accuracy')\n", + " plt.legend(loc='lower right')\n", + " plt.xlabel('Training step')\n", + " plt.ylabel('Accuracy')\n", + " plt.show()" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "qqsjik-QyA9R", + "b9AXIkNLxp6J", + "L2psuzPI02S9", + "weKFXAb615Vp", + "Em5dzSUOtLRP" + ], + "default_view": {}, + "name": "AutoGraph Workshop.ipynb", + "provenance": [ + { + "file_id": "1kE2gz_zuwdYySL4K2HQSz13uLCYi-fYP", + "timestamp": 1530563781803 + } + ], + "version": "0.3.2", + "views": {} + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py index c7401c7df126b73ca22cdaf74a2f1fd6149d7545..4729c735c621e68df30acfd4738d89874c3c55ac 100644 --- a/tensorflow/contrib/autograph/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -23,7 +23,6 @@ from functools import wraps from enum import Enum # pylint:disable=g-bad-import-order -import gast import six # pylint:enable=g-bad-import-order @@ -69,7 +68,8 @@ def convert(recursive=False, verbose=False, arg_types=None): @wraps(f) def wrapper(*args, **kwargs): - return converted_call(f, recursive, verbose, arg_types, *args, **kwargs) + return converted_call(f, recursive, verbose, True, arg_types, *args, + **kwargs) wrapper = tf_decorator.make_decorator(f, wrapper) @@ -99,6 +99,7 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): Returns: A decorator that wraps the original function. """ + def decorator(f): """Decorator implementation.""" @@ -109,8 +110,7 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): @wraps(f) def py_func_wrapper(*args, **kwargs): if kwargs: - raise NotImplementedError( - 'RunMode.PY_FUNC does not yet support kwargs') + raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs') # TODO(mdan): Add support for kwargs. return py_func.wrap_py_func( f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes) @@ -130,12 +130,12 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): return decorator -def converted_call(f, recursive, verbose, arg_types, *args, **kwargs): +def converted_call(f, recursive, verbose, force_conversion, arg_types, *args, + **kwargs): """Compiles a function call inline.""" # TODO(mdan): This needs cleanup. # In particular, we may want to avoid renaming functions altogether. - - if conversion.is_whitelisted_for_graph(f): + if not force_conversion and conversion.is_whitelisted_for_graph(f): return f(*args, **kwargs) unknown_arg_value = object() # Sentinel for arguments of unknown value @@ -231,7 +231,10 @@ def to_graph(e, Returns: A function with a signature identical to `o`, but which when executed it - creates TF a graph that has the same functionality as the original entity. + creates TF a graph that has the same functionality as the original entity. + Raises: + ValueError: If the converted function defines or refers to symbol names that + are reserved for AutoGraph. """ program_ctx = converter.ProgramContext( recursive=recursive, @@ -242,24 +245,41 @@ def to_graph(e, _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) - module = gast.Module([]) + nodes = [] for dep in reversed(program_ctx.dependency_cache.values()): - module.body.append(dep) - compiled_node, compiled_src = compiler.ast_to_object( - module, source_prefix=program_ctx.required_imports) + nodes.extend(dep) + compiled_module, compiled_src = compiler.ast_to_object( + nodes, + source_prefix=program_ctx.required_imports, + include_source_map=True) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. - if key not in compiled_node.__dict__: - compiled_node.__dict__[key] = val - compiled_fn = getattr(compiled_node, name) + if key not in compiled_module.__dict__: + compiled_module.__dict__[key] = val + compiled = getattr(compiled_module, name) + + # Need this so the source_mapping attribute is available for the context + # manager to access for runtime errors. + # + # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' + # symbol to the compiled module. + # TODO(mdan): Record this statically in the generated code. + # TODO(mdan): Rename this attribute to 'autograph_info__' + source_map_attribute_name = 'ag_source_map' + if getattr(compiled, source_map_attribute_name, None) is not None: + raise ValueError('cannot convert %s because is has an attribute ' + '"%s", which is reserved for AutoGraph.' % + (compiled, source_map_attribute_name)) + setattr(compiled, source_map_attribute_name, + compiled_module.__dict__['ag_source_map__']) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) - return compiled_fn + return compiled def to_code(e, diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py index 994309333209586001c9369322ec3ddeee0a508e..803fde9089b1c004d9bfc0dfefd3d6b422752f0a 100644 --- a/tensorflow/contrib/autograph/impl/api_test.py +++ b/tensorflow/contrib/autograph/impl/api_test.py @@ -183,8 +183,8 @@ class ApiTest(test.TestCase): @api.convert(recursive=True) def test_method(self, x, s, a): while tf.reduce_sum(x) > s: - x //= api.converted_call(self.called_member, False, False, {}, self, - a) + x //= api.converted_call(self.called_member, False, False, False, {}, + self, a) return x tc = TestClass() @@ -195,7 +195,7 @@ class ApiTest(test.TestCase): self.assertListEqual([0, 1], sess.run(x).tolist()) def test_converted_call_builtin(self): - x = api.converted_call(range, False, False, {}, 3) + x = api.converted_call(range, False, False, False, {}, 3) self.assertEqual((0, 1, 2), tuple(x)) def test_converted_call_function(self): @@ -206,8 +206,8 @@ class ApiTest(test.TestCase): return x with self.test_session() as sess: - x = api.converted_call( - test_fn, False, False, {}, constant_op.constant(-1)) + x = api.converted_call(test_fn, False, False, False, {}, + constant_op.constant(-1)) self.assertEqual(1, sess.run(x)) def test_converted_call_method(self): @@ -224,7 +224,7 @@ class ApiTest(test.TestCase): with self.test_session() as sess: tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(tc.test_method, False, False, {}, tc) + x = api.converted_call(tc.test_method, False, False, False, {}, tc) self.assertEqual(1, sess.run(x)) def test_converted_call_method_by_class(self): @@ -241,7 +241,7 @@ class ApiTest(test.TestCase): with self.test_session() as sess: tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(TestClass.test_method, False, False, {}, tc) + x = api.converted_call(TestClass.test_method, False, False, False, {}, tc) self.assertEqual(1, sess.run(x)) def test_converted_call_callable_object(self): @@ -258,7 +258,7 @@ class ApiTest(test.TestCase): with self.test_session() as sess: tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(tc, False, False, {}) + x = api.converted_call(tc, False, False, False, {}) self.assertEqual(1, sess.run(x)) def test_converted_call_constructor(self): @@ -274,12 +274,27 @@ class ApiTest(test.TestCase): return self.x with self.test_session() as sess: - tc = api.converted_call( - TestClass, False, False, {}, constant_op.constant(-1)) + tc = api.converted_call(TestClass, False, False, False, {}, + constant_op.constant(-1)) # tc is now a converted object. x = tc.test_method() self.assertEqual(1, sess.run(x)) + def test_converted_call_already_converted(self): + + def f(x): + return x == 0 + + with self.test_session() as sess: + x = api.converted_call(f, False, False, False, {}, + constant_op.constant(0)) + self.assertTrue(sess.run(x)) + + converted_f = api.to_graph(f) + x = api.converted_call(converted_f, False, False, False, {}, + constant_op.constant(0)) + self.assertTrue(sess.run(x)) + def test_to_graph_basic(self): def test_fn(x, s): @@ -305,6 +320,13 @@ class ApiTest(test.TestCase): # Just check that it is parseable Python code. self.assertIsNotNone(parser.parse_str(compiled_code)) + def test_source_map_attribute_present(self): + + def test_fn(y): + return y**2 + + self.assertTrue(hasattr(api.to_graph(test_fn), 'ag_source_map')) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py index 776d19f672ebbd6b88985dda157434f2046d87e7..fc8a976d3f3ecdc9c6339995dd0dfc776824b90d 100644 --- a/tensorflow/contrib/autograph/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -28,26 +28,28 @@ from tensorflow.contrib.autograph.converters import asserts from tensorflow.contrib.autograph.converters import break_statements from tensorflow.contrib.autograph.converters import builtin_functions from tensorflow.contrib.autograph.converters import call_trees +from tensorflow.contrib.autograph.converters import conditional_expressions from tensorflow.contrib.autograph.converters import continue_statements from tensorflow.contrib.autograph.converters import control_flow from tensorflow.contrib.autograph.converters import decorators -from tensorflow.contrib.autograph.converters import ifexp +from tensorflow.contrib.autograph.converters import directives +from tensorflow.contrib.autograph.converters import error_handlers from tensorflow.contrib.autograph.converters import lists from tensorflow.contrib.autograph.converters import logical_expressions from tensorflow.contrib.autograph.converters import name_scopes +from tensorflow.contrib.autograph.converters import return_statements from tensorflow.contrib.autograph.converters import side_effect_guards -from tensorflow.contrib.autograph.converters import single_return from tensorflow.contrib.autograph.converters import slices from tensorflow.contrib.autograph.core import config from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.core import errors from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import inspect_utils +from tensorflow.contrib.autograph.pyct import origin_info from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct import transformer -from tensorflow.contrib.autograph.pyct.static_analysis import activity -from tensorflow.contrib.autograph.pyct.static_analysis import live_values -from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.util import tf_inspect @@ -69,6 +71,8 @@ def is_whitelisted_for_graph(o): for prefix, in config.DEFAULT_UNCOMPILED_MODULES: if m.__name__.startswith(prefix): return True + if hasattr(o, 'autograph_info__'): + return True return False @@ -114,12 +118,32 @@ def entity_to_graph(o, program_ctx, arg_values, arg_types): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) elif tf_inspect.ismethod(o): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) + # TODO(mdan,yashkatariya): Remove when object conversion is implemented. + elif hasattr(o, '__class__'): + raise NotImplementedError( + 'Object conversion is not yet supported. If you are ' + 'trying to convert code that uses an existing object, ' + 'try including the creation of that object in the ' + 'conversion. For example, instead of converting the method ' + 'of a class, try converting the entire class instead. ' + 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' + 'contrib/autograph/README.md#using-the-functional-api ' + 'for more information.') else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) + # TODO(mdan): This is temporary. it should be created using a converter. + # TODO(mdan): The attribute should be added with a helper, not directly. + # The helper can ensure there are no collisions. + template = ''' + entity.autograph_info__ = {} + ''' + node.extend(templates.replace(template, entity=name)) + program_ctx.add_to_cache(o, node) + if program_ctx.recursive: while True: candidate = None @@ -157,26 +181,27 @@ def class_to_graph(c, program_ctx): program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, - owner_type=c) + owner_type=c, + rewrite_errors=False) if class_namespace is None: class_namespace = namespace else: class_namespace.update(namespace) - converted_members[m] = node + converted_members[m] = node[0] namer = program_ctx.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. - # Process any base classes: if the sueprclass if of a whitelisted type, an + # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by # program_ctx.update_name_map(namer)). output_nodes = [] renames = {} - bases = [] + base_names = [] for base in c.__bases__: if isinstance(object, base): - bases.append('object') + base_names.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) @@ -188,28 +213,28 @@ def class_to_graph(c, program_ctx): else: # This will trigger a conversion into a class with this name. alias = namer.compiled_class_name(base.__name__, base) - bases.append(alias) + base_names.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) program_ctx.update_name_map(namer) # Generate the definition of the converted class. - output_nodes.append( - gast.ClassDef( - class_name, - bases=bases, - keywords=[], - body=list(converted_members.values()), - decorator_list=[])) - node = gast.Module(output_nodes) - + bases = [gast.Name(n, gast.Load(), None) for n in base_names] + class_def = gast.ClassDef( + class_name, + bases=bases, + keywords=[], + body=list(converted_members.values()), + decorator_list=[]) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. - node = qual_names.resolve(node) + class_def = qual_names.resolve(class_def) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) - node = ast_util.rename_symbols(node, renames) + class_def = ast_util.rename_symbols(class_def, renames) - return node, class_name, class_namespace + output_nodes.append(class_def) + + return output_nodes, class_name, class_namespace def _add_reserved_symbol(namespace, name, entity): @@ -231,6 +256,8 @@ def _add_self_references(namespace, autograph_module): ag_internal = imp.new_module('autograph') ag_internal.converted_call = autograph_module.converted_call ag_internal.utils = utils + ag_internal.rewrite_graph_construction_error = ( + errors.rewrite_graph_construction_error) # TODO(mdan): Add safeguards against name clashes. # We don't want to create a submodule because we want the operators to be # accessible as ag__. @@ -239,11 +266,17 @@ def _add_self_references(namespace, autograph_module): _add_reserved_symbol(namespace, 'ag__', ag_internal) -def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): +def function_to_graph(f, + program_ctx, + arg_values, + arg_types, + owner_type=None, + rewrite_errors=True): """Specialization of `entity_to_graph` for callable functions.""" + node, source = parser.parse_entity(f) node = node.body[0] - + origin_info.resolve(node, source, f) namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, program_ctx.autograph_module) namer = program_ctx.new_namer(namespace) @@ -256,38 +289,29 @@ def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): arg_types=arg_types, owner_type=owner_type) context = converter.EntityContext(namer, entity_info, program_ctx) - node = node_to_graph(node, context) + node = node_to_graph(node, context, rewrite_errors=rewrite_errors) - # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py + # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) if not did_rename: new_name = f.__name__ if node.name != f.__name__: raise NotImplementedError('Strange corner case. Send us offending code!') - node.name = new_name + program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. - return node, new_name, namespace - - -def _apply_transformer(node, context, converter_module): - # TODO(mdan): Clear static analysis here. - node = qual_names.resolve(node) - node = activity.resolve(node, context.info, None) - node = live_values.resolve(node, context.info, config.PYTHON_LITERALS) - node = type_info.resolve(node, context.info) - node = converter_module.transform(node, context) - return node + return [node], new_name, namespace -def node_to_graph(node, context): +def node_to_graph(node, context, rewrite_errors=True): """Convert Python code to equivalent TF graph mode code. Args: node: AST, the code to convert. context: converter.EntityContext + rewrite_errors: Boolean, whether or not to rewrite the error traceback. Returns: A tuple (node, deps): @@ -295,28 +319,33 @@ def node_to_graph(node, context): * deps: A set of strings, the fully qualified names of entity dependencies that this node has. """ - # TODO(mdan): Verify arguments for correctness. + # TODO(mdan): Insert list_comprehensions somewhere. - node = _apply_transformer(node, context, ifexp) + node = converter.standard_analysis(node, context, is_initial=True) # Past this point, line numbers are no longer accurate so we ignore the # source. # TODO(mdan): Is it feasible to reconstruct intermediate source code? context.info.source_code = None - node = _apply_transformer(node, context, decorators) - node = _apply_transformer(node, context, break_statements) - node = _apply_transformer(node, context, asserts) + + node = converter.apply_(node, context, decorators) + node = converter.apply_(node, context, directives) + node = converter.apply_(node, context, break_statements) + node = converter.apply_(node, context, asserts) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. - node = _apply_transformer(node, context, continue_statements) + node = converter.apply_(node, context, continue_statements) context.info.namespace['len'] = len - node = _apply_transformer(node, context, single_return) - node = _apply_transformer(node, context, lists) - node = _apply_transformer(node, context, slices) - node = _apply_transformer(node, context, builtin_functions) - node = _apply_transformer(node, context, call_trees) - node = _apply_transformer(node, context, control_flow) - node = _apply_transformer(node, context, logical_expressions) - node = _apply_transformer(node, context, side_effect_guards) - node = _apply_transformer(node, context, name_scopes) + node = converter.apply_(node, context, return_statements) + node = converter.apply_(node, context, lists) + node = converter.apply_(node, context, slices) + node = converter.apply_(node, context, builtin_functions) + node = converter.apply_(node, context, call_trees) + node = converter.apply_(node, context, control_flow) + node = converter.apply_(node, context, conditional_expressions) + node = converter.apply_(node, context, logical_expressions) + node = converter.apply_(node, context, side_effect_guards) + node = converter.apply_(node, context, name_scopes) + if rewrite_errors: + node = converter.apply_(node, context, error_handlers) return node diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py index f5279298afdcd406a9a6762e58367cea8ca63141..86432573a719ea3f2b163746996dbf3301785a91 100644 --- a/tensorflow/contrib/autograph/impl/conversion_test.py +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -50,7 +50,7 @@ class ConversionTest(test.TestCase): self.assertTrue(conversion.is_whitelisted_for_graph(constant_op.constant)) def test_entity_to_graph_unsupported_types(self): - with self.assertRaises(ValueError): + with self.assertRaises(NotImplementedError): program_ctx = self._simple_program_ctx() conversion.entity_to_graph('dummy', program_ctx, None, None) @@ -60,10 +60,11 @@ class ConversionTest(test.TestCase): return a + b program_ctx = self._simple_program_ctx() - ast, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) - self.assertTrue(isinstance(ast, gast.FunctionDef), ast) + nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) + fn_node, _ = nodes + self.assertIsInstance(fn_node, gast.FunctionDef) self.assertEqual('tf__f', name) - self.assertTrue(ns['b'] is b) + self.assertIs(ns['b'], b) def test_entity_to_graph_call_tree(self): @@ -78,12 +79,11 @@ class ConversionTest(test.TestCase): self.assertTrue(f in program_ctx.dependency_cache) self.assertTrue(g in program_ctx.dependency_cache) - self.assertEqual('tf__f', program_ctx.dependency_cache[f].name) - # need the extra .body[0] in order to step past the with tf.name_scope('f') - # that is added automatically - self.assertEqual( - 'tf__g', program_ctx.dependency_cache[f].body[0].body[0].value.func.id) - self.assertEqual('tf__g', program_ctx.dependency_cache[g].name) + f_node = program_ctx.dependency_cache[f][0] + g_node = program_ctx.dependency_cache[g][0] + self.assertEqual('tf__f', f_node.name) + self.assertEqual('tf__g', f_node.body[0].body[0].body[0].value.func.id) + self.assertEqual('tf__g', g_node.name) def test_entity_to_graph_class_hierarchy(self): @@ -115,10 +115,12 @@ class ConversionTest(test.TestCase): self.assertTrue(TestBase in program_ctx.dependency_cache) self.assertTrue(TestSubclass in program_ctx.dependency_cache) + # The returned nodes will include: + # , , self.assertEqual('TfTestBase', - program_ctx.dependency_cache[TestBase].body[-1].name) + program_ctx.dependency_cache[TestBase][-2].name) self.assertEqual('TfTestSubclass', - program_ctx.dependency_cache[TestSubclass].body[-1].name) + program_ctx.dependency_cache[TestSubclass][-2].name) def test_entity_to_graph_class_hierarchy_whitelisted(self): @@ -137,10 +139,11 @@ class ConversionTest(test.TestCase): self.assertTrue(TestSubclass in program_ctx.dependency_cache) self.assertFalse(training.Model in program_ctx.dependency_cache) self.assertEqual( - 'Model', - program_ctx.dependency_cache[TestSubclass].body[0].names[0].name) + 'Model', program_ctx.dependency_cache[TestSubclass][0].names[0].name) + # The returned nodes will include: + # , , self.assertEqual('TfTestSubclass', - program_ctx.dependency_cache[TestSubclass].body[-1].name) + program_ctx.dependency_cache[TestSubclass][-2].name) def test_entity_to_graph_lambda(self): f = lambda a: a diff --git a/tensorflow/contrib/autograph/lang/special_functions.py b/tensorflow/contrib/autograph/lang/special_functions.py index 11135295a7966bc5d693676fcc71fe43791f2e99..6149cbbd6c9214fb6989bdcae430286445b1db28 100644 --- a/tensorflow/contrib/autograph/lang/special_functions.py +++ b/tensorflow/contrib/autograph/lang/special_functions.py @@ -26,6 +26,43 @@ from __future__ import print_function from tensorflow.contrib.autograph.operators import data_structures +def tensor_list(elements, + element_dtype=None, + element_shape=None, + use_tensor_array=False): + """Creates an tensor list and populates it with the given elements. + + This function provides a more uniform access to tensor lists and tensor + arrays, and allows optional initialization. + + Note: this function is a simplified wrapper. If you need greater control, + it is recommended to use the underlying implementation directly. + + Args: + elements: Iterable[tf.Tensor, ...], the elements to initially fill the list + with + element_dtype: Optional[tf.DType], data type for the elements in the list; + required if the list is empty + element_shape: Optional[tf.TensorShape], shape for the elements in the list; + required if the list is empty + use_tensor_array: bool, whether to use the more compatible but restrictive + tf.TensorArray implementation + Returns: + Union[tf.Tensor, tf.TensorArray], the new list. + Raises: + ValueError: for invalid arguments + """ + if not (elements or (element_dtype and element_shape)): + raise ValueError( + 'element_dtype and element_shape are required for empty lists') + if use_tensor_array: + return data_structures.tf_tensor_array_new(elements, element_dtype, + element_shape) + else: + return data_structures.tf_tensor_list_new(elements, element_dtype, + element_shape) + + def stack(list_or_tensor, element_dtype=None, strict=True): """Stacks the input, if it admits the notion of stacking. diff --git a/tensorflow/contrib/autograph/lang/special_functions_test.py b/tensorflow/contrib/autograph/lang/special_functions_test.py index a49cb6407517b634e0f1259fccda03d4ed18e83f..db492cc5c689155bf7d426cbfee320130f4bda9f 100644 --- a/tensorflow/contrib/autograph/lang/special_functions_test.py +++ b/tensorflow/contrib/autograph/lang/special_functions_test.py @@ -28,7 +28,23 @@ from tensorflow.python.platform import test class SpecialFunctionsTest(test.TestCase): - def test_basic(self): + def test_tensor_list_from_elements(self): + elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])] + + l = special_functions.tensor_list(elements) + sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) + with self.test_session() as sess: + self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]]) + + def test_tensor_list_array_from_elements(self): + elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])] + + l = special_functions.tensor_list(elements, use_tensor_array=True) + sl = l.stack() + with self.test_session() as sess: + self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]]) + + def test_stack(self): self.assertEqual(special_functions.stack(1, strict=False), 1) self.assertListEqual( special_functions.stack([1, 2, 3], strict=False), [1, 2, 3]) diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py index c900fd6af2ea5dfb419f731ee8d8822d68424b27..392cb60bcc44c0f554defcddc50c4afbdaa25067 100644 --- a/tensorflow/contrib/autograph/operators/__init__.py +++ b/tensorflow/contrib/autograph/operators/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""This module implements operators that we overload. +"""This module implements operators that AutoGraph overloads. Note that "operator" is used loosely here, and includes control structures like conditionals and loops, implemented in functional form, using for example diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py index 988df70157170ed0a9ece33976e871e6f7693bbc..9909e521644a7a901653dc09853222167828c75c 100644 --- a/tensorflow/contrib/autograph/operators/control_flow.py +++ b/tensorflow/contrib/autograph/operators/control_flow.py @@ -141,7 +141,7 @@ def _dataset_for_stmt(ds, extra_test, body, init_state): while_body, init_state=(epoch_number, iterate) + init_state, extra_deps=()) - # Dropping the epoch number and iterate because they are not not syntactically + # Dropping the epoch number and iterate because they are not syntactically # visible. results = results[2:] @@ -212,12 +212,12 @@ def if_stmt(cond, body, orelse): Tuple containing the statement outputs. """ if tensor_util.is_tensor(cond): - return _tf_if_stmt(cond, body, orelse) + return tf_if_stmt(cond, body, orelse) else: return _py_if_stmt(cond, body, orelse) -def _tf_if_stmt(cond, body, orelse): +def tf_if_stmt(cond, body, orelse): """Overload of if_stmt that stages a TF cond.""" return control_flow_ops.cond(cond, body, orelse) diff --git a/tensorflow/contrib/autograph/operators/data_structures.py b/tensorflow/contrib/autograph/operators/data_structures.py index 06d8727b0fcc30b532b3f11281cd1a83c51ac8bc..cc0a3c35448980945f2975f829f9d9421afdb65d 100644 --- a/tensorflow/contrib/autograph/operators/data_structures.py +++ b/tensorflow/contrib/autograph/operators/data_structures.py @@ -28,7 +28,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops -from tensorflow.python.ops import variables # TODO(mdan): Once control flow supports objects, repackage as a class. @@ -48,29 +47,101 @@ def new_list(iterable=None): else: elements = () - # TODO(mdan): Extend these criteria. - if any(isinstance(el, variables.Variable) for el in elements): + if elements: + # When the list contains elements, it is assumed to be a "Python" lvalue + # list. return _py_list_new(elements) - return _tf_tensor_list_new(elements) + return tf_tensor_list_new(elements) -def _tf_tensor_list_new(elements): +def tf_tensor_array_new(elements, element_dtype=None, element_shape=None): """Overload of new_list that stages a Tensor list creation.""" elements = tuple(ops.convert_to_tensor(el) for el in elements) + + all_dtypes = set(el.dtype for el in elements) + if len(all_dtypes) == 1: + inferred_dtype, = tuple(all_dtypes) + if element_dtype is not None and element_dtype != inferred_dtype: + raise ValueError( + 'incompatible dtype; specified: {}, inferred from {}: {}'.format( + element_dtype, elements, inferred_dtype)) + elif len(all_dtypes) > 1: + raise ValueError( + 'TensorArray requires all elements to have the same dtype:' + ' {}'.format(elements)) + else: + if element_dtype is None: + raise ValueError('dtype is required to create an empty TensorArray') + + all_shapes = set(tuple(el.shape.as_list()) for el in elements) + if len(all_shapes) == 1: + inferred_shape, = tuple(all_shapes) + if element_shape is not None and element_shape != inferred_shape: + raise ValueError( + 'incompatible shape; specified: {}, inferred from {}: {}'.format( + element_shape, elements, inferred_shape)) + elif len(all_shapes) > 1: + raise ValueError( + 'TensorArray requires all elements to have the same shape:' + ' {}'.format(elements)) + # TODO(mdan): We may want to allow different shapes with infer_shape=False. + else: + inferred_shape = None + + if element_dtype is None: + element_dtype = inferred_dtype + if element_shape is None: + element_shape = inferred_shape + + l = tensor_array_ops.TensorArray( + dtype=element_dtype, + size=len(elements), + dynamic_size=True, + infer_shape=(element_shape is None), + element_shape=element_shape) + for i, el in enumerate(elements): + l = l.write(i, el) + return l + + +def tf_tensor_list_new(elements, element_dtype=None, element_shape=None): + """Overload of new_list that stages a Tensor list creation.""" + elements = tuple(ops.convert_to_tensor(el) for el in elements) + all_dtypes = set(el.dtype for el in elements) if len(all_dtypes) == 1: - element_dtype = tuple(all_dtypes)[0] + inferred_dtype = tuple(all_dtypes)[0] + if element_dtype is not None and element_dtype != inferred_dtype: + raise ValueError( + 'incompatible dtype; specified: {}, inferred from {}: {}'.format( + element_dtype, elements, inferred_dtype)) else: # Heterogeneous lists are ok. - element_dtype = dtypes.variant + if element_dtype is not None: + raise ValueError( + 'specified dtype {} is inconsistent with that of elements {}'.format( + element_dtype, elements)) + inferred_dtype = dtypes.variant - # TODO(mdan): This may fail for elements of variable shapes. all_shapes = set(tuple(el.shape.as_list()) for el in elements) if len(all_shapes) == 1: - element_shape = array_ops.shape(elements[0]) + inferred_shape = array_ops.shape(elements[0]) + if element_shape is not None and element_shape != inferred_shape: + raise ValueError( + 'incompatible shape; specified: {}, inferred from {}: {}'.format( + element_shape, elements, inferred_shape)) else: # Heterogeneous lists are ok. - element_shape = constant_op.constant(-1) # unknown shape, by convention + if element_shape is not None: + raise ValueError( + 'specified shape {} is inconsistent with that of elements {}'.format( + element_shape, elements)) + inferred_shape = constant_op.constant(-1) # unknown shape, by convention + + if element_dtype is None: + element_dtype = inferred_dtype + if element_shape is None: + element_shape = inferred_shape l = list_ops.empty_tensor_list( element_shape=element_shape, element_dtype=element_dtype) diff --git a/tensorflow/contrib/autograph/operators/data_structures_test.py b/tensorflow/contrib/autograph/operators/data_structures_test.py index 8bbb52d6c10b241ec754c7dea599fa15a869595f..7ea11a839b6070f6c6dfdd8a8f7939923a7d9eaa 100644 --- a/tensorflow/contrib/autograph/operators/data_structures_test.py +++ b/tensorflow/contrib/autograph/operators/data_structures_test.py @@ -37,10 +37,51 @@ class ListTest(test.TestCase): def test_new_list_tensor(self): l = data_structures.new_list([3, 4, 5]) + self.assertAllEqual(l, [3, 4, 5]) + + def test_tf_tensor_list_new(self): + l = data_structures.tf_tensor_list_new([3, 4, 5]) t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) with self.test_session() as sess: self.assertAllEqual(sess.run(t), [3, 4, 5]) + def test_tf_tensor_list_new_illegal_input(self): + with self.assertRaises(ValueError): + data_structures.tf_tensor_list_new([3, 4.0]) + # TODO(mdan): It might make more sense to type cast in this case. + with self.assertRaises(ValueError): + data_structures.tf_tensor_list_new([3, 4], element_dtype=dtypes.float32) + # Tensor lists do support heterogeneous lists. + self.assertIsNot(data_structures.tf_tensor_list_new([3, [4, 5]]), None) + with self.assertRaises(ValueError): + data_structures.tf_tensor_list_new([3, 4], element_shape=(2,)) + with self.assertRaises(ValueError): + data_structures.tf_tensor_list_new([], element_shape=(2,)) + with self.assertRaises(ValueError): + data_structures.tf_tensor_list_new([], element_dtype=dtypes.float32) + + def test_tf_tensor_array_new(self): + l = data_structures.tf_tensor_array_new([3, 4, 5]) + t = l.stack() + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [3, 4, 5]) + + def test_tf_tensor_array_new_illegal_input(self): + with self.assertRaises(ValueError): + data_structures.tf_tensor_array_new([3, 4.0]) + with self.assertRaises(ValueError): + data_structures.tf_tensor_array_new([3, 4], element_dtype=dtypes.float32) + with self.assertRaises(ValueError): + data_structures.tf_tensor_array_new([3, [4, 5]]) + with self.assertRaises(ValueError): + data_structures.tf_tensor_array_new([3, 4], element_shape=(2,)) + with self.assertRaises(ValueError): + data_structures.tf_tensor_array_new([], element_shape=(2,)) + # TAs can infer the shape. + self.assertIsNot( + data_structures.tf_tensor_array_new([], element_dtype=dtypes.float32), + None) + def test_append_tensor_list(self): l = data_structures.new_list() x = constant_op.constant([1, 2, 3]) diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD index 8f09689fe9b33bec03dc8b5370633c3a953fa322..ddadc6b96e8eb5417bfa1676ae304f7cbdedd92b 100644 --- a/tensorflow/contrib/autograph/pyct/BUILD +++ b/tensorflow/contrib/autograph/pyct/BUILD @@ -22,8 +22,10 @@ py_library( "__init__.py", "anno.py", "ast_util.py", + "cfg.py", "compiler.py", "inspect_utils.py", + "origin_info.py", "parser.py", "pretty_printer.py", "qual_names.py", @@ -63,6 +65,17 @@ py_test( ], ) +py_test( + name = "cfg_test", + srcs = ["cfg_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + py_test( name = "compiler_test", srcs = ["compiler_test.py"], @@ -86,6 +99,16 @@ py_test( ], ) +py_test( + name = "origin_info_test", + srcs = ["origin_info_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "parser_test", srcs = ["parser_test.py"], diff --git a/tensorflow/contrib/autograph/pyct/anno.py b/tensorflow/contrib/autograph/pyct/anno.py index ae861627fd65cca057e7bf1af41424e605d4b7a1..1a52110ef36bbc0888e03cc25b3717822cb75c16 100644 --- a/tensorflow/contrib/autograph/pyct/anno.py +++ b/tensorflow/contrib/autograph/pyct/anno.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Handling annotations on AST nodes. +"""AST node annotation support. Adapted from Tangent. """ @@ -21,37 +21,90 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from enum import Enum +import enum +# pylint:disable=g-bad-import-order +import gast +# pylint:enable=g-bad-import-order -class NoValue(Enum): + +# TODO(mdan): Shorten the names. +# These names are heavily used, and anno.blaa +# TODO(mdan): Replace the attr-dict mechanism with a more typed solution. + + +class NoValue(enum.Enum): def __repr__(self): return self.name class Basic(NoValue): - """Container for annotation keys. + """Container for basic annotation keys. The enum values are used strictly for documentation purposes. """ - QN = 'Qualified name, as it appeared in the code.' + QN = 'Qualified name, as it appeared in the code. See qual_names.py.' SKIP_PROCESSING = ( 'This node should be preserved as is and not processed any further.') INDENT_BLOCK_REMAINDER = ( - 'When a node is annotated with this, the remainder of the block should ' - 'be indented below it. The annotation contains a tuple ' - '(new_body, name_map), where `new_body` is the new indented block and ' - '`name_map` allows renaming symbols.') + 'When a node is annotated with this, the remainder of the block should' + ' be indented below it. The annotation contains a tuple' + ' (new_body, name_map), where `new_body` is the new indented block and' + ' `name_map` allows renaming symbols.') + ORIGIN = ('Information about the source code that converted code originated' + ' from. See origin_information.py.') + + +class Static(NoValue): + """Container for static analysis annotation keys. + + The enum values are used strictly for documentation purposes. + """ + + # Deprecated - use reaching definitions instead. + # Symbols + # These flags are boolean. + IS_LOCAL = 'Symbol is local to the function scope being analyzed.' + IS_PARAM = 'Symbol is a parameter to the function being analyzed.' + + # Scopes + # Scopes are represented by objects of type activity.Scope. + SCOPE = 'The scope for the annotated node. See activity.py.' + # TODO(mdan): Drop these in favor of accessing the child's SCOPE. + ARGS_SCOPE = 'The scope for the argument list of a function call.' + COND_SCOPE = 'The scope for the test node of a conditional statement.' + BODY_SCOPE = ( + 'The scope for the main body of a statement (True branch for if ' + 'statements, main body for loops).') + ORELSE_SCOPE = ( + 'The scope for the orelse body of a statement (False branch for if ' + 'statements, orelse body for loops).') + + # Static analysis annotations. + DEFINITIONS = ( + 'Reaching definition information. See reaching_definitions.py.') + ORIG_DEFINITIONS = ( + 'The value of DEFINITIONS that applied to the original code before any' + ' conversion.') + DEFINED_VARS_IN = ( + 'Symbols defined when entering the node. See reaching_definitions.py.') + LIVE_VARS_OUT = ('Symbols live when exiting the node. See liveness.py.') FAIL = object() +def keys(node, field_name='___pyct_anno'): + if not hasattr(node, field_name): + return frozenset() + return frozenset(getattr(node, field_name).keys()) + + def getanno(node, key, default=FAIL, field_name='___pyct_anno'): - if (default is FAIL or - (hasattr(node, field_name) and (key in getattr(node, field_name)))): + if (default is FAIL or (hasattr(node, field_name) and + (key in getattr(node, field_name)))): return getattr(node, field_name)[key] else: return default @@ -86,3 +139,19 @@ def copyanno(from_node, to_node, key, field_name='___pyct_anno'): key, getanno(from_node, key, field_name=field_name), field_name=field_name) + + +def dup(node, copy_map, field_name='___pyct_anno'): + """Recursively copies annotations in an AST tree. + + Args: + node: ast.AST + copy_map: Dict[Hashable, Hashable], maps a source anno key to a destination + key. All annotations with the source key will be copied to identical + annotations with the destination key. + field_name: str + """ + for n in gast.walk(node): + for k in copy_map: + if hasanno(n, k, field_name): + setanno(n, copy_map[k], getanno(n, k, field_name), field_name) diff --git a/tensorflow/contrib/autograph/pyct/anno_test.py b/tensorflow/contrib/autograph/pyct/anno_test.py index f2c0c8cf05ca4b3671eb653ce56f6da61de54aee..5ef4da61a3627f9c0bc615ce5cb56052a28c64d1 100644 --- a/tensorflow/contrib/autograph/pyct/anno_test.py +++ b/tensorflow/contrib/autograph/pyct/anno_test.py @@ -32,22 +32,27 @@ class AnnoTest(test.TestCase): def test_basic(self): node = ast.Name() + self.assertEqual(anno.keys(node), set()) self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') anno.setanno(node, 'foo', 3) + + self.assertEqual(anno.keys(node), {'foo'}) self.assertTrue(anno.hasanno(node, 'foo')) self.assertEqual(anno.getanno(node, 'foo'), 3) self.assertEqual(anno.getanno(node, 'bar', default=7), 7) anno.delanno(node, 'foo') + + self.assertEqual(anno.keys(node), set()) self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') self.assertIsNone(anno.getanno(node, 'foo', default=None)) - def test_copyanno(self): + def test_copy(self): node_1 = ast.Name() anno.setanno(node_1, 'foo', 3) @@ -58,6 +63,22 @@ class AnnoTest(test.TestCase): self.assertTrue(anno.hasanno(node_2, 'foo')) self.assertFalse(anno.hasanno(node_2, 'bar')) + def test_duplicate(self): + node = ast.If( + test=ast.Num(1), + body=[ast.Expr(ast.Name('bar', ast.Load()))], + orelse=[]) + anno.setanno(node, 'spam', 1) + anno.setanno(node, 'ham', 1) + anno.setanno(node.body[0], 'ham', 1) + + anno.dup(node, {'spam': 'eggs'}) + + self.assertTrue(anno.hasanno(node, 'spam')) + self.assertTrue(anno.hasanno(node, 'ham')) + self.assertTrue(anno.hasanno(node, 'eggs')) + self.assertFalse(anno.hasanno(node.body[0], 'eggs')) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/ast_util.py b/tensorflow/contrib/autograph/pyct/ast_util.py index c4f82d11708393a6029d3f17be428b47eb9342ff..d7453b078197cd6f1c0521b861e96dd28d287cab 100644 --- a/tensorflow/contrib/autograph/pyct/ast_util.py +++ b/tensorflow/contrib/autograph/pyct/ast_util.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Copy an AST tree, discarding annotations.""" +"""AST manipulation utilities.""" from __future__ import absolute_import from __future__ import division @@ -26,47 +26,53 @@ from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import parser -class CleanCopier(gast.NodeVisitor): - """Copies AST nodes. +class CleanCopier(object): + """NodeTransformer-like visitor that copies an AST.""" - The copied nodes will ignore almost all fields that are prefixed by '__'. - Exceptions make some annotations. - """ + def __init__(self, preserve_annos): + super(CleanCopier, self).__init__() + self.preserve_annos = preserve_annos - # TODO(mdan): Parametrize which annotations get carried over. + def copy(self, node): + """Returns a deep copy of node (excluding some fields, see copy_clean).""" + + if isinstance(node, list): + return [self.copy(n) for n in node] + elif isinstance(node, tuple): + return tuple(self.copy(n) for n in node) + elif not isinstance(node, (gast.AST, ast.AST)): + # Assuming everything that's not an AST, list or tuple is a value type + # and may simply be assigned. + return node + + assert isinstance(node, (gast.AST, ast.AST)) - def generic_visit(self, node): new_fields = {} for f in node._fields: - if f.startswith('__'): - continue - if not hasattr(node, f): - continue - v = getattr(node, f) - if isinstance(v, list): - v = [self.generic_visit(n) for n in v] - elif isinstance(v, tuple): - v = tuple(self.generic_visit(n) for n in v) - elif isinstance(v, (gast.AST, ast.AST)): - v = self.generic_visit(v) - else: - # Assume everything else is a value type. - pass - new_fields[f] = v + if not f.startswith('__') and hasattr(node, f): + new_fields[f] = self.copy(getattr(node, f)) new_node = type(node)(**new_fields) - if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): - anno.setanno(new_node, anno.Basic.SKIP_PROCESSING, True) + + if self.preserve_annos: + for k in self.preserve_annos: + anno.copyanno(node, new_node, k) return new_node -def copy_clean(node): - copier = CleanCopier() - if isinstance(node, list): - return [copier.visit(n) for n in node] - elif isinstance(node, tuple): - return tuple(copier.visit(n) for n in node) - else: - return copier.visit(node) +def copy_clean(node, preserve_annos=None): + """Creates a deep copy of an AST. + + The copy will not include fields that are prefixed by '__', with the + exception of user-specified annotations. + + Args: + node: ast.AST + preserve_annos: Optional[Set[Hashable]], annotation keys to include in the + copy + Returns: + ast.AST + """ + return CleanCopier(preserve_annos).copy(node) class SymbolRenamer(gast.NodeTransformer): @@ -78,7 +84,11 @@ class SymbolRenamer(gast.NodeTransformer): def _process(self, node): qn = anno.getanno(node, anno.Basic.QN) if qn in self.name_map: - return gast.Name(str(self.name_map[qn]), node.ctx, None) + new_node = gast.Name(str(self.name_map[qn]), node.ctx, None) + # All annotations get carried over. + for k in anno.keys(node): + anno.copyanno(node, new_node, k) + return new_node return self.generic_visit(node) def visit_Name(self, node): @@ -92,6 +102,7 @@ class SymbolRenamer(gast.NodeTransformer): def rename_symbols(node, name_map): + """Renames symbols in an AST. Requires qual_names annotations.""" renamer = SymbolRenamer(name_map) if isinstance(node, list): return [renamer.visit(n) for n in node] @@ -101,6 +112,7 @@ def rename_symbols(node, name_map): def keywords_to_dict(keywords): + """Converts a list of ast.keyword objects to a dict.""" keys = [] values = [] for kw in keywords: @@ -110,10 +122,7 @@ def keywords_to_dict(keywords): class PatternMatcher(gast.NodeVisitor): - """Matches a node against a pattern represented by a node. - - The pattern may contain wildcards represented by the symbol '_'. - """ + """Matches a node against a pattern represented by a node.""" def __init__(self, pattern): self.pattern = pattern @@ -177,9 +186,128 @@ class PatternMatcher(gast.NodeVisitor): def matches(node, pattern): + """Basic pattern matcher for AST. + + The pattern may contain wildcards represented by the symbol '_'. A node + matches a pattern if for every node in the tree, either there is a node of + the same type in pattern, or a Name node with id='_'. + + Args: + node: ast.AST + pattern: ast.AST + Returns: + bool + """ if isinstance(pattern, str): pattern = parser.parse_expression(pattern) matcher = PatternMatcher(pattern) matcher.visit(node) return matcher.matches + +# TODO(mdan): Once we have error tracing, we may be able to just go to SSA. +def apply_to_single_assignments(targets, values, apply_fn): + """Applies a function to each individual assignment. + + This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. + It tries to break down the unpacking if possible. In effect, it has the same + effect as passing the assigned values in SSA form to apply_fn. + + Examples: + + The following will result in apply_fn(a, c), apply_fn(b, d): + + a, b = c, d + + The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]): + + a, b = c + + The following will result in apply_fn(a, (b, c)): + + a = b, c + + It uses the visitor pattern to allow subclasses to process single + assignments individually. + + Args: + targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be + used with the targets field of an ast.Assign node + values: ast.AST + apply_fn: Callable[[ast.AST, ast.AST], None], called with the + respective nodes of each single assignment + """ + if not isinstance(targets, (list, tuple)): + targets = (targets,) + for target in targets: + if isinstance(target, (gast.Tuple, gast.List)): + for i in range(len(target.elts)): + target_el = target.elts[i] + if isinstance(values, (gast.Tuple, gast.List)): + value_el = values.elts[i] + else: + idx = parser.parse_expression(str(i)) + value_el = gast.Subscript(values, gast.Index(idx), ctx=gast.Load()) + apply_to_single_assignments(target_el, value_el, apply_fn) + else: + apply_fn(target, values) + + +def parallel_walk(node, other): + """Walks two ASTs in parallel. + + The two trees must have identical structure. + + Args: + node: Union[ast.AST, Iterable[ast.AST]] + other: Union[ast.AST, Iterable[ast.AST]] + Yields: + Tuple[ast.AST, ast.AST] + Raises: + ValueError: if the two trees don't have identical structure. + """ + if isinstance(node, (list, tuple)): + node_stack = list(node) + else: + node_stack = [node] + + if isinstance(other, (list, tuple)): + other_stack = list(other) + else: + other_stack = [other] + + while node_stack and other_stack: + assert len(node_stack) == len(other_stack) + n = node_stack.pop() + o = other_stack.pop() + + if (not isinstance(n, (ast.AST, gast.AST)) or + not isinstance(o, (ast.AST, gast.AST)) or + n.__class__.__name__ != o.__class__.__name__): + raise ValueError('inconsistent nodes: {} and {}'.format(n, o)) + + yield n, o + + for f in n._fields: + n_child = getattr(n, f, None) + o_child = getattr(o, f, None) + if f.startswith('__') or n_child is None or o_child is None: + continue + + if isinstance(n_child, (list, tuple)): + if (not isinstance(o_child, (list, tuple)) or + len(n_child) != len(o_child)): + raise ValueError( + 'inconsistent values for field {}: {} and {}'.format( + f, n_child, o_child)) + node_stack.extend(n_child) + other_stack.extend(o_child) + + elif isinstance(n_child, (gast.AST, ast.AST)): + node_stack.append(n_child) + other_stack.append(o_child) + + elif n_child != o_child: + raise ValueError( + 'inconsistent values for field {}: {} and {}'.format( + f, n_child, o_child)) diff --git a/tensorflow/contrib/autograph/pyct/ast_util_test.py b/tensorflow/contrib/autograph/pyct/ast_util_test.py index 3afa04a50685d19c90944c14ed39f9d3ad35e486..2293c89720a54f7495670c6f28b00f716cad70db 100644 --- a/tensorflow/contrib/autograph/pyct/ast_util_test.py +++ b/tensorflow/contrib/autograph/pyct/ast_util_test.py @@ -19,7 +19,10 @@ from __future__ import division from __future__ import print_function import ast +import collections +import textwrap +from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import compiler from tensorflow.contrib.autograph.pyct import parser @@ -29,62 +32,75 @@ from tensorflow.python.platform import test class AstUtilTest(test.TestCase): - def test_rename_symbols(self): - node = ast.Tuple([ - ast.Name('a', ast.Load()), - ast.Name('b', ast.Load()), - ast.Attribute(ast.Name('b', None), 'c', ast.Store()), - ast.Attribute( - ast.Attribute(ast.Name('b', None), 'c', ast.Load()), 'd', None) - ], None) + def setUp(self): + super(AstUtilTest, self).setUp() + self._invocation_counts = collections.defaultdict(lambda: 0) + + def test_rename_symbols_basic(self): + node = parser.parse_str('a + b') node = qual_names.resolve(node) + node = ast_util.rename_symbols( - node, { - qual_names.QN('a'): - qual_names.QN('renamed_a'), - qual_names.QN(qual_names.QN('b'), attr='c'): - qual_names.QN('renamed_b_c'), - }) - - self.assertEqual(node.elts[0].id, 'renamed_a') - self.assertTrue(isinstance(node.elts[0].ctx, ast.Load)) - self.assertEqual(node.elts[1].id, 'b') - self.assertEqual(node.elts[2].id, 'renamed_b_c') - self.assertTrue(isinstance(node.elts[2].ctx, ast.Store)) - self.assertEqual(node.elts[3].value.id, 'renamed_b_c') - self.assertTrue(isinstance(node.elts[3].value.ctx, ast.Load)) + node, {qual_names.QN('a'): qual_names.QN('renamed_a')}) + + self.assertIsInstance(node.body[0].value.left.id, str) + source = compiler.ast_to_source(node) + self.assertEqual(source.strip(), 'renamed_a + b') + + def test_rename_symbols_attributes(self): + node = parser.parse_str('b.c = b.c.d') + node = qual_names.resolve(node) + + node = ast_util.rename_symbols( + node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')}) + + source = compiler.ast_to_source(node) + self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d') + + def test_rename_symbols_annotations(self): + node = parser.parse_str('a[i]') + node = qual_names.resolve(node) + anno.setanno(node, 'foo', 'bar') + orig_anno = anno.getanno(node, 'foo') + + node = ast_util.rename_symbols(node, + {qual_names.QN('a'): qual_names.QN('b')}) + + self.assertIs(anno.getanno(node, 'foo'), orig_anno) def test_copy_clean(self): - ret = ast.Return( - ast.BinOp( - op=ast.Add(), - left=ast.Name(id='a', ctx=ast.Load()), - right=ast.Num(1))) - setattr(ret, '__foo', 'bar') - node = ast.FunctionDef( - name='f', - args=ast.arguments( - args=[ast.Name(id='a', ctx=ast.Param())], - vararg=None, - kwarg=None, - defaults=[]), - body=[ret], - decorator_list=[], - returns=None) + node = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + 1 + """)) + setattr(node.body[0], '__foo', 'bar') new_node = ast_util.copy_clean(node) - self.assertFalse(node is new_node) - self.assertFalse(ret is new_node.body[0]) + self.assertIsNot(new_node, node) + self.assertIsNot(new_node.body[0], node.body[0]) self.assertFalse(hasattr(new_node.body[0], '__foo')) + def test_copy_clean_preserves_annotations(self): + node = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + 1 + """)) + anno.setanno(node.body[0], 'foo', 'bar') + anno.setanno(node.body[0], 'baz', 1) + new_node = ast_util.copy_clean(node, preserve_annos={'foo'}) + self.assertEqual(anno.getanno(new_node.body[0], 'foo'), 'bar') + self.assertFalse(anno.hasanno(new_node.body[0], 'baz')) + def test_keywords_to_dict(self): keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords d = ast_util.keywords_to_dict(keywords) # Make sure we generate a usable dict node by attaching it to a variable and # compiling everything. - output = parser.parse_str('b = 3') - output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d),) - result, _ = compiler.ast_to_object(output) - self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'}) + node = parser.parse_str('def f(b): pass').body[0] + node.body.append(ast.Return(d)) + result, _ = compiler.ast_to_object(node) + self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'}) def assertMatch(self, target_str, pattern_str): node = parser.parse_expression(target_str) @@ -113,6 +129,68 @@ class AstUtilTest(test.TestCase): self.assertNoMatch('super(Foo, self).__init__()', 'super(Bar, _).__init__(_)') + def _mock_apply_fn(self, target, source): + target = compiler.ast_to_source(target) + source = compiler.ast_to_source(source) + self._invocation_counts[(target.strip(), source.strip())] += 1 + + def test_apply_to_single_assignments_dynamic_unpack(self): + node = parser.parse_str('a, b, c = d') + node = node.body[0] + ast_util.apply_to_single_assignments(node.targets, node.value, + self._mock_apply_fn) + self.assertDictEqual(self._invocation_counts, { + ('a', 'd[0]'): 1, + ('b', 'd[1]'): 1, + ('c', 'd[2]'): 1, + }) + + def test_apply_to_single_assignments_static_unpack(self): + node = parser.parse_str('a, b, c = d, e, f') + node = node.body[0] + ast_util.apply_to_single_assignments(node.targets, node.value, + self._mock_apply_fn) + self.assertDictEqual(self._invocation_counts, { + ('a', 'd'): 1, + ('b', 'e'): 1, + ('c', 'f'): 1, + }) + + def test_parallel_walk(self): + node = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + 1 + """)) + for child_a, child_b in ast_util.parallel_walk(node, node): + self.assertEqual(child_a, child_b) + + def test_parallel_walk_inconsistent_trees(self): + node_1 = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + 1 + """)) + node_2 = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + (a * 2) + """)) + node_3 = parser.parse_str( + textwrap.dedent(""" + def f(a): + return a + 2 + """)) + with self.assertRaises(ValueError): + for _ in ast_util.parallel_walk(node_1, node_2): + pass + # There is not particular reason to reject trees that differ only in the + # value of a constant. + # TODO(mdan): This should probably be allowed. + with self.assertRaises(ValueError): + for _ in ast_util.parallel_walk(node_1, node_3): + pass + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/cfg.py b/tensorflow/contrib/autograph/pyct/cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..ba51dcf285036220e01b89e8beeb9aec8ffe36be --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/cfg.py @@ -0,0 +1,815 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Control flow graph (CFG) structure for Python AST representation. + +The CFG is a digraph with edges representing valid control flow. Each +node is associated with exactly one AST node, but not all AST nodes may have +a corresponding CFG counterpart. + +Once built, the CFG itself is immutable, but the values it holds need not be; +they are usually annotated with information extracted by walking the graph. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +from enum import Enum + +# pylint:disable=g-bad-import-order +import gast +# pylint:enable=g-bad-import-order + +from tensorflow.contrib.autograph.pyct import compiler + + +class Node(object): + """A node in the CFG. + + Although new instances of this class are mutable, the objects that a user + finds in the CFG are typically not. + + The nodes represent edges in the CFG graph, and maintain pointers to allow + efficient walking in both forward and reverse order. The following property + holds for all nodes: "child in node.next" iff "node in child.prev". + + Attributes: + next: FrozenSet[Node, ...], the nodes that follow this node, in control + flow order + prev: FrozenSet[Node, ...], the nodes that precede this node, in reverse + control flow order + ast_node: ast.AST, the AST node corresponding to this CFG node + """ + + def __init__(self, next_, prev, ast_node): + self.next = next_ + self.prev = prev + self.ast_node = ast_node + + def freeze(self): + self.next = frozenset(self.next) + self.prev = frozenset(self.prev) + + def __repr__(self): + if isinstance(self.ast_node, gast.FunctionDef): + return 'def %s' % self.ast_node.name + elif isinstance(self.ast_node, gast.withitem): + return compiler.ast_to_source(self.ast_node.context_expr).strip() + return compiler.ast_to_source(self.ast_node).strip() + + +class Graph( + collections.namedtuple( + 'Graph', + ['entry', 'exit', 'error', 'index', 'stmt_prev', 'stmt_next'])): + """A Control Flow Graph. + + The CFG maintains an index to allow looking up a CFG node by the AST node to + which it is associated. The index can also be enumerated in top-down, depth + first order. + + Walking the graph in forward or reverse order is supported by double + parent-child links. + + Note: the error nodes are not wired to their corresponding finally guards, + because these are shared, and wiring them would create a reverse path from + normal control flow into the error nodes, which we want to avoid. + + The graph also maintains edges corresponding to higher level statements + like for-else loops. A node is considered successor of a statement if there + is an edge from a node that is lexically a child of that statement to a node + that is not. Statement predecessors are analogously defined. + + Attributes: + entry: Node, the entry node + exit: FrozenSet[Node, ...], the exit nodes + error: FrozenSet[Node, ...], nodes that exit due to an explicitly raised + error (errors propagated from function calls are not accounted) + index: Dict[ast.Node, Node], mapping AST nodes to the respective CFG + node + stmt_prev: Dict[ast.Node, FrozenSet[Node, ...]], mapping statement AST + nodes to their predecessor CFG nodes + stmt_next: Dict[ast.Node, FrozenSet[Node, ...]], mapping statement AST + nodes to their successor CFG nodes + """ + + def __repr__(self): + result = 'digraph CFG {\n' + for node in self.index.values(): + result += ' %s [label="%s"];\n' % (id(node), node) + for node in self.index.values(): + for next_ in node.next: + result += ' %s -> %s;\n' % (id(node), id(next_)) + result += '}' + return result + + +class _WalkMode(Enum): + FORWARD = 1 + REVERSE = 2 + + +# TODO(mdan): Rename to DataFlowAnalyzer. +# TODO(mdan): Consider specializations that use gen/kill/transfer abstractions. +class GraphVisitor(object): + """Base class for a CFG visitors. + + This implementation is not thread safe. + + The visitor has some facilities to simplify dataflow analyses. In particular, + it allows revisiting the nodes at the decision of the subclass. This can be + used to visit the graph until the state reaches a fixed point. + + For more details on dataflow analysis, see + https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec02-Dataflow.pdf + + Note: the literature generally suggests visiting successor nodes only when the + state of the current node changed, regardless of whether that successor has + ever been visited. This implementation visits every successor at least once. + + Attributes: + graph: Graph + in_: Dict[Node, Any], stores node-keyed state during a visit + out: Dict[Node, Any], stores node-keyed state during a visit + """ + + def __init__(self, graph): + self.graph = graph + self.reset() + + def init_state(self, node): + """State initialization function. Optional to overload. + + An in/out state slot will be created for each node in the graph. Subclasses + must overload this to control what that is initialized to. + + Args: + node: Node + """ + raise NotImplementedError('Subclasses must implement this.') + + # TODO(mdan): Rename to flow? + def visit_node(self, node): + """Visitor function. + + Args: + node: Node + Returns: + bool, whether the node should be revisited; subclasses can visit every + reachable node exactly once by always returning False + """ + raise NotImplementedError('Subclasses must implement this.') + + def reset(self): + self.in_ = { + node: self.init_state(node) for node in self.graph.index.values() + } + self.out = { + node: self.init_state(node) for node in self.graph.index.values() + } + + def _visit_internal(self, mode): + """Visits the CFG, depth-first.""" + assert mode in (_WalkMode.FORWARD, _WalkMode.REVERSE) + if mode == _WalkMode.FORWARD: + open_ = [self.graph.entry] + elif mode == _WalkMode.REVERSE: + open_ = list(self.graph.exit) + closed = set() + + while open_: + node = open_.pop(0) + closed.add(node) + + should_revisit = self.visit_node(node) + + if mode == _WalkMode.FORWARD: + children = node.next + elif mode == _WalkMode.REVERSE: + children = node.prev + + for next_ in children: + if should_revisit or next_ not in closed: + open_.append(next_) + + def visit_forward(self): + self._visit_internal(_WalkMode.FORWARD) + + def visit_reverse(self): + self._visit_internal(_WalkMode.REVERSE) + + +class GraphBuilder(object): + """Builder that constructs a CFG from a given AST. + + This GraphBuilder facilitates constructing the DAG that forms the CFG when + nodes + are supplied in lexical order (i.e., top-down, depth first). Under these + conditions, it supports building patterns found in typical structured + programs. + + This builder ignores the flow generated by exceptions, which are assumed to + always be catastrophic and present purely for diagnostic purposes (e.g. to + print debug information). Statements like raise and try/catch sections are + allowed and will generate control flow edges, but ordinaty statements are + assumed not to raise exceptions. + + Finally sections are also correctly interleaved between break/continue/return + nodes and their subsequent statements. + + Important concepts: + * nodes - nodes refer refer to CFG nodes; AST nodes are qualified explicitly + * leaf set - since the graph is constructed gradually, a leaf set maintains + the CFG nodes that will precede the node that the builder expects to + receive next; when an ordinary node is added, it is connected to the + existing leaves and it in turn becomes the new leaf + * jump nodes - nodes that should generate edges other than what + ordinary nodes would; these correspond to break, continue and return + statements + * sections - logical delimiters for subgraphs that require special + edges; there are various types of nodes, each admitting various + types of jump nodes; sections are identified by their corresponding AST + node + """ + + # TODO(mdan): Perhaps detail this in a markdown doc. + # TODO(mdan): Add exception support. + + def __init__(self, parent_ast_node): + self.reset() + self.parent = parent_ast_node + + def reset(self): + """Resets the state of this factory.""" + self.head = None + self.errors = set() + self.node_index = collections.OrderedDict() + + # TODO(mdan): Too many primitives. Use classes. + self.leaves = set() + + # Note: This mechanism requires that nodes are added in lexical order (top + # to bottom, depth first). + self.active_stmts = set() + self.owners = {} # type: Set[any] + self.forward_edges = set() # type: Tuple[Node, Node] # (from, to) + + self.finally_sections = {} + # Dict values represent (entry, exits) + self.finally_section_subgraphs = { + } # type: Dict[ast.AST, Tuple[Node, Set[Node]]] + # Whether the guard section can be reached from the statement that precedes + # it. + self.finally_section_has_direct_flow = {} + # Finally sections that await their first node. + self.pending_finally_sections = set() + + # Exit jumps keyed by the section they affect. + self.exits = {} + + # The entry of loop sections, keyed by the section. + self.section_entry = {} + # Continue jumps keyed by the section they affect. + self.continues = {} + + # The entry of conditional sections, keyed by the section. + self.cond_entry = {} + # Lists of leaf nodes corresponding to each branch in the section. + self.cond_leaves = {} + + def _connect_nodes(self, first, second): + """Connects nodes to signify that control flows from first to second. + + Args: + first: Union[Set[Node, ...], Node] + second: Node + """ + if isinstance(first, Node): + first.next.add(second) + second.prev.add(first) + self.forward_edges.add((first, second)) + else: + for node in first: + self._connect_nodes(node, second) + + def _add_new_node(self, ast_node): + """Grows the graph by adding a CFG node following the current leaves.""" + if ast_node is self.node_index: + raise ValueError('%s added twice' % ast_node) + node = Node(next_=set(), prev=set(), ast_node=ast_node) + self.node_index[ast_node] = node + self.owners[node] = frozenset(self.active_stmts) + + if self.head is None: + self.head = node + + for leaf in self.leaves: + self._connect_nodes(leaf, node) + + # If any finally section awaits its first node, populate it. + for section_id in self.pending_finally_sections: + self.finally_section_subgraphs[section_id][0] = node + self.pending_finally_sections = set() + + return node + + def begin_statement(self, stmt): + """Marks the beginning of a statement. + + Args: + stmt: Hashable, a key by which the statement can be identified in + the CFG's stmt_prev and stmt_next attributes + """ + self.active_stmts.add(stmt) + + def end_statement(self, stmt): + """Marks the end of a statement. + + Args: + stmt: Hashable, a key by which the statement can be identified in + the CFG's stmt_prev and stmt_next attributes; must match a key + previously passed to begin_statement. + """ + self.active_stmts.remove(stmt) + + def add_ordinary_node(self, ast_node): + """Grows the graph by adding an ordinary CFG node. + + Ordinary nodes are followed by the next node, in lexical order, that is, + they become the new leaf set. + + Args: + ast_node: ast.AST + Returns: + Node + """ + node = self._add_new_node(ast_node) + self.leaves = set((node,)) + return node + + def _add_jump_node(self, ast_node, guards): + """Grows the graph by adding a jump node. + + Jump nodes are added to the current leaf set, and the leaf set becomes + empty. If the jump node is the last in a cond section, then it may be added + back to the leaf set by a separate mechanism. + + Args: + ast_node: ast.AST + guards: Tuple[ast.AST, ...], the finally sections active for this node + Returns: + Node + """ + node = self._add_new_node(ast_node) + self.leaves = set() + # The guards themselves may not yet be complete, and will be wired later. + self.finally_sections[node] = guards + return node + + def _connect_jump_to_finally_sections(self, node): + """Connects a jump node to the finally sections protecting it.""" + cursor = set((node,)) + for guard_section_id in self.finally_sections[node]: + guard_begin, guard_ends = self.finally_section_subgraphs[guard_section_id] + self._connect_nodes(cursor, guard_begin) + cursor = guard_ends + del self.finally_sections[node] + # TODO(mdan): Should garbage-collect finally_section_subgraphs. + return cursor + + def add_exit_node(self, ast_node, section_id, guards): + """Grows the graph by adding an exit node. + + This node becomes an exit for the current section. + + Args: + ast_node: ast.AST + section_id: Hashable, the node for which ast_node should be considered + to be an exit node + guards: Tuple[ast.AST, ...], the finally sections that guard ast_node + """ + node = self._add_jump_node(ast_node, guards) + self.exits[section_id].add(node) + + def add_continue_node(self, ast_node, section_id, guards): + """Grows the graph by adding a reentry node. + + This node causes control flow to go back to the loop section's entry. + + Args: + ast_node: ast.AST + section_id: Hashable, the node for which ast_node should be considered + to be an exit node + guards: Tuple[ast.AST, ...], the finally sections that guard ast_node + """ + node = self._add_jump_node(ast_node, guards) + self.continues[section_id].add(node) + + def add_error_node(self, ast_node, guards): + """Grows the graph by adding an error node. + + This node becomes an exit for the entire graph. + + Args: + ast_node: ast.AST + guards: Tuple[ast.AST, ...], the finally sections that guard ast_node + """ + node = self._add_jump_node(ast_node, guards) + self.errors.add(node) + self.leaves = set() + + def enter_section(self, section_id): + """Enters a regular section. + + Regular sections admit exit jumps, which end the section. + + Args: + section_id: Hashable, the same node that will be used in calls to the + ast_node arg passed to add_exit_node + """ + assert section_id not in self.exits + self.exits[section_id] = set() + + def exit_section(self, section_id): + """Exits a regular section.""" + + # Exits are jump nodes, which may be protected. + for exit_ in self.exits[section_id]: + self.leaves |= self._connect_jump_to_finally_sections(exit_) + + del self.exits[section_id] + + def enter_loop_section(self, section_id, entry_node): + """Enters a loop section. + + Loop sections define an entry node. The end of the section always flows back + to the entry node. These admit continue jump nodes which also flow to the + entry node. + + Args: + section_id: Hashable, the same node that will be used in calls to the + ast_node arg passed to add_continue_node + entry_node: ast.AST, the entry node into the loop (e.g. the test node + for while loops) + """ + assert section_id not in self.section_entry + assert section_id not in self.continues + self.continues[section_id] = set() + node = self.add_ordinary_node(entry_node) + self.section_entry[section_id] = node + + def exit_loop_section(self, section_id): + """Exits a loop section.""" + self._connect_nodes(self.leaves, self.section_entry[section_id]) + + # continues are jump nodes, which may be protected. + for reentry in self.continues[section_id]: + guard_ends = self._connect_jump_to_finally_sections(reentry) + self._connect_nodes(guard_ends, self.section_entry[section_id]) + + # Loop nodes always loop back. + self.leaves = set((self.section_entry[section_id],)) + + del self.continues[section_id] + del self.section_entry[section_id] + + def enter_cond_section(self, section_id): + """Enters a conditional section. + + Conditional sections define an entry node, and one or more branches. + + Args: + section_id: Hashable, the same node that will be used in calls to the + section_id arg passed to new_cond_branch + """ + + assert section_id not in self.cond_entry + assert section_id not in self.cond_leaves + self.cond_leaves[section_id] = [] + + def new_cond_branch(self, section_id): + """Begins a new branch in a cond section.""" + assert section_id in self.cond_leaves + + if section_id in self.cond_entry: + # Subsequent splits move back to the split point, and memorize the + # current leaves. + self.cond_leaves[section_id].append(self.leaves) + self.leaves = self.cond_entry[section_id] + else: + # If this is the first time we split a section, just remember the split + # point. + self.cond_entry[section_id] = self.leaves + + def exit_cond_section(self, section_id): + """Exits a conditional section.""" + for split in self.cond_leaves[section_id]: + self.leaves |= split + del self.cond_entry[section_id] + del self.cond_leaves[section_id] + + def enter_finally_section(self, section_id): + """Enters a finally section.""" + # TODO(mdan): This, not the caller, should track the active sections. + self.finally_section_subgraphs[section_id] = [None, None] + if self.leaves: + self.finally_section_has_direct_flow[section_id] = True + else: + self.finally_section_has_direct_flow[section_id] = False + self.pending_finally_sections.add(section_id) + + def exit_finally_section(self, section_id): + """Exits a finally section.""" + assert section_id not in self.pending_finally_sections, 'Empty finally?' + self.finally_section_subgraphs[section_id][1] = self.leaves + # If the guard can only be reached by a jump, then it will not flow + # into the statement that follows it. + if not self.finally_section_has_direct_flow[section_id]: + self.leaves = set() + del self.finally_section_has_direct_flow[section_id] + + def build(self): + """Returns the CFG accumulated so far and resets the builder. + + Returns: + Graph + """ + # Freeze the nodes. + for node in self.node_index.values(): + node.freeze() + + # Build the statement edges. + stmt_next = {} + stmt_prev = {} + for node, _ in self.forward_edges: + for stmt in self.owners[node]: + if stmt not in stmt_next: + stmt_next[stmt] = set() + if stmt not in stmt_prev: + stmt_prev[stmt] = set() + for first, second in self.forward_edges: + stmts_exited = self.owners[first] - self.owners[second] + for stmt in stmts_exited: + stmt_next[stmt].add(second) + stmts_entered = self.owners[second] - self.owners[first] + for stmt in stmts_entered: + stmt_prev[stmt].add(first) + for stmt in stmt_next: + stmt_next[stmt] = frozenset(stmt_next[stmt]) + for stmt in stmt_prev: + stmt_prev[stmt] = frozenset(stmt_prev[stmt]) + + # Construct the final graph object. + result = Graph( + entry=self.head, + exit=self.leaves, + error=self.errors, + index=self.node_index, + stmt_prev=stmt_prev, + stmt_next=stmt_next) + + # Reset the state. + self.reset() + + return result + + +class AstToCfg(gast.NodeVisitor): + """Converts an AST to CFGs. + + A separate CFG will be constructed for each function. + """ + + def __init__(self): + super(AstToCfg, self).__init__() + + self.builder_stack = [] + self.builder = None + self.cfgs = {} + + self.lexical_scopes = [] + + def _enter_lexical_scope(self, node): + self.lexical_scopes.append(node) + + def _exit_lexical_scope(self, node): + leaving_node = self.lexical_scopes.pop() + assert node == leaving_node + + def _get_enclosing_scopes(self, include, stop_at): + included = [] + for node in reversed(self.lexical_scopes): + if isinstance(node, include): + included.append(node) + if isinstance(node, stop_at): + return node, included + return None, included + + def _process_basic_statement(self, node): + self.generic_visit(node) + self.builder.add_ordinary_node(node) + + def _process_exit_statement(self, node, *exits_nodes_of_type): + # Note: this is safe because we process functions separately. + try_node, guards = self._get_enclosing_scopes( + include=(gast.Try,), + stop_at=tuple(exits_nodes_of_type), + ) + if try_node is None: + raise ValueError( + '%s that is not enclosed by any of %s' % (node, exits_nodes_of_type)) + self.builder.add_exit_node(node, try_node, guards) + + def _process_continue_statement(self, node, *loops_to_nodes_of_type): + # Note: this is safe because we process functions separately. + try_node, guards = self._get_enclosing_scopes( + include=(gast.Try,), + stop_at=tuple(loops_to_nodes_of_type), + ) + if try_node is None: + raise ValueError('%s that is not enclosed by any of %s' % + (node, loops_to_nodes_of_type)) + self.builder.add_continue_node(node, try_node, guards) + + def visit_FunctionDef(self, node): + # We also keep the FunctionDef node in the CFG. This allows us to determine + # things like reaching definitions via closure. Note that the function body + # will be stored in a separate graph, because function definitions are not + # the same as function calls. + if self.builder is not None: + self.builder.add_ordinary_node(node) + + self.builder_stack.append(self.builder) + self.builder = GraphBuilder(node) + + self._enter_lexical_scope(node) + self.builder.enter_section(node) + + self._process_basic_statement(node.args) + for stmt in node.body: + self.visit(stmt) + + self.builder.exit_section(node) + self._exit_lexical_scope(node) + + self.cfgs[node] = self.builder.build() + self.builder = self.builder_stack.pop() + + def visit_Lambda(self, node): + # TODO(mdan): Treat like FunctionDef? That would be a separate CFG. + raise NotImplementedError() + + def visit_Return(self, node): + self._process_exit_statement(node, gast.FunctionDef) + + def visit_Expr(self, node): + self._process_basic_statement(node) + + def visit_Assign(self, node): + self._process_basic_statement(node) + + def visit_AnnAssign(self, node): + self._process_basic_statement(node) + + def visit_AugAssign(self, node): + self._process_basic_statement(node) + + def visit_Print(self, node): + self._process_basic_statement(node) + + def visit_Raise(self, node): + try_node, guards = self._get_enclosing_scopes( + include=(gast.Try,), + stop_at=(gast.FunctionDef,), + ) + if try_node is None: + raise ValueError('%s that is not enclosed by any FunctionDef' % node) + self.builder.add_error_node(node, guards) + + def visit_Assert(self, node): + # Ignoring the effect of exceptions. + self._process_basic_statement(node) + + def visit_Delete(self, node): + self._process_basic_statement(node) + + def visit_If(self, node): + # No need to track ifs as lexical scopes, for now. + # Lexical scopes are generally tracked in order to be able to resolve the + # targets of jump statements like break/continue/etc. Since there is no + # statement that can interrupt a conditional, we don't need to track their + # lexical scope. That may change in the future. + self.builder.begin_statement(node) + + self.builder.enter_cond_section(node) + self._process_basic_statement(node.test) + + self.builder.new_cond_branch(node) + for stmt in node.body: + self.visit(stmt) + + self.builder.new_cond_branch(node) + for stmt in node.orelse: + self.visit(stmt) + + self.builder.exit_cond_section(node) + self.builder.end_statement(node) + + def visit_While(self, node): + self.builder.begin_statement(node) + self._enter_lexical_scope(node) + + self.builder.enter_section(node) + + self.builder.enter_loop_section(node, node.test) + for stmt in node.body: + self.visit(stmt) + self.builder.exit_loop_section(node) + + # Note: although the orelse is technically part of the loop node, + # the statements inside it don't affect the loop itself. For example, a + # break in the loop's orelse will not affect the loop itself. + self._exit_lexical_scope(node) + + for stmt in node.orelse: + self.visit(stmt) + + self.builder.exit_section(node) + self.builder.end_statement(node) + + def visit_For(self, node): + self.builder.begin_statement(node) + self._enter_lexical_scope(node) + + self.builder.enter_section(node) + + # TODO(mdan): Strictly speaking, this should be node.target + node.iter. + # A blind dataflow analysis would have to process both node.target and + # node.iter to properly process read and write access. + self.builder.enter_loop_section(node, node.iter) + for stmt in node.body: + self.visit(stmt) + self.builder.exit_loop_section(node) + + # Note: although the orelse is technically part of the loop node, + # they don't count as loop bodies. For example, a break in the loop's + # orelse will affect the parent loop, not the current one. + self._exit_lexical_scope(node) + + for stmt in node.orelse: + self.visit(stmt) + + self.builder.exit_section(node) + self.builder.end_statement(node) + + def visit_Break(self, node): + self._process_exit_statement(node, gast.While, gast.For) + + def visit_Continue(self, node): + self._process_continue_statement(node, gast.While, gast.For) + + def visit_Try(self, node): + self._enter_lexical_scope(node) + + for stmt in node.body: + self.visit(stmt) + # Unlike loops, the orelse is a simple continuation of the body. + for stmt in node.orelse: + self.visit(stmt) + + if node.handlers: + # TODO(mdan): Should we still support bare try/except? Might be confusing. + raise NotImplementedError('exceptions are not yet supported') + + self._exit_lexical_scope(node) + + self.builder.enter_finally_section(node) + for stmt in node.finalbody: + self.visit(stmt) + self.builder.exit_finally_section(node) + + def visit_With(self, node): + # TODO(mdan): Mark the context manager's exit call as exit guard. + for item in node.items: + self._process_basic_statement(item) + for stmt in node.body: + self.visit(stmt) + + +def build(node): + visitor = AstToCfg() + visitor.visit(node) + return visitor.cfgs diff --git a/tensorflow/contrib/autograph/pyct/cfg_test.py b/tensorflow/contrib/autograph/pyct/cfg_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9d0a85d615cc5a7dcebf405aebdbfe409be0b5cf --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/cfg_test.py @@ -0,0 +1,969 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cfg module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import cfg +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.python.platform import test + + +class CountingVisitor(cfg.GraphVisitor): + + def __init__(self, graph): + super(CountingVisitor, self).__init__(graph) + self.counts = {} + + def init_state(self, _): + return None + + def visit_node(self, node): + self.counts[node.ast_node] = self.counts.get(node.ast_node, 0) + 1 + return False # visit only once + + +class GraphVisitorTest(test.TestCase): + + def _build_cfg(self, fn): + node, _ = parser.parse_entity(fn) + cfgs = cfg.build(node) + return cfgs, node + + def test_basic_coverage_forward(self): + + def test_fn(a): + while a > 0: + a = 1 + break + return a # pylint:disable=unreachable + a = 2 + + graphs, node = self._build_cfg(test_fn) + graph, = graphs.values() + visitor = CountingVisitor(graph) + visitor.visit_forward() + fn_node = node.body[0] + + self.assertEqual(visitor.counts[fn_node.args], 1) + self.assertEqual(visitor.counts[fn_node.body[0].test], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1) + # The return node should be unreachable in forward direction. + self.assertTrue(fn_node.body[0].body[2] not in visitor.counts) + self.assertEqual(visitor.counts[fn_node.body[1]], 1) + + def test_basic_coverage_reverse(self): + + def test_fn(a): + while a > 0: + a = 1 + break + return a # pylint:disable=unreachable + a = 2 + + graphs, node = self._build_cfg(test_fn) + graph, = graphs.values() + visitor = CountingVisitor(graph) + visitor.visit_reverse() + fn_node = node.body[0] + + self.assertEqual(visitor.counts[fn_node.args], 1) + self.assertEqual(visitor.counts[fn_node.body[0].test], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1) + self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1) + self.assertTrue(visitor.counts[fn_node.body[0].body[2]], 1) + self.assertEqual(visitor.counts[fn_node.body[1]], 1) + + +class AstToCfgTest(test.TestCase): + + def _build_cfg(self, fn): + node, _ = parser.parse_entity(fn) + cfgs = cfg.build(node) + return cfgs + + def _repr_set(self, node_set): + return frozenset(repr(n) for n in node_set) + + def _as_set(self, elements): + if elements is None: + return frozenset() + elif isinstance(elements, str): + return frozenset((elements,)) + else: + return frozenset(elements) + + def assertGraphMatches(self, graph, edges): + """Tests whether the CFG contains the specified edges.""" + for prev, node_repr, next_ in edges: + matched = False + for cfg_node in graph.index.values(): + if repr(cfg_node) == node_repr: + if (self._as_set(prev) == frozenset(map(repr, cfg_node.prev)) and + self._as_set(next_) == frozenset(map(repr, cfg_node.next))): + matched = True + break + if not matched: + self.fail( + 'match failed for node "%s" in graph:\n%s' % (node_repr, graph)) + + def assertStatementEdges(self, graph, edges): + """Tests whether the CFG contains the specified statement edges.""" + for prev_node_reprs, node_repr, next_node_reprs in edges: + matched = False + partial_matches = [] + self.assertSetEqual( + frozenset(graph.stmt_next.keys()), frozenset(graph.stmt_prev.keys())) + for stmt_ast_node in graph.stmt_next: + ast_repr = '%s:%s' % (stmt_ast_node.__class__.__name__, + stmt_ast_node.lineno) + if ast_repr == node_repr: + actual_next = frozenset(map(repr, graph.stmt_next[stmt_ast_node])) + actual_prev = frozenset(map(repr, graph.stmt_prev[stmt_ast_node])) + partial_matches.append((actual_prev, node_repr, actual_next)) + if (self._as_set(prev_node_reprs) == actual_prev and + self._as_set(next_node_reprs) == actual_next): + matched = True + break + if not matched: + self.fail('edges mismatch for %s: %s' % (node_repr, partial_matches)) + + def test_straightline(self): + + def test_fn(a): + a += 1 + a = 2 + a = 3 + return + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', 'a += 1'), + ('a += 1', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', 'return'), + ('a = 3', 'return', None), + ), + ) + + def test_straightline_no_return(self): + + def test_fn(a, b): + a = b + 1 + a += max(a) + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a, b', 'a = b + 1'), + ('a = b + 1', 'a += max(a)', None), + ), + ) + + def test_unreachable_code(self): + + def test_fn(a): + return + a += 1 # pylint:disable=unreachable + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', 'return'), + ('a', 'return', None), + (None, 'a += 1', None), + ), + ) + + def test_if_straightline(self): + + def test_fn(a): + if a > 0: + a = 1 + else: + a += -1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', '(a > 0)'), + ('(a > 0)', 'a = 1', None), + ('(a > 0)', 'a += -1', None), + ), + ) + self.assertStatementEdges( + graph, + (('a', 'If:2', None),), + ) + + def test_branch_nested(self): + + def test_fn(a): + if a > 0: + if a > 1: + a = 1 + else: + a = 2 + else: + if a > 2: + a = 3 + else: + a = 4 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', '(a > 0)'), + ('a', '(a > 0)', ('(a > 1)', '(a > 2)')), + ('(a > 0)', '(a > 1)', ('a = 1', 'a = 2')), + ('(a > 1)', 'a = 1', None), + ('(a > 1)', 'a = 2', None), + ('(a > 0)', '(a > 2)', ('a = 3', 'a = 4')), + ('(a > 2)', 'a = 3', None), + ('(a > 2)', 'a = 4', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'If:2', None), + ('(a > 0)', 'If:3', None), + ('(a > 0)', 'If:8', None), + ), + ) + + def test_branch_straightline_semi(self): + + def test_fn(a): + if a > 0: + a = 1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (None, 'a', '(a > 0)'), + ('a', '(a > 0)', 'a = 1'), + ('(a > 0)', 'a = 1', None), + ), + ) + self.assertStatementEdges( + graph, + (('a', 'If:2', None),), + ) + + def test_branch_return(self): + + def test_fn(a): + if a > 0: + return + else: + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', '(a > 0)', ('return', 'a = 1')), + ('(a > 0)', 'a = 1', 'a = 2'), + ('(a > 0)', 'return', None), + ('a = 1', 'a = 2', None), + ), + ) + self.assertStatementEdges( + graph, + (('a', 'If:2', 'a = 2'),), + ) + + def test_branch_return_minimal(self): + + def test_fn(a): + if a > 0: + return + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', '(a > 0)', 'return'), + ('(a > 0)', 'return', None), + ), + ) + self.assertStatementEdges( + graph, + (('a', 'If:2', None),), + ) + + def test_while_straightline(self): + + def test_fn(a): + while a > 0: + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('a = 1', 'a = 2')), + ('(a > 0)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', None), + ), + ) + self.assertStatementEdges( + graph, + (('a', 'While:2', 'a = 2'),), + ) + + def test_while_else_straightline(self): + + def test_fn(a): + while a > 0: + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('a = 1', 'a = 2')), + ('(a > 0)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + (('a', 'While:2', 'a = 3'),), + ) + + def test_while_else_continue(self): + + def test_fn(a): + while a > 0: + if a > 1: + continue + else: + a = 0 + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'continue', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')), + ('(a > 0)', '(a > 1)', ('continue', 'a = 0')), + ('(a > 1)', 'continue', '(a > 0)'), + ('a = 0', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'While:2', 'a = 3'), + ('(a > 0)', 'If:3', ('a = 1', '(a > 0)')), + ), + ) + + def test_while_else_break(self): + + def test_fn(a): + while a > 0: + if a > 1: + break + a = 1 + else: + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')), + ('(a > 0)', '(a > 1)', ('break', 'a = 1')), + ('(a > 1)', 'break', 'a = 3'), + ('(a > 1)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + (('break', 'a = 2'), 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'While:2', 'a = 3'), + ('(a > 0)', 'If:3', ('a = 1', 'a = 3')), + ), + ) + + def test_while_else_return(self): + + def test_fn(a): + while a > 0: + if a > 1: + return + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')), + ('(a > 0)', '(a > 1)', ('return', 'a = 1')), + ('(a > 1)', 'return', None), + ('(a > 1)', 'a = 1', '(a > 0)'), + ('(a > 0)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'While:2', 'a = 3'), + ('(a > 0)', 'If:3', 'a = 1'), + ), + ) + + def test_while_nested_straightline(self): + + def test_fn(a): + while a > 0: + while a > 1: + a = 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')), + (('(a > 0)', 'a = 1'), '(a > 1)', ('a = 1', 'a = 2')), + ('(a > 1)', 'a = 1', '(a > 1)'), + ('(a > 1)', 'a = 2', '(a > 0)'), + ('(a > 0)', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'While:2', 'a = 3'), + ('(a > 0)', 'While:3', 'a = 2'), + ), + ) + + def test_while_nested_continue(self): + + def test_fn(a): + while a > 0: + while a > 1: + if a > 3: + continue + a = 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')), + (('(a > 0)', 'continue', 'a = 1'), '(a > 1)', ('(a > 3)', 'a = 2')), + ('(a > 1)', '(a > 3)', ('continue', 'a = 1')), + ('(a > 3)', 'continue', '(a > 1)'), + ('(a > 3)', 'a = 1', '(a > 1)'), + ('(a > 1)', 'a = 2', '(a > 0)'), + ('(a > 0)', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'While:2', 'a = 3'), + ('(a > 0)', 'While:3', 'a = 2'), + ('(a > 1)', 'If:4', ('a = 1', '(a > 1)')), + ), + ) + + def test_while_nested_break(self): + + def test_fn(a): + while a > 0: + while a > 1: + if a > 2: + break + a = 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches(graph, ( + (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')), + (('(a > 0)', 'a = 1'), '(a > 1)', ('(a > 2)', 'a = 2')), + ('(a > 1)', '(a > 2)', ('break', 'a = 1')), + ('(a > 2)', 'break', 'a = 2'), + ('(a > 2)', 'a = 1', '(a > 1)'), + (('(a > 1)', 'break'), 'a = 2', '(a > 0)'), + ('(a > 0)', 'a = 3', None), + )) + self.assertStatementEdges( + graph, + ( + ('a', 'While:2', 'a = 3'), + ('(a > 0)', 'While:3', 'a = 2'), + ('(a > 1)', 'If:4', ('a = 1', 'a = 2')), + ), + ) + + def test_for_straightline(self): + + def test_fn(a): + for a in range(0, a): + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('a = 1', 'a = 2')), + ('range(0, a)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', None), + ), + ) + self.assertStatementEdges( + graph, + (('a', 'For:2', 'a = 2'),), + ) + + def test_for_else_straightline(self): + + def test_fn(a): + for a in range(0, a): + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('a = 1', 'a = 2')), + ('range(0, a)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + (('a', 'For:2', 'a = 3'),), + ) + + def test_for_else_continue(self): + + def test_fn(a): + for a in range(0, a): + if a > 1: + continue + else: + a = 0 + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'continue', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')), + ('range(0, a)', '(a > 1)', ('continue', 'a = 0')), + ('(a > 1)', 'continue', 'range(0, a)'), + ('(a > 1)', 'a = 0', 'a = 1'), + ('a = 0', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'For:2', 'a = 3'), + ('range(0, a)', 'If:3', ('a = 1', 'range(0, a)')), + ), + ) + + def test_for_else_break(self): + + def test_fn(a): + for a in range(0, a): + if a > 1: + break + a = 1 + else: + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')), + ('range(0, a)', '(a > 1)', ('break', 'a = 1')), + ('(a > 1)', 'break', 'a = 3'), + ('(a > 1)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + (('break', 'a = 2'), 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'For:2', 'a = 3'), + ('range(0, a)', 'If:3', ('a = 1', 'a = 3')), + ), + ) + + def test_for_else_return(self): + + def test_fn(a): + for a in range(0, a): + if a > 1: + return + a = 1 + else: # pylint:disable=useless-else-on-loop + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')), + ('range(0, a)', '(a > 1)', ('return', 'a = 1')), + ('(a > 1)', 'return', None), + ('(a > 1)', 'a = 1', 'range(0, a)'), + ('range(0, a)', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'For:2', 'a = 3'), + ('range(0, a)', 'If:3', 'a = 1'), + ), + ) + + def test_for_nested_straightline(self): + + def test_fn(a): + for a in range(0, a): + for b in range(1, a): + b += 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')), + (('range(0, a)', 'b += 1'), 'range(1, a)', ('b += 1', 'a = 2')), + ('range(1, a)', 'b += 1', 'range(1, a)'), + ('range(1, a)', 'a = 2', 'range(0, a)'), + ('range(0, a)', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'For:2', 'a = 3'), + ('range(0, a)', 'For:3', 'a = 2'), + ), + ) + + def test_for_nested_continue(self): + + def test_fn(a): + for a in range(0, a): + for b in range(1, a): + if a > 3: + continue + b += 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')), + (('range(0, a)', 'continue', 'b += 1'), 'range(1, a)', + ('(a > 3)', 'a = 2')), + ('range(1, a)', '(a > 3)', ('continue', 'b += 1')), + ('(a > 3)', 'continue', 'range(1, a)'), + ('(a > 3)', 'b += 1', 'range(1, a)'), + ('range(1, a)', 'a = 2', 'range(0, a)'), + ('range(0, a)', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'For:2', 'a = 3'), + ('range(0, a)', 'For:3', 'a = 2'), + ('range(1, a)', 'If:4', ('b += 1', 'range(1, a)')), + ), + ) + + def test_for_nested_break(self): + + def test_fn(a): + for a in range(0, a): + for b in range(1, a): + if a > 2: + break + b += 1 + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')), + (('range(0, a)', 'b += 1'), 'range(1, a)', ('(a > 2)', 'a = 2')), + ('range(1, a)', '(a > 2)', ('break', 'b += 1')), + ('(a > 2)', 'break', 'a = 2'), + ('(a > 2)', 'b += 1', 'range(1, a)'), + (('range(1, a)', 'break'), 'a = 2', 'range(0, a)'), + ('range(0, a)', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('a', 'For:2', 'a = 3'), + ('range(0, a)', 'For:3', 'a = 2'), + ('range(1, a)', 'If:4', ('b += 1', 'a = 2')), + ), + ) + + def test_complex(self): + + def test_fn(a): + b = 0 + while a > 0: + for b in range(0, a): + if a > 2: + break + if a > 3: + if a > 4: + continue + else: + max(a) + break + b += 1 + else: # for b in range(0, a): + return a + a = 2 + for a in range(1, a): + return b + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('b = 0', 'a = 2'), '(a > 0)', ('range(0, a)', 'range(1, a)')), + ( + ('(a > 0)', 'continue', 'b += 1'), + 'range(0, a)', + ('(a > 2)', 'return a'), + ), + ('range(0, a)', '(a > 2)', ('(a > 3)', 'break')), + ('(a > 2)', 'break', 'a = 2'), + ('(a > 2)', '(a > 3)', ('(a > 4)', 'b += 1')), + ('(a > 3)', '(a > 4)', ('continue', 'max(a)')), + ('(a > 4)', 'max(a)', 'break'), + ('max(a)', 'break', 'a = 2'), + ('(a > 4)', 'continue', 'range(0, a)'), + ('(a > 3)', 'b += 1', 'range(0, a)'), + ('range(0, a)', 'return a', None), + ('break', 'a = 2', '(a > 0)'), + ('(a > 0)', 'range(1, a)', ('return b', 'a = 3')), + ('range(1, a)', 'return b', None), + ('range(1, a)', 'a = 3', None), + ), + ) + self.assertStatementEdges( + graph, + ( + ('b = 0', 'While:3', 'range(1, a)'), + ('(a > 0)', 'For:4', 'a = 2'), + ('range(0, a)', 'If:5', ('(a > 3)', 'a = 2')), + ('(a > 2)', 'If:7', ('b += 1', 'a = 2', 'range(0, a)')), + ('(a > 3)', 'If:8', ('a = 2', 'range(0, a)')), + ('(a > 0)', 'For:17', 'a = 3'), + ), + ) + + def test_finally_straightline(self): + + def test_fn(a): + try: + a += 1 + finally: + a = 2 + a = 3 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', 'a += 1', 'a = 2'), + ('a += 1', 'a = 2', 'a = 3'), + ('a = 2', 'a = 3', None), + ), + ) + + def test_return_finally(self): + + def test_fn(a): + try: + return a + finally: + a = 1 + a = 2 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', 'return a', 'a = 1'), + ('return a', 'a = 1', None), + (None, 'a = 2', None), + ), + ) + + def test_break_finally(self): + + def test_fn(a): + while a > 0: + try: + break + finally: + a = 1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', '(a > 0)', 'break'), + ('(a > 0)', 'break', 'a = 1'), + ('break', 'a = 1', None), + ), + ) + + def test_continue_finally(self): + + def test_fn(a): + while a > 0: + try: + continue + finally: + a = 1 + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + (('a', 'a = 1'), '(a > 0)', 'continue'), + ('(a > 0)', 'continue', 'a = 1'), + ('continue', 'a = 1', '(a > 0)'), + ), + ) + + def test_with_straightline(self): + + def test_fn(a): + with max(a) as b: + a = 0 + return b + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', 'max(a)', 'a = 0'), + ('max(a)', 'a = 0', 'return b'), + ('a = 0', 'return b', None), + ), + ) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD index ca1441cf6f8bb034c95b37fcdd9e8158d1db2e39..a0938b3e5f0e52532f63fea6fb4c3e478fc51d93 100644 --- a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD +++ b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD @@ -24,6 +24,7 @@ py_library( deps = [ "//tensorflow/contrib/autograph/pyct", "@gast_archive//:gast", + "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py index cc039986c219db1febfe610a5078e26eeb2d5a83..e42f679cfe31f919e10f7baf409247014b3cf386 100644 --- a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py @@ -12,12 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Conversion to A-normal form.""" +"""Conversion to A-normal form. + +The general idea of A-normal form is that every intermediate value is +explicitly named with a variable. For more, see +https://en.wikipedia.org/wiki/A-normal_form. + +The specific converters used here are based on Python AST semantics as +documented at https://greentreesnakes.readthedocs.io/en/latest/. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gast +import six + +from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct import transformer @@ -32,26 +44,375 @@ class DummyGensym(object): # * the symbols generated so far self._idx = 0 - def new_name(self, stem): + def new_name(self, stem='tmp'): self._idx += 1 return stem + '_' + str(1000 + self._idx) class AnfTransformer(transformer.Base): - """Performs the actual conversion.""" + """Performs the conversion to A-normal form (ANF).""" - # TODO(mdan): Link to a reference. - # TODO(mdan): Implement. + # The algorithm is a postorder recursive tree walk. Any given node A may, in + # general, require creation of a series B of Assign statements, which compute + # and explicitly name the intermediate values needed to compute the value of + # A. If A was already a statement, it can be replaced with the sequence B + + # [A]. If A was an expression, B needs to be propagated up the tree until a + # statement is encountered. Since the `ast.NodeTransformer` framework makes + # no provision for subtraversals returning side information, this class + # accumulates the sequence B in an instance variable. - def __init__(self, entity_info): - """Creates a transformer. + # The only other subtlety is that some Python statements (like `if`) have both + # expression fields (`test`) and statement list fields (`body` and `orelse`). + # Any additional assignments needed to name all the intermediate values in the + # `test` can be prepended to the `if` node, but assignments produced by + # processing the `body` and the `orelse` need to be kept together with them, + # and not accidentally lifted out of the `if`. + + def __init__(self, entity_info, gensym_source=None): + """Creates an ANF transformer. Args: entity_info: transformer.EntityInfo + gensym_source: An optional object with the same interface as `DummyGensym` + for generating unique names """ super(AnfTransformer, self).__init__(entity_info) - self._gensym = DummyGensym(entity_info) + if gensym_source is None: + self._gensym = DummyGensym(entity_info) + else: + self._gensym = gensym_source(entity_info) + self._pending_statements = [] + + def _consume_pending_statements(self): + ans = self._pending_statements + self._pending_statements = [] + return ans + + def _add_pending_statement(self, stmt): + self._pending_statements.append(stmt) + + _trivial_nodes = ( + # Non-nodes that show up as AST fields + bool, six.string_types, + # Leaf nodes that are already in A-normal form + gast.expr_context, gast.Name, gast.Num, gast.Str, gast.Bytes, + gast.NameConstant, gast.Ellipsis, + # Binary operators + gast.Add, gast.Sub, gast.Mult, gast.Div, gast.Mod, gast.Pow, gast.LShift, + gast.RShift, gast.BitOr, gast.BitXor, gast.BitAnd, gast.FloorDiv, + # Unary operators + gast.Invert, gast.Not, gast.UAdd, gast.USub, + # Comparison operators + gast.Eq, gast.NotEq, gast.Lt, gast.LtE, gast.Gt, gast.GtE, + gast.Is, gast.IsNot, gast.In, gast.NotIn, + ) + + def _is_node_trivial(self, node): + if node is None: + return True + elif isinstance(node, self._trivial_nodes): + return True + elif isinstance(node, gast.keyword): + return self._is_node_trivial(node.value) + elif isinstance(node, (gast.Starred, gast.withitem, gast.slice)): + return self._are_children_trivial(node) + return False + + def _are_children_trivial(self, node): + for field in node._fields: + if not field.startswith('__'): + if not self._is_node_trivial(getattr(node, field)): + return False + return True + + def _ensure_node_is_trivial(self, node): + if node is None: + return node + elif isinstance(node, self._trivial_nodes): + return node + elif isinstance(node, list): + # If something's field was actually a list, e.g., variadic arguments. + return [self._ensure_node_is_trivial(n) for n in node] + elif isinstance(node, gast.keyword): + node.value = self._ensure_node_is_trivial(node.value) + return node + elif isinstance(node, (gast.Starred, gast.withitem, gast.slice)): + return self._ensure_fields_trivial(node) + elif isinstance(node, gast.expr): + temp_name = self._gensym.new_name() + temp_assign = templates.replace( + 'temp_name = expr', temp_name=temp_name, expr=node)[0] + self._add_pending_statement(temp_assign) + answer = templates.replace('temp_name', temp_name=temp_name)[0] + return answer + else: + raise ValueError('Do not know how to treat {}'.format(node)) + + def _ensure_fields_trivial(self, node): + for field in node._fields: + if field.startswith('__'): + continue + setattr(node, field, self._ensure_node_is_trivial(getattr(node, field))) + return node + + def _visit_strict_statement(self, node, trivialize_children=True): + assert not self._pending_statements + node = self.generic_visit(node) + if trivialize_children: + self._ensure_fields_trivial(node) + results = self._consume_pending_statements() + results.append(node) + return results + + def _visit_strict_expression(self, node): + node = self.generic_visit(node) + self._ensure_fields_trivial(node) + return node + + # Note on code order: These are listed in the same order as the grammar + # elements on https://github.com/serge-sans-paille/gast + + # FunctionDef, AsyncFunctionDef, and ClassDef should be correct by default. + + def visit_Return(self, node): + return self._visit_strict_statement(node) + + def visit_Delete(self, node): + return self._visit_strict_statement(node, trivialize_children=False) + + def visit_Assign(self, node): + return self._visit_strict_statement(node, trivialize_children=False) + + def visit_AugAssign(self, node): + return self._visit_strict_statement(node, trivialize_children=False) + + def visit_Print(self, node): + return self._visit_strict_statement(node) + + def visit_For(self, node): + assert not self._pending_statements + # It's important to visit node.iter first, because any statements created + # thereby need to live outside the body. + self.visit(node.iter) + node.iter = self._ensure_node_is_trivial(node.iter) + iter_stmts = self._consume_pending_statements() + # This generic_visit will revisit node.iter, but that is both correct and + # cheap because by this point node.iter is trivial. + node = self.generic_visit(node) + assert not self._pending_statements + iter_stmts.append(node) + return iter_stmts + + def visit_AsyncFor(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial AsyncFor nodes not supported yet ' + '(need to think through the semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_While(self, node): + if not self._is_node_trivial(node.test): + msg = ('While with nontrivial test not supported yet ' + '(need to avoid precomputing the test).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_If(self, node): + assert not self._pending_statements + # It's important to visit node.test first, because any statements created + # thereby need to live outside the body. + self.visit(node.test) + node.test = self._ensure_node_is_trivial(node.test) + condition_stmts = self._consume_pending_statements() + # This generic_visit will revisit node.test, but that is both correct and + # cheap because by this point node.test is trivial. + node = self.generic_visit(node) + assert not self._pending_statements + condition_stmts.append(node) + return condition_stmts + + def visit_With(self, node): + assert not self._pending_statements + # It's important to visit node.items first, because any statements created + # thereby need to live outside the body. + for item in node.items: + self.visit(item) + node.items = [self._ensure_node_is_trivial(n) for n in node.items] + contexts_stmts = self._consume_pending_statements() + # This generic_visit will revisit node.items, but that is both correct and + # cheap because by this point node.items is trivial. + node = self.generic_visit(node) + assert not self._pending_statements + contexts_stmts.append(node) + return contexts_stmts + + def visit_AsyncWith(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial AsyncWith nodes not supported yet ' + '(need to think through the semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Raise(self, node): + return self._visit_strict_statement(node) + + # Try should be correct by default. + + def visit_Assert(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial Assert nodes not supported yet ' + '(need to avoid computing the test when assertions are off, and ' + 'avoid computing the irritant when the assertion does not fire).') + raise ValueError(msg) + return self.generic_visit(node) + + # Import and ImportFrom should be correct by default. + + def visit_Exec(self, node): + return self._visit_strict_statement(node) + + # Global and Nonlocal should be correct by default. + + def visit_Expr(self, node): + return self._visit_strict_statement(node, trivialize_children=False) + + # Pass, Break, and Continue should be correct by default. + + def visit_BoolOp(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial BoolOp nodes not supported yet ' + '(need to preserve short-circuiting semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_BinOp(self, node): + return self._visit_strict_expression(node) + + def visit_UnaryOp(self, node): + return self._visit_strict_expression(node) + + def visit_Lambda(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial Lambda nodes not supported ' + '(cannot insert statements into lambda bodies).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_IfExp(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial IfExp nodes not supported yet ' + '(need to convert to If statement, to evaluate branches lazily ' + 'and insert statements into them).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Dict(self, node): + return self._visit_strict_expression(node) + + def visit_Set(self, node): + return self._visit_strict_expression(node) + + def visit_ListComp(self, node): + msg = ('ListComp nodes not supported ' + '(need to convert to a form that tolerates ' + 'assignment statements in clause bodies).') + raise ValueError(msg) + + def visit_SetComp(self, node): + msg = ('SetComp nodes not supported ' + '(need to convert to a form that tolerates ' + 'assignment statements in clause bodies).') + raise ValueError(msg) + + def visit_DictComp(self, node): + msg = ('DictComp nodes not supported ' + '(need to convert to a form that tolerates ' + 'assignment statements in clause bodies).') + raise ValueError(msg) + + def visit_GeneratorExp(self, node): + msg = ('GeneratorExp nodes not supported ' + '(need to convert to a form that tolerates ' + 'assignment statements in clause bodies).') + raise ValueError(msg) + + def visit_Await(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial Await nodes not supported yet ' + '(need to think through the semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Yield(self, node): + return self._visit_strict_expression(node) + + def visit_YieldFrom(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial YieldFrom nodes not supported yet ' + '(need to unit-test them in Python 2).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Compare(self, node): + if len(node.ops) > 1: + msg = ('Multi-ary compare nodes not supported yet ' + '(need to preserve short-circuiting semantics).') + raise ValueError(msg) + return self._visit_strict_expression(node) + + def visit_Call(self, node): + return self._visit_strict_expression(node) + + def visit_Repr(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial Repr nodes not supported yet ' + '(need to research their syntax and semantics).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_FormattedValue(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial FormattedValue nodes not supported yet ' + '(need to unit-test them in Python 2).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_JoinedStr(self, node): + if not self._are_children_trivial(node): + msg = ('Nontrivial JoinedStr nodes not supported yet ' + '(need to unit-test them in Python 2).') + raise ValueError(msg) + return self.generic_visit(node) + + def visit_Attribute(self, node): + return self._visit_strict_expression(node) + + def visit_Subscript(self, node): + return self._visit_strict_expression(node) + + # Starred and Name are correct by default, because the right thing to do is to + # just recur. + + def visit_List(self, node): + return self._visit_strict_expression(node) + + def visit_Tuple(self, node): + return self._visit_strict_expression(node) + + +def transform(node, entity_info, gensym_source=None): + """Converts the given node to A-normal form (ANF). + + The general idea of A-normal form: https://en.wikipedia.org/wiki/A-normal_form + The specific converters used here are based on Python AST semantics as + documented at https://greentreesnakes.readthedocs.io/en/latest/. -def transform(node, entity_info): - return AnfTransformer(entity_info).visit(node) + Args: + node: The node to transform. + entity_info: transformer.EntityInfo. TODO(mdan): What information does this + argument provide? + gensym_source: An optional object with the same interface as `DummyGensym` + for generating unique names. + """ + return AnfTransformer(entity_info, gensym_source=gensym_source).visit(node) diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py index 81983a5ecb7b8c6216285409f854e27b7154a08b..951974820c784974cb5bb2320adbb2b07f9332df 100644 --- a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import textwrap + from tensorflow.contrib.autograph.pyct import compiler from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import transformer @@ -25,6 +27,22 @@ from tensorflow.contrib.autograph.pyct.common_transformers import anf from tensorflow.python.platform import test +class DummyGensym(object): + """A dumb gensym that suffixes a stem by sequential numbers from 1000.""" + + def __init__(self, entity_info): + del entity_info + # A proper implementation needs to account for: + # * entity_info.namespace + # * all the symbols defined in the AST + # * the symbols generated so far + self._idx = 0 + + def new_name(self, stem='tmp'): + self._idx += 1 + return stem + '_' + str(1000 + self._idx) + + class AnfTransformerTest(test.TestCase): def _simple_source_info(self): @@ -37,17 +55,349 @@ class AnfTransformerTest(test.TestCase): owner_type=None) def test_basic(self): - def test_function(): a = 0 return a - node, _ = parser.parse_entity(test_function) - node = anf.transform(node, self._simple_source_info()) + node = anf.transform(node.body[0], self._simple_source_info()) result, _ = compiler.ast_to_object(node) - self.assertEqual(test_function(), result.test_function()) + def assert_same_ast(self, expected_node, node, msg=None): + expected_source = compiler.ast_to_source(expected_node, indentation=' ') + expected_str = textwrap.dedent(expected_source).strip() + got_source = compiler.ast_to_source(node, indentation=' ') + got_str = textwrap.dedent(got_source).strip() + self.assertEqual(expected_str, got_str, msg=msg) + + def assert_body_anfs_as_expected(self, expected_fn, test_fn): + # Testing the code bodies only. Wrapping them in functions so the + # syntax highlights nicely, but Python doesn't try to execute the + # statements. + exp_node, _ = parser.parse_entity(expected_fn) + node, _ = parser.parse_entity(test_fn) + node = anf.transform( + node, self._simple_source_info(), gensym_source=DummyGensym) + exp_name = exp_node.body[0].name + # Ignoring the function names in the result because they can't be + # the same (because both functions have to exist in the same scope + # at the same time). + node.body[0].name = exp_name + self.assert_same_ast(exp_node, node) + # Check that ANF is idempotent + node_repeated = anf.transform( + node, self._simple_source_info(), gensym_source=DummyGensym) + self.assert_same_ast(node_repeated, node) + + def test_binop_basic(self): + + def test_function(x, y, z): + a = x + y + z + return a + + def expected_result(x, y, z): + tmp_1001 = x + y + a = tmp_1001 + z + return a + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_if_basic(self): + + def test_function(a, b, c, e, f, g): + if a + b + c: + d = e + f + g + return d + + def expected_result(a, b, c, e, f, g): + tmp_1001 = a + b + tmp_1002 = tmp_1001 + c + if tmp_1002: + tmp_1003 = e + f + d = tmp_1003 + g + return d + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_nested_binop_and_return(self): + + def test_function(b, c, d, e): + return (2 * b + c) + (d + e) + + def expected_result(b, c, d, e): + tmp_1001 = 2 * b + tmp_1002 = tmp_1001 + c + tmp_1003 = d + e + tmp_1004 = tmp_1002 + tmp_1003 + return tmp_1004 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_function_call_and_expr(self): + + def test_function(call_something, a, b, y, z, c, d, e, f, g, h, i): + call_something(a + b, y * z, kwarg=c + d, *(e + f), **(g + h + i)) + + def expected_result(call_something, a, b, y, z, c, d, e, f, g, h, i): + tmp_1001 = g + h + tmp_1002 = a + b + tmp_1003 = y * z + tmp_1004 = e + f + tmp_1005 = c + d + tmp_1006 = tmp_1001 + i + call_something(tmp_1002, tmp_1003, kwarg=tmp_1005, *tmp_1004, **tmp_1006) + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_with_and_print(self): + + def test_function(a, b, c): + with a + b + c as d: + print(2 * d + 1) + + def expected_result(a, b, c): + tmp_1001 = a + b + tmp_1002 = tmp_1001 + c + with tmp_1002 as d: + tmp_1003 = 2 * d + tmp_1004 = tmp_1003 + 1 + print(tmp_1004) + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_local_definition_and_binary_compare(self): + + def test_function(): + def foo(a, b): + return 2 * a < b + return foo + + def expected_result(): + def foo(a, b): + tmp_1001 = 2 * a + tmp_1002 = tmp_1001 < b + return tmp_1002 + return foo + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_list_literal(self): + + def test_function(a, b, c, d, e, f): + return [a + b, c + d, e + f] + + def expected_result(a, b, c, d, e, f): + tmp_1001 = a + b + tmp_1002 = c + d + tmp_1003 = e + f + tmp_1004 = [tmp_1001, tmp_1002, tmp_1003] + return tmp_1004 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_tuple_literal_and_unary(self): + + def test_function(a, b, c, d, e, f): + return (a + b, -(c + d), e + f) + + def expected_result(a, b, c, d, e, f): + tmp_1001 = c + d + tmp_1002 = a + b + tmp_1003 = -tmp_1001 + tmp_1004 = e + f + tmp_1005 = (tmp_1002, tmp_1003, tmp_1004) + return tmp_1005 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_set_literal(self): + + def test_function(a, b, c, d, e, f): + return set(a + b, c + d, e + f) + + def expected_result(a, b, c, d, e, f): + tmp_1001 = a + b + tmp_1002 = c + d + tmp_1003 = e + f + tmp_1004 = set(tmp_1001, tmp_1002, tmp_1003) + return tmp_1004 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_dict_literal_and_repr(self): + + def test_function(foo, bar, baz): + return repr({foo + bar + baz: 7 | 8}) + + def expected_result(foo, bar, baz): + tmp_1001 = foo + bar + tmp_1002 = tmp_1001 + baz + tmp_1003 = 7 | 8 + tmp_1004 = {tmp_1002: tmp_1003} + tmp_1005 = repr(tmp_1004) + return tmp_1005 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_field_read_and_write(self): + + def test_function(a, d): + a.b.c = d.e.f + 3 + + def expected_result(a, d): + tmp_1001 = a.b + tmp_1002 = d.e + tmp_1003 = tmp_1002.f + tmp_1001.c = tmp_1003 + 3 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_subscript_read_and_write(self): + + def test_function(a, b, c, d, e, f): + a[b][c] = d[e][f] + 3 + + def expected_result(a, b, c, d, e, f): + tmp_1001 = a[b] + tmp_1002 = d[e] + tmp_1003 = tmp_1002[f] + tmp_1001[c] = tmp_1003 + 3 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_augassign_and_delete(self): + + def test_function(a, x, y, z): + a += x + y + z + del a + del z[y][x] + + def expected_result(a, x, y, z): + tmp_1001 = x + y + a += tmp_1001 + z + del a + tmp_1002 = z[y] + del tmp_1002[x] + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_raise_yield_and_raise(self): + + def test_function(a, c, some_computed, exception): + yield a ** c + raise some_computed('complicated' + exception) + + def expected_result(a, c, some_computed, exception): + tmp_1001 = a ** c + yield tmp_1001 + tmp_1002 = 'complicated' + exception + tmp_1003 = some_computed(tmp_1002) + raise tmp_1003 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_with_and_if_with_expressions(self): + + def test_function(foo, bar, function, quux, quozzle, w, x, y, z): + with foo + bar: + function(x + y) + if quux + quozzle: + function(z / w) + + def expected_result(foo, bar, function, quux, quozzle, w, x, y, z): + tmp_1001 = foo + bar + with tmp_1001: + tmp_1002 = x + y + function(tmp_1002) + tmp_1003 = quux + quozzle + if tmp_1003: + tmp_1004 = z / w + function(tmp_1004) + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_exec(self): + + def test_function(): + # The point is to test A-normal form conversion of exec + # pylint: disable=exec-used + exec('computed' + 5 + 'stuff', globals(), locals()) + + def expected_result(): + # pylint: disable=exec-used + tmp_1001 = 'computed' + 5 + tmp_1002 = tmp_1001 + 'stuff' + tmp_1003 = globals() + tmp_1004 = locals() + exec(tmp_1002, tmp_1003, tmp_1004) + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_simple_while_and_assert(self): + + def test_function(foo, quux): + while foo: + assert quux + foo = foo + 1 * 3 + + def expected_result(foo, quux): + while foo: + assert quux + tmp_1001 = 1 * 3 + foo = foo + tmp_1001 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_for(self): + + def test_function(compute, something, complicated, foo): + for foo in compute(something + complicated): + bar = foo + 1 * 3 + return bar + + def expected_result(compute, something, complicated, foo): + tmp_1001 = something + complicated + tmp_1002 = compute(tmp_1001) + for foo in tmp_1002: + tmp_1003 = 1 * 3 + bar = foo + tmp_1003 + return bar + + self.assert_body_anfs_as_expected(expected_result, test_function) + + # This test collects several examples where the definition of A-normal form + # implemented by this transformer is questionable. Mostly it's here to spell + # out what the definition is in these cases. + def test_controversial(self): + + def test_function(b, c, d, f): + a = c + d + a.b = c + d + a[b] = c + d + a += c + d + a, b = c + a, b = c, d + a = f(c) + a = f(c + d) + a[b + d] = f.e(c + d) + + def expected_result(b, c, d, f): + a = c + d + a.b = c + d # Should be a.b = tmp? (Definitely not tmp = c + d) + a[b] = c + d # Should be a[b] = tmp? (Definitely not tmp = c + d) + a += c + d # Should be a += tmp? (Definitely not tmp = c + d) + a, b = c # Should be a = c[0], b = c[1]? Or not? + a, b = c, d # Should be a = c, b = d? Or not? + a = f(c) + tmp_1001 = c + d + a = f(tmp_1001) + tmp_1002 = b + d + tmp_1003 = f.e + tmp_1004 = c + d + a[tmp_1002] = tmp_1003(tmp_1004) # Or should be a[tmp1] = tmp2? + + self.assert_body_anfs_as_expected(expected_result, test_function) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/compiler.py b/tensorflow/contrib/autograph/pyct/compiler.py index 24c4517afa89147101f80af3ef60237132c1144c..f9cee109624dafd4da4a0981c5f8fda0a5d8a5e7 100644 --- a/tensorflow/contrib/autograph/pyct/compiler.py +++ b/tensorflow/contrib/autograph/pyct/compiler.py @@ -30,46 +30,112 @@ import tempfile import astor import gast +from tensorflow.contrib.autograph.pyct import origin_info + def ast_to_source(node, indentation=' '): - """Return the source code of given AST.""" - if isinstance(node, gast.AST): - node = gast.gast_to_ast(node) + """Return the source code of given AST. + + Args: + node: The code to compile, as an AST object. + indentation: The string to use for indentation. + + Returns: + code: The source code generated from the AST object + source_mapping: A mapping between the user and AutoGraph generated code. + """ + if not isinstance(node, (list, tuple)): + node = (node,) generator = astor.codegen.SourceGenerator(indentation, False, astor.string_repr.pretty_string) - generator.visit(node) - generator.result.append('\n') + + for n in node: + if isinstance(n, gast.AST): + n = gast.gast_to_ast(n) + generator.visit(n) + generator.result.append('\n') + # In some versions of Python, literals may appear as actual values. This # ensures everything is string. code = map(str, generator.result) - return astor.source_repr.pretty_source(code).lstrip() + code = astor.source_repr.pretty_source(code).lstrip() + return code -def ast_to_object( - node, indentation=' ', source_prefix=None, delete_on_exit=True): + +def ast_to_object(nodes, + indentation=' ', + include_source_map=False, + source_prefix=None, + delete_on_exit=True): """Return the Python objects represented by given AST. Compiling the AST code this way ensures that the source code is readable by e.g. `pdb` or `inspect`. Args: - node: The code to compile, as an AST object. - indentation: The string to use for indentation. - source_prefix: Optional string to print as-is into the source file. - delete_on_exit: Whether to delete the temporary file used for compilation - on exit. + nodes: Union[ast.AST, Iterable[ast.AST]], the code to compile, as an AST + object. + indentation: Text, the string to use for indentation. + include_source_map: bool, whether to attach a source map to the compiled + object. Also see origin_info.py. + source_prefix: Optional[Text], string to print as-is into the source file. + delete_on_exit: bool, whether to delete the temporary file used for + compilation on exit. Returns: - A module object containing the compiled source code. + compiled_nodes: A module object containing the compiled source code. + source: The source code of the compiled object + Raises: + ValueError: If ag_source_map__ is already in the namespace of the compiled + nodes. """ - source = ast_to_source(node, indentation) + if not isinstance(nodes, (list, tuple)): + nodes = (nodes,) + + source = ast_to_source(nodes, indentation=indentation) + + if source_prefix: + source = source_prefix + '\n' + source with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: module_name = os.path.basename(f.name[:-3]) - if source_prefix: - f.write(source_prefix) - f.write('\n') f.write(source) + + if isinstance(nodes, (list, tuple)): + indices = range(-len(nodes), 0) + else: + indices = (-1,) + + if include_source_map: + source_map = origin_info.source_map(nodes, source, f.name, indices) + + # TODO(mdan): Try flush() and delete=False instead. if delete_on_exit: atexit.register(lambda: os.remove(f.name)) - return imp.load_source(module_name, f.name), source + compiled_nodes = imp.load_source(module_name, f.name) + + # TODO(znado): Clean this up so we don't need to attach it to the namespace. + # TODO(znado): This does not work for classes because their methods share a + # namespace. + # This attaches the source map which is needed for error handling. Note that + # api.to_graph copies this source map into an attribute of the function. + # + # We need this so the ag_source_map__ variable is available to the call to + # rewrite_graph_construction_error in the except block inside each function + # that handles graph construction errors. + # + # We cannot get the rewritten function name until it is too late so templating + # is hard, and this cleanly fixes the + # issues encountered with nested functions because this is attached to the + # outermost one. + if include_source_map: + # TODO(mdan): This name should be decided by the caller. + source_map_name = 'ag_source_map__' + if source_map_name in compiled_nodes.__dict__: + raise ValueError('cannot convert %s because is has namespace attribute ' + '"%s", which is reserved for AutoGraph.' % + (compiled_nodes, source_map_name)) + compiled_nodes.__dict__[source_map_name] = source_map + + return compiled_nodes, source diff --git a/tensorflow/contrib/autograph/pyct/compiler_test.py b/tensorflow/contrib/autograph/pyct/compiler_test.py index 98cdc1506b6aced603df99662f1468687a55f92c..cf783da6a3e540c6901a5fe9a5e4afdb6b1cfc03 100644 --- a/tensorflow/contrib/autograph/pyct/compiler_test.py +++ b/tensorflow/contrib/autograph/pyct/compiler_test.py @@ -59,14 +59,14 @@ class CompilerTest(test.TestCase): value=gast.Str('c')) ]) + source = compiler.ast_to_source(node, indentation=' ') self.assertEqual( textwrap.dedent(""" if 1: a = b else: a = 'c' - """).strip(), - compiler.ast_to_source(node, indentation=' ').strip()) + """).strip(), source.strip()) def test_ast_to_object(self): node = gast.FunctionDef( diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/contrib/autograph/pyct/origin_info.py new file mode 100644 index 0000000000000000000000000000000000000000..b60651a30e342dabe40cbcef1486826e16c2e2c7 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/origin_info.py @@ -0,0 +1,186 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Container for origin source code information before AutoGraph compilation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import tokenize + +import gast +import six + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.python.util import tf_inspect + + +class LineLocation( + collections.namedtuple('LineLocation', ('filename', 'lineno'))): + """Similar to Location, but without column information. + + Attributes: + filename: Text + lineno: int, 1-based + """ + pass + + +class Location( + collections.namedtuple('Location', ('filename', 'lineno', 'col_offset'))): + """Encodes code location information. + + Attributes: + filename: Text + lineno: int, 1-based + col_offset: int + """ + + @property + def line_loc(self): + return LineLocation(self.filename, self.lineno) + + +class OriginInfo( + collections.namedtuple( + 'OriginInfo', + ('loc', 'function_name', 'source_code_line', 'comment'))): + """Container for information about the source code before conversion. + + Attributes: + loc: Location + function_name: Optional[Text] + source_code_line: Text + comment: Optional[Text] + """ + + def as_frame(self): + """Returns a 4-tuple consistent with the return of traceback.extract_tb.""" + return (self.loc.filename, self.loc.lineno, self.function_name, + self.source_code_line) + + +# TODO(mdan): This source map should be a class - easier to refer to. +def source_map(nodes, code, filename, indices_in_code): + """Creates a source map between an annotated AST and the code it compiles to. + + Args: + nodes: Iterable[ast.AST, ...] + code: Text + filename: Optional[Text] + indices_in_code: Union[int, Iterable[int, ...]], the positions at which + nodes appear in code. The parser always returns a module when parsing + code. This argument indicates the position in that module's body at + which the corresponding of node should appear. + + Returns: + Dict[CodeLocation, OriginInfo], mapping locations in code to locations + indicated by origin annotations in node. + """ + reparsed_nodes = parser.parse_str(code) + reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code] + + resolve(reparsed_nodes, code) + result = {} + + for before, after in ast_util.parallel_walk(nodes, reparsed_nodes): + # Note: generated code might not be mapped back to its origin. + # TODO(mdan): Generated code should always be mapped to something. + origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None) + final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None) + if origin_info is None or final_info is None: + continue + + line_loc = LineLocation(filename, final_info.loc.lineno) + + existing_origin = result.get(line_loc) + if existing_origin is not None: + # Overlaps may exist because of child nodes, but almost never to + # different line locations. Exception make decorated functions, where + # both lines are mapped to the same line in the AST. + + # Line overlaps: keep bottom node. + if existing_origin.loc.line_loc == origin_info.loc.line_loc: + if existing_origin.loc.lineno >= origin_info.loc.lineno: + continue + + # In case of overlaps, keep the leftmost node. + if existing_origin.loc.col_offset <= origin_info.loc.col_offset: + continue + + result[line_loc] = origin_info + + return result + + +# TODO(znado): Consider refactoring this into a Visitor. +# TODO(mdan): Does this work correctly with inner functions? +def resolve(nodes, source, function=None): + """Adds an origin information to all nodes inside the body of function. + + Args: + nodes: Union[ast.AST, Iterable[ast.AST, ...]] + source: Text, the source code string for the function whose body nodes will + be annotated. + function: Callable, the function that will have all nodes inside of it + annotation with an OriginInfo annotation with key anno.Basic.ORIGIN. If + it is None then only the line numbers and column offset will be set in the + annotation, with the rest of the information being None. + + Returns: + A tuple of the AST node for function and a String containing its source + code. + """ + if not isinstance(nodes, (list, tuple)): + nodes = (nodes,) + + if function: + _, function_lineno = tf_inspect.getsourcelines(function) + function_filepath = tf_inspect.getsourcefile(function) + else: + function_lineno = None + function_filepath = None + + # TODO(mdan): Pull this to a separate utility. + code_reader = six.StringIO(source) + comment_map = {} + for token in tokenize.generate_tokens(code_reader.readline): + tok_type, tok_string, loc, _, _ = token + srow, _ = loc + if tok_type == tokenize.COMMENT: + comment_map[srow] = tok_string.strip()[1:].strip() + + source_lines = source.split('\n') + for node in nodes: + for n in gast.walk(node): + if not hasattr(n, 'lineno'): + continue + + lineno_in_body = n.lineno + + source_code_line = source_lines[lineno_in_body - 1] + if function: + source_lineno = function_lineno + lineno_in_body + function_name = function.__name__ + else: + source_lineno = lineno_in_body + function_name = None + + location = Location(function_filepath, source_lineno, n.col_offset) + origin = OriginInfo(location, function_name, + source_code_line, comment_map.get(source_lineno)) + anno.setanno(n, anno.Basic.ORIGIN, origin) diff --git a/tensorflow/contrib/autograph/pyct/origin_info_test.py b/tensorflow/contrib/autograph/pyct/origin_info_test.py new file mode 100644 index 0000000000000000000000000000000000000000..eeaa13007ea0ae331293c216a76352956c0ee9ec --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/origin_info_test.py @@ -0,0 +1,104 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for origin_info module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import origin_info +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.python.platform import test + + +class OriginInfoTest(test.TestCase): + + def test_source_map(self): + + def test_fn(x): + if x > 0: + x += 1 + return x + + node, source = parser.parse_entity(test_fn) + fn_node = node.body[0] + origin_info.resolve(fn_node, source) + + # Insert a traced line. + new_node = parser.parse_str('x = abs(x)').body[0] + anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN) + fn_node.body.insert(0, new_node) + + # Insert an untraced line. + fn_node.body.insert(0, parser.parse_str('x = 0').body[0]) + + modified_source = compiler.ast_to_source(fn_node) + + source_map = origin_info.source_map(fn_node, modified_source, + 'test_filename', [0]) + + loc = origin_info.LineLocation('test_filename', 1) + origin = source_map[loc] + self.assertEqual(origin.source_code_line, 'def test_fn(x):') + self.assertEqual(origin.loc.lineno, 1) + + # The untraced line, inserted second. + loc = origin_info.LineLocation('test_filename', 2) + self.assertFalse(loc in source_map) + + # The traced line, inserted first. + loc = origin_info.LineLocation('test_filename', 3) + origin = source_map[loc] + self.assertEqual(origin.source_code_line, ' if x > 0:') + self.assertEqual(origin.loc.lineno, 2) + + loc = origin_info.LineLocation('test_filename', 4) + origin = source_map[loc] + self.assertEqual(origin.source_code_line, ' if x > 0:') + self.assertEqual(origin.loc.lineno, 2) + + def test_resolve(self): + + def test_fn(x): + """Docstring.""" + return x # comment + + node, source = parser.parse_entity(test_fn) + fn_node = node.body[0] + origin_info.resolve(fn_node, source) + + origin = anno.getanno(fn_node, anno.Basic.ORIGIN) + self.assertEqual(origin.loc.lineno, 1) + self.assertEqual(origin.loc.col_offset, 0) + self.assertEqual(origin.source_code_line, 'def test_fn(x):') + self.assertIsNone(origin.comment) + + origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN) + self.assertEqual(origin.loc.lineno, 2) + self.assertEqual(origin.loc.col_offset, 2) + self.assertEqual(origin.source_code_line, ' """Docstring."""') + self.assertIsNone(origin.comment) + + origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN) + self.assertEqual(origin.loc.lineno, 3) + self.assertEqual(origin.loc.col_offset, 2) + self.assertEqual(origin.source_code_line, ' return x # comment') + self.assertEqual(origin.comment, 'comment') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/parser.py b/tensorflow/contrib/autograph/pyct/parser.py index c961efa892df6a21804dae8f52ef64bf99cd409e..112ed46a1e487a7904e79267c1ce7db0ad914552 100644 --- a/tensorflow/contrib/autograph/pyct/parser.py +++ b/tensorflow/contrib/autograph/pyct/parser.py @@ -37,6 +37,7 @@ def parse_entity(entity): def parse_str(src): """Returns the AST of given piece of code.""" + # TODO(mdan): This should exclude the module things are autowrapped in. return gast.parse(src) diff --git a/tensorflow/contrib/autograph/pyct/qual_names.py b/tensorflow/contrib/autograph/pyct/qual_names.py index da07013cf4f4309c0e24adda3017575d942861b7..fb81404edc1994309f5108fc7e7ba368a1ea3ccb 100644 --- a/tensorflow/contrib/autograph/pyct/qual_names.py +++ b/tensorflow/contrib/autograph/pyct/qual_names.py @@ -30,6 +30,7 @@ import collections import gast from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser class Symbol(collections.namedtuple('Symbol', ['name'])): @@ -89,7 +90,8 @@ class QN(object): if not isinstance(base, (str, StringLiteral, NumberLiteral)): # TODO(mdan): Require Symbol instead of string. raise ValueError( - 'For simple QNs, base must be a string or a Literal object.') + 'for simple QNs, base must be a string or a Literal object;' + ' got instead "%s"' % type(base)) assert '.' not in base and '[' not in base and ']' not in base self._parent = None self.qn = (base,) @@ -112,6 +114,22 @@ class QN(object): raise ValueError('Cannot get parent of simple name "%s".' % self.qn[0]) return self._parent + @property + def owner_set(self): + """Returns all the symbols (simple or composite) that own this QN. + + In other words, if this symbol was modified, the symbols in the owner set + may also be affected. + + Examples: + 'a.b[c.d]' has two owners, 'a' and 'a.b' + """ + owners = set() + if self.has_attr() or self.has_subscript(): + owners.add(self.parent) + owners.update(self.parent.owner_set) + return owners + @property def support_set(self): """Returns the set of simple symbols that this QN relies on. @@ -122,7 +140,7 @@ class QN(object): Examples: 'a.b' has only one support symbol, 'a' - 'a[i]' has two roots, 'a' and 'i' + 'a[i]' has two support symbols, 'a' and 'i' """ # TODO(mdan): This might be the set of Name nodes in the AST. Track those? roots = set() @@ -231,3 +249,9 @@ class QnResolver(gast.NodeTransformer): def resolve(node): return QnResolver().visit(node) + + +def from_str(qn_str): + node = parser.parse_expression(qn_str) + node = resolve(node) + return anno.getanno(node, anno.Basic.QN) diff --git a/tensorflow/contrib/autograph/pyct/qual_names_test.py b/tensorflow/contrib/autograph/pyct/qual_names_test.py index 264afd508cdb847315c486806b531dc1483ef622..c793c2bb39df19f1af9b74f33323dbd4c985ee0d 100644 --- a/tensorflow/contrib/autograph/pyct/qual_names_test.py +++ b/tensorflow/contrib/autograph/pyct/qual_names_test.py @@ -30,6 +30,15 @@ from tensorflow.python.platform import test class QNTest(test.TestCase): + def test_from_str(self): + a = QN('a') + b = QN('b') + a_dot_b = QN(a, attr='b') + a_sub_b = QN(a, subscript=b) + self.assertEqual(qual_names.from_str('a.b'), a_dot_b) + self.assertEqual(qual_names.from_str('a'), a) + self.assertEqual(qual_names.from_str('a[b]'), a_sub_b) + def test_basic(self): a = QN('a') self.assertEqual(a.qn, ('a',)) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD index bcf2dacec2062704805f1d72ec27a243159d13c1..92eacba3fd53602ce238dfd7115ff0c3da9b1fc8 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD +++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD @@ -19,8 +19,9 @@ py_library( srcs = [ "activity.py", "annos.py", - "cfg.py", "live_values.py", + "liveness.py", + "reaching_definitions.py", "type_info.py", ], srcs_version = "PY2AND3", @@ -28,6 +29,7 @@ py_library( deps = [ "//tensorflow/contrib/autograph/pyct", "//tensorflow/contrib/autograph/utils", + "//tensorflow/python:util", "@gast_archive//:gast", ], ) @@ -46,23 +48,32 @@ py_test( ) py_test( - name = "cfg_test", - srcs = ["cfg_test.py"], + name = "live_values_test", + srcs = ["live_values_test.py"], srcs_version = "PY2AND3", tags = ["no_windows"], deps = [ ":static_analysis", "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", - "@gast_archive//:gast", ], ) py_test( - name = "live_values_test", - srcs = ["live_values_test.py"], + name = "liveness_test", + srcs = ["liveness_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":static_analysis", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "reaching_definitions_test", + srcs = ["reaching_definitions_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], deps = [ ":static_analysis", "//tensorflow/contrib/autograph/pyct", diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py b/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py index c325e19f28376da3be6db4b00b9f664eac047af2..9a82de735dc663f6a824488e4c5864943cecc3d4 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py @@ -18,10 +18,14 @@ This module contains utilities to help annotate AST nodes with as much runtime information as can be possibly extracted without actually executing the code, under that assumption that the context in which the code will run is known. -Note: It's a fair bet that this analysis cannot be reused across contexts -without re-running it. In most cases, the context usually means referenced -modules, which should be static enough to allow reuse, but that is not being -reliably verified. +Overall, the different analyses have the functions listed below: + + * activity: inventories symbols read, written to, params, etc. at different + levels + * liveness, reaching_definitions: dataflow analyses based on the program's CFG + and using the symbol information gathered by activity analysis + * live_values, type_info: type and value inference based on dataflow + analysis and context information """ from __future__ import absolute_import diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py index 4d7b0cbb7b8f6ee5bd64553644dc3ec9b8bca95b..a0182da9d132f50f290f4ba4896484815efb1286 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Activity analysis.""" +"""Activity analysis. + +Requires qualified name annotations (see qual_names.py). +""" from __future__ import absolute_import from __future__ import division @@ -59,9 +62,10 @@ class Scope(object): self.parent = parent self.add_unknown_symbols = add_unknown_symbols self.modified = set() + # TODO(mdan): Completely remove this. self.created = set() self.used = set() - self.params = set() + self.params = {} self.returned = set() # TODO(mdan): Rename to `locals` @@ -106,37 +110,23 @@ class Scope(object): self.modified |= other.modified self.created |= other.created self.used |= other.used - self.params |= other.params + self.params.update(other.params) self.returned |= other.returned def has(self, name): - if name in self.modified or name in self.params: + if name in self.modified: return True elif self.parent is not None: return self.parent.has(name) return False - def is_modified_since_entry(self, name): - if name in self.modified: - return True - elif self.parent is not None and not self.isolated: - return self.parent.is_modified_since_entry(name) - return False - - def is_param(self, name): - if name in self.params: - return True - elif self.parent is not None and not self.isolated: - return self.parent.is_param(name) - return False - def mark_read(self, name): self.used.add(name) if self.parent is not None and name not in self.created: self.parent.mark_read(name) - def mark_param(self, name): - self.params.add(name) + def mark_param(self, name, owner): + self.params[name] = owner def mark_creation(self, name, writes_create_symbol=False): """Mark a qualified name as created.""" @@ -226,37 +216,56 @@ class ActivityAnalyzer(transformer.Base): elif isinstance(node.ctx, gast.Param): # Param contexts appear in function defs, so they have the meaning of # defining a variable. - # TODO(mdan): This may be incorrect with nested functions. - # For nested functions, we'll have to add the notion of hiding args from - # the parent scope, not writing to them. - self.scope.mark_creation(qn) - self.scope.mark_param(qn) + self.scope.mark_write(qn) + self.scope.mark_param(qn, self.enclosing_entities[-1]) else: raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn)) anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn)) - anno.setanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY, - self.scope.is_modified_since_entry(qn)) - anno.setanno(node, NodeAnno.IS_PARAM, self.scope.is_param(qn)) if self._in_return_statement: self.scope.mark_returned(qn) + def _enter_scope(self, isolated): + self.scope = Scope(self.scope, isolated=isolated) + + def _exit_scope(self): + self.scope = self.scope.parent + + def _process_statement(self, node): + self._enter_scope(False) + node = self.generic_visit(node) + anno.setanno(node, anno.Static.SCOPE, self.scope) + self._exit_scope() + return node + + def visit_Expr(self, node): + return self._process_statement(node) + + def visit_Return(self, node): + self._in_return_statement = True + node = self._process_statement(node) + self._in_return_statement = False + return node + + def visit_Assign(self, node): + return self._process_statement(node) + def visit_AugAssign(self, node): # Special rules for AugAssign. In Assign, the target is only written, # but in AugAssig (e.g. a += b), the target is both read and written. self._in_aug_assign = True - self.generic_visit(node) + node = self._process_statement(node) self._in_aug_assign = False return node def visit_Name(self, node): - self.generic_visit(node) + node = self.generic_visit(node) self._track_symbol(node) return node def visit_Attribute(self, node): - self.generic_visit(node) + node = self.generic_visit(node) if self._in_constructor and self._node_sets_self_attribute(node): self._track_symbol( node, composite_writes_alter_parent=True, writes_create_symbol=True) @@ -265,44 +274,38 @@ class ActivityAnalyzer(transformer.Base): return node def visit_Subscript(self, node): - self.generic_visit(node) + node = self.generic_visit(node) # Subscript writes (e.g. a[b] = "value") are considered to modify # both the element itself (a[b]) and its parent (a). - self._track_symbol(node, composite_writes_alter_parent=True) + self._track_symbol(node) return node def visit_Print(self, node): - current_scope = self.scope - args_scope = Scope(current_scope) - self.scope = args_scope - for n in node.values: - self.visit(n) - anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope) - self.scope = current_scope + self._enter_scope(False) + node.values = self.visit_block(node.values) + anno.setanno(node, anno.Static.SCOPE, self.scope) + anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope) + self._exit_scope() return node + def visit_Assert(self, node): + return self._process_statement(node) + def visit_Call(self, node): - current_scope = self.scope - args_scope = Scope(current_scope, isolated=False) - self.scope = args_scope - for n in node.args: - self.visit(n) + self._enter_scope(False) + node.args = self.visit_block(node.args) + node.keywords = self.visit_block(node.keywords) # TODO(mdan): Account starargs, kwargs - for n in node.keywords: - self.visit(n) - anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope) - self.scope = current_scope - self.visit(node.func) + anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope) + self._exit_scope() + node.func = self.visit(node.func) return node def _process_block_node(self, node, block, scope_name): - current_scope = self.scope - block_scope = Scope(current_scope, isolated=False) - self.scope = block_scope - for n in block: - self.visit(n) - anno.setanno(node, scope_name, block_scope) - self.scope = current_scope + self._enter_scope(False) + block = self.visit_block(block) + anno.setanno(node, scope_name, self.scope) + self._exit_scope() return node def _process_parallel_blocks(self, parent, children): @@ -321,94 +324,75 @@ class ActivityAnalyzer(transformer.Base): self.scope.merge_from(after_child) return parent + def visit_arguments(self, node): + return self._process_statement(node) + def visit_FunctionDef(self, node): - if self.scope: - qn = qual_names.QN(node.name) - self.scope.mark_write(qn) - current_scope = self.scope - body_scope = Scope(current_scope, isolated=True) - self.scope = body_scope - self.generic_visit(node) - anno.setanno(node, NodeAnno.BODY_SCOPE, body_scope) - self.scope = current_scope + # The FunctionDef node itself has a Scope object that tracks the creation + # of its name, along with the usage of any decorator accompany it. + self._enter_scope(False) + node.decorator_list = self.visit_block(node.decorator_list) + self.scope.mark_write(qual_names.QN(node.name)) + anno.setanno(node, anno.Static.SCOPE, self.scope) + self._exit_scope() + + # A separate Scope tracks the actual function definition. + self._enter_scope(True) + node.args = self.visit(node.args) + + # Track the body separately. This is for compatibility reasons, it may not + # be strictly needed. + self._enter_scope(False) + node.body = self.visit_block(node.body) + anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope) + self._exit_scope() + + self._exit_scope() return node def visit_With(self, node): - current_scope = self.scope - with_scope = Scope(current_scope, isolated=False) - self.scope = with_scope - self.generic_visit(node) - anno.setanno(node, NodeAnno.BODY_SCOPE, with_scope) - self.scope = current_scope + self._enter_scope(False) + node = self.generic_visit(node) + anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope) + self._exit_scope() return node - def visit_If(self, node): - current_scope = self.scope - cond_scope = Scope(current_scope, isolated=False) - self.scope = cond_scope - self.visit(node.test) - anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope) - self.scope = current_scope + def visit_withitem(self, node): + return self._process_statement(node) + def visit_If(self, node): + self._enter_scope(False) + node.test = self.visit(node.test) + anno.setanno(node, NodeAnno.COND_SCOPE, self.scope) + anno.setanno(node.test, anno.Static.SCOPE, self.scope) + self._exit_scope() node = self._process_parallel_blocks(node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) return node def visit_For(self, node): - self.visit(node.target) - self.visit(node.iter) + self._enter_scope(False) + node.target = self.visit(node.target) + node.iter = self.visit(node.iter) + anno.setanno(node.iter, anno.Static.SCOPE, self.scope) + self._exit_scope() node = self._process_parallel_blocks(node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) return node def visit_While(self, node): - current_scope = self.scope - cond_scope = Scope(current_scope, isolated=False) - self.scope = cond_scope - self.visit(node.test) - anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope) - self.scope = current_scope - + self._enter_scope(False) + node.test = self.visit(node.test) + anno.setanno(node, NodeAnno.COND_SCOPE, self.scope) + anno.setanno(node.test, anno.Static.SCOPE, self.scope) + self._exit_scope() node = self._process_parallel_blocks(node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) return node - def visit_Return(self, node): - self._in_return_statement = True - node = self.generic_visit(node) - self._in_return_statement = False - return node - - -def get_read(node, context): - """Return the variable names as QNs (qual_names.py) read by this statement.""" - analyzer = ActivityAnalyzer(context, None, True) - analyzer.visit(node) - return analyzer.scope.used - - -def get_updated(node, context): - """Return the variable names created or mutated by this statement. - - This function considers assign statements, augmented assign statements, and - the targets of for loops, as well as function arguments. - For example, `x[0] = 2` will return `x`, `x, y = 3, 4` will return `x` and - `y`, `for i in range(x)` will return `i`, etc. - Args: - node: An AST node - context: An EntityContext instance - - Returns: - A set of variable names (QNs, see qual_names.py) of all the variables - created or mutated. - """ - analyzer = ActivityAnalyzer(context, None, True) - analyzer.visit(node) - return analyzer.scope.created | analyzer.scope.modified - def resolve(node, context, parent_scope=None): return ActivityAnalyzer(context, parent_scope).visit(node) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py index bc22be0a270bbc9c361aea6d6d9c255ea51796e8..e940516190182a905f5747ffdd66533567bac76b 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py @@ -52,18 +52,18 @@ class ScopeTest(test.TestCase): other = activity.Scope(None) other.copy_from(scope) - self.assertTrue(QN('foo') in other.created) + self.assertTrue(QN('foo') in other.modified) scope.mark_write(QN('bar')) scope.copy_from(other) - self.assertFalse(QN('bar') in scope.created) + self.assertFalse(QN('bar') in scope.modified) scope.mark_write(QN('bar')) scope.merge_from(other) - self.assertTrue(QN('bar') in scope.created) - self.assertFalse(QN('bar') in other.created) + self.assertTrue(QN('bar') in scope.modified) + self.assertFalse(QN('bar') in other.modified) def test_copy_of(self): scope = activity.Scope(None) @@ -157,7 +157,8 @@ class ActivityAnalyzerTest(test.TestCase): """Assert the scope contains specific used, modified & created variables.""" self.assertSymbolSetsAre(used, scope.used, 'read') self.assertSymbolSetsAre(modified, scope.modified, 'modified') - self.assertSymbolSetsAre(created, scope.created, 'created') + # Created is deprecated, we're no longer verifying it. + # self.assertSymbolSetsAre(created, scope.created, 'created') def test_print_statement(self): @@ -215,12 +216,6 @@ class ActivityAnalyzerTest(test.TestCase): (), (), ) - self.assertScopeIsRmc( - anno.getanno(call_node, NodeAnno.ARGS_SCOPE).parent, - ('a', 'a.b', 'a.c', 'a.d', 'foo'), - ('a.c',), - ('a',), - ) def test_call_args_subscripts(self): @@ -241,12 +236,6 @@ class ActivityAnalyzerTest(test.TestCase): (), (), ) - self.assertScopeIsRmc( - anno.getanno(call_node, NodeAnno.ARGS_SCOPE).parent, - ('a', 'a[0]', 'a[b]', 'a[c]', 'b', 'c', 'foo'), - ('b', 'c'), - ('a', 'b', 'c'), - ) def test_while(self): @@ -362,20 +351,20 @@ class ActivityAnalyzerTest(test.TestCase): self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a', 'b', 'c', 'a[c]'), - ('a', 'a[b]', 'd'), + ('a[b]', 'd'), ('d',), ) # TODO(mdan): Should subscript writes (a[0] = 1) be considered to read "a"? self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('a', 'e'), - ('a', 'a[0]', 'd'), + ('a[0]', 'd'), ('d',), ) self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('a', 'b', 'c', 'd', 'e', 'a[c]'), - ('a', 'd', 'a[b]', 'a[0]'), + ('d', 'a[b]', 'a[0]'), ('a', 'b', 'c', 'd', 'e'), ) @@ -415,10 +404,6 @@ class ActivityAnalyzerTest(test.TestCase): node, _ = self._parse_and_analyze(test_fn) fn_def_node = node.body[0].body[0] - self.assertScopeIsRmc( - anno.getanno(fn_def_node, - NodeAnno.BODY_SCOPE).parent, ('b', 'i', 'f', 'c', 'a'), - ('f', 'b', 'c', 'i'), ('f', 'a', 'b', 'c', 'i')) self.assertScopeIsRmc( anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), ( 'x', @@ -452,7 +437,7 @@ class ActivityAnalyzerTest(test.TestCase): self.assertScopeIsRmc( anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('a', 'a[0]'), - ('a', 'a[0]'), + ('a[0]',), ('a',), ) @@ -518,47 +503,6 @@ class ActivityAnalyzerTest(test.TestCase): anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('b',), (('')), (('a', 'b'))) - def test_get_read(self): - - def test_fn(x, y): - z = test_fn(x, y) - return z - - node, ctx = self._parse_and_analyze(test_fn) - node = node.body[0].body[0] - read_vars = activity.get_read(node, ctx) - self.assertEqual(read_vars, set(map(qual_names.QN, ('test_fn', 'x', 'y')))) - - def test_fn2(x, y, z): - z += test_fn2(x, y, z) - return z - - node, ctx = self._parse_and_analyze(test_fn2) - node = node.body[0].body[0] - read_vars = activity.get_read(node, ctx) - self.assertEqual(read_vars, - set(map(qual_names.QN, ('test_fn2', 'x', 'y', 'z')))) - - def test_get_updated(self): - - def test_fn(x, y): - z = test_fn(x, y) - return z - - node, ctx = self._parse_and_analyze(test_fn) - node = node.body[0].body[0] - updated_vars = activity.get_updated(node, ctx) - self.assertEqual(updated_vars, set(map(qual_names.QN, ('z')))) - - def test_fn2(x, y, z): - z += test_fn2(x, y, z) - return z - - node, ctx = self._parse_and_analyze(test_fn2) - node = node.body[0].body[0] - updated_vars = activity.get_updated(node, ctx) - self.assertEqual(updated_vars, set(map(qual_names.QN, ('z')))) - if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py index b929b35b79200b0968c9c4f26b10cda28763773a..5eefecf278992f73464817585a3498de4c031978 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py @@ -21,6 +21,9 @@ from __future__ import print_function from enum import Enum +# TODO(mdan): Remove. + + class NoValue(Enum): def __repr__(self): @@ -50,10 +53,3 @@ class NodeAnno(NoValue): ORELSE_SCOPE = ( 'The scope for the orelse body of a statement (False branch for if ' 'statements, orelse body for loops).') - - # Type and Value annotations - # Type annotations are represented by objects of type type_info.Type. - STATIC_INFO = ( - 'The type or value information that should be asserted about the entity ' - 'referenced by the symbol holding this annotation, irrespective of the ' - 'execution context.') diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py deleted file mode 100644 index 39eca6e44441cc28e565d383759cc796d57d6438..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py +++ /dev/null @@ -1,446 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Control flow graph analysis. - -Given a Python AST we construct a control flow graph, with edges both to the -next and previous statements (so it can easily walk the graph both ways). Its -nodes contain the AST of the statements. It can then perform forward or backward -analysis on this CFG. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from collections import namedtuple -import functools -import operator - -import gast - -from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct.static_analysis import activity - - -class CfgNode(object): - """A node in the CFG.""" - __slots__ = ['next', 'value', 'prev'] - - def __init__(self, value): - self.next = set() - self.prev = set() - self.value = value - - -class Cfg(namedtuple('Cfg', ['entry', 'exit'])): - """A Control Flow Graph. - - Each statement is represented as a node. For control flow statements such - as conditionals and loops the conditional itself is a node which either - branches or cycles, respectively. - Attributes: - entry: The entry node, which contains the `gast.arguments` node of the - function definition. - exit: The exit node. This node is special because it has no value (i.e. no - corresponding AST node). This is because Python functions can have - multiple return statements. - """ - pass - - -class CfgBuilder(gast.NodeVisitor): - """Construct a control flow graph. - - Construct a CFG starting from a FunctionDef node. - Usage: - cfg_obj = CfgBuilder().build_cfg(fndef_node) - """ - - def __init__(self): - # The current leaves of the CFG - self.current_leaves = [] - # TODO(alexbw): generalize to break, return, continue, yield, etc. - # A stack of lists, tracking continue statements - self.continue_ = [] - # A stack of lists tracking break nodes - self.break_ = [] - - def set_current_leaves(self, cfg_node): - """Link this cfg_node to the current leaves. - - This is the central function for building the CFG. It links the current - head cfg_nodes to the passed cfg_node. It then resets the head to the - passed cfg_node. - - Args: - cfg_node: A CfgNode instance. - """ - for head in self.current_leaves: - head.next.add(cfg_node) - # While we're linking the CFG forward, add backlinks - cfg_node.prev.add(head) - self.current_leaves = [cfg_node] - - def build_cfg(self, node): - """Build a CFG for a function. - - Implementation of building a CFG for dataflow analysis. See, e.g.: - https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec02-Dataflow.pdf - - Args: - node: A function definition the body of which to analyze. - Returns: - A CFG object. - Raises: - TypeError: If the input is not a function definition. - """ - if not isinstance(node, gast.FunctionDef): - raise TypeError('input must be a function definition') - entry_cfg_node = CfgNode(node.args) - self.current_leaves = [entry_cfg_node] - self.visit_statements(node.body) - exit_cfg_node = CfgNode(None) - self.set_current_leaves(exit_cfg_node) - return Cfg(entry_cfg_node, exit_cfg_node) - - def visit_statements(self, nodes): - for node in nodes: - # Check for control flow - if isinstance(node, (gast.For, gast.While, gast.If, gast.Try, gast.Break, - gast.Continue, gast.With)): - self.visit(node) - else: - expr = CfgNode(node) - self.set_current_leaves(expr) - - def generic_visit(self, node): - raise ValueError('unknown control flow') - - def visit_If(self, node): - # TODO(alexbw): change this to use immutable tuples instead of lists - # The current head will hold the conditional - test = CfgNode(node.test) - self.set_current_leaves(test) - # Handle the body - self.visit_statements(node.body) - body_exit = self.current_leaves - self.current_leaves = [test] - # Handle the orelse - self.visit_statements(node.orelse) - self.current_leaves.extend(body_exit) - - def visit_While(self, node): - test = CfgNode(node.test) - self.set_current_leaves(test) - # Start a new level of nesting - self.break_.append([]) - self.continue_.append([]) - # Handle the body - self.visit_statements(node.body) - body_exit = self.current_leaves - self.current_leaves.extend(self.continue_.pop()) - self.set_current_leaves(test) - # Handle the orelse - self.visit_statements(node.orelse) - # The break statements and the test go to the next node - self.current_leaves.extend(self.break_.pop()) - # Body and orelse statements can reach out of the loop - self.current_leaves.extend(body_exit) - - def visit_For(self, node): - iter_ = CfgNode(node.iter) - self.set_current_leaves(iter_) - self.break_.append([]) - self.continue_.append([]) - self.visit_statements(node.body) - body_exit = self.current_leaves - self.current_leaves.extend(self.continue_.pop()) - self.set_current_leaves(iter_) - # Handle the orelse - self.visit_statements(node.orelse) - # The break statements and the test go to the next node - self.current_leaves.extend(self.break_.pop()) - # Body and orelse statements can reach out of the loop - self.current_leaves.extend(body_exit) - - def visit_Break(self, node): - self.break_[-1].extend(self.current_leaves) - self.current_leaves[:] = [] - - def visit_Continue(self, node): - self.continue_[-1].extend(self.current_leaves) - self.current_leaves[:] = [] - - def visit_Try(self, node): - self.visit_statements(node.body) - body = self.current_leaves - handlers = [] - for handler in node.handlers: - self.current_leaves = body[:] - self.visit_statements(handler.body) - handlers.extend(self.current_leaves) - self.current_leaves = body - self.visit_statements(node.orelse) - self.current_leaves = handlers + self.current_leaves - self.visit_statements(node.finalbody) - - def visit_With(self, node): - for item in node.items: - self.set_current_leaves(CfgNode(item)) - self.visit_statements(node.body) - - -# TODO(alexbw): once CFG analysis occurs at a block level, -# this extra class will not be necessary -class PropagateAnalysis(gast.NodeVisitor): - """Port analysis annotations from statements to their enclosing blocks.""" - - def __init__(self, analysis): - self.transfer_fn = analysis.transfer_fn - self.in_label = analysis.in_label - self.out_label = analysis.out_label - super(PropagateAnalysis, self).__init__() - - def visit_If(self, node): - # Depth-first. - self.generic_visit(node) - incoming = anno.getanno(node.body[0], self.in_label) - incoming |= anno.getanno(node.test, self.in_label) - outgoing = anno.getanno(node.body[-1], self.out_label) - outgoing |= anno.getanno(node.test, self.out_label) - if node.orelse: - orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) - outgoing = self.transfer_fn(outgoing, orelse_outgoing) - anno.setanno(node, self.in_label, incoming) - anno.setanno(node, self.out_label, outgoing) - - def visit_For(self, node): - self.generic_visit(node) - incoming = set(anno.getanno(node.body[0], self.in_label)) - incoming -= set((anno.getanno(node.target, anno.Basic.QN),)) - outgoing = anno.getanno(node.body[-1], self.out_label) - if node.orelse: - orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) - outgoing = self.transfer_fn(outgoing, orelse_outgoing) - anno.setanno(node, self.in_label, frozenset(incoming)) - anno.setanno(node, self.out_label, outgoing) - - def visit_While(self, node): - self.generic_visit(node) - incoming = anno.getanno(node.body[0], self.in_label) - incoming |= anno.getanno(node.test, self.in_label) - outgoing = anno.getanno(node.body[-1], self.out_label) - if node.orelse: - orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) - outgoing = self.transfer_fn(outgoing, orelse_outgoing) - anno.setanno(node, self.in_label, incoming) - anno.setanno(node, self.out_label, outgoing) - - def visit_With(self, node): - self.generic_visit(node) - incoming = anno.getanno(node.body[0], self.in_label) - for item in node.items: - incoming |= anno.getanno(item, self.in_label) - outgoing = anno.getanno(node.body[-1], self.out_label) - anno.setanno(node, self.in_label, incoming) - anno.setanno(node, self.out_label, outgoing) - - -# TODO(alexbw): Abstract the CFG walking machinery into a superclass -# which is parameterized on which fields it selects when walking. -# TODO(alexbw): Abstract the application of dataflow analysis -class Forward(object): - """Forward analysis on CFG. - - Args: - label: A name for this analysis e.g. 'active' for activity analysis. The AST - nodes in the CFG will be given annotations 'name_in', 'name_out', - 'name_gen' and 'name_kill' which contain the incoming values, outgoing - values, values generated by the statement, and values deleted by the - statement respectively. - transfer_fn: Either the AND or OR operator. If the AND operator is used it - turns into forward must analysis (i.e. a value will only be carried - forward if it appears on all incoming paths). The OR operator means that - forward may analysis is done (i.e. the union of incoming values will be - taken). - """ - - def __init__(self, label, source_info, transfer_fn=operator.or_): - self.transfer_fn = transfer_fn - self.source_info = source_info - self.out_label = label + '_out' - self.in_label = label + '_in' - self.gen_label = label + '_gen' - self.kill_label = label + '_kill' - - # TODO(alexbw): see if we can simplify by visiting breadth-first - def visit(self, node): - """Depth-first walking the CFG, applying dataflow information propagation.""" - # node.value is None only for the exit CfgNode. - if not node.value: - return - - if anno.hasanno(node.value, self.out_label): - before = hash(anno.getanno(node.value, self.out_label)) - else: - before = None - preds = [ - anno.getanno(pred.value, self.out_label) - for pred in node.prev - if anno.hasanno(pred.value, self.out_label) - ] - if preds: - incoming = functools.reduce(self.transfer_fn, preds[1:], preds[0]) - else: - incoming = frozenset() - anno.setanno(node.value, self.in_label, incoming) - gen, kill = self.get_gen_kill(node, incoming) - anno.setanno(node.value, self.gen_label, gen) - anno.setanno(node.value, self.kill_label, kill) - anno.setanno(node.value, self.out_label, (incoming - kill) | gen) - - if hash(anno.getanno(node.value, self.out_label)) != before: - for succ in node.next: - self.visit(succ) - - def get_gen_kill(self, cfg_node, incoming): - """Calculate Gen and Kill properties of a CFG node in dataflow analysis. - - A function which takes the CFG node as well as a set of incoming - values. It must return a set of newly generated values by the statement as - well as a set of deleted (killed) values. - - Args: - cfg_node: A CfgNode instance. - incoming: - """ - raise NotImplementedError() - - -class Backward(Forward): - """Backward analysis on CFG.""" - - def visit(self, cfg_node): - # cfg_node.value is None for the exit node, which will be visited only once - if not cfg_node.value: - for pred in cfg_node.prev: - self.visit(pred) - return - - if anno.hasanno(cfg_node.value, self.in_label): - before = hash(anno.getanno(cfg_node.value, self.in_label)) - else: - before = None - succs = [ - anno.getanno(succ.value, self.in_label) - for succ in cfg_node.next - if anno.hasanno(succ.value, self.in_label) - ] - if succs: - incoming = functools.reduce(self.transfer_fn, succs[1:], succs[0]) - else: - incoming = frozenset() - anno.setanno(cfg_node.value, self.out_label, incoming) - gen, kill = self.get_gen_kill(cfg_node, incoming) - anno.setanno(cfg_node.value, self.gen_label, gen) - anno.setanno(cfg_node.value, self.kill_label, kill) - anno.setanno(cfg_node.value, self.in_label, (incoming - kill) | gen) - if hash(anno.getanno(cfg_node.value, self.in_label)) != before: - for pred in cfg_node.prev: - self.visit(pred) - - -def run_analyses(node, analyses): - """Perform dataflow analysis on all functions within an AST. - - Args: - node: An AST node on which to run dataflow analysis. - analyses: Either an instance of the Forward or Backward dataflow analysis - class, or a list or tuple of them. - - Returns: - node: The node, but now with annotations on the AST nodes containing the - results of the dataflow analyses. - """ - if not isinstance(analyses, (tuple, list)): - analyses = (analyses,) - for analysis in analyses: - if not isinstance(analysis, (Forward, Backward)): - raise TypeError('not a valid forward analysis object') - - for child_node in gast.walk(node): - if isinstance(child_node, gast.FunctionDef): - cfg_obj = CfgBuilder().build_cfg(child_node) - for analysis in analyses: - if isinstance(analysis, Backward): - analysis.visit(cfg_obj.exit) - elif isinstance(analysis, Forward): - analysis.visit(cfg_obj.entry) - for analysis in analyses: - PropagateAnalysis(analysis).visit(node) - return node - - -class Liveness(Backward): - """Perform a liveness analysis. - - Each statement is annotated with a set of variables that may be used - later in the program. - """ - - def __init__(self, source_info): - super(Liveness, self).__init__('live', source_info) - - def get_gen_kill(self, node, _): - # A variable's parents are live if it is live - # e.g. x is live if x.y is live. This means gen needs to return - # all parents of a variable (if it's an Attribute or Subscript). - # This doesn't apply to kill (e.g. del x.y doesn't affect liveness of x) - gen = activity.get_read(node.value, self.source_info) - gen = functools.reduce(lambda left, right: left | right.support_set, gen, - gen) - kill = activity.get_updated(node.value, self.source_info) - return gen, kill - - -class ReachingDefinitions(Forward): - """Perform reaching definition analysis. - - Each statement is annotated with a set of (variable, definition) pairs. - """ - - def __init__(self, source_info): - super(ReachingDefinitions, self).__init__('definitions', source_info) - - def get_gen_kill(self, node, incoming): - definitions = activity.get_updated(node.value, self.source_info) - gen = frozenset((id_, node.value) for id_ in definitions) - kill = frozenset(def_ for def_ in incoming if def_[0] in definitions) - return gen, kill - - -class Defined(Forward): - """Perform defined variable analysis. - - Each statement is annotated with a set of variables which are guaranteed to - be defined at that point. - """ - - def __init__(self, source_info): - super(Defined, self).__init__( - 'defined', source_info, transfer_fn=operator.and_) - - def get_gen_kill(self, node, _): - gen = activity.get_updated(node.value, self.source_info) - return gen, frozenset() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py deleted file mode 100644 index 428ebbedca85f9b94b4b1db0f3b36a334126196b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py +++ /dev/null @@ -1,303 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for cfg module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools - -import gast - -from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import parser -from tensorflow.contrib.autograph.pyct import qual_names -from tensorflow.contrib.autograph.pyct import transformer -from tensorflow.contrib.autograph.pyct.static_analysis import cfg -from tensorflow.python.platform import test - - -class CFGTest(test.TestCase): - - def _parse_and_analyze(self, test_fn): - node, source = parser.parse_entity(test_fn) - entity_info = transformer.EntityInfo( - source_code=source, - source_file=None, - namespace={}, - arg_values=None, - arg_types=None, - owner_type=None) - node = qual_names.resolve(node) - return node, entity_info - - def _check_anno_matches(self, node, anno_name, var_names): - if isinstance(var_names, str): - var_names = (var_names,) - qual_vars = set() - for var_name in var_names: - if isinstance(var_name, str): - if '[' in var_name or ']' in var_name: - raise ValueError('Annotation matching not supported with subscript.') - if '.' not in var_name: - qual_vars.add(qual_names.QN(var_name)) - else: - attrs = var_name.split('.') - this_qn = functools.reduce(qual_names.QN, attrs[1:], - qual_names.QN(attrs[0])) - qual_vars.add(this_qn) - self.assertEqual(anno.getanno(node, anno_name), qual_vars) - - def test_reaching(self): - - def f(x): - print(x) - while True: - x = x - x = x - return x - - node, ctx = self._parse_and_analyze(f) - cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) - body = node.body[0].body - # Only the argument reaches the expression - def_in = anno.getanno(body[0], 'definitions_in') - # One element, x, from arguments - self.assertEqual(set(type(d[1]) for d in def_in), set((gast.arguments,))) - - while_body = body[1].body - def_in = anno.getanno(while_body[0], 'definitions_in') - # One definition, two possible sources. - # - One from an assignment (if the loop is entered) - # - The other from the arguments (if loop is not entered) - self.assertEqual( - set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign))) - - def_in = anno.getanno(while_body[1], 'definitions_in') - # If we've reached this line, the only reaching definition of x is the - # Assign node in previous line - self.assertEqual(set(type(d[1]) for d in def_in), set((gast.Assign,))) - - def_in = anno.getanno(body[2], 'definitions_in') - # Same situation as while_body[0] - self.assertEqual( - set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign))) - - def test_defined(self): - - def f(x): - if x: - y = 2 # pylint: disable=unused-variable - return x - - node, ctx = self._parse_and_analyze(f) - cfg.run_analyses(node, cfg.Defined(ctx)) - body = node.body[0].body - # only x is for sure defined at the end - self._check_anno_matches(body[1], 'defined_in', 'x') - # at the end of the if body both x and y are defined - if_body = body[0].body - self._check_anno_matches(if_body[0], 'defined_out', ('x', 'y')) - - def _get_live_annotated_fnbody(self, f): - node, ctx = self._parse_and_analyze(f) - cfg.run_analyses(node, cfg.Liveness(ctx)) - body = node.body[0].body - return body - - def test_live_straightline(self): - - def f1(x): - a = g(x) # pylint: disable=undefined-variable - b = h(a) # pylint: disable=undefined-variable, unused-variable - return x - - body = self._get_live_annotated_fnbody(f1) - self._check_anno_matches(body[1], 'live_in', ('a', 'h', 'x')) - self._check_anno_matches(body[2], 'live_in', ('x')) - self._check_anno_matches(body[0], 'live_in', ('g', 'h', 'x')) - self._check_anno_matches(body[2], 'live_out', ()) - - def test_live_stacked_conds_with_else(self): - - def f2(x, a): # pylint: disable=unused-argument - if a > 0: # x should not be live - x = 0 - if a > 1: - x = 1 - else: - x = 2 - - body = self._get_live_annotated_fnbody(f2) - self._check_anno_matches(body[0], 'live_in', ('a')) - self._check_anno_matches(body[1], 'live_in', ('a')) - - def test_live_stacked_conds(self): - - def f3(x, a): - if a > 0: # x and a should be live - x = 0 - if a > 1: # x and a should be live_in - x = 1 - return x # x should be live - - body = self._get_live_annotated_fnbody(f3) - self._check_anno_matches(body[0], 'live_in', ('a', 'x')) - self._check_anno_matches(body[1], 'live_in', ('a', 'x')) - self._check_anno_matches(body[2], 'live_in', ('x')) - - def test_live_possibly_unused_cond(self): - - def f4(x, a): - if a > 0: # x should be live - x = 0 - x += 1 - - body = self._get_live_annotated_fnbody(f4) - self._check_anno_matches(body[0], 'live_in', ('x', 'a')) - self._check_anno_matches(body[1], 'live_in', ('x')) - - def test_live_attribute_in_cond(self): - - def f5(x, a): - if a > 0: # x.y should be live - x.y = 0 - return x.y - - body = self._get_live_annotated_fnbody(f5) - self._check_anno_matches(body[0], 'live_in', ('x', 'x.y', 'a')) - - def test_live_noop(self): - - def f6(x): - return x # should this cause x.* to be live? - - body = self._get_live_annotated_fnbody(f6) - self._check_anno_matches(body[0], 'live_in', ('x')) - - def test_live_loop(self): - - def f7(x, n): - for i in range(n): - x += i - return x - - body = self._get_live_annotated_fnbody(f7) - self._check_anno_matches(body[0], 'live_in', ('x', 'n', 'range')) - self._check_anno_matches(body[1], 'live_in', ('x')) - - def test_live_context_manager(self): - - def f8(x, f): - with f: - x += 1 - - body = self._get_live_annotated_fnbody(f8) - self._check_anno_matches(body[0], 'live_in', ('f', 'x')) - - def test_node_equality(self): - node_a = gast.parse('y = x').body[0] - node_b = gast.parse('y = x').body[0] - self.assertNotEqual(node_a, node_b) - - def test_nested_functions_defined(self): - - def f(x): - y = x * 2 - - def g(z): - return z + y - - return g(x) - - node, ctx = self._parse_and_analyze(f) - cfg.run_analyses(node, cfg.Defined(ctx)) - - body = node.body[0].body - self.assertEqual( - anno.getanno(body[2], 'defined_in'), - frozenset(map(qual_names.QN, ('g', 'x', 'y')))) - - # TODO(alexbw): CFG analysis doesn't currently cross FunctionDef boundaries. - # NOTE: 'z' is easy to find, but 'y' is not identified as - # defined, because CFG analysis is applied with each function separately. - # fndef_body = body[1].body - # self.assertEqual( - # anno.getanno(fndef_body[0], 'defined_in'), - # frozenset(map(qual_names.QN, ('z', 'y')))) - - def test_nested_functions_dont_leak_definitions(self): - - def f(x): - print(x) - - def g(): - y = 2 - return y - - return g() # y is not defined here - - node, ctx = self._parse_and_analyze(f) - cfg.run_analyses(node, cfg.Defined(ctx)) - body = node.body[0].body - self.assertEqual( - anno.getanno(body[2], 'defined_in'), - frozenset(map(qual_names.QN, ('x', 'g')))) - - def test_loop_else(self): - - # Disabling useless-else-on-loop error, because 'break' and 'continue' - # canonicalization are a separate analysis pass, and here we test - # the CFG analysis in isolation. - def for_orelse(x): - y = 0 - for i in range(len(x)): - x += i - else: # pylint: disable=useless-else-on-loop - y = 1 - return x, y - - def while_orelse(x, i): - y = 0 - while x < 10: - x += i - else: # pylint: disable=useless-else-on-loop - y = 1 - return x, y - - for f in (for_orelse, while_orelse): - node, ctx = self._parse_and_analyze(f) - cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) - body = node.body[0].body - return_node = body[-1] - reaching_defs = anno.getanno(return_node, 'definitions_in') - - # Y could be defined by Assign(Num(0)) or Assign(Num(1)) - # X could be defined as an argument or an AugAssign. - y_defs = [node for var, node in reaching_defs if str(var) == 'y'] - x_defs = [node for var, node in reaching_defs if str(var) == 'x'] - - self.assertEqual(set((gast.Assign,)), set(type(def_) for def_ in y_defs)) - self.assertEqual(set((0, 1)), set(def_.value.n for def_ in y_defs)) - self.assertEqual(len(y_defs), 2) - self.assertEqual( - set((gast.arguments, gast.AugAssign)), - set(type(def_) for def_ in x_defs)) - self.assertEqual(len(x_defs), 2) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py index 9ccb98f79adbe5410a7554548ee75ab95345962d..2d8f922a4589e45ab7e4f20f800e0ffef3d7f0a5 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py @@ -16,7 +16,7 @@ Live values are extracted from the known execution context. -Requires activity analysis annotations. +Requires activity and reaching definitions analyses. """ from __future__ import absolute_import @@ -45,14 +45,12 @@ class LiveValueResolver(transformer.Base): def visit_Name(self, node): self.generic_visit(node) if isinstance(node.ctx, gast.Load): - assert anno.hasanno(node, NodeAnno.IS_LOCAL), node - symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL) - assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node - symbol_is_modified = anno.getanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY) - assert anno.hasanno(node, NodeAnno.IS_PARAM), node - symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM) - - if not symbol_is_local and not symbol_is_param: + defs = anno.getanno(node, anno.Static.DEFINITIONS, ()) + + is_defined = bool(defs) + has_single_def = len(defs) == 1 + + if not is_defined: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) elif node.id in self.entity_info.namespace: @@ -79,11 +77,13 @@ class LiveValueResolver(transformer.Base): # TODO(mdan): Attempt to trace its value through the local chain. # TODO(mdan): Use type annotations as fallback. - if not symbol_is_modified: - if node.id in self.entity_info.arg_values: - obj = self.entity_info.arg_values[node.id] - anno.setanno(node, 'live_val', obj) - anno.setanno(node, 'fqn', (obj.__class__.__name__,)) + if has_single_def: + def_, = defs + if def_.param_of is self.enclosing_entities[0]: + if node.id in self.entity_info.arg_values: + obj = self.entity_info.arg_values[node.id] + anno.setanno(node, 'live_val', obj) + anno.setanno(node, 'fqn', (obj.__class__.__name__,)) return node def visit_Attribute(self, node): @@ -91,12 +91,20 @@ class LiveValueResolver(transformer.Base): if anno.hasanno(node.value, 'live_val'): assert anno.hasanno(node.value, 'fqn') parent_object = anno.getanno(node.value, 'live_val') - if not hasattr(parent_object, node.attr): - raise AttributeError('%s has no attribute %s' % (parent_object, - node.attr)) + anno.setanno(node, 'parent_type', type(parent_object)) - anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,)) + if hasattr(parent_object, node.attr): + # This can happen when the attribute's creation and use depend on the + # same static condition, for example: + # + # if cond: + # foo.bar = baz + # if cond: + # x = foo.bar + # + anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) + # TODO(mdan): Investigate the role built-in annotations can play here. elif anno.hasanno(node.value, 'type'): parent_type = anno.getanno(node.value, 'type') diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py index 38af79277779f77ffe31c2f6e26ae88f3e1a7ae9..fe3051179cd93ddd2627802dd2536bb50f17fb7f 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py @@ -21,11 +21,13 @@ from __future__ import print_function import six from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import cfg from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.framework import constant_op from tensorflow.python.platform import test @@ -48,7 +50,10 @@ class LiveValuesResolverTest(test.TestCase): arg_types=arg_types, owner_type=None) node = qual_names.resolve(node) + graphs = cfg.build(node) node = activity.resolve(node, entity_info) + node = reaching_definitions.resolve(node, entity_info, graphs, + reaching_definitions.Definition) node = live_values.resolve(node, entity_info, literals) node = type_info.resolve(node, entity_info) node = live_values.resolve(node, entity_info, literals) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/liveness.py b/tensorflow/contrib/autograph/pyct/static_analysis/liveness.py new file mode 100644 index 0000000000000000000000000000000000000000..bf29d868a2e4d2a4c7dd1057c0ed93e54d01d750 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/liveness.py @@ -0,0 +1,200 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Live variable analysis. + +This analysis attaches a set containing the live symbols that are live at the +exit of control flow statements. + +Requires activity analysis. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import cfg +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis import annos + + +class Analyzer(cfg.GraphVisitor): + """CFG visitor that performs liveness analysis at statement level.""" + + def __init__(self, graph): + super(Analyzer, self).__init__(graph) + # This allows communicating that nodes generate extra symbols, + # e.g. those that a function definition closes over. + self.extra_gen = {} + + def init_state(self, _): + return set() + + def visit_node(self, node): + prev_live_in = self.in_[node] + + if anno.hasanno(node.ast_node, anno.Static.SCOPE): + node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) + + gen = node_scope.used | self.extra_gen.get(node.ast_node, frozenset()) + # TODO(mdan): verify whether composites' parents need to be added. + # E.g. if x.y is live whether x needs to be added. Theoretically the + # activity analysis should have both so that wouldn't be needed. + kill = node_scope.modified + + live_out = set() + for n in node.next: + live_out |= self.in_[n] + live_in = gen | (live_out - kill) + + else: + # Nodes that don't have a scope annotation are assumed not to touch any + # symbols. + # This Name node below is a literal name, e.g. False + assert isinstance(node.ast_node, + (gast.Name, gast.Continue, gast.Break)), type( + node.ast_node) + live_in = prev_live_in + live_out = live_in + + self.in_[node] = live_in + self.out[node] = live_out + + # TODO(mdan): Move this to the superclass? + return prev_live_in != live_in + + +class WholeTreeAnalyzer(transformer.Base): + """Runs liveness analysis on each of the functions defined in the AST. + + If a function defined other local functions, those will have separate CFGs. + However, dataflow analysis needs to tie up these CFGs to properly emulate the + effect of closures. In the case of liveness, the parent function's live + variables must account for the variables that are live at the entry of each + subfunction. For example: + + def foo(): + # baz is live here + def bar(): + print(baz) + + This analyzer runs liveness analysis on each individual function, accounting + for the effect above. + """ + + def __init__(self, source_info, graphs): + super(WholeTreeAnalyzer, self).__init__(source_info) + self.graphs = graphs + self.current_analyzer = None + self.analyzers = {} + + def visit_FunctionDef(self, node): + parent_analyzer = self.current_analyzer + subgraph = self.graphs[node] + + # Postorder tree processing makes this a bit complicated: + # 1. construct an analyzer object and put it on stack + # 2. recursively walk the subtree; this will initialize the analyzer's + # in_ state properly (done in a block below) + # 3. run the final analysis + analyzer = Analyzer(subgraph) + self.current_analyzer = analyzer + node = self.generic_visit(node) + analyzer.visit_reverse() + + if parent_analyzer is not None: + # Wire the state between the two subgraphs' analyzers. + child_in_state = analyzer.in_[subgraph.entry] + # Exception: symbols modified in the child function are local to it + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + for qn in body_scope.modified: + # Note: a function modifying the symbol doesn't make that symbol + # live at the function's entry. In fact when that happens it is + # probably a case of undefined assignment, like this: + # + # bar = 0 + # def foo(): + # print(bar) # bar is undefined here! + # bar = 1 + # + # Hence we use discard and not remove below. + child_in_state.discard(qn) + parent_analyzer.extra_gen[node] = frozenset(child_in_state,) + + self.analyzers[node] = analyzer + self.current_analyzer = parent_analyzer + return node + + def visit_nonlocal(self, node): + raise NotImplementedError() + + def visit_global(self, node): + raise NotImplementedError() + + +class Annotator(transformer.Base): + """AST visitor that annotates each control flow block with live symbols.""" + + # Note: additional nodes may be added as needed. + + def __init__(self, source_info, cross_function_analyzer): + super(Annotator, self).__init__(source_info) + self.cross_function_analyzer = cross_function_analyzer + self.current_analyzer = None + + def visit_FunctionDef(self, node): + parent_analyzer = self.current_analyzer + self.current_analyzer = self.cross_function_analyzer.analyzers[node] + + node = self.generic_visit(node) + self.current_analyzer = parent_analyzer + return node + + def _aggregate_successors_live_in(self, node): + successors = self.current_analyzer.graph.stmt_next[node] + node_live_out = set() + for s in successors: + node_live_out.update(self.current_analyzer.in_[s]) + anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(node_live_out)) + node = self.generic_visit(node) + return node + + def visit_If(self, node): + return self._aggregate_successors_live_in(node) + + def visit_For(self, node): + return self._aggregate_successors_live_in(node) + + def visit_While(self, node): + return self._aggregate_successors_live_in(node) + + +def resolve(node, source_info, graphs): + """Resolves the live symbols at the exit of control flow statements. + + Args: + node: ast.AST + source_info: transformer.SourceInfo + graphs: Dict[ast.FunctionDef, cfg.Graph] + Returns: + ast.AST + """ + cross_function_analyzer = WholeTreeAnalyzer(source_info, graphs) + node = cross_function_analyzer.visit(node) + visitor = Annotator(source_info, cross_function_analyzer) + node = visitor.visit(node) + return node diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d53adb28af03f0de14f319f642ee82928a480e3a --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py @@ -0,0 +1,149 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for liveness module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import cfg +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import liveness +from tensorflow.python.platform import test + + +class LivenessTest(test.TestCase): + + def _parse_and_analyze(self, test_fn): + node, source = parser.parse_entity(test_fn) + entity_info = transformer.EntityInfo( + source_code=source, + source_file=None, + namespace={}, + arg_values=None, + arg_types=None, + owner_type=None) + node = qual_names.resolve(node) + node = activity.resolve(node, entity_info) + graphs = cfg.build(node) + liveness.resolve(node, entity_info, graphs) + return node + + def assertHasLiveOut(self, node, expected): + live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) + live_out_str = set(str(v) for v in live_out) + if not expected: + expected = () + if not isinstance(expected, tuple): + expected = (expected,) + self.assertSetEqual(live_out_str, set(expected)) + + def test_stacked_if(self): + + def test_fn(x, a): + if a > 0: + x = 0 + if a > 1: + x = 1 + return x + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveOut(fn_body[0], ('a', 'x')) + self.assertHasLiveOut(fn_body[1], 'x') + + def test_stacked_if_else(self): + + def test_fn(x, a): + if a > 0: + x = 0 + if a > 1: + x = 1 + else: + x = 2 + return x + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveOut(fn_body[0], 'a') + self.assertHasLiveOut(fn_body[1], 'x') + + def test_for_basic(self): + + def test_fn(x, a): + for i in range(a): + x += i + return x + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveOut(fn_body[0], 'x') + + def test_attributes(self): + + def test_fn(x, a): + if a > 0: + x.y = 0 + return x.y + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveOut(fn_body[0], ('x.y', 'x')) + + def test_nested_functions(self): + + def test_fn(a, b): + if b: + a = [] + + def foo(): + return a + + foo() + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveOut(fn_body[0], 'a') + + def test_nested_functions_isolation(self): + + def test_fn(b): + if b: + a = 0 # pylint:disable=unused-variable + + def child(): + max(a) # pylint:disable=used-before-assignment + a = 1 + return a + + child() + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasLiveOut(fn_body[0], 'max') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py new file mode 100644 index 0000000000000000000000000000000000000000..7f2b379d3de236020f1ec2b8a4972cc67b10b060 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py @@ -0,0 +1,301 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reaching definition analysis. + +This analysis attaches a set of a Definition objects to each symbol, one +for each distinct definition that may reach it. The Definition objects are +mutable and may be used by subsequent analyses to further annotate data like +static type and value information. +The analysis also attaches the set of the symbols defined at the entry of +control flow statements. + +Requires activity analysis. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import cfg +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis import annos + + +class Definition(object): + """Definition objects describe a unique definition of a variable. + + Subclasses of this may be used by passing an appropriate factory function to + resolve. + + Attributes: + param_of: Optional[ast.AST] + """ + + def __init__(self): + self.param_of = None + + def __repr__(self): + return '%s[%d]' % (self.__class__.__name__, id(self)) + + +class _NodeState(object): + """Abstraction for the state of the CFG walk for reaching definition analysis. + + This is a value type. Only implements the strictly necessary operators. + + Attributes: + value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and + their possible definitions + """ + + def __init__(self, init_from=None): + if init_from: + if isinstance(init_from, _NodeState): + self.value = { + s: set(other_infos) for s, other_infos in init_from.value.items() + } + elif isinstance(init_from, dict): + self.value = {s: set((init_from[s],)) for s in init_from} + else: + assert False, init_from + else: + self.value = {} + + def __eq__(self, other): + if frozenset(self.value.keys()) != frozenset(other.value.keys()): + return False + ret = all(self.value[s] == other.value[s] for s in self.value) + return ret + + def __ne__(self, other): + return not self.__eq__(other) + + def __or__(self, other): + assert isinstance(other, _NodeState) + result = _NodeState(self) + for s, other_infos in other.value.items(): + if s in result.value: + result.value[s].update(other_infos) + else: + result.value[s] = set(other_infos) + return result + + def __sub__(self, other): + assert isinstance(other, set) + result = _NodeState(self) + for s in other: + result.value.pop(s, None) + return result + + def __repr__(self): + return 'NodeState[%s]=%s' % (id(self), repr(self.value)) + + +class Analyzer(cfg.GraphVisitor): + """CFG visitor that determines reaching definitions at statement level.""" + + def __init__(self, graph, definition_factory): + self._definition_factory = definition_factory + super(Analyzer, self).__init__(graph) + # This allows communicating that nodes have extra reaching definitions, + # e.g. those that a function closes over. + self.extra_in = {} + + self.gen_map = {} + + def init_state(self, _): + return _NodeState() + + def visit_node(self, node): + prev_defs_out = self.out[node] + + defs_in = _NodeState(self.extra_in.get(node.ast_node, None)) + for n in node.prev: + defs_in |= self.out[n] + + if anno.hasanno(node.ast_node, anno.Static.SCOPE): + node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE) + # The definition objects created by each node must be singletons because + # their ids are used in equality checks. + if node not in self.gen_map: + node_symbols = {} + for s in node_scope.modified: + def_ = self._definition_factory() + if s in node_scope.params: + def_.param_of = node_scope.params[s] + node_symbols[s] = def_ + self.gen_map[node] = _NodeState(node_symbols) + + gen = self.gen_map[node] + kill = node_scope.modified + defs_out = gen | (defs_in - kill) + + else: + # Nodes that don't have a scope annotation are assumed not to touch any + # symbols. + # This Name node below is a literal name, e.g. False + # This can also happen if activity.py forgot to annotate the node with a + # scope object. + assert isinstance( + node.ast_node, + (gast.Name, gast.Break, gast.Continue, gast.Raise)), (node.ast_node, + node) + defs_out = defs_in + + self.in_[node] = defs_in + self.out[node] = defs_out + + # TODO(mdan): Move this to the superclass? + return prev_defs_out != defs_out + + +class TreeAnnotator(transformer.Base): + """AST visitor that annotates each symbol name with its reaching definitions. + + Simultaneously, the visitor runs the dataflow analysis on each function node, + accounting for the effect of closures. For example: + + def foo(): + bar = 1 + def baz(): + # bar = 1 reaches here + """ + + def __init__(self, source_info, graphs, definition_factory): + super(TreeAnnotator, self).__init__(source_info) + self.definition_factory = definition_factory + self.graphs = graphs + self.current_analyzer = None + self.current_cfg_node = None + + def visit_FunctionDef(self, node): + parent_analyzer = self.current_analyzer + subgraph = self.graphs[node] + + # Preorder tree processing: + # 1. if this is a child function, the parent was already analyzed and it + # has the proper state value for the subgraph's entry + # 2. analyze the current function body + # 2. recursively walk the subtree; child functions will be processed + analyzer = Analyzer(subgraph, self.definition_factory) + if parent_analyzer is not None: + # Wire the state between the two subgraphs' analyzers. + parent_out_state = parent_analyzer.out[parent_analyzer.graph.index[node]] + # Exception: symbols modified in the child function are local to it + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + parent_out_state -= body_scope.modified + analyzer.extra_in[node.args] = parent_out_state + + # Complete the analysis for the local function and annotate its body. + analyzer.visit_forward() + + # Recursively process any remaining subfunctions. + self.current_analyzer = analyzer + # Note: not visiting name, decorator_list and returns because they don't + # apply to this anlysis. + # TODO(mdan): Should we still process the function name? + node.args = self.visit(node.args) + node.body = self.visit_block(node.body) + self.current_analyzer = parent_analyzer + + return node + + def visit_nonlocal(self, node): + raise NotImplementedError() + + def visit_global(self, node): + raise NotImplementedError() + + def visit_Name(self, node): + if self.current_analyzer is None: + # Names may appear outside function defs - for example in class + # definitions. + return node + + analyzer = self.current_analyzer + cfg_node = self.current_cfg_node + + assert cfg_node is not None, 'name node outside of any statement?' + + qn = anno.getanno(node, anno.Basic.QN) + if isinstance(node.ctx, gast.Load): + anno.setanno(node, anno.Static.DEFINITIONS, + tuple(analyzer.in_[cfg_node].value.get(qn, ()))) + else: + anno.setanno(node, anno.Static.DEFINITIONS, + tuple(analyzer.out[cfg_node].value.get(qn, ()))) + + return node + + def _aggregate_predecessors_defined_in(self, node): + preds = self.current_analyzer.graph.stmt_prev[node] + node_defined_in = set() + for p in preds: + node_defined_in |= set(self.current_analyzer.out[p].value.keys()) + anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in)) + + def visit_If(self, node): + self._aggregate_predecessors_defined_in(node) + return self.generic_visit(node) + + def visit_For(self, node): + self._aggregate_predecessors_defined_in(node) + + # Manually accounting for the shortcoming described in + # cfg.AstToCfg.visit_For. + parent = self.current_cfg_node + self.current_cfg_node = self.current_analyzer.graph.index[node.iter] + node.target = self.visit(node.target) + self.current_cfg_node = parent + + node.iter = self.visit(node.iter) + node.body = self.visit_block(node.body) + node.orelse = self.visit_block(node.orelse) + + return node + + def visit_While(self, node): + self._aggregate_predecessors_defined_in(node) + return self.generic_visit(node) + + def visit(self, node): + parent = self.current_cfg_node + + if (self.current_analyzer is not None and + node in self.current_analyzer.graph.index): + self.current_cfg_node = self.current_analyzer.graph.index[node] + node = super(TreeAnnotator, self).visit(node) + + self.current_cfg_node = parent + return node + + +def resolve(node, source_info, graphs, definition_factory): + """Resolves reaching definitions for each symbol. + + Args: + node: ast.AST + source_info: transformer.SourceInfo + graphs: Dict[ast.FunctionDef, cfg.Graph] + definition_factory: Callable[[], Definition] + Returns: + ast.AST + """ + visitor = TreeAnnotator(source_info, graphs, definition_factory) + node = visitor.visit(node) + return node diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py new file mode 100644 index 0000000000000000000000000000000000000000..243fe804b229686f33a4964b16c987c673a97c4b --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py @@ -0,0 +1,263 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for reaching_definitions module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import cfg +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions +from tensorflow.python.platform import test + + +class DefinitionInfoTest(test.TestCase): + + def _parse_and_analyze(self, test_fn): + node, source = parser.parse_entity(test_fn) + entity_info = transformer.EntityInfo( + source_code=source, + source_file=None, + namespace={}, + arg_values=None, + arg_types=None, + owner_type=None) + node = qual_names.resolve(node) + node = activity.resolve(node, entity_info) + graphs = cfg.build(node) + node = reaching_definitions.resolve(node, entity_info, graphs, + reaching_definitions.Definition) + return node + + def assertHasDefs(self, node, num): + defs = anno.getanno(node, anno.Static.DEFINITIONS) + self.assertEqual(len(defs), num) + for r in defs: + self.assertIsInstance(r, reaching_definitions.Definition) + + def assertHasDefinedIn(self, node, expected): + defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) + defined_in_str = set(str(v) for v in defined_in) + if not expected: + expected = () + if not isinstance(expected, tuple): + expected = (expected,) + self.assertSetEqual(defined_in_str, set(expected)) + + def assertSameDef(self, first, second): + self.assertHasDefs(first, 1) + self.assertHasDefs(second, 1) + self.assertIs( + anno.getanno(first, anno.Static.DEFINITIONS)[0], + anno.getanno(second, anno.Static.DEFINITIONS)[0]) + + def assertNotSameDef(self, first, second): + self.assertHasDefs(first, 1) + self.assertHasDefs(second, 1) + self.assertIsNot( + anno.getanno(first, anno.Static.DEFINITIONS)[0], + anno.getanno(second, anno.Static.DEFINITIONS)[0]) + + def test_conditional(self): + + def test_fn(a, b): + a = [] + if b: + a = [] + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasDefs(fn_body[0].targets[0], 1) + self.assertHasDefs(fn_body[1].test, 1) + self.assertHasDefs(fn_body[1].body[0].targets[0], 1) + self.assertHasDefs(fn_body[2].value, 2) + + self.assertHasDefinedIn(fn_body[1], ('a', 'b')) + + def test_while(self): + + def test_fn(a): + max(a) + while True: + a = a + a = a + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasDefs(fn_body[0].value.args[0], 1) + self.assertHasDefs(fn_body[1].body[0].targets[0], 1) + self.assertHasDefs(fn_body[1].body[1].targets[0], 1) + self.assertHasDefs(fn_body[1].body[1].value, 1) + # The loop does have an invariant test, but the CFG doesn't know that. + self.assertHasDefs(fn_body[1].body[0].value, 2) + self.assertHasDefs(fn_body[2].value, 2) + + def test_while_else(self): + + def test_fn(x, i): + y = 0 + while x: + x += i + if i: + break + else: + y = 1 + return x, y + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasDefs(fn_body[0].targets[0], 1) + self.assertHasDefs(fn_body[1].test, 2) + self.assertHasDefs(fn_body[1].body[0].target, 1) + self.assertHasDefs(fn_body[1].body[1].test, 1) + self.assertHasDefs(fn_body[1].orelse[0].targets[0], 1) + self.assertHasDefs(fn_body[2].value.elts[0], 2) + self.assertHasDefs(fn_body[2].value.elts[1], 2) + + def test_for_else(self): + + def test_fn(x, i): + y = 0 + for i in x: + x += i + if i: + break + else: + continue + else: + y = 1 + return x, y + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasDefs(fn_body[0].targets[0], 1) + self.assertHasDefs(fn_body[1].target, 1) + self.assertHasDefs(fn_body[1].body[0].target, 1) + self.assertHasDefs(fn_body[1].body[1].test, 1) + self.assertHasDefs(fn_body[1].orelse[0].targets[0], 1) + self.assertHasDefs(fn_body[2].value.elts[0], 2) + self.assertHasDefs(fn_body[2].value.elts[1], 2) + + def test_nested_functions(self): + + def test_fn(a, b): + a = [] + if b: + a = [] + + def foo(): + return a + + foo() + + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + def_of_a_in_if = fn_body[1].body[0].targets[0] + + self.assertHasDefs(fn_body[0].targets[0], 1) + self.assertHasDefs(fn_body[1].test, 1) + self.assertHasDefs(def_of_a_in_if, 1) + self.assertHasDefs(fn_body[2].value, 2) + + inner_fn_body = fn_body[1].body[1].body + self.assertSameDef(inner_fn_body[0].value, def_of_a_in_if) + + def test_nested_functions_isolation(self): + + def test_fn(a): + a = 0 + + def child(): + a = 1 + return a + + child() + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + parent_return = fn_body[3] + child_return = fn_body[1].body[1] + # The assignment `a = 1` makes `a` local to `child`. + self.assertNotSameDef(parent_return.value, child_return.value) + + def test_function_call_in_with(self): + + def foo(_): + pass + + def test_fn(a): + with foo(a): + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + self.assertHasDefs(fn_body[0].items[0].context_expr.func, 0) + self.assertHasDefs(fn_body[0].items[0].context_expr.args[0], 1) + + def test_mutation_subscript(self): + + def test_fn(a): + l = [] + l[0] = a + return l + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + creation = fn_body[0].targets[0] + mutation = fn_body[1].targets[0].value + use = fn_body[2].value + self.assertSameDef(creation, mutation) + self.assertSameDef(creation, use) + + def test_replacement(self): + + def foo(a): + return a + + def test_fn(a): + a = foo(a) + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body[0].body + + param = node.body[0].args.args[0] + source = fn_body[0].value.args[0] + target = fn_body[0].targets[0] + retval = fn_body[1].value + self.assertSameDef(param, source) + self.assertNotSameDef(source, target) + self.assertSameDef(target, retval) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py index a229c288a83e516fc02f3af8df2046c5365e569c..835d5199fa1a5c145e29a413d4d23b4138a3c1cd 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py @@ -43,9 +43,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect @@ -166,7 +165,6 @@ class TypeInfoResolver(transformer.Base): definition = self.scope.getval(qn) anno.copyanno(definition, node, 'type') anno.copyanno(definition, node, 'type_fqn') - anno.setanno(node, 'definition', definition) # TODO(mdan): Remove this when the directives module is in. anno.copyanno(definition, node, 'element_type') @@ -198,52 +196,18 @@ class TypeInfoResolver(transformer.Base): def visit_With(self, node): for item in node.items: if item.optional_vars is not None: - self.apply_to_single_assignments((item.optional_vars,), - item.context_expr, - self._process_variable_assignment) + ast_util.apply_to_single_assignments((item.optional_vars,), + item.context_expr, + self._process_variable_assignment) self.generic_visit(node) return node def visit_Assign(self, node): self.generic_visit(node) - self.apply_to_single_assignments( - node.targets, node.value, self._process_variable_assignment) + ast_util.apply_to_single_assignments(node.targets, node.value, + self._process_variable_assignment) return node - # TODO(mdan): Remove as soon as the new directives module is ready. - def visit_Call(self, node): - if anno.hasanno(node.func, 'live_val'): - # Symbols targeted by the "set_type" marker function are assigned the data - # type that it specified. - if anno.getanno(node.func, 'live_val') is utils.set_element_type: - - if len(node.args) < 2 or len(node.args) > 3: - raise ValueError('"%s" must have either two or three parameters' - % self.context.type_annotation_func) - if len(node.args) == 2: - target_arg, type_arg = node.args - shape_arg = parser.parse_expression('None') - else: - target_arg, type_arg, shape_arg = node.args - if not anno.hasanno(target_arg, anno.Basic.QN): - raise ValueError('the first argument of "%s" must by a symbol' % - utils.set_element_type) - # TODO(mdan): This is vulnerable to symbol renaming. - element_type = type_arg - element_shape = shape_arg - - target_symbol = anno.getanno(target_arg, anno.Basic.QN) - # Find the definition of this symbol and annotate it with the given - # data type. That in turn will cause future uses of the symbol - # to receive the same type annotation. - definition = self.scope.getval(target_symbol) - anno.setanno(node, 'element_type', element_type) - anno.setanno(node, 'element_shape', element_shape) - anno.setanno(definition, 'element_type', element_type) - anno.setanno(definition, 'element_shape', element_shape) - # TODO(mdan): Should we update references between definition and here? - return self.generic_visit(node) - def resolve(node, context): return TypeInfoResolver(context).visit(node) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py index 32b1148ab21809514bc09a31e26f0219017bd088..404311ba242cf0359cf5695dfe3eeaf9cb858eb8 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py @@ -19,11 +19,13 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import cfg from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.client import session from tensorflow.python.platform import test @@ -69,7 +71,10 @@ class TypeInfoResolverTest(test.TestCase): arg_types=arg_types, owner_type=None) node = qual_names.resolve(node) + graphs = cfg.build(node) node = activity.resolve(node, entity_info) + node = reaching_definitions.resolve(node, entity_info, graphs, + reaching_definitions.Definition) node = live_values.resolve(node, entity_info, {}) node = type_info.resolve(node, entity_info) node = live_values.resolve(node, entity_info, {}) diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py index 9c479ebc2fa83d27dc363ae306daedb556734a1f..5831d57ceb58d4b291a4f52bbf4282e107104219 100644 --- a/tensorflow/contrib/autograph/pyct/templates.py +++ b/tensorflow/contrib/autograph/pyct/templates.py @@ -26,6 +26,7 @@ import textwrap import gast +from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names @@ -43,39 +44,65 @@ class ReplaceTransformer(gast.NodeTransformer): """ self.replacements = replacements self.in_replacements = False + self.preserved_annos = { + anno.Basic.ORIGIN, + anno.Basic.SKIP_PROCESSING, + anno.Static.ORIG_DEFINITIONS, + } + + def _prepare_replacement(self, replaced, key): + """Prepares a replacement AST that's safe to swap in for a node. + + Args: + replaced: ast.AST, the node being replaced + key: Hashable, the key of the replacement AST + Returns: + ast.AST, the replacement AST + """ + repl = self.replacements[key] + + new_nodes = ast_util.copy_clean(repl, preserve_annos=self.preserved_annos) + if isinstance(new_nodes, gast.AST): + new_nodes = [new_nodes] + + return new_nodes def visit_Expr(self, node): - if (isinstance(node.value, gast.Name) and - node.value.id in self.replacements): - return self.visit(node.value) - self.generic_visit(node) - return node + # When replacing a placeholder with an entire statement, the replacement + # must stand on its own and not be wrapped in an Expr. + new_value = self.visit(node.value) + if new_value is node.value: + return node + return new_value def visit_keyword(self, node): - if node.arg in self.replacements: - repl = self.replacements[node.arg] - if isinstance(repl, gast.keyword): - return repl - elif (isinstance(repl, (list, tuple)) and repl and - all(isinstance(r, gast.keyword) for r in repl)): - return repl - # TODO(mdan): We may allow replacing with a string as well. - # For example, if one wanted to replace foo with bar in foo=baz, then - # we could allow changing just node arg, so that we end up with bar=baz. - raise ValueError( - 'a keyword argument may only be replaced by another keyword or a ' - 'non-empty list of keywords. Found: %s' % repl) - return self.generic_visit(node) + if node.arg not in self.replacements: + return self.generic_visit(node) + + repl = self._prepare_replacement(node, node.arg) + if isinstance(repl, gast.keyword): + return repl + elif (repl and isinstance(repl, (list, tuple)) and + all(isinstance(r, gast.keyword) for r in repl)): + return repl + # TODO(mdan): We may allow replacing with a string as well. + # For example, if one wanted to replace foo with bar in foo=baz, then + # we could allow changing just node arg, so that we end up with bar=baz. + raise ValueError( + 'a keyword argument may only be replaced by another keyword or a ' + 'non-empty list of keywords. Found: %s' % repl) def visit_FunctionDef(self, node): node = self.generic_visit(node) - if node.name in self.replacements: - repl = self.replacements[node.name] - if not isinstance(repl, (gast.Name, ast.Name)): - raise ValueError( - 'a function name can only be replaced by a Name node. Found: %s' % - repl) - node.name = repl.id + if node.name not in self.replacements: + return node + + repl = self.replacements[node.name] + if not isinstance(repl, (gast.Name, ast.Name)): + raise ValueError( + 'a function name can only be replaced by a Name node. Found: %s' % + repl) + node.name = repl.id return node def _check_has_context(self, node): @@ -113,8 +140,8 @@ class ReplaceTransformer(gast.NodeTransformer): def _set_inner_child_context(self, node, ctx): if isinstance(node, gast.Attribute): - self._set_inner_child_context(node.value, ctx) - node.ctx = gast.Load() + self._set_inner_child_context(node.value, gast.Load()) + node.ctx = ctx elif isinstance(node, gast.Tuple): for e in node.elts: self._set_inner_child_context(e, ctx) @@ -148,6 +175,7 @@ class ReplaceTransformer(gast.NodeTransformer): node = self.generic_visit(node) if node.attr not in self.replacements: return node + repl = self.replacements[node.attr] if not isinstance(repl, gast.Name): raise ValueError( @@ -159,9 +187,7 @@ class ReplaceTransformer(gast.NodeTransformer): if node.id not in self.replacements: return node - new_nodes = ast_util.copy_clean(self.replacements[node.id]) - if isinstance(new_nodes, gast.AST): - new_nodes = [new_nodes] + new_nodes = self._prepare_replacement(node, node.id) # Preserve the target context. for n in new_nodes: @@ -182,7 +208,7 @@ class ReplaceTransformer(gast.NodeTransformer): def _convert_to_ast(n): - """Convert from a known data type to AST.""" + """Converts from a known data type to AST.""" if isinstance(n, str): # Note: the node will receive the ctx value from the template, see # ReplaceTransformer.visit_Name. @@ -197,7 +223,7 @@ def _convert_to_ast(n): def replace(template, **replacements): - """Replace placeholders in a Python template. + """Replaces placeholders in a Python template. AST Name and Tuple nodes always receive the context that inferred from the template. However, when replacing more complex nodes (that can potentially diff --git a/tensorflow/contrib/autograph/pyct/templates_test.py b/tensorflow/contrib/autograph/pyct/templates_test.py index a01f8bf04c4faa6ec1779e0fb306155d99f5bd09..77e8ff62fd8665e095cfb410a2aa418e9f9bd52b 100644 --- a/tensorflow/contrib/autograph/pyct/templates_test.py +++ b/tensorflow/contrib/autograph/pyct/templates_test.py @@ -97,6 +97,19 @@ class TemplatesTest(test.TestCase): with self.assertRaises(ValueError): templates.replace(template, foo=1) + def test_replace_attribute_context(self): + template = """ + def test_fn(foo): + foo = 0 + """ + + node = templates.replace( + template, + foo=parser.parse_expression('a.b.c'))[0] + self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store) + self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load) + self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load) + def test_replace_call_keyword(self): template = """ def test_fn(): @@ -151,17 +164,13 @@ class TemplatesTest(test.TestCase): self.assertEqual(node.func.id, 'bar') self.assertEqual(node.func.args[0].id, 'baz') - def replace_as_expression_restrictions(self): + def test_replace_as_expression_restrictions(self): template = """ foo(a) bar(b) """ with self.assertRaises(ValueError): templates.replace_as_expression(template) - with self.assertRaises(ValueError): - templates.replace('') - with self.assertRaises(ValueError): - templates.replace('a = b') if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/pyct/testing/BUILD b/tensorflow/contrib/autograph/pyct/testing/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..957db356f7e1acf673ce5db7c8087208af43ac23 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/testing/BUILD @@ -0,0 +1,43 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "testing", + srcs = [ + "codegen.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/utils", + "@gast_archive//:gast", + ], +) + +py_test( + name = "codegen_test", + size = "large", + srcs = ["codegen_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":testing", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) diff --git a/tensorflow/contrib/autograph/pyct/testing/codegen.py b/tensorflow/contrib/autograph/pyct/testing/codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..279e7c09dc6449184e2029ad65fc3f71d94db8b4 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/testing/codegen.py @@ -0,0 +1,234 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Random code generation for testing/fuzzing.""" +# pylint: disable=invalid-name +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import string + +import gast +import numpy as np + +from tensorflow.contrib.autograph.pyct import templates + + +class NodeSampler(object): + sample_map = None + + def sample(self): + nodes, magnitudes = zip(*self.sample_map.items()) + return np.random.choice( + nodes, p=np.array(magnitudes, dtype='float32') / np.sum(magnitudes)) + + +class StatementSampler(NodeSampler): + sample_map = dict(( + (gast.Assign, 10), + (gast.Print, 1), + (gast.If, 2), + (gast.While, 2), + (gast.For, 0), + )) + + +class ExpressionSampler(NodeSampler): + sample_map = dict(( + (gast.UnaryOp, 1), + (gast.BinOp, 8), + (gast.Name, 1), + (gast.Call, 0), + )) + + +class CompareSampler(NodeSampler): + sample_map = dict(( + (gast.Eq, 1), + (gast.NotEq, 1), + (gast.Lt, 1), + (gast.LtE, 1), + (gast.Gt, 1), + (gast.GtE, 1), + (gast.Is, 1), + (gast.IsNot, 1), + )) + + +class BinaryOpSampler(NodeSampler): + sample_map = dict(( + (gast.Add, 1), + (gast.Sub, 1), + (gast.Mult, 1), + (gast.Div, 1), + (gast.FloorDiv, 1), + (gast.Mod, 1), + (gast.Pow, 1), + )) + + +class UnaryOpSampler(NodeSampler): + sample_map = dict(((gast.USub, 1), (gast.UAdd, 0))) + + +class NameSampler(NodeSampler): + sample_map = dict(( + ('new', 1), + ('existing', 1), + )) + + +N_CONTROLFLOW_STATEMENTS = 10 +N_FUNCTIONDEF_STATEMENTS = 10 + + +class CodeGenerator(object): + """Generate random syntactically-valid Python ASTs.""" + + def __init__(self, max_depth=3, depth=0): + self.max_depth = max_depth + self.depth = depth + + def generate_statement(self): + """Generate a statement node, dispatching to the correct class method.""" + desired_node = StatementSampler().sample() + self.depth += 1 + + # Enforce some constraints on generating statements. + # E.g., if statements need at least 3 readable variables. + # If we fail to satisfy our constraints, draw another sample. + if desired_node in (gast.While, gast.For, gast.If): + if self.depth > self.max_depth: + return self.generate_statement() + + # Go get the generator method and run it + method = 'generate_' + desired_node.__name__ + visitor = getattr(self, method) + node = visitor() + self.depth -= 1 + return node + + def sample_node_list(self, low, high, generator): + """Generate a list of statements of random length. + + Args: + low: Fewest number of statements to generate. + high: Highest number of statements to generate. + generator: Function to call to generate nodes. + + Returns: + A list of statements. + """ + statements = [] + for _ in range(np.random.randint(low, high)): + statements.append(generator()) + return statements + + def generate_Name(self, ctx=gast.Load()): + variable_name = '_' + ''.join( + random.choice(string.ascii_lowercase) for _ in range(4)) + return gast.Name(variable_name, ctx=ctx, annotation=None) + + def generate_BinOp(self): + # TODO(alexbw): convert to generate_expression when we get to limit + # expression depth. + op = BinaryOpSampler().sample()() + return gast.BinOp(self.generate_Name(), op, self.generate_Name()) + + def generate_Compare(self): + op = CompareSampler().sample()() + return gast.Compare(self.generate_Name(), [op], [self.generate_Name()]) + + def generate_UnaryOp(self): + operand = self.generate_Name() + op = UnaryOpSampler().sample()() + return gast.UnaryOp(op, operand) + + def generate_expression(self): + desired_node = ExpressionSampler().sample() + # Go get the generator method and run it + method = 'generate_' + desired_node.__name__ + generator = getattr(self, method) + return generator() + + def generate_Assign(self): + """Generate an Assign node.""" + # Generate left-hand side + target_node = self.generate_Name(gast.Store()) + # Generate right-hand side + value_node = self.generate_expression() + # Put it all together + node = gast.Assign(targets=[target_node], value=value_node) + return node + + def generate_If(self): + """Generate an If node.""" + test = self.generate_Compare() + + # Generate true branch statements + body = self.sample_node_list( + low=1, + high=N_CONTROLFLOW_STATEMENTS // 2, + generator=self.generate_statement) + + # Generate false branch statements + orelse = self.sample_node_list( + low=1, + high=N_CONTROLFLOW_STATEMENTS // 2, + generator=self.generate_statement) + + node = gast.If(test, body, orelse) + return node + + def generate_While(self): + """Generate a While node.""" + + test = self.generate_Compare() + body = self.sample_node_list( + low=1, high=N_CONTROLFLOW_STATEMENTS, generator=self.generate_statement) + orelse = [] # not generating else statements + + node = gast.While(test, body, orelse) + return node + + def generate_Call(self): + raise NotImplementedError + + def generate_Return(self): + return gast.Return(self.generate_expression()) + + def generate_Print(self): + return templates.replace('print(x)', x=self.generate_expression())[0] + + def generate_FunctionDef(self): + """Generate a FunctionDef node.""" + + # Generate the arguments, register them as available + arg_vars = self.sample_node_list( + low=2, high=10, generator=lambda: self.generate_Name(gast.Param())) + args = gast.arguments(arg_vars, None, [], [], None, []) + + # Generate the function body + body = self.sample_node_list( + low=1, high=N_FUNCTIONDEF_STATEMENTS, generator=self.generate_statement) + body.append(self.generate_Return()) + fn_name = self.generate_Name().id + node = gast.FunctionDef(fn_name, args, body, (), None) + return node + + +def generate_random_functiondef(): + return CodeGenerator().generate_FunctionDef() diff --git a/tensorflow/contrib/autograph/pyct/testing/codegen_test.py b/tensorflow/contrib/autograph/pyct/testing/codegen_test.py new file mode 100644 index 0000000000000000000000000000000000000000..255c3b2a2edc65ab978d8c32682fafd8ce00f5ac --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/testing/codegen_test.py @@ -0,0 +1,40 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for type_info module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct.testing import codegen +from tensorflow.python.platform import test + + +class CodeGenTest(test.TestCase): + + def test_codegen_gens(self): + np.random.seed(0) + for _ in range(1000): + node = codegen.generate_random_functiondef() + fn = compiler.ast_to_object(node) + self.assertIsNotNone( + fn, 'Generated invalid AST that could not convert to source.') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py index 76558118308c31a2c1a770cad814e96abd6a6063..969ca12244148b346ba3160fba124384a9641a05 100644 --- a/tensorflow/contrib/autograph/pyct/transformer.py +++ b/tensorflow/contrib/autograph/pyct/transformer.py @@ -59,6 +59,103 @@ class EntityInfo(object): self.owner_type = owner_type +class _StateStack(object): + """Typed stack abstraction. + + This class provides syntactic sugar for a stack of objects of known + type. It allows accessing attributes of the object at the top of the stack + directly against this object, which allows for very terse syntax. + + For example, this code: + + stack = _StateStack(Foo) + stack.enter() + stack.bar + + Is equivalent to: + + stack = [] + stack.append(Foo()) + foo = stack[-1] + foo.bar + + See _State for more on how this is used. + + Attributes: + type: Any, the type of objects that this stack holds + level: int, the current stack depth + value: Any, the instance of the object at the top of the stack + """ + + def __init__(self, type_): + # Because we override __setattr__, we need to attach these attributes using + # the superclass' setattr. + object.__setattr__(self, 'type', type_) + object.__setattr__(self, '_stack', []) + self.enter() + + def enter(self): + self._stack.append(self.type()) + + def exit(self): + return self._stack.pop() + + @property + def level(self): + return len(self._stack) + + @property + def value(self): + return self._stack[-1] + + def __getattr__(self, key): + return getattr(self._stack[-1], key) + + def __setattr__(self, key, value): + setattr(self._stack[-1], key, value) + + +class _State(object): + """Supporting class for nested scope variable space for converter.Base. + + This structure offers syntactic sugar over a dict of stacks of objects + of known type. These structures are useful to keep state during AST walks. + Multiple different scopes can be tracked in parallel. For example: + + s = _State() + + s[foo].enter() + s[bar].enter() # this will not affect s[foo] + + Element access has special semantics: + * keys are a data type + * element values are _StateStack(type=key) objects + * missing elements are automatically added, similarly to defaultdict + + For example, the following block : + + _State s + s[Foo] + + Is equivalent to: + + s = {} + if Foo not in s: + s[Foo] = Foo() + s[Foo] + + See Base for how it's used. + """ + + def __init__(self): + self._value = {} + + def __getitem__(self, key): + if key not in self._value: + self._value[key] = _StateStack(key) + return self._value[key] + + class Base(gast.NodeTransformer): """Base class for general-purpose code transformers transformers. @@ -71,6 +168,27 @@ class Base(gast.NodeTransformer): (possibly nested) scopes, use enter/exit_local_scope and set/get_local. You must call enter/exit_local_scope manually, but the transformer detects when they are not properly paired. + + The transformer allows keeping state across calls to visit_* that is local to + arbitrary nodes and their descendants, using the self.state attribute. + Multiple independent scopes are allowed and automatically constructed. + + For example, to keep track of the If node that encloses any Name node, one can + write: + + class FooType(object): + + def __init__(self): + self.foo_property = None + + class DummyTransformer(Base): + + def visit_If(self, node): + self.state[FooType].enter() + self.state[FooType].foo_property = node + + def visit_Name(self, node): + self.state[FooType].foo_property # will hold the innermost enclosing if """ # TODO(mdan): Document all extra features. @@ -92,6 +210,12 @@ class Base(gast.NodeTransformer): self._local_scope_state = [] self.enter_local_scope() + # Allows scoping of local variables to keep state across calls to visit_* + # methods. Multiple scope hierchies may exist and are keyed by tag. A scope + # is valid at one or more nodes and all its children. Scopes created in + # child nodes supersede their parent. Scopes are isolated from one another. + self.state = _State() + @property def enclosing_entities(self): return tuple(self._enclosing_entities) @@ -101,7 +225,9 @@ class Base(gast.NodeTransformer): return len(self._local_scope_state) def enter_local_scope(self, inherit=None): - """Marks entry into a new local scope. + """Deprecated. Use self.state instead. + + Marks entry into a new local scope. Args: inherit: Optional enumerable of variable names to copy from the @@ -116,7 +242,9 @@ class Base(gast.NodeTransformer): self._local_scope_state.append(scope_entered) def exit_local_scope(self, keep=None): - """Marks exit from the current local scope. + """Deprecated. Use self.state instead. + + Marks exit from the current local scope. Args: keep: Optional enumerable of variable names to copy into the @@ -133,9 +261,11 @@ class Base(gast.NodeTransformer): return scope_left def set_local(self, name, value): + """Deprecated. Use self.state instead.""" self._local_scope_state[-1][name] = value def get_local(self, name, default=None): + """Deprecated. Use self.state instead.""" return self._local_scope_state[-1].get(name, default) def debug_print(self, node): @@ -216,7 +346,7 @@ class Base(gast.NodeTransformer): node_destination = new_destination return results - # TODO(mdan): Once we have error tracing, we may be able to just go to SSA. + # TODO(mdan): Remove. def apply_to_single_assignments(self, targets, values, apply_fn): """Applies a function to each individual assignment. @@ -266,19 +396,38 @@ class Base(gast.NodeTransformer): def _get_source(self, node): try: - return compiler.ast_to_source(node) - except AssertionError: + source, _ = compiler.ast_to_source(node) + return source + # pylint: disable=broad-except + # This function is used for error reporting. If an exception occurs here, + # it should be suppressed, in favor of emitting as informative a message + # about the original error as possible. + except Exception: return '' def visit(self, node): + if not isinstance(node, gast.AST): + # This is not that uncommon a mistake: various node bodies are lists, for + # example, posing a land mine for transformers that need to recursively + # call `visit`. The error needs to be raised before the exception handler + # below is installed, because said handler will mess up if `node` is not, + # in fact, a node. + msg = ( + 'invalid value for "node": expected "ast.AST", got "{}"; to' + ' visit lists of nodes, use "visit_block" instead').format(type(node)) + raise ValueError(msg) + source_code = self.entity_info.source_code source_file = self.entity_info.source_file did_enter_function = False local_scope_size_at_entry = len(self._local_scope_state) + processing_expr_node = False try: if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)): did_enter_function = True + elif isinstance(node, gast.Expr): + processing_expr_node = True if did_enter_function: self._enclosing_entities.append(node) @@ -287,9 +436,23 @@ class Base(gast.NodeTransformer): self._lineno = node.lineno self._col_offset = node.col_offset + if processing_expr_node: + entry_expr_value = node.value + if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING): result = super(Base, self).visit(node) + # Adjust for consistency: replacing the value of an Expr with + # an Assign node removes the need for the Expr node. + if processing_expr_node: + if isinstance(result, gast.Expr) and result.value != entry_expr_value: + # When the replacement is a list, it is assumed that the list came + # from a template that contained a number of statements, which + # themselves are standalone and don't require an enclosing Expr. + if isinstance(result.value, + (list, tuple, gast.Assign, gast.AugAssign)): + result = result.value + # On exception, the local scope integrity is not guaranteed. if did_enter_function: self._enclosing_entities.pop() diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py index baf04653ae862b0159fb50a1c67fa675ceb74b9a..a37e922a1de902106dd3a11f20a14ddde8f6675e 100644 --- a/tensorflow/contrib/autograph/pyct/transformer_test.py +++ b/tensorflow/contrib/autograph/pyct/transformer_test.py @@ -93,6 +93,83 @@ class TransformerTest(test.TestCase): inner_function, lambda_node), anno.getanno(lambda_expr, 'enclosing_entities')) + def assertSameAnno(self, first, second, key): + self.assertIs(anno.getanno(first, key), anno.getanno(second, key)) + + def assertDifferentAnno(self, first, second, key): + self.assertIsNot(anno.getanno(first, key), anno.getanno(second, key)) + + def test_state_tracking(self): + + class LoopState(object): + pass + + class CondState(object): + pass + + class TestTransformer(transformer.Base): + + def visit(self, node): + anno.setanno(node, 'loop_state', self.state[LoopState].value) + anno.setanno(node, 'cond_state', self.state[CondState].value) + return super(TestTransformer, self).visit(node) + + def visit_While(self, node): + self.state[LoopState].enter() + node = self.generic_visit(node) + self.state[LoopState].exit() + return node + + def visit_If(self, node): + self.state[CondState].enter() + node = self.generic_visit(node) + self.state[CondState].exit() + return node + + tr = TestTransformer(self._simple_source_info()) + + def test_function(a): + a = 1 + while a: + _ = 'a' + if a > 2: + _ = 'b' + while True: + raise '1' + if a > 3: + _ = 'c' + while True: + raise '1' + + node, _ = parser.parse_entity(test_function) + node = tr.visit(node) + + fn_body = node.body[0].body + outer_while_body = fn_body[1].body + self.assertSameAnno(fn_body[0], outer_while_body[0], 'cond_state') + self.assertDifferentAnno(fn_body[0], outer_while_body[0], 'loop_state') + + first_if_body = outer_while_body[1].body + self.assertDifferentAnno(outer_while_body[0], first_if_body[0], + 'cond_state') + self.assertSameAnno(outer_while_body[0], first_if_body[0], 'loop_state') + + first_inner_while_body = first_if_body[1].body + self.assertSameAnno(first_if_body[0], first_inner_while_body[0], + 'cond_state') + self.assertDifferentAnno(first_if_body[0], first_inner_while_body[0], + 'loop_state') + + second_if_body = outer_while_body[2].body + self.assertDifferentAnno(first_if_body[0], second_if_body[0], 'cond_state') + self.assertSameAnno(first_if_body[0], second_if_body[0], 'loop_state') + + second_inner_while_body = second_if_body[1].body + self.assertDifferentAnno(first_inner_while_body[0], + second_inner_while_body[0], 'cond_state') + self.assertDifferentAnno(first_inner_while_body[0], + second_inner_while_body[0], 'loop_state') + def test_local_scope_info_stack(self): class TestTransformer(transformer.Base): @@ -205,6 +282,88 @@ class TransformerTest(test.TestCase): self.assertTrue(isinstance(node.body[1].body[0], gast.Assign)) self.assertTrue(isinstance(node.body[1].body[1], gast.Return)) + def test_robust_error_on_list_visit(self): + + class BrokenTransformer(transformer.Base): + + def visit_If(self, node): + # This is broken because visit expects a single node, not a list, and + # the body of an if is a list. + # Importantly, the default error handling in visit also expects a single + # node. Therefore, mistakes like this need to trigger a type error + # before the visit called here installs its error handler. + # That type error can then be caught by the enclosing call to visit, + # and correctly blame the If node. + self.visit(node.body) + return node + + def test_function(x): + if x > 0: + return x + + tr = BrokenTransformer(self._simple_source_info()) + + node, _ = parser.parse_entity(test_function) + with self.assertRaises(transformer.AutographParseError) as cm: + node = tr.visit(node) + obtained_message = str(cm.exception) + expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"' + self.assertRegexpMatches(obtained_message, expected_message) + # The exception should point at the if statement, not any place else. Could + # also check the stack trace. + self.assertTrue( + 'Occurred at node:\nIf' in obtained_message, obtained_message) + self.assertTrue( + 'Occurred at node:\nFunctionDef' not in obtained_message, + obtained_message) + self.assertTrue( + 'Occurred at node:\nReturn' not in obtained_message, obtained_message) + + def test_robust_error_on_ast_corruption(self): + # A child class should not be able to be so broken that it causes the error + # handling in `transformer.Base` to raise an exception. Why not? Because + # then the original error location is dropped, and an error handler higher + # up in the call stack gives misleading information. + + # Here we test that the error handling in `visit` completes, and blames the + # correct original exception, even if the AST gets corrupted. + + class NotANode(object): + pass + + class BrokenTransformer(transformer.Base): + + def visit_If(self, node): + node.body = NotANode() + raise ValueError('I blew up') + + def test_function(x): + if x > 0: + return x + + tr = BrokenTransformer(self._simple_source_info()) + + node, _ = parser.parse_entity(test_function) + with self.assertRaises(transformer.AutographParseError) as cm: + node = tr.visit(node) + obtained_message = str(cm.exception) + # The message should reference the exception actually raised, not anything + # from the exception handler. + expected_substring = 'I blew up' + self.assertTrue(expected_substring in obtained_message, obtained_message) + # Expect the exception to have failed to parse the corrupted AST + self.assertTrue( + '' in obtained_message, + obtained_message) + # The exception should point at the if statement, not any place else. Could + # also check the stack trace. + self.assertTrue( + 'Occurred at node:\nIf' in obtained_message, obtained_message) + self.assertTrue( + 'Occurred at node:\nFunctionDef' not in obtained_message, + obtained_message) + self.assertTrue( + 'Occurred at node:\nReturn' not in obtained_message, obtained_message) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD index d82c17bf2afd01aedf4344f983b02c09abcb9bad..d2b399f19b63bfaa20d334df78ae60d50f6ca6e7 100644 --- a/tensorflow/contrib/autograph/utils/BUILD +++ b/tensorflow/contrib/autograph/utils/BUILD @@ -28,7 +28,6 @@ py_library( "tensor_list.py", "testing.py", "type_check.py", - "type_hints.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], diff --git a/tensorflow/contrib/autograph/utils/__init__.py b/tensorflow/contrib/autograph/utils/__init__.py index 817d4126d106487e1fea3e442712a69bbfccd7f3..57b5f747417613a5dd5bce08e4a9e9ef98442cf6 100644 --- a/tensorflow/contrib/autograph/utils/__init__.py +++ b/tensorflow/contrib/autograph/utils/__init__.py @@ -30,4 +30,3 @@ from tensorflow.contrib.autograph.utils.py_func import wrap_py_func from tensorflow.contrib.autograph.utils.tensor_list import dynamic_list_append from tensorflow.contrib.autograph.utils.testing import fake_tf from tensorflow.contrib.autograph.utils.type_check import is_tensor -from tensorflow.contrib.autograph.utils.type_hints import set_element_type diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py index 998087e056c2cd264399982220d6e0528aab9edb..ccbe5fc9541dfad561d8eab730e2b15f6250ceb2 100644 --- a/tensorflow/contrib/autograph/utils/builtins.py +++ b/tensorflow/contrib/autograph/utils/builtins.py @@ -27,6 +27,7 @@ from tensorflow.contrib.autograph.utils import type_check from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops @@ -50,15 +51,22 @@ def dynamic_builtin(f, *args, **kwargs): def dynamic_len(list_or_tensor): """Implementation of len using dynamic dispatch.""" - if tensor_util.is_tensor(list_or_tensor): + if _is_tensor_list(list_or_tensor): + return list_ops.tensor_list_length(list_or_tensor) + elif tensor_util.is_tensor(list_or_tensor): shape = list_or_tensor.shape - if not shape: + if not shape.ndims: raise ValueError( 'len requires non-zero rank for tensor "%s"' % list_or_tensor) return array_ops.shape(list_or_tensor)[0] return len(list_or_tensor) +def _is_tensor_list(list_or_tensor): + return (tensor_util.is_tensor(list_or_tensor) + and list_or_tensor.dtype == dtypes.variant) + + def dynamic_int(num_or_tensor, **kwargs): """Implementation of int() using dynamic dispatch.""" if tensor_util.is_tensor(num_or_tensor): diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py index 0c2312178a921037fa419818bf309d671c33914d..b4821f36fcab8c201956e366d394bababb9f02b6 100644 --- a/tensorflow/contrib/autograph/utils/builtins_test.py +++ b/tensorflow/contrib/autograph/utils/builtins_test.py @@ -33,7 +33,8 @@ class BuiltinsTest(test.TestCase): def test_dynamic_len_tf_scalar(self): a = constant_op.constant(1) - with self.assertRaises(ValueError): + with self.assertRaisesRegexp(ValueError, + 'len requires non-zero rank for tensor.*'): with self.test_session() as sess: sess.run(builtins.dynamic_builtin(len, a)) diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py index 47b80bdf4ad88ebce3603a14ea2aa3cbe5bd345f..55faad983f2bcf2f3fa633669bd371608e2e925b 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops.py @@ -58,8 +58,6 @@ def batch_function(num_batch_threads, max_batch_size, batch_timeout_micros, allowed_batch_sizes=None, - grad_timeout_micros=60 * 1000 * 1000, - unbatch_timeout_micros=60 * 1000 * 1000, max_enqueued_batches=10): """Batches the computation done by the decorated function. @@ -94,10 +92,6 @@ def batch_function(num_batch_threads, does nothing. Otherwise, supplies a list of batch sizes, causing the op to pad batches up to one of those sizes. The entries must increase monotonically, and the final entry must equal max_batch_size. - grad_timeout_micros: The timeout to use for the gradient. See the - documentation of the unbatch op for more details. Defaults to 60s. - unbatch_timeout_micros: The timeout to use for unbatching. See the - documentation of the unbatch op for more details. Defaults to 60s. max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10. Returns: diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py index 032b859d469ee5039e08e4af4c2f4ebf35c2ff19..68ead2f7609ca987180fe8973cf902f1e56b8388 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py @@ -192,7 +192,7 @@ def _logspace_mean(log_values): def expectation(f, samples, log_prob=None, use_reparametrization=True, axis=0, keep_dims=False, name=None): - """Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\). + r"""Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\). This function computes the Monte-Carlo approximation of an expectation, i.e., diff --git a/tensorflow/contrib/bigtable/BUILD b/tensorflow/contrib/bigtable/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..71538e0770dcb436c8ff1571c22e950336328357 --- /dev/null +++ b/tensorflow/contrib/bigtable/BUILD @@ -0,0 +1,213 @@ +# Cloud Bigtable client for TensorFlow + +package( + default_visibility = ["//tensorflow:internal"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load( + "//tensorflow:tensorflow.bzl", + "tf_copts", + "tf_custom_op_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_kernel_library", + "tf_cc_test", + "tf_py_test", +) + +tf_custom_op_py_library( + name = "bigtable", + srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + dso = [ + ":python/ops/_bigtable.so", + ], + kernels = [ + ":bigtable_kernels", + ":bigtable_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":bigtable_ops", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "//tensorflow/python/data", + ], +) + +KERNEL_FILES = [ + "kernels/bigtable_kernels.cc", + "kernels/bigtable_lookup_dataset_op.cc", + "kernels/bigtable_prefix_key_dataset_op.cc", + "kernels/bigtable_range_key_dataset_op.cc", + "kernels/bigtable_sample_keys_dataset_op.cc", + "kernels/bigtable_sample_key_pairs_dataset_op.cc", + "kernels/bigtable_scan_dataset_op.cc", +] + +tf_custom_op_library( + name = "python/ops/_bigtable.so", + srcs = KERNEL_FILES + [ + "ops/bigtable_ops.cc", + ], + deps = [ + ":bigtable_lib_cc", + ":bigtable_range_helpers", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +tf_gen_op_wrapper_py( + name = "bigtable_ops", + deps = [":bigtable_ops_op_lib"], +) + +tf_gen_op_libs( + op_lib_names = [ + "bigtable_ops", + "bigtable_test_ops", + ], +) + +tf_kernel_library( + name = "bigtable_kernels", + srcs = KERNEL_FILES, + deps = [ + ":bigtable_lib_cc", + ":bigtable_range_helpers", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +# A library for use in the bigtable kernels. +cc_library( + name = "bigtable_lib_cc", + srcs = ["kernels/bigtable_lib.cc"], + hdrs = ["kernels/bigtable_lib.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +cc_library( + name = "bigtable_range_helpers", + srcs = ["kernels/bigtable_range_helpers.cc"], + hdrs = ["kernels/bigtable_range_helpers.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + ], +) + +cc_library( + name = "bigtable_test_client", + srcs = ["kernels/test_kernels/bigtable_test_client.cc"], + hdrs = ["kernels/test_kernels/bigtable_test_client.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "@com_github_googleapis_googleapis//:bigtable_protos", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + "@com_googlesource_code_re2//:re2", + ], +) + +tf_cc_test( + name = "bigtable_test_client_test", + srcs = ["kernels/test_kernels/bigtable_test_client_test.cc"], + tags = ["manual"], + deps = [ + ":bigtable_test_client", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client", + ], +) + +tf_cc_test( + name = "bigtable_range_helpers_test", + size = "small", + srcs = ["kernels/bigtable_range_helpers_test.cc"], + deps = [ + ":bigtable_range_helpers", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_gen_op_wrapper_py( + name = "bigtable_test_ops", + deps = [":bigtable_test_ops_op_lib"], +) + +tf_custom_op_library( + name = "python/kernel_tests/_bigtable_test.so", + srcs = [ + "kernels/test_kernels/bigtable_test_client_op.cc", + "ops/bigtable_test_ops.cc", + ], + deps = [ + ":bigtable_lib_cc", + ":bigtable_test_client", + "@com_googlesource_code_re2//:re2", + ], +) + +# Don't use tf_kernel_library because it prevents access to strings/stringprintf.h +cc_library( + name = "bigtable_test_kernels", + srcs = [ + "kernels/test_kernels/bigtable_test_client_op.cc", + ], + copts = tf_copts(), + linkstatic = 1, + deps = [ + ":bigtable_lib_cc", + ":bigtable_test_client", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@com_googlesource_code_re2//:re2", + ], + alwayslink = 1, +) + +tf_custom_op_py_library( + name = "bigtable_test_py", + dso = [ + ":python/kernel_tests/_bigtable_test.so", + ], + kernels = [ + ":bigtable_test_kernels", + ":bigtable_test_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":bigtable_test_ops", + ], +) + +tf_py_test( + name = "bigtable_ops_test", + size = "small", + srcs = ["python/kernel_tests/bigtable_ops_test.py"], + additional_deps = [ + ":bigtable", + ":bigtable_test_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], + tags = ["manual"], +) diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md new file mode 100644 index 0000000000000000000000000000000000000000..88a3909de4f34c11ac7ac3f0a865d76b675d0d06 --- /dev/null +++ b/tensorflow/contrib/bigtable/README.md @@ -0,0 +1,347 @@ +# Bigtable # + +[Cloud Bigtable](https://cloud.google.com/bigtable/) is a high +performance storage system that can store and serve training data. This contrib +package contains an experimental integration with TensorFlow. + +> **Status: Highly experimental.** The current implementation is very much in +> flux. Please use at your own risk! :-) + +The TensorFlow integration with Cloud Bigtable is optimized for common +TensorFlow usage and workloads. It is currently optimized for reading from Cloud +Bigtable at high speed, in particular to feed modern accelerators. For +general-purpose Cloud Bigtable +APIs, see the [official Cloud Bigtable client library documentation][clientdoc]. + +[clientdoc]: https://cloud.google.com/bigtable/docs/reference/libraries + +## Sample Use + +There are three main reading styles supported by the `BigtableTable` class: + + 1. **Reading keys**: Read only the row keys in a table. Keys are returned in + sorted order from the table. Most key reading operations retrieve all keys + in a contiguous range, however the `sample_keys` operation skips keys, and + operates on the whole table (and not a contiguous subset). + 2. **Retrieving a row's values**: Given a row key, look up the data associated + with a defined set of columns. This operation takes advantage of Cloud + Bigtable's low-latency and excellent support for random access. + 3. **Scanning ranges**: Given a contiguous range of rows retrieve both the row + key and the data associated with a fixed set of columns. This operation + takes advantage of Cloud Bigtable's high throughput scans, and is the most + efficient way to read data. + +When using the Cloud Bigtable API, the workflow is: + + 1. Create a `BigtableClient` object. + 2. Use the `BigtableClient` to create `BigtableTable` objects corresponding to + each table in the Cloud Bigtable instance you would like to access. + 3. Call methods on the `BigtableTable` object to create `tf.data.Dataset`s to + retrieve data. + +The following is an example for how to read all row keys with the prefix +`train-`. + +```python +import tensorflow as tf + +GCP_PROJECT_ID = '' +BIGTABLE_INSTANCE_ID = '' +BIGTABLE_TABLE_NAME = '' +PREFIX = 'train-' + +def main(): + client = tf.contrib.cloud.BigtableClient(GCP_PROJECT_ID, BIGTABLE_INSTANCE_ID) + table = client.table(BIGTABLE_TABLE_NAME) + dataset = table.keys_by_prefix_dataset(PREFIX) + iterator = dataset.make_initializable_iterator() + get_next_op = iterator.get_next() + + with tf.Session() as sess: + print('Initializing the iterator.') + sess.run(iterator.initializer) + print('Retrieving rows:') + row_index = 0 + while True: + try: + row_key = sess.run(get_next_op) + print('Row key %d: %s' % (row_index, row_key)) + row_index += 1 + except tf.errors.OutOfRangeError: + print('Finished reading data!') + break + +if __name__ == '__main__': + main() + +``` + +### Reading row keys + +Read only the row keys in a table. Keys are returned in sorted order from the +table. Most key reading operations retrieve all keys in a contiguous range, +however the `sample_keys` operation skips keys, and operates on the whole table +(and not a contiguous subset). + +There are 3 methods to retrieve row keys: + + - `table.keys_by_range_dataset(start, end)`: Retrieve row keys starting with + `start`, and ending with `end`. The range is "half-open", and thus it + includes `start` if `start` is present in the table. It does not include + `end`. + - `table.keys_by_prefix_dataset(prefix)`: Retrieves all row keys that start + with `prefix`. It includes the row key `prefix` if present in the table. + - `table.sample_keys()`: Retrieves a sampling of keys from the underlying + table. This is often useful in conjunction with parallel scans. + +### Reading cell values given a row key + +Given a dataset producing row keys, you can use the `table.lookup_columns` +transformation to retrieve values. Example: + +```python +key_dataset = tf.data.Dataset.from_tensor_slices([ + 'row_key_1', + 'other_row_key', + 'final_row_key', +]) +values_dataset = key_dataset.apply( + table.lookup_columns(('my_column_family', 'column_name'), + ('other_cf', 'col'))) +training_data = values_dataset.map(my_parsing_function) # ... +``` + +### Scanning ranges +Given a contiguous range of rows retrieve both the row key and the data +associated with a fixed set of columns. Scanning is the most efficient way to +retrieve data from Cloud Bigtable and is thus a very common API for high +performance data pipelines. To construct a scanning `tf.data.Dataset` from a +`BigtableTable` object, call one of the following methods: + + - `table.scan_prefix(prefix, ...)` + - `table.scan_range(start, end, ...)` + - `table.parallel_scan_prefix(prefix, ...)` + - `table.parallel_scan_range(start, end, ...)` + +Aside from the specification of the contiguous range of rows, they all take the +following arguments: + + - `probability`: (Optional.) A float between 0 (exclusive) and 1 (inclusive). + A non-1 value indicates to probabilistically sample rows with the + provided probability. + - `columns`: The columns to read. (See below.) + - `**kwargs`: The columns to read. (See below.) + +In addition the two parallel operations accept the following optional argument: +`num_parallel_scans` which configures the number of parallel Cloud Bigtable scan +operations to run. A reasonable default is automatically chosen for small +Cloud Bigtable clusters. If you have a large cluster, or an extremely demanding +workload, you can tune this value to optimize performance. + +#### Specifying columns to read when scanning + +All of the scan operations allow you to specify the column family and columns +in the same ways. + +##### Using `columns` + +The first way to specify the data to read is via the `columns` parameter. The +value should be a tuple (or list of tuples) of strings. The first string in the +tuple is the column family, and the second string in the tuple is the column +qualifier. + +##### Using `**kwargs` + +The second way to specify the data to read is via the `**kwargs` parameter, +which you can use to specify keyword arguments corresponding to the columns that +you want to read. The keyword to use is the column family name, and the argument +value should be either a string, or a tuple of strings, specifying the column +qualifiers (column names). + +Although using `**kwargs` has the advantage of requiring less typing, it is not +future-proof in all cases. (If we add a new parameter to the scan functions that +has the same name as your column family, your code will break.) + +##### Examples + +Below are two equivalent snippets for how to specify which columns to read: + +```python +ds1 = table.scan_range("row_start", "row_end", columns=[("cfa", "c1"), + ("cfa", "c2"), + ("cfb", "c3")]) +ds2 = table.scan_range("row_start", "row_end", cfa=["c1", "c2"], cfb="c3") +``` + +In this example, we are reading 3 columns from a total of 2 column families. +From the `cfa` column family, we are reading columns `c1`, and `c2`. From the +second column family (`cfb`), we are reading `c3`. Both `ds1` and `ds2` will +output elements of the following types (`tf.string`, `tf.string`, `tf.string`, +`tf.string`). The first `tf.string` is the row key, the second `tf.string` is +the latest data in cell `cfa:c1`, the third corresponds to `cfa:c2`, and the +final one is `cfb:c3`. + +#### Determinism when scanning + +While the non-parallel scan operations are fully deterministic, the parallel +scan operations are not. If you would like to scan in parallel without losing +determinism, you can build up the `parallel_interleave` yourself. As an example, +say we wanted to scan all rows between `training_data_00000`, and +`training_data_90000`, we can use the following code snippet: + +```python +table = # ... +columns = [('cf1', 'col1'), ('cf1', 'col2')] +NUM_PARALLEL_READS = # ... +ds = tf.data.Dataset.range(9).shuffle(10) +def interleave_fn(index): + # Given a starting index, create 2 strings to be the start and end + start_idx = index + end_idx = index + 1 + start_idx_str = tf.as_string(start_idx * 10000, width=5, fill='0') + end_idx_str = tf.as_string(end_idx * 10000, width=5, fill='0') + start = tf.string_join(['training_data_', start_idx_str]) + end = tf.string_join(['training_data_', end_idx_str]) + return table.scan_range(start_idx, end_idx, columns=columns) +ds = ds.apply(tf.contrib.data.parallel_interleave( + interleave_fn, cycle_length=NUM_PARALLEL_READS, prefetch_input_elements=1)) +``` + +> Note: you should divide up the key range into more sub-ranges for increased +> parallelism. + +## Writing to Cloud Bigtable + +In order to simplify getting started, this package provides basic support for +writing data into Cloud Bigtable. + +> Note: The implementation is not optimized for performance! Please consider +> using alternative frameworks such as Apache Beam / Cloud Dataflow for +> production workloads. + +Below is an example for how to write a trivial dataset into Cloud Bigtable. + +```python +import tensorflow as tf + +GCP_PROJECT_ID = '' +BIGTABLE_INSTANCE_ID = '' +BIGTABLE_TABLE_NAME = '' +COLUMN_FAMILY = '' +COLUMN_QUALIFIER = '' + +def make_dataset(): + """Makes a dataset to write to Cloud Bigtable.""" + return tf.data.Dataset.from_tensor_slices([ + 'training_data_1', + 'training_data_2', + 'training_data_3', + ]) + +def make_row_key_dataset(): + """Makes a dataset of strings used for row keys. + + The strings are of the form: `fake-data-` followed by a sequential counter. + For example, this dataset would contain the following elements: + + - fake-data-00000001 + - fake-data-00000002 + - ... + - fake-data-23498103 + """ + counter_dataset = tf.contrib.data.Counter() + width = 8 + row_key_prefix = 'fake-data-' + ds = counter_dataset.map(lambda index: tf.as_string(index, + width=width, + fill='0')) + ds = ds.map(lambda idx_str: tf.string_join([row_key_prefix, idx_str])) + return ds + + +def main(): + client = tf.contrib.cloud.BigtableClient(GCP_PROJECT_ID, BIGTABLE_INSTANCE_ID) + table = client.table(BIGTABLE_TABLE_NAME) + dataset = make_dataset() + index_dataset = make_row_key_dataset() + aggregate_dataset = tf.data.Dataset.zip((index_dataset, dataset)) + write_op = table.write(aggregate_dataset, column_families=[COLUMN_FAMILY], + columns=[COLUMN_QUALIFIER]) + + with tf.Session() as sess: + print('Starting transfer.') + sess.run(write_op) + print('Transfer complete.') + +if __name__ == '__main__': + main() +``` + +## Sample applications and architectures + +While most machine learning applications are well suited by a high performance +distributed file system, there are certain applications where using Cloud +Bigtable works extremely well. + +### Perfect Shuffling + +Normally, training data is stored in flat files, and a combination of +(1) `tf.data.Dataset.interleave` (or `parallel_interleave`), (2) +`tf.data.Dataset.shuffle`, and (3) writing the data in an unsorted order in the +data files in the first place, provides enough randomization to ensure models +train efficiently. However, if you would like perfect shuffling, you can use +Cloud Bigtable's low-latency random access capabilities. Create a +`tf.data.Dataset` that generates the keys in a perfectly random order (or read +all the keys into memory and use a shuffle buffer sized to fit all of them for a +perfect random shuffle using `tf.data.Dataset.shuffle`), and then use +`lookup_columns` to retrieve the training data. + +### Distributed Reinforcement Learning + +Sophisticated reinforcement learning algorithms are commonly trained across a +distributed cluster. (See [IMPALA by DeepMind][impala].) One part of the cluster +runs self-play, while the other part of the cluster learns a new version of the +model based on the training data generated by self-play. The new model version +is then distributed to the self-play half of the cluster, and new training data +is generated to continue the cycle. + +In such a configuration, because there is value in training on the freshest +examples, a storage service like Cloud Bigtable can be used to store and +serve the generated training data. When using Cloud Bigtable, there is no need +to aggregate the examples into large batch files, but the examples can instead +be written as soon as they are generated, and then retrieved at high speed. + +[impala]: https://arxiv.org/abs/1802.01561 + +## Common Gotchas! + +### gRPC Certificates + +If you encounter a log line that includes the following: + +``` +"description":"Failed to load file", [...], +"filename":"/usr/share/grpc/roots.pem" +``` + +you likely need to copy the [gRPC `roots.pem` file][grpcPem] to +`/usr/share/grpc/roots.pem` on your local machine. + +[grpcPem]: https://github.com/grpc/grpc/blob/master/etc/roots.pem + +### Permission denied errors + +The TensorFlow Cloud Bigtable client will search for credentials to use in the +process's environment. It will use the first credentials it finds if multiple +are available. + + - **Compute Engine**: When running on Compute Engine, the client will often use + the service account from the virtual machine's metadata service. Be sure to + authorize your Compute Engine VM to have access to the Cloud Bigtable service + when creating your VM, or [update the VM's scopes][update-vm-scopes] on a + running VM if you run into this issue. + - **Cloud TPU**: Your Cloud TPUs run with the designated Cloud TPU service + account dedicated to your GCP project. Ensure the service account has been + authorized via the Cloud Console to access your Cloud Bigtable instances. + +[update-vm-scopes]: https://cloud.google.com/compute/docs/access/create-enable-service-accounts-for-instances#changeserviceaccountandscopes diff --git a/tensorflow/contrib/bigtable/__init__.py b/tensorflow/contrib/bigtable/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7d89c98420ab3ac1465bba718f8257ce2312467 --- /dev/null +++ b/tensorflow/contrib/bigtable/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cloud Bigtable Client for TensorFlow. + +This contrib package allows TensorFlow to interface directly with Cloud Bigtable +for high-speed data loading. + +@@BigtableClient +@@BigtableTable + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableClient +from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableTable + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'BigtableClient', + 'BigtableTable', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc new file mode 100644 index 0000000000000000000000000000000000000000..a6755a3496f3e1720f1c8c67f75521f2380a9845 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc @@ -0,0 +1,355 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/threadpool.h" + +namespace tensorflow { + +namespace { + +class BigtableClientOp : public OpKernel { + public: + explicit BigtableClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("project_id", &project_id_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("instance_id", &instance_id_)); + OP_REQUIRES(ctx, !project_id_.empty(), + errors::InvalidArgument("project_id must be non-empty")); + OP_REQUIRES(ctx, !instance_id_.empty(), + errors::InvalidArgument("instance_id must be non-empty")); + + OP_REQUIRES_OK( + ctx, ctx->GetAttr("connection_pool_size", &connection_pool_size_)); + // If left unset by the client code, set it to a default of 100. Note: the + // cloud-cpp default of 4 concurrent connections is far too low for high + // performance streaming. + if (connection_pool_size_ == -1) { + connection_pool_size_ = 100; + } + + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_receive_message_size", + &max_receive_message_size_)); + // If left unset by the client code, set it to a default of 100. Note: the + // cloud-cpp default of 4 concurrent connections is far too low for high + // performance streaming. + if (max_receive_message_size_ == -1) { + max_receive_message_size_ = 1 << 24; // 16 MBytes + } + OP_REQUIRES(ctx, max_receive_message_size_ > 0, + errors::InvalidArgument("connection_pool_size must be > 0")); + } + + ~BigtableClientOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + BigtableClientResource* resource; + OP_REQUIRES_OK( + ctx, + mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, ctx]( + BigtableClientResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + auto client_options = + google::cloud::bigtable::ClientOptions() + .set_connection_pool_size(connection_pool_size_) + .set_data_endpoint("batch-bigtable.googleapis.com"); + auto channel_args = client_options.channel_arguments(); + channel_args.SetMaxReceiveMessageSize( + max_receive_message_size_); + channel_args.SetUserAgentPrefix("tensorflow"); + client_options.set_channel_arguments(channel_args); + std::shared_ptr client = + google::cloud::bigtable::CreateDefaultDataClient( + project_id_, instance_id_, std::move(client_options)); + *ret = new BigtableClientResource(project_id_, instance_id_, + std::move(client)); + return Status::OK(); + })); + core::ScopedUnref resource_cleanup(resource); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } + + private: + string project_id_; + string instance_id_; + int64 connection_pool_size_; + int32 max_receive_message_size_; + + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableClient").Device(DEVICE_CPU), + BigtableClientOp); + +class BigtableTableOp : public OpKernel { + public: + explicit BigtableTableOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_)); + OP_REQUIRES(ctx, !table_.empty(), + errors::InvalidArgument("table_name must be non-empty")); + } + + ~BigtableTableOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + + BigtableClientResource* client_resource; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource)); + core::ScopedUnref unref_client(client_resource); + + BigtableTableResource* resource; + OP_REQUIRES_OK( + ctx, mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, client_resource](BigtableTableResource** ret) { + *ret = new BigtableTableResource(client_resource, table_); + return Status::OK(); + })); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } + + private: + string table_; // Note: this is const after construction. + + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableTable").Device(DEVICE_CPU), + BigtableTableOp); + +class ToBigtableOp : public AsyncOpKernel { + public: + explicit ToBigtableOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), + thread_pool_(new thread::ThreadPool( + ctx->env(), ThreadOptions(), + strings::StrCat("to_bigtable_op_", SanitizeThreadSuffix(name())), + /* num_threads = */ 1, /* low_latency_hint = */ false)) {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + // The call to `iterator->GetNext()` may block and depend on an + // inter-op thread pool thread, so we issue the call from the + // owned thread pool. + thread_pool_->Schedule([this, ctx, done]() { + const Tensor* column_families_tensor; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->input("column_families", &column_families_tensor), done); + OP_REQUIRES_ASYNC( + ctx, column_families_tensor->dims() == 1, + errors::InvalidArgument("`column_families` must be a vector."), done); + + const Tensor* columns_tensor; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input("columns", &columns_tensor), done); + OP_REQUIRES_ASYNC(ctx, columns_tensor->dims() == 1, + errors::InvalidArgument("`columns` must be a vector."), + done); + OP_REQUIRES_ASYNC( + ctx, + columns_tensor->NumElements() == + column_families_tensor->NumElements(), + errors::InvalidArgument("len(column_families) != len(columns)"), + done); + + std::vector column_families; + column_families.reserve(column_families_tensor->NumElements()); + std::vector columns; + columns.reserve(column_families_tensor->NumElements()); + for (uint64 i = 0; i < column_families_tensor->NumElements(); ++i) { + column_families.push_back(column_families_tensor->flat()(i)); + columns.push_back(columns_tensor->flat()(i)); + } + + DatasetBase* dataset; + OP_REQUIRES_OK_ASYNC( + ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done); + + IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); + std::unique_ptr iterator; + OP_REQUIRES_OK_ASYNC( + ctx, + dataset->MakeIterator(&iter_ctx, "ToBigtableOpIterator", &iterator), + done); + + int64 timestamp_int; + OP_REQUIRES_OK_ASYNC( + ctx, ParseScalarArgument(ctx, "timestamp", ×tamp_int), + done); + OP_REQUIRES_ASYNC(ctx, timestamp_int >= -1, + errors::InvalidArgument("timestamp must be >= -1"), + done); + + BigtableTableResource* resource; + OP_REQUIRES_OK_ASYNC( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource), done); + core::ScopedUnref resource_cleanup(resource); + + std::vector components; + components.reserve(dataset->output_dtypes().size()); + bool end_of_sequence = false; + do { + ::google::cloud::bigtable::BulkMutation mutation; + // TODO(saeta): Make # of mutations configurable. + for (uint64 i = 0; i < 100 && !end_of_sequence; ++i) { + OP_REQUIRES_OK_ASYNC( + ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence), + done); + if (!end_of_sequence) { + OP_REQUIRES_OK_ASYNC( + ctx, + CreateMutation(std::move(components), column_families, columns, + timestamp_int, &mutation), + done); + } + components.clear(); + } + grpc::Status mutation_status; + std::vector<::google::cloud::bigtable::FailedMutation> failures = + resource->table().BulkApply(std::move(mutation), mutation_status); + if (!mutation_status.ok()) { + LOG(ERROR) << "Failure applying mutation: " + << mutation_status.error_code() << " - " + << mutation_status.error_message() << " (" + << mutation_status.error_details() << ")."; + } + if (!failures.empty()) { + for (const auto& failure : failures) { + LOG(ERROR) << "Failure applying mutation on row (" + << failure.original_index() + << "): " << failure.mutation().row_key() + << " - error: " << failure.status().error_message() + << " (Details: " << failure.status().error_details() + << ")."; + } + } + OP_REQUIRES_ASYNC( + ctx, failures.empty() && mutation_status.ok(), + errors::Unknown("Failure while writing to Cloud Bigtable: ", + mutation_status.error_code(), " - ", + mutation_status.error_message(), " (", + mutation_status.error_details(), + "), # of mutation failures: ", failures.size(), + ". See the log for the specific error details."), + done); + } while (!end_of_sequence); + done(); + }); + } + + private: + static string SanitizeThreadSuffix(string suffix) { + string clean; + for (int i = 0; i < suffix.size(); ++i) { + const char ch = suffix[i]; + if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || + (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') { + clean += ch; + } else { + clean += '_'; + } + } + return clean; + } + + Status CreateMutation( + std::vector tensors, const std::vector& column_families, + const std::vector& columns, int64 timestamp_int, + ::google::cloud::bigtable::BulkMutation* bulk_mutation) { + if (tensors.size() != column_families.size() + 1) { + return errors::InvalidArgument( + "Iterator produced a set of Tensors shorter than expected"); + } + ::google::cloud::bigtable::SingleRowMutation mutation( + std::move(tensors[0].scalar()())); + std::chrono::milliseconds timestamp(timestamp_int); + for (size_t i = 1; i < tensors.size(); ++i) { + if (!TensorShapeUtils::IsScalar(tensors[i].shape())) { + return errors::Internal("Output tensor ", i, " was not a scalar"); + } + if (timestamp_int == -1) { + mutation.emplace_back(::google::cloud::bigtable::SetCell( + column_families[i - 1], columns[i - 1], + std::move(tensors[i].scalar()()))); + } else { + mutation.emplace_back(::google::cloud::bigtable::SetCell( + column_families[i - 1], columns[i - 1], timestamp, + std::move(tensors[i].scalar()()))); + } + } + bulk_mutation->emplace_back(std::move(mutation)); + return Status::OK(); + } + + template + Status ParseScalarArgument(OpKernelContext* ctx, + const StringPiece& argument_name, T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar()(); + return Status::OK(); + } + + std::unique_ptr thread_pool_; +}; + +REGISTER_KERNEL_BUILDER(Name("DatasetToBigtable").Device(DEVICE_CPU), + ToBigtableOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc new file mode 100644 index 0000000000000000000000000000000000000000..67bf14c17646cff81af707405b66c9fba2ded0bd --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" + +namespace tensorflow { + +Status GrpcStatusToTfStatus(const ::grpc::Status& status) { + if (status.ok()) { + return Status::OK(); + } + auto grpc_code = status.error_code(); + if (status.error_code() == ::grpc::StatusCode::ABORTED || + status.error_code() == ::grpc::StatusCode::UNAVAILABLE || + status.error_code() == ::grpc::StatusCode::OUT_OF_RANGE) { + grpc_code = ::grpc::StatusCode::INTERNAL; + } + return Status(static_cast<::tensorflow::error::Code>(status.error_code()), + strings::StrCat("Error reading from Cloud Bigtable: ", + status.error_message(), + " (Details: ", status.error_details(), ")")); +} + +string RegexFromStringSet(const std::vector& strs) { + CHECK(!strs.empty()) << "The list of strings to turn into a regex was empty."; + std::unordered_set uniq(strs.begin(), strs.end()); + if (uniq.size() == 1) { + return *uniq.begin(); + } + return str_util::Join(uniq, "|"); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h new file mode 100644 index 0000000000000000000000000000000000000000..a2a5df1037a00ccfdff1910dd950d7b012e684e2 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h @@ -0,0 +1,143 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ +#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ + +// Note: we use bigtable/client/internal/table.h as this is the no-exception API + +#include "google/cloud/bigtable/data_client.h" +#include "google/cloud/bigtable/internal/table.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/resource_mgr.h" + +namespace tensorflow { + +Status GrpcStatusToTfStatus(const ::grpc::Status& status); + +string RegexFromStringSet(const std::vector& strs); + +class BigtableClientResource : public ResourceBase { + public: + BigtableClientResource( + string project_id, string instance_id, + std::shared_ptr client) + : project_id_(std::move(project_id)), + instance_id_(std::move(instance_id)), + client_(std::move(client)) {} + + std::shared_ptr get_client() { + return client_; + } + + string DebugString() override { + return strings::StrCat("BigtableClientResource(project_id: ", project_id_, + ", instance_id: ", instance_id_, ")"); + } + + private: + const string project_id_; + const string instance_id_; + std::shared_ptr client_; +}; + +class BigtableTableResource : public ResourceBase { + public: + BigtableTableResource(BigtableClientResource* client, string table_name) + : client_(client), + table_name_(std::move(table_name)), + table_(client->get_client(), table_name_, + google::cloud::bigtable::AlwaysRetryMutationPolicy()) { + client_->Ref(); + } + + ~BigtableTableResource() override { client_->Unref(); } + + ::google::cloud::bigtable::noex::Table& table() { return table_; } + + string DebugString() override { + return strings::StrCat( + "BigtableTableResource(client: ", client_->DebugString(), + ", table: ", table_name_, ")"); + } + + private: + BigtableClientResource* client_; // Ownes one ref. + const string table_name_; + ::google::cloud::bigtable::noex::Table table_; +}; + +// BigtableReaderDatasetIterator is an abstract class for iterators from +// datasets that are "readers" (source datasets, not transformation datasets) +// that read from Bigtable. +template +class BigtableReaderDatasetIterator : public DatasetIterator { + public: + explicit BigtableReaderDatasetIterator( + const typename DatasetIterator::Params& params) + : DatasetIterator(params), iterator_(nullptr, false) {} + + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(EnsureIteratorInitialized()); + if (iterator_ == reader_->end()) { + grpc::Status status = reader_->Finish(); + if (status.ok()) { + *end_of_sequence = true; + return Status::OK(); + } + return GrpcStatusToTfStatus(status); + } + *end_of_sequence = false; + google::cloud::bigtable::Row& row = *iterator_; + Status s = ParseRow(ctx, row, out_tensors); + // Ensure we always advance. + ++iterator_; + return s; + } + + protected: + virtual ::google::cloud::bigtable::RowRange MakeRowRange() = 0; + virtual ::google::cloud::bigtable::Filter MakeFilter() = 0; + virtual Status ParseRow(IteratorContext* ctx, + const ::google::cloud::bigtable::Row& row, + std::vector* out_tensors) = 0; + + private: + Status EnsureIteratorInitialized() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (reader_) { + return Status::OK(); + } + + auto rows = MakeRowRange(); + auto filter = MakeFilter(); + + // Note: the this in `this->dataset()` below is necessary due to namespace + // name conflicts. + reader_.reset(new ::google::cloud::bigtable::RowReader( + this->dataset()->table()->table().ReadRows(rows, filter))); + iterator_ = reader_->begin(); + return Status::OK(); + } + + mutex mu_; + std::unique_ptr<::google::cloud::bigtable::RowReader> reader_ GUARDED_BY(mu_); + ::google::cloud::bigtable::RowReader::iterator iterator_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9e49fa35db4b2cd2c8991100a28a5b9c55f01ffe --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc @@ -0,0 +1,221 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { + public: + using UnaryDatasetOpKernel::UnaryDatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + BigtableTableResource* table; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table)); + + std::vector column_families; + std::vector columns; + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "column_families", + &column_families)); + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "columns", &columns)); + OP_REQUIRES( + ctx, column_families.size() == columns.size(), + errors::InvalidArgument("len(columns) != len(column_families)")); + + const uint64 num_outputs = columns.size() + 1; + std::vector output_shapes; + output_shapes.reserve(num_outputs); + DataTypeVector output_types; + output_types.reserve(num_outputs); + for (uint64 i = 0; i < num_outputs; ++i) { + output_shapes.push_back({}); + output_types.push_back(DT_STRING); + } + + *output = + new Dataset(ctx, input, table, std::move(column_families), + std::move(columns), output_types, std::move(output_shapes)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, + BigtableTableResource* table, + std::vector column_families, + std::vector columns, + const DataTypeVector& output_types, + std::vector output_shapes) + : GraphDatasetBase(ctx), + input_(input), + table_(table), + column_families_(std::move(column_families)), + columns_(std::move(columns)), + output_types_(output_types), + output_shapes_(std::move(output_shapes)), + filter_(MakeFilter(column_families_, columns_)) { + table_->Ref(); + input_->Ref(); + } + + ~Dataset() override { + table_->Unref(); + input_->Unref(); + } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::BigtableLookupDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "BigtableLookupDatasetOp::Dataset"; + } + + private: + static ::google::cloud::bigtable::Filter MakeFilter( + const std::vector& column_families, + const std::vector& columns) { + string column_family_regex = RegexFromStringSet(column_families); + string column_regex = RegexFromStringSet(columns); + + return ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1), + ::google::cloud::bigtable::Filter::FamilyRegex(column_family_regex), + ::google::cloud::bigtable::Filter::ColumnRegex(column_regex)); + } + + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); // Sequence requests. + std::vector input_tensors; + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, &input_tensors, end_of_sequence)); + if (*end_of_sequence) { + return Status::OK(); + } + if (input_tensors.size() != 1) { + return errors::InvalidArgument( + "Upstream iterator (", dataset()->input_->DebugString(), + ") did not produce a single `tf.string` `tf.Tensor`. It " + "produced ", + input_tensors.size(), " tensors."); + } + if (input_tensors[0].NumElements() == 0) { + return errors::InvalidArgument("Upstream iterator (", + dataset()->input_->DebugString(), + ") return an empty set of keys."); + } + if (input_tensors[0].NumElements() == 1) { + // Single key lookup. + ::grpc::Status status; + auto pair = dataset()->table_->table().ReadRow( + input_tensors[0].scalar()(), dataset()->filter_, status); + if (!status.ok()) { + return GrpcStatusToTfStatus(status); + } + if (!pair.first) { + return errors::DataLoss("Row key '", + input_tensors[0].scalar()(), + "' not found."); + } + TF_RETURN_IF_ERROR(ParseRow(ctx, pair.second, out_tensors)); + } else { + // Batched get. + return errors::Unimplemented( + "BigtableLookupDataset doesn't yet support batched retrieval."); + } + return Status::OK(); + } + + private: + Status ParseRow(IteratorContext* ctx, + const ::google::cloud::bigtable::Row& row, + std::vector* out_tensors) { + out_tensors->reserve(dataset()->columns_.size() + 1); + Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {}); + row_key_tensor.scalar()() = string(row.row_key()); + out_tensors->emplace_back(std::move(row_key_tensor)); + + if (row.cells().size() > 2 * dataset()->columns_.size()) { + LOG(WARNING) << "An excessive number of columns (" + << row.cells().size() + << ") were retrieved when reading row: " + << row.row_key(); + } + + for (uint64 i = 0; i < dataset()->columns_.size(); ++i) { + Tensor col_tensor(ctx->allocator({}), DT_STRING, {}); + bool found_column = false; + for (auto cell_itr = row.cells().begin(); + !found_column && cell_itr != row.cells().end(); ++cell_itr) { + if (cell_itr->family_name() == dataset()->column_families_[i] && + string(cell_itr->column_qualifier()) == + dataset()->columns_[i]) { + col_tensor.scalar()() = string(cell_itr->value()); + found_column = true; + } + } + if (!found_column) { + return errors::DataLoss("Column ", dataset()->column_families_[i], + ":", dataset()->columns_[i], + " not found in row: ", row.row_key()); + } + out_tensors->emplace_back(std::move(col_tensor)); + } + return Status::OK(); + } + + mutex mu_; + std::unique_ptr input_impl_ GUARDED_BY(mu_); + }; + + const DatasetBase* const input_; + BigtableTableResource* table_; + const std::vector column_families_; + const std::vector columns_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + const ::google::cloud::bigtable::Filter filter_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableLookupDataset").Device(DEVICE_CPU), + BigtableLookupDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e960719614a1c7c6c4af53ea924aef214a09b24d --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefix", &prefix)); + + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + + *output = new Dataset(ctx, resource, std::move(prefix)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, + string prefix) + : GraphDatasetBase(ctx), table_(table), prefix_(std::move(prefix)) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::BigtablePrefixKeyDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}}); + return *shapes; + } + + string DebugString() const override { + return "BigtablePrefixKeyDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + private: + class Iterator : public BigtableReaderDatasetIterator { + public: + explicit Iterator(const Params& params) + : BigtableReaderDatasetIterator(params) {} + + ::google::cloud::bigtable::RowRange MakeRowRange() override { + return ::google::cloud::bigtable::RowRange::Prefix(dataset()->prefix_); + } + ::google::cloud::bigtable::Filter MakeFilter() override { + return ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::CellsRowLimit(1), + ::google::cloud::bigtable::Filter::StripValueTransformer()); + } + Status ParseRow(IteratorContext* ctx, + const ::google::cloud::bigtable::Row& row, + std::vector* out_tensors) override { + Tensor output_tensor(ctx->allocator({}), DT_STRING, {}); + output_tensor.scalar()() = string(row.row_key()); + out_tensors->emplace_back(std::move(output_tensor)); + return Status::OK(); + } + }; + + BigtableTableResource* const table_; + const string prefix_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtablePrefixKeyDataset").Device(DEVICE_CPU), + BigtablePrefixKeyDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.cc new file mode 100644 index 0000000000000000000000000000000000000000..51965f6214413c08453473e71c30eecbd8925a64 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.cc @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h" + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +namespace { + +string MakePrefixEndKey(const string& prefix) { + string end = prefix; + while (true) { + if (end.empty()) { + return end; + } + ++end[end.size() - 1]; + if (end[end.size() - 1] == 0) { + // Handle wraparound case. + end = end.substr(0, end.size() - 1); + } else { + return end; + } + } +} + +} // namespace + +/* static */ MultiModeKeyRange MultiModeKeyRange::FromPrefix(string prefix) { + string end = MakePrefixEndKey(prefix); + VLOG(1) << "Creating MultiModeKeyRange from Prefix: " << prefix + << ", with end key: " << end; + return MultiModeKeyRange(std::move(prefix), std::move(end)); +} + +/* static */ MultiModeKeyRange MultiModeKeyRange::FromRange(string begin, + string end) { + return MultiModeKeyRange(std::move(begin), std::move(end)); +} + +const string& MultiModeKeyRange::begin_key() const { return begin_; } + +const string& MultiModeKeyRange::end_key() const { return end_; } + +bool MultiModeKeyRange::contains_key(StringPiece key) const { + if (StringPiece(begin_) > key) { + return false; + } + if (StringPiece(end_) <= key && !end_.empty()) { + return false; + } + return true; +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..44c628e366c26b88011642f1e8e8d8e74b4698fd --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_ +#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_ + +#include + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Represents a continuous range of keys defined by either a prefix or a range. +// +// Ranges are represented as "half-open", where the beginning key is included +// in the range, and the end_key is the first excluded key after the range. +// +// The range of keys can be specified either by a key prefix, or by an explicit +// begin key and end key. All methods on this class are valid no matter which +// way the range was specified. +// +// Example: +// MultiModeKeyRange range = MultiModeKeyRange::FromPrefix("myPrefix"); +// if (range.contains_key("myPrefixedKey")) { +// LOG(INFO) << "range from " << range.begin_key() << " to " +// << range.end_key() << "contains \"myPrefixedKey\""; +// } +// if (!range.contains_key("randomKey")) { +// LOG(INFO) << "range does not contain \"randomKey\""; +// } +// range = MultiModeKeyRange::FromRange("a_start_key", "z_end_key"); +class MultiModeKeyRange { + public: + static MultiModeKeyRange FromPrefix(string prefix); + static MultiModeKeyRange FromRange(string begin, string end); + + // The first valid key in the range. + const string& begin_key() const; + // The first invalid key after the valid range. + const string& end_key() const; + // Returns true if the provided key is a part of the range, false otherwise. + bool contains_key(StringPiece key) const; + + private: + MultiModeKeyRange(string begin, string end) + : begin_(std::move(begin)), end_(std::move(end)) {} + + const string begin_; + const string end_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_ diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers_test.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1bfc547271d5e58a9145b73356b2b558dc1af9f1 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers_test.cc @@ -0,0 +1,107 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(MultiModeKeyRangeTest, SimplePrefix) { + MultiModeKeyRange r = MultiModeKeyRange::FromPrefix("prefix"); + EXPECT_EQ("prefix", r.begin_key()); + EXPECT_EQ("prefiy", r.end_key()); + EXPECT_TRUE(r.contains_key("prefixed_key")); + EXPECT_FALSE(r.contains_key("not-prefixed-key")); + EXPECT_FALSE(r.contains_key("prefi")); + EXPECT_FALSE(r.contains_key("prefiy")); + EXPECT_FALSE(r.contains_key("early")); + EXPECT_FALSE(r.contains_key("")); +} + +TEST(MultiModeKeyRangeTest, Range) { + MultiModeKeyRange r = MultiModeKeyRange::FromRange("a", "b"); + EXPECT_EQ("a", r.begin_key()); + EXPECT_EQ("b", r.end_key()); + EXPECT_TRUE(r.contains_key("a")); + EXPECT_TRUE(r.contains_key("ab")); + EXPECT_FALSE(r.contains_key("b")); + EXPECT_FALSE(r.contains_key("bc")); + EXPECT_FALSE(r.contains_key("A")); + EXPECT_FALSE(r.contains_key("B")); + EXPECT_FALSE(r.contains_key("")); +} + +TEST(MultiModeKeyRangeTest, InvertedRange) { + MultiModeKeyRange r = MultiModeKeyRange::FromRange("b", "a"); + EXPECT_FALSE(r.contains_key("a")); + EXPECT_FALSE(r.contains_key("b")); + EXPECT_FALSE(r.contains_key("")); +} + +TEST(MultiModeKeyRangeTest, EmptyPrefix) { + MultiModeKeyRange r = MultiModeKeyRange::FromPrefix(""); + EXPECT_EQ("", r.begin_key()); + EXPECT_EQ("", r.end_key()); + EXPECT_TRUE(r.contains_key("")); + EXPECT_TRUE(r.contains_key("a")); + EXPECT_TRUE(r.contains_key("z")); + EXPECT_TRUE(r.contains_key("A")); + EXPECT_TRUE(r.contains_key("ZZZZZZ")); +} + +TEST(MultiModeKeyRangeTest, HalfRange) { + MultiModeKeyRange r = MultiModeKeyRange::FromRange("start", ""); + EXPECT_EQ("start", r.begin_key()); + EXPECT_EQ("", r.end_key()); + EXPECT_TRUE(r.contains_key("start")); + EXPECT_TRUE(r.contains_key("starting")); + EXPECT_TRUE(r.contains_key("z-end")); + EXPECT_FALSE(r.contains_key("")); + EXPECT_FALSE(r.contains_key("early")); +} + +TEST(MultiModeKeyRangeTest, PrefixWrapAround) { + string prefix = "abc\xff"; + MultiModeKeyRange r = MultiModeKeyRange::FromPrefix(prefix); + EXPECT_EQ(prefix, r.begin_key()); + EXPECT_EQ("abd", r.end_key()); + + EXPECT_TRUE(r.contains_key("abc\xff\x07")); + EXPECT_TRUE(r.contains_key("abc\xff\x15")); + EXPECT_TRUE(r.contains_key("abc\xff\x61")); + EXPECT_TRUE(r.contains_key("abc\xff\xff")); + EXPECT_FALSE(r.contains_key("abc\0")); + EXPECT_FALSE(r.contains_key("abd")); +} + +TEST(MultiModeKeyRangeTest, PrefixSignedWrapAround) { + string prefix = "abc\x7f"; + MultiModeKeyRange r = MultiModeKeyRange::FromPrefix(prefix); + EXPECT_EQ(prefix, r.begin_key()); + EXPECT_EQ("abc\x80", r.end_key()); + + EXPECT_TRUE(r.contains_key("abc\x7f\x07")); + EXPECT_TRUE(r.contains_key("abc\x7f\x15")); + EXPECT_TRUE(r.contains_key("abc\x7f\x61")); + EXPECT_TRUE(r.contains_key("abc\x7f\xff")); + EXPECT_FALSE(r.contains_key("abc\0")); + EXPECT_FALSE(r.contains_key("abc\x01")); + EXPECT_FALSE(r.contains_key("abd")); + EXPECT_FALSE(r.contains_key("ab\x80")); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..96d3565d9b90e72f9e25e69e91f1931c982714cd --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableRangeKeyDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string start_key; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "start_key", &start_key)); + string end_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "end_key", &end_key)); + + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + + *output = + new Dataset(ctx, resource, std::move(start_key), std::move(end_key)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, + string start_key, string end_key) + : GraphDatasetBase(ctx), + table_(table), + start_key_(std::move(start_key)), + end_key_(std::move(end_key)) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::BigtableRangeKeyDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}}); + return *shapes; + } + + string DebugString() const override { + return "BigtableRangeKeyDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + private: + class Iterator : public BigtableReaderDatasetIterator { + public: + explicit Iterator(const Params& params) + : BigtableReaderDatasetIterator(params) {} + + ::google::cloud::bigtable::RowRange MakeRowRange() override { + return ::google::cloud::bigtable::RowRange::Range(dataset()->start_key_, + dataset()->end_key_); + } + ::google::cloud::bigtable::Filter MakeFilter() override { + return ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::CellsRowLimit(1), + ::google::cloud::bigtable::Filter::StripValueTransformer()); + } + Status ParseRow(IteratorContext* ctx, + const ::google::cloud::bigtable::Row& row, + std::vector* out_tensors) override { + Tensor output_tensor(ctx->allocator({}), DT_STRING, {}); + output_tensor.scalar()() = string(row.row_key()); + out_tensors->emplace_back(std::move(output_tensor)); + return Status::OK(); + } + }; + + BigtableTableResource* const table_; + const string start_key_; + const string end_key_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableRangeKeyDataset").Device(DEVICE_CPU), + BigtableRangeKeyDatasetOp); +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a1a63a975afd62325e01586542006058fa2c83bc --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc @@ -0,0 +1,200 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefix", &prefix)); + + string start_key; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "start_key", &start_key)); + string end_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "end_key", &end_key)); + + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + + OP_REQUIRES(ctx, prefix.empty() || start_key.empty(), + errors::InvalidArgument( + "Only one of prefix and start_key can be provided")); + if (!prefix.empty()) { + OP_REQUIRES(ctx, end_key.empty(), + errors::InvalidArgument( + "If prefix is specified, end_key must be empty.")); + } + + *output = new Dataset(ctx, resource, std::move(prefix), + std::move(start_key), std::move(end_key)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, + string prefix, string start_key, string end_key) + : GraphDatasetBase(ctx), + table_(table), + key_range_(MakeMultiModeKeyRange( + std::move(prefix), std::move(start_key), std::move(end_key))) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::BigtableSampleKeyPairsDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = + new DataTypeVector({DT_STRING, DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}, {}}); + return *shapes; + } + + string DebugString() const override { + return "BigtableSampleKeyPairsDatasetOp::Dataset"; + } + + private: + static MultiModeKeyRange MakeMultiModeKeyRange(string prefix, + string start_key, + string end_key) { + if (!start_key.empty()) { + return MultiModeKeyRange::FromRange(std::move(start_key), + std::move(end_key)); + } + return MultiModeKeyRange::FromPrefix(std::move(prefix)); + } + + BigtableTableResource& table() const { return *table_; } + + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + // Computes split points (`keys_`) to use when scanning the table. + // + // Initialize first retrieves the sample keys from the table (`row_keys`), + // as these often form good split points within the table. We then iterate + // over them, and copy them to `keys_` if they fall within the requested + // range to scan (`dataset()->key_range_`). Because the requested range + // might start between elements of the sampled keys list, care is taken to + // ensure we don't accidentally miss any subsets of the requested range by + // including `begin_key()` and `end_key()` as appropriate. + Status Initialize(IteratorContext* ctx) override { + grpc::Status status; + std::vector row_keys = + dataset()->table().table().SampleRows(status); + if (!status.ok()) { + return GrpcStatusToTfStatus(status); + } + + for (size_t i = 0; i < row_keys.size(); ++i) { + string row_key(row_keys[i].row_key); + if (dataset()->key_range_.contains_key(row_key)) { + // First key: check to see if we need to add the begin_key. + if (keys_.empty() && dataset()->key_range_.begin_key() != row_key) { + keys_.push_back(dataset()->key_range_.begin_key()); + } + keys_.push_back(std::move(row_key)); + } else if (!keys_.empty()) { + // If !keys_.empty(), then we have found at least one element of + // `row_keys` that is within our requested range + // (`dataset()->key_range_`). Because `row_keys` is sorted, if we + // have found an element that's not within our key range, then we + // are after our requested range (ranges are contiguous) and can end + // iteration early. + break; + } + } + + // Handle the case where we skip over the selected range entirely. + if (keys_.empty()) { + keys_.push_back(dataset()->key_range_.begin_key()); + } + + // Last key: check to see if we need to add the end_key. + if (keys_.back() != dataset()->key_range_.end_key()) { + keys_.push_back(dataset()->key_range_.end_key()); + } + return Status::OK(); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (index_ > keys_.size() - 2) { + *end_of_sequence = true; + return Status::OK(); + } + + *end_of_sequence = false; + out_tensors->emplace_back(ctx->allocator({}), DT_STRING, + TensorShape({})); + out_tensors->back().scalar()() = keys_[index_]; + + out_tensors->emplace_back(ctx->allocator({}), DT_STRING, + TensorShape({})); + out_tensors->back().scalar()() = keys_[index_ + 1]; + ++index_; + + return Status::OK(); + } + + private: + mutex mu_; + size_t index_ GUARDED_BY(mu_) = 0; + // Note: we store the keys_ on the iterator instead of the dataset + // because we want to re-sample the row keys in case there have been + // tablet rebalancing operations since the dataset was created. + // + // Note: keys_ is readonly after Initialize, and thus does not need a + // guarding lock. + std::vector keys_; + }; + + BigtableTableResource* const table_; + const MultiModeKeyRange key_range_; + }; +}; + +REGISTER_KERNEL_BUILDER( + Name("BigtableSampleKeyPairsDataset").Device(DEVICE_CPU), + BigtableSampleKeyPairsDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a5a47cfe2dcf7c4034e0d5bc7d9a73ef9c1dc94e --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc @@ -0,0 +1,113 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableSampleKeysDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + *output = new Dataset(ctx, resource); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table) + : GraphDatasetBase(ctx), table_(table) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::BigtableSampleKeysDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}}); + return *shapes; + } + + string DebugString() const override { + return "BigtableRangeKeyDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + ::grpc::Status status; + row_keys_ = dataset()->table()->table().SampleRows(status); + if (!status.ok()) { + row_keys_.clear(); + return GrpcStatusToTfStatus(status); + } + return Status::OK(); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (index_ < row_keys_.size()) { + out_tensors->emplace_back(ctx->allocator({}), DT_STRING, + TensorShape({})); + out_tensors->back().scalar()() = + string(row_keys_[index_].row_key); + *end_of_sequence = false; + index_++; + } else { + *end_of_sequence = true; + } + return Status::OK(); + } + + private: + mutex mu_; + size_t index_ = 0; + std::vector<::google::cloud::bigtable::RowKeySample> row_keys_; + }; + + BigtableTableResource* const table_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableSampleKeysDataset").Device(DEVICE_CPU), + BigtableSampleKeysDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..13cb8681679ec1541b74a20474665f770790201f --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc @@ -0,0 +1,219 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BigtableScanDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefix", &prefix)); + string start_key; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "start_key", &start_key)); + string end_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "end_key", &end_key)); + + OP_REQUIRES(ctx, !(prefix.empty() && start_key.empty()), + errors::InvalidArgument( + "Either prefix or start_key must be specified")); + OP_REQUIRES(ctx, prefix.empty() || start_key.empty(), + errors::InvalidArgument( + "Only one of prefix and start_key can be provided")); + if (!prefix.empty()) { + OP_REQUIRES(ctx, end_key.empty(), + errors::InvalidArgument( + "If prefix is specified, end_key must be empty.")); + } + + std::vector column_families; + std::vector columns; + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "column_families", + &column_families)); + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "columns", &columns)); + OP_REQUIRES( + ctx, column_families.size() == columns.size(), + errors::InvalidArgument("len(columns) != len(column_families)")); + OP_REQUIRES(ctx, !column_families.empty(), + errors::InvalidArgument("`column_families` is empty")); + + float probability = 0; + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "probability", &probability)); + OP_REQUIRES( + ctx, probability > 0 && probability <= 1, + errors::InvalidArgument( + "Probability outside the range of (0, 1]. Got: ", probability)); + + BigtableTableResource* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + + const uint64 num_outputs = columns.size() + 1; + std::vector output_shapes; + output_shapes.reserve(num_outputs); + DataTypeVector output_types; + output_types.reserve(num_outputs); + for (uint64 i = 0; i < num_outputs; ++i) { + output_shapes.push_back({}); + output_types.push_back(DT_STRING); + } + + *output = new Dataset(ctx, resource, std::move(prefix), + std::move(start_key), std::move(end_key), + std::move(column_families), std::move(columns), + probability, output_types, std::move(output_shapes)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, + string prefix, string start_key, string end_key, + std::vector column_families, + std::vector columns, float probability, + const DataTypeVector& output_types, + std::vector output_shapes) + : GraphDatasetBase(ctx), + table_(table), + prefix_(std::move(prefix)), + start_key_(std::move(start_key)), + end_key_(std::move(end_key)), + column_families_(std::move(column_families)), + columns_(std::move(columns)), + column_family_regex_(RegexFromStringSet(column_families_)), + column_regex_(RegexFromStringSet(columns_)), + probability_(probability), + output_types_(output_types), + output_shapes_(std::move(output_shapes)) { + table_->Ref(); + } + + ~Dataset() override { table_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::BigtableScanDataset")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "BigtableScanDatasetOp::Dataset"; + } + + BigtableTableResource* table() const { return table_; } + + private: + class Iterator : public BigtableReaderDatasetIterator { + public: + explicit Iterator(const Params& params) + : BigtableReaderDatasetIterator(params) {} + + ::google::cloud::bigtable::RowRange MakeRowRange() override { + if (!dataset()->prefix_.empty()) { + DCHECK(dataset()->start_key_.empty()); + return ::google::cloud::bigtable::RowRange::Prefix( + dataset()->prefix_); + } else { + DCHECK(!dataset()->start_key_.empty()) + << "Both prefix and start_key were empty!"; + return ::google::cloud::bigtable::RowRange::Range( + dataset()->start_key_, dataset()->end_key_); + } + } + ::google::cloud::bigtable::Filter MakeFilter() override { + // TODO(saeta): Investigate optimal ordering here. + return ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1), + ::google::cloud::bigtable::Filter::FamilyRegex( + dataset()->column_family_regex_), + ::google::cloud::bigtable::Filter::ColumnRegex( + dataset()->column_regex_), + dataset()->probability_ != 1.0 + ? ::google::cloud::bigtable::Filter::RowSample( + dataset()->probability_) + : ::google::cloud::bigtable::Filter::PassAllFilter()); + } + Status ParseRow(IteratorContext* ctx, + const ::google::cloud::bigtable::Row& row, + std::vector* out_tensors) override { + out_tensors->reserve(dataset()->columns_.size() + 1); + Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {}); + row_key_tensor.scalar()() = string(row.row_key()); + out_tensors->emplace_back(std::move(row_key_tensor)); + + if (row.cells().size() > 2 * dataset()->columns_.size()) { + LOG(WARNING) << "An excessive number of columns (" + << row.cells().size() + << ") were retrieved when reading row: " + << row.row_key(); + } + + for (uint64 i = 0; i < dataset()->columns_.size(); ++i) { + Tensor col_tensor(ctx->allocator({}), DT_STRING, {}); + bool found_column = false; + for (auto cell_itr = row.cells().begin(); + !found_column && cell_itr != row.cells().end(); ++cell_itr) { + if (cell_itr->family_name() == dataset()->column_families_[i] && + string(cell_itr->column_qualifier()) == + dataset()->columns_[i]) { + col_tensor.scalar()() = string(cell_itr->value()); + found_column = true; + } + } + if (!found_column) { + return errors::InvalidArgument( + "Column ", dataset()->column_families_[i], ":", + dataset()->columns_[i], " not found in row: ", row.row_key()); + } + out_tensors->emplace_back(std::move(col_tensor)); + } + return Status::OK(); + } + }; + + BigtableTableResource* table_; + const string prefix_; + const string start_key_; + const string end_key_; + const std::vector column_families_; + const std::vector columns_; + const string column_family_regex_; + const string column_regex_; + const float probability_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableScanDataset").Device(DEVICE_CPU), + BigtableScanDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc new file mode 100644 index 0000000000000000000000000000000000000000..f083ce6f44b3c2a83d9b5d3235056eb94c4be4a8 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -0,0 +1,374 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" + +#include "google/bigtable/v2/data.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "re2/re2.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/util/ptr_util.h" +// #include "util/task/codes.pb.h" + +namespace tensorflow { +namespace { + +void UpdateRow(const ::google::bigtable::v2::Mutation& mut, + std::map* row) { + if (mut.has_set_cell()) { + CHECK(mut.set_cell().timestamp_micros() >= -1) + << "Timestamp_micros: " << mut.set_cell().timestamp_micros(); + auto col = + strings::Printf("%s:%s", mut.set_cell().family_name().c_str(), + string(mut.set_cell().column_qualifier()).c_str()); + (*row)[col] = string(mut.set_cell().value()); + } else if (mut.has_delete_from_column()) { + auto col = strings::Printf( + "%s:%s", mut.delete_from_column().family_name().c_str(), + string(mut.delete_from_column().column_qualifier()).c_str()); + row->erase(col); + } else if (mut.has_delete_from_family()) { + auto itr = row->lower_bound(mut.delete_from_family().family_name()); + auto prefix = + strings::Printf("%s:", mut.delete_from_family().family_name().c_str()); + while (itr != row->end() && itr->first.substr(0, prefix.size()) == prefix) { + row->erase(itr); + } + } else if (mut.has_delete_from_row()) { + row->clear(); + } else { + LOG(ERROR) << "Unknown mutation: " << mut.ShortDebugString(); + } +} + +} // namespace + +class SampleRowKeysResponse : public grpc::ClientReaderInterface< + google::bigtable::v2::SampleRowKeysResponse> { + public: + explicit SampleRowKeysResponse(BigtableTestClient* client) + : client_(client) {} + + bool NextMessageSize(uint32_t* sz) override { + mutex_lock l(mu_); + mutex_lock l2(client_->mu_); + if (num_messages_sent_ * 2 < client_->table_.rows.size()) { + *sz = 10000; // A sufficiently high enough value to not worry about. + return true; + } + return false; + } + + bool Read(google::bigtable::v2::SampleRowKeysResponse* resp) override { + // Send every other key from the table. + mutex_lock l(mu_); + mutex_lock l2(client_->mu_); + *resp = google::bigtable::v2::SampleRowKeysResponse(); + auto itr = client_->table_.rows.begin(); + for (uint64 i = 0; i < 2 * num_messages_sent_; ++i) { + ++itr; + if (itr == client_->table_.rows.end()) { + return false; + } + } + resp->set_row_key(itr->first); + resp->set_offset_bytes(100 * num_messages_sent_); + num_messages_sent_++; + return true; + } + + grpc::Status Finish() override { return grpc::Status::OK; } + + void WaitForInitialMetadata() override {} // Do nothing. + + private: + mutex mu_; + int64 num_messages_sent_ GUARDED_BY(mu_) = 0; + BigtableTestClient* client_; // Not owned. +}; + +class ReadRowsResponse : public grpc::ClientReaderInterface< + google::bigtable::v2::ReadRowsResponse> { + public: + ReadRowsResponse(BigtableTestClient* client, + google::bigtable::v2::ReadRowsRequest const& request) + : client_(client), request_(request) {} + + bool NextMessageSize(uint32_t* sz) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + *sz = 10000000; // A sufficiently high enough value to not worry about. + return true; + } + + bool Read(google::bigtable::v2::ReadRowsResponse* resp) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + sent_first_message_ = true; + RowFilter filter = MakeRowFilter(); + + mutex_lock l2(client_->mu_); + *resp = google::bigtable::v2::ReadRowsResponse(); + // Send all contents in first response. + for (auto itr = client_->table_.rows.begin(); + itr != client_->table_.rows.end(); ++itr) { + if (filter.AllowRow(itr->first)) { + ::google::bigtable::v2::ReadRowsResponse_CellChunk* chunk = nullptr; + bool sent_first = false; + for (auto col_itr = itr->second.columns.begin(); + col_itr != itr->second.columns.end(); ++col_itr) { + if (filter.AllowColumn(col_itr->first)) { + chunk = resp->add_chunks(); + if (!sent_first) { + sent_first = true; + chunk->set_row_key(itr->first); + } + auto colon_idx = col_itr->first.find(":"); + CHECK(colon_idx != string::npos) + << "No ':' found in: " << col_itr->first; + chunk->mutable_family_name()->set_value( + string(col_itr->first, 0, colon_idx)); + chunk->mutable_qualifier()->set_value( + string(col_itr->first, ++colon_idx)); + if (!filter.strip_values) { + chunk->set_value(col_itr->second); + } + if (filter.only_one_column) { + break; + } + } + } + if (sent_first) { + // We are sending this row, so set the commit flag on the last chunk. + chunk->set_commit_row(true); + } + } + } + return true; + } + + grpc::Status Finish() override { return grpc::Status::OK; } + + void WaitForInitialMetadata() override {} // Do nothing. + + private: + struct RowFilter { + std::set row_set; + std::vector> row_ranges; + double row_sample = 0.0; // Note: currently ignored. + std::unique_ptr col_filter; + bool strip_values = false; + bool only_one_column = false; + + bool AllowRow(const string& row) { + if (row_set.find(row) != row_set.end()) { + return true; + } + for (const auto& range : row_ranges) { + if (range.first <= row && range.second > row) { + return true; + } + } + return false; + } + + bool AllowColumn(const string& col) { + if (col_filter) { + return RE2::FullMatch(col, *col_filter); + } else { + return true; + } + } + }; + + RowFilter MakeRowFilter() { + RowFilter filter; + for (auto i = request_.rows().row_keys().begin(); + i != request_.rows().row_keys().end(); ++i) { + filter.row_set.insert(string(*i)); + } + for (auto i = request_.rows().row_ranges().begin(); + i != request_.rows().row_ranges().end(); ++i) { + if (i->start_key_case() != + google::bigtable::v2::RowRange::kStartKeyClosed || + i->end_key_case() != google::bigtable::v2::RowRange::kEndKeyOpen) { + LOG(WARNING) << "Skipping row range that cannot be processed: " + << i->ShortDebugString(); + continue; + } + filter.row_ranges.emplace_back(std::make_pair( + string(i->start_key_closed()), string(i->end_key_open()))); + } + if (request_.filter().has_chain()) { + string family_filter; + string qualifier_filter; + for (auto i = request_.filter().chain().filters().begin(); + i != request_.filter().chain().filters().end(); ++i) { + switch (i->filter_case()) { + case google::bigtable::v2::RowFilter::kFamilyNameRegexFilter: + family_filter = i->family_name_regex_filter(); + break; + case google::bigtable::v2::RowFilter::kColumnQualifierRegexFilter: + qualifier_filter = i->column_qualifier_regex_filter(); + break; + case google::bigtable::v2::RowFilter::kCellsPerColumnLimitFilter: + if (i->cells_per_column_limit_filter() != 1) { + LOG(ERROR) << "Unexpected cells_per_column_limit_filter: " + << i->cells_per_column_limit_filter(); + } + break; + case google::bigtable::v2::RowFilter::kStripValueTransformer: + filter.strip_values = i->strip_value_transformer(); + break; + case google::bigtable::v2::RowFilter::kRowSampleFilter: + LOG(INFO) << "Ignoring row sample directive."; + break; + case google::bigtable::v2::RowFilter::kPassAllFilter: + break; + case google::bigtable::v2::RowFilter::kCellsPerRowLimitFilter: + filter.only_one_column = true; + break; + default: + LOG(WARNING) << "Ignoring unknown filter type: " + << i->ShortDebugString(); + } + } + if (family_filter.empty() || qualifier_filter.empty()) { + LOG(WARNING) << "Missing regex!"; + } else { + string regex = strings::Printf("%s:%s", family_filter.c_str(), + qualifier_filter.c_str()); + filter.col_filter.reset(new RE2(regex)); + } + } else { + LOG(WARNING) << "Read request did not have a filter chain specified: " + << request_.filter().DebugString(); + } + return filter; + } + + mutex mu_; + bool sent_first_message_ GUARDED_BY(mu_) = false; + BigtableTestClient* client_; // Not owned. + const google::bigtable::v2::ReadRowsRequest request_; +}; + +class MutateRowsResponse : public grpc::ClientReaderInterface< + google::bigtable::v2::MutateRowsResponse> { + public: + explicit MutateRowsResponse(size_t num_successes) + : num_successes_(num_successes) {} + + bool NextMessageSize(uint32_t* sz) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + *sz = 10000000; // A sufficiently high enough value to not worry about. + return true; + } + + bool Read(google::bigtable::v2::MutateRowsResponse* resp) override { + mutex_lock l(mu_); + if (sent_first_message_) { + return false; + } + sent_first_message_ = true; + *resp = google::bigtable::v2::MutateRowsResponse(); + for (size_t i = 0; i < num_successes_; ++i) { + auto entry = resp->add_entries(); + entry->set_index(i); + } + return true; + } + + grpc::Status Finish() override { return grpc::Status::OK; } + + void WaitForInitialMetadata() override {} // Do nothing. + + private: + const size_t num_successes_; + + mutex mu_; + bool sent_first_message_ = false; +}; + +grpc::Status BigtableTestClient::MutateRow( + grpc::ClientContext* context, + google::bigtable::v2::MutateRowRequest const& request, + google::bigtable::v2::MutateRowResponse* response) { + mutex_lock l(mu_); + auto* row = &table_.rows[string(request.row_key())]; + for (int i = 0; i < request.mutations_size(); ++i) { + UpdateRow(request.mutations(i), &row->columns); + } + *response = google::bigtable::v2::MutateRowResponse(); + return grpc::Status::OK; +} +grpc::Status BigtableTestClient::CheckAndMutateRow( + grpc::ClientContext* context, + google::bigtable::v2::CheckAndMutateRowRequest const& request, + google::bigtable::v2::CheckAndMutateRowResponse* response) { + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, + "CheckAndMutateRow not implemented."); +} +grpc::Status BigtableTestClient::ReadModifyWriteRow( + grpc::ClientContext* context, + google::bigtable::v2::ReadModifyWriteRowRequest const& request, + google::bigtable::v2::ReadModifyWriteRowResponse* response) { + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, + "ReadModifyWriteRow not implemented."); +} +std::unique_ptr< + grpc::ClientReaderInterface> +BigtableTestClient::ReadRows( + grpc::ClientContext* context, + google::bigtable::v2::ReadRowsRequest const& request) { + return MakeUnique(this, request); +} + +std::unique_ptr< + grpc::ClientReaderInterface> +BigtableTestClient::SampleRowKeys( + grpc::ClientContext* context, + google::bigtable::v2::SampleRowKeysRequest const& request) { + return MakeUnique(this); +} +std::unique_ptr< + grpc::ClientReaderInterface> +BigtableTestClient::MutateRows( + grpc::ClientContext* context, + google::bigtable::v2::MutateRowsRequest const& request) { + mutex_lock l(mu_); + for (auto i = request.entries().begin(); i != request.entries().end(); ++i) { + auto* row = &table_.rows[string(i->row_key())]; + for (auto mut = i->mutations().begin(); mut != i->mutations().end(); + ++mut) { + UpdateRow(*mut, &row->columns); + } + } + return MakeUnique(request.entries_size()); +} + +std::shared_ptr BigtableTestClient::Channel() { + LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely " + "cause a crash!"; + return nullptr; +} +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h new file mode 100644 index 0000000000000000000000000000000000000000..dac2b16a216d26f02684c7401ed2ddaa4b7baddb --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h @@ -0,0 +1,87 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_ +#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_ + +#include "google/cloud/bigtable/data_client.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +class BigtableTestClient : public ::google::cloud::bigtable::DataClient { + public: + std::string const& project_id() const override { return project_id_; } + std::string const& instance_id() const override { return instance_id_; } + void reset() override { + mutex_lock l(mu_); + table_ = Table(); + } + + grpc::Status MutateRow( + grpc::ClientContext* context, + google::bigtable::v2::MutateRowRequest const& request, + google::bigtable::v2::MutateRowResponse* response) override; + + grpc::Status CheckAndMutateRow( + grpc::ClientContext* context, + google::bigtable::v2::CheckAndMutateRowRequest const& request, + google::bigtable::v2::CheckAndMutateRowResponse* response) override; + + grpc::Status ReadModifyWriteRow( + grpc::ClientContext* context, + google::bigtable::v2::ReadModifyWriteRowRequest const& request, + google::bigtable::v2::ReadModifyWriteRowResponse* response) override; + + std::unique_ptr< + grpc::ClientReaderInterface> + ReadRows(grpc::ClientContext* context, + google::bigtable::v2::ReadRowsRequest const& request) override; + std::unique_ptr< + grpc::ClientReaderInterface> + SampleRowKeys( + grpc::ClientContext* context, + google::bigtable::v2::SampleRowKeysRequest const& request) override; + + std::unique_ptr< + grpc::ClientReaderInterface> + MutateRows(grpc::ClientContext* context, + google::bigtable::v2::MutateRowsRequest const& request) override; + + std::shared_ptr Channel() override; + + private: + friend class SampleRowKeysResponse; + friend class ReadRowsResponse; + friend class MutateRowsResponse; + + struct Row { + string row_key; + std::map columns; + }; + struct Table { + std::map rows; + }; + + mutex mu_; + const std::string project_id_ = "testproject"; + const std::string instance_id_ = "testinstance"; + Table table_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_ diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fa3e587b90147bd519586eef0cfb5e048b1b75be --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc @@ -0,0 +1,78 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h" +#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace tensorflow { + +namespace { + +class BigtableTestClientOp : public OpKernel { + public: + explicit BigtableTestClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + ~BigtableTestClientOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + BigtableClientResource* resource; + OP_REQUIRES_OK( + ctx, + mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, ctx](BigtableClientResource** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::shared_ptr client( + new BigtableTestClient()); + // Note: must make explicit copies to sequence + // them before the move of client. + string project_id = client->project_id(); + string instance_id = client->instance_id(); + *ret = new BigtableClientResource(std::move(project_id), + std::move(instance_id), + std::move(client)); + return Status::OK(); + })); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } + + private: + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; +}; + +REGISTER_KERNEL_BUILDER(Name("BigtableTestClient").Device(DEVICE_CPU), + BigtableTestClientOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..32611e2590d9a81f46d0b9dfc09fe7e0068e9671 --- /dev/null +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc @@ -0,0 +1,345 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h" +#include "google/cloud/bigtable/internal/table.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +void WriteCell(const string& row, const string& family, const string& column, + const string& value, + ::google::cloud::bigtable::noex::Table* table) { + ::google::cloud::bigtable::SingleRowMutation mut(row); + mut.emplace_back(::google::cloud::bigtable::SetCell(family, column, value)); + table->Apply(std::move(mut)); +} + +TEST(BigtableTestClientTest, EmptyRowRead) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + ::google::cloud::bigtable::RowSet rowset; + rowset.Append("r1"); + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + EXPECT_EQ(rows.begin(), rows.end()) << "Some rows were returned in response!"; + EXPECT_TRUE(rows.Finish().ok()) << "Error reading rows."; +} + +TEST(BigtableTestClientTest, SingleRowWriteAndRead) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + + ::google::cloud::bigtable::RowSet rowset("r1"); + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + auto itr = rows.begin(); + EXPECT_NE(itr, rows.end()) << "No rows were returned in response!"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + EXPECT_EQ(itr, rows.end()); + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, MultiRowWriteAndSingleRowRead) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + ::google::cloud::bigtable::RowSet rowset("r1"); + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, MultiRowWriteAndRead) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + ::google::cloud::bigtable::RowSet rowset("r1", "r2", "r3"); + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1)); + auto rows = table.ReadRows(std::move(rowset), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v2"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v3"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, MultiRowWriteAndPrefixRead) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1)); + auto rows = + table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v2"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v3"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, ColumnFiltering) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + // Extra cells + WriteCell("r1", "f2", "c1", "v1", &table); + WriteCell("r2", "f2", "c1", "v2", &table); + WriteCell("r3", "f1", "c2", "v3", &table); + + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1), + ::google::cloud::bigtable::Filter::FamilyRegex("f1"), + ::google::cloud::bigtable::Filter::ColumnRegex("c1")); + auto rows = + table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter); + auto itr = rows.begin(); + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v1"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v2"); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), "v3"); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, RowKeys) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + + // Extra cells + WriteCell("r1", "f2", "c1", "v1", &table); + WriteCell("r2", "f2", "c1", "v2", &table); + WriteCell("r3", "f1", "c2", "v3", &table); + + auto filter = ::google::cloud::bigtable::Filter::Chain( + ::google::cloud::bigtable::Filter::Latest(1), + ::google::cloud::bigtable::Filter::CellsRowLimit(1), + ::google::cloud::bigtable::Filter::StripValueTransformer()); + auto rows = + table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter); + auto itr = rows.begin(); + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r1"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), ""); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r2"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), ""); + + ++itr; + + EXPECT_NE(itr, rows.end()) << "Missing rows"; + EXPECT_EQ(itr->row_key(), "r3"); + EXPECT_EQ(itr->cells().size(), 1); + EXPECT_EQ(itr->cells()[0].family_name(), "f1"); + EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1"); + EXPECT_EQ(itr->cells()[0].value(), ""); + + ++itr; + EXPECT_EQ(itr, rows.end()) << "Extra rows in the response."; + EXPECT_TRUE(rows.Finish().ok()); +} + +TEST(BigtableTestClientTest, SampleKeys) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + WriteCell("r4", "f1", "c1", "v4", &table); + WriteCell("r5", "f1", "c1", "v5", &table); + + grpc::Status status; + auto resp = table.SampleRows(status); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(3, resp.size()); + EXPECT_EQ("r1", string(resp[0].row_key)); + EXPECT_EQ(0, resp[0].offset_bytes); + EXPECT_EQ("r3", string(resp[1].row_key)); + EXPECT_EQ(100, resp[1].offset_bytes); + EXPECT_EQ("r5", string(resp[2].row_key)); + EXPECT_EQ(200, resp[2].offset_bytes); +} + +TEST(BigtableTestClientTest, SampleKeysShort) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + + grpc::Status status; + auto resp = table.SampleRows(status); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(1, resp.size()); + EXPECT_EQ("r1", string(resp[0].row_key)); +} + +TEST(BigtableTestClientTest, SampleKeysEvenNumber) { + std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr = + std::make_shared(); + ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table"); + + WriteCell("r1", "f1", "c1", "v1", &table); + WriteCell("r2", "f1", "c1", "v2", &table); + WriteCell("r3", "f1", "c1", "v3", &table); + WriteCell("r4", "f1", "c1", "v4", &table); + + grpc::Status status; + auto resp = table.SampleRows(status); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(2, resp.size()); + EXPECT_EQ("r1", string(resp[0].row_key)); + EXPECT_EQ("r3", string(resp[1].row_key)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..416b719e30aa5f2504449d151a48e95c9105c68b --- /dev/null +++ b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc @@ -0,0 +1,107 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +// TODO(saeta): Add support for setting ClientOptions values. +REGISTER_OP("BigtableClient") + .Attr("project_id: string") + .Attr("instance_id: string") + .Attr("connection_pool_size: int") + .Attr("max_receive_message_size: int = -1") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Output("client: resource") + .SetShapeFn(shape_inference::ScalarShape); + +// TODO(saeta): Add support for Application Profiles. +// See https://cloud.google.com/bigtable/docs/app-profiles for more info. +REGISTER_OP("BigtableTable") + .Input("client: resource") + .Attr("table_name: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Output("table: resource") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("DatasetToBigtable") + .Input("table: resource") + .Input("input_dataset: variant") + .Input("column_families: string") + .Input("columns: string") + .Input("timestamp: int64") + .SetShapeFn(shape_inference::NoOutputs); + +REGISTER_OP("BigtableLookupDataset") + .Input("keys_dataset: variant") + .Input("table: resource") + .Input("column_families: string") + .Input("columns: string") + .Output("handle: variant") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("BigtablePrefixKeyDataset") + .Input("table: resource") + .Input("prefix: string") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("BigtableRangeKeyDataset") + .Input("table: resource") + .Input("start_key: string") + .Input("end_key: string") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("BigtableSampleKeysDataset") + .Input("table: resource") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("BigtableSampleKeyPairsDataset") + .Input("table: resource") + .Input("prefix: string") + .Input("start_key: string") + .Input("end_key: string") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +// TODO(saeta): Support continuing despite bad data (e.g. empty string, or +// skip incomplete row.) +REGISTER_OP("BigtableScanDataset") + .Input("table: resource") + .Input("prefix: string") + .Input("start_key: string") + .Input("end_key: string") + .Input("column_families: string") + .Input("columns: string") + .Input("probability: float") + .Output("handle: variant") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..f7d02458f63d547000f00b184b3d5e3c5007fb72 --- /dev/null +++ b/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("BigtableTestClient") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Output("client: resource") + .SetShapeFn(shape_inference::ScalarShape); + +} // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py b/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..292d8f4e51abbbd89d68b47febd86b7297bb8ed2 --- /dev/null +++ b/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""This module contains tests for the bigtable integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e36f7f32c61b50047c0d9137427f2a24462b1c9a --- /dev/null +++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py @@ -0,0 +1,272 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Bigtable Ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import bigtable +from tensorflow.contrib.bigtable.ops import gen_bigtable_ops +from tensorflow.contrib.bigtable.ops import gen_bigtable_test_ops +from tensorflow.contrib.bigtable.python.ops import bigtable_api +from tensorflow.contrib.util import loader +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test +from tensorflow.python.util import compat + +_bigtable_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_bigtable_test.so")) + + +def _ListOfTuplesOfStringsToBytes(values): + return [(compat.as_bytes(i[0]), compat.as_bytes(i[1])) for i in values] + + +class BigtableOpsTest(test.TestCase): + COMMON_ROW_KEYS = ["r1", "r2", "r3"] + COMMON_VALUES = ["v1", "v2", "v3"] + + def setUp(self): + self._client = gen_bigtable_test_ops.bigtable_test_client() + table = gen_bigtable_ops.bigtable_table(self._client, "testtable") + self._table = bigtable.BigtableTable("testtable", None, table) + + def _makeSimpleDataset(self): + output_rows = dataset_ops.Dataset.from_tensor_slices(self.COMMON_ROW_KEYS) + output_values = dataset_ops.Dataset.from_tensor_slices(self.COMMON_VALUES) + return dataset_ops.Dataset.zip((output_rows, output_values)) + + def _writeCommonValues(self, sess): + output_ds = self._makeSimpleDataset() + write_op = self._table.write(output_ds, ["cf1"], ["c1"]) + sess.run(write_op) + + def runReadKeyTest(self, read_ds): + itr = read_ds.make_initializable_iterator() + n = itr.get_next() + expected = list(self.COMMON_ROW_KEYS) + expected.reverse() + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + for i in range(3): + output = sess.run(n) + want = expected.pop() + self.assertEqual( + compat.as_bytes(want), compat.as_bytes(output), + "Unequal at step %d: want: %s, got: %s" % (i, want, output)) + + def testReadPrefixKeys(self): + self.runReadKeyTest(self._table.keys_by_prefix_dataset("r")) + + def testReadRangeKeys(self): + self.runReadKeyTest(self._table.keys_by_range_dataset("r1", "r4")) + + def runScanTest(self, read_ds): + itr = read_ds.make_initializable_iterator() + n = itr.get_next() + expected_keys = list(self.COMMON_ROW_KEYS) + expected_keys.reverse() + expected_values = list(self.COMMON_VALUES) + expected_values.reverse() + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + for i in range(3): + output = sess.run(n) + want = expected_keys.pop() + self.assertEqual( + compat.as_bytes(want), compat.as_bytes(output[0]), + "Unequal keys at step %d: want: %s, got: %s" % (i, want, output[0])) + want = expected_values.pop() + self.assertEqual( + compat.as_bytes(want), compat.as_bytes(output[1]), + "Unequal values at step: %d: want: %s, got: %s" % (i, want, + output[1])) + + def testScanPrefixStringCol(self): + self.runScanTest(self._table.scan_prefix("r", cf1="c1")) + + def testScanPrefixListCol(self): + self.runScanTest(self._table.scan_prefix("r", cf1=["c1"])) + + def testScanPrefixTupleCol(self): + self.runScanTest(self._table.scan_prefix("r", columns=("cf1", "c1"))) + + def testScanRangeStringCol(self): + self.runScanTest(self._table.scan_range("r1", "r4", cf1="c1")) + + def testScanRangeListCol(self): + self.runScanTest(self._table.scan_range("r1", "r4", cf1=["c1"])) + + def testScanRangeTupleCol(self): + self.runScanTest(self._table.scan_range("r1", "r4", columns=("cf1", "c1"))) + + def testLookup(self): + ds = self._table.keys_by_prefix_dataset("r") + ds = ds.apply(self._table.lookup_columns(cf1="c1")) + itr = ds.make_initializable_iterator() + n = itr.get_next() + expected_keys = list(self.COMMON_ROW_KEYS) + expected_values = list(self.COMMON_VALUES) + expected_tuples = zip(expected_keys, expected_values) + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + for i, elem in enumerate(expected_tuples): + output = sess.run(n) + self.assertEqual( + compat.as_bytes(elem[0]), compat.as_bytes(output[0]), + "Unequal keys at step %d: want: %s, got: %s" % + (i, compat.as_bytes(elem[0]), compat.as_bytes(output[0]))) + self.assertEqual( + compat.as_bytes(elem[1]), compat.as_bytes(output[1]), + "Unequal values at step %d: want: %s, got: %s" % + (i, compat.as_bytes(elem[1]), compat.as_bytes(output[1]))) + + def testSampleKeys(self): + ds = self._table.sample_keys() + itr = ds.make_initializable_iterator() + n = itr.get_next() + expected_key = self.COMMON_ROW_KEYS[0] + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + output = sess.run(n) + self.assertEqual( + compat.as_bytes(self.COMMON_ROW_KEYS[0]), compat.as_bytes(output), + "Unequal keys: want: %s, got: %s" % (compat.as_bytes( + self.COMMON_ROW_KEYS[0]), compat.as_bytes(output))) + output = sess.run(n) + self.assertEqual( + compat.as_bytes(self.COMMON_ROW_KEYS[2]), compat.as_bytes(output), + "Unequal keys: want: %s, got: %s" % (compat.as_bytes( + self.COMMON_ROW_KEYS[2]), compat.as_bytes(output))) + with self.assertRaises(errors.OutOfRangeError): + sess.run(n) + + def runSampleKeyPairsTest(self, ds, expected_key_pairs): + itr = ds.make_initializable_iterator() + n = itr.get_next() + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + for i, elems in enumerate(expected_key_pairs): + output = sess.run(n) + self.assertEqual( + compat.as_bytes(elems[0]), compat.as_bytes(output[0]), + "Unequal key pair (first element) at step %d; want: %s, got %s" % + (i, compat.as_bytes(elems[0]), compat.as_bytes(output[0]))) + self.assertEqual( + compat.as_bytes(elems[1]), compat.as_bytes(output[1]), + "Unequal key pair (second element) at step %d; want: %s, got %s" % + (i, compat.as_bytes(elems[1]), compat.as_bytes(output[1]))) + with self.assertRaises(errors.OutOfRangeError): + sess.run(n) + + def testSampleKeyPairsSimplePrefix(self): + ds = bigtable_api._BigtableSampleKeyPairsDataset( + self._table, prefix="r", start="", end="") + expected_key_pairs = [("r", "r1"), ("r1", "r3"), ("r3", "s")] + self.runSampleKeyPairsTest(ds, expected_key_pairs) + + def testSampleKeyPairsSimpleRange(self): + ds = bigtable_api._BigtableSampleKeyPairsDataset( + self._table, prefix="", start="r1", end="r3") + expected_key_pairs = [("r1", "r3")] + self.runSampleKeyPairsTest(ds, expected_key_pairs) + + def testSampleKeyPairsSkipRangePrefix(self): + ds = bigtable_api._BigtableSampleKeyPairsDataset( + self._table, prefix="r2", start="", end="") + expected_key_pairs = [("r2", "r3")] + self.runSampleKeyPairsTest(ds, expected_key_pairs) + + def testSampleKeyPairsSkipRangeRange(self): + ds = bigtable_api._BigtableSampleKeyPairsDataset( + self._table, prefix="", start="r2", end="r3") + expected_key_pairs = [("r2", "r3")] + self.runSampleKeyPairsTest(ds, expected_key_pairs) + + def testSampleKeyPairsOffsetRanges(self): + ds = bigtable_api._BigtableSampleKeyPairsDataset( + self._table, prefix="", start="r2", end="r4") + expected_key_pairs = [("r2", "r3"), ("r3", "r4")] + self.runSampleKeyPairsTest(ds, expected_key_pairs) + + def testSampleKeyPairEverything(self): + ds = bigtable_api._BigtableSampleKeyPairsDataset( + self._table, prefix="", start="", end="") + expected_key_pairs = [("", "r1"), ("r1", "r3"), ("r3", "")] + self.runSampleKeyPairsTest(ds, expected_key_pairs) + + def testSampleKeyPairsPrefixAndStartKey(self): + ds = bigtable_api._BigtableSampleKeyPairsDataset( + self._table, prefix="r", start="r1", end="") + itr = ds.make_initializable_iterator() + with self.test_session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(itr.initializer) + + def testSampleKeyPairsPrefixAndEndKey(self): + ds = bigtable_api._BigtableSampleKeyPairsDataset( + self._table, prefix="r", start="", end="r3") + itr = ds.make_initializable_iterator() + with self.test_session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(itr.initializer) + + def testParallelScanPrefix(self): + ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1") + itr = ds.make_initializable_iterator() + n = itr.get_next() + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES)) + actual_values = [] + for _ in range(len(expected_values)): + output = sess.run(n) + actual_values.append(output) + with self.assertRaises(errors.OutOfRangeError): + sess.run(n) + self.assertItemsEqual( + _ListOfTuplesOfStringsToBytes(expected_values), + _ListOfTuplesOfStringsToBytes(actual_values)) + + def testParallelScanRange(self): + ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1") + itr = ds.make_initializable_iterator() + n = itr.get_next() + with self.test_session() as sess: + self._writeCommonValues(sess) + sess.run(itr.initializer) + expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES)) + actual_values = [] + for _ in range(len(expected_values)): + output = sess.run(n) + actual_values.append(output) + with self.assertRaises(errors.OutOfRangeError): + sess.run(n) + self.assertItemsEqual( + _ListOfTuplesOfStringsToBytes(expected_values), + _ListOfTuplesOfStringsToBytes(actual_values)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bigtable/python/ops/__init__.py b/tensorflow/contrib/bigtable/python/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36d75b0d7068a650347a5e17f4727a5432d8752f --- /dev/null +++ b/tensorflow/contrib/bigtable/python/ops/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""This module contains the Python API for the Cloud Bigtable integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py new file mode 100644 index 0000000000000000000000000000000000000000..1102fb3c2dfb9ed71d286dc860a17b2079381eb0 --- /dev/null +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -0,0 +1,746 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Python API for TensorFlow's Cloud Bigtable integration. + +TensorFlow has support for reading from and writing to Cloud Bigtable. To use +TensorFlow + Cloud Bigtable integration, first create a BigtableClient to +configure your connection to Cloud Bigtable, and then create a BigtableTable +object to allow you to create numerous @{tf.data.Dataset}s to read data, or +write a @{tf.data.Dataset} object to the underlying Cloud Bigtable table. + +For background on Cloud Bigtable, see: https://cloud.google.com/bigtable . +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six import iteritems +from six import string_types + +from tensorflow.contrib.bigtable.ops import gen_bigtable_ops +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.contrib.util import loader +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.platform import resource_loader + +_bigtable_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_bigtable.so")) + + +class BigtableClient(object): + """BigtableClient is the entrypoint for interacting with Cloud Bigtable in TF. + + BigtableClient encapsulates a connection to Cloud Bigtable, and exposes the + `table` method to open a Bigtable table. + """ + + def __init__(self, + project_id, + instance_id, + connection_pool_size=None, + max_receive_message_size=None): + """Creates a BigtableClient that can be used to open connections to tables. + + Args: + project_id: A string representing the GCP project id to connect to. + instance_id: A string representing the Bigtable instance to connect to. + connection_pool_size: (Optional.) A number representing the number of + concurrent connections to the Cloud Bigtable service to make. + max_receive_message_size: (Optional.) The maximum bytes received in a + single gRPC response. + + Raises: + ValueError: if the arguments are invalid (e.g. wrong type, or out of + expected ranges (e.g. negative).) + """ + if not isinstance(project_id, str): + raise ValueError("`project_id` must be a string") + self._project_id = project_id + + if not isinstance(instance_id, str): + raise ValueError("`instance_id` must be a string") + self._instance_id = instance_id + + if connection_pool_size is None: + connection_pool_size = -1 + elif connection_pool_size < 1: + raise ValueError("`connection_pool_size` must be positive") + + if max_receive_message_size is None: + max_receive_message_size = -1 + elif max_receive_message_size < 1: + raise ValueError("`max_receive_message_size` must be positive") + + self._connection_pool_size = connection_pool_size + + self._resource = gen_bigtable_ops.bigtable_client( + project_id, instance_id, connection_pool_size, max_receive_message_size) + + def table(self, name, snapshot=None): + """Opens a table and returns a `tf.contrib.bigtable.BigtableTable` object. + + Args: + name: A `tf.string` `tf.Tensor` name of the table to open. + snapshot: Either a `tf.string` `tf.Tensor` snapshot id, or `True` to + request the creation of a snapshot. (Note: currently unimplemented.) + + Returns: + A `tf.contrib.bigtable.BigtableTable` Python object representing the + operations available on the table. + """ + # TODO(saeta): Implement snapshot functionality. + table = gen_bigtable_ops.bigtable_table(self._resource, name) + return BigtableTable(name, snapshot, table) + + +class BigtableTable(object): + """BigtableTable is the entrypoint for reading and writing data in Cloud + Bigtable. + + This BigtableTable class is the Python representation of the Cloud Bigtable + table within TensorFlow. Methods on this class allow data to be read from and + written to the Cloud Bigtable service in flexible and high performance + manners. + """ + + # TODO(saeta): Investigate implementing tf.contrib.lookup.LookupInterface. + # TODO(saeta): Consider variant tensors instead of resources (while supporting + # connection pooling). + + def __init__(self, name, snapshot, resource): + self._name = name + self._snapshot = snapshot + self._resource = resource + + def lookup_columns(self, *args, **kwargs): + """Retrieves the values of columns for a dataset of keys. + + Example usage: + + ```python + table = bigtable_client.table("my_table") + key_dataset = table.get_keys_prefix("imagenet") + images = key_dataset.apply(table.lookup_columns(("cf1", "image"), + ("cf2", "label"), + ("cf2", "boundingbox"))) + training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128) + ``` + + Alternatively, you can use keyword arguments to specify the columns to + capture. Example (same as above, rewritten): + + ```python + table = bigtable_client.table("my_table") + key_dataset = table.get_keys_prefix("imagenet") + images = key_dataset.apply(table.lookup_columns( + cf1="image", cf2=("label", "boundingbox"))) + training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128) + ``` + + Note: certain `kwargs` keys are reserved, and thus, some column families + cannot be identified using the `kwargs` syntax. Instead, please use the + `args` syntax. This list includes: + + - 'name' + + Note: this list can change at any time. + + Args: + *args: A list of tuples containing (column family, column name) pairs. + **kwargs: Column families (keys) and column qualifiers (values). + + Returns: + A function that can be passed to `tf.data.Dataset.apply` to retrieve the + values of columns for the rows. + """ + table = self # Capture self + normalized = args + if normalized is None: + normalized = [] + if isinstance(normalized, tuple): + normalized = list(normalized) + for key, value in iteritems(kwargs): + if key == "name": + continue + if isinstance(value, str): + normalized.append((key, value)) + continue + for col in value: + normalized.append((key, col)) + + def _apply_fn(dataset): + # TODO(saeta): Verify dataset's types are correct! + return _BigtableLookupDataset(dataset, table, normalized) + + return _apply_fn + + def keys_by_range_dataset(self, start, end): + """Retrieves all row keys between start and end. + + Note: it does NOT retrieve the values of columns. + + Args: + start: The start row key. The row keys for rows after start (inclusive) + will be retrieved. + end: (Optional.) The end row key. Rows up to (but not including) end will + be retrieved. If end is None, all subsequent row keys will be retrieved. + + Returns: + A @{tf.data.Dataset} containing `tf.string` Tensors corresponding to all + of the row keys between `start` and `end`. + """ + # TODO(saeta): Make inclusive / exclusive configurable? + if end is None: + end = "" + return _BigtableRangeKeyDataset(self, start, end) + + def keys_by_prefix_dataset(self, prefix): + """Retrieves the row keys matching a given prefix. + + Args: + prefix: All row keys that begin with `prefix` in the table will be + retrieved. + + Returns: + A @{tf.data.Dataset}. containing `tf.string` Tensors corresponding to all + of the row keys matching that prefix. + """ + return _BigtablePrefixKeyDataset(self, prefix) + + def sample_keys(self): + """Retrieves a sampling of row keys from the Bigtable table. + + This dataset is most often used in conjunction with + @{tf.contrib.data.parallel_interleave} to construct a set of ranges for + scanning in parallel. + + Returns: + A @{tf.data.Dataset} returning string row keys. + """ + return _BigtableSampleKeysDataset(self) + + def scan_prefix(self, prefix, probability=None, columns=None, **kwargs): + """Retrieves row (including values) from the Bigtable service. + + Rows with row-key prefixed by `prefix` will be retrieved. + + Specifying the columns to retrieve for each row is done by either using + kwargs or in the columns parameter. To retrieve values of the columns "c1", + and "c2" from the column family "cfa", and the value of the column "c3" + from column family "cfb", the following datasets (`ds1`, and `ds2`) are + equivalent: + + ``` + table = # ... + ds1 = table.scan_prefix("row_prefix", columns=[("cfa", "c1"), + ("cfa", "c2"), + ("cfb", "c3")]) + ds2 = table.scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3") + ``` + + Note: only the latest value of a cell will be retrieved. + + Args: + prefix: The prefix all row keys must match to be retrieved for prefix- + based scans. + probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive). + A non-1 value indicates to probabilistically sample rows with the + provided probability. + columns: The columns to read. Note: most commonly, they are expressed as + kwargs. Use the columns value if you are using column families that are + reserved. The value of columns and kwargs are merged. Columns is a list + of tuples of strings ("column_family", "column_qualifier"). + **kwargs: The column families and columns to read. Keys are treated as + column_families, and values can be either lists of strings, or strings + that are treated as the column qualifier (column name). + + Returns: + A @{tf.data.Dataset} returning the row keys and the cell contents. + + Raises: + ValueError: If the configured probability is unexpected. + """ + probability = _normalize_probability(probability) + normalized = _normalize_columns(columns, kwargs) + return _BigtableScanDataset(self, prefix, "", "", normalized, probability) + + def scan_range(self, start, end, probability=None, columns=None, **kwargs): + """Retrieves rows (including values) from the Bigtable service. + + Rows with row-keys between `start` and `end` will be retrieved. + + Specifying the columns to retrieve for each row is done by either using + kwargs or in the columns parameter. To retrieve values of the columns "c1", + and "c2" from the column family "cfa", and the value of the column "c3" + from column family "cfb", the following datasets (`ds1`, and `ds2`) are + equivalent: + + ``` + table = # ... + ds1 = table.scan_range("row_start", "row_end", columns=[("cfa", "c1"), + ("cfa", "c2"), + ("cfb", "c3")]) + ds2 = table.scan_range("row_start", "row_end", cfa=["c1", "c2"], cfb="c3") + ``` + + Note: only the latest value of a cell will be retrieved. + + Args: + start: The start of the range when scanning by range. + end: (Optional.) The end of the range when scanning by range. + probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive). + A non-1 value indicates to probabilistically sample rows with the + provided probability. + columns: The columns to read. Note: most commonly, they are expressed as + kwargs. Use the columns value if you are using column families that are + reserved. The value of columns and kwargs are merged. Columns is a list + of tuples of strings ("column_family", "column_qualifier"). + **kwargs: The column families and columns to read. Keys are treated as + column_families, and values can be either lists of strings, or strings + that are treated as the column qualifier (column name). + + Returns: + A @{tf.data.Dataset} returning the row keys and the cell contents. + + Raises: + ValueError: If the configured probability is unexpected. + """ + probability = _normalize_probability(probability) + normalized = _normalize_columns(columns, kwargs) + return _BigtableScanDataset(self, "", start, end, normalized, probability) + + def parallel_scan_prefix(self, + prefix, + num_parallel_scans=None, + probability=None, + columns=None, + **kwargs): + """Retrieves row (including values) from the Bigtable service at high speed. + + Rows with row-key prefixed by `prefix` will be retrieved. This method is + similar to `scan_prefix`, but by contrast performs multiple sub-scans in + parallel in order to achieve higher performance. + + Note: The dataset produced by this method is not deterministic! + + Specifying the columns to retrieve for each row is done by either using + kwargs or in the columns parameter. To retrieve values of the columns "c1", + and "c2" from the column family "cfa", and the value of the column "c3" + from column family "cfb", the following datasets (`ds1`, and `ds2`) are + equivalent: + + ``` + table = # ... + ds1 = table.parallel_scan_prefix("row_prefix", columns=[("cfa", "c1"), + ("cfa", "c2"), + ("cfb", "c3")]) + ds2 = table.parallel_scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3") + ``` + + Note: only the latest value of a cell will be retrieved. + + Args: + prefix: The prefix all row keys must match to be retrieved for prefix- + based scans. + num_parallel_scans: (Optional.) The number of concurrent scans against the + Cloud Bigtable instance. + probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive). + A non-1 value indicates to probabilistically sample rows with the + provided probability. + columns: The columns to read. Note: most commonly, they are expressed as + kwargs. Use the columns value if you are using column families that are + reserved. The value of columns and kwargs are merged. Columns is a list + of tuples of strings ("column_family", "column_qualifier"). + **kwargs: The column families and columns to read. Keys are treated as + column_families, and values can be either lists of strings, or strings + that are treated as the column qualifier (column name). + + Returns: + A @{tf.data.Dataset} returning the row keys and the cell contents. + + Raises: + ValueError: If the configured probability is unexpected. + """ + probability = _normalize_probability(probability) + normalized = _normalize_columns(columns, kwargs) + ds = _BigtableSampleKeyPairsDataset(self, prefix, "", "") + return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability, + normalized) + + def parallel_scan_range(self, + start, + end, + num_parallel_scans=None, + probability=None, + columns=None, + **kwargs): + """Retrieves rows (including values) from the Bigtable service. + + Rows with row-keys between `start` and `end` will be retrieved. This method + is similar to `scan_range`, but by contrast performs multiple sub-scans in + parallel in order to achieve higher performance. + + Note: The dataset produced by this method is not deterministic! + + Specifying the columns to retrieve for each row is done by either using + kwargs or in the columns parameter. To retrieve values of the columns "c1", + and "c2" from the column family "cfa", and the value of the column "c3" + from column family "cfb", the following datasets (`ds1`, and `ds2`) are + equivalent: + + ``` + table = # ... + ds1 = table.parallel_scan_range("row_start", + "row_end", + columns=[("cfa", "c1"), + ("cfa", "c2"), + ("cfb", "c3")]) + ds2 = table.parallel_scan_range("row_start", "row_end", + cfa=["c1", "c2"], cfb="c3") + ``` + + Note: only the latest value of a cell will be retrieved. + + Args: + start: The start of the range when scanning by range. + end: (Optional.) The end of the range when scanning by range. + num_parallel_scans: (Optional.) The number of concurrent scans against the + Cloud Bigtable instance. + probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive). + A non-1 value indicates to probabilistically sample rows with the + provided probability. + columns: The columns to read. Note: most commonly, they are expressed as + kwargs. Use the columns value if you are using column families that are + reserved. The value of columns and kwargs are merged. Columns is a list + of tuples of strings ("column_family", "column_qualifier"). + **kwargs: The column families and columns to read. Keys are treated as + column_families, and values can be either lists of strings, or strings + that are treated as the column qualifier (column name). + + Returns: + A @{tf.data.Dataset} returning the row keys and the cell contents. + + Raises: + ValueError: If the configured probability is unexpected. + """ + probability = _normalize_probability(probability) + normalized = _normalize_columns(columns, kwargs) + ds = _BigtableSampleKeyPairsDataset(self, "", start, end) + return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability, + normalized) + + def write(self, dataset, column_families, columns, timestamp=None): + """Writes a dataset to the table. + + Args: + dataset: A @{tf.data.Dataset} to be written to this table. It must produce + a list of number-of-columns+1 elements, all of which must be strings. + The first value will be used as the row key, and subsequent values will + be used as cell values for the corresponding columns from the + corresponding column_families and columns entries. + column_families: A @{tf.Tensor} of `tf.string`s corresponding to the + column names to store the dataset's elements into. + columns: A `tf.Tensor` of `tf.string`s corresponding to the column names + to store the dataset's elements into. + timestamp: (Optional.) An int64 timestamp to write all the values at. + Leave as None to use server-provided timestamps. + + Returns: + A @{tf.Operation} that can be run to perform the write. + + Raises: + ValueError: If there are unexpected or incompatible types, or if the + number of columns and column_families does not match the output of + `dataset`. + """ + if timestamp is None: + timestamp = -1 # Bigtable server provided timestamp. + for tensor_type in nest.flatten(dataset.output_types): + if tensor_type != dtypes.string: + raise ValueError("Not all elements of the dataset were `tf.string`") + for shape in nest.flatten(dataset.output_shapes): + if not shape.is_compatible_with(tensor_shape.scalar()): + raise ValueError("Not all elements of the dataset were scalars") + if len(column_families) != len(columns): + raise ValueError("len(column_families) != len(columns)") + if len(nest.flatten(dataset.output_types)) != len(columns) + 1: + raise ValueError("A column name must be specified for every component of " + "the dataset elements. (e.g.: len(columns) != " + "len(dataset.output_types))") + return gen_bigtable_ops.dataset_to_bigtable( + self._resource, + dataset._as_variant_tensor(), # pylint: disable=protected-access + column_families, + columns, + timestamp) + + def _make_parallel_scan_dataset(self, ds, num_parallel_scans, + normalized_probability, normalized_columns): + """Builds a parallel dataset from a given range. + + Args: + ds: A `_BigtableSampleKeyPairsDataset` returning ranges of keys to use. + num_parallel_scans: The number of concurrent parallel scans to use. + normalized_probability: A number between 0 and 1 for the keep probability. + normalized_columns: The column families and column qualifiers to retrieve. + + Returns: + A @{tf.data.Dataset} representing the result of the parallel scan. + """ + if num_parallel_scans is None: + num_parallel_scans = 50 + + ds = ds.shuffle(buffer_size=10000) # TODO(saeta): Make configurable. + + def _interleave_fn(start, end): + return _BigtableScanDataset( + self, + prefix="", + start=start, + end=end, + normalized=normalized_columns, + probability=normalized_probability) + + # Note prefetch_input_elements must be set in order to avoid rpc timeouts. + ds = ds.apply( + interleave_ops.parallel_interleave( + _interleave_fn, + cycle_length=num_parallel_scans, + sloppy=True, + prefetch_input_elements=1)) + return ds + + +def _normalize_probability(probability): + if probability is None: + probability = 1.0 + if isinstance(probability, float) and (probability <= 0.0 or + probability > 1.0): + raise ValueError("probability must be in the range (0, 1].") + return probability + + +def _normalize_columns(columns, provided_kwargs): + """Converts arguments (columns, and kwargs dict) to C++ representation. + + Args: + columns: a datastructure containing the column families and qualifier to + retrieve. Valid types include (1) None, (2) list of tuples, (3) a tuple of + strings. + provided_kwargs: a dictionary containing the column families and qualifiers + to retrieve + + Returns: + A list of pairs of column family+qualifier to retrieve. + + Raises: + ValueError: If there are no cells to retrieve or the columns are in an + incorrect format. + """ + normalized = columns + if normalized is None: + normalized = [] + if isinstance(normalized, tuple): + if len(normalized) == 2: + normalized = [normalized] + else: + raise ValueError("columns was a tuple of inappropriate length") + for key, value in iteritems(provided_kwargs): + if key == "name": + continue + if isinstance(value, string_types): + normalized.append((key, value)) + continue + for col in value: + normalized.append((key, col)) + if not normalized: + raise ValueError("At least one column + column family must be specified.") + return normalized + + +class _BigtableKeyDataset(dataset_ops.Dataset): + """_BigtableKeyDataset is an abstract class representing the keys of a table. + """ + + def __init__(self, table): + """Constructs a _BigtableKeyDataset. + + Args: + table: a Bigtable class. + """ + super(_BigtableKeyDataset, self).__init__() + self._table = table + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.TensorShape([]) + + @property + def output_types(self): + return dtypes.string + + +class _BigtablePrefixKeyDataset(_BigtableKeyDataset): + """_BigtablePrefixKeyDataset represents looking up keys by prefix. + """ + + def __init__(self, table, prefix): + super(_BigtablePrefixKeyDataset, self).__init__(table) + self._prefix = prefix + + def _as_variant_tensor(self): + return gen_bigtable_ops.bigtable_prefix_key_dataset( + table=self._table._resource, # pylint: disable=protected-access + prefix=self._prefix) + + +class _BigtableRangeKeyDataset(_BigtableKeyDataset): + """_BigtableRangeKeyDataset represents looking up keys by range. + """ + + def __init__(self, table, start, end): + super(_BigtableRangeKeyDataset, self).__init__(table) + self._start = start + self._end = end + + def _as_variant_tensor(self): + return gen_bigtable_ops.bigtable_range_key_dataset( + table=self._table._resource, # pylint: disable=protected-access + start_key=self._start, + end_key=self._end) + + +class _BigtableSampleKeysDataset(_BigtableKeyDataset): + """_BigtableSampleKeysDataset represents a sampling of row keys. + """ + + # TODO(saeta): Expose the data size offsets into the keys. + + def __init__(self, table): + super(_BigtableSampleKeysDataset, self).__init__(table) + + def _as_variant_tensor(self): + return gen_bigtable_ops.bigtable_sample_keys_dataset( + table=self._table._resource) # pylint: disable=protected-access + + +class _BigtableLookupDataset(dataset_ops.Dataset): + """_BigtableLookupDataset represents a dataset that retrieves values for keys. + """ + + def __init__(self, dataset, table, normalized): + self._num_outputs = len(normalized) + 1 # 1 for row key + self._dataset = dataset + self._table = table + self._normalized = normalized + self._column_families = [i[0] for i in normalized] + self._columns = [i[1] for i in normalized] + + @property + def output_classes(self): + return tuple([ops.Tensor] * self._num_outputs) + + @property + def output_shapes(self): + return tuple([tensor_shape.TensorShape([])] * self._num_outputs) + + @property + def output_types(self): + return tuple([dtypes.string] * self._num_outputs) + + def _as_variant_tensor(self): + # pylint: disable=protected-access + return gen_bigtable_ops.bigtable_lookup_dataset( + keys_dataset=self._dataset._as_variant_tensor(), + table=self._table._resource, + column_families=self._column_families, + columns=self._columns) + + +class _BigtableScanDataset(dataset_ops.Dataset): + """_BigtableScanDataset represents a dataset that retrieves keys and values. + """ + + def __init__(self, table, prefix, start, end, normalized, probability): + self._table = table + self._prefix = prefix + self._start = start + self._end = end + self._column_families = [i[0] for i in normalized] + self._columns = [i[1] for i in normalized] + self._probability = probability + self._num_outputs = len(normalized) + 1 # 1 for row key + + @property + def output_classes(self): + return tuple([ops.Tensor] * self._num_outputs) + + @property + def output_shapes(self): + return tuple([tensor_shape.TensorShape([])] * self._num_outputs) + + @property + def output_types(self): + return tuple([dtypes.string] * self._num_outputs) + + def _as_variant_tensor(self): + return gen_bigtable_ops.bigtable_scan_dataset( + table=self._table._resource, # pylint: disable=protected-access + prefix=self._prefix, + start_key=self._start, + end_key=self._end, + column_families=self._column_families, + columns=self._columns, + probability=self._probability) + + +class _BigtableSampleKeyPairsDataset(dataset_ops.Dataset): + """_BigtableSampleKeyPairsDataset returns key pairs from a Bigtable table. + """ + + def __init__(self, table, prefix, start, end): + self._table = table + self._prefix = prefix + self._start = start + self._end = end + + @property + def output_classes(self): + return (ops.Tensor, ops.Tensor) + + @property + def output_shapes(self): + return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) + + @property + def output_types(self): + return (dtypes.string, dtypes.string) + + def _as_variant_tensor(self): + # pylint: disable=protected-access + return gen_bigtable_ops.bigtable_sample_key_pairs_dataset( + table=self._table._resource, + prefix=self._prefix, + start_key=self._start, + end_key=self._end) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index ef0e80cd0997bc0e95cd0d150e87db144a2dde44..5fcb19a47aac492d49b0d8e99af5699bae2ad9f0 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -147,6 +147,7 @@ py_library( deps = [ ":distillation_loss", ":estimator_utils", + ":model", ":trainer_hooks", "//tensorflow/contrib/boosted_trees:gbdt_batch", "//tensorflow/contrib/boosted_trees:model_ops_py", @@ -190,7 +191,7 @@ py_test( py_test( name = "estimator_test", - size = "medium", + size = "large", srcs = ["estimator_test.py"], srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index 62f1f4122b05b56a708823df4246d618bd3fa5d4..78232fa0a6e2311c13d4f35acffc3486a9a28803 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import ops from tensorflow.python.platform import gfile from tensorflow.python.saved_model import loader as saved_model_loader from tensorflow.python.saved_model import tag_constants +from tensorflow.python.util import compat _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE = "%s_%d" @@ -88,10 +89,12 @@ def make_custom_export_strategy(name, len(sparse_float_indices), len(sparse_int_indices)) sorted_by_importance = sorted( feature_importances.items(), key=lambda x: -x[1]) - assets_dir = os.path.join(result_dir, "assets.extra") + assets_dir = os.path.join( + compat.as_bytes(result_dir), compat.as_bytes("assets.extra")) gfile.MakeDirs(assets_dir) - with gfile.GFile(os.path.join(assets_dir, "feature_importances"), - "w") as f: + with gfile.GFile(os.path.join( + compat.as_bytes(assets_dir), + compat.as_bytes("feature_importances")), "w") as f: f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance)) return result_dir diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index 7eb429b636a5193a124dd9b0c020dae6cac910cb..194a5c8754cb0ab2db299e3fb5c998c0f27f8435 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -26,6 +26,7 @@ from __future__ import print_function import six from tensorflow.contrib import layers +from tensorflow.contrib.boosted_trees.estimator_batch import model from tensorflow.contrib.boosted_trees.estimator_batch import distillation_loss from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks @@ -34,6 +35,7 @@ from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batc from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.python.estimator import estimator as core_estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.python.feature_column import feature_column as feature_column_lib from tensorflow.python.framework import ops @@ -62,27 +64,30 @@ def _add_hidden_layer_summary(value, tag): summary.histogram("%s_activation" % tag, value) -def _dnn_tree_combined_model_fn(features, - labels, - mode, - head, - dnn_hidden_units, - dnn_feature_columns, - tree_learner_config, - num_trees, - tree_examples_per_layer, - config=None, - dnn_optimizer="Adagrad", - dnn_activation_fn=nn.relu, - dnn_dropout=None, - dnn_input_layer_partitioner=None, - dnn_input_layer_to_tree=True, - dnn_steps_to_train=10000, - predict_with_tree_only=False, - tree_feature_columns=None, - tree_center_bias=False, - dnn_to_tree_distillation_param=None, - use_core_versions=False): +def _dnn_tree_combined_model_fn( + features, + labels, + mode, + head, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + config=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + predict_with_tree_only=False, + tree_feature_columns=None, + tree_center_bias=False, + dnn_to_tree_distillation_param=None, + use_core_versions=False, + output_type=model.ModelBuilderOutputType.MODEL_FN_OPS, + override_global_step_value=None): """DNN and GBDT combined model_fn. Args: @@ -131,6 +136,12 @@ def _dnn_tree_combined_model_fn(features, will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec + (new interface). + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. Returns: A `ModelFnOps` object. @@ -156,6 +167,10 @@ def _dnn_tree_combined_model_fn(features, partitioned_variables.min_max_variable_partitioner( max_partitions=config.num_ps_replicas, min_slice_size=64 << 20)) + if (output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC and + not use_core_versions): + raise ValueError("You must use core versions with Estimator Spec") + with variable_scope.variable_scope( dnn_parent_scope, values=tuple(six.itervalues(features)), @@ -235,7 +250,8 @@ def _dnn_tree_combined_model_fn(features, learner_config=tree_learner_config, feature_columns=tree_feature_columns, logits_dimension=head.logits_dimension, - features=tree_features) + features=tree_features, + use_core_columns=use_core_versions) with ops.name_scope("gbdt"): predictions_dict = gbdt_model.predict(mode) @@ -284,63 +300,98 @@ def _dnn_tree_combined_model_fn(features, del loss return control_flow_ops.no_op() - if use_core_versions: - model_fn_ops = head.create_estimator_spec( - features=features, - mode=mode, - labels=labels, - train_op_fn=_no_train_op_fn, - logits=tree_train_logits) - dnn_train_op = head.create_estimator_spec( - features=features, - mode=mode, - labels=labels, - train_op_fn=_dnn_train_op_fn, - logits=dnn_logits) - dnn_train_op = estimator_utils.estimator_spec_to_model_fn_ops( - dnn_train_op).train_op + if tree_center_bias: + num_trees += 1 + finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() - tree_train_op = head.create_estimator_spec( - features=tree_features, - mode=mode, - labels=labels, - train_op_fn=_tree_train_op_fn, - logits=tree_train_logits) - tree_train_op = estimator_utils.estimator_spec_to_model_fn_ops( - tree_train_op).train_op + if output_type == model.ModelBuilderOutputType.MODEL_FN_OPS: + if use_core_versions: + model_fn_ops = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits) + dnn_train_op = estimator_utils.estimator_spec_to_model_fn_ops( + dnn_train_op).train_op - model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops) - else: - model_fn_ops = head.create_model_fn_ops( + tree_train_op = head.create_estimator_spec( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits) + tree_train_op = estimator_utils.estimator_spec_to_model_fn_ops( + tree_train_op).train_op + + model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops( + model_fn_ops) + else: + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits).train_op + tree_train_op = head.create_model_fn_ops( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits).train_op + + # Add the hooks + model_fn_ops.training_hooks.extend([ + trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train, + tree_train_op), + trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, + finalized_trees, + override_global_step_value) + ]) + return model_fn_ops + + elif output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC: + fusion_spec = head.create_estimator_spec( features=features, mode=mode, labels=labels, train_op_fn=_no_train_op_fn, logits=tree_train_logits) - dnn_train_op = head.create_model_fn_ops( + dnn_spec = head.create_estimator_spec( features=features, mode=mode, labels=labels, train_op_fn=_dnn_train_op_fn, - logits=dnn_logits).train_op - tree_train_op = head.create_model_fn_ops( + logits=dnn_logits) + tree_spec = head.create_estimator_spec( features=tree_features, mode=mode, labels=labels, train_op_fn=_tree_train_op_fn, - logits=tree_train_logits).train_op - - if tree_center_bias: - num_trees += 1 - finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() - - model_fn_ops.training_hooks.extend([ - trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train, - tree_train_op), - trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, finalized_trees) - ]) + logits=tree_train_logits) - return model_fn_ops + training_hooks = [ + trainer_hooks.SwitchTrainOp(dnn_spec.train_op, dnn_steps_to_train, + tree_spec.train_op), + trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, + finalized_trees, + override_global_step_value) + ] + fusion_spec = fusion_spec._replace(training_hooks=training_hooks + + list(fusion_spec.training_hooks)) + return fusion_spec class DNNBoostedTreeCombinedClassifier(estimator.Estimator): @@ -369,7 +420,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): tree_feature_columns=None, tree_center_bias=False, dnn_to_tree_distillation_param=None, - use_core_versions=False): + use_core_versions=False, + override_global_step_value=None): """Initializes a DNNBoostedTreeCombinedClassifier instance. Args: @@ -425,6 +477,10 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. """ head = head_lib.multi_class_head( n_classes=n_classes, @@ -455,7 +511,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): tree_feature_columns=tree_feature_columns, tree_center_bias=tree_center_bias, dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, - use_core_versions=use_core_versions) + use_core_versions=use_core_versions, + override_global_step_value=override_global_step_value) super(DNNBoostedTreeCombinedClassifier, self).__init__( model_fn=_model_fn, @@ -489,7 +546,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): tree_feature_columns=None, tree_center_bias=False, dnn_to_tree_distillation_param=None, - use_core_versions=False): + use_core_versions=False, + override_global_step_value=None): """Initializes a DNNBoostedTreeCombinedRegressor instance. Args: @@ -545,6 +603,10 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. """ head = head_lib.regression_head( label_name=label_name, @@ -580,7 +642,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): tree_feature_columns=tree_feature_columns, tree_center_bias=tree_center_bias, dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, - use_core_versions=use_core_versions) + use_core_versions=use_core_versions, + override_global_step_value=override_global_step_value) super(DNNBoostedTreeCombinedRegressor, self).__init__( model_fn=_model_fn, @@ -615,7 +678,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): tree_feature_columns=None, tree_center_bias=False, dnn_to_tree_distillation_param=None, - use_core_versions=False): + use_core_versions=False, + override_global_step_value=None): """Initializes a DNNBoostedTreeCombinedEstimator instance. Args: @@ -666,6 +730,10 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): will be set to True. use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. """ def _model_fn(features, labels, mode, config): @@ -690,10 +758,109 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): tree_feature_columns=tree_feature_columns, tree_center_bias=tree_center_bias, dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, - use_core_versions=use_core_versions) + use_core_versions=use_core_versions, + override_global_step_value=override_global_step_value) super(DNNBoostedTreeCombinedEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn) + + +class CoreDNNBoostedTreeCombinedEstimator(core_estimator.Estimator): + """Initializes a core version of DNNBoostedTreeCombinedEstimator. + + Args: + dnn_hidden_units: List of hidden units per layer for DNN. + dnn_feature_columns: An iterable containing all the feature columns + used by the model's DNN. + tree_learner_config: A config for the tree learner. + num_trees: Number of trees to grow model to after training DNN. + tree_examples_per_layer: Number of examples to accumulate before + growing the tree a layer. This value has a big impact on model + quality and should be set equal to the number of examples in + training dataset if possible. It can also be a function that computes + the number of examples based on the depth of the layer that's + being built. + head: `Head` instance. + model_dir: Directory for model exports. + config: `RunConfig` of the estimator. + dnn_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the DNN. If `None`, will use the Adagrad + optimizer with default learning rate. + dnn_activation_fn: Activation function applied to each layer of the DNN. + If `None`, will use `tf.nn.relu`. + dnn_dropout: When not `None`, the probability to drop out a given + unit in the DNN. + dnn_input_layer_partitioner: Partitioner for input layer of the DNN. + Defaults to `min_max_variable_partitioner` with `min_slice_size` + 64 << 20. + dnn_input_layer_to_tree: Whether to provide the DNN's input layer + as a feature to the tree. + dnn_steps_to_train: Number of steps to train dnn for before switching + to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. + tree_feature_columns: An iterable containing all the feature columns + used by the model's boosted trees. If dnn_input_layer_to_tree is + set to True, these features are in addition to dnn_feature_columns. + tree_center_bias: Whether a separate tree should be created for + first fitting the bias. + dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the + float defines the weight of the distillation loss, and the loss_fn, for + computing distillation loss, takes dnn_logits, tree_logits and weight + tensor. If the entire tuple is None, no distillation will be applied. If + only the loss_fn is None, we will take the sigmoid/softmax cross entropy + loss be default. When distillation is applied, `predict_with_tree_only` + will be set to True. + """ + + def __init__(self, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + head, + model_dir=None, + config=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + predict_with_tree_only=False, + tree_feature_columns=None, + tree_center_bias=False, + dnn_to_tree_distillation_param=None): + + def _model_fn(features, labels, mode, config): + return _dnn_tree_combined_model_fn( + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + dnn_to_tree_distillation_param=dnn_to_tree_distillation_param, + output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC, + use_core_versions=True, + override_global_step_value=None) + + super(CoreDNNBoostedTreeCombinedEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py index 9b7acfa664b0398216b5a7fb904960d8363929d6..839eedd3a87ccaa1faecd1966fe5907d682cac02 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py @@ -28,10 +28,11 @@ from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops.losses import losses from tensorflow.python.platform import googletest - +from tensorflow.python.training import checkpoint_utils def _train_input_fn(): features = { @@ -156,5 +157,72 @@ class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): classifier.evaluate(input_fn=_eval_input_fn, steps=1) +class CoreDNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): + + def _assert_checkpoint(self, model_dir, global_step): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) + + def testTrainEvaluateInferDoesNotThrowErrorWithNoDnnInput(self): + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + est = estimator.CoreDNNBoostedTreeCombinedEstimator( + head=head_fn, + dnn_hidden_units=[1], + dnn_feature_columns=[core_feature_column.numeric_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=False, + tree_feature_columns=[core_feature_column.numeric_column("x")]) + + # Train for a few steps. + est.train(input_fn=_train_input_fn, steps=1000) + # 10 steps for dnn, 3 for 1 tree of depth 3 + 1 after the tree finished + self._assert_checkpoint(est.model_dir, global_step=14) + res = est.evaluate(input_fn=_eval_input_fn, steps=1) + self.assertLess(0.5, res["auc"]) + est.predict(input_fn=_eval_input_fn) + + def testTrainEvaluateInferDoesNotThrowErrorWithDnnInput(self): + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + est = estimator.CoreDNNBoostedTreeCombinedEstimator( + head=head_fn, + dnn_hidden_units=[1], + dnn_feature_columns=[core_feature_column.numeric_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=True, + tree_feature_columns=[]) + + # Train for a few steps. + est.train(input_fn=_train_input_fn, steps=1000) + res = est.evaluate(input_fn=_eval_input_fn, steps=1) + self.assertLess(0.5, res["auc"]) + est.predict(input_fn=_eval_input_fn) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 9c36c302210185bc390751a0229a61f2f8cd91b8..870ce2442bb5e98db7615c43054c9c827b8c88f0 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -22,7 +22,16 @@ from tensorflow.contrib.boosted_trees.estimator_batch import model from tensorflow.contrib.boosted_trees.python.utils import losses from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.python.estimator.canned import head as core_head_lib +from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.ops import math_ops +from tensorflow.python.ops.losses import losses as core_losses + + +# ================== Old estimator interface=================================== +# The estimators below were designed for old feature columns and old estimator +# interface. They can be used with new feature columns and losses by setting +# use_core_libs = True. class GradientBoostedDecisionTreeClassifier(estimator.Estimator): @@ -42,7 +51,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): logits_modifier_function=None, center_bias=True, use_core_libs=False, - output_leaf_index=False): + output_leaf_index=False, + override_global_step_value=None): """Initializes a GradientBoostedDecisionTreeClassifier estimator instance. Args: @@ -76,6 +86,14 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): for result_dict in result_iter: # access leaf index list by result_dict["leaf_index"] # which contains one leaf index per tree + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. Raises: ValueError: If learner_config is not valid. @@ -116,6 +134,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'use_core_libs': use_core_libs, 'output_leaf_index': output_leaf_index, + 'override_global_step_value': override_global_step_value }, model_dir=model_dir, config=config, @@ -139,7 +158,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): logits_modifier_function=None, center_bias=True, use_core_libs=False, - output_leaf_index=False): + output_leaf_index=False, + override_global_step_value=None): """Initializes a GradientBoostedDecisionTreeRegressor estimator instance. Args: @@ -173,6 +193,14 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): for example_prediction_result in result_dict: # access leaf index list by example_prediction_result["leaf_index"] # which contains one leaf index per tree + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. """ head = head_lib.regression_head( label_name=label_name, @@ -196,6 +224,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): 'center_bias': center_bias, 'use_core_libs': use_core_libs, 'output_leaf_index': False, + 'override_global_step_value': override_global_step_value }, model_dir=model_dir, config=config, @@ -221,7 +250,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): logits_modifier_function=None, center_bias=True, use_core_libs=False, - output_leaf_index=False): + output_leaf_index=False, + override_global_step_value=None): """Initializes a GradientBoostedDecisionTreeEstimator estimator instance. Args: @@ -251,6 +281,14 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): for example_prediction_result in result_dict: # access leaf index list by example_prediction_result["leaf_index"] # which contains one leaf index per tree + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. """ super(GradientBoostedDecisionTreeEstimator, self).__init__( model_fn=model.model_builder, @@ -265,7 +303,282 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): 'center_bias': center_bias, 'use_core_libs': use_core_libs, 'output_leaf_index': False, + 'override_global_step_value': override_global_step_value + }, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) + + +class GradientBoostedDecisionTreeRanker(estimator.Estimator): + """A ranking estimator using gradient boosted decision trees.""" + + def __init__(self, + learner_config, + examples_per_layer, + head, + ranking_model_pair_keys, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + label_keys=None, + feature_engineering_fn=None, + logits_modifier_function=None, + center_bias=False, + use_core_libs=False, + output_leaf_index=False, + override_global_step_value=None): + """Initializes a GradientBoostedDecisionTreeRanker instance. + + This is an estimator that can be trained off the pairwise data and can be + used for inference on non-paired data. This is essentially LambdaMart. + Args: + learner_config: A config for the learner. + examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + head: `Head` instance. + ranking_model_pair_keys: Keys to distinguish between features + for left and right part of the training pairs for ranking. For example, + for an Example with features "a.f1" and "b.f1", the keys would be + ("a", "b"). + num_trees: An int, number of trees to build. + feature_columns: A list of feature columns. + weight_column_name: Name of the column for weights, or None if not + weighted. + model_dir: Directory for model exports, etc. + config: `RunConfig` object to configure the runtime settings. + label_keys: Optional list of strings with size `[n_classes]` defining the + label vocabulary. Only supported for `n_classes` > 2. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + logits_modifier_function: A modifier function for the logits. + center_bias: Whether a separate tree should be created for first fitting + the bias. + use_core_libs: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is + [batch_size, num_trees]. + For example, + result_iter = classifier.predict(...) + for result_dict in result_iter: + # access leaf index list by result_dict["leaf_index"] + # which contains one leaf index per tree + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. + Raises: + ValueError: If learner_config is not valid. + """ + super(GradientBoostedDecisionTreeRanker, self).__init__( + model_fn=model.ranking_model_builder, + params={ + 'head': head, + 'n_classes': 2, + 'feature_columns': feature_columns, + 'learner_config': learner_config, + 'num_trees': num_trees, + 'weight_column_name': weight_column_name, + 'examples_per_layer': examples_per_layer, + 'center_bias': center_bias, + 'logits_modifier_function': logits_modifier_function, + 'use_core_libs': use_core_libs, + 'output_leaf_index': output_leaf_index, + 'ranking_model_pair_keys': ranking_model_pair_keys, + 'override_global_step_value': override_global_step_value }, model_dir=model_dir, config=config, feature_engineering_fn=feature_engineering_fn) + +# ================== New Estimator interface=================================== +# The estimators below use new core Estimator interface and must be used with +# new feature columns and heads. + +# For multiclass classification, use the following head since it uses loss +# that is twice differentiable. +def core_multiclass_head(n_classes): + """Core head for multiclass problems.""" + + def loss_fn(labels, logits): + result = losses.per_example_maxent_loss( + labels=labels, logits=logits, weights=None, num_classes=n_classes) + return result[0] + + # pylint:disable=protected-access + head_fn = core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( + n_classes=n_classes, + loss_fn=loss_fn, + loss_reduction=core_losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + # pylint:enable=protected-access + + return head_fn + + +class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): + """An estimator using gradient boosted decision trees. + + Useful for training with user specified `Head`. + """ + + def __init__(self, + learner_config, + examples_per_layer, + head, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + label_keys=None, + feature_engineering_fn=None, + logits_modifier_function=None, + center_bias=True, + output_leaf_index=False): + """Initializes a core version of GradientBoostedDecisionTreeEstimator. + + Args: + learner_config: A config for the learner. + examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + head: `Head` instance. + num_trees: An int, number of trees to build. + feature_columns: A list of feature columns. + weight_column_name: Name of the column for weights, or None if not + weighted. + model_dir: Directory for model exports, etc. + config: `RunConfig` object to configure the runtime settings. + label_keys: Optional list of strings with size `[n_classes]` defining the + label vocabulary. Only supported for `n_classes` > 2. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + logits_modifier_function: A modifier function for the logits. + center_bias: Whether a separate tree should be created for first fitting + the bias. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree + """ + + def _model_fn(features, labels, mode, config): + return model.model_builder( + features=features, + labels=labels, + mode=mode, + config=config, + params={ + 'head': head, + 'feature_columns': feature_columns, + 'learner_config': learner_config, + 'num_trees': num_trees, + 'weight_column_name': weight_column_name, + 'examples_per_layer': examples_per_layer, + 'center_bias': center_bias, + 'logits_modifier_function': logits_modifier_function, + 'use_core_libs': True, + 'output_leaf_index': output_leaf_index, + 'override_global_step_value': None + }, + output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) + + super(CoreGradientBoostedDecisionTreeEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) + + +class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): + """A ranking estimator using gradient boosted decision trees.""" + + def __init__(self, + learner_config, + examples_per_layer, + head, + ranking_model_pair_keys, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + label_keys=None, + logits_modifier_function=None, + center_bias=False, + output_leaf_index=False): + """Initializes a GradientBoostedDecisionTreeRanker instance. + + This is an estimator that can be trained off the pairwise data and can be + used for inference on non-paired data. This is essentially LambdaMart. + Args: + learner_config: A config for the learner. + examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + head: `Head` instance. + ranking_model_pair_keys: Keys to distinguish between features + for left and right part of the training pairs for ranking. For example, + for an Example with features "a.f1" and "b.f1", the keys would be + ("a", "b"). + num_trees: An int, number of trees to build. + feature_columns: A list of feature columns. + weight_column_name: Name of the column for weights, or None if not + weighted. + model_dir: Directory for model exports, etc. + config: `RunConfig` object to configure the runtime settings. + label_keys: Optional list of strings with size `[n_classes]` defining the + label vocabulary. Only supported for `n_classes` > 2. + logits_modifier_function: A modifier function for the logits. + center_bias: Whether a separate tree should be created for first fitting + the bias. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is + [batch_size, num_trees]. + For example, + result_iter = classifier.predict(...) + for result_dict in result_iter: + # access leaf index list by result_dict["leaf_index"] + # which contains one leaf index per tree + + Raises: + ValueError: If learner_config is not valid. + """ + + def _model_fn(features, labels, mode, config): + return model.ranking_model_builder( + features=features, + labels=labels, + mode=mode, + config=config, + params={ + 'head': head, + 'n_classes': 2, + 'feature_columns': feature_columns, + 'learner_config': learner_config, + 'num_trees': num_trees, + 'weight_column_name': weight_column_name, + 'examples_per_layer': examples_per_layer, + 'center_bias': center_bias, + 'logits_modifier_function': logits_modifier_function, + 'use_core_libs': True, + 'output_leaf_index': output_leaf_index, + 'ranking_model_pair_keys': ranking_model_pair_keys, + 'override_global_step_value': None + }, + output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) + + super(CoreGradientBoostedDecisionTreeRanker, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index 75ef1b050028b6462b255827c06e836e5c481844..68d710d713770a3a4a623b9447bb6a6b93569cac 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -25,10 +25,12 @@ from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import googletest +from tensorflow.python.training import checkpoint_utils def _train_input_fn(): @@ -37,18 +39,50 @@ def _train_input_fn(): return features, label +def _multiclass_train_input_fn(): + features = { + "x": constant_op.constant([[2.], [1.], [1.], [5.], [3.5], [4.6], [3.5]]) + } + label = constant_op.constant( + [[1], [0], [0], [2], [2], [0], [1]], dtype=dtypes.int32) + return features, label + + +def _ranking_train_input_fn(): + features = { + "a.f1": constant_op.constant([[3.], [0.3], [1.]]), + "a.f2": constant_op.constant([[0.1], [3.], [1.]]), + "b.f1": constant_op.constant([[13.], [0.4], [5.]]), + "b.f2": constant_op.constant([[1.], [3.], [0.01]]), + } + label = constant_op.constant([[0], [0], [1]], dtype=dtypes.int32) + return features, label + + def _eval_input_fn(): features = {"x": constant_op.constant([[1.], [2.], [2.]])} label = constant_op.constant([[0], [1], [1]], dtype=dtypes.int32) return features, label +def _infer_ranking_train_input_fn(): + features = { + "f1": constant_op.constant([[3.], [2], [1.]]), + "f2": constant_op.constant([[0.1], [3.], [1.]]) + } + return features, None + + class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): def setUp(self): self._export_dir_base = tempfile.mkdtemp() + "export/" gfile.MkDir(self._export_dir_base) + def _assert_checkpoint(self, model_dir, global_step): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) + def testFitAndEvaluateDontThrowException(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 @@ -155,6 +189,290 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): regressor.evaluate(input_fn=_eval_input_fn, steps=1) regressor.export(self._export_dir_base) + def testRankingDontThrowExceptionForForEstimator(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + model = estimator.GradientBoostedDecisionTreeRanker( + head=head_fn, + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + use_core_libs=True, + feature_columns=[ + core_feature_column.numeric_column("f1"), + core_feature_column.numeric_column("f2") + ], + ranking_model_pair_keys=("a", "b")) + + model.fit(input_fn=_ranking_train_input_fn, steps=1000) + model.evaluate(input_fn=_ranking_train_input_fn, steps=1) + model.predict(input_fn=_infer_ranking_train_input_fn) + + def testDoesNotOverrideGlobalSteps(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 2 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")], + output_leaf_index=False) + + classifier.fit(input_fn=_train_input_fn, steps=15) + # When no override of global steps, 5 steps were used. + self._assert_checkpoint(classifier.model_dir, global_step=5) + + def testOverridesGlobalSteps(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 2 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")], + output_leaf_index=False, + override_global_step_value=10000000) + + classifier.fit(input_fn=_train_input_fn, steps=15) + self._assert_checkpoint(classifier.model_dir, global_step=10000000) + + def testFitAndEvaluateMultiClassTreePerClassDontThrowException(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 3 + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.TREE_PER_CLASS) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + n_classes=learner_config.num_classes, + num_trees=1, + examples_per_layer=7, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")]) + + classifier.fit(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + classifier.export(self._export_dir_base) + result_iter = classifier.predict(input_fn=_eval_input_fn) + for prediction_dict in result_iter: + self.assertTrue("classes" in prediction_dict) + + def testFitAndEvaluateMultiClassDiagonalDontThrowException(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 3 + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + n_classes=learner_config.num_classes, + num_trees=1, + examples_per_layer=7, + model_dir=model_dir, + config=config, + center_bias=False, + feature_columns=[contrib_feature_column.real_valued_column("x")]) + + classifier.fit(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + classifier.export(self._export_dir_base) + result_iter = classifier.predict(input_fn=_eval_input_fn) + for prediction_dict in result_iter: + self.assertTrue("classes" in prediction_dict) + + def testFitAndEvaluateMultiClassFullDontThrowException(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 3 + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.FULL_HESSIAN) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + n_classes=learner_config.num_classes, + num_trees=1, + examples_per_layer=7, + model_dir=model_dir, + config=config, + center_bias=False, + feature_columns=[contrib_feature_column.real_valued_column("x")]) + + classifier.fit(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + classifier.export(self._export_dir_base) + result_iter = classifier.predict(input_fn=_eval_input_fn) + for prediction_dict in result_iter: + self.assertTrue("classes" in prediction_dict) + + +class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): + + def testTrainEvaluateInferDoesNotThrowError(self): + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + est = estimator.CoreGradientBoostedDecisionTreeEstimator( + head=head_fn, + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")]) + + # Train for a few steps. + est.train(input_fn=_train_input_fn, steps=1000) + est.evaluate(input_fn=_eval_input_fn, steps=1) + est.predict(input_fn=_eval_input_fn) + + def testRankingDontThrowExceptionForForEstimator(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + est = estimator.CoreGradientBoostedDecisionTreeRanker( + head=head_fn, + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[ + core_feature_column.numeric_column("f1"), + core_feature_column.numeric_column("f2") + ], + ranking_model_pair_keys=("a", "b")) + + # Train for a few steps. + est.train(input_fn=_ranking_train_input_fn, steps=1000) + est.evaluate(input_fn=_ranking_train_input_fn, steps=1) + est.predict(input_fn=_infer_ranking_train_input_fn) + + def testFitAndEvaluateMultiClassTreePerClasssDontThrowException(self): + n_classes = 3 + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = n_classes + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.TREE_PER_CLASS) + + head_fn = estimator.core_multiclass_head(n_classes=n_classes) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.CoreGradientBoostedDecisionTreeEstimator( + learner_config=learner_config, + head=head_fn, + num_trees=1, + center_bias=False, + examples_per_layer=7, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")]) + + classifier.train(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1) + classifier.predict(input_fn=_eval_input_fn) + + def testFitAndEvaluateMultiClassDiagonalDontThrowException(self): + n_classes = 3 + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = n_classes + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) + + head_fn = estimator.core_multiclass_head(n_classes=n_classes) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.CoreGradientBoostedDecisionTreeEstimator( + learner_config=learner_config, + head=head_fn, + num_trees=1, + center_bias=False, + examples_per_layer=7, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")]) + + classifier.train(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1) + classifier.predict(input_fn=_eval_input_fn) + + def testFitAndEvaluateMultiClassFullDontThrowException(self): + n_classes = 3 + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = n_classes + learner_config.constraints.max_tree_depth = 1 + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.FULL_HESSIAN) + + head_fn = estimator.core_multiclass_head(n_classes=n_classes) + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.CoreGradientBoostedDecisionTreeEstimator( + learner_config=learner_config, + head=head_fn, + num_trees=1, + center_bias=False, + examples_per_layer=7, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")]) + + classifier.train(input_fn=_multiclass_train_input_fn, steps=100) + classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1) + classifier.predict(input_fn=_eval_input_fn) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 1ee891198939e53fc5913104b2c2e65dc977823f..04b46c3483fa25286078b88c2776b76e4f3c0bcf 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -20,6 +20,7 @@ from __future__ import print_function import copy +from tensorflow.contrib import learn from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks from tensorflow.contrib.boosted_trees.python.ops import model_ops @@ -28,8 +29,17 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import state_ops from tensorflow.python.training import training_util +class ModelBuilderOutputType(object): + MODEL_FN_OPS = 0 + ESTIMATOR_SPEC = 1 -def model_builder(features, labels, mode, params, config): + +def model_builder(features, + labels, + mode, + params, + config, + output_type=ModelBuilderOutputType.MODEL_FN_OPS): """Multi-machine batch gradient descent tree model. Args: @@ -48,7 +58,13 @@ def model_builder(features, labels, mode, params, config): * weight_column_name: The name of weight column. * center_bias: Whether a separate tree should be created for first fitting the bias. + * override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. config: `RunConfig` of the estimator. + output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec + (new interface). Returns: A `ModelFnOps` object. @@ -64,6 +80,7 @@ def model_builder(features, labels, mode, params, config): use_core_libs = params["use_core_libs"] logits_modifier_function = params["logits_modifier_function"] output_leaf_index = params["output_leaf_index"] + override_global_step_value = params.get("override_global_step_value", None) if features is None: raise ValueError("At least one feature must be specified.") @@ -115,29 +132,271 @@ def model_builder(features, labels, mode, params, config): return update_op create_estimator_spec_op = getattr(head, "create_estimator_spec", None) - if use_core_libs and callable(create_estimator_spec_op): - model_fn_ops = head.create_estimator_spec( + + training_hooks = [] + if num_trees: + if center_bias: + num_trees += 1 + + finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() + training_hooks.append( + trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, + finalized_trees, + override_global_step_value)) + + if output_type == ModelBuilderOutputType.MODEL_FN_OPS: + if use_core_libs and callable(create_estimator_spec_op): + model_fn_ops = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops( + model_fn_ops) + else: + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + + if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: + model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ + gbdt_batch.LEAF_INDEX] + + model_fn_ops.training_hooks.extend(training_hooks) + return model_fn_ops + elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC: + assert callable(create_estimator_spec_op) + estimator_spec = head.create_estimator_spec( features=features, mode=mode, labels=labels, train_op_fn=_train_op_fn, logits=logits) - model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops) + + estimator_spec = estimator_spec._replace( + training_hooks=training_hooks + list(estimator_spec.training_hooks)) + return estimator_spec + + return model_fn_ops + + +def ranking_model_builder(features, + labels, + mode, + params, + config, + output_type=ModelBuilderOutputType.MODEL_FN_OPS): + """Multi-machine batch gradient descent tree model for ranking. + + Args: + features: `Tensor` or `dict` of `Tensor` objects. + labels: Labels used to train on. + mode: Mode we are in. (TRAIN/EVAL/INFER) + params: A dict of hyperparameters. + The following hyperparameters are expected: + * head: A `Head` instance. + * learner_config: A config for the learner. + * feature_columns: An iterable containing all the feature columns used by + the model. + * examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + * weight_column_name: The name of weight column. + * center_bias: Whether a separate tree should be created for first fitting + the bias. + * ranking_model_pair_keys (Optional): Keys to distinguish between features + for left and right part of the training pairs for ranking. For example, + for an Example with features "a.f1" and "b.f1", the keys would be + ("a", "b"). + * override_global_step_value: If after the training is done, global step + value must be reset to this value. This is particularly useful for hyper + parameter tuning, which can't recognize early stopping due to the number + of trees. If None, no override of global step will happen. + config: `RunConfig` of the estimator. + output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec + (new interface). + + + Returns: + A `ModelFnOps` object. + Raises: + ValueError: if inputs are not valid. + """ + head = params["head"] + learner_config = params["learner_config"] + examples_per_layer = params["examples_per_layer"] + feature_columns = params["feature_columns"] + weight_column_name = params["weight_column_name"] + num_trees = params["num_trees"] + use_core_libs = params["use_core_libs"] + logits_modifier_function = params["logits_modifier_function"] + output_leaf_index = params["output_leaf_index"] + ranking_model_pair_keys = params["ranking_model_pair_keys"] + override_global_step_value = params.get("override_global_step_value", None) + + if features is None: + raise ValueError("At least one feature must be specified.") + + if config is None: + raise ValueError("Missing estimator RunConfig.") + + center_bias = params["center_bias"] + + if isinstance(features, ops.Tensor): + features = {features.name: features} + + # Make a shallow copy of features to ensure downstream usage + # is unaffected by modifications in the model function. + training_features = copy.copy(features) + training_features.pop(weight_column_name, None) + global_step = training_util.get_global_step() + with ops.device(global_step.device): + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config="", # Initialize an empty ensemble. + name="ensemble_model") + + # Extract the features. + if mode == learn.ModeKeys.TRAIN or mode == learn.ModeKeys.EVAL: + # For ranking pairwise training, we extract two sets of features. + if len(ranking_model_pair_keys) != 2: + raise ValueError("You must provide keys for ranking.") + left_pair_key = ranking_model_pair_keys[0] + right_pair_key = ranking_model_pair_keys[1] + if left_pair_key is None or right_pair_key is None: + raise ValueError("Both pair keys should be provided for ranking.") + + features_1 = {} + features_2 = {} + for name in training_features: + feature = training_features[name] + new_name = name[2:] + if name.startswith(left_pair_key + "."): + features_1[new_name] = feature + else: + assert name.startswith(right_pair_key + ".") + features_2[new_name] = feature + + main_features = features_1 + supplementary_features = features_2 else: - model_fn_ops = head.create_model_fn_ops( + # For non-ranking or inference ranking, we have only 1 set of features. + main_features = training_features + + # Create GBDT model. + gbdt_model_main = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=config.is_chief, + num_ps_replicas=config.num_ps_replicas, + ensemble_handle=ensemble_handle, + center_bias=center_bias, + examples_per_layer=examples_per_layer, + learner_config=learner_config, + feature_columns=feature_columns, + logits_dimension=head.logits_dimension, + features=main_features, + use_core_columns=use_core_libs, + output_leaf_index=output_leaf_index) + + with ops.name_scope("gbdt", "gbdt_optimizer"): + # Logits for inference. + if mode == learn.ModeKeys.INFER: + predictions_dict = gbdt_model_main.predict(mode) + logits = predictions_dict[gbdt_batch.PREDICTIONS] + if logits_modifier_function: + logits = logits_modifier_function(logits, features, mode) + else: + gbdt_model_supplementary = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=config.is_chief, + num_ps_replicas=config.num_ps_replicas, + ensemble_handle=ensemble_handle, + center_bias=center_bias, + examples_per_layer=examples_per_layer, + learner_config=learner_config, + feature_columns=feature_columns, + logits_dimension=head.logits_dimension, + features=supplementary_features, + use_core_columns=use_core_libs, + output_leaf_index=output_leaf_index) + + # Logits for train and eval. + if not supplementary_features: + raise ValueError("Features for ranking must be specified.") + + predictions_dict_1 = gbdt_model_main.predict(mode) + predictions_1 = predictions_dict_1[gbdt_batch.PREDICTIONS] + + predictions_dict_2 = gbdt_model_supplementary.predict(mode) + predictions_2 = predictions_dict_2[gbdt_batch.PREDICTIONS] + + logits = predictions_1 - predictions_2 + if logits_modifier_function: + logits = logits_modifier_function(logits, features, mode) + + predictions_dict = predictions_dict_1 + predictions_dict[gbdt_batch.PREDICTIONS] = logits + + def _train_op_fn(loss): + """Returns the op to optimize the loss.""" + update_op = gbdt_model_main.train(loss, predictions_dict, labels) + with ops.control_dependencies( + [update_op]), (ops.colocate_with(global_step)): + update_op = state_ops.assign_add(global_step, 1).op + return update_op + + create_estimator_spec_op = getattr(head, "create_estimator_spec", None) + + training_hooks = [] + if num_trees: + if center_bias: + num_trees += 1 + + finalized_trees, attempted_trees = ( + gbdt_model_main.get_number_of_trees_tensor()) + training_hooks.append( + trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, + finalized_trees, + override_global_step_value)) + + if output_type == ModelBuilderOutputType.MODEL_FN_OPS: + if use_core_libs and callable(create_estimator_spec_op): + model_fn_ops = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops( + model_fn_ops) + else: + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + + if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: + model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ + gbdt_batch.LEAF_INDEX] + + model_fn_ops.training_hooks.extend(training_hooks) + return model_fn_ops + + elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC: + assert callable(create_estimator_spec_op) + estimator_spec = head.create_estimator_spec( features=features, mode=mode, labels=labels, train_op_fn=_train_op_fn, logits=logits) - if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: - model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ - gbdt_batch.LEAF_INDEX] - if num_trees: - if center_bias: - num_trees += 1 - finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() - model_fn_ops.training_hooks.append( - trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, - finalized_trees)) + + estimator_spec = estimator_spec._replace( + training_hooks=training_hooks + list(estimator_spec.training_hooks)) + return estimator_spec + return model_fn_ops diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py b/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py index 2e4151cac40f770e2bece70d752122eb7f34dd40..f137ada35524bf2467314f4a284ea35a82f06825 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py @@ -25,6 +25,7 @@ from tensorflow.contrib.learn.python.learn.session_run_hook import SessionRunArg from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import training_util from tensorflow.python.training.summary_io import SummaryWriterCache @@ -150,12 +151,23 @@ class FeedFnHook(session_run_hook.SessionRunHook): class StopAfterNTrees(session_run_hook.SessionRunHook): """Stop training after building N full trees.""" - def __init__(self, n, num_attempted_trees_tensor, num_finalized_trees_tensor): + def __init__(self, n, num_attempted_trees_tensor, num_finalized_trees_tensor, + override_global_step_value=None): self._num_trees = n # num_attempted_trees_tensor and num_finalized_trees_tensor are both # tensors. self._num_attempted_trees_tensor = num_attempted_trees_tensor self._num_finalized_trees_tensor = num_finalized_trees_tensor + self._override_global_step_value = override_global_step_value + + def begin(self): + self._global_step_tensor = training_util.get_global_step() + if self._global_step_tensor is None: + raise RuntimeError("Global step should be created.") + + if self._override_global_step_value is not None: + self._override_global_step_op = state_ops.assign( + self._global_step_tensor, self._override_global_step_value) def before_run(self, run_context): del run_context # unused by StopTrainingAfterNTrees. @@ -175,6 +187,9 @@ class StopAfterNTrees(session_run_hook.SessionRunHook): num_attempted_trees > 2 * self._num_trees): logging.info("Requesting stop since we have reached %d trees.", num_finalized_trees) + if self._override_global_step_value is not None: + logging.info("Overriding global steps value.") + run_context.session.run(self._override_global_step_op) run_context.request_stop() diff --git a/tensorflow/contrib/boosted_trees/examples/boston.py b/tensorflow/contrib/boosted_trees/examples/boston.py index e9dbdb0fd784052eeb36ac1aa9342165ef2ac0a7..54c4ff059e3408d2cb8fc689a9ae877f57485f58 100644 --- a/tensorflow/contrib/boosted_trees/examples/boston.py +++ b/tensorflow/contrib/boosted_trees/examples/boston.py @@ -45,6 +45,7 @@ from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientB from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn import learn_runner +from tensorflow.python.util import compat _BOSTON_NUM_FEATURES = 13 @@ -79,7 +80,8 @@ def _convert_fn(dtec, sorted_feature_names, num_dense, num_sparse_float, num_sparse_int, export_dir, unused_eval_result): universal_format = custom_export_strategy.convert_to_universal_format( dtec, sorted_feature_names, num_dense, num_sparse_float, num_sparse_int) - with tf.gfile.GFile(os.path.join(export_dir, "tree_proto"), "w") as f: + with tf.gfile.GFile(os.path.join( + compat.as_bytes(export_dir), compat.as_bytes("tree_proto")), "w") as f: f.write(str(universal_format)) diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index 0b28f81e7ca9a1228adc5bde19c429265e0aa9b8..1375fddf2bea1a8f856c35d756c38a8beb14a53f 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -125,6 +125,8 @@ void QuantizeFeatures( auto flat_values = values_tensor.flat(); for (int64 instance = 0; instance < num_values; ++instance) { const float value = flat_values(instance); + CHECK(!buckets_vector.empty()) + << "Got empty buckets for feature " << feature_index; auto bucket_iter = std::lower_bound(buckets_vector.begin(), buckets_vector.end(), value); if (bucket_iter == buckets_vector.end()) { @@ -241,6 +243,11 @@ class CreateQuantileAccumulatorOp : public OpKernel { // other exceptions. If one already exists, it unrefs the new one. const Tensor* stamp_token_t; OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); + // An epsilon value of zero could cause perfoamance issues and is therefore, + // disallowed. + OP_REQUIRES( + context, epsilon_ > 0, + errors::InvalidArgument("An epsilon value of zero is not allowed.")); auto result = new QuantileStreamResource(epsilon_, num_quantiles_, max_elements_, generate_quantiles_, stamp_token_t->scalar()()); diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index 1bfeed306641111718984b2097512e5ec3fa8630..6d9a6ee5a0d05465459393c4339558f1ca38d417 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -372,12 +372,18 @@ class GrowTreeEnsembleOp : public OpKernel { return; } + // Get the max tree depth. + const Tensor* max_tree_depth_t; + OP_REQUIRES_OK(context, + context->input("max_tree_depth", &max_tree_depth_t)); + const int32 max_tree_depth = max_tree_depth_t->scalar()(); + // Update and retrieve the growable tree. // If the tree is fully built and dropout was applied, it also adjusts the // weights of dropped and the last tree. boosted_trees::trees::DecisionTreeConfig* const tree_config = UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate, - dropout_seed); + dropout_seed, max_tree_depth); // Split tree nodes. for (auto& split_entry : best_splits) { @@ -494,7 +500,8 @@ class GrowTreeEnsembleOp : public OpKernel { boosted_trees::trees::DecisionTreeConfig* UpdateAndRetrieveGrowableTree( boosted_trees::models::DecisionTreeEnsembleResource* const ensemble_resource, - const float learning_rate, const uint64 dropout_seed) { + const float learning_rate, const uint64 dropout_seed, + const int32 max_tree_depth) { const auto num_trees = ensemble_resource->num_trees(); if (num_trees <= 0 || ensemble_resource->LastTreeMetadata()->is_finalized()) { @@ -506,8 +513,7 @@ class GrowTreeEnsembleOp : public OpKernel { tree_config->add_nodes()->mutable_leaf(); boosted_trees::trees::DecisionTreeMetadata* const tree_metadata = ensemble_resource->LastTreeMetadata(); - tree_metadata->set_is_finalized( - learner_config_.constraints().max_tree_depth() <= 1); + tree_metadata->set_is_finalized(max_tree_depth <= 1); tree_metadata->set_num_tree_weight_updates(1); } else { // The growable tree is by definition the last tree in the ensemble. @@ -518,8 +524,7 @@ class GrowTreeEnsembleOp : public OpKernel { << num_trees - 1 << " of ensemble of " << num_trees << " trees."; // Update growable tree metadata. tree_metadata->set_num_layers_grown(new_num_layers); - tree_metadata->set_is_finalized( - new_num_layers >= learner_config_.constraints().max_tree_depth()); + tree_metadata->set_is_finalized(new_num_layers >= max_tree_depth); } UpdateTreeWeightsIfDropout(ensemble_resource, dropout_seed); return ensemble_resource->LastTree(); diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py index 1b7f59ea4218355a13f1df7264352bd68503bd19..5d4819b0f1cb598cfbe146f569aecd7883186339 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py @@ -131,6 +131,10 @@ class BaseSplitHandler(object): }, stamp_token, None) return control_flow_ops.group(update_1, *update_2[self]) + @abc.abstractmethod + def reset(self, stamp_token, next_stamp_token): + """Resets the state maintained by the handler.""" + @abc.abstractmethod def make_splits(self, stamp_token, next_stamp_token, class_id): """Create the best split using the accumulated stats and flush the state. diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index bf686237ff696dadad9713d26bf784d7442b80d0..efe29216c2a7d8aa985da54cdbb839b9e6f69078 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -202,3 +202,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): # always return ready. are_splits_ready = constant_op.constant(True) return (are_splits_ready, partition_ids, gains, split_infos) + + def reset(self, stamp_token, next_stamp_token): + reset = self._stats_accumulator.flush(stamp_token, next_stamp_token) + return reset diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index df0bec1fe363e07bbff6b059e86076239bd605e9..2559fe9913f377ce38aa11dfa908cd25ec76dab4 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -79,6 +79,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops + _BIAS_FEATURE_ID = -1 # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") @@ -147,6 +148,11 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): num_quantiles=num_quantiles, name="QuantileAccumulator/{}".format(self._name)) + def reset(self, stamp_token, next_stamp_token): + reset_1 = self._stats_accumulator.flush(stamp_token, next_stamp_token) + reset_2 = self._quantile_accumulator.flush(stamp_token, next_stamp_token) + return control_flow_ops.group([reset_1, reset_2]) + class DenseSplitHandler(InequalitySplitHandler): """Computes stats and finds the best inequality splits on dense columns.""" @@ -264,6 +270,7 @@ class DenseSplitHandler(InequalitySplitHandler): self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, self._min_node_weight, self._loss_uses_sum_reduction)) + return are_splits_ready, partition_ids, gains, split_infos @@ -579,8 +586,10 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column, example_partition_ids, feature_ids, gradients, hessians = ( control_flow_ops.cond( - math_ops.logical_and(are_buckets_ready, is_active[0]), - ready_inputs_fn, not_ready_inputs_fn)) + math_ops.logical_and( + math_ops.logical_and(are_buckets_ready, + array_ops.size(quantile_buckets) > 0), + is_active[0]), ready_inputs_fn, not_ready_inputs_fn)) return (quantile_values, quantile_weights, example_partition_ids, feature_ids, gradients, hessians) @@ -674,8 +683,10 @@ def sparse_make_stats_update( lambda: handler_not_active)) example_partition_ids, feature_ids, gradients, hessians = ( - control_flow_ops.cond(are_buckets_ready, quantiles_ready, - quantiles_not_ready)) + control_flow_ops.cond( + math_ops.logical_and(are_buckets_ready, + array_ops.size(quantile_buckets) > 0), + quantiles_ready, quantiles_not_ready)) return (quantile_indices, quantile_values, quantile_shape, quantile_weights, example_partition_ids, feature_ids, gradients, hessians) diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index d59732cf92eb85e88732ac5a17dccf475ae5342f..5d82c4cae5dbe28c82fa8754a7c65db62a2e6814 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -1072,8 +1072,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): with self.test_session() as sess: # Batch is 4, 2 classes - gradients = array_ops.constant( - [[0.2, 1.4], [-0.5, 0.1], [1.2, 3], [4.0, -3]]) + gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3], + [4.0, -3]]) # 2x2 matrix for each instance hessian_0 = [[0.12, 0.02], [0.3, 0.11]] hessian_1 = [[0.07, -0.2], [-0.5, 0.2]] @@ -1167,8 +1167,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self): with self.test_session() as sess: # Batch is 4, 2 classes - gradients = array_ops.constant( - [[0.2, 1.4], [-0.5, 0.1], [1.2, 3], [4.0, -3]]) + gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3], + [4.0, -3]]) # Each hessian is a diagonal from a full hessian matrix. hessian_0 = [0.12, 0.11] hessian_1 = [0.07, 0.2] @@ -1406,6 +1406,100 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(gains), 0) self.assertEqual(len(splits), 0) + def testEmptyBuckets(self): + """Test that reproduces the case when quantile buckets were empty.""" + with self.test_session() as sess: + sparse_column = array_ops.sparse_placeholder(dtypes.float32) + + # We have two batches - at first, a sparse feature is empty. + empty_indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2]) + empty_values = array_ops.constant([], dtype=dtypes.float32) + empty_sparse_column = sparse_tensor.SparseTensor(empty_indices, + empty_values, [4, 2]) + empty_sparse_column = empty_sparse_column.eval(session=sess) + + # For the second batch, the sparse feature is not empty. + non_empty_indices = array_ops.constant( + [[0, 0], [2, 1], [3, 2]], dtype=dtypes.int64, shape=[3, 2]) + non_empty_values = array_ops.constant( + [0.52, 0.3, 0.52], dtype=dtypes.float32) + non_empty_sparse_column = sparse_tensor.SparseTensor( + non_empty_indices, non_empty_values, [4, 2]) + non_empty_sparse_column = non_empty_sparse_column.eval(session=sess) + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + class_id = -1 + + split_handler = ordinal_split_handler.SparseSplitHandler( + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, + epsilon=0.01, + num_quantiles=2, + feature_column_group_id=0, + sparse_float_column=sparse_column, + init_stamp_token=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS) + resources.initialize_resources(resources.shared_resources()).run() + gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) + hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) + partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([4, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + + # First, calculate quantiles and try to update on an empty data for a + # feature. + are_splits_ready = ( + sess.run( + are_splits_ready, + feed_dict={sparse_column: empty_sparse_column})) + self.assertFalse(are_splits_ready) + + update_2 = split_handler.update_stats_sync( + 1, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_2]): + are_splits_ready2, partitions, gains, splits = ( + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) + + # Now the feature in the second batch is not empty, but buckets + # calculated on the first batch are empty. + are_splits_ready2, partitions, gains, splits = ( + sess.run( + [are_splits_ready2, partitions, gains, splits], + feed_dict={sparse_column: non_empty_sparse_column})) + self.assertFalse(are_splits_ready) + self.assertTrue(are_splits_ready2) + # Since the buckets were empty, we can't calculate the splits. + self.assertEqual(len(partitions), 0) + self.assertEqual(len(gains), 0) + self.assertEqual(len(splits), 0) + def testDegenerativeCase(self): with self.test_session() as sess: # One data example only, one leaf and thus one quantile bucket.The same diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h index c120dd8a6c156ec9eb7ba0b6c552f5138bd21a16..f19e5116f5865777ab65e1add2777ac41105acc0 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h @@ -58,6 +58,8 @@ namespace quantiles { // Compute: O(n * log(1/eps * log(eps * n))). // Memory: O(1/eps * log^2(eps * n)) <- for one worker streaming through the // entire dataset. +// An epsilon value of zero would make the algorithm extremely inefficent and +// therefore, is disallowed. template > class WeightedQuantilesStream { @@ -69,6 +71,9 @@ class WeightedQuantilesStream { explicit WeightedQuantilesStream(double eps, int64 max_elements) : eps_(eps), buffer_(1LL, 2LL), finalized_(false) { + // See the class documentation. An epsilon value of zero could cause + // perfoamance issues. + QCHECK(eps > 0) << "An epsilon value of zero is not allowed."; std::tie(max_levels_, block_size_) = GetQuantileSpecs(eps, max_elements); buffer_ = Buffer(block_size_, max_elements); summary_levels_.reserve(max_levels_); diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h index a7e7bfc13cadcea4d29d33e0dbd955bdad6ffcb9..69bb8fd4ada861a42a0ccc3f287a47d91be5c879 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h @@ -51,7 +51,7 @@ class WeightedQuantilesSummary { SummaryEntry() { memset(this, 0, sizeof(*this)); - value = 0; + value = ValueType(); weight = 0; min_rank = 0; max_rank = 0; diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc index 35b059f3496dbc8fb2b3d4fe6ec6b55a9d73dd0c..4fab2b0b7deb6ff2e353d758dc068aa28d44d5ae 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc @@ -16,6 +16,7 @@ #include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h" #include "tensorflow/contrib/boosted_trees/lib/utils/macros.h" #include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { namespace boosted_trees { @@ -96,9 +97,11 @@ Status BatchFeatures::Initialize( "Sparse float feature shape incompatible with batch size.")); auto tensor_shape = TensorShape({shape_flat(0), shape_flat(1)}); auto order_dims = sparse::SparseTensor::VarDimArray({0, 1}); - sparse_float_feature_columns_.emplace_back(sparse_float_feature_indices, - sparse_float_feature_values, - tensor_shape, order_dims); + sparse::SparseTensor sparse_tensor; + TF_RETURN_IF_ERROR(sparse::SparseTensor::Create( + sparse_float_feature_indices, sparse_float_feature_values, tensor_shape, + order_dims, &sparse_tensor)); + sparse_float_feature_columns_.push_back(std::move(sparse_tensor)); } // Read sparse int features. @@ -136,9 +139,11 @@ Status BatchFeatures::Initialize( "Sparse int feature shape incompatible with batch size.")); auto tensor_shape = TensorShape({shape_flat(0), shape_flat(1)}); auto order_dims = sparse::SparseTensor::VarDimArray({0, 1}); - sparse_int_feature_columns_.emplace_back(sparse_int_feature_indices, - sparse_int_feature_values, - tensor_shape, order_dims); + sparse::SparseTensor sparse_tensor; + TF_RETURN_IF_ERROR(sparse::SparseTensor::Create( + sparse_int_feature_indices, sparse_int_feature_values, tensor_shape, + order_dims, &sparse_tensor)); + sparse_int_feature_columns_.push_back(std::move(sparse_tensor)); } return Status::OK(); } diff --git a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc index d8a608864834b17886313a368221fbf94e31c98e..30c37435fe16ef29a9e29202850501098e9ac7f8 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc @@ -43,27 +43,35 @@ TEST_F(ExamplesIterableTest, Iterate) { test::AsTensor({0, 0, 2, 0, 3, 0, 4, 0}, {4, 2}); auto sparse_float_values1 = test::AsTensor({-3.0f, 0.0f, 5.0f, 0.0f}); auto sparse_float_shape1 = TensorShape({8, 1}); - sparse::SparseTensor sparse_float_tensor1( - sparse_float_indices1, sparse_float_values1, sparse_float_shape1); + sparse::SparseTensor sparse_float_tensor1; + TF_ASSERT_OK( + sparse::SparseTensor::Create(sparse_float_indices1, sparse_float_values1, + sparse_float_shape1, &sparse_float_tensor1)); auto sparse_float_indices2 = test::AsTensor( {0, 1, 1, 0, 2, 1, 3, 0, 4, 1, 5, 0, 5, 1, 7, 0}, {8, 2}); auto sparse_float_values2 = test::AsTensor({1.f, 4.0f, 3.f, 7.0f, 4.3f, 9.0f, 0.8f, -4.0f}); auto sparse_float_shape2 = TensorShape({8, 2}); - sparse::SparseTensor sparse_float_tensor2( - sparse_float_indices2, sparse_float_values2, sparse_float_shape2); + sparse::SparseTensor sparse_float_tensor2; + TF_ASSERT_OK( + sparse::SparseTensor::Create(sparse_float_indices2, sparse_float_values2, + sparse_float_shape2, &sparse_float_tensor2)); auto sparse_int_indices1 = test::AsTensor({0, 0, 0, 1, 1, 0, 3, 0, 3, 1, 7, 0}, {6, 2}); auto sparse_int_values1 = test::AsTensor({1, 8, 0, 2, 0, 5}); auto sparse_int_shape1 = TensorShape({8, 2}); - sparse::SparseTensor sparse_int_tensor1( - sparse_int_indices1, sparse_int_values1, sparse_int_shape1); + sparse::SparseTensor sparse_int_tensor1; + TF_ASSERT_OK( + sparse::SparseTensor::Create(sparse_int_indices1, sparse_int_values1, + sparse_int_shape1, &sparse_int_tensor1)); auto sparse_int_indices2 = test::AsTensor({1, 0, 2, 0, 3, 0, 4, 0}, {4, 2}); auto sparse_int_values2 = test::AsTensor({7, 13, 4, 0}); auto sparse_int_shape2 = TensorShape({8, 1}); - sparse::SparseTensor sparse_int_tensor2( - sparse_int_indices2, sparse_int_values2, sparse_int_shape2); + sparse::SparseTensor sparse_int_tensor2; + TF_ASSERT_OK( + sparse::SparseTensor::Create(sparse_int_indices2, sparse_int_values2, + sparse_int_shape2, &sparse_int_tensor2)); auto validate_example_features = [](int64 example_idx, const Example& example) { diff --git a/tensorflow/contrib/boosted_trees/ops/training_ops.cc b/tensorflow/contrib/boosted_trees/ops/training_ops.cc index f63c199ad6146c23c22437ffe2287a77ee91ca44..22ac9edb72ea91ecef6fd1dff9f399b3c9020083 100644 --- a/tensorflow/contrib/boosted_trees/ops/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/training_ops.cc @@ -56,6 +56,7 @@ REGISTER_OP("GrowTreeEnsemble") .Input("next_stamp_token: int64") .Input("learning_rate: float") .Input("dropout_seed: int64") + .Input("max_tree_depth: int32") .Input("partition_ids: num_handlers * int32") .Input("gains: num_handlers * float") .Input("splits: num_handlers * string") @@ -67,6 +68,8 @@ REGISTER_OP("GrowTreeEnsemble") TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_input)); // Dropout seed. TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_input)); + // Maximum tree depth. + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_input)); return Status::OK(); }) .Doc(R"doc( diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index 3e524efbeac74ff754d63cae92b3e194411cb2de..e39e1de8d1954c7f4dcab87d7727a64affa13c8c 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -296,7 +296,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, # Dropout does not change anything here, tree is not finalized. - dropout_probability=0.5).SerializeToString() + dropout_probability=0.5) # Prepare handler inputs. # Note that handlers 1 & 3 have the same gain but different splits. @@ -321,9 +321,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the simpler split from handler 1 to be chosen. @@ -443,7 +444,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, # Dropout does not change anything here - tree is not finalized. - dropout_probability=0.5).SerializeToString() + dropout_probability=0.5) # Prepare handler inputs. # Handler 1 only has a candidate for partition 1, handler 2 has candidates @@ -472,9 +473,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the split for partition 1 to be chosen from handler 1 and @@ -632,8 +634,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=1, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) @@ -657,9 +658,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect a new tree to be added with the split from handler 1. @@ -773,8 +775,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=1, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. # All handlers have negative gain. @@ -794,9 +795,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions, handler2_partitions], gains=[handler1_gains, handler2_gains], splits=[handler1_split, handler2_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the ensemble to be empty. @@ -839,8 +841,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=1, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. # Note that handlers 1 & 3 have the same gain but different splits. @@ -865,9 +866,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the simpler split from handler 1 to be chosen. @@ -946,8 +948,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=2, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. # All handlers have negative gain. @@ -967,9 +968,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions, handler2_partitions], gains=[handler1_gains, handler2_gains], splits=[handler1_split, handler2_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the split from handler 2 to be chosen despite the negative gain. @@ -1048,9 +1050,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions], gains=[handler1_gains], splits=[handler1_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the ensemble to be empty as post-pruning will prune @@ -1094,8 +1097,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=2, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( - ) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) # Prepare handler inputs. # Second handler has positive gain. @@ -1115,9 +1117,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions, handler2_partitions], gains=[handler1_gains, handler2_gains], splits=[handler1_split, handler2_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the split from handler 2 to be chosen despite the negative gain. @@ -1194,9 +1197,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): partition_ids=[handler1_partitions], gains=[handler1_gains], splits=[handler1_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the negative gain split of partition 1 to be pruned and the @@ -1335,7 +1339,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER, # Dropout will have no effect, since the tree will not be fully grown. - dropout_probability=1.0).SerializeToString() + dropout_probability=1.0) # Prepare handler inputs. # Handler 1 only has a candidate for partition 1, handler 2 has candidates @@ -1364,9 +1368,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect the split for partition 1 to be chosen from handler 1 and @@ -1543,7 +1548,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, - dropout_probability=1.0).SerializeToString() + dropout_probability=1.0) # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) @@ -1567,9 +1572,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) # Expect a new tree to be added with the split from handler 1. @@ -1669,7 +1675,6 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) learner_config.constraints.max_number_of_unique_feature_columns = 3 - learner_config = learner_config.SerializeToString() # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) handler1_gains = np.array([7.62], dtype=np.float32) @@ -1692,9 +1697,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): ], gains=[handler1_gains, handler2_gains, handler3_gains], splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), dropout_seed=123, - center_bias=True) + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth) session.run(grow_op) _, serialized = session.run( diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 1ee7f2395ea2ad71a7d380a1cc8f9a77bd4782b3..d0d1249bd6afc9cdbf6d88298c5024a4a54a5073 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -51,6 +51,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import device_setter + # Key names for prediction dict. ENSEMBLE_STAMP = "ensemble_stamp" PREDICTIONS = "predictions" @@ -287,7 +288,8 @@ class GradientBoostedDecisionTreeModel(object): loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS, feature_columns=None, use_core_columns=False, - output_leaf_index=False): + output_leaf_index=False, + output_leaf_index_modes=None): """Construct a new GradientBoostedDecisionTreeModel function. Args: @@ -307,6 +309,9 @@ class GradientBoostedDecisionTreeModel(object): used. output_leaf_index: A boolean variable indicating whether to output leaf index into predictions dictionary. + output_leaf_index_modes: A list of modes from (TRAIN, EVAL, INFER) which + dictates when leaf indices will be outputted. By default, leaf indices + are only outputted in INFER mode. Raises: ValueError: if inputs are not valid. @@ -349,6 +354,9 @@ class GradientBoostedDecisionTreeModel(object): self._gradient_shape = tensor_shape.scalar() self._hessian_shape = tensor_shape.scalar() else: + if center_bias: + raise ValueError("Center bias should be False for multiclass.") + self._gradient_shape = tensor_shape.TensorShape([logits_dimension]) if (learner_config.multi_class_strategy == learner_pb2.LearnerConfig.FULL_HESSIAN): @@ -376,6 +384,8 @@ class GradientBoostedDecisionTreeModel(object): self._learner_config = learner_config self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() + self._max_tree_depth = variables.Variable( + initial_value=self._learner_config.constraints.max_tree_depth) self._attempted_trees = variables.Variable( initial_value=array_ops.zeros([], dtypes.int64), trainable=False, @@ -404,7 +414,16 @@ class GradientBoostedDecisionTreeModel(object): self._learner_config.multi_class_strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS and learner_config.num_classes == 2) + + if output_leaf_index_modes is None: + output_leaf_index_modes = [learn.ModeKeys.INFER] + elif not all( + mode in (learn.ModeKeys.TRAIN, learn.ModeKeys.EVAL, + learn.ModeKeys.INFER) for mode in output_leaf_index_modes): + raise ValueError("output_leaf_index_modes should only contain ModeKeys.") + self._output_leaf_index = output_leaf_index + self._output_leaf_index_modes = output_leaf_index_modes def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode): """Runs prediction and returns a dictionary of the prediction results. @@ -435,8 +454,7 @@ class GradientBoostedDecisionTreeModel(object): # the right stamp. with ops.control_dependencies(ensemble_stats): leaf_index = None - # Only used in infer (predict), not used in train and eval. - if self._output_leaf_index and mode == learn.ModeKeys.INFER: + if self._output_leaf_index and mode in self._output_leaf_index_modes: predictions, _, leaf_index = ( prediction_ops).gradient_trees_prediction_verbose( ensemble_handle, @@ -508,9 +526,6 @@ class GradientBoostedDecisionTreeModel(object): if not input_deps: raise ValueError("No input tensors for prediction.") - if any(i.device != input_deps[0].device for i in input_deps): - raise ValueError("All input tensors should be on the same device.") - # Get most current model stamp. ensemble_stamp = model_ops.tree_ensemble_stamp_token(self._ensemble_handle) @@ -884,7 +899,7 @@ class GradientBoostedDecisionTreeModel(object): reset_ops = [] for handler in handlers: - reset_ops.append(handler.make_splits(stamp_token, next_stamp_token, 0)) + reset_ops.append(handler.reset(stamp_token, next_stamp_token)) if self._center_bias: reset_ops.append( bias_stats_accumulator.flush(stamp_token, next_stamp_token)) @@ -1042,7 +1057,8 @@ class GradientBoostedDecisionTreeModel(object): splits=split_info_list, learner_config=self._learner_config_serialized, dropout_seed=dropout_seed, - center_bias=self._center_bias) + center_bias=self._center_bias, + max_tree_depth=self._max_tree_depth) def _grow_ensemble_not_ready_fn(): # Don't grow the ensemble, just update the stamp. @@ -1056,7 +1072,8 @@ class GradientBoostedDecisionTreeModel(object): splits=[], learner_config=self._learner_config_serialized, dropout_seed=dropout_seed, - center_bias=self._center_bias) + center_bias=self._center_bias, + max_tree_depth=self._max_tree_depth) def _grow_ensemble_fn(): # Conditionally grow an ensemble depending on whether the splits @@ -1096,6 +1113,9 @@ class GradientBoostedDecisionTreeModel(object): def get_number_of_trees_tensor(self): return self._finalized_trees, self._attempted_trees + def get_max_tree_depth(self): + return self._max_tree_depth + def train(self, loss, predictions_dict, labels): """Updates the accumalator stats and grows the ensemble. diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py index ab7ac2aba605db22a8ed370049b27d55cf1d413a..b5ebaf1999519f65110e8164fa20bace5ecc3ef6 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py @@ -23,6 +23,12 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn +from tensorflow.python.ops.losses import losses + + +def per_example_squared_hinge_loss(labels, weights, predictions): + loss = losses.hinge_loss(labels=labels, logits=predictions, weights=weights) + return math_ops.square(loss), control_flow_ops.no_op() def per_example_logistic_loss(labels, weights, predictions): @@ -126,7 +132,7 @@ def per_example_squared_loss(labels, weights, predictions): def per_example_exp_loss(labels, weights, predictions, name=None, eps=0.1): - """Exponential loss given labels, example weights and predictions. + """Trimmed exponential loss given labels, example weights and predictions. Note that this is only for binary classification. If logistic loss tries to make sure that the classifier is certain of its @@ -211,3 +217,62 @@ def per_example_exp_loss(labels, weights, predictions, name=None, eps=0.1): unweighted_loss = exp_with_logits( name=name, eps=eps, labels=labels, logits=predictions) return unweighted_loss * weights, control_flow_ops.no_op() + + +def per_example_full_exp_loss(labels, weights, predictions, name=None): + """Full exponential loss given labels, example weights and predictions. + + Note that this is only for binary classification. + The loss returns is exp(-targets*logits), where targets are converted to -1 + and 1. + + Args: + labels: Rank 2 (N, D) tensor of per-example labels. + weights: Rank 2 (N, 1) tensor of per-example weights. + predictions: Rank 2 (N, D) tensor of per-example predictions. + name: A name for the operation (optional). + + Returns: + loss: A Rank 2 (N, 1) tensor of per-example exp loss + update_op: An update operation to update the loss's internal state. + """ + + def full_exp_with_logits(name, labels=None, logits=None): + """Computes exponential loss given `logits`. + + Args: + name: A name for the operation (optional). + labels: A `Tensor` of the same type and shape as `logits`. + logits: A `Tensor` of type `float32` or `float64`. + + Returns: + A `Tensor` of the same shape as `logits` with the componentwise + exponential losses. + + Raises: + ValueError: If `logits` and `labels` do not have the same shape. + """ + with ops.name_scope(name, "exp_loss", [logits, labels]) as name: + logits = ops.convert_to_tensor(logits, name="logits") + labels = ops.convert_to_tensor(labels, name="labels") + try: + labels.get_shape().merge_with(logits.get_shape()) + except ValueError: + raise ValueError("logits and labels must have the same shape (%s vs %s)" + % (logits.get_shape(), labels.get_shape())) + + # Default threshold of 0 to switch between classes + zeros = array_ops.zeros_like(logits, dtype=logits.dtype) + ones = array_ops.ones_like(logits, dtype=logits.dtype) + neg_ones = -array_ops.ones_like(logits, dtype=logits.dtype) + + # Convert labels to 1 and -1 + cond_labels = (labels > zeros) + labels_converted = array_ops.where(cond_labels, ones, neg_ones) + + return math_ops.exp(-1.0 * logits * labels_converted) + + labels = math_ops.to_float(labels) + unweighted_loss = full_exp_with_logits( + name=name, labels=labels, logits=predictions) + return unweighted_loss * weights, control_flow_ops.no_op() diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 8c1ce5c2a2d552e30d3b676e3ac8b5fc7c74a917..2fbaa31d5e19b58c335cd0a894e1db9af2c34d08 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -44,8 +44,8 @@ from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import Checkpointa from tensorflow.python.training.checkpointable.base import CheckpointableBase from tensorflow.python.training.checkpointable.data_structures import List from tensorflow.python.training.checkpointable.data_structures import Mapping +from tensorflow.python.training.checkpointable.data_structures import NoDependency from tensorflow.python.training.checkpointable.tracking import Checkpointable -from tensorflow.python.training.checkpointable.tracking import NoDependency from tensorflow.python.training.checkpointable.util import capture_dependencies from tensorflow.python.training.checkpointable.util import list_objects from tensorflow.python.training.checkpointable.util import object_metadata diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py index 4d3d5312993740636709cb732c0b8e3e2626262d..242c1e8ba45e0b2f6f9a1a51695b824546382666 100644 --- a/tensorflow/contrib/checkpoint/python/containers.py +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -35,9 +35,9 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure): self.slotdeps = tf.contrib.checkpoint.UniqueNameTracker() slotdeps = self.slotdeps slots = [] - slots.append(slotdeps.track(tfe.Variable(3.), "x")) # Named "x" - slots.append(slotdeps.track(tfe.Variable(4.), "y")) - slots.append(slotdeps.track(tfe.Variable(5.), "x")) # Named "x_1" + slots.append(slotdeps.track(tf.Variable(3.), "x")) # Named "x" + slots.append(slotdeps.track(tf.Variable(4.), "y")) + slots.append(slotdeps.track(tf.Variable(5.), "x")) # Named "x_1" ``` """ diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py index 64d056bd689a14c0c58d7a0f75c833c71b00a5c3..ac85c7be803cd4c2f8ba19d3ef887a3c65a15933 100644 --- a/tensorflow/contrib/checkpoint/python/containers_test.py +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -26,6 +26,7 @@ from tensorflow.python.keras import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import data_structures from tensorflow.python.training.checkpointable import tracking from tensorflow.python.training.checkpointable import util @@ -79,7 +80,7 @@ class UniqueNameTrackerTests(test.TestCase): resource_variable_ops.ResourceVariable(4.), "y")) slots.append(slotdeps.track( resource_variable_ops.ResourceVariable(5.), "x")) - self.slots = slots + self.slots = data_structures.NoDependency(slots) manager = SlotManager() self.evaluate([v.initializer for v in manager.slots]) diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index 1a7a3759baa4a5559b4b70ff4f7467c41da9111f..523a9efcf05f5d32589f6e1734f866bf8b4b9cdc 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -50,6 +50,7 @@ py_library( deps = [ ":gen_bigquery_reader_ops", ":gen_gcs_config_ops", + "//tensorflow/contrib/bigtable", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:io_ops", "//tensorflow/python:util", diff --git a/tensorflow/contrib/cloud/README.md b/tensorflow/contrib/cloud/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a80d8965f3b562cadaff8caad8d40c7b98afa78f --- /dev/null +++ b/tensorflow/contrib/cloud/README.md @@ -0,0 +1,18 @@ +# Cloud # + +## Cloud Bigtable ## + +[Google Cloud Bigtable](https://cloud.google.com/bigtable/) is a high +performance storage system that can store and serve training data. This contrib +package contains an experimental integration with TensorFlow. + +> **Status: Highly experimental.** The current implementation is very much in +> flux. Please use at your own risk! :-) + + + +## Cloud Storage (GCS) ## + +The Google Cloud Storage ops allow the user to configure the GCS File System. + + diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py index ef7aa7624ce7b9b6480c4d088a2fb7678a7acc76..8efd259946b7696e66b83a3b0aa451543c107467 100644 --- a/tensorflow/contrib/cloud/__init__.py +++ b/tensorflow/contrib/cloud/__init__.py @@ -18,15 +18,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=line-too-long,wildcard-import +import os + +# pylint: disable=line-too-long,wildcard-import,g-import-not-at-top from tensorflow.contrib.cloud.python.ops.bigquery_reader_ops import * from tensorflow.contrib.cloud.python.ops.gcs_config_ops import * -# pylint: enable=line-too-long,wildcard-import + +if os.name != 'nt': + from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableClient + from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableTable + +del os from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'BigQueryReader', + 'BigtableClient', + 'BigtableTable', 'BlockCacheParams', 'configure_colab_session', 'configure_gcs', diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc index 1bfd27305d569668a0bd67d876e59eec082296b3..58fadffce32f9a8fec047d1e99f9f4eb5a710d91 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc @@ -85,7 +85,7 @@ Status BigQueryTableAccessor::New( int64 timestamp_millis, int64 row_buffer_size, const string& end_point, const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, - std::unique_ptr http_request_factory, + std::shared_ptr http_request_factory, std::unique_ptr* accessor) { if (timestamp_millis <= 0) { return errors::InvalidArgument( @@ -94,29 +94,19 @@ Status BigQueryTableAccessor::New( const string& big_query_end_point = end_point.empty() ? kBigQueryEndPoint : end_point; if (auth_provider == nullptr && http_request_factory == nullptr) { - accessor->reset(new BigQueryTableAccessor( - project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - big_query_end_point, columns, partition)); - } else { - accessor->reset(new BigQueryTableAccessor( - project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - big_query_end_point, columns, partition, std::move(auth_provider), - std::move(http_request_factory))); + http_request_factory = std::make_shared(); + auto compute_engine_metadata_client = + std::make_shared(http_request_factory); + auth_provider = std::unique_ptr( + new GoogleAuthProvider(compute_engine_metadata_client)); } - return (*accessor)->ReadSchema(); -} -BigQueryTableAccessor::BigQueryTableAccessor( - const string& project_id, const string& dataset_id, const string& table_id, - int64 timestamp_millis, int64 row_buffer_size, const string& end_point, - const std::vector& columns, const BigQueryTablePartition& partition) - : BigQueryTableAccessor( - project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, - end_point, columns, partition, - std::unique_ptr(new GoogleAuthProvider()), - std::unique_ptr( - new CurlHttpRequest::Factory())) { - row_buffer_.resize(row_buffer_size); + accessor->reset(new BigQueryTableAccessor( + project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, + big_query_end_point, columns, partition, std::move(auth_provider), + std::move(http_request_factory))); + + return (*accessor)->ReadSchema(); } BigQueryTableAccessor::BigQueryTableAccessor( @@ -124,7 +114,7 @@ BigQueryTableAccessor::BigQueryTableAccessor( int64 timestamp_millis, int64 row_buffer_size, const string& end_point, const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, - std::unique_ptr http_request_factory) + std::shared_ptr http_request_factory) : project_id_(project_id), dataset_id_(dataset_id), table_id_(table_id), diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h index b349063715c903c982cfe2fb116b6525e35ff63b..1af43a3e1070d466bb50019f12b22a060c1e6ab1 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h @@ -109,24 +109,17 @@ class BigQueryTableAccessor { const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, - std::unique_ptr http_request_factory, + std::shared_ptr http_request_factory, std::unique_ptr* accessor); /// \brief Constructs an object for a given table and partition. - BigQueryTableAccessor(const string& project_id, const string& dataset_id, - const string& table_id, int64 timestamp_millis, - int64 row_buffer_size, const string& end_point, - const std::vector& columns, - const BigQueryTablePartition& partition); - - /// Used for unit testing. BigQueryTableAccessor( const string& project_id, const string& dataset_id, const string& table_id, int64 timestamp_millis, int64 row_buffer_size, const string& end_point, const std::vector& columns, const BigQueryTablePartition& partition, std::unique_ptr auth_provider, - std::unique_ptr http_request_factory); + std::shared_ptr http_request_factory); /// \brief Parses column values for a given row. Status ParseColumnValues(const Json::Value& value, @@ -199,7 +192,7 @@ class BigQueryTableAccessor { SchemaNode schema_root_; std::unique_ptr auth_provider_; - std::unique_ptr http_request_factory_; + std::shared_ptr http_request_factory_; TF_DISALLOW_COPY_AND_ASSIGN(BigQueryTableAccessor); }; diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD index c239e6f8f960910cee14e1df7c4678c643496f54..707f6211846ca0310bde297603928e9ec5bb471c 100644 --- a/tensorflow/contrib/cluster_resolver/BUILD +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -12,6 +12,15 @@ licenses(["notice"]) # Apache 2.0 py_library( name = "cluster_resolver_pip", + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":cluster_resolver_py", + ], +) + +py_library( + name = "cluster_resolver_py", srcs = [ "__init__.py", "python/training/__init__.py", @@ -19,7 +28,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - ":cluster_resolver_py", + ":base_cluster_resolver_py", ":gce_cluster_resolver_py", ":tpu_cluster_resolver_py", "//tensorflow/python:util", @@ -27,7 +36,7 @@ py_library( ) py_library( - name = "cluster_resolver_py", + name = "base_cluster_resolver_py", srcs = ["python/training/cluster_resolver.py"], srcs_version = "PY2AND3", deps = [ @@ -40,7 +49,7 @@ py_library( srcs = ["python/training/gce_cluster_resolver.py"], srcs_version = "PY2AND3", deps = [ - ":cluster_resolver_py", + ":base_cluster_resolver_py", "//tensorflow/python:training", ], ) @@ -50,13 +59,13 @@ py_library( srcs = ["python/training/tpu_cluster_resolver.py"], srcs_version = "PY2AND3", deps = [ - ":cluster_resolver_py", + ":base_cluster_resolver_py", "//tensorflow/python:training", ], ) tf_py_test( - name = "cluster_resolver_py_test", + name = "base_cluster_resolver_py_test", srcs = ["python/training/cluster_resolver_test.py"], additional_deps = [ ":cluster_resolver_py", diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 8f521ffee4d31e090c13bac98290656d6e1d330e..1ab150d74ac00c5f9acf3c9399880708b2f62b1e 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -148,6 +148,9 @@ class TPUClusterResolver(ClusterResolver): else: tpu = self._envVarFallback() + if tpu is None: + raise ValueError('Please provide a TPU Name to connect to.') + self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes self._job_name = job_name self._credentials = credentials @@ -259,11 +262,11 @@ class TPUClusterResolver(ClusterResolver): if 'state' in response and response['state'] != 'READY': raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % - (self._tpu, response['state'])) + (compat.as_text(self._tpu), response['state'])) if 'health' in response and response['health'] != 'HEALTHY': - raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, - response['health'])) + raise RuntimeError('TPU "%s" is unhealthy: "%s"' % + (compat.as_text(self._tpu), response['health'])) if 'networkEndpoints' in response: worker_list = [ diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index a0a5b0e00c1979ebf8850408785135b9ceac7d2a..f6c928e2be62e7292c6feaa3bb26fd463320158b 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -145,26 +145,41 @@ if(WIN32) # temporary fix for #18241 add_definitions(-DEIGEN_DEFAULT_DENSE_INDEX_TYPE=std::int64_t) endif() - add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00 -DLANG_CXX11) - add_definitions(-DWIN32 -DOS_WIN -D_MBCS -DWIN32_LEAN_AND_MEAN -DNOGDI -DPLATFORM_WINDOWS) + add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00) + add_definitions(-DWIN32_LEAN_AND_MEAN -DNOGDI -DPLATFORM_WINDOWS) add_definitions(-DTENSORFLOW_USE_EIGEN_THREADPOOL -DEIGEN_HAS_C99_MATH) add_definitions(-DTF_COMPILE_LIBRARY) - add_definitions(/bigobj /nologo /EHsc /GF /MP /Gm-) + add_compile_options(/bigobj /GF /MP /Gm-) # Suppress warnings to reduce build log size. - add_definitions(/wd4267 /wd4244 /wd4800 /wd4503 /wd4554 /wd4996 /wd4348 /wd4018) - add_definitions(/wd4099 /wd4146 /wd4267 /wd4305 /wd4307) - add_definitions(/wd4715 /wd4722 /wd4723 /wd4838 /wd4309 /wd4334) - add_definitions(/wd4003 /wd4244 /wd4267 /wd4503 /wd4506 /wd4800 /wd4996) + add_compile_options(/wd4267 /wd4244 /wd4800 /wd4503 /wd4554 /wd4996 /wd4348 /wd4018) + add_compile_options(/wd4099 /wd4146 /wd4267 /wd4305 /wd4307) + add_compile_options(/wd4715 /wd4722 /wd4723 /wd4838 /wd4309 /wd4334) + add_compile_options(/wd4003 /wd4244 /wd4267 /wd4503 /wd4506 /wd4800 /wd4996) # Suppress linker warnings. set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} /ignore:4049 /ignore:4197 /ignore:4217 /ignore:4221") set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} /ignore:4049 /ignore:4197 /ignore:4217 /ignore:4221") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /ignore:4049 /ignore:4197 /ignore:4217 /ignore:4221") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") set(CMAKE_CXX_FLAGS_DEBUG "/D_DEBUG /MDd /Ob2") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /D_ITERATOR_DEBUG_LEVEL=0") set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} /D_ITERATOR_DEBUG_LEVEL=0") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /D_ITERATOR_DEBUG_LEVEL=0") + set(compiler_flags + CMAKE_CXX_FLAGS + CMAKE_CXX_FLAGS_DEBUG + CMAKE_CXX_FLAGS_RELEASE + CMAKE_C_FLAGS + CMAKE_C_FLAGS_DEBUG + CMAKE_C_FLAGS_RELEASE + ) + # No exception + foreach(flag ${compiler_flags}) + string(REPLACE "/EHsc" "/EHs-c-" ${flag} "${${flag}}") + endforeach() + add_definitions(/D_HAS_EXCEPTIONS=0) + # Suppress 'noexcept used with no exception handling mode specified' warning + add_compile_options(/wd4577) + # Try to avoid flaky failures due to failed generation of generate.stamp files. set(CMAKE_SUPPRESS_REGENERATION ON) endif() @@ -379,16 +394,20 @@ if (tensorflow_ENABLE_GPU) # by default we assume compute cabability 3.5 and 5.2. If you change this change it in # CUDA_NVCC_FLAGS and cuda_config.h below - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_30,code=\"sm_30,compute_30\";-gencode arch=compute_35,code=\"sm_35,compute_35\";-gencode arch=compute_52,code=\"sm_52,compute_52\") + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_37,code=\"sm_37,compute_37\") + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_52,code=\"sm_52,compute_52\") + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_60,code=\"sm_60,compute_60\") + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_61,code=\"sm_61,compute_61\") + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_70,code=\"sm_70,compute_70\") set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};--include-path ${PROJECT_BINARY_DIR}/$\{build_configuration\};--expt-relaxed-constexpr) set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-ftz=true) # Flush denormals to zero set(CUDA_INCLUDE ${CUDA_TOOLKIT_TARGET_DIR} ${CUDA_TOOLKIT_TARGET_DIR}/extras/CUPTI/include) include_directories(${CUDA_INCLUDE}) if (WIN32) - add_definitions(-DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=3.0,3.5,5.2) + add_definitions(-DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=3.7,5.2,6.0,6.1,7.0) else (WIN32) - # Without these double quotes, cmake in Linux makes it "-DTF_EXTRA_CUDA_CAPABILITIES=3.0, -D3.5, -D5.2" for cc, which incurs build breaks - add_definitions(-DGOOGLE_CUDA=1 -D"TF_EXTRA_CUDA_CAPABILITIES=3.0,3.5,5.2") + # Without these double quotes, cmake in Linux makes it "-DTF_EXTRA_CUDA_CAPABILITIES=3.7, -D5.2, ..." for cc, which incurs build breaks + add_definitions(-DGOOGLE_CUDA=1 -D"TF_EXTRA_CUDA_CAPABILITIES=3.7,5.2,6.0,6.1,7.0") endif (WIN32) if (WIN32) @@ -437,7 +456,7 @@ if (tensorflow_ENABLE_GPU) FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h "#ifndef CUDA_CUDA_CONFIG_H_\n" "#define CUDA_CUDA_CONFIG_H_\n" - "#define TF_CUDA_CAPABILITIES CudaVersion(\"3.0\"),CudaVersion(\"3.5\"),CudaVersion(\"5.2\")\n" + "#define TF_CUDA_CAPABILITIES CudaVersion(\"3.7\"),CudaVersion(\"5.2\"),CudaVersion(\"6.0\"),CudaVersion(\"6.1\"),CudaVersion(\"7.0\")\n" "#define TF_CUDA_VERSION \"64_${short_CUDA_VER}\"\n" "#define TF_CUDNN_VERSION \"64_${tensorflow_CUDNN_VERSION}\"\n" "#define TF_CUDA_TOOLKIT_PATH \"${CUDA_TOOLKIT_ROOT_DIR}\"\n" @@ -452,7 +471,6 @@ if (tensorflow_ENABLE_GPU) ${CUDA_TOOLKIT_TARGET_DIR}/include/cuComplex.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cublas_v2.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_fp16.h ${CUDA_TOOLKIT_TARGET_DIR}/include/device_functions.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cufft.h ${CUDA_TOOLKIT_TARGET_DIR}/include/curand.h diff --git a/tensorflow/contrib/cmake/external/eigen.cmake b/tensorflow/contrib/cmake/external/eigen.cmake index 45a0096085cc2a6332c82e1ea284812acdd45152..33bb31148d2e5b7ca177d7c30b7781e8f620c3cb 100644 --- a/tensorflow/contrib/cmake/external/eigen.cmake +++ b/tensorflow/contrib/cmake/external/eigen.cmake @@ -19,6 +19,12 @@ # build_file = "eigen.BUILD", #) +option(eigen_PATCH_FILE "Patch file to apply to eigen" OFF) +set(eigen_PATCH_COMMAND "") +if(eigen_PATCH_FILE) + set(eigen_PATCH_COMMAND PATCH_COMMAND patch -p0 -i "${eigen_PATCH_FILE}") +endif(eigen_PATCH_FILE) + include (ExternalProject) # We parse the current Eigen version and archive hash from the bazel configuration @@ -45,6 +51,7 @@ ExternalProject_Add(eigen URL ${eigen_URL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" INSTALL_DIR "${eigen_INSTALL}" + ${eigen_PATCH_COMMAND} CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF diff --git a/tensorflow/contrib/cmake/external/highwayhash.cmake b/tensorflow/contrib/cmake/external/highwayhash.cmake index a6e8a38d8c2ee3deb5453c264e0c5eb23248301f..7d260b85f21e7e56e153daf550c81155e4b68777 100644 --- a/tensorflow/contrib/cmake/external/highwayhash.cmake +++ b/tensorflow/contrib/cmake/external/highwayhash.cmake @@ -20,14 +20,6 @@ set(highwayhash_TAG be5edafc2e1a455768e260ccd68ae7317b6690ee) set(highwayhash_BUILD ${CMAKE_CURRENT_BINARY_DIR}/highwayhash/src/highwayhash) set(highwayhash_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/highwayhash/install) -# put highwayhash includes in the directory where they are expected -add_custom_target(highwayhash_create_destination_dir - COMMAND ${CMAKE_COMMAND} -E make_directory ${highwayhash_INCLUDE_DIR}/highwayhash - DEPENDS highwayhash) - -add_custom_target(highwayhash_copy_headers_to_destination - DEPENDS highwayhash_create_destination_dir) - if(WIN32) set(highwayhash_HEADERS "${highwayhash_BUILD}/highwayhash/*.h") set(highwayhash_STATIC_LIBRARIES ${highwayhash_INSTALL}/lib/highwayhash.lib) @@ -36,6 +28,20 @@ else() set(highwayhash_STATIC_LIBRARIES ${highwayhash_INSTALL}/lib/libhighwayhash.a) endif() +set(highwayhash_HEADERS + "${highwayhash_INSTALL}/include/code_annotation.h" + "${highwayhash_INSTALL}/include/highway_tree_hash.h" + "${highwayhash_INSTALL}/include/scalar_highway_tree_hash.h" + "${highwayhash_INSTALL}/include/scalar_sip_tree_hash.h" + "${highwayhash_INSTALL}/include/sip_hash.h" + "${highwayhash_INSTALL}/include/sip_tree_hash.h" + "${highwayhash_INSTALL}/include/sse41_highway_tree_hash.h" + "${highwayhash_INSTALL}/include/state_helpers.h" + "${highwayhash_INSTALL}/include/types.h" + "${highwayhash_INSTALL}/include/vec.h" + "${highwayhash_INSTALL}/include/vec2.h" +) + ExternalProject_Add(highwayhash PREFIX highwayhash GIT_REPOSITORY ${highwayhash_URL} @@ -50,5 +56,15 @@ ExternalProject_Add(highwayhash -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${highwayhash_INSTALL}) -add_custom_command(TARGET highwayhash_copy_headers_to_destination PRE_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_directory ${highwayhash_INSTALL}/include/ ${highwayhash_INCLUDE_DIR}/highwayhash) +# put highwayhash includes in the directory where they are expected +add_custom_target(highwayhash_create_destination_dir + COMMAND ${CMAKE_COMMAND} -E make_directory ${highwayhash_INCLUDE_DIR}/highwayhash + DEPENDS highwayhash) + +add_custom_target(highwayhash_copy_headers_to_destination + DEPENDS highwayhash_create_destination_dir) + +foreach(header_file ${highwayhash_HEADERS}) + add_custom_command(TARGET highwayhash_copy_headers_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${highwayhash_INCLUDE_DIR}/highwayhash/) +endforeach() diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index eba3bcfc79efe87d0a45c979c5accfa1b6511ed0..1d638e64023c7e2706d8d97ff8679677b6cd289d 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -20,14 +20,6 @@ set(nsync_TAG 1.20.0) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) -# put nsync includes in the directory where they are expected -add_custom_target(nsync_create_destination_dir - COMMAND ${CMAKE_COMMAND} -E make_directory ${nsync_INCLUDE_DIR} - DEPENDS nsync) - -add_custom_target(nsync_copy_headers_to_destination - DEPENDS nsync_create_destination_dir) - if(WIN32) set(nsync_HEADERS "${nsync_BUILD}/public/*.h") set(nsync_STATIC_LIBRARIES ${nsync_INSTALL}/lib/nsync.lib) @@ -49,7 +41,35 @@ ExternalProject_Add(nsync -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${nsync_INSTALL} - -DNSYNC_LANGUAGE:STRING=c++11) + -DNSYNC_LANGUAGE:STRING=c++11) + +set(nsync_HEADERS + "${nsync_INSTALL}/include/nsync.h" + "${nsync_INSTALL}/include/nsync_atomic.h" + "${nsync_INSTALL}/include/nsync_counter.h" + "${nsync_INSTALL}/include/nsync_cpp.h" + "${nsync_INSTALL}/include/nsync_cv.h" + "${nsync_INSTALL}/include/nsync_debug.h" + "${nsync_INSTALL}/include/nsync_mu.h" + "${nsync_INSTALL}/include/nsync_mu_wait.h" + "${nsync_INSTALL}/include/nsync_note.h" + "${nsync_INSTALL}/include/nsync_once.h" + "${nsync_INSTALL}/include/nsync_time.h" + "${nsync_INSTALL}/include/nsync_time_internal.h" + "${nsync_INSTALL}/include/nsync_waiter.h" +) + +# put nsync includes in the directory where they are expected +add_custom_target(nsync_create_destination_dir + COMMAND ${CMAKE_COMMAND} -E make_directory ${nsync_INCLUDE_DIR} + DEPENDS nsync) + +add_custom_target(nsync_copy_headers_to_destination + DEPENDS nsync_create_destination_dir) + +foreach(header_file ${nsync_HEADERS}) + add_custom_command(TARGET nsync_copy_headers_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${nsync_INCLUDE_DIR}/) +endforeach() + -add_custom_command(TARGET nsync_copy_headers_to_destination PRE_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_directory ${nsync_INSTALL}/include/ ${nsync_INCLUDE_DIR}/) diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index d530572e91825ed88d09c26a10693288878d09ed..9045290679b87a201df8b930df6ff9a4ec106dcf 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -14,6 +14,7 @@ tensorflow/examples/tutorials tensorflow/examples/tutorials/mnist tensorflow/python tensorflow/python/client +tensorflow/python/compat tensorflow/python/data tensorflow/python/data/ops tensorflow/python/data/util @@ -61,6 +62,8 @@ tensorflow/python/saved_model tensorflow/python/summary tensorflow/python/summary/writer tensorflow/python/tools +tensorflow/python/tools/api +tensorflow/python/tools/api/generator tensorflow/python/training tensorflow/python/training/checkpointable tensorflow/python/user_ops @@ -68,7 +71,6 @@ tensorflow/python/util tensorflow/python/util/protobuf tensorflow/tools tensorflow/tools/api -tensorflow/tools/api/generator tensorflow/tools/graph_transforms tensorflow/contrib tensorflow/contrib/all_reduce @@ -86,6 +88,8 @@ tensorflow/contrib/batching/python/ops tensorflow/contrib/bayesflow tensorflow/contrib/bayesflow/python tensorflow/contrib/bayesflow/python/ops +# tensorflow/contrib/bigtable/python +# tensorflow/contrib/bigtable/python/ops tensorflow/contrib/boosted_trees tensorflow/contrib/boosted_trees/estimator_batch tensorflow/contrib/boosted_trees/kernels @@ -111,7 +115,6 @@ tensorflow/contrib/coder tensorflow/contrib/coder/kernels tensorflow/contrib/coder/ops tensorflow/contrib/coder/python -tensorflow/contrib/coder/python/layers tensorflow/contrib/coder/python/ops tensorflow/contrib/compiler tensorflow/contrib/constrained_optimization @@ -238,6 +241,8 @@ tensorflow/contrib/keras/api/keras/wrappers/scikit_learn tensorflow/contrib/kernel_methods tensorflow/contrib/kernel_methods/python tensorflow/contrib/kernel_methods/python/mappers +tensorflow/contrib/kinesis/python +tensorflow/contrib/kinesis/python/ops tensorflow/contrib/kfac tensorflow/contrib/kfac/examples tensorflow/contrib/kfac/python diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 844f62649d970506f1b4b4c5718fab8d1f0856e1..7b892ba248bc43cd885f295288c677ac97efaa06 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -68,6 +68,7 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/csv_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc" diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index e3b59001bcb4f081eb2db3443ee9ad714c822ac8..6d86daf5f174a3238ab92e5bba6085c904766766 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -198,7 +198,7 @@ function(add_python_module MODULE_NAME) # so we currently add explicit commands to include those files # later on in this script. if (NOT "${script}" MATCHES "_test\.py$") - add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD + add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/${script} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/${script}) endif() endforeach() @@ -297,7 +297,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) ) target_link_libraries(${tf_python_op_lib_name}_gen_python PRIVATE tf_protos_cc - tf_python_protos_cc + tf_python_protos_cc ${tensorflow_EXTERNAL_LIBRARIES} ) @@ -549,15 +549,15 @@ if(WIN32) ${NUMPY_INCLUDE_DIR} ) #target_link_libraries(pywrap_tensorflow_internal_static - # tf_protos_cc - # tf_python_protos_cc + # tf_protos_cc + # tf_python_protos_cc #) add_dependencies(pywrap_tensorflow_internal_static tf_protos_cc tf_python_protos_cc) set(pywrap_tensorflow_internal_static_dependencies $ $ $ - ${nsync_STATIC_LIBRARIES} + ${nsync_STATIC_LIBRARIES} ) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") @@ -736,8 +736,8 @@ endif() # Generate API __init__.py files. ######################################################## -# Parse tensorflow/tools/api/generator/BUILD to get list of generated files. -FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text) +# Parse tensorflow/python/tools/api/generator/BUILD to get list of generated files. +FILE(READ ${tensorflow_source_dir}/tensorflow/python/tools/api/generator/api_init_files.bzl api_generator_BUILD_text) STRING(REGEX MATCH "# BEGIN GENERATED FILES.*# END GENERATED FILES" api_init_files_text ${api_generator_BUILD_text}) string(REPLACE "# BEGIN GENERATED FILES" "" api_init_files_text ${api_init_files_text}) string(REPLACE "# END GENERATED FILES" "" api_init_files_text ${api_init_files_text}) @@ -763,57 +763,40 @@ file(WRITE "${api_init_list_file}" "${api_init_files}") # recongnize paths. As CUDA isn't built with MKL, the MKL built directory is the only path to this command to work around that issue. # To not override the CUDA and system path in other circumstances, `if-else` branch used here to handle this problem, # and should be removed if the path issue can be resolved. +# UPDATE: Below block appears to handle multiple items in PATH correctly, but risks command line limits if PATH is large. +# If you have issues, try `set(PY_RUNTIME_ENV "PATH=${mkl_BIN_DIRS}")` instead. ### -if (tensorflow_ENABLE_MKL_SUPPORT) +set(PY_RUNTIME_ENV "") +if(tensorflow_ENABLE_MKL_SUPPORT) # add mkl dist dlls to system path for python - # TODO: In current cmake version, PY_RUNTIME_ENV behaves strange with multiple paths, - # so we have to specify only one path in it to work around the issue. We need this if/else - # to protect overwriting CUDA environments - set(PY_RUNTIME_ENV ${mkl_BIN_DIRS}) - add_custom_command( - OUTPUT ${api_init_files} - DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops - - # tensorflow/__init__.py depends on files generated in this step. So, remove it while - # this step is running since the files aren't there yet. - COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - - # Run create_python_api.py to generate API init files. - COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python PATH=${PY_RUNTIME_ENV} ${PYTHON_EXECUTABLE} - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" - "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" - "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" - "--package=tensorflow.python" - "--apiname=tensorflow" - "${api_init_list_file}" - - COMMENT "Generating __init__.py files for Python API." - WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" - VERBATIM - ) -else (tensorflow_ENABLE_MKL_SUPPORT) - add_custom_command( - OUTPUT ${api_init_files} - DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops - - # tensorflow/__init__.py depends on files generated in this step. So, remove it while - # this step is running since the files aren't there yet. - COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - - # Run create_python_api.py to generate API init files. - COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" - "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" - "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" - "--package=tensorflow.python" - "--apiname=tensorflow" - "${api_init_list_file}" - - COMMENT "Generating __init__.py files for Python API." - WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" - ) -endif (tensorflow_ENABLE_MKL_SUPPORT) + file(TO_CMAKE_PATH "$ENV{PATH}" PY_RUNTIME_ENV) + set(PY_RUNTIME_ENV ${mkl_BIN_DIRS} ${PY_RUNTIME_ENV}) + file(TO_NATIVE_PATH "${PY_RUNTIME_ENV}" PY_RUNTIME_ENV) + set(PY_RUNTIME_ENV "PATH=${PY_RUNTIME_ENV}") +endif(tensorflow_ENABLE_MKL_SUPPORT) + +add_custom_command( + OUTPUT ${api_init_files} + DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops + + # tensorflow/__init__.py depends on files generated in this step. So, remove it while + # this step is running since the files aren't there yet. + COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + + # Run create_python_api.py to generate API init files. + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python "${PY_RUNTIME_ENV}" ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/tools/api/generator/create_python_api.py" + "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" + "--package=tensorflow.python" + "--apiname=tensorflow" + "${api_init_list_file}" + + COMMENT "Generating __init__.py files for Python API." + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" + VERBATIM +) add_custom_target(tf_python_api SOURCES ${api_init_files}) add_dependencies(tf_python_api tf_python_ops) @@ -824,8 +807,8 @@ add_dependencies(tf_python_api tf_python_ops) # Generate API __init__.py files for tf.estimator. ######################################################## -# Parse tensorflow/tools/api/generator/BUILD to get list of generated files. -FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text) +# Parse tensorflow/python/tools/api/generator/BUILD to get list of generated files. +FILE(READ ${tensorflow_source_dir}/tensorflow/python/tools/api/generator/api_gen.bzl api_generator_BUILD_text) STRING(REGEX MATCH "# BEGIN GENERATED ESTIMATOR FILES.*# END GENERATED ESTIMATOR FILES" api_init_files_text ${api_generator_BUILD_text}) string(REPLACE "# BEGIN GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) string(REPLACE "# END GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) @@ -848,11 +831,12 @@ add_custom_command( DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops # Run create_python_api.py to generate API init files. - COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python "${PY_RUNTIME_ENV}" ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/tools/api/generator/create_python_api.py" "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api" "--package=tensorflow.python.estimator" "--apiname=estimator" + "--output_package=tensorflow.python.estimator.api" "${estimator_api_init_list_file}" COMMENT "Generating __init__.py files for Python API." diff --git a/tensorflow/contrib/cmake/tf_stream_executor.cmake b/tensorflow/contrib/cmake/tf_stream_executor.cmake index 2f70e59d54d9315d1079b268ef8543599c168d9e..6d634cb1709910f366c7ca538d28bd802b2a7c63 100644 --- a/tensorflow/contrib/cmake/tf_stream_executor.cmake +++ b/tensorflow/contrib/cmake/tf_stream_executor.cmake @@ -64,8 +64,6 @@ file(GLOB tf_stream_executor_srcs if (tensorflow_ENABLE_GPU) file(GLOB tf_stream_executor_gpu_srcs "${tensorflow_source_dir}/tensorflow/stream_executor/cuda/*.cc" - "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.h" - "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.cc" ) if (NOT tensorflow_BUILD_CC_TESTS) file(GLOB tf_stream_executor_gpu_tests diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index eb9482dc25f2be8ce46cc38bf3dd28889b09a9d4..2c878c17167c662d10a8c7dabf41687efdbf65d8 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -122,6 +122,17 @@ function(AddPythonTests) endforeach() endfunction(AddPythonTests) +# +# ensure that every element is an existing file +# +function(CheckExists TYPE SOURCES) + foreach(source ${SOURCES}) + if(NOT EXISTS ${source}) + message(SEND_ERROR "${TYPE} not found: ${source}") + endif() + endforeach(source) +endfunction(CheckExists) + if (tensorflow_BUILD_PYTHON_TESTS) # # python tests. This assumes that the tensorflow wheel is @@ -145,7 +156,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/debug/wrappers/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/estimator/python/estimator/*_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/*.py" - "${tensorflow_source_dir}/tensorflow/python/meta_graph_transform/*_test.py" "${tensorflow_source_dir}/tensorflow/python/ops/quantized_conv_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/ops/quantized_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/platform/build_info_test.py" @@ -193,11 +203,11 @@ if (tensorflow_BUILD_PYTHON_TESTS) # flaky test "${tensorflow_source_dir}/tensorflow/python/profiler/internal/run_metadata_test.py" "${tensorflow_source_dir}/tensorflow/python/profiler/model_analyzer_test.py" + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/map_dataset_op_test.py" # Fails because uses data dependencies with bazel "${tensorflow_source_dir}/tensorflow/python/saved_model/saved_model_test.py" "${tensorflow_source_dir}/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py" # requires scipy - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/preprocessing/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py" "${tensorflow_source_dir}/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py" # Takes very long to run without sharding (defined in bazel build file). @@ -216,7 +226,8 @@ if (tensorflow_BUILD_PYTHON_TESTS) ${tensorflow_source_dir}/tensorflow/python/kernel_tests/duplicate_op_test.py ${tensorflow_source_dir}/tensorflow/python/kernel_tests/invalid_op_test.py ${tensorflow_source_dir}/tensorflow/python/kernel_tests/ackermann_test.py - + # Tests too large to run. + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py ) if (WIN32) set(tf_test_src_py_exclude @@ -254,10 +265,9 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Flaky because of local cluster creation. "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" - "${tensorflow_source_dir}tensorflow/python/training/localhost_cluster_performance_test.py" + "${tensorflow_source_dir}/tensorflow/python/training/localhost_cluster_performance_test.py" "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Type error in testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU. "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py" @@ -327,6 +337,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/io_utils_test.py" # b/72894325 ) endif() + CheckExists(${tf_test_src_py_exclude}) list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) AddPythonTests( @@ -478,6 +489,7 @@ if (tensorflow_BUILD_CC_TESTS) "${tensorflow_source_dir}/tensorflow/cc/saved_model/*_test.cc" ) + CheckExists(${tf_test_src_simple_exclude}) list(REMOVE_ITEM tf_test_src_simple ${tf_test_src_simple_exclude} ${tf_cc_saved_model_test_srcs} @@ -492,6 +504,7 @@ if (tensorflow_BUILD_CC_TESTS) ${tf_core_profiler_test_srcs} ) + CheckExists(${tf_src_testlib}) set(tf_test_lib tf_test_lib) add_library(${tf_test_lib} STATIC ${tf_src_testlib}) diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD index a2c6e413039ee3b5af3cb53d1af3325037536d36..855c824ead2f7de4c37db2d2a3648a9ee00fb9e9 100644 --- a/tensorflow/contrib/coder/BUILD +++ b/tensorflow/contrib/coder/BUILD @@ -1,5 +1,5 @@ # Description: -# Contains tools related to data compression. +# Contains ops related to data compression. package(default_visibility = [ "//learning/brain:__subpackages__", @@ -168,7 +168,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":coder_ops_py", - ":entropybottleneck_py", ], ) @@ -205,44 +204,3 @@ tf_py_test( ], main = "python/ops/coder_ops_test.py", ) - -py_library( - name = "entropybottleneck_py", - srcs = [ - "python/layers/entropybottleneck.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":coder_ops_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:functional_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn", - "//tensorflow/python:ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - "//tensorflow/python/keras:engine", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "entropybottleneck_py_test", - srcs = [ - "python/layers/entropybottleneck_test.py", - ], - additional_deps = [ - ":entropybottleneck_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:variables", - "//tensorflow/python:training", - ], - main = "python/layers/entropybottleneck_test.py", -) diff --git a/tensorflow/contrib/coder/README.md b/tensorflow/contrib/coder/README.md deleted file mode 100644 index c6c379c458893551b765327c0c1cbfff7f24f9c3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/coder/README.md +++ /dev/null @@ -1,73 +0,0 @@ -# Entropy coder - -This module contains range encoder and range decoder which can encode integer -data into string with cumulative distribution functions (CDF). - -## Data and CDF values - -The data to be encoded should be non-negative integers in half-open interval -`[0, m)`. Then a CDF is represented as an integral vector of length `m + 1` -where `CDF(i) = f(Pr(X < i) * 2^precision)` for i = 0,1,...,m, and `precision` -is an attribute in range `0 < precision <= 16`. The function `f` maps real -values into integers, e.g., round or floor. It is important that to encode a -number `i`, `CDF(i + 1) - CDF(i)` cannot be zero. - -Note that we used `Pr(X < i)` not `Pr(X <= i)`, and therefore CDF(0) = 0 always. - -## RangeEncode: data shapes and CDF shapes - -For each data element, its CDF has to be provided. Therefore if the shape of CDF -should be `data.shape + (m + 1,)` in NumPy-like notation. For example, if `data` -is a 2-D tensor of shape (10, 10) and its elements are in `[0, 64)`, then the -CDF tensor should have shape (10, 10, 65). - -This may make CDF tensor too large, and in many applications all data elements -may have the same probability distribution. To handle this, `RangeEncode` -supports limited broadcasting CDF into data. Broadcasting is limited in the -following sense: - -- All CDF axes but the last one is broadcasted into data but not the other way - around, -- The number of CDF axes does not extend, i.e., `CDF.ndim == data.ndim + 1`. - -In the previous example where data has shape (10, 10), the following are -acceptable CDF shapes: - -- (10, 10, 65) -- (1, 10, 65) -- (10, 1, 65) -- (1, 1, 65) - -## RangeDecode - -`RangeEncode` encodes neither data shape nor termination character. Therefore -the decoder should know how many characters are encoded into the string, and -`RangeDecode` takes the encoded data shape as the second argument. The same -shape restrictions as `RangeEncode` inputs apply here. - -## Example - -```python -data = tf.random_uniform((128, 128), 0, 10, dtype=tf.int32) - -histogram = tf.bincount(data, minlength=10, maxlength=10) -cdf = tf.cumsum(histogram, exclusive=False) -# CDF should have length m + 1. -cdf = tf.pad(cdf, [[1, 0]]) -# CDF axis count must be one more than data. -cdf = tf.reshape(cdf, [1, 1, -1]) - -# Note that data has 2^14 elements, and therefore the sum of CDF is 2^14. -data = tf.cast(data, tf.int16) -encoded = coder.range_encode(data, cdf, precision=14) -decoded = coder.range_decode(encoded, tf.shape(data), cdf, precision=14) - -# data and decoded should be the same. -sess = tf.Session() -x, y = sess.run((data, decoded)) -assert np.all(x == y) -``` - -## Authors -Sung Jin Hwang (github: [ssjhv](https://github.com/ssjhv)) and Nick Johnston -(github: [nmjohn](https://github.com/nmjohn)) diff --git a/tensorflow/contrib/coder/__init__.py b/tensorflow/contrib/coder/__init__.py index 99b8ac7595ec632b2918e6b7ca22c06dd7f0a8b3..8897312046c63c42d85e7fba5b62d2ed908dd6e9 100644 --- a/tensorflow/contrib/coder/__init__.py +++ b/tensorflow/contrib/coder/__init__.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Data compression tools.""" +"""Data compression ops.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import -from tensorflow.contrib.coder.python.layers.entropybottleneck import * from tensorflow.contrib.coder.python.ops.coder_ops import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck.py b/tensorflow/contrib/coder/python/layers/entropybottleneck.py deleted file mode 100644 index 0c997bd4fdfa4233117c9fec2c4397301b1c8cb9..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/coder/python/layers/entropybottleneck.py +++ /dev/null @@ -1,697 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Entropy bottleneck layer.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.coder.python.ops import coder_ops - -from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras.engine import base_layer -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.summary import summary - - -class EntropyBottleneck(base_layer.Layer): - """Entropy bottleneck layer. - - This layer can be used to model the entropy (the amount of information - conveyed) of the tensor passing through it. During training, this can be used - to impose a (soft) entropy constraint on its activations, limiting the amount - of information flowing through the layer. Note that this is distinct from - other types of bottlenecks, which reduce the dimensionality of the space, for - example. Dimensionality reduction does not limit the amount of information, - and does not enable efficient data compression per se. - - After training, this layer can be used to compress any input tensor to a - string, which may be written to a file, and to decompress a file which it - previously generated back to a reconstructed tensor (possibly on a different - machine having access to the same model checkpoint). The entropies estimated - during training or evaluation are approximately equal to the average length of - the strings in bits. - - The layer implements a flexible probability density model to estimate entropy, - which is described in the appendix of the paper (please cite the paper if you - use this code for scientific work): - - "Variational image compression with a scale hyperprior" - - Johannes Ballé, David Minnen, Saurabh Singh, Sung Jin Hwang, Nick Johnston - - https://arxiv.org/abs/1802.01436 - - The layer assumes that the input tensor is at least 2D, with a batch dimension - at the beginning and a channel dimension as specified by `data_format`. The - layer trains an independent probability density model for each channel, but - assumes that across all other dimensions, the inputs are i.i.d. (independent - and identically distributed). Because the entropy (and hence, average - codelength) is a function of the densities, this assumption may have a direct - effect on the compression performance. - - Because data compression always involves discretization, the outputs of the - layer are generally only approximations of its inputs. During training, - discretization is modeled using additive uniform noise to ensure - differentiability. The entropies computed during training are differential - entropies. During evaluation, the data is actually quantized, and the - entropies are discrete (Shannon entropies). To make sure the approximated - tensor values are good enough for practical purposes, the training phase must - be used to balance the quality of the approximation with the entropy, by - adding an entropy term to the training loss, as in the following example. - - Here, we use the entropy bottleneck to compress the latent representation of - an autoencoder. The data vectors `x` in this case are 4D tensors in - `'channels_last'` format (for example, 16x16 pixel grayscale images). - - The layer always produces exactly one auxiliary loss and one update op which - are only significant for compression and decompression. To use the compression - feature, the auxiliary loss must be minimized during or after training. After - that, the update op must be executed at least once. Here, we simply attach - them to the main training step. - - Training: - ``` - # Build autoencoder. - x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1]) - y = forward_transform(x) - entropy_bottleneck = EntropyBottleneck() - y_, likelihoods = entropy_bottleneck(y, training=True) - x_ = backward_transform(y_) - - # Information content (= predicted codelength) in bits of each batch element - # (note that taking the natural logarithm and dividing by `log(2)` is - # equivalent to taking base-2 logarithms): - bits = tf.reduce_sum(tf.log(likelihoods), axis=(1, 2, 3)) / -np.log(2) - - # Squared difference of each batch element: - squared_error = tf.reduce_sum(tf.squared_difference(x, x_), axis=(1, 2, 3)) - - # The loss is a weighted sum of mean squared error and entropy (average - # information content), where the weight controls the trade-off between - # approximation error and entropy. - main_loss = 0.5 * tf.reduce_mean(squared_error) + tf.reduce_mean(bits) - - # Minimize loss and auxiliary loss, and execute update op. - main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) - main_step = optimizer.minimize(main_loss) - # 1e-2 is a good starting point for the learning rate of the auxiliary loss, - # assuming Adam is used. - aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-2) - aux_step = optimizer.minimize(entropy_bottleneck.losses[0]) - step = tf.group(main_step, aux_step, entropy_bottleneck.updates[0]) - ``` - - Evaluation: - ``` - # Build autoencoder. - x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1]) - y = forward_transform(x) - y_, likelihoods = EntropyBottleneck()(y, training=False) - x_ = backward_transform(y_) - - # Information content (= predicted codelength) in bits of each batch element: - bits = tf.reduce_sum(tf.log(likelihoods), axis=(1, 2, 3)) / -np.log(2) - - # Squared difference of each batch element: - squared_error = tf.reduce_sum(tf.squared_difference(x, x_), axis=(1, 2, 3)) - - # The loss is a weighted sum of mean squared error and entropy (average - # information content), where the weight controls the trade-off between - # approximation error and entropy. - loss = 0.5 * tf.reduce_mean(squared_error) + tf.reduce_mean(bits) - ``` - - To be able to compress the bottleneck tensor and decompress it in a different - session, or on a different machine, you need three items: - - The compressed representations stored as strings. - - The shape of the bottleneck for these string representations as a `Tensor`, - as well as the number of channels of the bottleneck at graph construction - time. - - The checkpoint of the trained model that was used for compression. Note: - It is crucial that the auxiliary loss produced by this layer is minimized - during or after training, and that the update op is run after training and - minimization of the auxiliary loss, but *before* the checkpoint is saved. - - Compression: - ``` - x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1]) - y = forward_transform(x) - strings = EntropyBottleneck().compress(y) - shape = tf.shape(y)[1:] - ``` - - Decompression: - ``` - strings = tf.placeholder(tf.string, shape=[None]) - shape = tf.placeholder(tf.int32, shape=[3]) - entropy_bottleneck = EntropyBottleneck(dtype=tf.float32) - y_ = entropy_bottleneck.decompress(strings, shape, channels=5) - x_ = backward_transform(y_) - ``` - Here, we assumed that the tensor produced by the forward transform has 5 - channels. - - The above four use cases can also be implemented within the same session (i.e. - on the same `EntropyBottleneck` instance), for testing purposes, etc., by - calling the object more than once. - - Arguments: - init_scale: Float. A scaling factor determining the initial width of the - probability densities. This should be chosen big enough so that the - range of values of the layer inputs roughly falls within the interval - [`-init_scale`, `init_scale`] at the beginning of training. - filters: An iterable of ints, giving the number of filters at each layer of - the density model. Generally, the more filters and layers, the more - expressive is the density model in terms of modeling more complicated - distributions of the layer inputs. For details, refer to the paper - referenced above. The default is `[3, 3, 3]`, which should be sufficient - for most practical purposes. - tail_mass: Float, between 0 and 1. The bottleneck layer automatically - determines the range of input values that should be represented based on - their frequency of occurrence. Values occurring in the tails of the - distributions will be clipped to that range during compression. - `tail_mass` determines the amount of probability mass in the tails which - is cut off in the worst case. For example, the default value of `1e-9` - means that at most 1 in a billion input samples will be clipped to the - range. - optimize_integer_offset: Boolean. Typically, the input values of this layer - are floats, which means that quantization during evaluation can be - performed with an arbitrary offset. By default, the layer determines that - offset automatically. In special situations, such as when it is known that - the layer will receive only full integer values during evaluation, it can - be desirable to set this argument to `False` instead, in order to always - quantize to full integer values. - likelihood_bound: Float. If positive, the returned likelihood values are - ensured to be greater than or equal to this value. This prevents very - large gradients with a typical entropy loss (defaults to 1e-9). - range_coder_precision: Integer, between 1 and 16. The precision of the range - coder used for compression and decompression. This trades off computation - speed with compression efficiency, where 16 is the slowest but most - efficient setting. Choosing lower values may increase the average - codelength slightly compared to the estimated entropies. - data_format: Either `'channels_first'` or `'channels_last'` (default). - trainable: Boolean. Whether the layer should be trained. - name: String. The name of the layer. - dtype: Default dtype of the layer's parameters (default of `None` means use - the type of the first input). - - Read-only properties: - init_scale: See above. - filters: See above. - tail_mass: See above. - optimize_integer_offset: See above. - likelihood_bound: See above. - range_coder_precision: See above. - data_format: See above. - name: String. See above. - dtype: See above. - trainable_variables: List of trainable variables. - non_trainable_variables: List of non-trainable variables. - variables: List of all variables of this layer, trainable and non-trainable. - updates: List of update ops of this layer. Always contains exactly one - update op, which must be run once after the last training step, before - `compress` or `decompress` is used. - losses: List of losses added by this layer. Always contains exactly one - auxiliary loss, which must be added to the training loss. - - Mutable properties: - trainable: Boolean. Whether the layer should be trained. - input_spec: Optional `InputSpec` object specifying the constraints on inputs - that can be accepted by the layer. - """ - - def __init__(self, init_scale=10, filters=(3, 3, 3), tail_mass=1e-9, - optimize_integer_offset=True, likelihood_bound=1e-9, - range_coder_precision=16, data_format="channels_last", **kwargs): - super(EntropyBottleneck, self).__init__(**kwargs) - self._init_scale = float(init_scale) - self._filters = tuple(int(f) for f in filters) - self._tail_mass = float(tail_mass) - if not 0 < self.tail_mass < 1: - raise ValueError( - "`tail_mass` must be between 0 and 1, got {}.".format(self.tail_mass)) - self._optimize_integer_offset = bool(optimize_integer_offset) - self._likelihood_bound = float(likelihood_bound) - self._range_coder_precision = int(range_coder_precision) - self._data_format = data_format - self._channel_axis(2) # trigger ValueError early - self.input_spec = base_layer.InputSpec(min_ndim=2) - - @property - def init_scale(self): - return self._init_scale - - @property - def filters(self): - return self._filters - - @property - def tail_mass(self): - return self._tail_mass - - @property - def optimize_integer_offset(self): - return self._optimize_integer_offset - - @property - def likelihood_bound(self): - return self._likelihood_bound - - @property - def range_coder_precision(self): - return self._range_coder_precision - - @property - def data_format(self): - return self._data_format - - def _channel_axis(self, ndim): - try: - return {"channels_first": 1, "channels_last": ndim - 1}[self.data_format] - except KeyError: - raise ValueError("Unsupported `data_format` for {} layer: {}.".format( - self.__class__.__name__, self.data_format)) - - def _logits_cumulative(self, inputs, stop_gradient): - """Evaluate logits of the cumulative densities. - - Args: - inputs: The values at which to evaluate the cumulative densities, expected - to be a `Tensor` of shape `(channels, 1, batch)`. - stop_gradient: Boolean. Whether to add `array_ops.stop_gradient` calls so - that the gradient of the output with respect to the density model - parameters is disconnected (the gradient with respect to `inputs` is - left untouched). - - Returns: - A `Tensor` of the same shape as `inputs`, containing the logits of the - cumulative densities evaluated at the given inputs. - """ - logits = inputs - - for i in range(len(self.filters) + 1): - matrix = self._matrices[i] - if stop_gradient: - matrix = array_ops.stop_gradient(matrix) - logits = math_ops.matmul(matrix, logits) - - bias = self._biases[i] - if stop_gradient: - bias = array_ops.stop_gradient(bias) - logits += bias - - if i < len(self._factors): - factor = self._factors[i] - if stop_gradient: - factor = array_ops.stop_gradient(factor) - logits += factor * math_ops.tanh(logits) - - return logits - - def build(self, input_shape): - """Builds the layer. - - Creates the variables for the network modeling the densities, creates the - auxiliary loss estimating the median and tail quantiles of the densities, - and then uses that to create the probability mass functions and the update - op that produces the discrete cumulative density functions used by the range - coder. - - Args: - input_shape: Shape of the input tensor, used to get the number of - channels. - - Raises: - ValueError: if `input_shape` doesn't specify the length of the channel - dimension. - """ - input_shape = tensor_shape.TensorShape(input_shape) - channel_axis = self._channel_axis(input_shape.ndims) - channels = input_shape[channel_axis].value - if channels is None: - raise ValueError("The channel dimension of the inputs must be defined.") - self.input_spec = base_layer.InputSpec( - ndim=input_shape.ndims, axes={channel_axis: channels}) - filters = (1,) + self.filters + (1,) - scale = self.init_scale ** (1 / (len(self.filters) + 1)) - - # Create variables. - self._matrices = [] - self._biases = [] - self._factors = [] - for i in range(len(self.filters) + 1): - init = np.log(np.expm1(1 / scale / filters[i + 1])) - matrix = self.add_variable( - "matrix_{}".format(i), dtype=self.dtype, - shape=(channels, filters[i + 1], filters[i]), - initializer=init_ops.Constant(init)) - matrix = nn.softplus(matrix) - self._matrices.append(matrix) - - bias = self.add_variable( - "bias_{}".format(i), dtype=self.dtype, - shape=(channels, filters[i + 1], 1), - initializer=init_ops.RandomUniform(-.5, .5)) - self._biases.append(bias) - - if i < len(self.filters): - factor = self.add_variable( - "factor_{}".format(i), dtype=self.dtype, - shape=(channels, filters[i + 1], 1), - initializer=init_ops.Zeros()) - factor = math_ops.tanh(factor) - self._factors.append(factor) - - # To figure out what range of the densities to sample, we need to compute - # the quantiles given by `tail_mass / 2` and `1 - tail_mass / 2`. Since we - # can't take inverses of the cumulative directly, we make it an optimization - # problem: - # `quantiles = argmin(|logit(cumulative) - target|)` - # where `target` is `logit(tail_mass / 2)` or `logit(1 - tail_mass / 2)`. - # Taking the logit (inverse of sigmoid) of the cumulative makes the - # representation of the right target more numerically stable. - - # Numerically stable way of computing logits of `tail_mass / 2` - # and `1 - tail_mass / 2`. - target = np.log(2 / self.tail_mass - 1) - # Compute lower and upper tail quantile as well as median. - target = constant_op.constant([-target, 0, target], dtype=self.dtype) - - def quantiles_initializer(shape, dtype=None, partition_info=None): - del partition_info # unused - assert tuple(shape[1:]) == (1, 3) - init = constant_op.constant( - [[[-self.init_scale, 0, self.init_scale]]], dtype=dtype) - return array_ops.tile(init, (shape[0], 1, 1)) - - quantiles = self.add_variable( - "quantiles", shape=(channels, 1, 3), dtype=self.dtype, - initializer=quantiles_initializer) - logits = self._logits_cumulative(quantiles, stop_gradient=True) - loss = math_ops.reduce_sum(abs(logits - target)) - self.add_loss(loss, inputs=None) - - # Save medians for `call`, `compress`, and `decompress`. - self._medians = quantiles[:, :, 1:2] - if not self.optimize_integer_offset: - self._medians = math_ops.round(self._medians) - - # Largest distance observed between lower tail quantile and median, - # or between median and upper tail quantile. - minima = math_ops.reduce_max(self._medians - quantiles[:, :, 0:1]) - maxima = math_ops.reduce_max(quantiles[:, :, 2:3] - self._medians) - minmax = math_ops.maximum(minima, maxima) - minmax = math_ops.ceil(minmax) - minmax = math_ops.maximum(minmax, 1) - - # Sample the density up to `minmax` around the median. - samples = math_ops.range(-minmax, minmax + 1, dtype=self.dtype) - samples += self._medians - - half = constant_op.constant(.5, dtype=self.dtype) - # We strip the sigmoid from the end here, so we can use the special rule - # below to only compute differences in the left tail of the sigmoid. - # This increases numerical stability (see explanation in `call`). - lower = self._logits_cumulative(samples - half, stop_gradient=True) - upper = self._logits_cumulative(samples + half, stop_gradient=True) - # Flip signs if we can move more towards the left tail of the sigmoid. - sign = -math_ops.sign(math_ops.add_n([lower, upper])) - pmf = abs(math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower)) - # Add tail masses to first and last bin of pmf, as we clip values for - # compression, meaning that out-of-range values get mapped to these bins. - pmf = array_ops.concat([ - math_ops.add_n([pmf[:, 0, :1], math_ops.sigmoid(lower[:, 0, :1])]), - pmf[:, 0, 1:-1], - math_ops.add_n([pmf[:, 0, -1:], math_ops.sigmoid(-upper[:, 0, -1:])]), - ], axis=-1) - self._pmf = pmf - - cdf = coder_ops.pmf_to_quantized_cdf( - pmf, precision=self.range_coder_precision) - def cdf_getter(*args, **kwargs): - del args, kwargs # ignored - return variable_scope.get_variable( - "quantized_cdf", dtype=dtypes.int32, initializer=cdf, - trainable=False, validate_shape=False, collections=()) - # Need to provide a fake shape here since add_variable insists on it. - self._quantized_cdf = self.add_variable( - "quantized_cdf", shape=(channels, 1), dtype=dtypes.int32, - getter=cdf_getter, trainable=False) - - update_op = state_ops.assign( - self._quantized_cdf, cdf, validate_shape=False) - self.add_update(update_op, inputs=None) - - super(EntropyBottleneck, self).build(input_shape) - - def call(self, inputs, training): - """Pass a tensor through the bottleneck. - - Args: - inputs: The tensor to be passed through the bottleneck. - training: Boolean. If `True`, returns a differentiable approximation of - the inputs, and their likelihoods under the modeled probability - densities. If `False`, returns the quantized inputs and their - likelihoods under the corresponding probability mass function. These - quantities can't be used for training, as they are not differentiable, - but represent actual compression more closely. - - Returns: - values: `Tensor` with the same shape as `inputs` containing the perturbed - or quantized input values. - likelihood: `Tensor` with the same shape as `inputs` containing the - likelihood of `values` under the modeled probability distributions. - - Raises: - ValueError: if `inputs` has different `dtype` or number of channels than - a previous set of inputs the model was invoked with earlier. - """ - inputs = ops.convert_to_tensor(inputs) - ndim = self.input_spec.ndim - channel_axis = self._channel_axis(ndim) - half = constant_op.constant(.5, dtype=self.dtype) - - # Convert to (channels, 1, batch) format by commuting channels to front - # and then collapsing. - order = list(range(ndim)) - order.pop(channel_axis) - order.insert(0, channel_axis) - values = array_ops.transpose(inputs, order) - shape = array_ops.shape(values) - values = array_ops.reshape(values, (shape[0], 1, -1)) - - # Add noise or quantize. - if training: - noise = random_ops.random_uniform(array_ops.shape(values), -half, half) - values = math_ops.add_n([values, noise]) - elif self.optimize_integer_offset: - values = math_ops.round(values - self._medians) + self._medians - else: - values = math_ops.round(values) - - # Evaluate densities. - # We can use the special rule below to only compute differences in the left - # tail of the sigmoid. This increases numerical stability: sigmoid(x) is 1 - # for large x, 0 for small x. Subtracting two numbers close to 0 can be done - # with much higher precision than subtracting two numbers close to 1. - lower = self._logits_cumulative(values - half, stop_gradient=False) - upper = self._logits_cumulative(values + half, stop_gradient=False) - # Flip signs if we can move more towards the left tail of the sigmoid. - sign = -math_ops.sign(math_ops.add_n([lower, upper])) - sign = array_ops.stop_gradient(sign) - likelihood = abs( - math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower)) - if self.likelihood_bound > 0: - likelihood_bound = constant_op.constant( - self.likelihood_bound, dtype=self.dtype) - # TODO(jballe): Override gradients. - likelihood = math_ops.maximum(likelihood, likelihood_bound) - - # Convert back to input tensor shape. - order = list(range(1, ndim)) - order.insert(channel_axis, 0) - values = array_ops.reshape(values, shape) - values = array_ops.transpose(values, order) - likelihood = array_ops.reshape(likelihood, shape) - likelihood = array_ops.transpose(likelihood, order) - - if not context.executing_eagerly(): - values_shape, likelihood_shape = self.compute_output_shape(inputs.shape) - values.set_shape(values_shape) - likelihood.set_shape(likelihood_shape) - - return values, likelihood - - def compress(self, inputs): - """Compress inputs and store their binary representations into strings. - - Args: - inputs: `Tensor` with values to be compressed. - - Returns: - String `Tensor` vector containing the compressed representation of each - batch element of `inputs`. - """ - with ops.name_scope(self._name_scope()): - inputs = ops.convert_to_tensor(inputs) - if not self.built: - # Check input assumptions set before layer building, e.g. input rank. - self._assert_input_compatibility(inputs) - if self.dtype is None: - self._dtype = inputs.dtype.base_dtype.name - self.build(inputs.shape) - - # Check input assumptions set after layer building, e.g. input shape. - if not context.executing_eagerly(): - self._assert_input_compatibility(inputs) - - ndim = self.input_spec.ndim - channel_axis = self._channel_axis(ndim) - # Tuple of slices for expanding dimensions of tensors below. - slices = ndim * [None] + [slice(None)] - slices[channel_axis] = slice(None) - slices = tuple(slices) - - # Expand dimensions of CDF to input dimensions, keeping the channels along - # the right dimension. - cdf = self._quantized_cdf[slices[1:]] - num_levels = array_ops.shape(cdf)[-1] - 1 - - # Bring inputs to the right range by centering the range on the medians. - half = constant_op.constant(.5, dtype=self.dtype) - medians = array_ops.squeeze(self._medians, [1, 2]) - offsets = (math_ops.cast(num_levels // 2, self.dtype) + half) - medians - # Expand offsets to input dimensions and add to inputs. - values = inputs + offsets[slices[:-1]] - - # Clip to range and cast to integers. Because we have added .5 above, and - # all values are positive, the cast effectively implements rounding. - values = math_ops.maximum(values, half) - values = math_ops.minimum( - values, math_ops.cast(num_levels, self.dtype) - half) - values = math_ops.cast(values, dtypes.int16) - - def loop_body(tensor): - return coder_ops.range_encode( - tensor, cdf, precision=self.range_coder_precision) - strings = functional_ops.map_fn( - loop_body, values, dtype=dtypes.string, back_prop=False) - - if not context.executing_eagerly(): - strings.set_shape(inputs.shape[:1]) - - return strings - - def decompress(self, strings, shape, channels=None): - """Decompress values from their compressed string representations. - - Args: - strings: A string `Tensor` vector containing the compressed data. - shape: A `Tensor` vector of int32 type. Contains the shape of the tensor - to be decompressed, excluding the batch dimension. - channels: Integer. Specifies the number of channels statically. Needs only - be set if the layer hasn't been built yet (i.e., this is the first input - it receives). - - Returns: - The decompressed `Tensor`. Its shape will be equal to `shape` prepended - with the batch dimension from `strings`. - - Raises: - ValueError: If the length of `shape` isn't available at graph construction - time. - """ - with ops.name_scope(self._name_scope()): - strings = ops.convert_to_tensor(strings) - shape = ops.convert_to_tensor(shape) - if self.built: - ndim = self.input_spec.ndim - channel_axis = self._channel_axis(ndim) - if channels is None: - channels = self.input_spec.axes[channel_axis] - else: - if not (shape.shape.is_fully_defined() and shape.shape.ndims == 1): - raise ValueError("`shape` must be a vector with known length.") - ndim = shape.shape[0].value + 1 - channel_axis = self._channel_axis(ndim) - input_shape = ndim * [None] - input_shape[channel_axis] = channels - self.build(input_shape) - - # Tuple of slices for expanding dimensions of tensors below. - slices = ndim * [None] + [slice(None)] - slices[channel_axis] = slice(None) - slices = tuple(slices) - - # Expand dimensions of CDF to input dimensions, keeping the channels along - # the right dimension. - cdf = self._quantized_cdf[slices[1:]] - num_levels = array_ops.shape(cdf)[-1] - 1 - - def loop_body(string): - return coder_ops.range_decode( - string, shape, cdf, precision=self.range_coder_precision) - outputs = functional_ops.map_fn( - loop_body, strings, dtype=dtypes.int16, back_prop=False) - outputs = math_ops.cast(outputs, self.dtype) - - medians = array_ops.squeeze(self._medians, [1, 2]) - offsets = math_ops.cast(num_levels // 2, self.dtype) - medians - outputs -= offsets[slices[:-1]] - - if not context.executing_eagerly(): - outputs_shape = ndim * [None] - outputs_shape[0] = strings.shape[0] - outputs_shape[channel_axis] = channels - outputs.set_shape(outputs_shape) - - return outputs - - def visualize(self): - """Multi-channel visualization of densities as images. - - Creates and returns an image summary visualizing the current probabilty - density estimates. The image contains one row for each channel. Within each - row, the pixel intensities are proportional to probability values, and each - row is centered on the median of the corresponding distribution. - - Returns: - The created image summary. - """ - with ops.name_scope(self._name_scope()): - image = self._pmf - image *= 255 / math_ops.reduce_max(image, axis=1, keepdims=True) - image = math_ops.cast(image + .5, dtypes.uint8) - image = image[None, :, :, None] - return summary.image("pmf", image, max_outputs=1) - - def compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape) - return input_shape, input_shape diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py b/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py deleted file mode 100644 index 798b0234ebcce7df108a0da65d1305502ce0253a..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py +++ /dev/null @@ -1,315 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests of EntropyBottleneck class.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.coder.python.layers import entropybottleneck - -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test -from tensorflow.python.training import gradient_descent - - -class EntropyBottleneckTest(test.TestCase): - - def test_noise(self): - # Tests that the noise added is uniform noise between -0.5 and 0.5. - inputs = array_ops.placeholder(dtypes.float32, (None, 1)) - layer = entropybottleneck.EntropyBottleneck() - noisy, _ = layer(inputs, training=True) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - values = np.linspace(-50, 50, 100)[:, None] - noisy, = sess.run([noisy], {inputs: values}) - self.assertFalse(np.allclose(values, noisy, rtol=0, atol=.49)) - self.assertAllClose(values, noisy, rtol=0, atol=.5) - - def test_quantization(self): - # Tests that inputs are quantized to full integer values, even after - # quantiles have been updated. - inputs = array_ops.placeholder(dtypes.float32, (None, 1)) - layer = entropybottleneck.EntropyBottleneck(optimize_integer_offset=False) - quantized, _ = layer(inputs, training=False) - opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) - self.assertTrue(len(layer.losses) == 1) - step = opt.minimize(layer.losses[0]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(step) - values = np.linspace(-50, 50, 100)[:, None] - quantized, = sess.run([quantized], {inputs: values}) - self.assertAllClose(np.around(values), quantized, rtol=0, atol=1e-6) - - def test_quantization_optimized_offset(self): - # Tests that inputs are not quantized to full integer values after quantiles - # have been updated. However, the difference between input and output should - # be between -0.5 and 0.5, and the offset must be consistent. - inputs = array_ops.placeholder(dtypes.float32, (None, 1)) - layer = entropybottleneck.EntropyBottleneck(optimize_integer_offset=True) - quantized, _ = layer(inputs, training=False) - opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) - self.assertTrue(len(layer.losses) == 1) - step = opt.minimize(layer.losses[0]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(step) - values = np.linspace(-50, 50, 100)[:, None] - quantized, = sess.run([quantized], {inputs: values}) - self.assertAllClose(values, quantized, rtol=0, atol=.5) - diff = np.ravel(np.around(values) - quantized) % 1 - self.assertAllClose(diff, np.full_like(diff, diff[0]), rtol=0, atol=5e-6) - self.assertNotEqual(diff[0], 0) - - def test_codec(self): - # Tests that inputs are compressed and decompressed correctly, and quantized - # to full integer values, even after quantiles have been updated. - inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_last", init_scale=60, - optimize_integer_offset=False) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) - self.assertTrue(len(layer.losses) == 1) - step = opt.minimize(layer.losses[0]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(step) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = np.linspace(-50, 50, 100)[None, :, None] - decoded, = sess.run([decoded], {inputs: values}) - self.assertAllClose(np.around(values), decoded, rtol=0, atol=1e-6) - - def test_codec_optimized_offset(self): - # Tests that inputs are compressed and decompressed correctly, and not - # quantized to full integer values after quantiles have been updated. - # However, the difference between input and output should be between -0.5 - # and 0.5, and the offset must be consistent. - inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_last", init_scale=60, - optimize_integer_offset=True) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) - self.assertTrue(len(layer.losses) == 1) - step = opt.minimize(layer.losses[0]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(step) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = np.linspace(-50, 50, 100)[None, :, None] - decoded, = sess.run([decoded], {inputs: values}) - self.assertAllClose(values, decoded, rtol=0, atol=.5) - diff = np.ravel(np.around(values) - decoded) % 1 - self.assertAllClose(diff, np.full_like(diff, diff[0]), rtol=0, atol=5e-6) - self.assertNotEqual(diff[0], 0) - - def test_codec_clipping(self): - # Tests that inputs are compressed and decompressed correctly, and clipped - # to the expected range. - inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_last", init_scale=40) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = np.linspace(-50, 50, 100)[None, :, None] - decoded, = sess.run([decoded], {inputs: values}) - expected = np.clip(np.around(values), -40, 40) - self.assertAllClose(expected, decoded, rtol=0, atol=1e-6) - - def test_channels_last(self): - # Test the layer with more than one channel and multiple input dimensions, - # with the channels in the last dimension. - inputs = array_ops.placeholder(dtypes.float32, (None, None, None, 2)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_last", init_scale=50) - noisy, _ = layer(inputs, training=True) - quantized, _ = layer(inputs, training=False) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = 5 * np.random.normal(size=(7, 5, 3, 2)) - noisy, quantized, decoded = sess.run( - [noisy, quantized, decoded], {inputs: values}) - self.assertAllClose(values, noisy, rtol=0, atol=.5) - self.assertAllClose(values, quantized, rtol=0, atol=.5) - self.assertAllClose(values, decoded, rtol=0, atol=.5) - - def test_channels_first(self): - # Test the layer with more than one channel and multiple input dimensions, - # with the channel dimension right after the batch dimension. - inputs = array_ops.placeholder(dtypes.float32, (None, 3, None, None)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_first", init_scale=50) - noisy, _ = layer(inputs, training=True) - quantized, _ = layer(inputs, training=False) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = 5 * np.random.normal(size=(2, 3, 5, 7)) - noisy, quantized, decoded = sess.run( - [noisy, quantized, decoded], {inputs: values}) - self.assertAllClose(values, noisy, rtol=0, atol=.5) - self.assertAllClose(values, quantized, rtol=0, atol=.5) - self.assertAllClose(values, decoded, rtol=0, atol=.5) - - def test_compress(self): - # Test compression and decompression, and produce test data for - # `test_decompress`. If you set the constant at the end to `True`, this test - # will fail and the log will contain the new test data. - inputs = array_ops.placeholder(dtypes.float32, (2, 3, 10)) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_first", filters=(), init_scale=2) - bitstrings = layer.compress(inputs) - decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - values = 5 * np.random.uniform(size=(2, 3, 10)) - 2.5 - bitstrings, quantized_cdf, decoded = sess.run( - [bitstrings, layer._quantized_cdf, decoded], {inputs: values}) - self.assertAllClose(values, decoded, rtol=0, atol=.5) - # Set this constant to `True` to log new test data for `test_decompress`. - if False: # pylint:disable=using-constant-test - assert False, (bitstrings, quantized_cdf, decoded) - - # Data generated by `test_compress`. - # pylint:disable=g-inconsistent-quotes,bad-whitespace - bitstrings = np.array([ - b'\x1e\xbag}\xc2\xdaN\x8b\xbd.', - b'\x8dF\xf0%\x1cv\xccllW' - ], dtype=object) - - quantized_cdf = np.array([ - [ 0, 15636, 22324, 30145, 38278, 65536], - [ 0, 19482, 26927, 35052, 42904, 65535], - [ 0, 21093, 28769, 36919, 44578, 65536] - ], dtype=np.int32) - - expected = np.array([ - [[-2., 1., 0., -2., -1., -2., -2., -2., 2., -1.], - [ 1., 2., 1., 0., -2., -2., 1., 2., 0., 1.], - [ 2., 0., -2., 2., 0., -1., -2., 0., 2., 0.]], - [[ 1., 2., 0., -1., 1., 2., 1., 1., 2., -2.], - [ 2., -1., -1., 0., -1., 2., 0., 2., -2., 2.], - [ 2., -2., -2., -1., -2., 1., -2., 0., 0., 0.]] - ], dtype=np.float32) - # pylint:enable=g-inconsistent-quotes,bad-whitespace - - def test_decompress(self): - # Test that decompression of values compressed with a previous version - # works, i.e. that the file format doesn't change across revisions. - bitstrings = array_ops.placeholder(dtypes.string) - input_shape = array_ops.placeholder(dtypes.int32) - quantized_cdf = array_ops.placeholder(dtypes.int32) - layer = entropybottleneck.EntropyBottleneck( - data_format="channels_first", filters=(), dtype=dtypes.float32) - layer.build(self.expected.shape) - layer._quantized_cdf = quantized_cdf - decoded = layer.decompress(bitstrings, input_shape[1:]) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - decoded, = sess.run([decoded], { - bitstrings: self.bitstrings, input_shape: self.expected.shape, - quantized_cdf: self.quantized_cdf}) - self.assertAllClose(self.expected, decoded, rtol=0, atol=1e-6) - - def test_build_decompress(self): - # Test that layer can be built when `decompress` is the first call to it. - bitstrings = array_ops.placeholder(dtypes.string) - input_shape = array_ops.placeholder(dtypes.int32, shape=[3]) - layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32) - layer.decompress(bitstrings, input_shape[1:], channels=5) - self.assertTrue(layer.built) - - def test_pmf_normalization(self): - # Test that probability mass functions are normalized correctly. - layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32) - layer.build((None, 10)) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - pmf, = sess.run([layer._pmf]) - self.assertAllClose(np.ones(10), np.sum(pmf, axis=-1), rtol=0, atol=1e-6) - - def test_visualize(self): - # Test that summary op can be constructed. - layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32) - layer.build((None, 10)) - summary = layer.visualize() - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run([summary]) - - def test_normalization(self): - # Test that densities are normalized correctly. - inputs = array_ops.placeholder(dtypes.float32, (None, 1)) - layer = entropybottleneck.EntropyBottleneck(filters=(2,)) - _, likelihood = layer(inputs, training=True) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - x = np.repeat(np.arange(-200, 201), 1000)[:, None] - likelihood, = sess.run([likelihood], {inputs: x}) - self.assertEqual(x.shape, likelihood.shape) - integral = np.sum(likelihood) * .001 - self.assertAllClose(1, integral, rtol=0, atol=1e-4) - - def test_entropy_estimates(self): - # Test that entropy estimates match actual range coding. - inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) - layer = entropybottleneck.EntropyBottleneck( - filters=(2, 3), data_format="channels_last") - _, likelihood = layer(inputs, training=True) - diff_entropy = math_ops.reduce_sum(math_ops.log(likelihood)) / -np.log(2) - _, likelihood = layer(inputs, training=False) - disc_entropy = math_ops.reduce_sum(math_ops.log(likelihood)) / -np.log(2) - bitstrings = layer.compress(inputs) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertTrue(len(layer.updates) == 1) - sess.run(layer.updates[0]) - diff_entropy, disc_entropy, bitstrings = sess.run( - [diff_entropy, disc_entropy, bitstrings], - {inputs: np.random.normal(size=(1, 10000, 1))}) - codelength = 8 * sum(len(bitstring) for bitstring in bitstrings) - self.assertAllClose(diff_entropy, disc_entropy, rtol=5e-3, atol=0) - self.assertAllClose(disc_entropy, codelength, rtol=5e-3, atol=0) - self.assertGreater(codelength, disc_entropy) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py index 3791dae8d7f6b03bc1115bca97811dfc4775c45b..ff846b191a34e3f3b4aa35671ca22b96b963db80 100644 --- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py @@ -150,7 +150,7 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix): "matrix must be two dimensional (instead is %d-dimensional)" % matrix_shape.ndims) if matrix_shape[0] != matrix_shape[1]: - raise ValueError("matrix must be be square (instead has shape (%d,%d))" % + raise ValueError("matrix must be square (instead has shape (%d,%d))" % (matrix_shape[0], matrix_shape[1])) dimension = matrix_shape[0].value if dimension is None: diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index a0dd3881a86c19e47ccb65f84a2477a55626b81c..6c9ab6aeb87fd39b22ab4f28d69b432b15899a13 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -18,7 +18,7 @@ These functions allow for recursive copying of elements (ops and variables) from one graph to another. The copied elements are initialized inside a user-specified scope in the other graph. There are separate functions to copy ops and variables. -There is also a function to retrive the copied version of an op from the +There is also a function to retrieve the copied version of an op from the first graph inside a scope in the second graph. @@copy_op_to_graph @@ -77,7 +77,7 @@ def copy_variable_to_graph(org_instance, to_graph, scope=''): else: collections.append(scope + '/' + name) - #See if its trainable. + #See if it's trainable. trainable = ( org_instance in org_instance.graph.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES)) @@ -162,7 +162,7 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): if isinstance(org_instance, ops.Tensor): - #If its a Tensor, it is one of the outputs of the underlying + #If it's a Tensor, it is one of the outputs of the underlying #op. Therefore, copy the op itself and return the appropriate #output. op = org_instance.op @@ -219,8 +219,10 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): op_def) #Use Graph's hidden methods to add the op to_graph._record_op_seen_by_control_dependencies(new_op) - for device_function in reversed(to_graph._device_function_stack): + # pylint: disable=protected-access + for device_function in to_graph._device_functions_outer_to_inner: new_op._set_device(device_function(new_op)) + # pylint: enable=protected-access return new_op diff --git a/tensorflow/contrib/crf/__init__.py b/tensorflow/contrib/crf/__init__.py index 046c509626bc2eb20a65c0b38495ff37c294e0e1..615e62b16f1906dafa22a12cc7275a2335e8df88 100644 --- a/tensorflow/contrib/crf/__init__.py +++ b/tensorflow/contrib/crf/__init__.py @@ -20,6 +20,7 @@ See the @{$python/contrib.crf} guide. @@crf_decode @@crf_log_likelihood @@crf_log_norm +@@crf_multitag_sequence_score @@crf_sequence_score @@crf_unary_score @@CrfDecodeBackwardRnnCell @@ -36,6 +37,7 @@ from tensorflow.contrib.crf.python.ops.crf import crf_binary_score from tensorflow.contrib.crf.python.ops.crf import crf_decode from tensorflow.contrib.crf.python.ops.crf import crf_log_likelihood from tensorflow.contrib.crf.python.ops.crf import crf_log_norm +from tensorflow.contrib.crf.python.ops.crf import crf_multitag_sequence_score from tensorflow.contrib.crf.python.ops.crf import crf_sequence_score from tensorflow.contrib.crf.python.ops.crf import crf_unary_score from tensorflow.contrib.crf.python.ops.crf import CrfDecodeBackwardRnnCell diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index 74f2ec22ffaab1654e5cd38169258fb87d307ad4..8cfe14205927bf7763cf36fa31012ab10fce995c 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -31,6 +31,15 @@ from tensorflow.python.platform import test class CrfTest(test.TestCase): + def calculateSequenceScore(self, inputs, transition_params, tag_indices, + sequence_lengths): + expected_unary_score = sum( + inputs[i][tag_indices[i]] for i in range(sequence_lengths)) + expected_binary_score = sum( + transition_params[tag_indices[i], tag_indices[i + 1]] + for i in range(sequence_lengths - 1)) + return expected_unary_score + expected_binary_score + def testCrfSequenceScore(self): transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) @@ -60,14 +69,55 @@ class CrfTest(test.TestCase): transition_params=constant_op.constant(transition_params)) sequence_score = array_ops.squeeze(sequence_score, [0]) tf_sequence_score = sess.run(sequence_score) - expected_unary_score = sum(inputs[i][tag_indices[i]] - for i in range(sequence_lengths)) - expected_binary_score = sum( - transition_params[tag_indices[i], tag_indices[i + 1]] - for i in range(sequence_lengths - 1)) - expected_sequence_score = expected_unary_score + expected_binary_score + expected_sequence_score = self.calculateSequenceScore( + inputs, transition_params, tag_indices, sequence_lengths) self.assertAllClose(tf_sequence_score, expected_sequence_score) + def testCrfMultiTagSequenceScore(self): + transition_params = np.array( + [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[4, 5, -3]], + dtype=np.float32), + ] + tag_bitmap_list = [ + np.array( + [[True, True, False], [True, False, True], [False, True, True], + [True, False, True]], + dtype=np.bool), + np.array([[True, True, False]], dtype=np.bool) + ] + for sequence_lengths, inputs, tag_bitmap in zip( + sequence_lengths_list, inputs_list, tag_bitmap_list): + with self.test_session() as sess: + sequence_score = crf.crf_multitag_sequence_score( + inputs=array_ops.expand_dims(inputs, 0), + tag_bitmap=array_ops.expand_dims(tag_bitmap, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params)) + sequence_score = array_ops.squeeze(sequence_score, [0]) + tf_sum_sequence_score = sess.run(sequence_score) + all_indices_list = [ + single_index_bitmap.nonzero()[0] + for single_index_bitmap in tag_bitmap[:sequence_lengths] + ] + expected_sequence_scores = [ + self.calculateSequenceScore(inputs, transition_params, indices, + sequence_lengths) + for indices in itertools.product(*all_indices_list) + ] + expected_log_sum_exp_sequence_scores = np.logaddexp.reduce( + expected_sequence_scores) + self.assertAllClose(tf_sum_sequence_score, + expected_log_sum_exp_sequence_scores) + def testCrfUnaryScore(self): inputs = np.array( [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) @@ -108,7 +158,7 @@ class CrfTest(test.TestCase): # Test both the length-1 and regular cases. sequence_lengths_list = [ np.array(3, dtype=np.int32), - np.array(1, dtype=np.int32) + np.array(1, dtype=np.int64) ] inputs_list = [ np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], @@ -241,7 +291,7 @@ class CrfTest(test.TestCase): # Test both the length-1 and regular cases. sequence_lengths_list = [ np.array(3, dtype=np.int32), - np.array(1, dtype=np.int32) + np.array(1, dtype=np.int64) ] inputs_list = [ np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 2d2cbdc1990ed9d8e58c0032cbc141a52271838f..2a91dcb63a80016e62d10d1310ca57e3e54434c5 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -67,7 +67,7 @@ __all__ = [ "crf_sequence_score", "crf_log_norm", "crf_log_likelihood", "crf_unary_score", "crf_binary_score", "CrfForwardRnnCell", "viterbi_decode", "crf_decode", "CrfDecodeForwardRnnCell", - "CrfDecodeBackwardRnnCell" + "CrfDecodeBackwardRnnCell", "crf_multitag_sequence_score" ] @@ -114,6 +114,56 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths, false_fn=_multi_seq_fn) +def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths, + transition_params): + """Computes the unnormalized score of all tag sequences matching tag_bitmap. + + tag_bitmap enables more than one tag to be considered correct at each time + step. This is useful when an observed output at a given time step is + consistent with more than one tag, and thus the log likelihood of that + observation must take into account all possible consistent tags. + + Using one-hot vectors in tag_bitmap gives results identical to + crf_sequence_score. + + Args: + inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials + to use as input to the CRF layer. + tag_bitmap: A [batch_size, max_seq_len, num_tags] boolean tensor + representing all active tags at each index for which to calculate the + unnormalized score. + sequence_lengths: A [batch_size] vector of true sequence lengths. + transition_params: A [num_tags, num_tags] transition matrix. + Returns: + sequence_scores: A [batch_size] vector of unnormalized sequence scores. + """ + + # If max_seq_len is 1, we skip the score calculation and simply gather the + # unary potentials of all active tags. + def _single_seq_fn(): + filtered_inputs = array_ops.where( + tag_bitmap, inputs, + array_ops.fill(array_ops.shape(inputs), float("-inf"))) + return math_ops.reduce_logsumexp( + filtered_inputs, axis=[1, 2], keepdims=False) + + def _multi_seq_fn(): + # Compute the logsumexp of all scores of sequences matching the given tags. + filtered_inputs = array_ops.where( + tag_bitmap, inputs, + array_ops.fill(array_ops.shape(inputs), float("-inf"))) + return crf_log_norm( + inputs=filtered_inputs, + sequence_lengths=sequence_lengths, + transition_params=transition_params) + + return utils.smart_cond( + pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1], + 1), + true_fn=_single_seq_fn, + false_fn=_multi_seq_fn) + + def crf_log_norm(inputs, sequence_lengths, transition_params): """Computes the normalization for a CRF. @@ -498,7 +548,9 @@ def crf_decode(potentials, transition_params, sequence_length): initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] # Sequence length is not allowed to be less than zero. - sequence_length_less_one = math_ops.maximum(0, sequence_length - 1) + sequence_length_less_one = math_ops.maximum( + constant_op.constant(0, dtype=sequence_length.dtype), + sequence_length - 1) backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O] crf_fwd_cell, inputs=inputs, diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 156538b4e01bf1a1ccca0fca1e309b1d37b6dbc0..7878e46e88b2ea8b0012768342c218baeda80eaa 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -34,6 +34,7 @@ See @{$guide/datasets$Importing Data} for an overview. @@batch_and_drop_remainder @@bucket_by_sequence_length @@choose_from_datasets +@@copy_to_device @@dense_to_sparse_batch @@enumerate_dataset @@ -51,6 +52,7 @@ See @{$guide/datasets$Importing Data} for an overview. @@prefetch_to_device @@read_batch_features @@rejection_resample +@@reduce_dataset @@sample_from_datasets @@scan @@shuffle_and_repeat @@ -76,6 +78,7 @@ from tensorflow.contrib.data.python.ops.counter import Counter from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset from tensorflow.contrib.data.python.ops.error_ops import ignore_errors from tensorflow.contrib.data.python.ops.get_single_element import get_single_element +from tensorflow.contrib.data.python.ops.get_single_element import reduce_dataset from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length from tensorflow.contrib.data.python.ops.grouping import group_by_reducer from tensorflow.contrib.data.python.ops.grouping import group_by_window @@ -86,6 +89,7 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator +from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device from tensorflow.contrib.data.python.ops.random_ops import RandomDataset from tensorflow.contrib.data.python.ops.readers import CsvDataset diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD index 7b69e10441eba3e38c979d5715c16699ac2710ed..2e249f5c14ab111ae412ff3288acc25de8d7aa11 100644 --- a/tensorflow/contrib/data/kernels/BUILD +++ b/tensorflow/contrib/data/kernels/BUILD @@ -37,6 +37,7 @@ cc_library( "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], + alwayslink = 1, ) cc_library( @@ -58,6 +59,7 @@ cc_library( "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], + alwayslink = 1, ) cc_library( @@ -68,11 +70,24 @@ cc_library( "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], + alwayslink = 1, +) + +cc_library( + name = "assert_next_dataset_op", + srcs = ["assert_next_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, ) cc_library( name = "dataset_kernels", deps = [ + ":assert_next_dataset_op", ":csv_dataset_op", ":directed_interleave_dataset_op", ":ignore_errors_dataset_op", diff --git a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..95b8e1f7fd487119d77a5f708de42b014c55f79d --- /dev/null +++ b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc @@ -0,0 +1,152 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. +class AssertNextDatasetOp : public UnaryDatasetOpKernel { + public: + explicit AssertNextDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + std::vector transformations; + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "transformations", + &transformations)); + *output = + new Dataset(ctx, input, transformations, output_types_, output_shapes_); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const std::vector& transformations, + const DataTypeVector& output_types, + const std::vector& output_shapes) + : GraphDatasetBase(ctx), + input_(input), + transformations_(transformations), + output_types_(output_types), + output_shapes_(output_shapes) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::Assert")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "AssertNextDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); + Node* transformations_node = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {input_graph_node, transformations_node}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + std::vector tokens = + str_util::Split(prefix(), ':', str_util::SkipEmpty()); + if (dataset()->transformations_.size() > tokens.size() - 2) { + return errors::InvalidArgument( + "Asserted next ", dataset()->transformations_.size(), + " transformations but encountered only ", tokens.size() - 2, "."); + } + int n = tokens.size(); + for (size_t i = 0; i < dataset()->transformations_.size(); ++i) { + if (dataset()->transformations_[i] != tokens[n - 2 - i]) { + return errors::InvalidArgument( + "Asserted ", dataset()->transformations_[i], + " transformation at offset ", i, " but encountered ", + tokens[n - 2 - i], " transformation instead."); + } + } + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + return Status::OK(); + } + + private: + std::unique_ptr input_impl_; + }; + + const DatasetBase* input_; + const std::vector transformations_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + }; + + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU), + AssertNextDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index 4657807785d58727d34f37172bd30c56a5b7cde6..f7e3ed886c6655cdc07e08bbe2fbe82e671a6802 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -18,7 +18,10 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/io/inputstream_interface.h" #include "tensorflow/core/lib/io/random_inputstream.h" +#include "tensorflow/core/lib/io/zlib_compression_options.h" +#include "tensorflow/core/lib/io/zlib_inputstream.h" namespace tensorflow { namespace { @@ -37,6 +40,10 @@ class CSVDatasetOp : public DatasetOpKernel { ctx, filenames_tensor->dims() <= 1, errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + string compression_type; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "compression_type", + &compression_type)); + OpInputList record_defaults_list; OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults_list)); @@ -86,6 +93,19 @@ class CSVDatasetOp : public DatasetOpKernel { filenames.push_back(filenames_tensor->flat()(i)); } + io::ZlibCompressionOptions zlib_compression_options = + io::ZlibCompressionOptions::DEFAULT(); + if (compression_type == "ZLIB") { + zlib_compression_options = io::ZlibCompressionOptions::DEFAULT(); + } else if (compression_type == "GZIP") { + zlib_compression_options = io::ZlibCompressionOptions::GZIP(); + } else { + OP_REQUIRES(ctx, compression_type.empty(), + errors::InvalidArgument( + "Unsupported compression_type: ", compression_type, ".")); + } + zlib_compression_options.input_buffer_size = buffer_size; + std::vector select_cols; select_cols.reserve(select_cols_tensor->NumElements()); for (int i = 0; i < select_cols_tensor->NumElements(); ++i) { @@ -103,7 +123,8 @@ class CSVDatasetOp : public DatasetOpKernel { ctx, select_cols.empty() || select_cols.front() >= 0, errors::InvalidArgument("select_cols should be non-negative indices")); - *output = new Dataset(ctx, std::move(filenames), header, buffer_size, + *output = new Dataset(ctx, std::move(filenames), header, + std::move(compression_type), zlib_compression_options, output_types_, output_shapes_, std::move(record_defaults), std::move(select_cols), use_quote_delim, delim[0], std::move(na_value)); @@ -113,21 +134,24 @@ class CSVDatasetOp : public DatasetOpKernel { class Dataset : public GraphDatasetBase { public: Dataset(OpKernelContext* ctx, std::vector filenames, bool header, - int64 buffer_size, const DataTypeVector& output_types, + string compression_type, io::ZlibCompressionOptions options, + const DataTypeVector& output_types, const std::vector& output_shapes, std::vector record_defaults, std::vector select_cols, bool use_quote_delim, char delim, string na_value) : GraphDatasetBase(ctx), filenames_(std::move(filenames)), header_(header), - buffer_size_(buffer_size), out_type_(output_types), output_shapes_(output_shapes), record_defaults_(std::move(record_defaults)), select_cols_(std::move(select_cols)), use_quote_delim_(use_quote_delim), delim_(delim), - na_value_(std::move(na_value)) {} + na_value_(std::move(na_value)), + use_compression_(!compression_type.empty()), + compression_type_(std::move(compression_type)), + options_(options) {} std::unique_ptr MakeIteratorInternal( const string& prefix) const override { @@ -146,10 +170,45 @@ class CSVDatasetOp : public DatasetOpKernel { protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, Node** output) const override { - // TODO(rachelim): Implement this - std::vector input_tensors; - TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output)); - return errors::Unimplemented("CSVDataset: AsGraphDefInternal"); + Node* filenames = nullptr; + Node* compression_type = nullptr; + Node* buffer_size = nullptr; + Node* header = nullptr; + Node* delim = nullptr; + Node* use_quote_delim = nullptr; + Node* na_value = nullptr; + Node* select_cols = nullptr; + + std::vector record_defaults; + record_defaults.reserve(record_defaults_.size()); + for (const Tensor& t : record_defaults_) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + record_defaults.emplace_back(node); + } + + TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); + TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type)); + TF_RETURN_IF_ERROR( + b->AddScalar(options_.input_buffer_size, &buffer_size)); + TF_RETURN_IF_ERROR(b->AddScalar(header_, &header)); + + string delim_string(1, delim_); + TF_RETURN_IF_ERROR(b->AddScalar(delim_string, &delim)); + TF_RETURN_IF_ERROR(b->AddScalar(use_quote_delim_, &use_quote_delim)); + TF_RETURN_IF_ERROR(b->AddScalar(na_value_, &na_value)); + TF_RETURN_IF_ERROR(b->AddVector(select_cols_, &select_cols)); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, + {std::make_pair(0, filenames), std::make_pair(1, compression_type), + std::make_pair(2, buffer_size), std::make_pair(3, header), + std::make_pair(4, delim), std::make_pair(5, use_quote_delim), + std::make_pair(6, na_value), + std::make_pair(7, select_cols)}, // Single tensor inputs + {std::make_pair(8, record_defaults)}, // Tensor list inputs + {}, output)); + return Status::OK(); } private: @@ -201,14 +260,58 @@ class CSVDatasetOp : public DatasetOpKernel { protected: Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); - // TODO(rachelim): Implement save - return errors::Unimplemented("CSVDataset: SaveInternal"); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), + current_file_index_)); + // `input_stream_` is empty if + // 1. GetNext has not been called even once. + // 2. All files have been read and the iterator has been exhausted. + if (input_stream_ && num_buffer_reads_ > 0) { + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("pos"), pos_)); + // If num_buffer_reads_ == 0, the buffer hasn't been filled even once. + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_buffer_reads"), + num_buffer_reads_)); + } + return Status::OK(); } + Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); - // TODO(rachelim): Implement restore - return errors::Unimplemented("CSVDataset: RestoreInternal"); + ResetStreamsLocked(); + int64 current_file_index; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"), + ¤t_file_index)); + current_file_index_ = size_t(current_file_index); + // The keys "pos" and "num_buffer_reads" are written only if + // the iterator was saved with an open, partially read file. + if (reader->Contains(full_name("pos"))) { + int64 pos, num_buffer_reads; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("pos"), &pos)); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_buffer_reads"), + &num_buffer_reads)); + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + + num_buffer_reads_ = size_t(num_buffer_reads - 1); + + // Restores the most recently held buffer + Status s = input_stream_->SkipNBytes( + num_buffer_reads_ * dataset()->options_.input_buffer_size); + if (!s.ok() && !errors::IsOutOfRange(s)) { + // We might get out of range error here if the size of the file + // is not an exact multiple of the buffer size, and the last buffer + // read is < buffer_size. This is valid and we do not surface the + // error. + return s; + } + + Status s2 = FillBuffer(&buffer_); + if (!s2.ok() && !errors::IsOutOfRange(s2)) { + return s2; + } + pos_ = size_t(pos); + } + return Status::OK(); } private: @@ -510,7 +613,9 @@ class CSVDatasetOp : public DatasetOpKernel { Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { result->clear(); - Status s = input_stream_->ReadNBytes(dataset()->buffer_size_, result); + ++num_buffer_reads_; + Status s = input_stream_->ReadNBytes( + dataset()->options_.input_buffer_size, result); if (errors::IsOutOfRange(s) && !result->empty()) { // Ignore OutOfRange error when ReadNBytes read < N bytes. @@ -675,10 +780,20 @@ class CSVDatasetOp : public DatasetOpKernel { // Actually move on to next file. TF_RETURN_IF_ERROR(env->NewRandomAccessFile( dataset()->filenames_[current_file_index_], &file_)); - input_stream_.reset( - new io::RandomAccessInputStream(file_.get(), false)); + random_access_input_stream_ = + std::make_shared(file_.get(), false); + + if (dataset()->use_compression_) { + input_stream_ = std::make_shared( + random_access_input_stream_.get(), + dataset()->options_.input_buffer_size, + dataset()->options_.input_buffer_size, dataset()->options_); + } else { + input_stream_ = random_access_input_stream_; + } buffer_.clear(); pos_ = 0; + num_buffer_reads_ = 0; if (dataset()->header_) { // Read one line, but don't include it. Pass nullptrs as dummy // pointers to objects that shouldn't be invoked anyway @@ -704,8 +819,10 @@ class CSVDatasetOp : public DatasetOpKernel { string buffer_ GUARDED_BY(mu_); // Maintain our own buffer size_t pos_ GUARDED_BY( mu_); // Index into the buffer must be maintained between iters - std::unique_ptr input_stream_ + size_t num_buffer_reads_ GUARDED_BY(mu_); + std::shared_ptr random_access_input_stream_ GUARDED_BY(mu_); + std::shared_ptr input_stream_ GUARDED_BY(mu_); size_t current_file_index_ GUARDED_BY(mu_) = 0; std::unique_ptr file_ GUARDED_BY(mu_); // must outlive input_stream_ @@ -713,7 +830,6 @@ class CSVDatasetOp : public DatasetOpKernel { const std::vector filenames_; const bool header_; - const int64 buffer_size_; const DataTypeVector out_type_; const std::vector output_shapes_; const std::vector record_defaults_; @@ -721,6 +837,9 @@ class CSVDatasetOp : public DatasetOpKernel { const bool use_quote_delim_; const char delim_; const string na_value_; + const bool use_compression_; + const string compression_type_; + const io::ZlibCompressionOptions options_; }; // class Dataset DataTypeVector output_types_; diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index b3d464d7165d53cf198072e06214f7d5e982073d..32f03ca68364e40c6fd6769f05d0566f50119240 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" @@ -23,6 +24,7 @@ limitations under the License. #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { +namespace { struct BufferElement { // The producer sets `status` if getting the input element fails. @@ -473,4 +475,466 @@ class IteratorGetDeviceOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("IteratorGetDevice").Device(DEVICE_CPU), IteratorGetDeviceOp); +Status VerifyTypesMatch(const DataTypeVector& expected, + const DataTypeVector& received) { + if (expected.size() != received.size()) { + return errors::InvalidArgument( + "Number of components does not match: expected ", expected.size(), + " types but got ", received.size(), "."); + } + for (size_t i = 0; i < expected.size(); ++i) { + if (expected[i] != received[i]) { + return errors::InvalidArgument("Data type mismatch at component ", i, + ": expected ", DataTypeString(expected[i]), + " but got ", DataTypeString(received[i]), + "."); + } + } + return Status::OK(); +} + +Status VerifyShapesCompatible(const std::vector& expected, + const std::vector& received) { + if (expected.size() != received.size()) { + return errors::InvalidArgument( + "Number of components does not match: expected ", expected.size(), + " shapes but got ", received.size(), "."); + } + for (size_t i = 0; i < expected.size(); ++i) { + if (!expected[i].IsCompatibleWith(received[i])) { + return errors::InvalidArgument("Incompatible shapes at component ", i, + ": expected ", expected[i].DebugString(), + " but got ", received[i].DebugString(), + "."); + } + } + + return Status::OK(); +} + +string SanitizeThreadSuffix(string suffix) { + string clean; + for (int i = 0; i < suffix.size(); ++i) { + const char ch = suffix[i]; + if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || + (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') { + clean += ch; + } else { + clean += '_'; + } + } + return clean; +} + +class MultiDeviceIterator : public ResourceBase { + public: + MultiDeviceIterator(const DataTypeVector& output_types, + const std::vector& output_shapes, + const std::vector& devices, + std::unique_ptr flib_def, + std::unique_ptr pflr, + FunctionLibraryRuntime* lib) + : output_types_(output_types), + output_shapes_(output_shapes), + devices_(devices), + flib_def_(std::move(flib_def)), + pflr_(std::move(pflr)), + lib_(lib) { + buffer_.resize(devices_.size()); + } + + string DebugString() override { + return strings::StrCat("MultiDeviceIterator"); + } + + Status Init(std::unique_ptr iterator, int64* incarnation_id) { + mutex_lock l(mu_); + if (iterator) { + TF_RETURN_IF_ERROR( + VerifyTypesMatch(output_types_, iterator->output_dtypes())); + TF_RETURN_IF_ERROR( + VerifyShapesCompatible(output_shapes_, iterator->output_shapes())); + } + host_iterator_.reset(iterator.release()); + incarnation_id_++; + *incarnation_id = incarnation_id_; + max_buffer_size_ = 0; + num_elements_ = 0; + buffer_.clear(); + buffer_.resize(devices_.size()); + return Status::OK(); + } + + Status GetNextFromShard(IteratorContext* ctx, int shard_num, + int64 incarnation_id, + std::vector* out_tensors, + bool* end_of_sequence) { + // TODO(rohanj): This might potentially strand elements in other shards. + // Opportunity to do smarter locking semantics. + mutex_lock l(mu_); + // Make sure we're in the right incarnation. + if (incarnation_id != incarnation_id_) { + return errors::InvalidArgument( + "Current incarnation: ", incarnation_id_, + "; Supplied incarnation: ", incarnation_id); + } + // Then look it up in the buffer. + if (!buffer_[shard_num].empty()) { + const HostBufferElement& elem = buffer_[shard_num].front(); + *out_tensors = elem.value; + *end_of_sequence = elem.end_of_sequence; + Status s = elem.status; + buffer_[shard_num].pop_front(); + return s; + } + std::shared_ptr captured_iterator(host_iterator_); + if (captured_iterator) { + if (lib_ != nullptr) { + ctx->set_lib(lib_); + } + while (true) { + HostBufferElement elem; + elem.status = + captured_iterator->GetNext(ctx, &elem.value, &elem.end_of_sequence); + int buffer_index = num_elements_ % devices_.size(); + num_elements_++; + if (buffer_index == shard_num) { + out_tensors->swap(elem.value); + *end_of_sequence = elem.end_of_sequence; + return elem.status; + } else { + buffer_[buffer_index].push_back(std::move(elem)); + // TODO(rohanj): Put an upper bound to buffer size. + if (buffer_[buffer_index].size() > max_buffer_size_) { + max_buffer_size_ = buffer_[buffer_index].size(); + VLOG(1) << "MultiDeviceIterator: Max buffer size increased to: " + << max_buffer_size_; + } + } + } + } else { + return errors::FailedPrecondition("Iterator not initialized"); + } + return Status::OK(); + } + + const DataTypeVector& output_types() const { return output_types_; } + + const std::vector& output_shapes() const { + return output_shapes_; + } + + std::shared_ptr function_library() { + tf_shared_lock l(mu_); + return lib_def_; + } + + private: + struct HostBufferElement { + Status status; + bool end_of_sequence; + std::vector value; + }; + + mutex mu_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + const std::vector devices_; + int64 num_elements_ GUARDED_BY(mu_) = 0; + int64 max_buffer_size_ GUARDED_BY(mu_) = 0; + int64 incarnation_id_ GUARDED_BY(mu_) = 0; + std::vector> buffer_ GUARDED_BY(mu_); + std::unique_ptr flib_def_; + std::unique_ptr pflr_; + FunctionLibraryRuntime* lib_ = nullptr; // not owned. + std::shared_ptr host_iterator_; + std::shared_ptr lib_def_ GUARDED_BY(mu_); +}; + +// Just creates a MultiDeviceIterator and returns it. +class MultiDeviceIteratorHandleOp : public OpKernel { + public: + explicit MultiDeviceIteratorHandleOp(OpKernelConstruction* ctx) + : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("devices", &devices_)); + } + + // The resource is deleted from the resource manager only when it is private + // to kernel. + ~MultiDeviceIteratorHandleOp() override { + if (resource_ != nullptr) { + resource_->Unref(); + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->template Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + } + + void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) { + { + mutex_lock l(mu_); + if (resource_ == nullptr) { + FunctionLibraryRuntime* lib; + std::unique_ptr flib_def(nullptr); + std::unique_ptr pflr(nullptr); + OP_REQUIRES_OK(context, context->function_library()->Clone( + &flib_def, &pflr, &lib)); + ResourceMgr* mgr = context->resource_manager(); + OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); + + MultiDeviceIterator* resource; + OP_REQUIRES_OK( + context, + mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, lib, &flib_def, &pflr](MultiDeviceIterator** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = new MultiDeviceIterator( + output_types_, output_shapes_, devices_, + std::move(flib_def), std::move(pflr), lib); + return Status::OK(); + })); + + Status s = VerifyResource(resource); + if (TF_PREDICT_FALSE(!s.ok())) { + resource->Unref(); + context->SetStatus(s); + return; + } + + resource_ = resource; + } + } + OP_REQUIRES_OK(context, MakeResourceHandleToOutput( + context, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } + + private: + // During the first Compute(), resource is either created or looked up using + // shared_name. In the latter case, the resource found should be verified if + // it is compatible with this op's configuration. The verification may fail in + // cases such as two graphs asking queues of the same shared name to have + // inconsistent capacities. + Status VerifyResource(MultiDeviceIterator* resource) { + TF_RETURN_IF_ERROR( + VerifyTypesMatch(output_types_, resource->output_types())); + TF_RETURN_IF_ERROR( + VerifyShapesCompatible(output_shapes_, resource->output_shapes())); + return Status::OK(); + } + + mutex mu_; + ContainerInfo cinfo_; // Written once under mu_ then constant afterwards. + MultiDeviceIterator* resource_ GUARDED_BY(mu_) = nullptr; + DataTypeVector output_types_; + std::vector output_shapes_; + const int graph_def_version_; + string name_; + string container_; + std::vector devices_; +}; + +REGISTER_KERNEL_BUILDER(Name("MultiDeviceIterator").Device(DEVICE_CPU), + MultiDeviceIteratorHandleOp); + +// Calls init on the MultiDeviceIterator. +class MultiDeviceIteratorInitOp : public OpKernel { + public: + explicit MultiDeviceIteratorInitOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + DatasetBase* dataset; + OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset)); + MultiDeviceIterator* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 1), &resource)); + core::ScopedUnref unref(resource); + + IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); + std::unique_ptr iterator; + OP_REQUIRES_OK(ctx, + dataset->MakeIterator(&iter_ctx, "Iterator", &iterator)); + int64 incarnation_id; + OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), &incarnation_id)); + Tensor tensor_incarnation_id(DT_INT64, TensorShape({})); + tensor_incarnation_id.scalar()() = incarnation_id; + OP_REQUIRES_OK(ctx, + ctx->set_output("incarnation_id", tensor_incarnation_id)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("MultiDeviceIteratorInit").Device(DEVICE_CPU), + MultiDeviceIteratorInitOp); + +// Calls GetNextFromShard(shard) and returns a vector of Tensors as output. +// TODO(rohanj): Implement using BackgroundWorker that Derek built? +class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { + public: + explicit MultiDeviceIteratorGetNextFromShardOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), + thread_pool_(new thread::ThreadPool( + ctx->env(), ThreadOptions(), + strings::StrCat("multi_device_iterator_get_next_thread_", + SanitizeThreadSuffix(name())), + 1 /* num_threads */, false /* low_latency_hint */)) {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + const Tensor* tensor_shard_num; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input("shard_num", &tensor_shard_num), done); + int32 shard_num = tensor_shard_num->scalar()(); + + const Tensor* tensor_incarnation_id; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done); + int64 incarnation_id = tensor_incarnation_id->scalar()(); + + MultiDeviceIterator* iterator; + OP_REQUIRES_OK_ASYNC( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done); + thread_pool_->Schedule(std::bind( + [ctx, iterator, shard_num, incarnation_id](DoneCallback done) { + std::vector components; + bool end_of_sequence = false; + + IteratorContext::Params params; + params.env = ctx->env(); + params.runner = *(ctx->runner()); + params.function_library = iterator->function_library(); + DeviceBase* device = ctx->function_library()->device(); + params.allocator_getter = [device](AllocatorAttributes attrs) { + return device->GetAllocator(attrs); + }; + IteratorContext iter_ctx(std::move(params)); + + Status s = + iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id, + &components, &end_of_sequence); + iterator->Unref(); + + if (!s.ok()) { + ctx->SetStatus(s); + } else if (end_of_sequence) { + ctx->SetStatus(errors::OutOfRange("End of sequence")); + } else { + for (int i = 0; i < components.size(); ++i) { + // TODO(mrry): Check that the shapes match the shape attrs. + ctx->set_output(i, components[i]); + } + } + done(); + }, + std::move(done))); + } + + private: + std::unique_ptr thread_pool_; +}; + +REGISTER_KERNEL_BUILDER( + Name("MultiDeviceIteratorGetNextFromShard").Device(DEVICE_CPU), + MultiDeviceIteratorGetNextFromShardOp); + +class MultiDeviceIteratorToStringHandleOp : public OpKernel { + public: + explicit MultiDeviceIteratorToStringHandleOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& resource_handle_t = ctx->input(0); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), + errors::InvalidArgument("resource_handle must be a scalar")); + + // Validate that the handle corresponds to a real resource, and + // that it is an MultiDeviceIterator. + MultiDeviceIterator* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + resource->Unref(); + + Tensor* string_handle_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({}), &string_handle_t)); + string_handle_t->scalar()() = + resource_handle_t.scalar()().SerializeAsString(); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("MultiDeviceIteratorToStringHandle").Device(DEVICE_CPU), + MultiDeviceIteratorToStringHandleOp); + +class MultiDeviceIteratorFromStringHandleOp : public OpKernel { + public: + explicit MultiDeviceIteratorFromStringHandleOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES( + ctx, + output_types_.empty() || output_shapes_.empty() || + output_types_.size() == output_shapes_.size(), + errors::InvalidArgument("If both 'output_types' and 'output_shapes' " + "are set, they must have the same length.")); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& string_handle_t = ctx->input(0); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()), + errors::InvalidArgument("string_handle must be a scalar")); + + ResourceHandle resource_handle; + OP_REQUIRES( + ctx, + resource_handle.ParseFromString(string_handle_t.scalar()()), + errors::InvalidArgument( + "Could not parse string_handle as a valid ResourceHandle")); + + OP_REQUIRES( + ctx, resource_handle.device() == ctx->device()->attributes().name(), + errors::InvalidArgument("Attempted create an iterator on device \"", + ctx->device()->attributes().name(), + "\" from handle defined on device \"", + resource_handle.device(), "\"")); + + // Validate that the handle corresponds to a real resource, and + // that it is an MultiDeviceIterator. + MultiDeviceIterator* resource; + OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource)); + core::ScopedUnref unref_iterator(resource); + if (!output_types_.empty()) { + OP_REQUIRES_OK(ctx, + VerifyTypesMatch(output_types_, resource->output_types())); + } + if (!output_shapes_.empty()) { + OP_REQUIRES_OK(ctx, VerifyShapesCompatible(output_shapes_, + resource->output_shapes())); + } + + Tensor* resource_handle_t; + OP_REQUIRES_OK( + ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t)); + resource_handle_t->scalar()() = resource_handle; + } + + private: + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +REGISTER_KERNEL_BUILDER( + Name("MultiDeviceIteratorFromStringHandle").Device(DEVICE_CPU), + MultiDeviceIteratorFromStringHandleOp); + +} // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index 8413fcaf872f49f654c6a1327a14d5c44bdd815a..66a7c7fdcd5e0ab77596177c209470e17f63bc10 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -36,6 +36,7 @@ data_input_datasets: `N` datasets with the same type that will be interleaved REGISTER_OP("CSVDataset") .Input("filenames: string") + .Input("compression_type: string") .Input("buffer_size: int64") .Input("header: bool") .Input("field_delim: string") @@ -52,17 +53,18 @@ REGISTER_OP("CSVDataset") shape_inference::ShapeHandle unused; // `filenames` must be a scalar or a vector. TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); - // `buffer_size`, `header`, `field_delim`, `use_quote_delim`, - // `na_value` must be scalars + // `compression_type`, `buffer_size`, `header`, `field_delim`, + // `use_quote_delim`, `na_value` must be scalars TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // `select_cols` must be a vector - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &unused)); - // `record_defaults` must be a list of scalars...? - for (size_t i = 7; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused)); + // `record_defaults` must be lists of scalars + for (size_t i = 8; i < c->num_inputs(); ++i) { TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &unused)); } return shape_inference::ScalarShape(c); @@ -143,6 +145,80 @@ Resets the FunctionBufferingResource. function_buffer_resource: The FunctionBufferingResource handle. )doc"); +REGISTER_OP("MultiDeviceIterator") + .Output("handle: resource") + .Attr("devices: list(string) >= 1") + .Attr("shared_name: string") + .Attr("container: string") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .Doc(R"doc( +Creates a MultiDeviceIterator resource. + +handle: Handle to the resource created. +devices: A list of devices the iterator works across. +shared_name: If non-empty, this resource will be shared under the given name + across multiple sessions. +container: If non-empty, this resource is placed in the given container. + Otherwise, a default container is used. +output_types: The type list for the return values. +output_shapes: The list of shapes being produced. +)doc"); + +REGISTER_OP("MultiDeviceIteratorInit") + .Input("dataset: variant") + .Input("multi_device_iterator: resource") + .Output("incarnation_id: int64") + .Doc(R"doc( +Initializes the multi device iterator with the given dataset. +incarnation_id: An int64 indicating which incarnation of the MultiDeviceIterator + is running. +dataset: Dataset to be iterated upon. +multi_device_iterator: A MultiDeviceIteratorResource. +)doc"); + +REGISTER_OP("MultiDeviceIteratorGetNextFromShard") + .Input("multi_device_iterator: resource") + .Input("shard_num: int32") + .Input("incarnation_id: int64") + .Output("components: output_types") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .Doc(R"doc( +Gets next element for the provided shard number. + +multi_device_iterator: A MultiDeviceIterator resource. +shard_num: Integer representing which shard to fetch data for. +incarnation_id: Which incarnation of the MultiDeviceIterator is running. +components: Result of the get_next on the dataset. +output_types: The type list for the return values. +output_shapes: The list of shapes being produced. +)doc"); + +REGISTER_OP("MultiDeviceIteratorToStringHandle") + .Input("multi_device_iterator: resource") + .Output("string_handle: string") + .Doc(R"doc( +Produces a string handle for the given MultiDeviceIterator. + +multi_device_iterator: A MultiDeviceIterator resource. +string_handle: A string representing the resource. +)doc"); + +REGISTER_OP("MultiDeviceIteratorFromStringHandle") + .Input("string_handle: string") + .Output("multi_device_iterator: resource") + .Attr("output_types: list(type) >= 0 = []") + .Attr("output_shapes: list(shape) >= 0 = []") + .Doc(R"doc( +Generates a MultiDeviceIterator resource from its provided string handle. + +string_handle: String representing the resource. +multi_device_iterator: A MultiDeviceIterator resource. +output_types: The type list for the return values. +output_shapes: The list of shapes being produced. +)doc"); + REGISTER_OP("ThreadPoolDataset") .Input("input_dataset: variant") .Input("thread_pool: resource") @@ -175,4 +251,17 @@ display_name: A human-readable name for the threads that may be visible in some visualizations. )doc"); +REGISTER_OP("AssertNextDataset") + .Input("input_dataset: variant") + .Input("transformations: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // transformations should be a vector. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + return shape_inference::ScalarShape(c); + }); + } // namespace tensorflow diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index d81654e039c53e5b9434288352ef1b2416a4b7e8..ea92191f3e20218d58bb2d0ba1ce7c1120361d45 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -60,7 +60,7 @@ py_test( py_test( name = "csv_dataset_op_test", - size = "small", + size = "medium", srcs = ["csv_dataset_op_test.py"], srcs_version = "PY2AND3", tags = ["no_pip"], @@ -121,6 +121,7 @@ py_test( srcs = ["get_single_element_test.py"], deps = [ "//tensorflow/contrib/data/python/ops:get_single_element", + "//tensorflow/contrib/data/python/ops:grouping", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -128,6 +129,7 @@ py_test( "//tensorflow/python:errors", "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", ], ) @@ -188,29 +190,56 @@ py_test( "optonly", ], deps = [ + "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/contrib/data/python/ops:error_ops", + "//tensorflow/contrib/data/python/ops:optimization", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:io_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) +py_test( + name = "map_defun_op_test", + size = "small", + srcs = ["map_defun_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/contrib/data/python/ops:map_defun", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + ], +) + py_test( name = "optimize_dataset_op_test", size = "small", srcs = ["optimize_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ + ":stats_dataset_test_base", "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/core:protos_all_py", + "//tensorflow/contrib/data/python/ops:stats_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:math_ops", "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", ], ) @@ -228,9 +257,16 @@ cuda_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:function", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/compat:compat", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", ], + tags = [ + "manual", + "no_oss", + "no_windows_gpu", + "notap", + ], ) py_test( @@ -377,6 +413,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -419,8 +456,8 @@ py_test( tags = ["no_pip"], deps = [ ":reader_dataset_ops_test_base", + ":stats_dataset_test_base", "//tensorflow/contrib/data/python/ops:stats_ops", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", @@ -430,6 +467,16 @@ py_test( ], ) +py_library( + name = "stats_dataset_test_base", + srcs = ["stats_dataset_test_base.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "threadpool_dataset_ops_test", size = "small", @@ -465,6 +512,28 @@ py_test( ], ) +py_test( + name = "window_dataset_op_test", + size = "medium", + srcs = ["window_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/contrib/data/python/ops:grouping", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + py_test( name = "writer_ops_test", size = "small", diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index af97fbf87aee5f7005f9d266ba9b1b6cf109a2ec..42adfd17f07e508f25d8b351c791fa519eca8bd9 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -293,7 +293,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): ph2: np.arange(8).astype(np.int32) }) with self.assertRaises(errors.InvalidArgumentError): - print(sess.run(next_element)) + sess.run(next_element) # No 0th dimension (i.e. scalar value) for one component. sess.run( @@ -303,7 +303,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): ph2: 7 }) with self.assertRaises(errors.InvalidArgumentError): - print(sess.run(next_element)) + sess.run(next_element) def testBatchAndDropRemainder(self): components = (np.arange(7), diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 5fc7e51d814901985d33525b782434386c3ad18a..2022c1f2bdd09cdf43a993b3666335ce468a40ba 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -616,7 +616,44 @@ class BucketBySequenceLength(test.TestCase): batch_sizes = batch_sizes[:-1] self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) - self.assertEqual(sorted(boundaries), sorted(lengths_val)) + self.assertEqual([boundary - 1 for boundary in sorted(boundaries)], + sorted(lengths_val)) + + def testPadToBoundaryNoExtraneousPadding(self): + + boundaries = [3, 7, 11] + batch_sizes = [2, 2, 2, 2] + lengths = range(1, 11) + + def element_gen(): + for length in lengths: + yield ([1] * length,) + + element_len = lambda element: array_ops.shape(element)[0] + dataset = dataset_ops.Dataset.from_generator( + element_gen, (dtypes.int64,), ([None],)).apply( + grouping.bucket_by_sequence_length( + element_len, boundaries, batch_sizes, + pad_to_bucket_boundary=True)) + batch, = dataset.make_one_shot_iterator().get_next() + + with self.test_session() as sess: + batches = [] + for _ in range(5): + batches.append(sess.run(batch)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(batch) + + self.assertAllEqual(batches[0], [[1, 0], + [1, 1]]) + self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0]]) + self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1]]) + self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) def testTupleElements(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py index df115175f5046803ada036563be1ca802f7ad0cd..2a0e64caeb61c5a7d45669783ace4588746c19e3 100644 --- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -18,10 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gzip import os import string import tempfile import time +import zlib import numpy as np @@ -62,18 +64,29 @@ class CsvDatasetOpTest(test.TestCase): op2 = sess.run(next2) self.assertAllEqual(op1, op2) - def setup_files(self, inputs, linebreak='\n'): + def _setup_files(self, inputs, linebreak='\n', compression_type=None): filenames = [] for i, ip in enumerate(inputs): fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i) - with open(fn, 'wb') as f: - f.write(linebreak.join(ip).encode('utf-8')) + contents = linebreak.join(ip).encode('utf-8') + if compression_type is None: + with open(fn, 'wb') as f: + f.write(contents) + elif compression_type == 'GZIP': + with gzip.GzipFile(fn, 'wb') as f: + f.write(contents) + elif compression_type == 'ZLIB': + contents = zlib.compress(contents) + with open(fn, 'wb') as f: + f.write(contents) + else: + raise ValueError('Unsupported compression_type', compression_type) filenames.append(fn) return filenames def _make_test_datasets(self, inputs, **kwargs): # Test by comparing its output to what we could get with map->decode_csv - filenames = self.setup_files(inputs) + filenames = self._setup_files(inputs) dataset_expected = core_readers.TextLineDataset(filenames) dataset_expected = dataset_expected.map( lambda l: parsing_ops.decode_csv(l, **kwargs)) @@ -112,15 +125,18 @@ class CsvDatasetOpTest(test.TestCase): except errors.OutOfRangeError: break - def _test_dataset(self, - inputs, - expected_output=None, - expected_err_re=None, - linebreak='\n', - **kwargs): + def _test_dataset( + self, + inputs, + expected_output=None, + expected_err_re=None, + linebreak='\n', + compression_type=None, # Used for both setup and parsing + **kwargs): """Checks that elements produced by CsvDataset match expected output.""" # Convert str type because py3 tf strings are bytestrings - filenames = self.setup_files(inputs, linebreak) + filenames = self._setup_files(inputs, linebreak, compression_type) + kwargs['compression_type'] = compression_type with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = readers.CsvDataset(filenames, **kwargs) @@ -174,7 +190,7 @@ class CsvDatasetOpTest(test.TestCase): def testCsvDataset_ignoreErrWithUnescapedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']] - filenames = self.setup_files(inputs) + filenames = self._setup_files(inputs) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) @@ -184,7 +200,7 @@ class CsvDatasetOpTest(test.TestCase): def testCsvDataset_ignoreErrWithUnquotedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] - filenames = self.setup_files(inputs) + filenames = self._setup_files(inputs) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) @@ -355,7 +371,7 @@ class CsvDatasetOpTest(test.TestCase): '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19', '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19' ]] - file_path = self.setup_files(data) + file_path = self._setup_files(data) with ops.Graph().as_default() as g: ds = readers.make_csv_dataset( @@ -432,14 +448,29 @@ class CsvDatasetOpTest(test.TestCase): record_defaults=record_defaults, buffer_size=0) - def testCsvDataset_withBufferSize(self): + def _test_dataset_on_buffer_sizes(self, + inputs, + expected, + linebreak, + record_defaults, + compression_type=None, + num_sizes_to_test=20): + # Testing reading with a range of buffer sizes that should all work. + for i in list(range(1, 1 + num_sizes_to_test)) + [None]: + self._test_dataset( + inputs, + expected, + linebreak=linebreak, + compression_type=compression_type, + record_defaults=record_defaults, + buffer_size=i) + + def testCsvDataset_withLF(self): record_defaults = [['NA']] * 3 inputs = [['abc,def,ghi', '0,1,2', ',,']] expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, expected, record_defaults=record_defaults, buffer_size=i + 1) + self._test_dataset_on_buffer_sizes( + inputs, expected, linebreak='\n', record_defaults=record_defaults) def testCsvDataset_withCR(self): # Test that when the line separator is '\r', parsing works with all buffer @@ -447,14 +478,8 @@ class CsvDatasetOpTest(test.TestCase): record_defaults = [['NA']] * 3 inputs = [['abc,def,ghi', '0,1,2', ',,']] expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\r', - record_defaults=record_defaults, - buffer_size=i + 1) + self._test_dataset_on_buffer_sizes( + inputs, expected, linebreak='\r', record_defaults=record_defaults) def testCsvDataset_withCRLF(self): # Test that when the line separator is '\r\n', parsing works with all buffer @@ -462,29 +487,15 @@ class CsvDatasetOpTest(test.TestCase): record_defaults = [['NA']] * 3 inputs = [['abc,def,ghi', '0,1,2', ',,']] expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\r\n', - record_defaults=record_defaults, - buffer_size=i + 1) + self._test_dataset_on_buffer_sizes( + inputs, expected, linebreak='\r\n', record_defaults=record_defaults) def testCsvDataset_withBufferSizeAndQuoted(self): record_defaults = [['NA']] * 3 inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\n', - record_defaults=record_defaults, - buffer_size=i + 1) - self._test_dataset( + self._test_dataset_on_buffer_sizes( inputs, expected, linebreak='\n', record_defaults=record_defaults) def testCsvDataset_withCRAndQuoted(self): @@ -494,15 +505,7 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\r', - record_defaults=record_defaults, - buffer_size=i + 1) - self._test_dataset( + self._test_dataset_on_buffer_sizes( inputs, expected, linebreak='\r', record_defaults=record_defaults) def testCsvDataset_withCRLFAndQuoted(self): @@ -512,17 +515,33 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\r\n', - record_defaults=record_defaults, - buffer_size=i + 1) - self._test_dataset( + self._test_dataset_on_buffer_sizes( inputs, expected, linebreak='\r\n', record_defaults=record_defaults) + def testCsvDataset_withGzipCompressionType(self): + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + self._test_dataset_on_buffer_sizes( + inputs, + expected, + linebreak='\r\n', + compression_type='GZIP', + record_defaults=record_defaults) + + def testCsvDataset_withZlibCompressionType(self): + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + self._test_dataset_on_buffer_sizes( + inputs, + expected, + linebreak='\r\n', + compression_type='ZLIB', + record_defaults=record_defaults) + class CsvDatasetBenchmark(test.Benchmark): """Benchmarks for the various ways of creating a dataset from CSV files. diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py index 87b7c6ddb7afcbaaf8fe97cd8be87e6f5af8cd4d..e6883d53e02c0f96d966a52abfe2f9b4118f2e12 100644 --- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py @@ -17,9 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized +import numpy as np + from tensorflow.contrib.data.python.ops import get_single_element +from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor @@ -27,40 +30,69 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class GetSingleElementTest(test.TestCase): +class GetSingleElementTest(test.TestCase, parameterized.TestCase): - def testGetSingleElement(self): - skip_value = array_ops.placeholder(dtypes.int64, shape=[]) - take_value = array_ops.placeholder_with_default( - constant_op.constant(1, dtype=dtypes.int64), shape=[]) + @parameterized.named_parameters( + ("Zero", 0, 1), + ("Five", 5, 1), + ("Ten", 10, 1), + ("Empty", 100, 1, errors.InvalidArgumentError, "Dataset was empty."), + ("MoreThanOne", 0, 2, errors.InvalidArgumentError, + "Dataset had more than one element."), + ) + def testGetSingleElement(self, skip, take, error=None, error_msg=None): + skip_t = array_ops.placeholder(dtypes.int64, shape=[]) + take_t = array_ops.placeholder(dtypes.int64, shape=[]) def make_sparse(x): x_1d = array_ops.reshape(x, [1]) x_2d = array_ops.reshape(x, [1, 1]) return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d) - dataset = (dataset_ops.Dataset.range(100) - .skip(skip_value) - .map(lambda x: (x * x, make_sparse(x))) - .take(take_value)) - + dataset = dataset_ops.Dataset.range(100).skip(skip_t).map( + lambda x: (x * x, make_sparse(x))).take(take_t) element = get_single_element.get_single_element(dataset) with self.test_session() as sess: - for x in [0, 5, 10]: - dense_val, sparse_val = sess.run(element, feed_dict={skip_value: x}) - self.assertEqual(x * x, dense_val) - self.assertAllEqual([[x]], sparse_val.indices) - self.assertAllEqual([x], sparse_val.values) - self.assertAllEqual([x], sparse_val.dense_shape) - - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Dataset was empty."): - sess.run(element, feed_dict={skip_value: 100}) - - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Dataset had more than one element."): - sess.run(element, feed_dict={skip_value: 0, take_value: 2}) + if error is None: + dense_val, sparse_val = sess.run( + element, feed_dict={ + skip_t: skip, + take_t: take + }) + self.assertEqual(skip * skip, dense_val) + self.assertAllEqual([[skip]], sparse_val.indices) + self.assertAllEqual([skip], sparse_val.values) + self.assertAllEqual([skip], sparse_val.dense_shape) + else: + with self.assertRaisesRegexp(error, error_msg): + sess.run(element, feed_dict={skip_t: skip, take_t: take}) + + @parameterized.named_parameters( + ("SumZero", 0), + ("SumOne", 1), + ("SumFive", 5), + ("SumTen", 10), + ) + def testReduceDataset(self, stop): + def init_fn(_): + return np.int64(0) + + def reduce_fn(state, value): + return state + value + + def finalize_fn(state): + return state + + sum_reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn) + + stop_t = array_ops.placeholder(dtypes.int64, shape=[]) + dataset = dataset_ops.Dataset.range(stop_t) + element = get_single_element.reduce_dataset(dataset, sum_reducer) + + with self.test_session() as sess: + value = sess.run(element, feed_dict={stop_t: stop}) + self.assertEqual(stop * (stop - 1) / 2, value) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index 30a993b1f7056b9726f524b2279131339c80c5eb..77148aceec7fa90f927a9c009671c2939460877b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util @@ -55,7 +56,7 @@ class CheckpointInputPipelineHookTest(test.TestCase): def _read_vars(self, model_dir): """Returns (global_step, latest_feature).""" with ops.Graph().as_default() as g: - ckpt_path = saver_lib.latest_checkpoint(model_dir) + ckpt_path = checkpoint_management.latest_checkpoint(model_dir) meta_filename = ckpt_path + '.meta' saver_lib.import_meta_graph(meta_filename) saver = saver_lib.Saver() diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index 270a2297b4d7b4fc44e3d1fa0aea8c9dfa5f39d3..009e21a34c8df86af6abbb7599dbcfa23ddf90a7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -17,19 +17,29 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import hashlib +import itertools import os +import time import numpy as np +from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import io_ops +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test from tensorflow.python.util import compat +_NUMPY_RANDOM_SEED = 42 + class MapDatasetTest(test.TestCase): @@ -70,18 +80,21 @@ class MapDatasetTest(test.TestCase): sess.run(get_next) def testReadFileIgnoreError(self): + def write_string_to_file(value, filename): with open(filename, "w") as f: f.write(value) - filenames = [os.path.join(self.get_temp_dir(), "file_%d.txt" % i) - for i in range(5)] + + filenames = [ + os.path.join(self.get_temp_dir(), "file_%d.txt" % i) for i in range(5) + ] for filename in filenames: write_string_to_file(filename, filename) dataset = ( dataset_ops.Dataset.from_tensor_slices(filenames).map( - io_ops.read_file, num_parallel_calls=2).prefetch(2).apply( - error_ops.ignore_errors())) + io_ops.read_file, + num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors())) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -135,5 +148,211 @@ class MapDatasetTest(test.TestCase): sess.run(get_next) +class MapDatasetBenchmark(test.Benchmark): + + # The purpose of this benchmark is to compare the performance of chaining vs + # fusing of the map and batch transformations across various configurations. + # + # NOTE: It is recommended to build the benchmark with + # `-c opt --copt=-mavx --copt=-mavx2 --copt=-mfma --copt=-gmlt` + # and execute it on a machine with at least 32 CPU cores. + def benchmarkMapAndBatch(self): + + # Sequential pipeline configurations. + seq_elem_size_series = itertools.product([1], [1], [1, 2, 4, 8], [16]) + seq_batch_size_series = itertools.product([1], [1], [1], [8, 16, 32, 64]) + + # Parallel pipeline configuration. + par_elem_size_series = itertools.product([32], [32], [1, 2, 4, 8], [256]) + par_batch_size_series = itertools.product([32], [32], [1], + [128, 256, 512, 1024]) + par_num_calls_series = itertools.product([8, 16, 32, 64], [32], [1], [512]) + par_inter_op_series = itertools.product([32], [8, 16, 32, 64], [1], [512]) + + def name(method, label, num_calls, inter_op, element_size, batch_size): + return ("%s_id_%s_num_calls_%d_inter_op_%d_elem_size_%d_batch_size_%d" % ( + method, + hashlib.sha1(label).hexdigest(), + num_calls, + inter_op, + element_size, + batch_size, + )) + + def benchmark(label, series): + + print("%s:" % label) + for num_calls, inter_op, element_size, batch_size in series: + + num_iters = 1024 // ( + (element_size * batch_size) // min(num_calls, inter_op)) + k = 1024 * 1024 + dataset = dataset_ops.Dataset.from_tensors((np.random.rand( + element_size, 4 * k), np.random.rand(4 * k, 1))).repeat() + + chained_dataset = dataset.map( + math_ops.matmul, + num_parallel_calls=num_calls).batch(batch_size=batch_size) + chained_iterator = chained_dataset.make_one_shot_iterator() + chained_get_next = chained_iterator.get_next() + + chained_deltas = [] + with session.Session( + config=config_pb2.ConfigProto( + inter_op_parallelism_threads=inter_op, + use_per_session_threads=True)) as sess: + for _ in range(5): + sess.run(chained_get_next.op) + for _ in range(num_iters): + start = time.time() + sess.run(chained_get_next.op) + end = time.time() + chained_deltas.append(end - start) + + fused_dataset = dataset = dataset.apply( + batching.map_and_batch( + math_ops.matmul, + num_parallel_calls=num_calls, + batch_size=batch_size)) + fused_iterator = fused_dataset.make_one_shot_iterator() + fused_get_next = fused_iterator.get_next() + + fused_deltas = [] + with session.Session( + config=config_pb2.ConfigProto( + inter_op_parallelism_threads=inter_op, + use_per_session_threads=True)) as sess: + + for _ in range(5): + sess.run(fused_get_next.op) + for _ in range(num_iters): + start = time.time() + sess.run(fused_get_next.op) + end = time.time() + fused_deltas.append(end - start) + + print( + "batch size: %d, num parallel calls: %d, inter-op parallelism: %d, " + "element size: %d, num iters: %d\nchained wall time: %f (median), " + "%f (mean), %f (stddev), %f (min), %f (max)\n fused wall time: " + "%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n " + "chained/fused: %.2fx (median), %.2fx (mean)" % + (batch_size, num_calls, inter_op, element_size, num_iters, + np.median(chained_deltas), np.mean(chained_deltas), + np.std(chained_deltas), np.min(chained_deltas), + np.max(chained_deltas), np.median(fused_deltas), + np.mean(fused_deltas), np.std(fused_deltas), np.min(fused_deltas), + np.max(fused_deltas), + np.median(chained_deltas) / np.median(fused_deltas), + np.mean(chained_deltas) / np.mean(fused_deltas))) + + self.report_benchmark( + iters=num_iters, + wall_time=np.median(chained_deltas), + name=name("chained", label, num_calls, inter_op, element_size, + batch_size)) + + self.report_benchmark( + iters=num_iters, + wall_time=np.median(fused_deltas), + name=name("fused", label, num_calls, inter_op, element_size, + batch_size)) + + print("") + + np.random.seed(_NUMPY_RANDOM_SEED) + benchmark("Sequential element size evaluation", seq_elem_size_series) + benchmark("Sequential batch size evaluation", seq_batch_size_series) + benchmark("Parallel element size evaluation", par_elem_size_series) + benchmark("Parallel batch size evaluation", par_batch_size_series) + benchmark("Transformation parallelism evaluation", par_num_calls_series) + benchmark("Threadpool size evaluation", par_inter_op_series) + + # This benchmark compares the performance of pipeline with multiple chained + # maps with and without map fusion. + def benchmarkChainOfMaps(self): + chain_lengths = [0, 1, 2, 5, 10, 20, 50] + for chain_length in chain_lengths: + self._benchmarkChainOfMaps(chain_length, False) + self._benchmarkChainOfMaps(chain_length, True) + + def _benchmarkChainOfMaps(self, chain_length, optimize_dataset): + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors(0).repeat(None) + for _ in range(chain_length): + dataset = dataset.map(lambda x: x) + if optimize_dataset: + dataset = dataset.apply(optimization.optimize(["map_fusion"])) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for _ in range(5): + sess.run(next_element.op) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100 + opt_mark = "opt" if optimize_dataset else "no-opt" + print("Map dataset {} chain length: {} Median wall time: {}".format( + opt_mark, chain_length, median_wall_time)) + self.report_benchmark( + iters=1000, + wall_time=median_wall_time, + name="benchmark_map_dataset_chain_latency_{}_{}".format( + opt_mark, chain_length)) + + +class MapAndFilterBenchmark(test.Benchmark): + + # This benchmark compares the performance of pipeline with multiple chained + # map + filter with and without map fusion. + def benchmarkMapAndFilter(self): + chain_lengths = [0, 1, 2, 5, 10, 20, 50] + for chain_length in chain_lengths: + self._benchmarkMapAndFilter(chain_length, False) + self._benchmarkMapAndFilter(chain_length, True) + + def _benchmarkMapAndFilter(self, chain_length, optimize_dataset): + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors(0).repeat(None) + for _ in range(chain_length): + dataset = dataset.map(lambda x: x + 5).filter( + lambda x: math_ops.greater_equal(x - 5, 0)) + if optimize_dataset: + dataset = dataset.apply( + optimization.optimize(["map_and_filter_fusion"])) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for _ in range(10): + sess.run(next_element.op) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100 + opt_mark = "opt" if optimize_dataset else "no-opt" + print("Map and filter dataset {} chain length: {} Median wall time: {}". + format(opt_mark, chain_length, median_wall_time)) + self.report_benchmark( + iters=1000, + wall_time=median_wall_time, + name="benchmark_map_and_filter_dataset_chain_latency_{}_{}".format( + opt_mark, chain_length)) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a711325daed12f45e4e533f18ee81adc7dec93be --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py @@ -0,0 +1,126 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MapDefunOp.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import map_defun +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class MapDefunTest(test.TestCase): + + def testMapDefun_Simple(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + 3 + + with self.test_session(): + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0] + expected = elems * 2 + 3 + self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) + + def testMapDefun_MismatchedTypes(self): + + @function.Defun(dtypes.int32) + def fn(x): + return math_ops.cast(x, dtypes.float64) + + with self.test_session(): + nums = [1, 2, 3, 4, 5, 6] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0] + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(r) + + def testMapDefun_MultipleOutputs(self): + + @function.Defun(dtypes.int32) + def fn(x): + return (x, math_ops.cast(x * 2 + 3, dtypes.float64)) + + with self.test_session(): + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], + [(2,), (2,)]) + expected = [elems, elems * 2 + 3] + self.assertAllEqual(self.evaluate(r), self.evaluate(expected)) + + def testMapDefun_ShapeInference(self): + + @function.Defun(dtypes.int32) + def fn(x): + return x + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0] + self.assertEqual(result.get_shape(), (3, 2)) + + def testMapDefun_PartialShapeInference(self): + + @function.Defun(dtypes.int32) + def fn(x): + return x + + elems = array_ops.placeholder(dtypes.int64, (None, 2)) + result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)]) + self.assertEqual(result[0].get_shape().as_list(), [None, 2]) + + def testMapDefun_RaisesErrorOnRuntimeShapeMismatch(self): + + @function.Defun(dtypes.int32, dtypes.int32) + def fn(x, y): + return x, y + + elems1 = array_ops.placeholder(dtypes.int32) + elems2 = array_ops.placeholder(dtypes.int32) + result = map_defun.map_defun(fn, [elems1, elems2], + [dtypes.int32, dtypes.int32], [(), ()]) + with self.test_session() as sess: + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, + "All inputs must have the same dimension 0."): + sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]}) + + def testMapDefun_RaisesDefunError(self): + + @function.Defun(dtypes.int32) + def fn(x): + with ops.control_dependencies([check_ops.assert_equal(x, 0)]): + return array_ops.identity(x) + + elems = constant_op.constant([0, 0, 0, 37, 0]) + result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()]) + with self.test_session(): + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(result) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py index e35be8a23f3706bd170c09b967b4f419fc9a626e..ae147b4fa79c5fc8e63e1860f45036709ecc9777 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -17,60 +17,265 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + +from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base from tensorflow.contrib.data.python.ops import optimization -from tensorflow.core.framework import graph_pb2 +from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class OptimizeDatasetTest(test.TestCase): +class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): + + def testAssertSuffix(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Map"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + self.assertEqual(0, sess.run(get_next)) + + def testAssertSuffixInvalid(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Whoops"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Asserted Whoops transformation at offset 0 but encountered " + "Map transformation instead."): + sess.run(get_next) + + def testAssertSuffixShort(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Map", "Whoops"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Asserted next 2 transformations but encountered only 1."): + sess.run(get_next) def testDefaultOptimizations(self): - dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( - 10).apply(optimization.optimize()) + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply( + optimization.optimize()) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() with self.test_session() as sess: - graph = graph_pb2.GraphDef().FromString( - sess.run(dataset._as_serialized_graph())) - self.assertTrue( - all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testEmptyOptimizations(self): - dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( - 10).apply(optimization.optimize([])) + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply( + optimization.optimize([])) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() with self.test_session() as sess: - graph = graph_pb2.GraphDef().FromString( - sess.run(dataset._as_serialized_graph())) - self.assertTrue( - all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testOptimization(self): - dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( - 10).apply(optimization.optimize(["map_and_batch_fusion"])) + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["MapAndBatch"])).map(lambda x: x * x).batch(10).apply( + optimization.optimize(["map_and_batch_fusion"])) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() with self.test_session() as sess: - graph = graph_pb2.GraphDef().FromString( - sess.run(dataset._as_serialized_graph())) - self.assertTrue( - any([node.op == "MapAndBatchDatasetV2" for node in graph.node])) self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testFunctionLibraryDefinitionModification(self): + dataset = dataset_ops.Dataset.from_tensors(0).map(lambda x: x).apply( + optimization.optimize(["_test_only_function_rename"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaisesRegexp(errors.NotFoundError, + "Function .* is not defined."): + sess.run(get_next) + + @staticmethod + def map_functions(): + identity = lambda x: x + increment = lambda x: x + 1 + + def increment_and_square(x): + y = x + 1 + return y * y + + functions = [identity, increment, increment_and_square] + tests = [] + for i, fun1 in enumerate(functions): + for j, fun2 in enumerate(functions): + tests.append(( + "test_{}_{}".format(i, j), + [fun1, fun2], + )) + for k, fun3 in enumerate(functions): + tests.append(( + "test_{}_{}_{}".format(i, j, k), + [fun1, fun2, fun3], + )) + + swap = lambda x, n: (n, x) + tests.append(( + "swap1", + [lambda x: (x, 42), swap], + )) + tests.append(( + "swap2", + [lambda x: (x, 42), swap, swap], + )) + return tuple(tests) + + @parameterized.named_parameters(*map_functions.__func__()) + def testMapFusion(self, functions): + dataset = dataset_ops.Dataset.range(5).apply( + optimization.assert_next(["Map", "Prefetch"])) + for function in functions: + dataset = dataset.map(function) + + dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.test_session() as sess: + for x in range(5): + result = sess.run(get_next) + r = x + for function in functions: + if isinstance(r, tuple): + r = function(*r) # Pass tuple as multiple arguments. + else: + r = function(r) + self.assertAllEqual(r, result) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + @staticmethod + def map_and_filter_functions(): + identity = lambda x: x + increment = lambda x: x + 1 + minus_five = lambda x: x - 5 + + def increment_and_square(x): + y = x + 1 + return y * y + + take_all = lambda x: constant_op.constant(True) + is_zero = lambda x: math_ops.equal(x, 0) + is_odd = lambda x: math_ops.equal(x % 2, 0) + greater = lambda x: math_ops.greater(x + 5, 0) + + functions = [identity, increment, minus_five, increment_and_square] + filters = [take_all, is_zero, is_odd, greater] + tests = [] + + for x, fun in enumerate(functions): + for y, predicate in enumerate(filters): + tests.append(("mixed_{}_{}".format(x, y), fun, predicate)) + + # Multi output + tests.append(("multiOne", lambda x: (x, x), + lambda x, y: constant_op.constant(True))) + tests.append( + ("multiTwo", lambda x: (x, 2), + lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0))) + return tuple(tests) + + @parameterized.named_parameters(*map_and_filter_functions.__func__()) + def testMapFilterFusion(self, function, predicate): + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", + "FilterByLastComponent"])).map(function).filter(predicate).apply( + optimization.optimize(["map_and_filter_fusion"])) + self._testMapAndFilter(dataset, function, predicate) + + def _testMapAndFilter(self, dataset, function, predicate): + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.test_session() as sess: + for x in range(10): + r = function(x) + if isinstance(r, tuple): + b = predicate(*r) # Pass tuple as multiple arguments. + else: + b = predicate(r) + if sess.run(b): + result = sess.run(get_next) + self.assertAllEqual(r, result) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testAdditionalInputs(self): + a = constant_op.constant(3, dtype=dtypes.int64) + b = constant_op.constant(4, dtype=dtypes.int64) + some_tensor = math_ops.mul(a, b) + function = lambda x: x * x + + def predicate(y): + return math_ops.less(math_ops.cast(y, dtypes.int64), some_tensor) + + # We are currently not supporting functions with additional inputs. + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", "Filter"])).map(function).filter(predicate).apply( + optimization.optimize(["map_and_filter_fusion"])) + + self._testMapAndFilter(dataset, function, predicate) + + +class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): + + def testLatencyStatsOptimization(self): + + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.from_tensors(1).apply( + optimization.assert_next( + ["LatencyStats", "Map", "LatencyStats", "Prefetch", + "LatencyStats"])).map(lambda x: x * x).prefetch(1).apply( + optimization.optimize(["latency_all_edges"])).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + get_next = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(iterator.initializer) + self.assertEqual(1 * 1, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, + "record_latency_TensorDataset/_1", 1) + self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4", + 1) + self._assertSummaryHasCount(summary_str, + "record_latency_PrefetchDataset/_6", 1) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index 20ed6397505dbd77dbfe686147391c18b62c8718..d66305d7326f78d1e414b6076c1ca6a029baa2f7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -21,6 +21,7 @@ import threading from tensorflow.contrib.data.python.ops import prefetching_ops from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.compat import compat from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op @@ -30,6 +31,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test @@ -86,8 +88,7 @@ class PrefetchingKernelsOpsTest(test.TestCase): return (prefetch_op, reset_op, destroy_op) def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False) prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name, @@ -126,8 +127,7 @@ class PrefetchingKernelsOpsTest(test.TestCase): "/job:localhost/replica:0/task:0/gpu:0") def testReinitialization(self): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) device0 = "/job:localhost/replica:0/task:0/cpu:0" device1 = "/job:localhost/replica:0/task:0/cpu:1" @@ -167,8 +167,7 @@ class PrefetchingKernelsOpsTest(test.TestCase): sess.run(destroy_op) def testReinitializationOutOfRange(self): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) device0 = "/job:localhost/replica:0/task:0/cpu:0" device1 = "/job:localhost/replica:0/task:0/cpu:1" @@ -237,9 +236,9 @@ class PrefetchingKernelsOpsTest(test.TestCase): buffer_resource_handle, ignore_lookup_error=True) with self.test_session() as sess: - self.assertEqual(["a"], sess.run(prefetch_op)) - self.assertEqual(["b"], sess.run(prefetch_op)) - self.assertEqual(["c"], sess.run(prefetch_op)) + self.assertEqual([b"a"], sess.run(prefetch_op)) + self.assertEqual([b"b"], sess.run(prefetch_op)) + self.assertEqual([b"c"], sess.run(prefetch_op)) with self.assertRaises(errors.OutOfRangeError): sess.run(prefetch_op) @@ -271,8 +270,7 @@ class PrefetchToDeviceTest(test.TestCase): self.assertEqual(dtypes.int64, next_element.dtype) self.assertEqual([], next_element.shape) - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config) as sess: for i in range(10): self.assertEqual(i, sess.run(next_element)) @@ -332,8 +330,7 @@ class PrefetchToDeviceTest(test.TestCase): self.assertEqual(dtypes.int64, next_element["a"].dtype) self.assertEqual([], next_element["a"].shape) - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config) as sess: for i in range(10): self.assertEqual({"a": i}, sess.run(next_element)) @@ -366,8 +363,7 @@ class PrefetchToDeviceTest(test.TestCase): next_element = iterator.get_next() self.assertEqual(dtypes.int64, next_element.dtype) - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config) as sess: for i in range(10): actual = sess.run(next_element) @@ -417,8 +413,7 @@ class PrefetchToDeviceTest(test.TestCase): self.assertEqual(dtypes.int64, next_element.dtype) self.assertEqual([], next_element.shape) - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config) as sess: sess.run(iterator.initializer) for i in range(5): @@ -451,5 +446,653 @@ class PrefetchToDeviceTest(test.TestCase): sess.run(next_element) +class CopyToDeviceTest(test.TestCase): + + def testCopyToDevice(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceInt32(self): + host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int32, next_element.dtype) + self.assertEqual((4,), next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToSameDevice(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:0")) + + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceWithPrefetch(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyDictToDevice(self): + host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element["a"].dtype) + self.assertEqual([], next_element["a"].shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual({"a": i}, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyDictToDeviceWithPrefetch(self): + host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element["a"].dtype) + self.assertEqual([], next_element["a"].shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual({"a": i}, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopySparseTensorsToDevice(self): + + def make_tensor(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2]) + + host_dataset = dataset_ops.Dataset.range(10).map(make_tensor) + + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + actual = sess.run(next_element) + self.assertAllEqual([i], actual.values) + self.assertAllEqual([[0, 0]], actual.indices) + self.assertAllEqual([2, 2], actual.dense_shape) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopySparseTensorsToDeviceWithPrefetch(self): + + def make_tensor(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2]) + + host_dataset = dataset_ops.Dataset.range(10).map(make_tensor) + + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + for i in range(10): + actual = sess.run(next_element) + self.assertAllEqual([i], actual.values) + self.assertAllEqual([[0, 0]], actual.indices) + self.assertAllEqual([2, 2], actual.dense_shape) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpu(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuWithPrefetch(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")).prefetch(1) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuInt32(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuInt32AndPrefetch(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")).prefetch(1) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuStrings(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"]) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuStringsAndPrefetch(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"]) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDevicePingPongCPUGPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + with compat.forward_compatibility_horizon(2018, 8, 4): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0", source_device="/cpu:0")) + back_to_cpu_dataset = device_dataset.apply( + prefetching_ops.copy_to_device("/cpu:0", source_device="/gpu:0")) + + with ops.device("/cpu:0"): + iterator = back_to_cpu_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceWithReInit(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceWithReInitAndPrefetch(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) + + with ops.device("/cpu:1"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=worker_config) as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuWithReInit(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testCopyToDeviceGpuWithReInitAndPrefetch(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")).prefetch(1) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testIteratorGetNextAsOptionalOnGPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(3) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_elem = iterator_ops.get_next_as_optional(iterator) + elem_has_value_t = next_elem.has_value() + elem_value_t = next_elem.get_value() + + with self.test_session() as sess: + # Before initializing the iterator, evaluating the optional fails with + # a FailedPreconditionError. + with self.assertRaises(errors.FailedPreconditionError): + sess.run(elem_has_value_t) + with self.assertRaises(errors.FailedPreconditionError): + sess.run(elem_value_t) + + # For each element of the dataset, assert that the optional evaluates to + # the expected value. + sess.run(iterator.initializer) + for i in range(3): + elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t]) + self.assertTrue(elem_has_value) + self.assertEqual(i, elem_value) + + # After exhausting the iterator, `next_elem.has_value()` will evaluate to + # false, and attempting to get the value will fail. + for _ in range(2): + self.assertFalse(sess.run(elem_has_value_t)) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(elem_value_t) + + +class MultiDeviceIteratorTest(test.TestCase): + + def testBasic(self): + dataset = dataset_ops.Dataset.range(10) + multi_device_iterator = prefetching_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/cpu:2"]) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + + config = config_pb2.ConfigProto(device_count={"CPU": 3}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 10, 2): + self.assertEqual(i, sess.run(elem_on_1)) + self.assertEqual(i + 1, sess.run(elem_on_2)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + + def testOneOnSameDevice(self): + with ops.device("/cpu:0"): + dataset = dataset_ops.Dataset.range(10) + multi_device_iterator = prefetching_ops.MultiDeviceIterator( + dataset, ["/cpu:0", "/cpu:1"]) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + + config = config_pb2.ConfigProto(device_count={"CPU": 2}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 10, 2): + self.assertEqual(i, sess.run(elem_on_1)) + self.assertEqual(i + 1, sess.run(elem_on_2)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + + def testRepeatDevices(self): + with ops.device("/cpu:0"): + dataset = dataset_ops.Dataset.range(20) + multi_device_iterator = prefetching_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/cpu:2", "/cpu:1", "/cpu:2"]) + elements = multi_device_iterator.get_next() + elem_on_1, elem_on_2, elem_on_3, elem_on_4 = elements + + config = config_pb2.ConfigProto(device_count={"CPU": 3}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 20, 4): + self.assertEqual(i, sess.run(elem_on_1)) + self.assertEqual(i + 1, sess.run(elem_on_2)) + self.assertEqual(i + 2, sess.run(elem_on_3)) + self.assertEqual(i + 3, sess.run(elem_on_4)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + sess.run(elem_on_3) + sess.run(elem_on_4) + + def testNotFullyDivisible(self): + dataset = dataset_ops.Dataset.range(9) + multi_device_iterator = prefetching_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/cpu:2"]) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + + config = config_pb2.ConfigProto(device_count={"CPU": 3}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 8, 2): + self.assertEqual(i, sess.run(elem_on_1)) + self.assertEqual(i + 1, sess.run(elem_on_2)) + self.assertEqual(8, sess.run(elem_on_1)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + + def testUneven(self): + dataset = dataset_ops.Dataset.range(10) + multi_device_iterator = prefetching_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/cpu:2"]) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + + config = config_pb2.ConfigProto(device_count={"CPU": 3}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 10, 2): + self.assertEqual(i, sess.run(elem_on_1)) + for i in range(0, 10, 2): + self.assertEqual(i + 1, sess.run(elem_on_2)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + + def testMultipleInitializations(self): + with ops.device("/cpu:0"): + epoch = array_ops.placeholder(dtypes.int64, shape=[]) + dataset1 = dataset_ops.Dataset.from_tensors(epoch).repeat(1000) + dataset2 = dataset_ops.Dataset.range(1000) + dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) + multi_device_iterator = prefetching_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + init_op = multi_device_iterator.initializer + + config = config_pb2.ConfigProto(device_count={"CPU": 3}) + with self.test_session(config=config) as sess: + for i in range(1000): + sess.run(init_op, feed_dict={epoch: i}) + self.assertEqual([(i, 0), (i, 1)], sess.run([elem_on_1, elem_on_2])) + + def testBasicGpu(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + with compat.forward_compatibility_horizon(2018, 8, 4): + dataset = dataset_ops.Dataset.range(10) + multi_device_iterator = prefetching_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/gpu:0"]) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + + config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 10, 2): + self.assertEqual(i, sess.run(elem_on_1)) + self.assertEqual(i + 1, sess.run(elem_on_2)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + + def testUnevenGpu(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + with compat.forward_compatibility_horizon(2018, 8, 4): + dataset = dataset_ops.Dataset.range(10) + multi_device_iterator = prefetching_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/gpu:0"]) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + + config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1}) + with self.test_session(config=config) as sess: + sess.run(multi_device_iterator.initializer) + for i in range(0, 10, 2): + self.assertEqual(i, sess.run(elem_on_1)) + for i in range(0, 10, 2): + self.assertEqual(i + 1, sess.run(elem_on_2)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elem_on_1) + sess.run(elem_on_2) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 9df403ef50e459d94b8edf3f651c7c95baf3ec42..15b342d30f85a05b3827998565ba5f84021ac885 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -17,13 +17,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gzip import os +import zlib import numpy as np from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base from tensorflow.contrib.data.python.ops import readers from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.data.util import nest from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -170,276 +173,383 @@ class ReadBatchFeaturesTest( for num_epochs in [1, 10]: with ops.Graph().as_default(): # Basic test: read from file 0. - self.outputs = self.make_batch_feature( + outputs = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, drop_final_batch=True).make_one_shot_iterator().get_next() - for _, tensor in self.outputs.items(): + for _, tensor in outputs.items(): if isinstance(tensor, ops.Tensor): # Guard against SparseTensor. self.assertEqual(tensor.shape[0], batch_size) + def testIndefiniteRepeatShapeInference(self): + dataset = self.make_batch_feature( + filenames=self.test_filenames[0], num_epochs=None, batch_size=32) + for shape, clazz in zip(nest.flatten(dataset.output_shapes), + nest.flatten(dataset.output_classes)): + if issubclass(clazz, ops.Tensor): + self.assertEqual(32, shape[0]) + class MakeCsvDatasetTest(test.TestCase): - COLUMN_TYPES = [ - dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string - ] - COLUMNS = ["col%d" % i for i in range(len(COLUMN_TYPES))] - DEFAULT_VALS = [[], [], [], [], ["NULL"]] - DEFAULTS = [ - constant_op.constant([], dtype=dtypes.int32), - constant_op.constant([], dtype=dtypes.int64), - constant_op.constant([], dtype=dtypes.float32), - constant_op.constant([], dtype=dtypes.float64), - constant_op.constant(["NULL"], dtype=dtypes.string) - ] - LABEL = COLUMNS[0] - - def setUp(self): - super(MakeCsvDatasetTest, self).setUp() - self._num_files = 2 - self._num_records = 11 - self._test_filenames = self._create_files() - - def _csv_values(self, fileno, recordno): - return [ - fileno, - recordno, - fileno * recordno * 0.5, - fileno * recordno + 0.5, - "record %d" % recordno if recordno % 2 == 1 else "", - ] + def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs): + return readers.make_csv_dataset( + filenames, batch_size=batch_size, num_epochs=num_epochs, **kwargs) - def _write_file(self, filename, rows): - for i in range(len(rows)): - if isinstance(rows[i], list): - rows[i] = ",".join(str(v) if v is not None else "" for v in rows[i]) - fn = os.path.join(self.get_temp_dir(), filename) - f = open(fn, "w") - f.write("\n".join(rows)) - f.close() - return fn - - def _create_file(self, fileno, header=True): - rows = [] - if header: - rows.append(self.COLUMNS) - for recno in range(self._num_records): - rows.append(self._csv_values(fileno, recno)) - return self._write_file("csv_file%d.csv" % fileno, rows) - - def _create_files(self): + def _setup_files(self, inputs, linebreak="\n", compression_type=None): filenames = [] - for i in range(self._num_files): - filenames.append(self._create_file(i)) + for i, ip in enumerate(inputs): + fn = os.path.join(self.get_temp_dir(), "temp_%d.csv" % i) + contents = linebreak.join(ip).encode("utf-8") + if compression_type is None: + with open(fn, "wb") as f: + f.write(contents) + elif compression_type == "GZIP": + with gzip.GzipFile(fn, "wb") as f: + f.write(contents) + elif compression_type == "ZLIB": + contents = zlib.compress(contents) + with open(fn, "wb") as f: + f.write(contents) + else: + raise ValueError("Unsupported compression_type", compression_type) + filenames.append(fn) return filenames - def _make_csv_dataset( - self, - filenames, - defaults, - column_names=COLUMNS, - label_name=LABEL, - select_cols=None, - batch_size=1, - num_epochs=1, - shuffle=False, - shuffle_seed=None, - header=True, - na_value="", - ): - return readers.make_csv_dataset( - filenames, - batch_size=batch_size, - column_names=column_names, - column_defaults=defaults, - label_name=label_name, - num_epochs=num_epochs, - shuffle=shuffle, - shuffle_seed=shuffle_seed, - header=header, - na_value=na_value, - select_columns=select_cols, - ) - - def _next_actual_batch(self, file_indices, batch_size, num_epochs, defaults): - features = {col: list() for col in self.COLUMNS} + def _next_expected_batch(self, expected_output, expected_keys, batch_size, + num_epochs): + features = {k: [] for k in expected_keys} for _ in range(num_epochs): - for i in file_indices: - for j in range(self._num_records): - values = self._csv_values(i, j) - for n, v in enumerate(values): - if v == "": # pylint: disable=g-explicit-bool-comparison - values[n] = defaults[n][0] - values[-1] = values[-1].encode("utf-8") - - # Regroup lists by column instead of row - for n, col in enumerate(self.COLUMNS): - features[col].append(values[n]) - if len(list(features.values())[0]) == batch_size: - yield features - features = {col: list() for col in self.COLUMNS} - - def _run_actual_batch(self, outputs, sess): - features, labels = sess.run(outputs) - batch = [features[k] for k in self.COLUMNS if k != self.LABEL] - batch.append(labels) - return batch - - def _verify_records( + for values in expected_output: + for n, key in enumerate(expected_keys): + features[key].append(values[n]) + if len(features[expected_keys[0]]) == batch_size: + yield features + features = {k: [] for k in expected_keys} + if features[expected_keys[0]]: # Leftover from the last batch + yield features + + def _verify_output( self, sess, dataset, - file_indices, - defaults=tuple(DEFAULT_VALS), - label_name=LABEL, - batch_size=1, - num_epochs=1, + batch_size, + num_epochs, + label_name, + expected_output, + expected_keys, ): - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() + nxt = dataset.make_one_shot_iterator().get_next() - for expected_features in self._next_actual_batch(file_indices, batch_size, - num_epochs, defaults): - actual_features = sess.run(get_next) + for expected_features in self._next_expected_batch( + expected_output, + expected_keys, + batch_size, + num_epochs, + ): + actual_features = sess.run(nxt) if label_name is not None: expected_labels = expected_features.pop(label_name) - # Compare labels self.assertAllEqual(expected_labels, actual_features[1]) - actual_features = actual_features[0] # Extract features dict from tuple + actual_features = actual_features[0] for k in expected_features.keys(): # Compare features self.assertAllEqual(expected_features[k], actual_features[k]) with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testMakeCSVDataset(self): - defaults = self.DEFAULTS - - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - # Basic test: read from file 0. - dataset = self._make_csv_dataset(self._test_filenames[0], defaults) - self._verify_records(sess, dataset, [0]) - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - # Basic test: read from file 1. - dataset = self._make_csv_dataset(self._test_filenames[1], defaults) - self._verify_records(sess, dataset, [1]) + sess.run(nxt) + + def _test_dataset(self, + inputs, + expected_output, + expected_keys, + batch_size=1, + num_epochs=1, + label_name=None, + **kwargs): + """Checks that elements produced by CsvDataset match expected output.""" + # Convert str type because py3 tf strings are bytestrings + filenames = self._setup_files( + inputs, compression_type=kwargs.get("compression_type", None)) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: - # Read from both files. - dataset = self._make_csv_dataset(self._test_filenames, defaults) - self._verify_records(sess, dataset, range(self._num_files)) - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - # Read from both files. Exercise the `batch` and `num_epochs` parameters - # of make_csv_dataset and make sure they work. dataset = self._make_csv_dataset( - self._test_filenames, defaults, batch_size=2, num_epochs=10) - self._verify_records( - sess, dataset, range(self._num_files), batch_size=2, num_epochs=10) + filenames, + batch_size=batch_size, + num_epochs=num_epochs, + label_name=label_name, + **kwargs) + self._verify_output(sess, dataset, batch_size, num_epochs, label_name, + expected_output, expected_keys) + + def testMakeCSVDataset(self): + """Tests making a CSV dataset with keys and defaults provided.""" + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + column_defaults=record_defaults, + ) + + def testMakeCSVDataset_withBatchSizeAndEpochs(self): + """Tests making a CSV dataset with keys and defaults provided.""" + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=3, + num_epochs=10, + shuffle=False, + header=True, + column_defaults=record_defaults, + ) - def testMakeCSVDataset_withBadColumns(self): + def testMakeCSVDataset_withCompressionType(self): + """Tests `compression_type` argument.""" + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + label = "col0" + + for compression_type in ("GZIP", "ZLIB"): + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + column_defaults=record_defaults, + compression_type=compression_type, + ) + + def testMakeCSVDataset_withBadInputs(self): """Tests that exception is raised when input is malformed. """ - dupe_columns = self.COLUMNS[:-1] + self.COLUMNS[:1] - defaults = self.DEFAULTS + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + filenames = self._setup_files(inputs) # Duplicate column names with self.assertRaises(ValueError): self._make_csv_dataset( - self._test_filenames, defaults, column_names=dupe_columns) + filenames, + batch_size=1, + column_defaults=record_defaults, + label_name="col0", + column_names=column_names * 2) # Label key not one of column names with self.assertRaises(ValueError): self._make_csv_dataset( - self._test_filenames, defaults, label_name="not_a_real_label") + filenames, + batch_size=1, + column_defaults=record_defaults, + label_name="not_a_real_label", + column_names=column_names) def testMakeCSVDataset_withNoLabel(self): - """Tests that CSV datasets can be created when no label is specified. - """ - defaults = self.DEFAULTS - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - # Read from both files. Make sure this works with no label key supplied. - dataset = self._make_csv_dataset( - self._test_filenames, - defaults, - batch_size=2, - num_epochs=10, - label_name=None) - self._verify_records( - sess, - dataset, - range(self._num_files), - batch_size=2, - num_epochs=10, - label_name=None) + """Tests making a CSV dataset with no label provided.""" + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + column_defaults=record_defaults, + ) def testMakeCSVDataset_withNoHeader(self): """Tests that datasets can be created from CSV files with no header line. """ - defaults = self.DEFAULTS - file_without_header = self._create_file( - len(self._test_filenames), header=False) - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - file_without_header, - defaults, - batch_size=2, - num_epochs=10, - header=False, - ) - self._verify_records( - sess, - dataset, - [len(self._test_filenames)], - batch_size=2, - num_epochs=10, - ) + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [["0,1,2,3,4", "5,6,7,8,9"], ["10,11,12,13,14", "15,16,17,18,19"]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=False, + column_defaults=record_defaults, + ) def testMakeCSVDataset_withTypes(self): """Tests that defaults can be a dtype instead of a Tensor for required vals. """ - defaults = [d for d in self.COLUMN_TYPES[:-1]] - defaults.append(constant_op.constant(["NULL"], dtype=dtypes.string)) - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset(self._test_filenames, defaults) - self._verify_records(sess, dataset, range(self._num_files)) + record_defaults = [ + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, + dtypes.string + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x[0] for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], + [ + ",".join(x[0] for x in column_names), "10,11,12,13,14", + "15,16,17,18,19" + ]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + column_defaults=record_defaults, + ) def testMakeCSVDataset_withNoColNames(self): """Tests that datasets can be created when column names are not specified. In that case, we should infer the column names from the header lines. """ - defaults = self.DEFAULTS - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - # Read from both files. Exercise the `batch` and `num_epochs` parameters - # of make_csv_dataset and make sure they work. - dataset = self._make_csv_dataset( - self._test_filenames, - defaults, - column_names=None, - batch_size=2, - num_epochs=10) - self._verify_records( - sess, dataset, range(self._num_files), batch_size=2, num_epochs=10) + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"], + [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + column_defaults=record_defaults, + ) def testMakeCSVDataset_withTypeInferenceMismatch(self): # Test that error is thrown when num fields doesn't match columns + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + filenames = self._setup_files(inputs) with self.assertRaises(ValueError): self._make_csv_dataset( - self._test_filenames, - column_names=self.COLUMNS + ["extra_name"], - defaults=None, + filenames, + column_names=column_names + ["extra_name"], + column_defaults=None, batch_size=2, num_epochs=10) @@ -448,197 +558,215 @@ class MakeCsvDatasetTest(test.TestCase): In that case, we should infer the types from the first N records. """ - # Test that it works with standard test files (with header, etc) - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - self._test_filenames, defaults=None, batch_size=2, num_epochs=10) - self._verify_records( - sess, - dataset, - range(self._num_files), - batch_size=2, - num_epochs=10, - defaults=[[], [], [], [], [""]]) - - def testMakeCSVDataset_withTypeInferenceTricky(self): - # Test on a deliberately tricky file (type changes as we read more rows, and - # there are null values) - fn = os.path.join(self.get_temp_dir(), "file.csv") - expected_dtypes = [ - dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float32, - dtypes.string, dtypes.string - ] - col_names = ["col%d" % i for i in range(len(expected_dtypes))] - rows = [[None, None, None, "NAN", "", - "a"], [1, 2**31 + 1, 2**64, 123, "NAN", ""], - ['"123"', 2, 2**64, 123.4, "NAN", '"cd,efg"']] - expected = [[0, 0, 0, 0, "", "a"], [1, 2**31 + 1, 2**64, 123, "", ""], - [123, 2, 2**64, 123.4, "", "cd,efg"]] - for row in expected: - row[-1] = row[-1].encode("utf-8") # py3 expects byte strings - row[-2] = row[-2].encode("utf-8") # py3 expects byte strings - self._write_file("file.csv", [col_names] + rows) + column_names = ["col%d" % i for i in range(5)] + str_int32_max = str(2**33) + inputs = [[ + ",".join(x for x in column_names), + "0,%s,2.0,3e50,rabbit" % str_int32_max + ]] + expected_output = [[0, 2**33, 2.0, 3e50, b"rabbit"]] + label = "col0" - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - fn, - defaults=None, - column_names=None, - label_name=None, - na_value="NAN", - ) - features = dataset.make_one_shot_iterator().get_next() - # Check that types match - for i in range(len(expected_dtypes)): - print(features["col%d" % i].dtype, expected_dtypes[i]) - assert features["col%d" % i].dtype == expected_dtypes[i] - for i in range(len(rows)): - assert sess.run(features) == dict(zip(col_names, expected[i])) - - def testMakeCSVDataset_withTypeInferenceAllTypes(self): - # Test that we make the correct inference for all types with fallthrough - fn = os.path.join(self.get_temp_dir(), "file.csv") - expected_dtypes = [ - dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, - dtypes.string, dtypes.string + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + ) + + def testMakeCSVDataset_withTypeInferenceFallthrough(self): + """Tests that datasets can be created when no defaults are specified. + + Tests on a deliberately tricky file. + """ + column_names = ["col%d" % i for i in range(5)] + str_int32_max = str(2**33) + inputs = [[ + ",".join(x for x in column_names), + ",,,,", + "0,0,0.0,0.0,0.0", + "0,%s,2.0,3e50,rabbit" % str_int32_max, + ",,,,", + ]] + expected_output = [[0, 0, 0, 0, b""], [0, 0, 0, 0, b"0.0"], + [0, 2**33, 2.0, 3e50, b"rabbit"], [0, 0, 0, 0, b""]] + label = "col0" + + self._test_dataset( + inputs, + expected_output=expected_output, + expected_keys=column_names, + column_names=column_names, + label_name=label, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + ) + + def testMakeCSVDataset_withSelectCols(self): + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) ] - col_names = ["col%d" % i for i in range(len(expected_dtypes))] - rows = [[1, 2**31 + 1, 1.0, 4e40, "abc", ""]] - expected = [[ - 1, 2**31 + 1, 1.0, 4e40, "abc".encode("utf-8"), "".encode("utf-8") + column_names = ["col%d" % i for i in range(5)] + str_int32_max = str(2**33) + inputs = [[ + ",".join(x for x in column_names), + "0,%s,2.0,3e50,rabbit" % str_int32_max ]] - self._write_file("file.csv", [col_names] + rows) + expected_output = [[0, 2**33, 2.0, 3e50, b"rabbit"]] - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - fn, - defaults=None, - column_names=None, - label_name=None, - na_value="NAN", - ) - features = dataset.make_one_shot_iterator().get_next() - # Check that types match - for i in range(len(expected_dtypes)): - self.assertAllEqual(features["col%d" % i].dtype, expected_dtypes[i]) - for i in range(len(rows)): - self.assertAllEqual( - sess.run(features), dict(zip(col_names, expected[i]))) + select_cols = [1, 3, 4] + self._test_dataset( + inputs, + expected_output=[[x[i] for i in select_cols] for x in expected_output], + expected_keys=[column_names[i] for i in select_cols], + column_names=column_names, + column_defaults=[record_defaults[i] for i in select_cols], + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + select_columns=select_cols, + ) + + # Can still do inference without provided defaults + self._test_dataset( + inputs, + expected_output=[[x[i] for i in select_cols] for x in expected_output], + expected_keys=[column_names[i] for i in select_cols], + column_names=column_names, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + select_columns=select_cols, + ) + + # Can still do column name inference + self._test_dataset( + inputs, + expected_output=[[x[i] for i in select_cols] for x in expected_output], + expected_keys=[column_names[i] for i in select_cols], + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + select_columns=select_cols, + ) + + # Can specify column names instead of indices + self._test_dataset( + inputs, + expected_output=[[x[i] for i in select_cols] for x in expected_output], + expected_keys=[column_names[i] for i in select_cols], + column_names=column_names, + batch_size=1, + num_epochs=1, + shuffle=False, + header=True, + select_columns=[column_names[i] for i in select_cols], + ) def testMakeCSVDataset_withSelectColsError(self): - data = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] - col_names = ["col%d" % i for i in range(5)] - fn = self._write_file("file.csv", [col_names] + data) + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + column_names = ["col%d" % i for i in range(5)] + str_int32_max = str(2**33) + inputs = [[ + ",".join(x for x in column_names), + "0,%s,2.0,3e50,rabbit" % str_int32_max + ]] + + select_cols = [1, 3, 4] + filenames = self._setup_files(inputs) + with self.assertRaises(ValueError): # Mismatch in number of defaults and number of columns selected, # should raise an error self._make_csv_dataset( - fn, - defaults=[[0]] * 5, - column_names=col_names, - label_name=None, - select_cols=[1, 3]) + filenames, + batch_size=1, + column_defaults=record_defaults, + column_names=column_names, + select_columns=select_cols) + with self.assertRaises(ValueError): # Invalid column name should raise an error self._make_csv_dataset( - fn, - defaults=[[0]], - column_names=col_names, + filenames, + batch_size=1, + column_defaults=[[0]], + column_names=column_names, label_name=None, - select_cols=["invalid_col_name"]) - - def testMakeCSVDataset_withSelectCols(self): - data = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] - col_names = ["col%d" % i for i in range(5)] - fn = self._write_file("file.csv", [col_names] + data) - # If select_cols is specified, should only yield a subset of columns - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - fn, - defaults=[[0], [0]], - column_names=col_names, - label_name=None, - select_cols=[1, 3]) - expected = [[1, 3], [6, 8]] - features = dataset.make_one_shot_iterator().get_next() - for i in range(len(data)): - self.assertAllEqual( - sess.run(features), - dict(zip([col_names[1], col_names[3]], expected[i]))) - # Can still do default inference with select_cols - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - fn, - defaults=None, - column_names=col_names, - label_name=None, - select_cols=[1, 3]) - expected = [[1, 3], [6, 8]] - features = dataset.make_one_shot_iterator().get_next() - for i in range(len(data)): - self.assertAllEqual( - sess.run(features), - dict(zip([col_names[1], col_names[3]], expected[i]))) - # Can still do column name inference - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - fn, - defaults=None, - column_names=None, - label_name=None, - select_cols=[1, 3]) - expected = [[1, 3], [6, 8]] - features = dataset.make_one_shot_iterator().get_next() - for i in range(len(data)): - self.assertAllEqual( - sess.run(features), - dict(zip([col_names[1], col_names[3]], expected[i]))) - # Can specify column names instead of indices - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - fn, - defaults=None, - column_names=None, - label_name=None, - select_cols=[col_names[1], col_names[3]]) - expected = [[1, 3], [6, 8]] - features = dataset.make_one_shot_iterator().get_next() - for i in range(len(data)): - self.assertAllEqual( - sess.run(features), - dict(zip([col_names[1], col_names[3]], expected[i]))) + select_columns=["invalid_col_name"]) def testMakeCSVDataset_withShuffle(self): - total_records = self._num_files * self._num_records - defaults = self.DEFAULTS + record_defaults = [ + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + constant_op.constant([], dtypes.float32), + constant_op.constant([], dtypes.float64), + constant_op.constant([], dtypes.string) + ] + + def str_series(st): + return ",".join(str(i) for i in range(st, st + 5)) + + column_names = ["col%d" % i for i in range(5)] + inputs = [ + [",".join(x for x in column_names) + ] + [str_series(5 * i) for i in range(15)], + [",".join(x for x in column_names)] + + [str_series(5 * i) for i in range(15, 20)], + ] + + filenames = self._setup_files(inputs) + + total_records = 20 for batch_size in [1, 2]: with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Test that shuffling with the same seed produces the same result dataset1 = self._make_csv_dataset( - self._test_filenames, - defaults, + filenames, + column_defaults=record_defaults, + column_names=column_names, batch_size=batch_size, + header=True, shuffle=True, - shuffle_seed=5) + shuffle_seed=5, + num_epochs=2, + ) dataset2 = self._make_csv_dataset( - self._test_filenames, - defaults, + filenames, + column_defaults=record_defaults, + column_names=column_names, batch_size=batch_size, + header=True, shuffle=True, - shuffle_seed=5) + shuffle_seed=5, + num_epochs=2, + ) outputs1 = dataset1.make_one_shot_iterator().get_next() outputs2 = dataset2.make_one_shot_iterator().get_next() for _ in range(total_records // batch_size): - batch1 = self._run_actual_batch(outputs1, sess) - batch2 = self._run_actual_batch(outputs2, sess) + batch1 = nest.flatten(sess.run(outputs1)) + batch2 = nest.flatten(sess.run(outputs2)) for i in range(len(batch1)): self.assertAllEqual(batch1[i], batch2[i]) @@ -646,27 +774,45 @@ class MakeCsvDatasetTest(test.TestCase): with self.test_session(graph=g) as sess: # Test that shuffling with a different seed produces different results dataset1 = self._make_csv_dataset( - self._test_filenames, - defaults, + filenames, + column_defaults=record_defaults, + column_names=column_names, batch_size=batch_size, + header=True, shuffle=True, - shuffle_seed=5) + shuffle_seed=5, + num_epochs=2, + ) dataset2 = self._make_csv_dataset( - self._test_filenames, - defaults, + filenames, + column_defaults=record_defaults, + column_names=column_names, batch_size=batch_size, + header=True, shuffle=True, - shuffle_seed=6) + shuffle_seed=6, + num_epochs=2, + ) outputs1 = dataset1.make_one_shot_iterator().get_next() outputs2 = dataset2.make_one_shot_iterator().get_next() all_equal = False for _ in range(total_records // batch_size): - batch1 = self._run_actual_batch(outputs1, sess) - batch2 = self._run_actual_batch(outputs2, sess) + batch1 = nest.flatten(sess.run(outputs1)) + batch2 = nest.flatten(sess.run(outputs2)) for i in range(len(batch1)): all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) self.assertFalse(all_equal) + def testIndefiniteRepeatShapeInference(self): + column_names = ["col%d" % i for i in range(5)] + inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [ + ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19" + ]] + filenames = self._setup_files(inputs) + dataset = self._make_csv_dataset(filenames, batch_size=32, num_epochs=None) + for shape in nest.flatten(dataset.output_shapes): + self.assertEqual(32, shape[0]) + class MakeTFRecordDatasetTest( reader_dataset_ops_test_base.TFRecordDatasetTestBase): @@ -874,6 +1020,12 @@ class MakeTFRecordDatasetTest( self._shuffle_test(batch_size, num_epochs, num_parallel_reads, seed=21345) + def testIndefiniteRepeatShapeInference(self): + dataset = readers.make_tf_record_dataset( + file_pattern=self.test_filenames, num_epochs=None, batch_size=32) + for shape in nest.flatten(dataset.output_shapes): + self.assertEqual(32, shape[0]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD index 686788522acdf1c5e91132c38bdc81d10d2a0cc2..7b9ea191a4524891d1b589e1e228e29241fda7f8 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD @@ -56,6 +56,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", ], ) @@ -72,6 +73,20 @@ py_test( ], ) +py_test( + name = "csv_dataset_serialization_test", + size = "small", + srcs = ["csv_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + ], +) + py_test( name = "dataset_constructor_serialization_test", size = "medium", diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py index a0a1100893c7384b0e2bd9fcfdaa8d3698b95d28..1b6059ccbcc81937696e1b0ebb269f213adbb976 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py @@ -19,6 +19,8 @@ from __future__ import print_function import os +from absl.testing import parameterized + from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors @@ -26,7 +28,8 @@ from tensorflow.python.platform import test class CacheDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + dataset_serialization_test_base.DatasetSerializationTestBase, + parameterized.TestCase): def setUp(self): self.range_size = 10 @@ -34,88 +37,123 @@ class CacheDatasetSerializationTest( self.num_outputs = self.range_size * self.num_repeats self.cache_file_prefix = 'test' - def ds_fn(self): - return dataset_ops.Dataset.range(self.range_size).cache( - os.path.join(self.get_temp_dir(), - self.cache_file_prefix)).repeat(self.num_repeats) + def make_dataset_fn(self, is_memory): + if is_memory: + filename = '' + else: + filename = os.path.join(self.get_temp_dir(), self.cache_file_prefix) + + def ds_fn(): + return dataset_ops.Dataset.range(self.range_size).cache(filename).repeat( + self.num_repeats) + + return ds_fn def expected_outputs(self): return list(range(self.range_size)) * self.num_repeats - def testCheckpointBeforeOneEpoch(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointBeforeOneEpoch(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Generate 5 entries from iterator and save checkpoint. - outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint and produce the rest of the elements from the # iterator. outputs.extend( self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, self.expected_outputs()) - def testCheckpointBeforeOneEpochThenRunFewSteps(self): - # Generate 8 entries from iterator but save checkpoint after producing - # 5. + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointBeforeOneEpochThenRunFewSteps(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 8 entries from iterator but save checkpoint after producing 5. outputs = self.gen_outputs( - self.ds_fn, [5], - 8, - verify_exhausted=False, - save_checkpoint_at_end=False) + ds_fn, [5], 8, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, range(8)) - # Restoring from checkpoint and running GetNext should return a - # `AlreadExistsError` now because the lockfile already exists. - with self.assertRaises(errors.AlreadyExistsError): - self.gen_outputs( - self.ds_fn, [], - self.num_outputs - 5, - ckpt_saved=True, - verify_exhausted=False) + if is_memory: + outputs = outputs[:5] + outputs.extend( + self.gen_outputs( + ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, self.expected_outputs()) + else: + # Restoring from checkpoint and running GetNext should return + # `AlreadExistsError` now because the lockfile already exists. + with self.assertRaises(errors.AlreadyExistsError): + self.gen_outputs( + ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False) + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointAfterOneEpoch(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) - def testCheckpointAfterOneEpoch(self): # Generate 15 entries from iterator and save checkpoint. - outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) # Restore from checkpoint and produce the rest of the elements from the # iterator. outputs.extend( self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 15, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, self.expected_outputs()) - def testCheckpointAfterOneEpochThenRunFewSteps(self): - # Generate 18 entries from iterator but save checkpoint after producing - # 15. + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointAfterOneEpochThenRunFewSteps(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 18 entries from iterator but save checkpoint after producing 15. outputs = self.gen_outputs( - self.ds_fn, [15], - 18, - verify_exhausted=False, - save_checkpoint_at_end=False) + ds_fn, [15], 18, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(8))) outputs = list(range(10)) + list(range(5)) + self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 15, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testCheckpointBeforeOneEpochButRunCompleteEpoch(self): - # Generate 13 entries from iterator but save checkpoint after producing - # 5. + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointBeforeOneEpochButRunCompleteEpoch(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + + # Generate 13 entries from iterator but save checkpoint after producing 5. outputs = self.gen_outputs( - self.ds_fn, [5], - 13, - verify_exhausted=False, - save_checkpoint_at_end=False) + ds_fn, [5], 13, verify_exhausted=False, save_checkpoint_at_end=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(3))) # Since we ran for more than one epoch, the cache was completely written. @@ -124,65 +162,90 @@ class CacheDatasetSerializationTest( # been completely written. outputs = list(range(5)) + self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testCheckpointUnusedWriterIterator(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointUnusedWriterIterator(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Checkpoint before get_next is called even once. - outputs = self.gen_outputs(self.ds_fn, [], 0, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 0, verify_exhausted=False) self.assertSequenceEqual(outputs, []) outputs = self.gen_outputs( - self.ds_fn, [], - self.num_outputs, - ckpt_saved=True, - verify_exhausted=False) + ds_fn, [], self.num_outputs, ckpt_saved=True, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testCheckpointUnusedMidwayWriterIterator(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testCheckpointUnusedMidwayWriterIterator(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Produce 5 elements and checkpoint. - outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint, then produce no elements and checkpoint. outputs.extend( - self.gen_outputs( - self.ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False)) + self.gen_outputs(ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, range(5)) # Restore from checkpoint and produce rest of the elements. outputs.extend( self.gen_outputs( - self.ds_fn, [], + ds_fn, [], self.num_outputs - 5, ckpt_saved=True, verify_exhausted=False)) self.assertSequenceEqual(outputs, list(range(10)) * 3) - def testUnusedCheckpointError(self): + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testUnusedCheckpointError(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) + # Produce 5 elements and save ckpt. - outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False) self.assertSequenceEqual(outputs, range(5)) - # Since the complete cache has not been written, a new iterator which does - # not restore the checkpoint will throw an error since there is a partial - # cache shard. - with self.assertRaises(errors.AlreadyExistsError): + if is_memory: outputs = self.gen_outputs( - self.ds_fn, [], self.num_outputs, verify_exhausted=False) + ds_fn, [], self.num_outputs, verify_exhausted=False) + self.assertSequenceEqual(outputs, self.expected_outputs()) + else: + # Since the complete cache has not been written, a new iterator which does + # not restore the checkpoint will throw an error since there is a partial + # cache shard. + with self.assertRaises(errors.AlreadyExistsError): + outputs = self.gen_outputs( + ds_fn, [], self.num_outputs, verify_exhausted=False) + + @parameterized.named_parameters( + ('Memory', True), + ('File', False), + ) + def testIgnoreCheckpointIfCacheWritten(self, is_memory): + ds_fn = self.make_dataset_fn(is_memory) - def testIgnoreCheckpointIfCacheWritten(self): # Produce 15 elements and save ckpt. This will write the complete cache. - outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) # Build the iterator again but do not restore from ckpt. Since the cache # has already been written we should be able to use it. outputs = self.gen_outputs( - self.ds_fn, [], self.num_outputs, verify_exhausted=False) + ds_fn, [], self.num_outputs, verify_exhausted=False) self.assertSequenceEqual(outputs, list(range(10)) * 3) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..247f2046ea313f97bdbda1674765f12406258509 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py @@ -0,0 +1,73 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the CsvDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.platform import test + + +class CsvDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self._num_cols = 7 + self._num_rows = 10 + self._num_epochs = 14 + self._num_outputs = self._num_rows * self._num_epochs + + inputs = [ + ",".join(str(self._num_cols * j + i) + for i in range(self._num_cols)) + for j in range(self._num_rows) + ] + contents = "\n".join(inputs).encode("utf-8") + + self._filename = os.path.join(self.get_temp_dir(), "file.csv") + self._compressed = os.path.join(self.get_temp_dir(), + "comp.csv") # GZip compressed + + with open(self._filename, "wb") as f: + f.write(contents) + with gzip.GzipFile(self._compressed, "wb") as f: + f.write(contents) + + def ds_func(self, **kwargs): + compression_type = kwargs.get("compression_type", None) + if compression_type == "GZIP": + filename = self._compressed + elif compression_type is None: + filename = self._filename + else: + raise ValueError("Invalid compression type:", compression_type) + + return readers.CsvDataset(filename, **kwargs).repeat(self._num_epochs) + + def testSerializationCore(self): + defs = [[0]] * self._num_cols + self.run_core_tests( + lambda: self.ds_func(record_defaults=defs, buffer_size=2), + lambda: self.ds_func(record_defaults=defs, buffer_size=12), + self._num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py index 393f08850b1865180a8b94e9209b2445b54c8b69..3ed4dfb7295ca77c78ce5318bf31e16a354e16a8 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import nest @@ -655,7 +656,7 @@ class DatasetSerializationTestBase(test.TestCase): return os.path.join(self.get_temp_dir(), "iterator") def _latest_ckpt(self): - return saver_lib.latest_checkpoint(self.get_temp_dir()) + return checkpoint_management.latest_checkpoint(self.get_temp_dir()) def _save(self, sess, saver): saver.save(sess, self._ckpt_path()) diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index 5590a4bf783d12b0d0710c0130b0b1df921c9baa..8b2f84649486e35e1067f5f9cbe4a7abec71e080 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.contrib.data.python.ops import sliding @@ -29,28 +30,45 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class SlideDatasetTest(test.TestCase): - - def testSlideDataset(self): - """Test an dataset that maps a TF function across its input elements.""" +class SlideDatasetTest(test.TestCase, parameterized.TestCase): + + @parameterized.parameters( + (20, 14, 7, 1), + (20, 17, 9, 1), + (20, 14, 14, 1), + (20, 10, 14, 1), + (20, 14, 19, 1), + (20, 4, 1, 2), + (20, 2, 1, 6), + (20, 4, 7, 2), + (20, 2, 7, 6), + (1, 10, 4, 1), + (0, 10, 4, 1), + ) + def testSlideDataset(self, count, window_size, window_shift, window_stride): + """Tests a dataset that slides a window its input elements.""" components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) - count = array_ops.placeholder(dtypes.int64, shape=[]) - window_size = array_ops.placeholder(dtypes.int64, shape=[]) - stride = array_ops.placeholder(dtypes.int64, shape=[]) + count_t = array_ops.placeholder(dtypes.int64, shape=[]) + window_size_t = array_ops.placeholder(dtypes.int64, shape=[]) + window_shift_t = array_ops.placeholder(dtypes.int64, shape=[]) + window_stride_t = array_ops.placeholder(dtypes.int64, shape=[]) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> - # RepeatDataset(count) -> _SlideDataset(window_size, stride). - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .map(_map_fn) - .repeat(count) - .apply(sliding.sliding_window_batch(window_size, stride)) - .make_initializable_iterator()) + # RepeatDataset(count) -> + # _SlideDataset(window_size, window_shift, window_stride). + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(count).apply( + sliding.sliding_window_batch( + window_size=window_size_t, + window_shift=window_shift_t, + window_stride=window_stride_t)).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -58,90 +76,126 @@ class SlideDatasetTest(test.TestCase): [t.shape.as_list() for t in get_next]) with self.test_session() as sess: - # stride < window_size. - # Slide over a finite input, where the window_size divides the - # total number of elements. - sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 7}) - # Same formula with convolution layer. - num_batches = (20 * 7 - 14) // 7 + 1 - for i in range(num_batches): - result = sess.run(get_next) - for component, result_component in zip(components, result): - for j in range(14): - self.assertAllEqual(component[(i*7 + j) % 7]**2, - result_component[j]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - # Slide over a finite input, where the window_size does not - # divide the total number of elements. - sess.run(init_op, feed_dict={count: 20, window_size: 17, stride: 9}) - num_batches = (20 * 7 - 17) // 9 + 1 + sess.run( + init_op, + feed_dict={ + count_t: count, + window_size_t: window_size, + window_shift_t: window_shift, + window_stride_t: window_stride + }) + num_batches = (count * 7 - ( + (window_size - 1) * window_stride + 1)) // window_shift + 1 for i in range(num_batches): result = sess.run(get_next) for component, result_component in zip(components, result): - for j in range(17): - self.assertAllEqual(component[(i*9 + j) % 7]**2, - result_component[j]) + for j in range(window_size): + self.assertAllEqual( + component[(i * window_shift + j * window_stride) % 7]**2, + result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - # stride == window_size. - sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 14}) - num_batches = 20 * 7 // 14 - for i in range(num_batches): - result = sess.run(get_next) - for component, result_component in zip(components, result): - for j in range(14): - self.assertAllEqual(component[(i*14 + j) % 7]**2, - result_component[j]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + @parameterized.parameters( + (20, 14, 7, 1), + (20, 17, 9, 1), + (20, 14, 14, 1), + (20, 10, 14, 1), + (20, 14, 19, 1), + (20, 4, 1, 2), + (20, 2, 1, 6), + (20, 4, 7, 2), + (20, 2, 7, 6), + (1, 10, 4, 1), + (0, 10, 4, 1), + ) + def testSlideDatasetDeprecated(self, count, window_size, stride, + window_stride): + """Tests a dataset that slides a window its input elements.""" + components = (np.arange(7), + np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], + np.array(37.0) * np.arange(7)) - # stride > window_size. - sess.run(init_op, feed_dict={count: 20, window_size: 10, stride: 14}) - num_batches = 20 * 7 // 14 - for i in range(num_batches): - result = sess.run(get_next) - for component, result_component in zip(components, result): - for j in range(10): - self.assertAllEqual(component[(i*14 + j) % 7]**2, - result_component[j]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - # Drop the last batch which is smaller than window_size. - sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 19}) - num_batches = (20 * 7 - 7) // 19 # = 19 * 7 // 19 - for i in range(num_batches): - result = sess.run(get_next) - for component, result_component in zip(components, result): - for j in range(14): - self.assertAllEqual(component[(i*19 + j) % 7]**2, - result_component[j]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + count_t = array_ops.placeholder(dtypes.int64, shape=[]) + window_size_t = array_ops.placeholder(dtypes.int64, shape=[]) + stride_t = array_ops.placeholder(dtypes.int64, shape=[]) + window_stride_t = array_ops.placeholder(dtypes.int64, shape=[]) - # Slide over a finite input, which is less than window_size, - # should fail straight away. - sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 4}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) - sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 8}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> + # RepeatDataset(count) -> _SlideDataset(window_size, stride, window_stride). + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(count).apply( + sliding.sliding_window_batch( + window_size=window_size_t, + stride=stride_t, + window_stride=window_stride_t)).make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() - # Slide over an empty input should fail straight away. - sess.run(init_op, feed_dict={count: 0, window_size: 8, stride: 4}) + self.assertEqual([[None] + list(c.shape[1:]) for c in components], + [t.shape.as_list() for t in get_next]) + + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + count_t: count, + window_size_t: window_size, + stride_t: stride, + window_stride_t: window_stride + }) + num_batches = (count * 7 - ( + (window_size - 1) * window_stride + 1)) // stride + 1 + for i in range(num_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(window_size): + self.assertAllEqual( + component[(i * stride + j * window_stride) % 7]**2, + result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - # Empty window_size should be an initialization time error. - with self.assertRaises(errors.InvalidArgumentError): - sess.run(init_op, feed_dict={count: 14, window_size: 0, stride: 0}) + @parameterized.parameters( + (14, 0, 3, 1), + (14, 3, 0, 1), + (14, 3, 3, 0), + ) + def testSlideDatasetInvalid(self, count, window_size, window_shift, + window_stride): + count_t = array_ops.placeholder(dtypes.int64, shape=[]) + window_size_t = array_ops.placeholder(dtypes.int64, shape=[]) + window_shift_t = array_ops.placeholder(dtypes.int64, shape=[]) + window_stride_t = array_ops.placeholder(dtypes.int64, shape=[]) + + iterator = ( + dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count_t).apply( + sliding.sliding_window_batch( + window_size=window_size_t, + window_shift=window_shift_t, + window_stride=window_stride_t)).make_initializable_iterator()) + init_op = iterator.initializer - # Invalid stride should be an initialization time error. + with self.test_session() as sess: with self.assertRaises(errors.InvalidArgumentError): - sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 0}) + sess.run( + init_op, + feed_dict={ + count_t: count, + window_size_t: window_size, + window_shift_t: window_shift, + window_stride_t: window_stride + }) + + def testSlideDatasetValueError(self): + with self.assertRaises(ValueError): + dataset_ops.Dataset.range(10).map(lambda x: x).apply( + sliding.sliding_window_batch( + window_size=1, stride=1, window_shift=1, window_stride=1)) def assertSparseValuesEqual(self, a, b): self.assertAllEqual(a.indices, b.indices) @@ -155,7 +209,8 @@ class SlideDatasetTest(test.TestCase): indices=[[0]], values=(i * [1]), dense_shape=[1]) iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( - sliding.sliding_window_batch(5, 3)).make_initializable_iterator() + sliding.sliding_window_batch( + window_size=5, window_shift=3)).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -183,7 +238,8 @@ class SlideDatasetTest(test.TestCase): dense_shape=[i]) iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( - sliding.sliding_window_batch(5, 3)).make_initializable_iterator() + sliding.sliding_window_batch( + window_size=5, window_shift=3)).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -213,11 +269,11 @@ class SlideDatasetTest(test.TestCase): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) - iterator = (dataset_ops.Dataset.range(10) - .map(_sparse) - .apply(sliding.sliding_window_batch(4, 2)) - .apply(sliding.sliding_window_batch(3, 1)) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(window_size=4, window_shift=2)).apply( + sliding.sliding_window_batch(window_size=3, window_shift=1)) + .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -226,9 +282,9 @@ class SlideDatasetTest(test.TestCase): # Slide: 1st batch. actual = sess.run(get_next) expected = sparse_tensor.SparseTensorValue( - indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], - [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], - [2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]], + indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0], + [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0], + [2, 2, 0], [2, 3, 0]], values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7], dense_shape=[3, 4, 1]) self.assertTrue(sparse_tensor.is_sparse(actual)) @@ -236,9 +292,9 @@ class SlideDatasetTest(test.TestCase): # Slide: 2nd batch. actual = sess.run(get_next) expected = sparse_tensor.SparseTensorValue( - indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], - [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], - [2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]], + indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0], + [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0], + [2, 2, 0], [2, 3, 0]], values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9], dense_shape=[3, 4, 1]) self.assertTrue(sparse_tensor.is_sparse(actual)) @@ -253,10 +309,11 @@ class SlideDatasetTest(test.TestCase): yield [4.0, 5.0, 6.0] yield [7.0, 8.0, 9.0, 10.0] - iterator = (dataset_ops.Dataset.from_generator(generator, dtypes.float32, - output_shapes=[None]) - .apply(sliding.sliding_window_batch(3, 1)) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_generator( + generator, dtypes.float32, output_shapes=[None]).apply( + sliding.sliding_window_batch(window_size=3, window_shift=1)) + .make_initializable_iterator()) next_element = iterator.get_next() with self.test_session() as sess: diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index b4945685c1d1062bf416b73f1541f351adf45604..a41d21f8c14ed6bec7626599a5aa7f365765ce8b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base from tensorflow.contrib.data.python.ops import stats_ops -from tensorflow.core.framework import summary_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -29,28 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class StatsDatasetTestBase(test.TestCase): - - def _assertSummaryHasCount(self, summary_str, tag, expected_value): - summary_proto = summary_pb2.Summary() - summary_proto.ParseFromString(summary_str) - for value in summary_proto.value: - if tag == value.tag: - self.assertEqual(expected_value, value.histo.num) - return - self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) - - def _assertSummaryHasSum(self, summary_str, tag, expected_value): - summary_proto = summary_pb2.Summary() - summary_proto.ParseFromString(summary_str) - for value in summary_proto.value: - if tag == value.tag: - self.assertEqual(expected_value, value.histo.sum) - return - self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) - - -class StatsDatasetTest(StatsDatasetTestBase): +class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): def testBytesProduced(self): stats_aggregator = stats_ops.StatsAggregator() @@ -197,7 +176,7 @@ class StatsDatasetTest(StatsDatasetTestBase): class FeatureStatsDatasetTest( - StatsDatasetTestBase, + stats_dataset_test_base.StatsDatasetTestBase, reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): def testFeaturesStats(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..9a13acf8f0ac6690cad8847873768562da795496 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py @@ -0,0 +1,44 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for testing the input pipeline statistics gathering ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.core.framework import summary_pb2 +from tensorflow.python.platform import test + + +class StatsDatasetTestBase(test.TestCase): + """Base class for testing statistics gathered in `StatsAggregator`.""" + + def _assertSummaryHasCount(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.histo.num) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + + def _assertSummaryHasSum(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.histo.sum) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..33d95d67549e1c8d1d9af578fcebbb4f939c418a --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py @@ -0,0 +1,523 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class WindowDatasetTest(test.TestCase, parameterized.TestCase): + + def _structuredDataset(self, structure, shape, dtype): + if structure is None: + return dataset_ops.Dataset.from_tensors( + array_ops.zeros(shape, dtype=dtype)) + else: + return dataset_ops.Dataset.zip( + tuple([ + self._structuredDataset(substructure, shape, dtype) + for substructure in structure + ])) + + def _structuredElement(self, structure, shape, dtype): + if structure is None: + return array_ops.zeros(shape, dtype=dtype) + else: + return tuple([ + self._structuredElement(substructure, shape, dtype) + for substructure in structure + ]) + + def _assertEqual(self, xs, ys): + self.assertEqual(type(xs), type(ys)) + if isinstance(xs, tuple) and isinstance(ys, tuple): + self.assertEqual(len(xs), len(ys)) + for x, y in zip(xs, ys): + self._assertEqual(x, y) + elif isinstance(xs, np.ndarray) and isinstance(ys, np.ndarray): + self.assertAllEqual(xs, ys) + else: + self.assertEqual(xs, ys) + + @parameterized.parameters( + (None, np.int32([]), dtypes.bool), + (None, np.int32([]), dtypes.int32), + (None, np.int32([]), dtypes.float32), + (None, np.int32([]), dtypes.string), + (None, np.int32([2]), dtypes.int32), + (None, np.int32([2, 2]), dtypes.int32), + ((None, None, None), np.int32([]), dtypes.int32), + ((None, (None, None)), np.int32([]), dtypes.int32), + ) + def testWindowDatasetFlatMap(self, structure, shape, dtype): + """Tests windowing by chaining it with flat map. + + Args: + structure: the input structure + shape: the input shape + dtype: the input data type + """ + + def fn(*args): + if len(args) == 1 and not isinstance(args[0], tuple): + return args[0] + return dataset_ops.Dataset.zip( + tuple([fn(*arg) if isinstance(arg, tuple) else arg for arg in args])) + + dataset = self._structuredDataset(structure, shape, dtype).apply( + grouping.window_dataset(5)).flat_map(fn) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + expected = sess.run(self._structuredElement(structure, shape, dtype)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (None, np.int32([]), dtypes.bool), + (None, np.int32([]), dtypes.int32), + (None, np.int32([]), dtypes.float32), + (None, np.int32([]), dtypes.string), + (None, np.int32([2]), dtypes.int32), + (None, np.int32([2, 2]), dtypes.int32), + ((None, None, None), np.int32([]), dtypes.int32), + ((None, (None, None)), np.int32([]), dtypes.int32), + ) + def testWindowDatasetBatchDense(self, structure, shape, dtype): + """Tests batching of dense tensor windows. + + Args: + structure: the input structure + shape: the input shape + dtype: the input data type + """ + + def fn(*args): + if len(args) == 1 and not isinstance(args[0], tuple): + return batching.batch_window(args[0]) + + return tuple([ + fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg) + for arg in args + ]) + + dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply( + grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + expected = sess.run( + self._structuredElement(structure, np.concatenate( + ([5], shape), axis=0), dtype)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (np.int32([]),), + (np.int32([1]),), + (np.int32([1, 2, 3]),), + ) + def testWindowDatasetBatchDenseDynamicShape(self, shape): + """Tests batching of dynamically shaped dense tensor windows. + + Args: + shape: the input shape + """ + + shape_t = array_ops.placeholder(dtypes.int32) + dataset = dataset_ops.Dataset.from_tensors( + array_ops.zeros(shape_t)).repeat(5).apply( + grouping.window_dataset(5)).apply( + grouping._map_x_dataset(batching.batch_window)) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op, {shape_t: shape}) + expected = sess.run( + self._structuredElement(None, np.concatenate(([5], shape), axis=0), + dtypes.int32)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + def _make_dense_to_sparse_fn(self, is_scalar): + + def dense_to_sparse_scalar(tensor): + indices = [[]] + values = array_ops.expand_dims(tensor, 0) + shape = [] + return sparse_tensor.SparseTensorValue(indices, values, shape) + + def dense_to_sparse_non_scalar(tensor): + indices = array_ops.where(array_ops.ones_like(tensor, dtype=dtypes.bool)) + values = array_ops.gather_nd(tensor, indices) + shape = array_ops.shape(tensor, out_type=dtypes.int64) + return sparse_tensor.SparseTensorValue(indices, values, shape) + + if is_scalar: + return dense_to_sparse_scalar + return dense_to_sparse_non_scalar + + def _structuredSparseDataset(self, structure, shape, dtype): + dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test + if structure is None: + return dataset_ops.Dataset.from_tensors( + dense_to_sparse(array_ops.zeros(shape, dtype=dtype))) + else: + return dataset_ops.Dataset.zip( + tuple([ + self._structuredSparseDataset(substructure, shape, dtype) + for substructure in structure + ])) + + def _structuredSparseElement(self, structure, shape, dtype): + dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test + if structure is None: + return dense_to_sparse(array_ops.zeros(shape, dtype=dtype)) + else: + return tuple([ + self._structuredSparseElement(substructure, shape, dtype) + for substructure in structure + ]) + + @parameterized.parameters( + (None, np.int32([]), dtypes.bool), + (None, np.int32([]), dtypes.int32), + (None, np.int32([]), dtypes.float32), + (None, np.int32([]), dtypes.string), + (None, np.int32([2]), dtypes.int32), + (None, np.int32([2, 2]), dtypes.int32), + ((None, None, None), np.int32([]), dtypes.int32), + ((None, (None, None)), np.int32([]), dtypes.int32), + ) + def testWindowDatasetBatchSparse(self, structure, shape, dtype): + """Tests batching of sparse tensor windows. + + Args: + structure: the input structure + shape: the input shape + dtype: the input data type + """ + + def fn(*args): + if len(args) == 1 and not isinstance(args[0], tuple): + return batching.batch_window(args[0]) + + return tuple([ + fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg) + for arg in args + ]) + + dataset = self._structuredSparseDataset( + structure, shape, dtype).repeat(5).apply( + grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + expected = sess.run( + self._structuredSparseElement(structure, + np.concatenate(([5], shape), axis=0), + dtype)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (np.int32([]),), + (np.int32([1]),), + (np.int32([1, 2, 3]),), + ) + def testWindowDatasetBatchSparseDynamicShape(self, shape): + """Tests batching of dynamically shaped sparse tensor windows. + + Args: + shape: the input shape + """ + + shape_t = array_ops.placeholder(dtypes.int32) + dataset = dataset_ops.Dataset.from_tensors(array_ops.zeros(shape_t)).map( + self._make_dense_to_sparse_fn(len(shape) == 0)).repeat(5).apply( # pylint: disable=g-explicit-length-test + grouping.window_dataset(5)).apply( + grouping._map_x_dataset(batching.batch_window)) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op, {shape_t: shape}) + expected = sess.run( + self._structuredSparseElement(None, + np.concatenate(([5], shape), axis=0), + dtypes.int32)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + def _structuredRaggedDataset(self, structure, shapes, dtype): + + if structure is None: + return dataset_ops.Dataset.from_tensor_slices(shapes).map( + lambda shape: array_ops.zeros(shape, dtype=dtype)) + else: + return dataset_ops.Dataset.zip( + tuple([ + self._structuredRaggedDataset(substructure, shapes, dtype) + for substructure in structure + ])) + + @parameterized.parameters( + (None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]), + (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + (None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]), + (None, np.int32([[1], [2], [3]]), dtypes.string, [-1]), + (None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), + (None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]), + ((None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + ((None, (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + (None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])), + ) + def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype, + padded_shape): + """Tests padded batching of dense tensor windows. + + Args: + structure: the input structure + shapes: the input shapes + dtype: the input data type + padded_shape: the shape to pad the output to + """ + + def fn(*args): + if len(args) == 1 and not isinstance(args[0], tuple): + return batching.padded_batch_window(args[0], padded_shape) + + return tuple([ + fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window( + arg, padded_shape) for arg in args + ]) + + dataset = self._structuredRaggedDataset(structure, shapes, dtype).apply( + grouping.window_dataset(len(shapes))).apply( + grouping._map_x_dataset(fn)) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) + expected = sess.run( + self._structuredElement( + structure, + np.concatenate((np.int32([len(shapes)]), expected_shape)), dtype)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (np.int32([[1], [2], [3]]), [-1]), + (np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]), + (np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), + ) + def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape): + """Tests padded batching of dynamically shaped dense tensor windows. + + Args: + shapes: the input shapes + padded_shape: the shape to pad the output to + """ + + shapes_t = array_ops.placeholder(dtypes.int32) + dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map( + lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply( + grouping.window_dataset(len(shapes))).apply( + grouping._map_x_dataset( + lambda x: batching.padded_batch_window(x, padded_shape))) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op, {shapes_t: shapes}) + expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) + expected = sess.run( + self._structuredElement( + None, np.concatenate((np.int32([len(shapes)]), expected_shape)), + dtypes.int32)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (np.int32([[1]]), np.int32([0])), + (np.int32([[10], [20]]), np.int32([15])), + ) + def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape): + """Tests invalid padded batching of dense tensor windows. + + Args: + shapes: the input shapes + padded_shape: the shape to pad the output to + """ + + dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map( + lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply( + grouping.window_dataset(len(shapes))).apply( + grouping._map_x_dataset( + lambda x: batching.padded_batch_window(x, padded_shape))) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + + def _structuredRaggedSparseDataset(self, structure, shapes, dtype): + + def map_fn(shape): + dense_to_sparse = self._make_dense_to_sparse_fn(False) + return dense_to_sparse(array_ops.zeros(shape, dtype=dtype)) + + if structure is None: + return dataset_ops.Dataset.from_tensor_slices(shapes).map(map_fn) + else: + return dataset_ops.Dataset.zip( + tuple([ + self._structuredRaggedSparseDataset(substructure, shapes, dtype) + for substructure in structure + ])) + + def _structuredRaggedSparseElement(self, structure, shapes, dtype, + padded_shape): + if structure is None: + dense_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) + values = [] + for shape in shapes: + dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test + sparse = dense_to_sparse(array_ops.zeros(shape, dtype=dtype)) + padded_sparse = sparse_tensor.SparseTensor(sparse.indices, + sparse.values, dense_shape) + reshaped_sparse = sparse_ops.sparse_reshape( + padded_sparse, + array_ops.concat([np.array([1], dtype=np.int64), dense_shape], 0)) + values.append(reshaped_sparse) + return sparse_ops.sparse_concat(0, values) + else: + return tuple([ + self._structuredRaggedSparseElement(substructure, shapes, dtype, + padded_shape) + for substructure in structure + ]) + + @parameterized.parameters( + (None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]), + (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + (None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]), + (None, np.int64([[1], [2], [3]]), dtypes.string, [-1]), + (None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), + (None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]), + ((None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + ((None, (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + (None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])), + ) + def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype, + padded_shape): + """Tests padded batching of sparse tensor windows. + + Args: + structure: the input structure + shapes: the input shapes + dtype: the input data type + padded_shape: the shape to pad the output to + """ + + def fn(*args): + if len(args) == 1 and not isinstance(args[0], tuple): + return batching.padded_batch_window(args[0], padded_shape) + + return tuple([ + fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window( + arg, padded_shape) for arg in args + ]) + + dataset = self._structuredRaggedSparseDataset( + structure, shapes, dtype).apply(grouping.window_dataset( + len(shapes))).apply(grouping._map_x_dataset(fn)) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + expected = sess.run( + self._structuredRaggedSparseElement(structure, shapes, dtype, + padded_shape)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (np.int64([[1], [2], [3]]), [-1]), + (np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]), + (np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), + ) + def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes, + padded_shape): + """Tests padded batching of dynamically shaped sparse tensor windows. + + Args: + shapes: the input shapes + padded_shape: the shape to pad the output to + """ + + shapes_t = array_ops.placeholder(dtypes.int32) + dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map( + lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map( + self._make_dense_to_sparse_fn(False) + ).apply(grouping.window_dataset(len(shapes))).apply( + grouping._map_x_dataset( + lambda x: batching.padded_batch_window(x, padded_shape))) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op, {shapes_t: shapes}) + expected = sess.run( + self._structuredRaggedSparseElement(None, shapes, dtypes.int32, + padded_shape)) + actual = sess.run(get_next) + self._assertEqual(expected, actual) + + @parameterized.parameters( + (np.int64([[1]]), [0]), + (np.int64([[10], [20]]), [15]), + ) + def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape): + """Tests invalid padded batching of sparse tensor windows. + + Args: + shapes: the input shapes + padded_shape: the shape to pad the output to + """ + + dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map( + lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map( + self._make_dense_to_sparse_fn(False) + ).apply(grouping.window_dataset(len(shapes))).apply( + grouping._map_x_dataset( + lambda x: batching.padded_batch_window(x, padded_shape))) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 02408145625b7e751541e7b87dc4fd5da4f7cad9..ad9378dfb9d938c826f994da9bbb89101cfbd872 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -28,10 +28,12 @@ py_library( srcs = ["get_single_element.py"], srcs_version = "PY2AND3", deps = [ + ":grouping", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", + "//third_party/py/numpy", ], ) @@ -115,6 +117,8 @@ py_library( srcs = ["batching.py"], srcs_version = "PY2AND3", deps = [ + ":get_single_element", + ":grouping", "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:array_ops", "//tensorflow/python:dataset_ops_gen", @@ -127,6 +131,7 @@ py_library( "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", + "//third_party/py/numpy", ], ) @@ -205,6 +210,17 @@ py_library( ], ) +py_library( + name = "map_defun", + srcs = ["map_defun.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python:tensor_shape", + ], +) + py_library( name = "resampling", srcs = ["resampling.py"], @@ -365,6 +381,7 @@ py_library( ":get_single_element", ":grouping", ":interleave_ops", + ":map_defun", ":optimization", ":prefetching_ops", ":readers", diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 7350d595f5f6b64d062dcc5ebc69d7e85d3f7b22..4835c4e5bd9efded57f19d6a382b145ae1b05e93 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -17,22 +17,134 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + +from tensorflow.contrib.data.python.ops import get_single_element +from tensorflow.contrib.data.python.ops import grouping from tensorflow.contrib.framework import with_shape from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.util import deprecation +def batch_window(dataset): + """Batches a window of tensors. + + Args: + dataset: the input dataset. + + Returns: + A `Tensor` representing the batch of the entire input dataset. + """ + if isinstance(dataset.output_classes, tuple): + raise TypeError("Input dataset expected to have a single component") + if dataset.output_classes is ops.Tensor: + return _batch_dense_window(dataset) + elif dataset.output_classes is sparse_tensor.SparseTensor: + return _batch_sparse_window(dataset) + else: + raise TypeError("Unsupported dataset type: %s" % dataset.output_classes) + + +def _batch_dense_window(dataset): + """Batches a window of dense tensors.""" + + def key_fn(_): + return np.int64(0) + + def shape_init_fn(_): + return array_ops.shape(first_element) + + def shape_reduce_fn(state, value): + check_ops.assert_equal(state, array_ops.shape(value)) + return state + + def finalize_fn(state): + return state + + if dataset.output_shapes.is_fully_defined(): + shape = dataset.output_shapes + else: + first_element = get_single_element.get_single_element(dataset.take(1)) + shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn, + finalize_fn) + shape = get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer))) + + def batch_init_fn(_): + batch_shape = array_ops.concat([[0], shape], 0) + return gen_array_ops.empty(batch_shape, dtype=dataset.output_types) + + def batch_reduce_fn(state, value): + return array_ops.concat([state, [value]], 0) + + batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) + return get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, batch_reducer))) + + +def _batch_sparse_window(dataset): + """Batches a window of sparse tensors.""" + + def key_fn(_): + return np.int64(0) + + def shape_init_fn(_): + return first_element.dense_shape + + def shape_reduce_fn(state, value): + check_ops.assert_equal(state, value.dense_shape) + return state + + def finalize_fn(state): + return state + + if dataset.output_shapes.is_fully_defined(): + shape = dataset.output_shapes + else: + first_element = get_single_element.get_single_element(dataset.take(1)) + shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn, + finalize_fn) + shape = get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer))) + + def batch_init_fn(_): + indices_shape = array_ops.concat([[0], [array_ops.size(shape) + 1]], 0) + return sparse_tensor.SparseTensor( + indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64), + values=constant_op.constant([], shape=[0], dtype=dataset.output_types), + dense_shape=array_ops.concat( + [np.array([0], dtype=np.int64), + math_ops.cast(shape, dtypes.int64)], 0)) + + def batch_reduce_fn(state, value): + return sparse_ops.sparse_concat(0, [state, value]) + + def reshape_fn(value): + return sparse_ops.sparse_reshape( + value, + array_ops.concat([np.array([1], dtype=np.int64), value.dense_shape], 0)) + + batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) + return get_single_element.get_single_element( + dataset.map(reshape_fn).apply( + grouping.group_by_reducer(key_fn, batch_reducer))) + + def dense_to_sparse_batch(batch_size, row_shape): """A transformation that batches ragged elements into `tf.SparseTensor`s. @@ -82,6 +194,157 @@ def dense_to_sparse_batch(batch_size, row_shape): return _apply_fn +def padded_batch_window(dataset, padded_shape, padding_value=None): + """Batches a window of tensors with padding. + + Args: + dataset: the input dataset. + padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like + object representing the shape to which the input elements should be padded + prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a + `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the + maximum size of that dimension in each batch. + padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the + padding value to use. Defaults are `0` for numeric types and the empty + string for string types. If `dataset` contains `tf.SparseTensor`, this + value is ignored. + + Returns: + A `Tensor` representing the batch of the entire input dataset. + + Raises: + ValueError: if invalid arguments are provided. + """ + if not issubclass(dataset.output_classes, + (ops.Tensor, sparse_tensor.SparseTensor)): + raise TypeError("Input dataset expected to have a single tensor component") + if issubclass(dataset.output_classes, (ops.Tensor)): + return _padded_batch_dense_window(dataset, padded_shape, padding_value) + elif issubclass(dataset.output_classes, (sparse_tensor.SparseTensor)): + if padding_value is not None: + raise ValueError("Padding value not allowed for sparse tensors") + return _padded_batch_sparse_window(dataset, padded_shape) + else: + raise TypeError("Unsupported dataset type: %s" % dataset.output_classes) + + +def _padded_batch_dense_window(dataset, padded_shape, padding_value=None): + """Batches a window of dense tensors with padding.""" + + padded_shape = math_ops.cast( + convert.partial_shape_to_tensor(padded_shape), dtypes.int32) + + def key_fn(_): + return np.int64(0) + + def max_init_fn(_): + return padded_shape + + def max_reduce_fn(state, value): + """Computes the maximum shape to pad to.""" + condition = math_ops.reduce_all( + math_ops.logical_or( + math_ops.less_equal(array_ops.shape(value), padded_shape), + math_ops.equal(padded_shape, -1))) + assert_op = control_flow_ops.Assert(condition, [ + "Actual shape greater than padded shape: ", + array_ops.shape(value), padded_shape + ]) + with ops.control_dependencies([assert_op]): + return math_ops.maximum(state, array_ops.shape(value)) + + def finalize_fn(state): + return state + + # Compute the padded shape. + max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn) + padded_shape = get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, max_reducer))) + + if padding_value is None: + if dataset.output_types == dtypes.string: + padding_value = "" + elif dataset.output_types == dtypes.bool: + padding_value = False + elif dataset.output_types == dtypes.variant: + raise TypeError("Unable to create padding for field of type 'variant'") + else: + padding_value = 0 + + def batch_init_fn(_): + return array_ops.fill( + array_ops.concat([np.array([0], dtype=np.int32), padded_shape], 0), + constant_op.constant(padding_value, dtype=dataset.output_types)) + + def batch_reduce_fn(state, value): + return array_ops.concat([state, [value]], 0) + + def pad_fn(value): + shape = array_ops.shape(value) + left = array_ops.zeros_like(shape) + right = padded_shape - shape + return array_ops.pad( + value, array_ops.stack([left, right], 1), constant_values=padding_value) + + batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) + return get_single_element.get_single_element( + dataset.map(pad_fn).apply( + grouping.group_by_reducer(key_fn, batch_reducer))) + + +def _padded_batch_sparse_window(dataset, padded_shape): + """Batches a window of sparse tensors with padding.""" + + def key_fn(_): + return np.int64(0) + + def max_init_fn(_): + return convert.partial_shape_to_tensor(padded_shape) + + def max_reduce_fn(state, value): + """Computes the maximum shape to pad to.""" + condition = math_ops.reduce_all( + math_ops.logical_or( + math_ops.less_equal(value.dense_shape, padded_shape), + math_ops.equal(padded_shape, -1))) + assert_op = control_flow_ops.Assert(condition, [ + "Actual shape greater than padded shape: ", value.dense_shape, + padded_shape + ]) + with ops.control_dependencies([assert_op]): + return math_ops.maximum(state, value.dense_shape) + + def finalize_fn(state): + return state + + # Compute the padded shape. + max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn) + padded_shape = get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, max_reducer))) + + def batch_init_fn(_): + indices_shape = array_ops.concat([[0], [array_ops.size(padded_shape) + 1]], + 0) + return sparse_tensor.SparseTensor( + indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64), + values=constant_op.constant([], shape=[0], dtype=dataset.output_types), + dense_shape=array_ops.concat( + [np.array([0], dtype=np.int64), padded_shape], 0)) + + def batch_reduce_fn(state, value): + padded_value = sparse_tensor.SparseTensor( + indices=value.indices, values=value.values, dense_shape=padded_shape) + reshaped_value = sparse_ops.sparse_reshape( + padded_value, + array_ops.concat( + [np.array([1], dtype=np.int64), padded_value.dense_shape], 0)) + return sparse_ops.sparse_concat(0, [state, reshaped_value]) + + reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn) + return get_single_element.get_single_element( + dataset.apply(grouping.group_by_reducer(key_fn, reducer))) + + class _UnbatchDataset(dataset_ops.Dataset): """A dataset that splits the elements of its input into multiple elements.""" @@ -175,48 +438,6 @@ def unbatch(): return _apply_fn -def _filter_irregular_batches(batch_size): - """Transformation that filters out batches that are not of size batch_size.""" - - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - tensor_batch_size = ops.convert_to_tensor( - batch_size, dtype=dtypes.int64, name="batch_size") - - flattened = _RestructuredDataset( - dataset, - tuple(nest.flatten(dataset.output_types)), - output_classes=tuple(nest.flatten(dataset.output_classes))) - - def _predicate(*xs): - """Return `True` if this element is a full batch.""" - # Extract the dynamic batch size from the first component of the flattened - # batched element. - first_component = xs[0] - first_component_batch_size = array_ops.shape( - first_component, out_type=dtypes.int64)[0] - - return math_ops.equal(first_component_batch_size, tensor_batch_size) - - filtered = flattened.filter(_predicate) - - maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size) - - def _set_first_dimension(shape): - return shape.merge_with( - tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:])) - - known_shapes = nest.map_structure(_set_first_dimension, - dataset.output_shapes) - return _RestructuredDataset( - filtered, - dataset.output_types, - known_shapes, - output_classes=dataset.output_classes) - - return _apply_fn - - @deprecation.deprecated( None, "Use `tf.data.Dataset.batch(..., drop_remainder=True)`.") def batch_and_drop_remainder(batch_size): @@ -251,10 +472,7 @@ def batch_and_drop_remainder(batch_size): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - # TODO(jsimsa): Switch to using `batch(..., drop_remainder=True)` any time - # after 6/30/2018. - batched = dataset.batch(batch_size) - return _filter_irregular_batches(batch_size)(batched) + return dataset.batch(batch_size, drop_remainder=True) return _apply_fn @@ -289,11 +507,9 @@ def padded_batch_and_drop_remainder(batch_size, def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - # TODO(jsimsa): Switch to using `padded_batch(..., drop_remainder=True)` - # any time after 6/30/2018. - batched = dataset.padded_batch( - batch_size, padded_shapes=padded_shapes, padding_values=padding_values) - return _filter_irregular_batches(batch_size)(batched) + return dataset.padded_batch( + batch_size, padded_shapes=padded_shapes, padding_values=padding_values, + drop_remainder=True) return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py index 0f4cd8e20c5727a5bcfa1dce4dadbfa8f90bd551..ef9284456eb35099db804e0680abfacd6384d503 100644 --- a/tensorflow/contrib/data/python/ops/get_single_element.py +++ b/tensorflow/contrib/data/python/ops/get_single_element.py @@ -17,6 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + +from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse @@ -68,3 +71,30 @@ def get_single_element(dataset): return sparse.deserialize_sparse_tensors( nested_ret, dataset.output_types, dataset.output_shapes, dataset.output_classes) + + +def reduce_dataset(dataset, reducer): + """Returns the result of reducing the `dataset` using `reducer`. + + Args: + dataset: A @{tf.data.Dataset} object. + reducer: A @{tf.contrib.data.Reducer} object representing the reduce logic. + + Returns: + A nested structure of @{tf.Tensor} objects, corresponding to the result + of reducing `dataset` using `reducer`. + + Raises: + TypeError: if `dataset` is not a `tf.data.Dataset` object. + """ + if not isinstance(dataset, dataset_ops.Dataset): + raise TypeError("`dataset` must be a `tf.data.Dataset` object.") + + # The sentinel dataset is used in case the reduced dataset is empty. + sentinel_dataset = dataset_ops.Dataset.from_tensors( + reducer.finalize_func(reducer.init_func(np.int64(0)))) + reduced_dataset = dataset.apply( + grouping.group_by_reducer(lambda x: np.int64(0), reducer)) + + return get_single_element( + reduced_dataset.concatenate(sentinel_dataset).take(1)) diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index ca9540bf136a5028c4321319bdfacaf8a16484c7..bd8d398c58cc1825616c1ab5337cf6668c66697e 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -149,9 +149,9 @@ def bucket_by_sequence_length(element_length_func, @{tf.data.Dataset.padded_batch}. Defaults to padding with 0. pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown size to maximum length in batch. If `True`, will pad dimensions with - unknown size to bucket boundary, and caller must ensure that the source - `Dataset` does not contain any elements with length longer than - `max(bucket_boundaries)`. + unknown size to bucket boundary minus 1 (i.e., the maximum length in each + bucket), and caller must ensure that the source `Dataset` does not contain + any elements with length longer than `max(bucket_boundaries)`. Returns: A `Dataset` transformation function, which can be passed to @@ -203,7 +203,7 @@ def bucket_by_sequence_length(element_length_func, none_filler = None if pad_to_bucket_boundary: err_msg = ("When pad_to_bucket_boundary=True, elements must have " - "length <= max(bucket_boundaries).") + "length < max(bucket_boundaries).") check = check_ops.assert_less( bucket_id, constant_op.constant(len(bucket_batch_sizes) - 1, @@ -213,7 +213,7 @@ def bucket_by_sequence_length(element_length_func, boundaries = constant_op.constant(bucket_boundaries, dtype=dtypes.int64) bucket_boundary = boundaries[bucket_id] - none_filler = bucket_boundary + none_filler = bucket_boundary - 1 shapes = make_padded_shapes( padded_shapes or grouped_dataset.output_shapes, none_filler=none_filler) @@ -227,6 +227,50 @@ def bucket_by_sequence_length(element_length_func, return _apply_fn +def _map_x_dataset(map_func): + """A transformation that maps `map_func` across its input. + + This transformation is similar to `tf.data.Dataset.map`, but in addition to + supporting dense and sparse tensor inputs, it also supports dataset inputs. + + Args: + map_func: A function mapping a nested structure of tensors and/or datasets + (having shapes and types defined by `self.output_shapes` and + `self.output_types`) to another nested structure of tensors and/or + datasets. + + Returns: + Dataset: A `Dataset`. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return _MapXDataset(dataset, map_func) + + return _apply_fn + + +def window_dataset(window_size): + """A transformation that creates window datasets from the input dataset. + + The resulting datasets will contain `window_size` elements (or + `N % window_size` for the last dataset if `window_size` does not divide the + number of input elements `N` evenly). + + Args: + window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + consecutive elements of the input dataset to combine into a window. + + Returns: + Dataset: A `Dataset`. + """ + + def _apply_fn(dataset): + return _WindowDataset(dataset, window_size) + + return _apply_fn + + class _GroupByReducerDataset(dataset_ops.Dataset): """A `Dataset` that groups its input and performs a reduction.""" @@ -468,3 +512,85 @@ class Reducer(object): @property def finalize_func(self): return self._finalize_func + + +class _MapXDataset(dataset_ops.Dataset): + """A `Dataset` that maps a function over elements in its input.""" + + def __init__(self, input_dataset, map_func): + """See `map_x_dataset()` for details.""" + super(_MapXDataset, self).__init__() + self._input_dataset = input_dataset + + wrapped_func = dataset_ops.StructuredFunctionWrapper( + map_func, + "tf.contrib.data.map_x_dataset()", + input_dataset, + experimental_nested_dataset_support=True) + self._output_classes = wrapped_func.output_classes + self._output_shapes = wrapped_func.output_shapes + self._output_types = wrapped_func.output_types + self._map_func = wrapped_func.function + + def _as_variant_tensor(self): + input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access + return gen_dataset_ops.map_dataset( + input_t, + self._map_func.captured_inputs, + f=self._map_func, + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + +class _WindowDataset(dataset_ops.Dataset): + """A dataset that creates window datasets from the input elements.""" + + def __init__(self, input_dataset, window_size): + """See `window_dataset()` for more details.""" + super(_WindowDataset, self).__init__() + self._input_dataset = input_dataset + self._window_size = ops.convert_to_tensor( + window_size, dtype=dtypes.int64, name="window_size") + self._output_classes = nest.pack_sequence_as( + input_dataset.output_classes, + [ + dataset_ops._NestedDatasetComponent( # pylint: disable=protected-access + output_classes=output_class, + output_shapes=output_shape, + output_types=output_type) + for output_class, output_shape, output_type in zip( + nest.flatten(input_dataset.output_classes), + nest.flatten(input_dataset.output_shapes), + nest.flatten(input_dataset.output_types)) + ]) + self._output_shapes = self._output_classes + self._output_types = self._output_classes + + def _as_variant_tensor(self): + return gen_dataset_ops.window_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._window_size, + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index 0d71be66018eeebe60de9deff24ceb6854d209d9..d2c1d0d3620f94f867395a8e2fff0d77a6dc0718 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -20,6 +20,7 @@ from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import session_run_hook @@ -206,7 +207,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): # Check if there is an existing checkpoint. If so, restore from it. # pylint: disable=protected-access - latest_checkpoint_path = saver_lib.latest_checkpoint( + latest_checkpoint_path = checkpoint_management.latest_checkpoint( self._checkpoint_saver_hook._checkpoint_dir, latest_filename=self._latest_filename) if latest_checkpoint_path: diff --git a/tensorflow/contrib/data/python/ops/map_defun.py b/tensorflow/contrib/data/python/ops/map_defun.py new file mode 100644 index 0000000000000000000000000000000000000000..54d5cd6da068fa5471b7beafcc66d76b5972e7d5 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/map_defun.py @@ -0,0 +1,58 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for optimizing `tf.data` pipelines.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops + + +def map_defun(fn, elems, output_dtypes, output_shapes): + """Map a function on the list of tensors unpacked from `elems` on dimension 0. + + Args: + fn: A function (`function.Defun`) that takes a list of tensors and returns + another list of tensors. The output list has the same types as + output_dtypes. The elements of the output list have the same dimension 0 + as `elems`, and the remaining dimensions correspond to those of + `fn_output_shapes`. + elems: A list of tensors. + output_dtypes: A list of dtypes corresponding to the output types of the + function. + output_shapes: A list of `TensorShape`s corresponding to the output + shapes from each invocation of the function on slices of inputs. + + Raises: + ValueError: if any of the inputs are malformed. + + Returns: + A list of `Tensor` objects with the same types as `output_dtypes`. + """ + if not isinstance(elems, list): + raise ValueError("`elems` must be a list of tensors.") + if not isinstance(output_dtypes, list): + raise ValueError("`output_dtypes` must be a list of tensors.") + if not isinstance(output_shapes, list): + raise ValueError("`output_shapes` must be a list of tensors.") + + elems = [ops.convert_to_tensor(e) for e in elems] + output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes] + if not all(s.is_fully_defined() for s in output_shapes): + raise ValueError("All fn output shapes must be fully defined.") + return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn) diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py index cf896572262929add5ac34d4fc8e4192c1049da3..018c5115e1d5599e48bf99ccf832c7962794fc40 100644 --- a/tensorflow/contrib/data/python/ops/optimization.py +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -18,12 +18,34 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops +# TODO(jsimsa): Support RE matching for both individual transformation (e.g. to +# account for indexing) and transformation sequence. +def assert_next(transformations): + """A transformation that asserts which transformations happen next. + + Args: + transformations: A `tf.string` vector `tf.Tensor` identifying the + transformations that are expected to happen next. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return _AssertNextDataset(dataset, transformations) + + return _apply_fn + + def optimize(optimizations=None): """A transformation that applies optimizations. @@ -44,6 +66,37 @@ def optimize(optimizations=None): return _apply_fn +class _AssertNextDataset(dataset_ops.Dataset): + """A `Dataset` that asserts which transformations happen next.""" + + def __init__(self, input_dataset, transformations): + """See `assert_next()` for details.""" + super(_AssertNextDataset, self).__init__() + self._input_dataset = input_dataset + if transformations is None: + raise ValueError("At least one transformation should be specified") + self._transformations = ops.convert_to_tensor( + transformations, dtype=dtypes.string, name="transformations") + + def _as_variant_tensor(self): + return contrib_gen_dataset_ops.assert_next_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._transformations, + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + class _OptimizeDataset(dataset_ops.Dataset): """A `Dataset` that acts as an identity, and applies optimizations.""" diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index 21fc17102e16a1f98f2c2e8aa0aeec89989edf67..0243c72c70716d31e3ab8b6a3da2270ee0bbc91b 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -26,10 +26,15 @@ from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.eager import context +from tensorflow.python.framework import device as framework_device from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_dataset_ops as core_gen_dataset_ops +from tensorflow.python.ops import resource_variable_ops def function_buffering_resource(string_arg, @@ -345,3 +350,358 @@ def prefetch_to_device(device, buffer_size=None): return _PrefetchToDeviceDataset(dataset, device, buffer_size) return _apply_fn + + +def copy_to_device(target_device, source_device="/cpu:0"): + """A transformation that copies dataset elements to the given `target_device`. + + Args: + target_device: The name of a device to which elements will be copied. + source_device: The original device on which `input_dataset` will be placed. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _CopyToDeviceDataset( + dataset, target_device=target_device, source_device=source_device) + + return _apply_fn + + +# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate +# all inputs to the Op are in host memory, thereby avoiding some unnecessary +# Sends and Recvs. +class _CopyToDeviceDataset(dataset_ops.Dataset): + """A `Dataset` that copies elements to another device.""" + + def __init__(self, input_dataset, target_device, source_device="/cpu:0"): + """Constructs a _CopyToDeviceDataset. + + Args: + input_dataset: `Dataset` to be copied + target_device: The name of the device to which elements would be copied. + source_device: Device where input_dataset would be placed. + """ + self._input_dataset = input_dataset + self._target_device = target_device + spec = framework_device.DeviceSpec().from_string(self._target_device) + self._is_gpu_target = (spec.device_type == "GPU") + self._source_device_string = source_device + self._source_device = ops.convert_to_tensor(source_device) + + self._flat_output_shapes = nest.flatten( + sparse.as_dense_shapes(self._input_dataset.output_shapes, + self._input_dataset.output_classes)) + self._flat_output_types = nest.flatten( + sparse.as_dense_types(self._input_dataset.output_types, + self._input_dataset.output_classes)) + + @function.Defun() + def _init_func(): + """Creates an iterator for the input dataset. + + Returns: + A `string` tensor that encapsulates the iterator created. + """ + # pylint: disable=protected-access + ds_variant = self._input_dataset._as_variant_tensor() + resource = core_gen_dataset_ops.anonymous_iterator( + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + with ops.control_dependencies( + [core_gen_dataset_ops.make_iterator(ds_variant, resource)]): + return core_gen_dataset_ops.iterator_to_string_handle(resource) + + @function.Defun() + def _remote_init_func(): + return functional_ops.remote_call( + target=self._source_device, + args=_init_func.captured_inputs, + Tout=[dtypes.string], + f=_init_func) + + self._init_func = _remote_init_func + self._init_captured_args = _remote_init_func.captured_inputs + + @function.Defun(dtypes.string) + def _next_func(string_handle): + """Calls get_next for created iterator. + + Args: + string_handle: An iterator string handle created by _init_func + Returns: + The elements generated from `input_dataset` + """ + with ops.device(self._source_device_string): + iterator = iterator_ops.Iterator.from_string_handle( + string_handle, self.output_types, self.output_shapes, + self.output_classes) + ret = iterator.get_next() + return nest.flatten(sparse.serialize_sparse_tensors(ret)) + + @function.Defun(dtypes.string) + def _remote_next_func(string_handle): + return functional_ops.remote_call( + target=self._source_device, + args=[string_handle] + _next_func.captured_inputs, + Tout=self._flat_output_types, + f=_next_func) + + self._next_func = _remote_next_func + self._next_captured_args = _remote_next_func.captured_inputs + + @function.Defun(dtypes.string) + def _finalize_func(string_handle): + """Destroys the iterator resource created. + + Args: + string_handle: An iterator string handle created by _init_func + Returns: + Tensor constant 0 + """ + iterator_resource = core_gen_dataset_ops.iterator_from_string_handle_v2( + string_handle, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + with ops.control_dependencies([ + resource_variable_ops.destroy_resource_op( + iterator_resource, ignore_lookup_error=True)]): + return array_ops.constant(0, dtypes.int64) + + @function.Defun(dtypes.string) + def _remote_finalize_func(string_handle): + return functional_ops.remote_call( + target=self._source_device, + args=[string_handle] + _finalize_func.captured_inputs, + Tout=[dtypes.int64], + f=_finalize_func) + + self._finalize_func = _remote_finalize_func + self._finalize_captured_args = _remote_finalize_func.captured_inputs + + g = ops.get_default_graph() + _remote_init_func.add_to_graph(g) + _remote_next_func.add_to_graph(g) + _remote_finalize_func.add_to_graph(g) + # pylint: enable=protected-scope + + # The one_shot_iterator implementation needs a 0 arg _make_dataset function + # that thereby captures all the inputs required to create the dataset. Since + # there are strings that are inputs to the GeneratorDataset which can't be + # placed on a GPU, this fails for the GPU case. Therefore, disabling it for + # GPU + def make_one_shot_iterator(self): + if self._is_gpu_target: + raise ValueError("Cannot create a one shot iterator when using " + "`tf.contrib.data.copy_to_device()` on GPU. Please use " + "`Dataset.make_initializable_iterator()` instead.") + else: + return super(_CopyToDeviceDataset, self).make_one_shot_iterator() + + def _as_variant_tensor(self): + with ops.device(self._target_device): + return core_gen_dataset_ops.generator_dataset( + self._init_captured_args, + self._next_captured_args, + self._finalize_captured_args, + init_func=self._init_func, + next_func=self._next_func, + finalize_func=self._finalize_func, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +class _PerDeviceGenerator(dataset_ops.Dataset): + """A `dummy` generator dataset.""" + + def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id, + source_device, target_device, output_shapes, output_types, + output_classes): + self._target_device = target_device + self._output_types = output_types + self._output_shapes = output_shapes + self._output_classes = output_classes + self._flat_output_shapes = nest.flatten( + sparse.as_dense_shapes(self._output_shapes, self._output_classes)) + self._flat_output_types = nest.flatten( + sparse.as_dense_types(self._output_types, self._output_classes)) + + multi_device_iterator_string_handle = ( + gen_dataset_ops.multi_device_iterator_to_string_handle( + multi_device_iterator_resource)) + + @function.Defun() + def _init_func(): + return multi_device_iterator_string_handle + + @function.Defun() + def _remote_init_func(): + return functional_ops.remote_call( + target=source_device, + args=_init_func.captured_inputs, + Tout=[dtypes.string], + f=_init_func) + + self._init_func = _remote_init_func + self._init_captured_args = _remote_init_func.captured_inputs + + @function.Defun(dtypes.string) + def _next_func(string_handle): + multi_device_iterator = ( + gen_dataset_ops.multi_device_iterator_from_string_handle( + string_handle=string_handle, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes)) + return gen_dataset_ops.multi_device_iterator_get_next_from_shard( + multi_device_iterator=multi_device_iterator, + shard_num=shard_num, + incarnation_id=incarnation_id, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + + @function.Defun(dtypes.string) + def _remote_next_func(string_handle): + return functional_ops.remote_call( + target=source_device, + args=[string_handle] + _next_func.captured_inputs, + Tout=self._flat_output_types, + f=_next_func) + + self._next_func = _remote_next_func + self._next_captured_args = _remote_next_func.captured_inputs + + @function.Defun(dtypes.string) + def _finalize_func(unused_string_handle): + return array_ops.constant(0, dtypes.int64) + + @function.Defun(dtypes.string) + def _remote_finalize_func(string_handle): + return functional_ops.remote_call( + target=source_device, + args=[string_handle] + _finalize_func.captured_inputs, + Tout=[dtypes.int64], + f=_finalize_func) + + self._finalize_func = _remote_finalize_func + self._finalize_captured_args = _remote_finalize_func.captured_inputs + + def _as_variant_tensor(self): + with ops.device(self._target_device): + return core_gen_dataset_ops.generator_dataset( + self._init_captured_args, + self._next_captured_args, + self._finalize_captured_args, + init_func=self._init_func, + next_func=self._next_func, + finalize_func=self._finalize_func, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + + @property + def output_types(self): + return self._output_types + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_classes(self): + return self._output_classes + + +class MultiDeviceIterator(object): + """An iterator over multiple devices.""" + + def __init__(self, + dataset, + devices, + prefetch_buffer_size=1, + source_device="/cpu:0"): + """Constructs a MultiDeviceIterator. + + Args: + dataset: The input dataset to be iterated over. + devices: The list of devices to fetch data to. + prefetch_buffer_size: if > 1, then we setup a buffer on each device + to prefetch into. + source_device: The host device to place the `dataset` on. + """ + self._dataset = dataset + self._devices = devices + self._source_device = source_device + self._source_device_tensor = ops.convert_to_tensor(source_device) + + self._flat_output_shapes = nest.flatten( + sparse.as_dense_shapes(self._dataset.output_shapes, + self._dataset.output_classes)) + self._flat_output_types = nest.flatten( + sparse.as_dense_types(self._dataset.output_types, + self._dataset.output_classes)) + + # Create the MultiDeviceIterator. + with ops.device(self._source_device): + self._multi_device_iterator_resource = ( + gen_dataset_ops.multi_device_iterator( + devices=self._devices, + shared_name="", + container="", + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes)) + + # The incarnation ID is used to ensure consistency between the per-device + # iterators and the multi-device iterator. + self._incarnation_id = gen_dataset_ops.multi_device_iterator_init( + self._dataset._as_variant_tensor(), # pylint: disable=protected-access + self._multi_device_iterator_resource) + + # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to + # initialize the device side of the pipeline. This would allow the + # MultiDeviceIterator to choose, for example, to move some transformations + # into the device side from its input. It might be useful in rewriting. + # Create the per device iterators. + self._device_iterators = [] + i = 0 + for device in self._devices: + ds = _PerDeviceGenerator( + i, self._multi_device_iterator_resource, self._incarnation_id, + self._source_device_tensor, device, self._dataset.output_shapes, + self._dataset.output_types, self._dataset.output_classes) + if prefetch_buffer_size > 0: + ds = ds.prefetch(prefetch_buffer_size) + with ops.device(device): + self._device_iterators.append(ds.make_initializable_iterator()) + i += 1 + + device_iterator_initializers = [ + iterator.initializer for iterator in self._device_iterators + ] + self._initializer = control_flow_ops.group(*device_iterator_initializers) + + def get_next(self): + result = [] + i = 0 + for device in self._devices: + with ops.device(device): + result.append(self._device_iterators[i].get_next()) + i += 1 + return result + + @property + def initializer(self): + return self._initializer diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 83095c7ba1c6465d18490e5197f71bf7f1fe2497..14d69f8d5b29d43649185e689c7a8e6604361bca 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -286,11 +286,14 @@ def make_tf_record_dataset( dataset = _maybe_shuffle_and_repeat( dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to + # improve the shape inference, because it makes the batch dimension static. + # It is safe to do this because in that case we are repeating the input + # indefinitely, and all batches will be full-sized. + drop_final_batch = drop_final_batch or num_epochs is None + if parser_fn is None: - if drop_final_batch: - dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) - else: - dataset = dataset.batch(batch_size) + dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch) else: # TODO(josh11b): if num_parallel_parser_calls is None, use some function # of num cores instead of map_and_batch's default behavior of one batch. @@ -326,6 +329,7 @@ def make_csv_dataset( num_parallel_parser_calls=2, sloppy=False, num_rows_for_inference=100, + compression_type=None, ): """Reads CSV files into a dataset. @@ -399,6 +403,8 @@ def make_csv_dataset( num_rows_for_inference: Number of rows of a file to use for type inference if record_defaults is not provided. If None, reads all the rows of all the files. Defaults to 100. + compression_type: (Optional.) A `tf.string` scalar evaluating to one of + `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no compression. Returns: A dataset, where each element is a (features, labels) tuple that corresponds @@ -461,7 +467,9 @@ def make_csv_dataset( use_quote_delim=use_quote_delim, na_value=na_value, select_cols=select_columns, - header=header) + header=header, + compression_type=compression_type, + ) def map_fn(*columns): """Organizes columns into a features dictionary. @@ -488,8 +496,13 @@ def make_csv_dataset( dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) # Apply batch before map for perf, because map has high overhead relative - # to the size of the computation in each map - dataset = dataset.batch(batch_size=batch_size) + # to the size of the computation in each map. + # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to + # improve the shape inference, because it makes the batch dimension static. + # It is safe to do this because in that case we are repeating the input + # indefinitely, and all batches will be full-sized. + dataset = dataset.batch(batch_size=batch_size, + drop_remainder=num_epochs is None) dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_parser_calls) dataset = dataset.prefetch(prefetch_buffer_size) @@ -505,6 +518,7 @@ class CsvDataset(dataset_ops.Dataset): def __init__(self, filenames, record_defaults, + compression_type=None, buffer_size=None, header=False, field_delim=",", @@ -540,11 +554,11 @@ class CsvDataset(dataset_ops.Dataset): The expected output of its iterations is: ```python - next = dataset.make_one_shot_iterator().get_next() + next_element = dataset.make_one_shot_iterator().get_next() with tf.Session() as sess: while True: try: - print(sess.run(nxt)) + print(sess.run(next_element)) except tf.errors.OutOfRangeError: break @@ -562,6 +576,9 @@ class CsvDataset(dataset_ops.Dataset): both this and `select_columns` are specified, these must have the same lengths, and `column_defaults` is assumed to be sorted in order of increasing column index. + compression_type: (Optional.) A `tf.string` scalar evaluating to one of + `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no + compression. buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes to buffer while reading files. Defaults to 4MB. header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s) @@ -581,6 +598,11 @@ class CsvDataset(dataset_ops.Dataset): super(CsvDataset, self).__init__() self._filenames = ops.convert_to_tensor( filenames, dtype=dtypes.string, name="filenames") + self._compression_type = convert.optional_param_to_tensor( + "compression_type", + compression_type, + argument_default="", + argument_dtype=dtypes.string) record_defaults = [ constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x for x in record_defaults @@ -621,6 +643,7 @@ class CsvDataset(dataset_ops.Dataset): use_quote_delim=self._use_quote_delim, na_value=self._na_value, select_cols=self._select_cols, + compression_type=self._compression_type, ) @property @@ -757,10 +780,12 @@ def make_batched_features_dataset(file_pattern, dataset = dataset.apply(stats_ops.feature_stats("record_stats")) - if drop_final_batch: - dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) - else: - dataset = dataset.batch(batch_size) + # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to + # improve the shape inference, because it makes the batch dimension static. + # It is safe to do this because in that case we are repeating the input + # indefinitely, and all batches will be full-sized. + dataset = dataset.batch( + batch_size, drop_remainder=drop_final_batch or num_epochs is None) # Parse `Example` tensors to a dictionary of `Feature` tensors. dataset = dataset.map( diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index 3f3c5ca17cf6ae22a719ed1d593d98eec37413fb..e9dd74530ac64cd414d53eab5294eaa95c919131 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -23,25 +23,29 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.util import deprecation class _SlideDataset(dataset_ops.Dataset): """A `Dataset` that passes a sliding window over its input.""" - def __init__(self, input_dataset, window_size, stride=1): + def __init__(self, input_dataset, window_size, window_shift, window_stride): """See `sliding_window_batch` for details.""" super(_SlideDataset, self).__init__() self._input_dataset = input_dataset self._window_size = ops.convert_to_tensor( - window_size, dtype=dtypes.int64, name="window_size") - self._stride = ops.convert_to_tensor( - stride, dtype=dtypes.int64, name="stride") + window_size, dtype=dtypes.int64, name="window_stride") + self._window_stride = ops.convert_to_tensor( + window_stride, dtype=dtypes.int64, name="window_stride") + self._window_shift = ops.convert_to_tensor( + window_shift, dtype=dtypes.int64, name="window_shift") def _as_variant_tensor(self): return gen_dataset_ops.slide_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access window_size=self._window_size, - stride=self._stride, + window_shift=self._window_shift, + window_stride=self._window_stride, **dataset_ops.flat_structure(self)) @property @@ -61,38 +65,63 @@ class _SlideDataset(dataset_ops.Dataset): return self._input_dataset.output_types -def sliding_window_batch(window_size, stride=1): - """A sliding window with size of `window_size` and step of `stride`. +@deprecation.deprecated_args( + None, "stride is deprecated, use window_shift instead", "stride") +def sliding_window_batch(window_size, + stride=None, + window_shift=None, + window_stride=1): + """A sliding window over a dataset. - This transformation passes a sliding window over this dataset. The - window size is `window_size` and step size is `stride`. If the left - elements cannot fill up the sliding window, this transformation will - drop the final smaller element. For example: + This transformation passes a sliding window over this dataset. The window size + is `window_size`, the stride of the input elements is `window_stride`, and the + shift between consecutive windows is `window_shift`. If the remaining elements + cannot fill up the sliding window, this transformation will drop the final + smaller element. For example: ```python # NOTE: The following examples use `{ ... }` to represent the # contents of a dataset. a = { [1], [2], [3], [4], [5], [6] } - a.apply(tf.contrib.data.sliding_window_batch(window_size=3, stride=2)) == - { - [[1], [2], [3]], - [[3], [4], [5]], - } + a.apply(sliding_window_batch(window_size=3)) == + { [[1], [2], [3]], [[2], [3], [4]], [[3], [4], [5]], [[4], [5], [6]] } + + a.apply(sliding_window_batch(window_size=3, window_shift=2)) == + { [[1], [2], [3]], [[3], [4], [5]] } + + a.apply(sliding_window_batch(window_size=3, window_stride=2)) == + { [[1], [3], [5]], [[2], [4], [6]] } ``` Args: window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of - elements in the sliding window. + elements in the sliding window. It must be positive. stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the - steps moving the sliding window forward for one iteration. The default - is `1`. It must be positive. + forward shift of the sliding window in each iteration. The default is `1`. + It must be positive. Deprecated alias for `window_shift`. + window_shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + forward shift of the sliding window in each iteration. The default is `1`. + It must be positive. + window_stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + stride of the input elements in the sliding window. The default is `1`. + It must be positive. Returns: A `Dataset` transformation function, which can be passed to @{tf.data.Dataset.apply}. + + Raises: + ValueError: if invalid arguments are provided. """ + if stride is None and window_shift is None: + window_shift = 1 + elif stride is not None and window_shift is None: + window_shift = stride + elif stride is not None and window_shift is not None: + raise ValueError("Cannot specify both `stride` and `window_shift`") + def _apply_fn(dataset): - return _SlideDataset(dataset, window_size, stride) + return _SlideDataset(dataset, window_size, window_shift, window_stride) return _apply_fn diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD index 74b2cd90a187159fd2da8ce236c14e813cc43c49..d3628d480d31017f835b39f750df40cafa2cc0db 100644 --- a/tensorflow/contrib/distribute/BUILD +++ b/tensorflow/contrib/distribute/BUILD @@ -25,11 +25,15 @@ py_library( srcs = ["__init__.py"], visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/contrib/distribute/python:collective_all_reduce_strategy", "//tensorflow/contrib/distribute/python:cross_tower_ops", "//tensorflow/contrib/distribute/python:mirrored_strategy", "//tensorflow/contrib/distribute/python:monitor", + "//tensorflow/contrib/distribute/python:multi_worker_strategy", "//tensorflow/contrib/distribute/python:one_device_strategy", + "//tensorflow/contrib/distribute/python:parameter_server_strategy", "//tensorflow/contrib/distribute/python:step_fn", + "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:training", "//tensorflow/python:util", ], diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index 44a4481021c380e72b535cf0aca39df2bf04d3b7..2f5dd10550d0771d0cd3c2501d0456dc95077386 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -116,8 +116,6 @@ in the input function gives a solid boost in performance. When using ## Caveats This feature is in early stages and there are a lot of improvements forthcoming: -* Metrics are not yet supported during distributed training. They are still -supported during the evaluation. * Summaries are only computed in the first tower in `MirroredStrategy`. * Evaluation is not yet distributed. * Eager support is in the works; performance can be more challenging with eager diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index 76711baf3a11c8978fbb5770ec173ff74a153158..9123ca749b68a1d0066313c77914fa3fb8006a9e 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -19,11 +19,15 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import CollectiveAllReduceStrategy from tensorflow.contrib.distribute.python.cross_tower_ops import * from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy +from tensorflow.contrib.distribute.python.multi_worker_strategy import MultiWorkerMirroredStrategy from tensorflow.contrib.distribute.python.monitor import Monitor from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy +from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.step_fn import * +from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy from tensorflow.python.training.distribute import * from tensorflow.python.util.all_util import remove_undocumented @@ -31,16 +35,20 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'AllReduceCrossTowerOps', + 'CollectiveAllReduceStrategy', 'CrossTowerOps', 'DistributionStrategy', 'MirroredStrategy', + 'MultiWorkerMirroredStrategy', 'Monitor', 'OneDeviceStrategy', + 'ParameterServerStrategy', 'ReductionToOneDeviceCrossTowerOps', 'Step', 'StandardInputStep', 'StandardSingleLossStep', 'TowerContext', + 'TPUStrategy', 'get_cross_tower_context', 'get_distribution_strategy', 'get_loss_reduction', diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index eba0dd0ea330e29db0ea8e68ee14767fcb8ddad0..3159dd154aee2c00b24d36f841584e2582c99b19 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -100,6 +100,23 @@ py_library( ], ) +py_library( + name = "parameter_server_strategy", + srcs = ["parameter_server_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":cross_tower_ops", + ":mirrored_strategy", + ":values", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + ], +) + py_library( name = "one_device_strategy", srcs = ["one_device_strategy.py"], @@ -116,6 +133,24 @@ py_library( ], ) +py_library( + name = "collective_all_reduce_strategy", + srcs = ["collective_all_reduce_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":cross_tower_ops", + ":cross_tower_utils", + ":mirrored_strategy", + ":values", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:collective_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/eager:context", + ], +) + py_library( name = "strategy_test_lib", testonly = 1, @@ -152,6 +187,7 @@ py_library( ":multi_worker_strategy", ":one_device_strategy", ":tpu_strategy", + "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/optimizer_v2:training", "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", @@ -207,6 +243,35 @@ py_test( ], ) +py_test( + name = "parameter_server_strategy_test", + srcs = ["parameter_server_strategy_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":combinations", + ":multi_worker_test_base", + ":parameter_server_strategy", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:layers", + "//tensorflow/python:session", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/estimator:run_config", + "@absl_py//absl/testing:parameterized", + ], +) + cuda_py_test( name = "mirrored_strategy_multigpu_test", srcs = ["mirrored_strategy_multigpu_test.py"], @@ -247,11 +312,11 @@ py_library( ], deps = [ "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", "//tensorflow/python:distributed_framework_test_lib", - "//tensorflow/python:platform", "//tensorflow/python:session", - "//tensorflow/python:training", - "//tensorflow/python/eager:test", + "//tensorflow/python/estimator:run_config", + "//third_party/py/numpy", ], ) @@ -272,8 +337,7 @@ py_library( deps = [ ":one_device_strategy", ":values", - "//tensorflow/contrib/tpu", - "//tensorflow/contrib/tpu:tpu_py", + "//tensorflow/contrib/tpu:tpu_lib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", @@ -281,6 +345,37 @@ py_library( ], ) +py_test( + name = "collective_all_reduce_strategy_test", + srcs = ["collective_all_reduce_strategy_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":collective_all_reduce_strategy", + ":combinations", + ":cross_tower_utils", + ":multi_worker_test_base", + ":strategy_test_lib", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/estimator:run_config", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + py_library( name = "minimize_loss_test_lib", testonly = 1, @@ -451,8 +546,11 @@ py_library( "//tensorflow/contrib/all_reduce:all_reduce_py", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/python:array_ops", + "//tensorflow/python:collective_ops", + "//tensorflow/python:device", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", "//tensorflow/python:math_ops", ], ) @@ -487,7 +585,9 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", + "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", "@six_archive//:six", ], @@ -495,6 +595,7 @@ py_library( cuda_py_test( name = "cross_tower_ops_test", + size = "large", srcs = ["cross_tower_ops_test.py"], additional_deps = [ ":combinations", @@ -509,7 +610,6 @@ cuda_py_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], - shard_count = 15, tags = [ "multi_and_single_gpu", "no_pip", @@ -587,6 +687,7 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", + "no_windows_gpu", "notsan", ], ) @@ -609,3 +710,40 @@ cuda_py_test( "no_pip", ], ) + +cuda_py_test( + name = "warm_starting_util_test", + size = "medium", + srcs = ["warm_starting_util_test.py"], + additional_deps = [ + ":combinations", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +cuda_py_test( + name = "checkpoint_utils_test", + size = "medium", + srcs = ["checkpoint_utils_test.py"], + additional_deps = [ + ":combinations", + "//tensorflow/python:client_testlib", + "//tensorflow/python:checkpoint_utils_test", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bcb977f64073b1d15ef5c872eb0d6b09d5307b54 --- /dev/null +++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py @@ -0,0 +1,78 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for checkpoint_utils.init_from_checkpoint with Distribution Strategy. + +These tests are located here instead of as part of +`python.training.CheckpointsTest` because they need access to distribution +strategies which are only present in contrib right now. +TODO(priyag): Move the tests to core `python.training.CheckpointsTest` when +distribution strategy moves out of contrib. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.framework import ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import checkpoint_utils_test + + +class CheckpointUtilsWithDistributionStrategyTest( + test.TestCase, parameterized.TestCase): + + @combinations.generate(combinations.combine( + distribution=[combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus], + in_tower_mode=[True, False], + mode=["graph"])) + def testInitFromCheckpoint(self, distribution, in_tower_mode): + checkpoint_dir = self.get_temp_dir() + with self.test_session() as session: + v1_value, v2_value, _, _ = checkpoint_utils_test._create_checkpoints( + session, checkpoint_dir) + + def init_and_verify(g): + v1 = variable_scope.get_variable("new_var1", [1, 10]) + v2 = variable_scope.get_variable( + "new_var2", [10, 10], + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.MEAN) + checkpoint_utils.init_from_checkpoint(checkpoint_dir, { + "var1": "new_var1", + "var2": "new_var2" + }) + with self.test_session(graph=g) as session: + session.run(variables.global_variables_initializer()) + self.assertAllEqual(v1_value, self.evaluate(v1)) + self.assertAllEqual(v2_value, self.evaluate(v2)) + + with ops.Graph().as_default() as g, distribution.scope(): + if in_tower_mode: + distribution.call_for_each_tower(init_and_verify, g) + else: + init_and_verify(g) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..9afcaecf78844b011a9dbc30bb95fa3bfeda8470 --- /dev/null +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -0,0 +1,205 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class CollectiveAllReduceStrategy implementing DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os + +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import values +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import collective_ops +from tensorflow.python.training import server_lib + + +# TODO(yuefengz): move this function to a common util file. +def _normalize_cluster_spec(cluster_spec): + if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): + return server_lib.ClusterSpec(cluster_spec) + elif not isinstance(cluster_spec, server_lib.ClusterSpec): + raise ValueError( + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + "`tf.train.ClusterDef` object") + return cluster_spec + + +# TODO(yuefengz): shard the dataset. +# TODO(yuefengz): support in-graph replication. +# TODO(yuefengz): it only works with a cluster without a chief node, maybe +# support chief node? +class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): + """Distribution strategy that uses collective ops for all-reduce. + + It is similar to the MirroredStrategy but it uses collective ops for + reduction. It currently only works for between-graph replication and its + reduction will reduce across all workers. + """ + + def __init__(self, + num_gpus_per_worker=0, + cluster_spec=None, + task_type="worker", + task_id=0): + """Initializes the object. + + Args: + num_gpus_per_worker: number of local GPUs or GPUs per worker. + cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the + cluster configurations. + task_type: the current task type, such as "worker". + task_id: the current task id. + + Raises: + ValueError: if `task_type` is not in the `cluster_spec`. + """ + self._num_gpus_per_worker = num_gpus_per_worker + self._initialize(cluster_spec, task_type, task_id) + + def _initialize(self, cluster_spec, task_type, task_id): + if task_type not in ["chief", "worker"]: + raise ValueError( + "Unrecognized task_type: %r, valid task types are: \"chief\", " + "\"worker\"." % task_type) + if cluster_spec: + self._cluster_spec = _normalize_cluster_spec(cluster_spec) + worker_device = "/job:%s/task:%d" % (task_type, task_id) + num_workers = len(self._cluster_spec.as_dict().get(task_type, [])) + if "chief" in self._cluster_spec.as_dict(): + num_workers += 1 + if not num_workers: + raise ValueError("`task_type` shoud be in `cluster_spec`.") + + # TODO(yuefengz): create a utility to infer chief. + if "chief" in self._cluster_spec.as_dict() and task_type == "chief": + assert task_id == 0 + self._is_chief = True + else: + assert task_type == "worker" + self._is_chief = task_id == 0 + else: + self._cluster_spec = None + self._is_chief = True + worker_device = "" + num_workers = 1 + self._num_workers = num_workers + + if self._num_gpus_per_worker: + local_devices = [ + "%s/device:GPU:%d" % (worker_device, i) + for i in range(self._num_gpus_per_worker) + ] + else: + local_devices = [worker_device] + + self._collective_keys = cross_tower_utils.CollectiveKeys() + super(CollectiveAllReduceStrategy, self).__init__( + devices=local_devices, + cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce( + num_workers=num_workers, + num_gpus_per_worker=self._num_gpus_per_worker, + collective_keys=self._collective_keys)) + + # Add a default device so that ops without specified devices will not end up + # on other workers. + if cluster_spec: + self._default_device = "/job:%s/replica:0/task:%d" % (task_type, task_id) + + def _create_variable(self, next_creator, *args, **kwargs): + colocate_with = kwargs.pop("colocate_with", None) + devices = self._get_devices_from(colocate_with) + group_size = len(devices) * self._num_workers + group_key = self._collective_keys.get_group_key(self._devices) + + def _real_mirrored_creator(devices, *args, **kwargs): + """Creates one MirroredVariable on the current worker.""" + index = {} + collective_instance_key = self._collective_keys.get_instance_key( + key_id=kwargs["name"]) + if "initial_value" not in kwargs: + raise ValueError("Initial value must be specified.") + initial_value = kwargs["initial_value"] + if callable(initial_value): + initial_value_fn = initial_value + else: + initial_value_fn = lambda: initial_value + + for i, d in enumerate(devices): + with ops.device(d): + if i > 0: + # Give replicas meaningful distinct names: + var0name = index[devices[0]].name.split(":")[0] + # We append a / to variable names created on towers with id > 0 to + # ensure that we ignore the name scope and instead use the given + # name as the absolute name of the variable. + kwargs["name"] = "%s/replica_%d/" % (var0name, i) + + # The initial value fn makes sure variables all initialized to + # same values. The first device of the chief worker will send their + # variable values to other devices and other workers. + def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring + with ops.device(device): + initial_value = initial_value_fn() + assert not callable(initial_value) + initial_value = ops.convert_to_tensor(initial_value) + + if self._is_chief and index == 0: + bcast_send = collective_ops.broadcast_send( + initial_value, initial_value.shape, initial_value.dtype, + group_size, group_key, collective_instance_key) + with ops.control_dependencies([bcast_send]): + return array_ops.identity(initial_value) + else: + return collective_ops.broadcast_recv( + initial_value.shape, initial_value.dtype, group_size, + group_key, collective_instance_key) + + kwargs["initial_value"] = _overridden_initial_value_fn + + with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + v = next_creator(*args, **kwargs) + + assert not isinstance(v, values.DistributedVariable) + index[d] = v + return index + + # pylint: disable=protected-access + return mirrored_strategy._create_mirrored_variable( + devices, _real_mirrored_creator, *args, **kwargs) + + def configure(self, session_config=None): + # Use TF_CONFIG to get the cluster spec and the current job. + if not self._cluster_spec: + tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) + cluster_spec = _normalize_cluster_spec(tf_config.get("cluster", {})) + + task_env = tf_config.get("task", {}) + if task_env: + task_type = task_env.get("type", "worker") + task_id = int(task_env.get("index", "0")) + else: + task_type = "worker" + task_id = 0 + + if cluster_spec: + self._initialize(cluster_spec, task_type, task_id) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e54e3b7d7156e87731e6f79aa66262d127232c --- /dev/null +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -0,0 +1,217 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CollectiveAllReduceStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.distribute.python import collective_all_reduce_strategy +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.eager import context +from tensorflow.python.estimator import run_config +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class DistributedCollectiveAllReduceStrategyTest( + multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase): + + collective_key_base = 0 + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers.""" + cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=0) + cls._cluster_spec = { + run_config.TaskType.WORKER: [ + 'fake_worker_0', 'fake_worker_1', 'fake_worker_2' + ] + } + + def setUp(self): + self._run_options = config_pb2.RunOptions() + self._run_options.experimental.collective_graph_key = 6 + + self._sess_config = config_pb2.ConfigProto() + self._sess_config.experimental.collective_group_leader = ( + '/job:worker/replica:0/task:0') + + # We use a different key_base for each test so that collective keys won't be + # reused. + # TODO(yuefengz, tucker): enable it to reuse collective keys in different + # tests. + DistributedCollectiveAllReduceStrategyTest.collective_key_base += 100000 + super(DistributedCollectiveAllReduceStrategyTest, self).setUp() + + def _get_test_object(self, task_type, task_id, num_gpus=0): + distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus, + cluster_spec=self._cluster_spec, + task_type=task_type, + task_id=task_id) + collective_keys = cross_tower_utils.CollectiveKeys( + group_key_start=10 * num_gpus + + DistributedCollectiveAllReduceStrategyTest.collective_key_base, + instance_key_start=num_gpus * 100 + + DistributedCollectiveAllReduceStrategyTest.collective_key_base, + instance_key_with_id_start=num_gpus * 10000 + + DistributedCollectiveAllReduceStrategyTest.collective_key_base) + distribution._collective_keys = collective_keys + distribution._cross_tower_ops._collective_keys = collective_keys + return distribution, self._workers[task_id].target + + def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): + d, master_target = self._get_test_object(task_type, task_id, num_gpus) + with ops.Graph().as_default(), \ + self.test_session(config=self._sess_config, + target=master_target) as sess, \ + d.scope(): + l = core.Dense(1, use_bias=False, name='gpu_%d' % d._num_gpus_per_worker) + + def loss_fn(x): + y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + return y * y + + # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for + # multiple graphs (b/111216820). + def grad_fn(x): + loss = loss_fn(x) + var_list = ( + variables.trainable_variables() + ops.get_collection( + ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) + grads = gradients.gradients(loss, var_list) + ret = list(zip(grads, var_list)) + return ret + + def update(v, g): + return v.assign_sub(0.05 * g, use_locking=True) + + one = d.broadcast(constant_op.constant([[1.]])) + + def step(): + """Perform one optimization step.""" + # Run forward & backward to get gradients, variables list. + g_v = d.call_for_each_tower(grad_fn, one) + # Update the variables using the gradients and the update() function. + before_list = [] + after_list = [] + for g, v in g_v: + fetched = d.read_var(v) + before_list.append(fetched) + with ops.control_dependencies([fetched]): + # TODO(yuefengz): support non-Mirrored variable as destinations. + g = d.reduce( + variable_scope.VariableAggregation.SUM, g, destinations=v) + with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + after_list.append(d.read_var(v)) + return before_list, after_list + + before_out, after_out = step() + + if context.num_gpus() < d._num_gpus_per_worker: + return True + + sess.run( + variables.global_variables_initializer(), options=self._run_options) + + for i in range(10): + b, a = sess.run((before_out, after_out), options=self._run_options) + if i == 0: + before, = b + after, = a + + error_before = abs(before - 1) + error_after = abs(after - 1) + # Error should go down + self.assertLess(error_after, error_before) + return error_after < error_before + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testMinimizeLossGraph(self, num_gpus): + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) + + def _test_variable_initialization(self, task_type, task_id, num_gpus): + distribution, master_target = self._get_test_object(task_type, task_id, + num_gpus) + with ops.Graph().as_default(), \ + self.test_session(config=self._sess_config, + target=master_target) as sess, \ + distribution.scope(): + + def model_fn(): + x = variable_scope.get_variable( + 'x', + shape=(2, 3), + initializer=init_ops.random_uniform_initializer( + 1.0, 10.0, dtype=dtypes.float32)) + return array_ops.identity(x) + + x = distribution.call_for_each_tower(model_fn) + reduced_x = distribution.unwrap( + distribution.reduce( + variable_scope.VariableAggregation.MEAN, x, + destinations='/cpu:0'))[0] + + sess.run( + variables.global_variables_initializer(), options=self._run_options) + x_value, reduced_x_value = sess.run( + [x, reduced_x], options=self._run_options) + self.assertTrue(np.array_equal(x_value, reduced_x_value)) + return np.array_equal(x_value, reduced_x_value) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testVariableInitialization(self, num_gpus): + if context.num_gpus() < num_gpus: + return + self._run_between_graph_clients( + self._test_variable_initialization, + self._cluster_spec, + num_gpus=num_gpus) + + +class LocalCollectiveAllReduceStrategy(strategy_test_lib.DistributionTestBase, + parameterized.TestCase): + + def testMinimizeLossGraph(self, num_gpus=2): + # Collective ops doesn't support strategy with one device. + if context.num_gpus() < num_gpus: + return + distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus) + self._test_minimize_loss_graph(distribution) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 9a8ea4aa48b8cf4c5906f18d8bddacc224e0b644..120349481ff11dd47d88154c72b37f81b2a1074f 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -46,6 +46,7 @@ import unittest from absl.testing import parameterized import six +from tensorflow.contrib.cluster_resolver import TPUClusterResolver from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib from tensorflow.contrib.distribute.python import multi_worker_strategy from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib @@ -144,7 +145,7 @@ def _augment_with_special_arguments(test_method): """A wrapped test method that treats some arguments in a special way.""" mode = kwargs.pop("mode", "graph") - distribution = kwargs.pop("distribution", None) + distribution = kwargs.get("distribution", None) required_tpu = kwargs.pop("required_tpu", False) required_gpus = kwargs.pop("required_gpus", None) @@ -153,7 +154,6 @@ def _augment_with_special_arguments(test_method): "Do not use `required_gpus` and `distribution` together.") assert required_tpu is False, ( "Do not use `required_tpu` and `distribution` together.") - kwargs["distribution"] = distribution.strategy required_gpus = distribution.required_gpus required_tpu = distribution.required_tpu @@ -189,9 +189,13 @@ def _augment_with_special_arguments(test_method): if mode == "eager": with ops.Graph().as_default(), context.eager_mode(): + if distribution: + kwargs_to_pass["distribution"] = distribution.strategy test_method(**kwargs_to_pass) elif mode == "graph": with ops.Graph().as_default(), context.graph_mode(): + if distribution: + kwargs_to_pass["distribution"] = distribution.strategy test_method(**kwargs_to_pass) else: raise ValueError( @@ -321,7 +325,9 @@ default_strategy = NamedDistribution( one_device_strategy = NamedDistribution( "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), required_gpus=None) -tpu_strategy = NamedDistribution("TPU", tpu_lib.TPUStrategy, required_tpu=True) +tpu_strategy = NamedDistribution( + "TPU", lambda: tpu_lib.TPUStrategy(TPUClusterResolver("")), + required_tpu=True) # Note that we disable prefetching for testing since prefetching makes # the input non-deterministic. mirrored_strategy_with_gpu_and_cpu = NamedDistribution( diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index 0261ce43fa854d3b2ee38df19b8a8938cac3c8f3..9b5534393edf32145dd5328407d365ac0676879b 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -28,17 +28,37 @@ from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import device_util +def check_destinations(destinations): + """Checks whether `destinations` is not None and not empty. + + Args: + destinations: a DistributedValues, Variable, string or a list of strings. + + Returns: + Boolean indicating whether `destinations` is not None and not empty. + """ + # Calling bool() on a ResourceVariable is not allowed. + if isinstance(destinations, resource_variable_ops.ResourceVariable): + return bool(destinations.device) + return bool(destinations) + + def validate_destinations(destinations): - if not isinstance(destinations, - (value_lib.DistributedValues, six.string_types, list)): + if not isinstance( + destinations, + (value_lib.DistributedValues, resource_variable_ops.ResourceVariable, + six.string_types, list)): raise ValueError("destinations must be one of a `DistributedValues` object," - " a device string, a list of device strings or None") + " a tf.Variable object, a device string, a list of device " + "strings or None") - if not destinations: + if not check_destinations(destinations): raise ValueError("destinations can not be empty") @@ -58,6 +78,8 @@ def _validate_value_destination_pairs(value_destination_pairs): def get_devices_from(destinations): if isinstance(destinations, value_lib.DistributedValues): return list(destinations.devices) + elif isinstance(destinations, resource_variable_ops.ResourceVariable): + return [destinations.device] elif isinstance(destinations, six.string_types): return [device_util.resolve(destinations)] else: @@ -88,7 +110,7 @@ def _simple_broadcast(value, destinations): def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, - method_string): + aggregation): # pylint: disable=g-missing-docstring all_values = [] count = 0 @@ -112,11 +134,12 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices( all_values, accumulation_fn) - if method_string == "mean": + if aggregation == vs.VariableAggregation.MEAN: reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices( reduced, count) - elif method_string != "sum": - raise ValueError("`method_string` must be 'sum' or 'mean'") + elif aggregation != vs.VariableAggregation.SUM: + raise ValueError("`aggregation` must be VariableAggregation.SUM " + "or VariableAggregation.MEAN.") return reduced @@ -126,14 +149,15 @@ class CrossTowerOps(object): def __init__(self): pass - def reduce(self, method_string, per_device_value, destinations=None): + def reduce(self, aggregation, per_device_value, destinations=None): """Reduce `per_device_value` to `destinations`. - It runs the reduction operation defined by `method_string` and put the + It runs the reduction operation defined by `aggregation` and put the result on `destinations`. Args: - method_string: either 'sum' or 'mean' specifying the reduction method. + aggregation: Indicates how a variable will be aggregated. Accepted values + are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}. per_device_value: a PerDevice object. destinations: the reduction destinations. @@ -147,16 +171,17 @@ class CrossTowerOps(object): raise ValueError("`per_device_value` must be a `PerDevice` object.") if destinations is not None: validate_destinations(destinations) - return self._reduce(method_string, per_device_value, destinations) + return self._reduce(aggregation, per_device_value, destinations) - def batch_reduce(self, method_string, value_destination_pairs): + def batch_reduce(self, aggregation, value_destination_pairs): """Reduce PerDevice objects in a batch. Reduce each first element in `value_destination_pairs` to each second element which indicates the destinations. Args: - method_string: either 'sum' or 'mean' specifying the reduction method. + aggregation: Indicates how a variable will be aggregated. Accepted values + are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}. value_destination_pairs: a list or a tuple of tuples of PerDevice objects and destinations. If a destination is None, then the destinations are set to match the devices of the input PerDevice object. @@ -175,7 +200,7 @@ class CrossTowerOps(object): if d is not None: validate_destinations(d) - return self._batch_reduce(method_string, value_destination_pairs) + return self._batch_reduce(aggregation, value_destination_pairs) def broadcast(self, tensor, destinations): """Broadcast the `tensor` to destinations. @@ -190,11 +215,11 @@ class CrossTowerOps(object): validate_destinations(destinations) return self._broadcast(tensor, destinations) - def _reduce(self, method_string, per_device_value, destinations): + def _reduce(self, aggregation, per_device_value, destinations): raise NotImplementedError( "_reduce method must be implemented in descendants.") - def _batch_reduce(self, method_string, value_destination_pairs): + def _batch_reduce(self, aggregation, value_destination_pairs): raise NotImplementedError( "_batch_reduce method must be implemented in descendants.") @@ -220,16 +245,21 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps): self.accumulation_fn = accumulation_fn super(ReductionToOneDeviceCrossTowerOps, self).__init__() - def _reduce(self, method_string, per_device_value, destinations): - devices = get_devices_from(destinations or per_device_value) + def _reduce(self, aggregation, per_device_value, destinations): + if check_destinations(destinations): + devices = get_devices_from(destinations) + else: + devices = get_devices_from(per_device_value) reduce_to_device = self.reduce_to_device or devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, - self.accumulation_fn, method_string) + self.accumulation_fn, aggregation) return self.broadcast(reduced, devices) - def _batch_reduce(self, method_string, value_destination_pairs): - return [self._reduce(method_string, t, destinations=v) - for t, v in value_destination_pairs] + def _batch_reduce(self, aggregation, value_destination_pairs): + return [ + self._reduce(aggregation, t, destinations=v) + for t, v in value_destination_pairs + ] def _group_value_by_device(per_device_values): @@ -237,9 +267,9 @@ def _group_value_by_device(per_device_values): This grouping is needed to call the all-reduce library because it expects a list of the following form: - [(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ... - (grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ... - (grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ... + [[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...], + [(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...], + [(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...], ... ] @@ -260,18 +290,24 @@ def _group_value_by_device(per_device_values): return grouped -def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string): +def _ungroup_and_make_mirrored(grouped_reduced, + destinations, + aggregation, + num_between_graph_workers=1): """Ungroup results from all-reduce and make Mirrored objects. Each all-reduce result will be divided by the number of destinations before - Mirrored objects are created if method_string is "mean". + Mirrored objects are created if aggregation is "mean". Args: grouped_reduced: a list of lists, each sublist has components for each device, paired with a None. It is the result from cross_tower_utils.aggregate_gradients_using*. destinations: a list of device strings for returned Mirrored objects. - method_string: "mean" or "sum". + aggregation: Indicates how a variable will be aggregated. Accepted values + are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}. + num_between_graph_workers: number of workers in the between-graph + replication. Returns: a list of Mirrored objects. @@ -279,8 +315,9 @@ def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string): index = [{} for _ in range(len(grouped_reduced[0]))] for d, per_device_reduced in enumerate(grouped_reduced): for i, (v, _) in enumerate(per_device_reduced): - if method_string == "mean": - index[i][destinations[d]] = v / len(destinations) + if aggregation == vs.VariableAggregation.MEAN: + index[i][destinations[d]] = v / ( + len(destinations) * num_between_graph_workers) else: index[i][destinations[d]] = v return [value_lib.Mirrored(v) for v in index] @@ -488,32 +525,35 @@ class AllReduceCrossTowerOps(CrossTowerOps): self._agg_small_grads_max_group = agg_small_grads_max_group super(AllReduceCrossTowerOps, self).__init__() - def _reduce(self, method_string, per_device_value, destinations): + def _reduce(self, aggregation, per_device_value, destinations): contains_indexed_slices = cross_tower_utils.contains_indexed_slices( per_device_value) if ((destinations is None or _devices_match(per_device_value, destinations)) and not context.executing_eagerly() and not contains_indexed_slices): - return self._batch_all_reduce(method_string, [per_device_value])[0] + return self._batch_all_reduce(aggregation, [per_device_value])[0] else: if contains_indexed_slices: logging.log_first_n( logging.WARN, "Efficient allreduce is not supported for IndexedSlices.", 10) - devices = get_devices_from(destinations or per_device_value) + if check_destinations(destinations): + devices = get_devices_from(destinations) + else: + devices = get_devices_from(per_device_value) reduce_to_device = devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, - math_ops.add_n, method_string) + math_ops.add_n, aggregation) return self.broadcast(reduced, devices) - def _batch_reduce(self, method_string, value_destination_pairs): + def _batch_reduce(self, aggregation, value_destination_pairs): all_devices_match = _all_devices_match(value_destination_pairs) contains_indexed_slices = cross_tower_utils.contains_indexed_slices( value_destination_pairs) if (all_devices_match and not context.executing_eagerly() and not contains_indexed_slices): - return self._batch_all_reduce(method_string, + return self._batch_all_reduce(aggregation, [v[0] for v in value_destination_pairs]) else: if not all_devices_match: @@ -521,18 +561,18 @@ class AllReduceCrossTowerOps(CrossTowerOps): "destinations are different.") return [ - self._reduce(method_string, t, destinations=v) + self._reduce(aggregation, t, destinations=v) for t, v in value_destination_pairs ] - def _batch_all_reduce(self, method_string, per_device_values): + def _batch_all_reduce(self, aggregation, per_device_values): """All reduce algorithm in a batch.""" - logging.info( - "batch_all_reduce invoked for batches size = %d with " + logging.log_first_n( + logging.INFO, "batch_all_reduce invoked for batches size = %d with " "algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and " - "agg_small_grads_max_group = %d", len(per_device_values), - self._all_reduce_alg, self._num_packs, self._agg_small_grads_max_bytes, - self._agg_small_grads_max_group) + "agg_small_grads_max_group = %d" % + (len(per_device_values), self._all_reduce_alg, self._num_packs, + self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10) destinations = per_device_values[0].devices grouped = _group_value_by_device(per_device_values) @@ -556,7 +596,7 @@ class AllReduceCrossTowerOps(CrossTowerOps): reduced = _unpack_tensors(reduced, tensor_packer) return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices, - method_string) + aggregation) AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple", @@ -635,14 +675,15 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps): validate_and_complete_spec(spec) for spec in all_reduce_spec ] - def _batch_all_reduce(self, method_string, per_device_values): + def _batch_all_reduce(self, aggregation, per_device_values): """All reduce algorithm in a batch.""" - logging.info( + logging.log_first_n( + logging.INFO, "distributed batch_all_reduce invoked for batches size = %d with " "allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d " - "and agg_small_grads_max_group = %d", len(per_device_values), - self._all_reduce_spec, self._num_packs, self._agg_small_grads_max_bytes, - self._agg_small_grads_max_group) + "and agg_small_grads_max_group = %d" % + (len(per_device_values), self._all_reduce_spec, self._num_packs, + self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10) destinations = sorted(per_device_values[0].devices) device_grads = _group_value_by_device(per_device_values) @@ -682,7 +723,103 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps): assert not remaining_grads return _ungroup_and_make_mirrored(aggregated_grads, destinations, - method_string) + aggregation) + + +# TODO(yuefengz): support in-graph collective all-reduce. +class CollectiveAllReduce(CrossTowerOps): + """All-reduce cross tower ops using collective ops. + + In the between-graph replicated training, it will still do all-reduces across + all workers and then put results on the right destinations. + """ + + def __init__(self, + num_workers=1, + num_gpus_per_worker=0, + all_reduce_merge_scope=1, + collective_keys=None): + """Initializes the object. + + Args: + num_workers: number of workers in the between-graph replicated training. + num_gpus_per_worker: number of GPUs per worker. + all_reduce_merge_scope: size of groups into which to partition consecutive + gradients grouped under a common 'allreduce' name scope. This is useful + for some optimization of collective ops. + collective_keys: an optional CollectiveKey object. + """ + self._num_workers = num_workers + self._num_gpus_per_worker = num_gpus_per_worker + self._all_reduce_merge_scope = all_reduce_merge_scope + self._collective_keys = collective_keys or cross_tower_utils.CollectiveKeys( + ) + super(CollectiveAllReduce, self).__init__() + + # TODO(yuefengz, tucker): is index slices supported by collective ops? + def _reduce(self, aggregation, per_device_value, destinations): + all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0] + if destinations is None or _devices_match(per_device_value, destinations): + return all_reduced + else: + index = {} + for d in get_devices_from(destinations): + # pylint: disable=protected-access + if d in all_reduced._index: + index[d] = all_reduced._index[d] + else: + with ops.device(d): + index[d] = array_ops.identity(list(all_reduced._index.values())[0]) + return value_lib.Mirrored(index) + + def _batch_reduce(self, aggregation, value_destination_pairs): + return [ + self._reduce(aggregation, t, destinations=v) + for t, v in value_destination_pairs + ] + + def _batch_all_reduce(self, aggregation, per_device_values): + """All-reduce across all workers in a batch.""" + if context.executing_eagerly(): + raise ValueError("Eager mode with collective ops is not supported yet.") + + logging.log_first_n( + logging.INFO, "Collective All-reduce invoked with batches size = %d, " + "num_workers = %d" % (len(per_device_values), self._num_workers), 10) + + grouped_by_tower = _group_value_by_device(per_device_values) + + grouped_by_var = list(zip(*grouped_by_tower)) + # grouped_by_var is grouped by variables and takes the following format: + # [((grad0_gpu0, v0_gpu0), (grad0_gpu1, v0_gpu1), (grad0_gpu2, v0_gpu2) ..), + # ((grad1_gpu0, v1_gpu0), (grad1_gpu1, v1_gpu1), (grad1_gpu0, v1_gpu2) ..), + # ((grad2_gpu0, v2_gpu0), (grad2_gpu1, v2_gpu1), (grad2_gpu0, v2_gpu2) ..), + # ... + # ] + chunked_gv = [ + grouped_by_var[x:x + self._all_reduce_merge_scope] + for x in range(0, len(grouped_by_var), self._all_reduce_merge_scope) + ] + + reduced_gv_list = [] + for chunk in chunked_gv: + with ops.name_scope("allreduce"): + for grad_and_vars in chunk: + scaled_grads = [g for g, _ in grad_and_vars] + collective_reduced = cross_tower_utils.build_collective_reduce( + scaled_grads, self._num_workers, self._collective_keys, "Add", + "Id") + result = [] + for (_, v), g in zip(grad_and_vars, collective_reduced): + result.append([g, v]) + reduced_gv_list.append(result) + + new_tower_grads = [list(x) for x in zip(*reduced_gv_list)] + return _ungroup_and_make_mirrored( + new_tower_grads, + per_device_values[0].devices, + aggregation, + num_between_graph_workers=self._num_workers) _dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index c540ea0d232e31af51ef4c2a1530250669e49495..aec53b01d7a089fec08eec6ea43373a2cd8267d6 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -21,17 +21,22 @@ from __future__ import print_function import itertools from absl.testing import parameterized +import numpy as np from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import test +from tensorflow.python.estimator import run_config from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import device_util @@ -129,32 +134,45 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase): # test reduce() for destinations in all_destinations: self._assert_values_equal( - cross_tower_ops.reduce("mean", per_device, destinations=destinations), + cross_tower_ops.reduce( + vs.VariableAggregation.MEAN, + per_device, + destinations=destinations), _fake_mirrored(mean, destinations or per_device)) self._assert_values_equal( cross_tower_ops.reduce( - "mean", per_device_2, destinations=destinations), + vs.VariableAggregation.MEAN, + per_device_2, + destinations=destinations), _fake_mirrored(mean_2, destinations or per_device)) self._assert_values_equal( - cross_tower_ops.reduce("sum", per_device, destinations=destinations), + cross_tower_ops.reduce( + vs.VariableAggregation.SUM, per_device, + destinations=destinations), _fake_mirrored(mean * len(devices), destinations or per_device)) self._assert_values_equal( cross_tower_ops.reduce( - "sum", per_device_2, destinations=destinations), + vs.VariableAggregation.SUM, + per_device_2, + destinations=destinations), _fake_mirrored(mean_2 * len(devices), destinations or per_device)) # test batch_reduce() for d1, d2 in itertools.product(all_destinations, all_destinations): self._assert_values_equal( - cross_tower_ops.batch_reduce( - "mean", [(per_device, d1), (per_device_2, d2)]), - [_fake_mirrored(mean, d1 or per_device), - _fake_mirrored(mean_2, d2 or per_device_2)]) + cross_tower_ops.batch_reduce(vs.VariableAggregation.MEAN, + [(per_device, d1), (per_device_2, d2)]), + [ + _fake_mirrored(mean, d1 or per_device), + _fake_mirrored(mean_2, d2 or per_device_2) + ]) self._assert_values_equal( - cross_tower_ops.batch_reduce( - "sum", [(per_device, d1), (per_device_2, d2)]), - [_fake_mirrored(mean * len(devices), d1 or per_device), - _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)]) + cross_tower_ops.batch_reduce(vs.VariableAggregation.SUM, + [(per_device, d1), (per_device_2, d2)]), + [ + _fake_mirrored(mean * len(devices), d1 or per_device), + _fake_mirrored(mean_2 * len(devices), d2 or per_device_2) + ]) # test broadcast() for destinations in all_destinations: @@ -255,8 +273,8 @@ class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase): t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0]) t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) - result = cross_tower_ops_lib._simple_reduce(per_device, devices[0], - math_ops.add_n, "sum") + result = cross_tower_ops_lib._simple_reduce( + per_device, devices[0], math_ops.add_n, vs.VariableAggregation.SUM) # Test that the result is semantically equal to both the concatenated # IndexedSlices with and without duplicate indices. @@ -267,21 +285,22 @@ class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase): self._assert_indexed_slices_equal(total_with_dups, result) self._assert_indexed_slices_equal(total_without_dups, result) - @combinations.generate(combinations.combine( - cross_tower_ops_instance=[ - combinations.NamedObject( - "ReductionToOneDeviceCrossTowerOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), - combinations.NamedObject( - "AllReduceCrossTowerOps", - cross_tower_ops_lib.AllReduceCrossTowerOps()) - ], - method_string=["sum", "mean"], - batch_reduce=[True, False], - mode=["graph", "eager"], - required_gpus=1)) - def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, - method_string, batch_reduce): + @combinations.generate( + combinations.combine( + cross_tower_ops_instance=[ + combinations.NamedObject( + "ReductionToOneDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), + combinations.NamedObject( + "AllReduceCrossTowerOps", + cross_tower_ops_lib.AllReduceCrossTowerOps()) + ], + aggregation=[vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN], + batch_reduce=[True, False], + mode=["graph", "eager"], + required_gpus=1)) + def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, aggregation, + batch_reduce): devices = ["/cpu:0", "/gpu:0"] dense_shape = [5, 2] t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0]) @@ -290,20 +309,19 @@ class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase): per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) if batch_reduce: - result = cross_tower_ops_instance.batch_reduce(method_string, + result = cross_tower_ops_instance.batch_reduce(aggregation, [(per_device, devices)]) else: - result = cross_tower_ops_instance.reduce(method_string, per_device, - devices) + result = cross_tower_ops_instance.reduce(aggregation, per_device, devices) total_indices_with_dups = [1, 1, 3] total_indices_without_dups = [1, 3] - if method_string == "sum": + if aggregation == vs.VariableAggregation.SUM: total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]] total_values_without_dups = [[4., 6.], [5., 6.]] else: - assert method_string == "mean" + assert aggregation == vs.VariableAggregation.MEAN total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]] total_values_without_dups = [[2., 3.], [2.5, 3.]] @@ -362,5 +380,166 @@ class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase, self._testReductionAndBroadcast(cross_tower_ops, distribution) +class MultiWorkerCollectiveAllReduceTest( + multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase): + + collective_key_base = 100000 + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers.""" + cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=0) + cls._cluster_spec = { + run_config.TaskType.WORKER: [ + "fake_worker_0", "fake_worker_1", "fake_worker_2" + ] + } + + def setUp(self): + super(MultiWorkerCollectiveAllReduceTest, self).setUp() + # Reusing keys are not supported well. So we have to give a different + # collective key base for different tests. + MultiWorkerCollectiveAllReduceTest.collective_key_base += 100000 + + def _get_test_objects(self, task_type, task_id, num_gpus=0, local_mode=False): + collective_keys = cross_tower_utils.CollectiveKeys( + group_key_start=10 * num_gpus + + MultiWorkerCollectiveAllReduceTest.collective_key_base, + instance_key_start=num_gpus * 100 + + MultiWorkerCollectiveAllReduceTest.collective_key_base, + instance_key_with_id_start=num_gpus * 10000 + + MultiWorkerCollectiveAllReduceTest.collective_key_base) + if local_mode: + collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( + 1, num_gpus, collective_keys=collective_keys) + if num_gpus: + devices = ["/device:GPU:%d" % i for i in range(num_gpus)] + else: + devices = ["/device:CPU:0"] + return collective_all_reduce_ops, devices, "local" + else: + collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( + 3, num_gpus, collective_keys=collective_keys) + if num_gpus: + devices = [ + "/job:%s/task:%d/device:GPU:%d" % (task_type, task_id, i) + for i in range(num_gpus) + ] + else: + devices = ["/job:%s/task:%d" % (task_type, task_id)] + return collective_all_reduce_ops, devices, self._workers[task_id].target + + def _assert_values_equal(self, left, right, sess): + if isinstance(left, list): + for l, r in zip(left, right): + self._assert_values_equal(l, r, sess) + else: + self.assertEqual(type(left), type(right)) + self.assertEqual(set(left.devices), set(right.devices)) + + run_options = config_pb2.RunOptions() + run_options.experimental.collective_graph_key = 6 + + left_values = np.array( + sess.run(list(left._index.values()), options=run_options)).flatten() + right_values = np.array(list(right._index.values())).flatten() + self.assertEqual(len(left_values), len(right_values)) + for l, r in zip(left_values, right_values): + self.assertEqual(l, r) + + def _test_reduction(self, task_type, task_id, num_gpus, local_mode=False): + collective_all_reduce, devices, master_target = self._get_test_objects( + task_type, task_id, num_gpus, local_mode=local_mode) + if local_mode: + num_workers = 1 + worker_device = None + else: + num_workers = len(self._workers) + worker_device = "/job:%s/task:%d" % (task_type, task_id) + with ops.Graph().as_default(), \ + ops.device(worker_device), \ + self.test_session(target=master_target) as sess: + # Collective ops doesn't support scalar tensors, so we have to construct + # 1-d tensors. + values = [constant_op.constant([float(d)]) for d in range(len(devices))] + per_device = _make_per_device(values, devices) + mean = np.array([(len(devices) - 1.) / 2.]) + + values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))] + per_device_2 = _make_per_device(values_2, devices) + mean_2 = np.array([mean[0] + 1.]) + + destination_mirrored = _fake_mirrored(1., devices) + destination_different = _fake_mirrored(1., _cpu_device) + destination_str = _cpu_device + destination_list = devices + + all_destinations = [ + None, destination_mirrored, destination_different, destination_str, + destination_list + ] + + # test reduce() + for destinations in all_destinations: + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.MEAN, + per_device, + destinations=destinations), + _fake_mirrored(mean, destinations or per_device), sess) + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.MEAN, + per_device_2, + destinations=destinations), + _fake_mirrored(mean_2, destinations or per_device), sess) + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.SUM, + per_device, + destinations=destinations), + _fake_mirrored(mean * len(devices) * num_workers, destinations or + per_device), sess) + self._assert_values_equal( + collective_all_reduce.reduce( + vs.VariableAggregation.SUM, + per_device_2, + destinations=destinations), + _fake_mirrored(mean_2 * len(devices) * num_workers, destinations or + per_device), sess) + + # test batch_reduce() + for d1, d2 in itertools.product(all_destinations, all_destinations): + self._assert_values_equal( + collective_all_reduce.batch_reduce(vs.VariableAggregation.MEAN, + [(per_device, d1), + (per_device_2, d2)]), + [ + _fake_mirrored(mean, d1 or per_device), + _fake_mirrored(mean_2, d2 or per_device_2) + ], sess) + self._assert_values_equal( + collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM, + [(per_device, d1), + (per_device_2, d2)]), + [ + _fake_mirrored(mean * len(devices) * num_workers, d1 or + per_device), + _fake_mirrored(mean_2 * len(devices) * num_workers, d2 or + per_device_2) + ], sess) + + return True + + @combinations.generate( + combinations.combine(mode=["graph"], num_gpus=[0, 1, 2])) + def testReductionDistributed(self, num_gpus): + if context.num_gpus() < num_gpus: + return + self._run_between_graph_clients(self._test_reduction, self._cluster_spec, + num_gpus) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py index 2bb088e704c584598b863b1b836166af2a5bb12c..24cb08fb48f832572da5ae2113e6c224557c6a81 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -19,13 +19,16 @@ from __future__ import division from __future__ import print_function import collections as pycoll +import threading from tensorflow.contrib import nccl from tensorflow.contrib.all_reduce.python import all_reduce from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import collective_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops @@ -218,6 +221,146 @@ def split_grads_by_size(threshold_size, device_grads): return small_grads, large_grads +# threading.Lock() cannot be pickled and therefore cannot be a field of +# CollectiveKeys. +_lock = threading.Lock() + + +# TODO(yuefengz): use random key starts to avoid reusing keys? +class CollectiveKeys(object): + """Class that manages collective keys. + + We need to manage three different keys for collective: + + *Group key*: an integer key to identify the set of cooperative devices. + Collective ops work under the same set of devices must using the same group + key. + + *Instance key*: an integer key to identify the set of same counterpart of + tensors on different devices in a device group that need to be all-reduced. + + "Graph key": an integer key that is unique key graph. This is used to support + multiple graphs per client session. It must be non-zero and set in the + `config` argument of each call to `session.run`. + """ + + def __init__(self, + group_key_start=1, + instance_key_start=100, + instance_key_with_id_start=10000): + """Initializes the object. + + Args: + group_key_start: the starting integer of group key. + instance_key_start: the starting integer of instance key. + instance_key_with_id_start: the starting integer of instance key that is + recorded with an id. + """ + self._group_key = group_key_start + self._group_key_table = dict() + + # For instance keys with ids + self._instance_key_id_to_key_table = dict() + self._instance_key_with_id_counter = instance_key_with_id_start + + # For instance keys without ids + self._instance_key_start = instance_key_start + + self._thread_local = threading.local() + + def _get_thread_local_object(self): + # We make instance key without key ids thread local so that it will work + # with MirroredStrategy and distribute coordinator. + if not hasattr(self._thread_local, 'instance_key'): + self._thread_local.instance_key = self._instance_key_start + return self._thread_local + + def get_group_key(self, devices): + """Returns a group key for the set of devices. + + Args: + devices: list of strings naming devices in a collective group. + + Returns: + int key uniquely identifying the set of device names. + """ + parsed = [pydev.DeviceSpec.from_string(d) for d in devices] + # In the between-graph replicated training, different workers need to get + # the same device key. So we remove the task_type and task_id from the + # devices. + # TODO(yuefengz): in the in-graph replicated training, we need to include + # task_type and task_id. + names = sorted(['%s:%d' % (d.device_type, d.device_index) for d in parsed]) + key_id = ','.join(names) + with _lock: + if key_id not in self._group_key_table: + new_key = self._group_key + self._group_key += 1 + self._group_key_table[key_id] = new_key + return self._group_key_table[key_id] + + def get_instance_key(self, key_id=None): + """Returns a new instance key for use in defining a collective op. + + Args: + key_id: optional string. If set, key will be recorded and the same key + will be returned when the same key_id is provided. If not, an increasing + instance key will be returned. + """ + if key_id: + with _lock: + if key_id not in self._instance_key_id_to_key_table: + self._instance_key_with_id_counter += 1 + self._instance_key_id_to_key_table[key_id] = ( + self._instance_key_with_id_counter) + return self._instance_key_id_to_key_table[key_id] + else: + v = self._get_thread_local_object().instance_key + self._get_thread_local_object().instance_key += 1 + return v + + +def build_collective_reduce(input_tensors, + num_workers, + collective_keys, + reduction_op='Add', + unary_op='Id'): + """Build a subgraph that does one full all-reduce, using the collective Op. + + Args: + input_tensors: tensors within a single worker graph that are to be reduced + together; must be one per device. + num_workers: total number of workers with identical independent graphs that + will be doing this same reduction. The reduction will actually include + the corresponding tensors at all these workers. + collective_keys: a CollectiveKeys object. + reduction_op: string naming the reduction op. + unary_op: string naming the unary final op. + + Returns: + An array of final tensors, one per device, computed by the full reduction. + + Raises: + ValueError: There must be at least two tensors over all the workers. + """ + group_size = len(input_tensors) * num_workers + if group_size < 2: + raise ValueError('num_workers * len(input_tensors) must be 2 or greater') + devices = [t.device for t in input_tensors] + num_devices = len(devices) + group_key = collective_keys.get_group_key(devices) + instance_key = collective_keys.get_instance_key() + out_tensors = [] + subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec + for d in range(num_devices): + with ops.device(devices[d]): + reduce_op = collective_ops.all_reduce( + input_tensors[d], group_size, group_key, instance_key, reduction_op, + unary_op, subdiv_offsets) + out_tensors.append(reduce_op) + return out_tensors + + def sum_grad_and_var_all_reduce(grad_and_vars, num_workers, alg, @@ -253,10 +396,10 @@ def sum_grad_and_var_all_reduce(grad_and_vars, else: raise ValueError('unsupported all_reduce alg: ', alg) - result = [] - for (_, v), g in zip(grad_and_vars, summed_grads): - result.append([g, v]) - return result + result = [] + for (_, v), g in zip(grad_and_vars, summed_grads): + result.append([g, v]) + return result def sum_gradients_all_reduce(dev_prefixes, tower_grads, num_workers, alg, diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py index 34410a6470185ac2821bc6a59de9230ff478aeb6..a0bb144b7c7e73b051e41ab93086cbb5b6a852cc 100644 --- a/tensorflow/contrib/distribute/python/estimator_integration_test.py +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -96,7 +96,8 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, # TODO(isaprykin): Work around the colocate_with error. dnn_optimizer=adagrad.AdagradOptimizer(0.001), linear_optimizer=adagrad.AdagradOptimizer(0.001), - config=run_config.RunConfig(train_distribute=distribution)) + config=run_config.RunConfig( + train_distribute=distribution, eval_distribute=distribution)) num_steps = 10 estimator.train(train_input_fn, steps=num_steps) diff --git a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py index 00c25c7a2482a559c8b94ff3be86c4961dfb439f..44a69ed23a4e00ab81d5b51ae0c14550bd493f14 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py @@ -59,7 +59,8 @@ def build_model_fn_optimizer(): def main(_): distribution = tf.contrib.distribute.MirroredStrategy( ["/device:GPU:0", "/device:GPU:1"]) - config = tf.estimator.RunConfig(train_distribute=distribution) + config = tf.estimator.RunConfig(train_distribute=distribution, + eval_distribute=distribution) def input_fn(): features = tf.data.Dataset.from_tensors([[1.]]).repeat(10) @@ -70,7 +71,7 @@ def main(_): model_fn=build_model_fn_optimizer(), config=config) estimator.train(input_fn=input_fn, steps=10) - eval_result = estimator.evaluate(input_fn=input_fn) + eval_result = estimator.evaluate(input_fn=input_fn, steps=10) print("Eval result: {}".format(eval_result)) def predict_input_fn(): diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py index 2b05884b9b93470ef9a764cbedbc91bd3912c611..518ec9c4232465c3ecd0e4161f707dac499430c7 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py @@ -57,7 +57,8 @@ def main(args): # tf.Estimator that utilizes the DistributionStrategy. strategy = tf.contrib.distribute.MirroredStrategy( ['/device:GPU:0', '/device:GPU:1']) - config = tf.estimator.RunConfig(train_distribute=strategy) + config = tf.estimator.RunConfig( + train_distribute=strategy, eval_distribute=strategy) keras_estimator = tf.keras.estimator.model_to_estimator( keras_model=model, config=config, model_dir=model_dir) diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 75ecd90dcffa7a786b78238ef453c4c8e4346afa..ec0ca6879cffb9214adec15058cfb7293d347b25 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -12,33 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for Keras Sequential and Functional models.""" +"""Tests for tf.keras models using DistributionStrategy.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os - import numpy as np from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import values from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import keras as keras_lib from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import gradient_descent from tensorflow.python.training import rmsprop + _RANDOM_SEED = 1337 _TRAIN_SIZE = 200 _INPUT_SIZE = (10,) _NUM_CLASS = 2 +# TODO(anjalisridhar): Add a decorator that will allow us to run these tests as +# part of the tf.keras unit tests suite. def simple_sequential_model(): model = keras.models.Sequential() model.add(keras.layers.Dense(16, activation='relu', input_shape=_INPUT_SIZE)) @@ -84,7 +91,7 @@ def get_ds_test_input_fn(): return dataset -class TestKerasDistributionStrategy(test_util.TensorFlowTestCase): +class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): def setUp(self): self._base_dir = os.path.join(self.get_temp_dir(), @@ -107,7 +114,8 @@ class TestKerasDistributionStrategy(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=dist) + train_distribute=dist, + eval_distribute=dist) with self.test_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=config) @@ -144,5 +152,416 @@ class TestKerasDistributionStrategy(test_util.TensorFlowTestCase): writer_cache.FileWriterCache.clear() gfile.DeleteRecursively(self._config.model_dir) + def test_keras_optimizer_with_distribution_strategy(self): + dist = mirrored_strategy.MirroredStrategy( + devices=['/device:GPU:0', '/device:GPU:1']) + keras_model = simple_sequential_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer=keras.optimizers.rmsprop(lr=0.01)) + + config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=dist) + with self.test_session(): + est_keras = keras_lib.model_to_estimator(keras_model=keras_model, + config=config) + with self.assertRaisesRegexp(ValueError, + 'Only TensorFlow native optimizers are ' + 'supported with DistributionStrategy.'): + est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + + +class TestWithDistributionStrategy(test.TestCase): + + def test_validating_dataset_input_tensors_with_shape_mismatch(self): + with self.test_session(): + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + a = constant_op.constant([1, 2], shape=(1, 2)) + b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) + x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) + y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) + with strategy.scope(): + # Removed device and input tensor shape details from the error message + # since the order of the device and the corresponding input tensor shape + # is not deterministic over different runs. + with self.assertRaisesRegexp(ValueError, + 'Input tensor shapes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + distributed_training_utils.validate_distributed_dataset_inputs( + strategy, x, y) + + def test_validating_dataset_input_tensors_with_dtype_mismatch(self): + with self.test_session(): + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) + b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) + x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) + y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) + with strategy.scope(): + # Removed device and input tensor dtype details from the error message + # since the order of the device and the corresponding input tensor dtype + # is not deterministic over different runs. + with self.assertRaisesRegexp(ValueError, + 'Input tensor dtypes do not match for ' + 'distributed tensor inputs ' + 'DistributedValues:.+'): + distributed_training_utils.validate_distributed_dataset_inputs( + strategy, x, y) + + def test_calling_model_on_same_dataset(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + # Call fit with validation data + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + model.predict(dataset, steps=2) + + def test_fit_eval_and_predict_methods_on_dataset(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(dataset, steps=2) + # Test with validation data + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + validation_data=dataset, validation_steps=2) + + def test_raise_error_for_stateful_metrics(self): + + class ExampleStatefulMetric(keras.layers.Layer): + + def __init__(self, name='true_positives', **kwargs): + super(ExampleStatefulMetric, self).__init__(name=name, **kwargs) + self.stateful = True + + def __call__(self, y_true, y_pred): + return y_pred - y_true + + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae', ExampleStatefulMetric()] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + with self.assertRaisesRegexp( + NotImplementedError, 'Stateful metrics are not supported with ' + 'DistributionStrategy.'): + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + def test_unsupported_features(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + # Test with validation split + with self.assertRaisesRegexp( + ValueError, '`validation_split` argument is not ' + 'supported when input `x` is a dataset or a ' + 'dataset iterator.+'): + model.fit(dataset, + epochs=1, steps_per_epoch=2, verbose=0, + validation_split=0.5, validation_steps=2) + + # Test with sample weight. + sample_weight = np.random.random((10,)) + with self.assertRaisesRegexp( + NotImplementedError, 'sample_weight is currently not supported when ' + 'using DistributionStrategy.'): + model.fit( + dataset, + epochs=1, + steps_per_epoch=2, + verbose=0, + sample_weight=sample_weight) + + # Test with not specifying the `steps` argument. + with self.assertRaisesRegexp( + ValueError, 'you should specify the `steps_per_epoch` argument'): + model.fit(dataset, epochs=1, verbose=0) + with self.assertRaisesRegexp(ValueError, + 'you should specify the `steps` argument'): + model.evaluate(dataset, verbose=0) + + with self.assertRaisesRegexp(ValueError, + 'you should specify the `steps` argument'): + model.predict(dataset, verbose=0) + + def test_calling_with_unsupported_predefined_callbacks(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + def schedule(_): + return 0.001 + with self.assertRaisesRegexp(ValueError, + 'LearningRateScheduler callback is not ' + 'supported with DistributionStrategy.'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) + + with self.assertRaisesRegexp(ValueError, + 'ReduceLROnPlateau callback is not ' + 'supported with DistributionStrategy.'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.ReduceLROnPlateau()]) + with self.assertRaisesRegexp(ValueError, + 'histogram_freq in the TensorBoard callback ' + 'is not supported when using ' + 'DistributionStrategy.'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)]) + + def test_dataset_input_shape_validation(self): + with self.test_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + + model.compile(optimizer, loss, distribute=strategy) + + # User forgets to batch the dataset + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + + with self.assertRaisesRegexp(ValueError, + 'expected input to have 2 dimensions'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + # Wrong input shape + inputs = np.zeros((10, 5), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + with self.assertRaisesRegexp(ValueError, + 'expected input to have shape'): + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) + + def test_learning_phase_value(self): + # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare + # meaningful values. Currently we don't pass the learning phase if the + # Lambda layer uses the learning phase. + with self.test_session(): + x = keras.layers.Input(shape=(16,), name='input') + y = keras.layers.Dense(16)(x) + z = keras.layers.Dropout(0.9999)(y) + model = keras.Model(x, z) + + optimizer = gradient_descent.GradientDescentOptimizer(0.005) + loss = 'mse' + metrics = ['acc'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', + '/device:CPU:0']) + + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.random.rand(10, 16) + targets = np.ones((10, 16), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(8) + + hist = model.fit(dataset, epochs=5, steps_per_epoch=20, verbose=1) + self.assertEqual(hist.history['acc'][0], 1) + + evaluate_output = model.evaluate(dataset, steps=20) + self.assertEqual(evaluate_output[1], 0) + + predict_output = model.predict(dataset, steps=1) + self.assertNotEqual(np.mean(predict_output), 0) + + +class LossMaskingWithDistributionStrategyTest(test.TestCase): + + def test_masking(self): + with self.test_session(): + np.random.seed(1337) + x = np.array([[[1], [1]], [[0], [0]]]) + model = keras.models.Sequential() + model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1))) + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(1, kernel_initializer='one'))) + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + + model.compile(loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01), + distribute=strategy) + y = np.array([[[1], [1]], [[1], [1]]]) + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + hist = model.fit(x=dataset, epochs=1, steps_per_epoch=2) + self.assertEqual(hist.history['loss'][0], 0) + + +class NormalizationLayerWithDistributionStrategyTest(test.TestCase): + + def test_batchnorm_correctness(self): + with self.test_session(): + model = keras.models.Sequential() + norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) + model.add(norm) + strategy = mirrored_strategy.MirroredStrategy(['/device:CPU:0', + '/device:GPU:0']) + model.compile(loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01), + distribute=strategy) + + # centered on 5.0, variance 10.0 + x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) + dataset = dataset_ops.Dataset.from_tensor_slices((x, x)) + dataset = dataset.repeat(100) + dataset = dataset.batch(32) + + model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10) + out = model.predict(dataset, steps=2) + out -= keras.backend.eval(norm.beta) + out /= keras.backend.eval(norm.gamma) + np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1) + np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) + + +class CorrectnessWithDistributionStrategyTest(test.TestCase): + + def test_correctness(self): + with self.test_session(): + keras.backend.set_image_data_format('channels_last') + num_samples = 10000 + x_train = np.random.rand(num_samples, 1) + y_train = 3 * x_train + x_train = x_train.astype('float32') + y_train = y_train.astype('float32') + + model = keras.Sequential() + model.add(keras.layers.Dense(1, input_shape=(1,))) + + # With DistributionStrategy + dataset_with = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + dataset_with = dataset_with.batch(32) + strategy = mirrored_strategy.MirroredStrategy(devices=['/device:CPU:0', + '/device:GPU:0'], + prefetch_on_device=False) + + model.compile(loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5), + distribute=strategy) + model.fit(x=dataset_with, epochs=1, steps_per_epoch=310) + wts_with_ds = model.get_weights() + + x_predict = [[1], [2], [3], [4]] + predict_dataset_with = dataset_ops.Dataset.from_tensor_slices((x_predict, + x_predict)) + predict_dataset_with = predict_dataset_with.batch(2) + predict_with_ds = model.predict(predict_dataset_with, steps=1) + predict_with_ds = np.reshape(predict_with_ds, (4, 1)) + + # Without DistributionStrategy + dataset_without = dataset_ops.Dataset.from_tensor_slices((x_train, + y_train)) + dataset_without = dataset_without.batch(64) + + model.compile(loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5)) + model.fit(x=dataset_without, epochs=1, steps_per_epoch=310) + wts_without_ds = model.get_weights() + + x_predict = [[1], [2], [3], [4]] + predict_dataset_without = dataset_ops.Dataset.from_tensor_slices(( + x_predict, x_predict)) + predict_dataset_without = predict_dataset_without.batch(4) + predict_without_ds = model.predict(predict_dataset_without, steps=1) + + # Verify that the weights are the same within some limits of tolerance. + np.testing.assert_allclose(wts_with_ds[0], wts_without_ds[0], rtol=1e-3) + # Verify that the predicted outputs are the same within some limits of + # tolerance. + np.testing.assert_allclose(predict_with_ds, predict_without_ds, rtol=1e-3) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index 6c6bf143098c1bba64d47efce1bfface7682683d..2f3d6bdd3f4e4bc7352d7b378ed40b930608ef08 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -19,7 +19,6 @@ from __future__ import print_function from absl.testing import parameterized -from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.distribute.python import combinations from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import test @@ -183,7 +182,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): def _dataset_fn(): dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float) # Want to produce a fixed, known shape, so drop remainder when batching. - dataset = dataset.apply(batching.batch_and_drop_remainder(4)) + dataset = dataset.batch(4, drop_remainder=True) return dataset def _expected_fn(num_batches): diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 8092d8efe04f7a242a7356e1c9650eee16dc9c67..e3376a06368e8ef5efcda5bb69de66b7ec3390e1 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -20,7 +20,6 @@ from __future__ import print_function import contextlib import threading -import six from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib from tensorflow.contrib.distribute.python import shared_variable_creator @@ -28,13 +27,17 @@ from tensorflow.contrib.distribute.python import values from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import tape +from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import coordinator from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.util import nest # TODO(josh11b): Replace asserts in this file with if ...: raise ... @@ -60,6 +63,233 @@ class _RequestedStop(Exception): pass +# Make _call_for_each_tower and _reduce_non_distributed_value not members of +# MirroredStrategy so that they are generally not allowed to use anything +# specific to MirroredStrategy and thus can be shared with other distribution +# strategies. + + +# TODO(yuefengz): maybe create a common class for those who need to call this +# _call_for_each_tower. +def _call_for_each_tower(distribution, fn, *args, **kwargs): + """Run `fn` in separate threads, once per tower/worker device. + + Args: + distribution: the DistributionStrategy object. + fn: function to run (will be run once per device, each in its own thread). + *args: positional arguments for `fn` + **kwargs: keyword arguments for `fn`. + `"run_concurrently"`: Boolean indicating whether executions of `fn` + can be run concurrently (under eager execution only), defaults to + `True`. + + Returns: + Merged return value of `fn` across all towers. + + Raises: + RuntimeError: If fn() calls get_tower_context().merge_call() a different + number of times from the available devices. + """ + run_concurrently = kwargs.pop("run_concurrently", True) + if not context.executing_eagerly(): + # Lots of TF library code isn't thread-safe in graph mode, and + # there is little to be gained by turning on multithreading when + # constructing a graph. + run_concurrently = False + # Needed for per-thread device, etc. contexts in graph mode. + ops.get_default_graph().switch_to_thread_local() + elif run_concurrently is None: + run_concurrently = True + + coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) + + shared_variable_store = {} + + # TODO(isaprykin): Create these threads once instead of during every run() + # call. + threads = [] + for index, d in enumerate(distribution.worker_devices): + variable_creator_fn = shared_variable_creator.make_fn( + shared_variable_store, index) + t = MirroredStrategy._MirroredTowerThread( # pylint: disable=protected-access + distribution, coord, d, variable_creator_fn, fn, + *values.select_device(d, args), **values.select_device(d, kwargs)) + threads.append(t) + + for t in threads: + t.start() + + # When `fn` starts `should_run` event is set on _MirroredTowerThread + # (`MTT`) threads. The execution waits until + # `MTT.has_paused` is set, which indicates that either `fn` is + # complete or a `get_tower_context().merge_call()` is called. If `fn` is + # complete, then `MTT.done` is set to True. Otherwise, arguments + # of `get_tower_context().merge_call` from all paused threads are grouped + # and the `merge_fn` is performed. Results of the + # `get_tower_context().merge_call` are then set to `MTT.merge_result`. + # Each such `get_tower_context().merge_call` call returns the + # `MTT.merge_result` for that thread when `MTT.should_run` event + # is reset again. Execution of `fn` resumes. + + try: + with coord.stop_on_exception(): + all_done = False + while not all_done and not coord.should_stop(): + done = [] + if run_concurrently: + for t in threads: + t.should_run.set() + for t in threads: + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + else: + for t in threads: + t.should_run.set() + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + if coord.should_stop(): + return None + all_done = all(done) + if not all_done: + if any(done): + raise RuntimeError("Some towers made a different number of " + "tower_context().merge_call() calls.") + # get_tower_context().merge_call() case + merge_args = values.regroup({t.device: t.merge_args for t in threads}) + merge_kwargs = values.regroup( + {t.device: t.merge_kwargs for t in threads}) + # We capture the name_scope of the MTT when we call merge_fn + # to ensure that if we have opened a name scope in the MTT, + # it will be respected when executing the merge function. We only + # capture the name_scope from the first MTT and assume it is + # the same for all other MTTs. + mtt_captured_name_scope = threads[0].captured_name_scope + with ops.name_scope(mtt_captured_name_scope): + merge_result = threads[0].merge_fn(distribution, *merge_args, + **merge_kwargs) + for t in threads: + t.merge_result = values.select_device(t.device, merge_result) + finally: + for t in threads: + t.should_run.set() + coord.join(threads) + + return values.regroup({t.device: t.main_result for t in threads}) + + +def _reduce_non_distributed_value(distribution, aggregation, value, + destinations): + """Reduce a non-DistributedValue `value` to `destinations`.""" + if isinstance(value, values.DistributedValues): + raise ValueError("You are passing a `DistributedValue` to " + "`_reduce_non_distributed_value`, which is not allowed.") + + # If the same value is present on all towers then the PerDevice value will + # be a single value. We also handle the case when `value` is a single value + # and equal to 0. + if value == 0: + return 0 + # If the aggregation type is MEAN, then this essentially means that the same + # value should be on all destinations. + if aggregation == variable_scope.VariableAggregation.MEAN: + return distribution.broadcast(value, destinations) + + cross_tower_ops_lib.validate_destinations(destinations) + # We do not support an aggregation type of SUM if the value is the same across + # all towers. We call this as part of assign functions for MirroredVariables + # and summing up identical values across towers is not clearly defined. + if (len(distribution.worker_devices) != 1 or + not cross_tower_ops_lib.check_destinations(destinations)): + raise ValueError("A non-DistributedValues value cannot be reduced with the " + "given aggregation.") + # TODO(anjalisridhar): Moves these methods to a device utility file? + devices = cross_tower_ops_lib.get_devices_from(destinations) + if len(devices) == 1: + with ops.device(devices[0]): + return array_ops.identity(value) + else: + value_updates = {} + for d in devices: + with ops.device(d): + value_updates[d] = array_ops.identity(value) + return values.Mirrored(value_updates) + + +def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring + # Figure out what collections this variable should be added to. + # We'll add the MirroredVariable to those collections instead. + collections = kwargs.pop("collections", None) + if collections is None: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + kwargs["collections"] = [] + + # Get synchronization value + synchronization = kwargs.get("synchronization", + variable_scope.VariableSynchronization.ON_WRITE) + if synchronization == variable_scope.VariableSynchronization.NONE: + raise ValueError("`NONE` variable synchronization mode is not " + "supported with `Mirrored` distribution strategy. Please" + " change the `synchronization` for variable: " + + kwargs["name"]) + elif synchronization == variable_scope.VariableSynchronization.ON_READ: + # Variables that are to be synced on read are tower local. + is_tower_local = True + kwargs["trainable"] = False + elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or + synchronization == variable_scope.VariableSynchronization.AUTO): + # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. + is_tower_local = False + else: + raise ValueError("Invalid variable synchronization mode: " + + synchronization + " for variable: " + kwargs["name"]) + + # Get aggregation value + aggregation = kwargs.pop("aggregation", + variable_scope.VariableAggregation.NONE) + if aggregation not in [ + variable_scope.VariableAggregation.NONE, + variable_scope.VariableAggregation.SUM, + variable_scope.VariableAggregation.MEAN + ]: + raise ValueError("Invalid variable aggregation mode: " + aggregation + + " for variable: " + kwargs["name"]) + + # Ignore user-specified caching device, not needed for mirrored variables. + kwargs.pop("caching_device", None) + + # TODO(josh11b,apassos): It would be better if variable initialization + # was never recorded on the tape instead of having to do this manually + # here. + with tape.stop_recording(): + index = real_mirrored_creator(devices, *args, **kwargs) + + if is_tower_local: + result = values.TowerLocalVariable(index, index[devices[0]], aggregation) + else: + result = values.MirroredVariable(index, index[devices[0]], aggregation) + + if not context.executing_eagerly(): + g = ops.get_default_graph() + # If "trainable" is True, next_creator() will add the member variables + # to the TRAINABLE_VARIABLES collection, so we manually remove + # them and replace with the MirroredVariable. We can't set + # "trainable" to False for next_creator() since that causes functions + # like implicit_gradients to skip those variables. + if kwargs.get("trainable", True): + collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) + for v in index.values(): + l.remove(v) + g.add_to_collections(collections, result) + return result + + class MirroredStrategy(distribute_lib.DistributionStrategy): """Mirrors vars to distribute across multiple devices on a single machine. @@ -94,27 +324,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" - # Figure out what collections this variable should be added to. - # We'll add the MirroredVariable to those collections instead. - collections = kwargs.pop("collections", None) - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] - kwargs["collections"] = [] - colocate_with = kwargs.pop("colocate_with", None) devices = self._get_devices_from(colocate_with) - tower_local = kwargs.pop("tower_local_reduce_method", None) - if tower_local is not None: - kwargs["trainable"] = False - - # Ignore user-specified caching device, not needed for mirrored variables. - kwargs.pop("caching_device", None) - - # TODO(josh11b,apassos): It would be better if variable initialization - # was never recorded on the tape instead of having to do this manually - # here. - with tape.stop_recording(): + def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring index = {} for i, d in enumerate(devices): with ops.device(d): @@ -138,149 +351,71 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) index[d] = v + return index - if tower_local is None: - result = values.MirroredVariable(index, index[devices[0]]) - else: - result = values.TowerLocalVariable( - index, index[devices[0]], tower_local) - - if not context.executing_eagerly(): - g = ops.get_default_graph() - # If "trainable" is True, next_creator() will add the member variables - # to the TRAINABLE_VARIABLES collection, so we manually remove - # them and replace with the MirroredVariable. We can't set - # "trainable" to False for next_creator() since that causes functions - # like implicit_gradients to skip those variables. - if kwargs.get("trainable", True): - collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) - l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) - for v in index.values(): - l.remove(v) - g.add_to_collections(collections, result) - return result + return _create_mirrored_variable(devices, _real_mirrored_creator, *args, + **kwargs) def distribute_dataset(self, dataset_fn): return values.PerDeviceDataset( self._call_dataset_fn(dataset_fn), self._devices, self._prefetch_on_device) + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. + def _run_steps_on_dataset(self, fn, iterator, iterations, + initial_loop_values=None): + if initial_loop_values is None: + initial_loop_values = {} + initial_loop_values = nest.flatten(initial_loop_values) + + ctx = values.MultiStepContext() + def body(i, *args): + """A wrapper around `fn` to create the while loop body.""" + del args + fn_result = fn(ctx, iterator.get_next()) + for (name, output) in ctx.last_step_outputs.items(): + # Convert all outputs to tensors, potentially from `DistributedValues`. + ctx.last_step_outputs[name] = self.unwrap(output) + flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) + with ops.control_dependencies([fn_result]): + return [i + 1] + flat_last_step_outputs + + cond = lambda i, *args: i < iterations + i = constant_op.constant(0) + loop_result = control_flow_ops.while_loop( + cond, body, [i] + initial_loop_values, name="", + parallel_iterations=1, back_prop=False, swap_memory=False, + return_same_structure=True) + + ctx.run_op = control_flow_ops.group(loop_result) + + # Convert the last_step_outputs from a list to the original dict structure + # of last_step_outputs. + last_step_tensor_outputs = loop_result[1:] + last_step_tensor_outputs_dict = nest.pack_sequence_as( + ctx.last_step_outputs, last_step_tensor_outputs) + + for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access + output = last_step_tensor_outputs_dict[name] + # For outputs that have already been aggregated, wrap them in a Mirrored + # container, else in a PerDevice container. + if aggregation is variables_lib.VariableAggregation.NONE: + last_step_tensor_outputs_dict[name] = values.regroup( + {d: t for d, t in zip(self._devices, output)}, values.PerDevice) + else: + assert len(output) == 1 + last_step_tensor_outputs_dict[name] = output[0] + + ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access + return ctx + def _broadcast(self, tensor, destinations): # TODO(josh11b): In eager mode, use one thread per device, or async mode. return self._get_cross_tower_ops().broadcast(tensor, destinations or self._devices) def _call_for_each_tower(self, fn, *args, **kwargs): - """Run `fn` in separate threads, once per tower/worker device. - - Args: - fn: function to run (will be run once per device, each in its own thread). - *args: positional arguments for `fn` - **kwargs: keyword arguments for `fn`. - `"run_concurrently"`: Boolean indicating whether executions of `fn` - can be run concurrently (under eager execution only), defaults to - `True`. - - Returns: - Merged return value of `fn` across all towers. - - Raises: - RuntimeError: If fn() calls get_tower_context().merge_call() a different - number of times for when called for different devices. - """ - run_concurrently = kwargs.pop("run_concurrently", True) - if not context.executing_eagerly(): - # Lots of TF library code isn't thread-safe in graph mode, and - # there is little to be gained by turning on multithreading when - # constructing a graph. - run_concurrently = False - # Needed for per-thread device, etc. contexts in graph mode. - ops.get_default_graph().switch_to_thread_local() - elif run_concurrently is None: - run_concurrently = True - - coord = coordinator.Coordinator( - clean_stop_exception_types=(_RequestedStop,)) - - shared_variable_store = {} - - # TODO(isaprykin): Create these threads once instead of during every run() - # call. - threads = [] - for index, d in enumerate(self._devices): - variable_creator_fn = shared_variable_creator.make_fn( - shared_variable_store, index) - t = MirroredStrategy._MirroredTowerThread( - self, coord, d, variable_creator_fn, fn, - *values.select_device(d, args), **values.select_device(d, kwargs)) - threads.append(t) - - for t in threads: - t.start() - - # When `fn` starts `should_run` event is set on _MirroredTowerThread - # (`MTT`) threads. The execution waits until - # `MTT.has_paused` is set, which indicates that either `fn` is - # complete or a `get_tower_context().merge_call()` is called. If `fn` is - # complete, then `MTT.done` is set to True. Otherwise, arguments - # of `get_tower_context().merge_call` from all paused threads are grouped - # and the `merge_fn` is performed. Results of the - # `get_tower_context().merge_call` are then set to `MTT.merge_result`. - # Each such `get_tower_context().merge_call` call returns the - # `MTT.merge_result` for that thread when `MTT.should_run` event - # is reset again. Execution of `fn` resumes. - - try: - with coord.stop_on_exception(): - all_done = False - while not all_done and not coord.should_stop(): - done = [] - if run_concurrently: - for t in threads: - t.should_run.set() - for t in threads: - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - else: - for t in threads: - t.should_run.set() - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - if coord.should_stop(): - return None - all_done = all(done) - if not all_done: - if any(done): - raise RuntimeError("Some towers made a different number of " - "tower_context().merge_call() calls.") - # get_tower_context().merge_call() case - merge_args = values.regroup( - {t.device: t.merge_args for t in threads}) - merge_kwargs = values.regroup( - {t.device: t.merge_kwargs for t in threads}) - # We capture the name_scope of the MTT when we call merge_fn - # to ensure that if we have opened a name scope in the MTT, - # it will be respected when executing the merge function. We only - # capture the name_scope from the first MTT and assume it is - # the same for all other MTTs. - mtt_captured_name_scope = threads[0].captured_name_scope - with ops.name_scope(mtt_captured_name_scope): - merge_result = threads[0].merge_fn( - self, *merge_args, **merge_kwargs) - for t in threads: - t.merge_result = values.select_device(t.device, merge_result) - finally: - for t in threads: - t.should_run.set() - coord.join(threads) - - return values.regroup({t.device: t.main_result for t in threads}) + return _call_for_each_tower(self, fn, *args, **kwargs) def map(self, map_over, fn, *args, **kwargs): # TODO(josh11b): In eager mode, use one thread per device. @@ -308,36 +443,19 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()) return self._cross_tower_ops - def _reduce(self, method_string, value, destinations): + def _reduce(self, aggregation, value, destinations): assert not isinstance(value, values.Mirrored) - if not isinstance(value, values.PerDevice): - if value == 0: - return 0 - if method_string == "mean": - return self._broadcast(value, destinations) - - cross_tower_ops_lib.validate_destinations(destinations) - if len(self._devices) == 1: - if destinations: - # TODO(anjalisridhar): Moves these methods to a device utility file? - devices = cross_tower_ops_lib.get_devices_from(destinations) - if len(devices) == 1: - with ops.device(devices[0]): - return array_ops.identity(value) - else: - value_updates = {} - for d in devices: - with ops.device(d): - value_updates[d] = array_ops.identity(value) - return values.Mirrored(value_updates) - raise ValueError("A non PerDevice value cannot be reduced with the given " - "method_string.") - + if not isinstance(value, values.DistributedValues): + # This function handles reducing values that are not PerDevice or Mirrored + # values. For example, the same value could be present on all towers in + # which case `value` would be a single value or value could be 0. + return _reduce_non_distributed_value(self, aggregation, value, + destinations) return self._get_cross_tower_ops().reduce( - method_string, value, destinations=destinations) + aggregation, value, destinations=destinations) - def _batch_reduce(self, method_string, value_destination_pairs): - return self._get_cross_tower_ops().batch_reduce(method_string, + def _batch_reduce(self, aggregation, value_destination_pairs): + return self._get_cross_tower_ops().batch_reduce(aggregation, value_destination_pairs) def _update(self, var, fn, *args, **kwargs): @@ -379,6 +497,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): return [val.get(device=d) for d in sorted(val.devices)] return [val] + def value_container(self, val): + return values.value_container(val) + @property def is_single_tower(self): return len(self._devices) == 1 @@ -406,15 +527,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def _get_devices_from(self, colocate_with=None): if colocate_with is None: return self._devices - elif isinstance(colocate_with, values.DistributedValues): - # pylint: disable=protected-access - return list(colocate_with._index.keys()) - elif isinstance(colocate_with, six.string_types): - return [device_util.resolve(colocate_with)] - elif isinstance(colocate_with, list): - return [device_util.resolve(d) for d in colocate_with] else: - return colocate_with + return cross_tower_ops_lib.get_devices_from(colocate_with) class _MirroredTowerThread(threading.Thread): """A thread that runs() a function on a device.""" diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 8d474124b7e0a80d49ed646254269988f49d69e4..e064cfe37db40a51e18a16c532500415a8b74816 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -25,7 +25,9 @@ from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -37,6 +39,7 @@ from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib @@ -114,7 +117,10 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): dist = self._get_distribution_strategy() with dist.scope(): result = dist.call_for_each_tower(run_fn, dist.worker_device_index) - reduced = dist.reduce("sum", result, destinations="/device:CPU:0") + reduced = dist.reduce( + variable_scope.VariableAggregation.SUM, + result, + destinations="/device:CPU:0") unwrapped = dist.unwrap(reduced) self.assertEqual(1, len(unwrapped)) expected = sum(range(len(dist.worker_devices))) @@ -132,8 +138,10 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): dist = mirrored_strategy.MirroredStrategy(devices) with dist.scope(): - reduced = dist.reduce("sum", 1.0, destinations=["/device:CPU:0", - "/device:GPU:0"]) + reduced = dist.reduce( + variable_scope.VariableAggregation.SUM, + 1.0, + destinations=["/device:CPU:0", "/device:GPU:0"]) unwrapped = dist.unwrap(reduced) self.assertEqual(2, len(unwrapped)) self.assertEqual(1.0, self.evaluate(unwrapped[0])) @@ -283,19 +291,69 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertIsInstance(bias, values.MirroredVariable) self.assertEquals("common/dense" + suffix + "/bias:0", bias.name) + @test_util.run_in_graph_and_eager_modes(config=config) + def testWithVariableAndVariableScope(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + v0 = variable_scope.variable(1.0, name="var0", aggregation=None) + with variable_scope.variable_scope("common"): + v1 = variable_scope.variable(1.0, name="var1") + # This will pause the current thread, and execute the other thread. + distribute_lib.get_tower_context().merge_call(lambda _: _) + v2 = variable_scope.variable( + 1.0, + name="var2", + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + v3 = variable_scope.variable( + 1.0, + name="var3", + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation=variable_scope.VariableAggregation.MEAN) + + return v0, v1, v2, v3 + + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + v = variable_scope.variable(1.0, name="var-main0") + self.assertEquals("var-main0:0", v.name) + + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertEquals(4, len(result)) + v0, v1, v2, v3 = result + self.assertIsInstance(v0, values.MirroredVariable) + self.assertEquals("var0:0", v0.name) + self.assertIsInstance(v1, values.MirroredVariable) + self.assertEquals("common/var1:0", v1.name) + self.assertIsInstance(v2, values.TowerLocalVariable) + self.assertEquals("common/var2:0", v2.name) + self.assertEquals(variable_scope.VariableAggregation.SUM, v2.aggregation) + self.assertIsInstance(v3, values.MirroredVariable) + self.assertEquals("common/var3:0", v3.name) + self.assertEquals(variable_scope.VariableAggregation.MEAN, v3.aggregation) + @test_util.run_in_graph_and_eager_modes(config=config) def testWithGetVariableAndVariableScope(self): self._skip_eager_if_gpus_less_than(1) def model_fn(): - v0 = variable_scope.get_variable("var-thread0", [1]) + v0 = variable_scope.get_variable("var0", [1]) with variable_scope.variable_scope("common"): - v1 = variable_scope.get_variable("var-thread1", [1]) + v1 = variable_scope.get_variable("var1", [1]) # This will pause the current thread, and execute the other thread. distribute_lib.get_tower_context().merge_call(lambda _: _) - v2 = variable_scope.get_variable("var-thread2", [1]) + v2 = variable_scope.get_variable( + "var2", [1], + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + v3 = variable_scope.get_variable( + "var3", [1], + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation=variable_scope.VariableAggregation.MEAN) - return v0, v1, v2 + return v0, v1, v2, v3 devices = ["/device:CPU:0", "/device:GPU:0"] dist = mirrored_strategy.MirroredStrategy(devices) @@ -305,14 +363,89 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertEquals("main/var-main0:0", v.name) result = dist.call_for_each_tower(model_fn, run_concurrently=False) - self.assertEquals(3, len(result)) - v0, v1, v2 = result + self.assertEquals(4, len(result)) + v0, v1, v2, v3 = result self.assertIsInstance(v0, values.MirroredVariable) - self.assertEquals("main/var-thread0:0", v0.name) + self.assertEquals("main/var0:0", v0.name) self.assertIsInstance(v1, values.MirroredVariable) - self.assertEquals("main/common/var-thread1:0", v1.name) - self.assertIsInstance(v2, values.MirroredVariable) - self.assertEquals("main/common/var-thread2:0", v2.name) + self.assertEquals("main/common/var1:0", v1.name) + self.assertIsInstance(v2, values.TowerLocalVariable) + self.assertEquals("main/common/var2:0", v2.name) + self.assertEquals(variable_scope.VariableAggregation.SUM, + v2.aggregation) + self.assertIsInstance(v3, values.MirroredVariable) + self.assertEquals("main/common/var3:0", v3.name) + self.assertEquals(variable_scope.VariableAggregation.MEAN, + v3.aggregation) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testNoneSynchronizationWithGetVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "`NONE` variable synchronization mode is not " + "supported with `Mirrored` distribution strategy. Please change " + "the `synchronization` for variable: v"): + variable_scope.get_variable( + "v", [1], + synchronization=variable_scope.VariableSynchronization.NONE) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testNoneSynchronizationWithVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "`NONE` variable synchronization mode is not " + "supported with `Mirrored` distribution strategy. Please change " + "the `synchronization` for variable: v"): + variable_scope.variable( + 1.0, + name="v", + synchronization=variable_scope.VariableSynchronization.NONE) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testInvalidSynchronizationWithVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "Invalid variable synchronization mode: Invalid for " + "variable: v"): + variable_scope.variable(1.0, name="v", synchronization="Invalid") + + @test_util.run_in_graph_and_eager_modes(config=config) + def testInvalidAggregationWithGetVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "Invalid variable aggregation mode: invalid for " + "variable: v"): + variable_scope.get_variable( + "v", [1], + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation="invalid") + + @test_util.run_in_graph_and_eager_modes(config=config) + def testInvalidAggregationWithVariable(self): + self._skip_eager_if_gpus_less_than(1) + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with self.assertRaisesRegexp( + ValueError, "Invalid variable aggregation mode: invalid for " + "variable: v"): + variable_scope.variable( + 1.0, + name="v", + synchronization=variable_scope.VariableSynchronization.ON_WRITE, + aggregation="invalid") @test_util.run_in_graph_and_eager_modes(config=config) def testThreeDevices(self): @@ -361,11 +494,14 @@ class MirroredStrategyVariableCreationTest(test.TestCase): components_mean = {} def model_fn(device_id): - tower_context = distribute_lib.get_tower_context() - with tower_context.tower_local_var_scope("sum"): - v_sum = variable_scope.variable(1.0) - with tower_context.tower_local_var_scope("mean"): - v_mean = variable_scope.variable(4.0) + v_sum = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + v_mean = variable_scope.variable( + 4.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.MEAN) self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) self.assertTrue(isinstance(v_mean, values.TowerLocalVariable)) updates = [v_sum.assign_add(2.0 + device_id), @@ -568,9 +704,10 @@ class MirroredStrategyVariableCreationTest(test.TestCase): with context.graph_mode(): def model_fn(): - tower_context = distribute_lib.get_tower_context() - with tower_context.tower_local_var_scope("sum"): - v_sum = variable_scope.variable(1.0) + v_sum = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) return v_sum @@ -642,7 +779,8 @@ class MirroredVariableUpdateTest(test.TestCase): # aggregation type. self._skip_eager_if_gpus_less_than(1) def var_fn(): - v = variable_scope.variable(1.0, name="foo") + v = variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.SUM) return v dist = mirrored_strategy.MirroredStrategy( @@ -650,9 +788,6 @@ class MirroredVariableUpdateTest(test.TestCase): with dist.scope(): mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) - # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the - # aggregation method. - mirrored_var._aggregation_method = "sum" self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) @@ -660,8 +795,8 @@ class MirroredVariableUpdateTest(test.TestCase): return mirrored_var.assign(5.0) with self.assertRaisesRegexp( - ValueError, "A non PerDevice value cannot be reduced with the given " - "method_string."): + ValueError, "A non-DistributedValues value cannot be reduced with " + "the given aggregation."): self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn))) @test_util.run_in_graph_and_eager_modes(config=config) @@ -685,16 +820,14 @@ class MirroredVariableUpdateTest(test.TestCase): def testAssignMirroredVarTowerContext(self): self._skip_eager_if_gpus_less_than(1) def var_fn(): - return variable_scope.variable(1.0, name="foo") + return variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) dist = mirrored_strategy.MirroredStrategy( ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) - # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the - # aggregation method. - mirrored_var._aggregation_method = "mean" self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) @@ -708,6 +841,29 @@ class MirroredVariableUpdateTest(test.TestCase): model_fn, run_concurrently=False))) self.assertEquals(0.5, self.evaluate(mirrored_var)) + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignMirroredVarTowerContextWithSingleValue(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + + def model_fn(): + return mirrored_var.assign(5.0) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(5.0, self.evaluate(mirrored_var)) + @test_util.run_in_graph_and_eager_modes(config=config) def testAssignAddMirroredVarCrossTowerContext(self): self._skip_eager_if_gpus_less_than(1) @@ -729,16 +885,14 @@ class MirroredVariableUpdateTest(test.TestCase): def testAssignAddMirroredVarTowerContext(self): self._skip_eager_if_gpus_less_than(1) def var_fn(): - return variable_scope.variable(1.0, name="foo") + return variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) dist = mirrored_strategy.MirroredStrategy( ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) - # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the - # aggregation method. - mirrored_var._aggregation_method = "mean" self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) @@ -752,6 +906,29 @@ class MirroredVariableUpdateTest(test.TestCase): model_fn, run_concurrently=False))) self.assertEquals(1.5, self.evaluate(mirrored_var)) + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignAddMirroredVarTowerContextWithSingleValue(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(1.0, self.evaluate(mirrored_var)) + + def model_fn(): + return mirrored_var.assign_add(5.0) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(6.0, self.evaluate(mirrored_var)) + @test_util.run_in_graph_and_eager_modes(config=config) def testAssignSubMirroredVarCrossTowerContext(self): self._skip_eager_if_gpus_less_than(1) @@ -773,16 +950,14 @@ class MirroredVariableUpdateTest(test.TestCase): def testAssignSubMirroredVarTowerContext(self): self._skip_eager_if_gpus_less_than(1) def var_fn(): - return variable_scope.variable(5.0, name="foo") + return variable_scope.variable( + 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) dist = mirrored_strategy.MirroredStrategy( ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) - # TODO(anjalisridhar): Use API introduced in cr/201463945 to set the - # aggregation method. - mirrored_var._aggregation_method = "mean" self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(5.0, self.evaluate(mirrored_var)) @@ -796,6 +971,268 @@ class MirroredVariableUpdateTest(test.TestCase): model_fn, run_concurrently=False))) self.assertEquals(4.5, self.evaluate(mirrored_var)) + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignSubMirroredVarTowerContextWithSingleValue(self): + self._skip_eager_if_gpus_less_than(1) + def var_fn(): + return variable_scope.variable( + 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.evaluate(variables.global_variables_initializer()) + self.assertEquals(5.0, self.evaluate(mirrored_var)) + + def model_fn(): + return mirrored_var.assign_sub(1.0) + + self.evaluate(dist.unwrap(dist.call_for_each_tower( + model_fn, run_concurrently=False))) + self.assertEquals(4.0, self.evaluate(mirrored_var)) + + +class MirroredAndTowerLocalVariableInitializerTest(test.TestCase): + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + def testAssignMirroredVarInitializer(self): + # This test is not eager compatible since in eager variables are initialized + # upon construction instead of once the initialization op is run. + with context.graph_mode(): + def var_fn(): + v = variable_scope.variable(1.0, name="foo") + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + mirrored_var = dist.call_for_each_tower(var_fn) + self.assertIsInstance(mirrored_var, values.MirroredVariable) + self.assertFalse(self.evaluate(mirrored_var.is_initialized())) + self.evaluate(mirrored_var.initializer) + self.assertTrue(self.evaluate(mirrored_var.is_initialized())) + + def testAssignTowerLocalVarInitializer(self): + # This test is not eager compatible since in eager variables are initialized + # upon construction instead of once the initialization op is run. + with context.graph_mode(): + def model_fn(): + v_sum = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) + return v_sum + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + tower_local_var = dist.call_for_each_tower(model_fn) + self.assertTrue(isinstance(tower_local_var, values.TowerLocalVariable)) + self.assertFalse(self.evaluate(tower_local_var.is_initialized())) + self.evaluate(tower_local_var.initializer) + self.assertTrue(self.evaluate(tower_local_var.is_initialized())) + + +class TowerLocalVariableAssignTest(test.TestCase): + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + def _skip_eager_if_gpus_less_than(self, num_gpus): + if context.num_gpus() < num_gpus and context.executing_eagerly(): + self.skipTest("Not enough GPUs available for this test in eager mode.") + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignTowerLocalVarSumAggregation(self): + self._skip_eager_if_gpus_less_than(1) + def model_fn(): + v_sum = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + return v_sum + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + tower_local_var = dist.call_for_each_tower(model_fn, + run_concurrently=False) + self.assertTrue(isinstance(tower_local_var, values.TowerLocalVariable)) + self.evaluate(variables.global_variables_initializer()) + # Each tower has a value of 1.0 assigned to it in tower context. + # When we read the value using `read_var` we should see the SUM of each of + # values on each of the towers. + self.assertEqual(2.0, self.evaluate(dist.read_var(tower_local_var))) + # Assigning 6.0 in cross tower context will assign a value of + # 6.0/num_towers to each tower. + tlv_ops = tower_local_var.assign(6.0) + self.evaluate(tlv_ops) + # On reading the tower local var we should get the assigned value back. + # The value on all the towers are added before being returned by + # `read_var`. + self.assertEqual(6.0, self.evaluate(dist.read_var(tower_local_var))) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testAssignTowerLocalVarMeanAggregation(self): + self._skip_eager_if_gpus_less_than(1) + def model_fn(): + v_sum = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.MEAN) + return v_sum + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + tower_local_var = dist.call_for_each_tower(model_fn, + run_concurrently=False) + self.assertTrue(isinstance(tower_local_var, values.TowerLocalVariable)) + self.evaluate(variables.global_variables_initializer()) + # Each tower has a value of 1.0 assigned to it in tower context. + # When we read the value using `read_var` we should see the MEAN of values + # on all towers which is the value assigned in tower context. + self.assertEqual(1.0, self.evaluate(dist.read_var(tower_local_var))) + tlv_ops = tower_local_var.assign(6.0) + self.evaluate(tlv_ops) + # On reading the tower local var we should get the MEAN of all values + # which is equal to the value assigned. + self.assertEqual(6.0, self.evaluate(dist.read_var(tower_local_var))) + + +class MockModel(object): + + def __init__(self, two_variables=False): + self.variables = [] + self.variables.append(variable_scope.variable(1.25, name="dummy_var1")) + if two_variables: + self.variables.append(variable_scope.variable(2.0, name="dummy_var2")) + + def __call__(self, factor=2): + x = factor * self.variables[0] + if len(self.variables) > 1: + x += self.variables[1] + return x + + +class MirroredStrategyDefunTest(test.TestCase): + + def _skip_eager_if_gpus_less_than(self, num_gpus): + if context.num_gpus() < num_gpus and context.executing_eagerly(): + self.skipTest("Not enough GPUs available for this test in eager mode.") + + def _call_and_check(self, model_fn, inputs, expected_result, defuns, + two_variables=False): + cpu_dev = device_util.canonicalize("CPU:0") + gpu_dev = device_util.canonicalize("GPU:0") + devices = [cpu_dev, gpu_dev] + dist = mirrored_strategy.MirroredStrategy(devices) + + with dist.scope(): + mock_model = MockModel(two_variables) + self.evaluate(variables.global_variables_initializer()) + + result = dist.call_for_each_tower(model_fn, mock_model, *inputs, + run_concurrently=False) + for device in devices: + device_result = values.select_device(device, result) + device_expected_result = values.select_device(device, expected_result) + self.assertAllClose(device_expected_result, + self.evaluate(device_result)) + + for defun in defuns: + self.assertEqual(set(mock_model.variables), set(defun.variables)) + + @test_util.run_in_graph_and_eager_modes() + def testVariableInDefun(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def times_two(mock_model): + return mock_model() + + def model_fn(mock_model): + return times_two(mock_model) + + self._call_and_check(model_fn, [], 2.5, [times_two]) + + @test_util.run_in_graph_and_eager_modes() + def testVariableInNestedDefun(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def times_two(mock_model): + return mock_model() + + @function.defun + def two_x_plus_one(mock_model): + return times_two(mock_model) + 1 + + def model_fn(mock_model): + return two_x_plus_one(mock_model) + + self._call_and_check(model_fn, [], 3.5, [times_two, two_x_plus_one]) + + @test_util.run_in_graph_and_eager_modes() + def testTwoVariablesInNestedDefun(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def fn1(mock_model): + return mock_model() + + @function.defun + def fn2(mock_model): + return fn1(mock_model) + 1 + + def model_fn(mock_model): + return fn2(mock_model) + + self._call_and_check(model_fn, [], 5.5, [fn1, fn2], two_variables=True) + + @test_util.run_in_graph_and_eager_modes() + def testGradientTapeOverNestedDefuns(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def fn1(mock_model): + return mock_model() + + @function.defun + def fn2(mock_model): + return fn1(mock_model) + 1 + + def model_fn(mock_model): + with backprop.GradientTape(persistent=True) as gtape: + result = fn2(mock_model) + grads = gtape.gradient(result, + [v.get() for v in mock_model.variables]) + return grads + + self._call_and_check(model_fn, [], [2.0, 1.0], [fn1, fn2], + two_variables=True) + + @test_util.run_in_graph_and_eager_modes() + def testPassPerDevice(self): + self._skip_eager_if_gpus_less_than(1) + + @function.defun + def fn1(mock_model, factor): + return mock_model(factor) + + factors = values.PerDevice({"CPU:0": 5.0, "GPU:0": 3.0}) + expected_result = values.PerDevice({"CPU:0": 5.0 * 1.25, + "GPU:0": 3.0 * 1.25}) + self._call_and_check(fn1, [factors], expected_result, [fn1]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py index 0f21a427320510635279f80c11711e81715ec37c..cbfe5df61d1ee6fa1eb9275b715b0721d678a46f 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_strategy.py +++ b/tensorflow/contrib/distribute/python/multi_worker_strategy.py @@ -46,7 +46,7 @@ class MultiWorkerMirroredStrategy(MirroredStrategy): * **In-graph replication**: the `client` creates a single `tf.Graph` that specifies tasks for devices on all workers. The `client` then creates a client session which will talk to the `master` service of a `worker`. Then - the `master` will parition the graph and distribute the work to all + the `master` will partition the graph and distribute the work to all participating workers. * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one physical machine. We will have multiple `worker`s with different `task` diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index f659be5f42594b275af06435cb0c228e5d594ac9..249de01f0880b02d603687db99692088480f7136 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -20,35 +20,68 @@ from __future__ import print_function import contextlib import copy +import threading +import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session -from tensorflow.python.eager import test +from tensorflow.python.estimator import run_config +from tensorflow.python.platform import test from tensorflow.python.framework import test_util +def create_in_process_cluster(num_workers, num_ps): + """Create an in-process cluster that consists of only standard server.""" + # Leave some memory for cuda runtime. + gpu_mem_frac = 0.7 / num_workers + worker_config = config_pb2.ConfigProto() + worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac + + # Enable collective ops which has no impact on non-collective ops. + # TODO(yuefengz, tucker): removing this after we move the initialization of + # collective mgr to the session level. + worker_config.experimental.collective_group_leader = ( + '/job:worker/replica:0/task:0') + + ps_config = config_pb2.ConfigProto() + ps_config.device_count['GPU'] = 0 + + # Create in-process servers. Once an in-process tensorflow server is created, + # there is no way to terminate it. So we create one cluster per test process. + # We could've started the server in another process, we could then kill that + # process to terminate the server. The reasons why we don't want multiple + # processes are + # 1) it is more difficult to manage these processes; + # 2) there is something global in CUDA such that if we initialize CUDA in the + # parent process, the child process cannot initialize it again and thus cannot + # use GPUs (https://stackoverflow.com/questions/22950047). + return test_util.create_local_cluster( + num_workers, + num_ps=num_ps, + worker_config=worker_config, + ps_config=ps_config, + protocol='grpc') + + class MultiWorkerTestBase(test.TestCase): """Base class for testing multi node strategy and dataset.""" @classmethod def setUpClass(cls): """Create a local cluster with 2 workers.""" - num_workers = 2 - # Leave some memory for cuda runtime. - gpu_mem_frac = 0.7 / num_workers - default_config = config_pb2.ConfigProto() - default_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac - - # The local cluster takes some portion of the local GPUs and there is no way - # for the cluster to terminate unless using multiple processes. Therefore, - # we have to only create only one cluster throughout a test process. - workers, _ = test_util.create_local_cluster( - num_workers, num_ps=0, worker_config=default_config) - cls._master_target = workers[0].target + cls._workers, cls._ps = create_in_process_cluster(num_workers=2, num_ps=0) + + def setUp(self): + # We only cache the session in one test because another test may have a + # different session config or master target. + self._thread_local = threading.local() + self._thread_local.cached_session = None + self._result = 0 + self._lock = threading.Lock() @contextlib.contextmanager - def test_session(self, graph=None, config=None): + def test_session(self, graph=None, config=None, target=None): """Create a test session with master target set to the testing cluster. This overrides the base class' method, removes arguments that are not needed @@ -59,6 +92,7 @@ class MultiWorkerTestBase(test.TestCase): graph: Optional graph to use during the returned session. config: An optional config_pb2.ConfigProto to use to configure the session. + target: the target of session to connect to. Yields: A Session object that should be used as a context manager to surround @@ -78,13 +112,46 @@ class MultiWorkerTestBase(test.TestCase): rewriter_config_pb2.RewriterConfig.OFF) if graph is None: - if self._cached_session is None: # pylint: disable=access-member-before-definition - self._cached_session = session.Session( - graph=None, config=config, target=self._master_target) - sess = self._cached_session + if getattr(self._thread_local, 'cached_session', None) is None: + self._thread_local.cached_session = session.Session( + graph=None, config=config, target=target or self._workers[0].target) + sess = self._thread_local.cached_session with sess.graph.as_default(), sess.as_default(): yield sess else: with session.Session( - graph=graph, config=config, target=self._master_target) as sess: + graph=graph, config=config, target=target or + self._workers[0].target) as sess: yield sess + + def _run_client(self, client_fn, task_type, task_id, num_gpus, *args, + **kwargs): + result = client_fn(task_type, task_id, num_gpus, *args, **kwargs) + if np.all(result): + with self._lock: + self._result += 1 + + def _run_between_graph_clients(self, client_fn, cluster_spec, num_gpus, *args, + **kwargs): + """Runs several clients for between-graph replication. + + Args: + client_fn: a function that needs to accept `task_type`, `task_id`, + `num_gpus` and returns True if it succeeds. + cluster_spec: a dict specifying jobs in a cluster. + num_gpus: number of GPUs per worker. + *args: will be passed to `client_fn`. + **kwargs: will be passed to `client_fn`. + """ + threads = [] + for task_type in [run_config.TaskType.CHIEF, run_config.TaskType.WORKER]: + for task_id in range(len(cluster_spec.get(task_type, []))): + t = threading.Thread( + target=self._run_client, + args=(client_fn, task_type, task_id, num_gpus) + args, + kwargs=kwargs) + t.start() + threads.append(t) + for t in threads: + t.join() + self.assertEqual(self._result, len(threads)) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index a580dac96c5e6c6c8790aa6af7309988bf7a6477..016978cdb3a152bbba0a2e63df1dea4035e32789 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -21,10 +21,14 @@ from __future__ import print_function import six from tensorflow.contrib.distribute.python import values +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.util import nest # TODO(josh11b): Replace asserts in this file with if ...: raise ... @@ -43,11 +47,6 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): self._default_device = device def _create_variable(self, next_creator, *args, **kwargs): - # No need to distinguish tower-local variables when not mirroring, - # we just enforce that they are not trainable. - if kwargs.pop("tower_local_reduce_method", None) is not None: - kwargs["trainable"] = False - colocate_with = kwargs.pop("colocate_with", None) if colocate_with is None: with ops.device(self._device): @@ -70,6 +69,41 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): def _broadcast(self, tensor, destinations): return tensor + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. + def _run_steps_on_dataset(self, fn, iterator, iterations, + initial_loop_values=None): + if initial_loop_values is None: + initial_loop_values = {} + initial_loop_values = nest.flatten(initial_loop_values) + + ctx = values.MultiStepContext() + def body(i, *args): + """A wrapper around `fn` to create the while loop body.""" + del args + fn_result = fn(ctx, iterator.get_next()) + flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) + with ops.control_dependencies([fn_result]): + return [i + 1] + flat_last_step_outputs + + cond = lambda i, *args: i < iterations + i = constant_op.constant(0) + # TODO(priyag): Use max_iterations instead of an explicit counter. + loop_result = control_flow_ops.while_loop( + cond, body, [i] + initial_loop_values, name="", + parallel_iterations=1, back_prop=False, swap_memory=False, + return_same_structure=True) + + ctx.run_op = control_flow_ops.group(loop_result) + + # Convert the last_step_outputs from a list to the original dict structure + # of last_step_outputs. + last_step_tensor_outputs = loop_result[1:] + last_step_tensor_outputs_dict = nest.pack_sequence_as( + ctx.last_step_outputs, last_step_tensor_outputs) + + ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access + return ctx + def _call_for_each_tower(self, fn, *args, **kwargs): # We don't run `fn` in multiple threads in OneDeviceStrategy. kwargs.pop("run_concurrently", None) @@ -80,15 +114,15 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): with ops.device(self._device): return values.MapOutput([fn(m, *args, **kwargs) for m in map_over]) - def _reduce(self, method_string, value, destinations): + def _reduce(self, aggregation, value, destinations): if not isinstance(value, values.MapOutput): return value l = value.get() assert l with ops.device(self._device): - if method_string == "sum": + if aggregation == vs.VariableAggregation.SUM: return math_ops.add_n(l) - elif method_string == "mean": + elif aggregation == vs.VariableAggregation.MEAN: return math_ops.add_n(l) / len(l) else: assert False @@ -109,6 +143,9 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): def _unwrap(self, value): return [value] + def value_container(self, value): + return value + @property def is_single_tower(self): return True diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..f2c7fd556adaf08926b6f1e327abd25b7c9a42e6 --- /dev/null +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -0,0 +1,358 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Classes implementing a multi-worker ps DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os + +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import values +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.framework import device as tf_device +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training import device_setter +from tensorflow.python.training import device_util +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import server_lib +from tensorflow.python.util import nest + +_LOCAL_CPU = "/device:CPU:0" +_LOCAL_GPU_0 = "/device:GPU:0" + + +def _normalize_cluster_spec(cluster_spec): + """Makes `cluster_spec` into a `ClusterSpec` object.""" + if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): + return server_lib.ClusterSpec(cluster_spec) + elif not isinstance(cluster_spec, server_lib.ClusterSpec): + raise ValueError( + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + "`tf.train.ClusterDef` object") + return cluster_spec + + +# TODO(yuefengz): maybe cache variables on local CPU. +# TODO(yuefengz): we may want to set session options to disallow communication +# between workers. +class ParameterServerStrategy(distribute_lib.DistributionStrategy): + """A parameter server DistributionStrategy. + + This strategy class works for both local training and between-graph replicated + training for multiple workers. If `cluster_spec` is specified, either passed + in to __init__() method or parsed from the + ["TF_CONFIG" environment + variable](https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig), + variables and updates to those variables are assigned to parameter servers and + other operations are assigned to workers. If `cluster_spec` is not set, it + becomes local training where variables are assigned to local CPU or the only + GPU. When each worker has more than one GPU, operations will be replicated on + these GPUs. In both cases, operations are replicated but variables are not and + these workers share a common view for which paramater server a variable is + assigned to. + + This class assumes between-graph replication will be used and works on a graph + for a particular worker. + + It is expected to call `call_for_each_tower(fn, *args, **kwargs)` for any + operations which potentially can be replicated across towers (i.e. multiple + GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra + caution needs to be taken: + + 1) Always use @{tf.get_variable} instead of @{tf.Variable} which is not able + to refer to the same variable on different towers. + + 2) It is generally not recommended to open a device scope under the strategy's + scope. A device scope (i.e. calling @{tf.device}) will be merged with or + override the device for operations but will not change the device for + variables. + + 3) It is also not recommended to open a colocation scope (i.e. calling + @{tf.colocate_with}) under the strategy's scope. For colocating variables, + use `distribution.colocate_vars_with` instead. Colocation of ops will possibly + create conflicts of device assignement. + """ + + def __init__(self, + num_gpus_per_worker=0, + cluster_spec=None, + task_type=None, + task_id=None): + """Initiailizes this strategy. + + Args: + num_gpus_per_worker: number of local GPUs or GPUs per worker. + cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the + cluster configurations. + task_type: the current task type. + task_id: the current task id. + """ + super(ParameterServerStrategy, self).__init__() + self._num_gpus_per_worker = num_gpus_per_worker + if cluster_spec: + cluster_spec = _normalize_cluster_spec(cluster_spec) + self._cluster_spec = cluster_spec + + # We typically don't need to do all-reduce in this strategy. + self._cross_tower_ops = ( + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + reduce_to_device=_LOCAL_CPU)) + + self._initialize_devices(num_gpus_per_worker, cluster_spec, task_type, + task_id) + + def _initialize_devices(self, num_gpus_per_worker, cluster_spec, task_type, + task_id): + """Initialize internal devices. + + It creates variable devices and compute devices. Variables and operations + will be assigned to them respectively. We have one compute device per tower. + The variable device is a device function or device string. The default + variable device assigns variables to parameter servers in a round-robin + fashion. + + Args: + num_gpus_per_worker: number of local GPUs or GPUs per worker. + cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the + cluster configurations. + task_type: the current task type. + task_id: the current task id. + + Raises: + ValueError: if the cluster_spec doesn't have ps jobs. + """ + self._task_type = task_type or "worker" + self._task_id = task_id or 0 + self._worker_device = "/job:%s/task:%d" % (self._task_type, self._task_id) + + # TODO(yuefengz): maybe clearer to split it into two classes, one for + # the distribuetd case and one for the local case, once we have the factory + # class/method. + + # Define compute devices which is a list of device strings and one for each + # tower. When there are GPUs, replicate operations on these GPUs. Otherwise, + # place operations on CPU. + if cluster_spec is None: + # Local mode. + if num_gpus_per_worker > 0: + self._compute_devices = list( + map("/device:GPU:{}".format, range(num_gpus_per_worker))) + else: + self._compute_devices = [_LOCAL_CPU] + else: + # Distributed mode. + if num_gpus_per_worker > 0: + self._compute_devices = [ + "%s/device:GPU:%d" % (self._worker_device, i) + for i in range(num_gpus_per_worker) + ] + else: + self._compute_devices = [self._worker_device] + + self._compute_devices = list( + map(device_util.resolve, self._compute_devices)) + self._canonical_compute_device_set = set(self._compute_devices) + + # Define variable device which is a device string in the local case and a + # device function in the distributed case. It is used to open a device scope + # where varibles are defined. + # The `_parameter_devices` is needed for the `parameter_devices` property + # and is a list of all variable devices. + if cluster_spec is None: + # Local mode. If there is only one GPU, put everything on that GPU. + # Otherwise, place variables on CPU. + if num_gpus_per_worker == 1: + assert len(list(self._compute_devices)) == 1 + self._variable_device = _LOCAL_GPU_0 + self._parameter_devices = [_LOCAL_GPU_0] + else: + self._variable_device = _LOCAL_CPU + self._parameter_devices = [_LOCAL_CPU] + else: + # Distributed mode. Place variables on ps jobs in a round-robin fashion. + # Note that devices returned from `replica_device_setter` are not + # canonical and therefore we don't canonicalize all variable devices to + # make them consistent. + # TODO(yuefengz): support passing a strategy object to control variable + # assignment. + # TODO(yuefengz): merge the logic of replica_device_setter into this + # class. + num_ps_replicas = len(cluster_spec.as_dict().get("ps", [])) + if num_ps_replicas == 0: + raise ValueError("The cluster spec needs to have `ps` jobs.") + self._variable_device = device_setter.replica_device_setter( + ps_tasks=num_ps_replicas, + worker_device=self._worker_device, + merge_devices=True, + cluster=cluster_spec) + + # Parameter devices are all tasks of the "ps" job. + self._parameter_devices = map("/job:ps/task:{}".format, + range(num_ps_replicas)) + + # Define the default device in cross-tower mode. In the distributed case, we + # set the default device to the corresponding worker to prevent these ops + # from being placed on other workers. + if cluster_spec is None: + self._default_device = None + else: + self._default_device = self._worker_device + + def distribute_dataset(self, dataset_fn): + """Distributes the dataset to each local GPU.""" + return values.PerDeviceDataset( + self._call_dataset_fn(dataset_fn), self._compute_devices, True) + + def _broadcast(self, tensor, destinations): + if not cross_tower_ops_lib.check_destinations(destinations): + destinations = self._compute_devices + return self._cross_tower_ops.broadcast(tensor, destinations) + + # TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through + # this creator, such as "MutableHashTable". + def _create_variable(self, next_creator, *args, **kwargs): + if "colocate_with" in kwargs: + with ops.device(None): + with ops.colocate_with(kwargs["colocate_with"]): + return next_creator(*args, **kwargs) + + with ops.colocate_with(None, ignore_existing=True): + with ops.device(self._variable_device): + return next_creator(*args, **kwargs) + + def _call_for_each_tower(self, fn, *args, **kwargs): + # pylint: disable=protected-access + return mirrored_strategy._call_for_each_tower(self, fn, *args, **kwargs) + + def _verify_destinations_not_different_worker(self, destinations): + if destinations is None: + return + for d in cross_tower_ops_lib.get_devices_from(destinations): + d_spec = tf_device.DeviceSpec.from_string(d) + if d_spec.job == self._task_type and d_spec.task != self._task_id: + raise ValueError( + "Cannot reduce to another worker: %r, current worker is %r" % + (d, self._worker_device)) + + def _reduce(self, aggregation, value, destinations): + self._verify_destinations_not_different_worker(destinations) + if not isinstance(value, values.DistributedValues): + # pylint: disable=protected-access + return mirrored_strategy._reduce_non_distributed_value( + self, aggregation, value, destinations) + + return self._cross_tower_ops.reduce( + aggregation, value, destinations=destinations) + + def _batch_reduce(self, aggregation, value_destination_pairs): + for _, destinations in value_destination_pairs: + self._verify_destinations_not_different_worker(destinations) + return self._cross_tower_ops.batch_reduce(aggregation, + value_destination_pairs) + + def _select_single_value(self, structured): + """Select any single values in `structured`.""" + + def _select_fn(x): # pylint: disable=g-missing-docstring + if isinstance(x, values.Mirrored): + if len(x.devices) == 1: + return list(x._index.values())[0] # pylint: disable=protected-access + else: + raise ValueError( + "You cannot update variable with a Mirrored object with multiple " + "components %r when using ParameterServerStrategy. You must " + "specify a single value or a Mirrored with a single value." % x) + elif isinstance(x, values.PerDevice): + raise ValueError( + "You cannot update variable with a PerDevice object %r when using " + "ParameterServerStrategy. You must specify a single value or a " + "Mirrored with a single value" % x) + else: + return x + + return nest.map_structure(_select_fn, structured) + + def _update(self, var, fn, *args, **kwargs): + if not isinstance(var, resource_variable_ops.ResourceVariable): + raise ValueError( + "You can not update `var` %r. It must be a Variable." % var) + with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): + return fn(var, *self._select_single_value(args), + **self._select_single_value(kwargs)) + + # TODO(yuefengz): does it need to call _select_single_value? + def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + with ops.device( + colocate_with.device), distribute_lib.UpdateContext(colocate_with): + return fn(*args, **kwargs) + + def _unwrap(self, val): + if isinstance(val, values.DistributedValues): + # Return in a deterministic order. + if set(val.devices) == self._canonical_compute_device_set: + return [val.get(device=d) for d in self._compute_devices] + return [val.get(device=d) for d in sorted(val.devices)] + return [val] + + def value_container(self, val): + return values.value_container(val) + + def read_var(self, var): + # No need to distinguish between normal variables and tower-local variables. + return array_ops.identity(var) + + def configure(self, session_config=None): + del session_config + + # Use TF_CONFIG to get the cluster spec and the current job. + tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) + cluster_spec = _normalize_cluster_spec(tf_config.get("cluster", {})) + + task_env = tf_config.get("task", {}) + if task_env: + task_type = task_env.get("type", "worker") + task_id = int(task_env.get("index", "0")) + else: + task_type = "worker" + task_id = None + + # Set the devices if cluster_spec is defined in TF_CONFIG but not passed in + # the constructor. + if not self._cluster_spec and cluster_spec: + self._cluster_spec = cluster_spec + self._initialize_devices(self._num_gpus_per_worker, cluster_spec, + task_type, task_id) + + @property + def num_towers(self): + return len(self._compute_devices) + + @property + def worker_devices(self): + # Make a copy to prevent users from accidentally mutating our copy. + return list(self._compute_devices) + + @property + def parameter_devices(self): + return list(self._parameter_devices) + + def non_slot_devices(self, var_list): + return min(var_list, key=lambda x: x.name) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cf29c0ed91a14843ce15bf671dd363ca0f7073c0 --- /dev/null +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -0,0 +1,430 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ParameterServerStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import threading +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import parameter_server_strategy +from tensorflow.python.eager import context +from tensorflow.python.estimator import run_config +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import device_util +from tensorflow.python.training import distribute as distribute_lib + + +class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, + parameterized.TestCase): + + @classmethod + def setUpClass(cls): + cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=2) + cls._cluster_spec = { + run_config.TaskType.WORKER: [ + 'fake_worker_0', 'fake_worker_1', 'fake_worker_2' + ], + run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] + } + + def setUp(self): + self._result = 0 + self._lock = threading.Lock() + self._init_condition = threading.Condition() + self._init_reached = 0 + self._finish_condition = threading.Condition() + self._finish_reached = 0 + super(ParameterServerStrategyTest, self).setUp() + + def _get_test_objects(self, task_type, task_id, num_gpus): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=num_gpus) + if not task_type: + return distribution, '' + + tf_config = { + 'cluster': self._cluster_spec, + 'task': { + 'type': task_type, + 'index': task_id + } + } + with self._lock: + # Accessing environment variables should be protected by locks because + # environment variables are shared by all threads. + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + distribution.configure() + return distribution, self._workers[task_id].target + + def _test_device_assignment_distributed(self, task_type, task_id, num_gpus): + worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id) + d, _ = self._get_test_objects(task_type, task_id, num_gpus) + with ops.Graph().as_default(), \ + self.test_session(target=self._workers[0].target) as sess, \ + d.scope(): + + # Define a variable outside the call_for_each_tower scope. This is not + # recommended. + n = variable_scope.get_variable('n', initializer=10.0) + self.assertEqual(n.device, '/job:ps/task:0') + + def model_fn(): + if num_gpus == 0: + last_part_device = 'device:CPU:0' + else: + last_part_device = ( + 'device:GPU:%d' % distribute_lib.get_tower_context().tower_id) + + a = constant_op.constant(1.0) + b = constant_op.constant(2.0) + c = a + b + self.assertEqual(a.device, worker_device + '/' + last_part_device) + self.assertEqual(b.device, worker_device + '/' + last_part_device) + self.assertEqual(c.device, worker_device + '/' + last_part_device) + + # The device scope is ignored for variables but not for normal ops. + with ops.device('/job:worker/task:0'): + x = variable_scope.get_variable('x', initializer=10.0) + x_add = x.assign_add(c) + e = a + c + # The variable x is on the task 1 since the device_function has been + # called once before the model_fn. + self.assertEqual(x.device, '/job:ps/task:1') + self.assertEqual(x_add.device, x.device) + self.assertEqual(e.device, + '/job:worker/replica:0/task:0/%s' % last_part_device) + + # The colocate_vars_with can override the distribution's device. + with d.colocate_vars_with(x): + y = variable_scope.get_variable('y', initializer=20.0) + y_add = y.assign_add(x_add) + self.assertEqual(y.device, '/job:ps/task:1') + self.assertEqual(y_add.device, y.device) + self.assertEqual(y.device, x.device) + + z = variable_scope.get_variable('z', initializer=10.0) + self.assertEqual(z.device, '/job:ps/task:0') + self.assertNotEqual(z.device, x.device) + + with ops.control_dependencies([y_add]): + z_add = z.assign_add(y) + with ops.control_dependencies([z_add]): + f = z + c + self.assertEqual(f.device, worker_device + '/' + last_part_device) + + # The device scope would merge with the default worker device. + with ops.device('/CPU:1'): + g = e + 1.0 + self.assertEqual(g.device, worker_device + '/device:CPU:1') + + # Ths ops.colocate_with will be ignored when defining a variale but not + # for a normal tensor. + with ops.colocate_with(x): + u = variable_scope.get_variable('u', initializer=30.0) + v = variable_scope.get_variable('v', initializer=30.0) + h = f + 1.0 + self.assertIn('/job:ps/', u.device) + self.assertIn('/job:ps/', v.device) + # u and v are on different parameter servers. + self.assertTrue(u.device != x.device or v.device != x.device) + self.assertTrue(u.device == x.device or v.device == x.device) + # Here h is not on one worker. Note h.device is canonical while x.device + # is not but. + self.assertIn('/job:ps/', h.device) + return y_add, z_add, f + + y, z, f = d.call_for_each_tower(model_fn) + self.assertNotEqual(y, None) + self.assertNotEqual(z, None) + self.assertNotEqual(f, None) + + if context.num_gpus() >= 1 and num_gpus <= 1: + variables.global_variables_initializer().run() + y_val, z_val, f_val = sess.run([y, z, f]) + self.assertEqual(y_val, 33.0) + self.assertEqual(z_val, 43.0) + self.assertEqual(f_val, 46.0) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testDeviceAssignmentDistributed(self, num_gpus): + self._test_device_assignment_distributed('worker', 1, num_gpus) + + def _test_device_assignment_local(self, + d, + compute_device='CPU', + variable_device='CPU', + num_gpus=0): + with ops.Graph().as_default(), \ + self.test_session(target=self._workers[0].target) as sess, \ + d.scope(): + + def model_fn(): + if 'CPU' in compute_device: + tower_compute_device = '/device:CPU:0' + else: + tower_compute_device = ( + '/device:GPU:%d' % distribute_lib.get_tower_context().tower_id) + tower_compute_device = device_util.canonicalize(tower_compute_device) + + if 'CPU' in variable_device: + tower_variable_device = '/device:CPU:0' + else: + tower_variable_device = ( + '/device:GPU:%d' % distribute_lib.get_tower_context().tower_id) + tower_variable_device = device_util.canonicalize(tower_variable_device) + + a = constant_op.constant(1.0) + b = constant_op.constant(2.0) + c = a + b + self.assertEqual(a.device, tower_compute_device) + self.assertEqual(b.device, tower_compute_device) + self.assertEqual(c.device, tower_compute_device) + + # The device scope is ignored for variables but not for normal ops. + with ops.device('/device:GPU:2'): + x = variable_scope.get_variable('x', initializer=10.0) + x_add = x.assign_add(c) + e = a + c + self.assertEqual( + device_util.canonicalize(x.device), tower_variable_device) + self.assertEqual(x_add.device, x.device) + self.assertEqual(e.device, device_util.canonicalize('/device:GPU:2')) + + # The colocate_vars_with can override the distribution's device. + with d.colocate_vars_with(x): + y = variable_scope.get_variable('y', initializer=20.0) + y_add = y.assign_add(x_add) + self.assertEqual( + device_util.canonicalize(y.device), tower_variable_device) + self.assertEqual(y_add.device, y.device) + self.assertEqual(y.device, x.device) + + z = variable_scope.get_variable('z', initializer=10.0) + self.assertEqual( + device_util.canonicalize(z.device), tower_variable_device) + + with ops.control_dependencies([y_add]): + z_add = z.assign_add(y) + with ops.control_dependencies([z_add]): + f = z + c + self.assertEqual(f.device, tower_compute_device) + + # The device scope would merge with the default worker device. + with ops.device('/CPU:1'): + g = e + 1.0 + self.assertEqual(g.device, device_util.canonicalize('/device:CPU:1')) + + # Ths ops.colocate_with will be ignored when defining a variale but not + # for a normal tensor. + with ops.colocate_with(x): + u = variable_scope.get_variable('u', initializer=30.0) + h = f + 1.0 + self.assertEqual( + device_util.canonicalize(u.device), tower_variable_device) + self.assertEqual(device_util.canonicalize(x.device), h.device) + return y_add, z_add, f + + y, z, f = d.call_for_each_tower(model_fn) + self.assertNotEqual(y, None) + self.assertNotEqual(z, None) + self.assertNotEqual(f, None) + + if context.num_gpus() >= 1 and num_gpus <= 1: + variables.global_variables_initializer().run() + y_val, z_val, f_val = sess.run([y, z, f]) + self.assertEqual(y_val, 33.0) + self.assertEqual(z_val, 43.0) + self.assertEqual(f_val, 46.0) + + def testDeviceAssignmentLocalCPU(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=0) + self._test_device_assignment_local( + distribution, compute_device='CPU', variable_device='CPU', num_gpus=0) + + def testDeviceAssignmentLocalOneGPU(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=1) + self._test_device_assignment_local( + distribution, compute_device='GPU', variable_device='GPU', num_gpus=1) + + def testDeviceAssignmentLocalTwoGPUs(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_device_assignment_local( + distribution, compute_device='GPU', variable_device='CPU', num_gpus=2) + + def _test_simple_increment(self, task_type, task_id, num_gpus): + d, master_target = self._get_test_objects(task_type, task_id, num_gpus) + if hasattr(d, '_cluster_spec') and d._cluster_spec: + num_workers = len(d._cluster_spec.as_dict().get('worker', + ['dummy_worker'])) + else: + num_workers = 1 + with ops.Graph().as_default(), \ + self.test_session(target=master_target) as sess, \ + d.scope(): + + def model_fn(): + x = variable_scope.get_variable('x', initializer=10.0) + y = variable_scope.get_variable('y', initializer=20.0) + + x_add = x.assign_add(1.0, use_locking=True) + y_add = y.assign_add(1.0, use_locking=True) + + train_op = control_flow_ops.group([x_add, y_add]) + return x, y, train_op + + x, y, train_op = d.call_for_each_tower(model_fn) + train_op = d.group(d.unwrap(train_op)) + + if context.num_gpus() < d._num_gpus_per_worker: + return True + + if task_id == 0: + variables.global_variables_initializer().run() + + # Workers waiting for chief worker's initializing variables. + self._init_condition.acquire() + self._init_reached += 1 + while self._init_reached != num_workers: + self._init_condition.wait() + self._init_condition.notify_all() + self._init_condition.release() + + sess.run(train_op) + + # Wait for other workers to finish training. + self._finish_condition.acquire() + self._finish_reached += 1 + while self._finish_reached != num_workers: + self._finish_condition.wait() + self._finish_condition.notify_all() + self._finish_condition.release() + + x_val, y_val = sess.run([x, y]) + self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_towers) + self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_towers) + return (x_val == 10.0 + 1.0 * num_workers * d.num_towers and + y_val == 20.0 + 1.0 * num_workers * d.num_towers) + + def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): + d, master_target = self._get_test_objects(task_type, task_id, num_gpus) + with ops.Graph().as_default(), \ + self.test_session(target=master_target) as sess, \ + d.scope(): + l = core.Dense(1, use_bias=False) + + def loss_fn(x): + y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + return y * y + + # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for + # multiple graphs (b/111216820). + def grad_fn(x): + loss = loss_fn(x) + var_list = ( + variables.trainable_variables() + ops.get_collection( + ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) + grads = gradients.gradients(loss, var_list) + ret = list(zip(grads, var_list)) + return ret + + def update(v, g): + return v.assign_sub(0.05 * g, use_locking=True) + + one = d.broadcast(constant_op.constant([[1.]])) + + def step(): + """Perform one optimization step.""" + # Run forward & backward to get gradients, variables list. + g_v = d.call_for_each_tower(grad_fn, one) + # Update the variables using the gradients and the update() function. + before_list = [] + after_list = [] + for g, v in g_v: + fetched = d.read_var(v) + before_list.append(fetched) + with ops.control_dependencies([fetched]): + # TODO(yuefengz): support non-Mirrored variable as destinations. + g = d.reduce( + variable_scope.VariableAggregation.SUM, g, destinations=v) + with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + after_list.append(d.read_var(v)) + return before_list, after_list + + before_out, after_out = step() + + if context.num_gpus() < d._num_gpus_per_worker: + return True + + if task_id == 0: + variables.global_variables_initializer().run() + + # Workers waiting for chief worker's initializing variables. + self._init_condition.acquire() + self._init_reached += 1 + while self._init_reached != 3: + self._init_condition.wait() + self._init_condition.notify_all() + self._init_condition.release() + + for i in range(10): + b, a = sess.run((before_out, after_out)) + if i == 0: + before, = b + after, = a + + error_before = abs(before - 1) + error_after = abs(after - 1) + # Error should go down + self.assertLess(error_after, error_before) + return error_after < error_before + + def testSimpleBetweenGraph(self): + self._run_between_graph_clients(self._test_simple_increment, + self._cluster_spec, 0) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testLocalSimpleIncrement(self, num_gpus): + self._test_simple_increment(None, 0, num_gpus) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testMinimizeLossGraph(self, num_gpus): + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index d2fe8b3b1efabf7b35c070a82d01595f3fa51bf9..baed0ebaae8a3f41c55f309d28203b363336dd16 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import optimizer @@ -110,7 +111,8 @@ class DistributionTestBase(test.TestCase): before_list.append(fetched) # control_dependencies irrelevant but harmless in eager execution with ops.control_dependencies([fetched]): - g = d.reduce("sum", g, destinations=v) + g = d.reduce( + variable_scope.VariableAggregation.SUM, g, destinations=v) with ops.control_dependencies(d.unwrap(d.update(v, update, g))): after_list.append(d.read_var(v)) return before_list, after_list @@ -162,7 +164,8 @@ class DistributionTestBase(test.TestCase): fetched = d.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): - g = d.reduce("sum", g, destinations=v) + g = d.reduce( + variable_scope.VariableAggregation.SUM, g, destinations=v) with ops.control_dependencies(d.unwrap(d.update(v, update, g))): after_list.append(d.read_var(v)) return before_list, after_list @@ -184,7 +187,7 @@ class DistributionTestBase(test.TestCase): with d.scope(): map_in = [constant_op.constant(i) for i in range(10)] map_out = d.map(map_in, lambda x, y: x * y, 2) - observed = d.reduce("sum", map_out) + observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out) expected = 90 # 2 * (0 + 1 + ... + 9) self.assertEqual(expected, observed.numpy()) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 1ae12ae98aaecbb0ce46a944d8e61e051627ff51..83af37fc8175d56c8c4b3c75c63862fd07131184 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -21,39 +21,72 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import tpu +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python import values from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu +from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib +from tensorflow.contrib.tpu.python.tpu import training_loop +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.training import device_util +from tensorflow.python.training import server_lib from tensorflow.python.util import nest +def get_tpu_system_metadata(tpu_cluster_resolver): + """Retrieves TPU system metadata given a TPUClusterResolver.""" + master = tpu_cluster_resolver.master() + + # pylint: disable=protected-access + cluster_def = (tpu_cluster_resolver.cluster_spec() + or server_lib.ClusterSpec({})).as_cluster_def() + tpu_system_metadata = ( + tpu_system_metadata_lib._query_tpu_system_metadata( + master, + cluster_def=cluster_def, + query_topology=True)) + + return tpu_system_metadata + + class TPUStrategy(one_device_strategy.OneDeviceStrategy): """Experimental TPU distribution strategy implementation.""" - def __init__(self, num_cores_per_host=2): + def __init__(self, tpu_cluster_resolver): + """Initializes the TPUStrategy object. + + Args: + tpu_cluster_resolver: A tf.contrib.cluster_resolver.TPUClusterResolver, + which provides information about the TPU cluster. + """ # TODO(isaprykin): Generalize the defaults. They are currently tailored for # the unit test. - super(TPUStrategy, self).__init__('/cpu:0') - # TODO(isaprykin): Auto-detect number of cores and hosts. - self._num_cores_per_host = num_cores_per_host + super(TPUStrategy, self).__init__('/device:CPU:0') + + self._tpu_cluster_resolver = tpu_cluster_resolver + self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) + # TODO(priyag): This should not be hardcoded here. - self._host = '/task:0/device:CPU:0' + self._host = '/device:CPU:0' def distribute_dataset(self, dataset_fn): # TODO(priyag): Perhaps distribute across cores here. return self._call_dataset_fn(dataset_fn) - # TODO(priyag): Deal with OutOfRange errors. + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have # a mechanism to infer the outputs of `fn`. Pending b/110550782. def _run_steps_on_dataset(self, fn, iterator, iterations, initial_loop_values=None): - # Enqueue ops + shapes = nest.flatten(iterator.output_shapes) if any([not s.is_fully_defined() for s in shapes]): raise ValueError( @@ -67,7 +100,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): control_deps = [] sharded_inputs = [] with ops.device(self._host): - for _ in range(self._num_cores_per_host): + for _ in range(self.num_towers): # Use control dependencies to ensure a deterministic ordering. with ops.control_dependencies(control_deps): inputs = nest.flatten(iterator.get_next()) @@ -92,58 +125,117 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): [constant_op.constant(0)], parallel_iterations=1) - # Dequeue ops def dequeue_fn(): - dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes) + dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) return nest.pack_sequence_as(iterator.output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: - initial_loop_values = [] - ctx = values.MultiStepContext(initial_loop_values) + initial_loop_values = {} + initial_loop_values = nest.flatten(initial_loop_values) + ctx = values.MultiStepContext() def run_fn(*args, **kwargs): del args, kwargs fn_result = fn(ctx, dequeue_fn()) - if ctx.last_step_outputs is None: - ctx.last_step_outputs = [] - with ops.control_dependencies([fn_result]): - return array_ops.identity(ctx.last_step_outputs) + flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) + if flat_last_step_outputs: + with ops.control_dependencies([fn_result]): + return [array_ops.identity(f) for f in flat_last_step_outputs] + else: + return fn_result - # Repeat # TODO(sourabhbajaj): The input to while loop should be based on the output # type of the step_fn def iterate_on_tpu(): - return tpu.repeat(iterations, run_fn, [initial_loop_values]) - - # Re-write and distribute computation. - # TODO(sourabhbajaj): Convert the output to PerDevice variable and - # implement support for that in reduce. - last_step_tensor_outputs = tpu.batch_parallel( - iterate_on_tpu, [], num_shards=self._num_cores_per_host) - - # Take index [0] of last_step_tensor_outputs as we wrapped - # initial_loop_values in a list in the `repeat` call. - return (control_flow_ops.group(last_step_tensor_outputs, enqueue_ops), - last_step_tensor_outputs[0], ctx) + return training_loop.repeat(iterations, run_fn, initial_loop_values) + + replicate_inputs = [[]] * self.num_towers + replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) + ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) + + # Filter out any ops from the outputs, typically this would be the case + # when there were no tensor outputs. + last_step_tensor_outputs = [x for x in replicate_outputs + if not isinstance(x, ops.Operation)] + + # Outputs are currently of the structure (grouped by device) + # [[output0_device0, output1_device0, output2_device0], + # [output0_device1, output1_device1, output2_device1]] + # Convert this to the following structure instead: (grouped by output) + # [[output0_device0, output0_device1], + # [output1_device0, output1_device1], + # [output2_device0, output2_device1]] + last_step_tensor_outputs = [list(x) for x in zip(*last_step_tensor_outputs)] + + # Convert replicate_outputs to the original dict structure of + # last_step_outputs. + last_step_tensor_outputs_dict = nest.pack_sequence_as( + ctx.last_step_outputs, last_step_tensor_outputs) + + for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access + output = last_step_tensor_outputs_dict[name] + # For outputs that have already been aggregated, take the first value + # from the list as each value should be the same. Else return the full + # list of values. + if aggregation is not variables_lib.VariableAggregation.NONE: + # TODO(priyag): Should this return the element or a list with 1 element + last_step_tensor_outputs_dict[name] = output[0] + ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access + + return ctx def _call_for_each_tower(self, fn, *args, **kwargs): kwargs.pop('run_concurrently', None) with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access return fn(*args, **kwargs) - def get_initialization_ops(self): - return [tpu.initialize_system()] - - def get_finalize_ops(self): - return [tpu.shutdown_system()] - - def _reduce(self, method_string, value, destinations): - del destinations # TPU is graph mode only. Rely on implicit Send/Recv. - if method_string == 'mean': - # TODO(jhseu): Revisit once we support model-parallelism. - value *= (1. / self._num_cores_per_host) - return tpu_ops.cross_replica_sum(value) + def initialize(self): + if context.executing_eagerly(): + # TODO(priyag): Add appopriate call here when eager is supported for TPUs. + raise NotImplementedError('Eager mode not supported in TPUStrategy.') + else: + return [tpu.initialize_system()] + + def finalize(self): + if context.executing_eagerly(): + # TODO(priyag): Add appopriate call here when eager is supported for TPUs. + raise NotImplementedError('Eager mode not supported in TPUStrategy.') + else: + return [tpu.shutdown_system()] + + def _reduce(self, aggregation, value, destinations): + graph = ops.get_default_graph() + cf_context = graph._get_control_flow_context() # pylint: disable=protected-access + # If we're inside the ReplicateContext, reduction should be done using + # CrossReplicaSum while outside we can directly use an add_n op. + while cf_context: + if isinstance(cf_context, tpu.TPUReplicateContext): + if aggregation == vs.VariableAggregation.MEAN: + # TODO(jhseu): Revisit once we support model-parallelism. + value *= (1. / self.num_towers) + return tpu_ops.cross_replica_sum(value) + cf_context = cf_context.outer_context + + # Validate that the destination is same as the host device + # Note we don't do this when in replicate context as the reduction is + # performed on the TPU device itself. + devices = cross_tower_ops_lib.get_devices_from(destinations) + if len(devices) == 1: + assert device_util.canonicalize(devices[0]) == device_util.canonicalize( + self._host) + else: + raise ValueError('Multiple devices are not supported for TPUStrategy') + + output = math_ops.add_n(value) + if aggregation == vs.VariableAggregation.MEAN: + return output * (1. / len(value)) + return output + + def _unwrap(self, value): + if isinstance(value, list): + return value + return [value] @property def num_towers(self): - return self._num_cores_per_host + return self._tpu_metadata.num_of_cores_per_host diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 95390041f45a6dc9111454f2318cdff5aff017ed..5fd4c9de696b715c3fb9b8a6ca64923b413a32e9 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -30,10 +30,12 @@ from tensorflow.contrib.distribute.python import prefetching_ops_v2 from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import saver @@ -77,6 +79,13 @@ class DistributedValues(object): def devices(self): return list(self._index.keys()) + @property + def is_tensor_like(self): + for v in self._index.values(): + if not tensor_util.is_tensor(v): + return False + return True + def __str__(self): return "%s:%s" % (self.__class__.__name__, self._index) @@ -196,11 +205,54 @@ class DistributedVariable(DistributedDelegate): # to the container without introducing a reference cycle. for v in six.itervalues(index): v._distributed_container = weakref.ref(self) # pylint: disable=protected-access + # tf.keras keeps track of variables initialized using this attribute. When + # tf.keras gets the default session, it initializes all uninitialized vars. + # We need to make _keras_initialized a member of DistributedVariable because + # without this it will use `__getattr__` which will delegate to a component + # variable. + self._keras_initialized = False + # Typically, a `DistributedVariable`'s initializer is composed of the + # initializers of the components variables. However, in some cases, such as + # when restoring from a checkpoint, we may set the _initializer_op + # property on the entire `DistributedVariable`. + self._initializer_op = None super(DistributedVariable, self).__init__(index) + def is_initialized(self, name=None): + """Identifies if all the component variables are initialized. + + Args: + name: Name of the final `logical_and` op. + + Returns: + The op that evaluates to True or False depending on if all the + component variables are initialized. + """ + # We have to cast the self._index.values() to a `list` because when we + # use `model_to_estimator` to run tf.keras models, self._index.values() is + # of type `dict_values` and not `list`. + values_list = list(self._index.values()) + result = values_list[0].is_initialized() + # We iterate through the list of values except the last one to allow us to + # name the final `logical_and` op the same name that is passed by the user + # to the `is_initialized` op. For distributed variables, the + # `is_initialized` op is a `logical_and` op. + for v in values_list[1:-1]: + result = math_ops.logical_and(result, v.is_initialized()) + result = math_ops.logical_and(result, values_list[-1].is_initialized(), + name=name) + return result + @property def initializer(self): - return control_flow_ops.group([v.initializer for v in self._index.values()]) + if self._initializer_op: + init_op = self._initializer_op + else: + # return grouped ops of all the var initializations of component values of + # the mirrored variable + init_op = control_flow_ops.group( + [v.initializer for v in self._index.values()]) + return init_op @property def graph(self): @@ -243,6 +295,9 @@ class DistributedVariable(DistributedDelegate): self._primary_var.op.type) return self.get().op + def read_value(self): + return distribute_lib.get_distribution_strategy().read_var(self) + def _should_act_as_resource_variable(self): """Pass resource_variable_ops.is_resource_variable check.""" pass @@ -290,13 +345,13 @@ class MirroredVariable(DistributedVariable, Mirrored, checkpointable.CheckpointableBase): """Holds a map from device to variables whose values are kept in sync.""" - def __init__(self, index, primary_var, aggregation_method=None): + def __init__(self, index, primary_var, aggregation): # Use a weakref to make it easy to map from the contained values # to the container without introducing a reference cycle. for v in six.itervalues(index): v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access self._primary_var = primary_var - self._aggregation_method = aggregation_method + self._aggregation = aggregation super(MirroredVariable, self).__init__(index) # The arguments to update() are automatically unwrapped so the update() @@ -319,34 +374,42 @@ class MirroredVariable(DistributedVariable, Mirrored, return distribute_lib.get_distribution_strategy().update( self, f, *args, **kwargs) else: + _assert_tower_context() # We are calling an assign function on the mirrored variable in tower # context. # We reduce the value we want to assign/add/sub. More details about how we # handle the different use cases can be found in the _reduce method. # We call the function on each of the mirrored variables with the reduced # value. - if not self._aggregation_method: + if self._aggregation == vs.VariableAggregation.NONE: raise ValueError("You must specify an aggregation method to update a " "MirroredVariable in Tower Context.") - def merge_fn(strategy, value): - return strategy.update(self, - f, - strategy.reduce( - method_string=self._aggregation_method, - value=value, - destinations=self)) + def merge_fn(strategy, value, *other_args, **other_kwargs): + return strategy.update( + self, f, + strategy.reduce( + aggregation=self._aggregation, value=value, destinations=self), + *other_args, **other_kwargs) + return distribute_lib.get_tower_context().merge_call(merge_fn, *args, **kwargs) def assign_sub(self, *args, **kwargs): - return self._assign_func(f=state_ops.assign_sub, *args, **kwargs) + assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) + return self._assign_func(f=assign_sub_fn, *args, **kwargs) def assign_add(self, *args, **kwargs): - return self._assign_func(f=state_ops.assign_add, *args, **kwargs) + assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) + return self._assign_func(f=assign_add_fn, *args, **kwargs) def assign(self, *args, **kwargs): - return self._assign_func(f=state_ops.assign, *args, **kwargs) + assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) + return self._assign_func(f=assign_fn, *args, **kwargs) + + @property + def aggregation(self): + return self._aggregation def _get_cross_tower(self): device = device_util.canonicalize(device_util.current()) @@ -408,14 +471,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): def restore(self, restored_tensors, restored_shapes): """Restore the same value into all variables.""" tensor, = restored_tensors - # To preserve the sum across save and restore, we have to divide the - # total across all devices when restoring a variable that was summed - # when saving. - if self._tower_local_variable.reduce_method == "sum": - tensor *= 1. / len(self._tower_local_variable.devices) - return control_flow_ops.group([ - _assign_on_device(d, v, tensor) - for d, v in six.iteritems(self._tower_local_variable._index)]) # pylint: disable=protected-access + return self._tower_local_variable.assign(tensor) def _assert_tower_context(): @@ -428,9 +484,9 @@ class TowerLocalVariable(DistributedVariable, PerDevice, checkpointable.CheckpointableBase): """Holds a map from device to variables whose values are reduced on save.""" - def __init__(self, index, primary_var, reduce_method): + def __init__(self, index, primary_var, aggregation): self._primary_var = primary_var - self._reduce_method = reduce_method + self._aggregation = aggregation super(TowerLocalVariable, self).__init__(index) def assign_sub(self, *args, **kwargs): @@ -442,18 +498,29 @@ class TowerLocalVariable(DistributedVariable, PerDevice, return self.get().assign_add(*args, **kwargs) def assign(self, *args, **kwargs): - _assert_tower_context() - return self.get().assign(*args, **kwargs) + if distribute_lib.get_cross_tower_context(): + # To preserve the sum across save and restore, we have to divide the + # total across all devices when restoring a variable that was summed + # when saving. + tensor = args[0] + if self._aggregation == vs.VariableAggregation.SUM: + tensor *= 1. / len(self.devices) + return control_flow_ops.group( + [_assign_on_device(d, v, tensor) + for d, v in six.iteritems(self._index)]) + else: + _assert_tower_context() + return self.get().assign(*args, **kwargs) @property - def reduce_method(self): - return self._reduce_method + def aggregation(self): + return self._aggregation def _get_cross_tower(self): all_components = tuple(self._index.values()) # TODO(josh11b): Use a strategy-specific method. total = math_ops.add_n(all_components) - if self._reduce_method == "mean": + if self._aggregation == vs.VariableAggregation.MEAN: return total * (1./ len(all_components)) return total @@ -868,65 +935,120 @@ class MultiStepContext(object): This context object is useful when running multiple steps at a time using the `run_steps_on_dataset` API. For e.g. it allows the user's step function to - specify which outputs to emit at what frequency. Currently it only supports - capturing output from the last step, but will soon be augmented to support - other use cases such as output each N steps. + specify which outputs to emit at what frequency. Currently it supports + capturing output from the last step, as well as capturing non tensor outputs. + In the future it will be augmented to support other use cases such as output + each N steps. """ - def __init__(self, initial_loop_values=None): + def __init__(self): """Initializes an output context. - Args: - initial_loop_values: Initial values passed to the run steps - while loop. The only purpose is to verify the shapes and types - when the actual output is set. This will be removed once we - automatically infer the output shapes and types (and do not need to - check for user error in specifying them manually). Returns: A context object. """ - self._last_step_outputs = None - self._non_tensor_outputs = None - self._initial_loop_values = initial_loop_values + self._last_step_outputs = {} + self._last_step_outputs_aggregations = {} + self._non_tensor_outputs = {} @property def last_step_outputs(self): - """Return the last step's outputs.""" + """A dictionary consisting of outputs to be captured on last step. + + Keys in the dictionary are names of tensors to be captured, as specified + when `set_last_step_output` is called. + Values in the dictionary are the tensors themselves. If + `set_last_step_output` was called with an `aggregation` for this output, + then the value is the aggregated value. + + Returns: + A dictionary with last step outputs. + """ return self._last_step_outputs - @last_step_outputs.setter - def last_step_outputs(self, outputs): - """Set the last step's outputs.""" - self._verify_structure_shapes_types(outputs, self._initial_loop_values) + def _set_last_step_outputs(self, outputs): + """Replace the entire dictionary of last step outputs.""" + if not isinstance(outputs, dict): + raise ValueError("Need a dictionary to set last_step_outputs.") self._last_step_outputs = outputs + def set_last_step_output(self, name, output, + aggregation=variables_lib.VariableAggregation.NONE): + """Set `output` with `name` to be outputted from the last step. + + Args: + name: String, name to identify the output. Doesn't need to match tensor + name. + output: The tensors that should be outputted with `name`. See below for + actual types supported. + aggregation: Aggregation method to use to aggregate outputs from multiple + towers. Required if `set_last_step_output` is called in a tower context. + Optional in cross_tower_context. + When present, the outputs from all the towers are aggregated using the + current distribution strategy's `reduce` method. Hence, the type of + `output` must be what's supported by the corresponding `reduce` method. + For e.g. if using MirroredStrategy and aggregation is set, output + must be a `PerDevice` value. + The aggregation method is also recorded in a dictionary + `_last_step_outputs_aggregations` for later interpreting of the + outputs as already reduced or not. + + """ + if distribute_lib.get_cross_tower_context(): + self._last_step_outputs_aggregations[name] = aggregation + if aggregation is variables_lib.VariableAggregation.NONE: + self._last_step_outputs[name] = output + else: + distribution = distribute_lib.get_distribution_strategy() + self._last_step_outputs[name] = distribution.reduce( + aggregation, output, destinations="/device:CPU:0") + else: + assert aggregation is not variables_lib.VariableAggregation.NONE + def merge_fn(distribution, value): + self._last_step_outputs[name] = distribution.reduce( + aggregation, value, destinations="/device:CPU:0") + # Setting this inside the `merge_fn` because all towers share the same + # context object, so it's more robust to set it only once (even if all + # the towers are trying to set the same value). + self._last_step_outputs_aggregations[name] = aggregation + distribute_lib.get_tower_context().merge_call(merge_fn, output) + @property def non_tensor_outputs(self): - """Return the non tensor outputs.""" + """A dictionary consisting of any non tensor outputs to be captured.""" return self._non_tensor_outputs - @non_tensor_outputs.setter - def non_tensor_outputs(self, outputs): - """Set any non tensor outputs.""" - self._non_tensor_outputs = outputs - - def _verify_structure_shapes_types(self, left, right): - """Verify that the structure, shapes and types of left are same as right.""" - nest.assert_same_structure(left, right) - flat_left = nest.flatten(left) - flat_right = nest.flatten(right) - assert len(flat_left) == len(flat_right), ( - "Length of left {} and right {} should be same.". - format(len(flat_left), len(flat_right))) - - for o, i in zip(flat_left, flat_right): - # TODO(priyag): Add checks for other types like IndexedSlices. - if isinstance(o, ops.Tensor): - assert isinstance(i, ops.Tensor) - assert o.shape == i.shape, ( - "Shape {} of left {} doesn't match shape {} of right {}.". - format(o.shape, o, i.shape, i)) - assert o.dtype == i.dtype, ( - "Dtype {} of left {} doesn't match dtype {} of right {}.". - format(o.dtype, o, i.dtype, i)) + def set_non_tensor_output(self, name, output): + """Set `output` with `name` to be captured as a non tensor output.""" + if distribute_lib.get_cross_tower_context(): + self._non_tensor_outputs[name] = output + else: + def merge_fn(distribution, value): + # NOTE(priyag): For non tensor outputs, we simply return all the values + # in a list as aggregation doesn't make sense on non tensors. + self._non_tensor_outputs[name] = distribution.unwrap(value) + distribute_lib.get_tower_context().merge_call(merge_fn, output) + + +def value_container(val): + """Returns the container that this per-device `value` belongs to. + Args: + val: A value returned by `call_for_each_tower()` or a variable + created in `scope()`. + + Returns: + A container that `value` belongs to. + If value does not belong to any container (including the case of + container having been destroyed), returns the value itself. + """ + # pylint: disable=protected-access + if (hasattr(val, "_distributed_container") and + # DistributedVariable has _distributed_container defined + # but we don't want to return it. + not isinstance(val, DistributedVariable)): + container = val._distributed_container() + # pylint: disable=protected-access + if container is not None: + return container + return val diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index c5b246e8041500e478478d1bb1527c3fe752b377..91a43d499933c77de846085e0f12abf3064b0499 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -32,6 +32,7 @@ from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops @@ -79,6 +80,30 @@ class DistributedValuesTest(test.TestCase): with self.assertRaises(AssertionError): v = values.DistributedValues({"/device:cpu:0": 42}) + def testIsTensorLike(self): + with context.graph_mode(), \ + ops.Graph().as_default(), \ + ops.device("/device:CPU:0"): + one = constant_op.constant(1) + two = constant_op.constant(2) + v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + self.assertEqual(two, v.get("/device:GPU:0")) + self.assertEqual(one, v.get()) + self.assertTrue(v.is_tensor_like) + self.assertTrue(tensor_util.is_tensor(v)) + + def testIsTensorLikeWithAConstant(self): + with context.graph_mode(), \ + ops.Graph().as_default(), \ + ops.device("/device:CPU:0"): + one = constant_op.constant(1) + two = 2.0 + v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + self.assertEqual(two, v.get("/device:GPU:0")) + self.assertEqual(one, v.get()) + self.assertFalse(v.is_tensor_like) + self.assertFalse(tensor_util.is_tensor(v)) + class DistributedDelegateTest(test.TestCase): @@ -158,7 +183,8 @@ def _make_mirrored(): v.append(variable_scope.get_variable( name=n, initializer=init, use_resource=True)) index[d] = v[-1] - mirrored = values.MirroredVariable(index, v[0]) + mirrored = values.MirroredVariable(index, v[0], + variable_scope.VariableAggregation.SUM) return v, devices, mirrored @@ -277,7 +303,8 @@ class RegroupAndSelectDeviceTest(test.TestCase): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) index = {d: v} - mirrored = values.MirroredVariable(index, v) + mirrored = values.MirroredVariable(index, v, + variable_scope.VariableAggregation.SUM) result = values.regroup(index) self.assertIs(mirrored, result) @@ -581,7 +608,8 @@ class MirroredVariableTest(test.TestCase): v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) index = {"/job:foo/device:CPU:0": v} - mirrored = values.MirroredVariable(index, v) + mirrored = values.MirroredVariable(index, v, + variable_scope.VariableAggregation.MEAN) self.assertEquals(v.name, mirrored.name) self.assertEquals(v.dtype, mirrored.dtype) @@ -716,7 +744,9 @@ class MirroredVariableTest(test.TestCase): with ops.device("/device:GPU:0"): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) - mirrored = values.MirroredVariable({"/device:GPU:0": v}, v) + mirrored = values.MirroredVariable({ + "/device:GPU:0": v + }, v, variable_scope.VariableAggregation.MEAN) sess.run(variables_lib.global_variables_initializer()) sess.run({"complicated": mirrored}) @@ -746,24 +776,27 @@ class TowerLocalVariableTest(test.TestCase): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - v, tower_local = _make_tower_local("sum") + v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM) self.assertEquals(v[0].name, tower_local.name) self.assertEquals(v[0].dtype, tower_local.dtype) self.assertEquals(v[0].shape, tower_local.shape) - self.assertEquals("sum", tower_local.reduce_method) + self.assertEquals(variable_scope.VariableAggregation.SUM, + tower_local.aggregation) @test_util.run_in_graph_and_eager_modes(config=config) def testVariableOnAnotherDevice(self): v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) index = {"/job:foo/device:CPU:0": v} - tower_local = values.TowerLocalVariable(index, v, "mean") + tower_local = values.TowerLocalVariable( + index, v, variable_scope.VariableAggregation.MEAN) self.assertEquals(v.name, tower_local.name) self.assertEquals(v.dtype, tower_local.dtype) self.assertEquals(v.shape, tower_local.shape) - self.assertEquals("mean", tower_local.reduce_method) + self.assertEquals(variable_scope.VariableAggregation.MEAN, + tower_local.aggregation) def _assign_tower_local(self, devices, v, new): for d, var, n in zip(devices, v, new): @@ -789,7 +822,7 @@ class TowerLocalVariableTest(test.TestCase): self.skipTest("A GPU is not available for this test in eager mode.") with self.test_session() as sess: - v, tower_local = _make_tower_local("sum") + v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM) # Overwrite the initial values. self._assign_tower_local(_devices, v, [3., 4.]) @@ -812,7 +845,8 @@ class TowerLocalVariableTest(test.TestCase): self.skipTest("A GPU is not available for this test in eager mode.") with self.test_session() as sess: - v, tower_local = _make_tower_local("mean") + v, tower_local = _make_tower_local( + variable_scope.VariableAggregation.MEAN) # Overwrite the initial values. self._assign_tower_local(_devices, v, [3., 4.]) @@ -831,7 +865,8 @@ class TowerLocalVariableTest(test.TestCase): def _save_tower_local_mean(self): """Save variables with mirroring, returns save_path.""" with self.test_session(graph=ops.Graph()) as sess: - v, tower_local = _make_tower_local("mean") + v, tower_local = _make_tower_local( + variable_scope.VariableAggregation.MEAN) # Overwrite the initial values. self._assign_tower_local(_devices, v, [3., 4.]) @@ -893,7 +928,8 @@ class TowerLocalVariableTest(test.TestCase): def _restore_tower_local_mean(self, save_path): """Restore to variables with mirroring in a fresh graph.""" with self.test_session(graph=ops.Graph()) as sess: - v, tower_local = _make_tower_local("mean") + v, tower_local = _make_tower_local( + variable_scope.VariableAggregation.MEAN) # Overwrite the initial values. self._assign_tower_local(_devices, v, [7., 8.]) @@ -907,7 +943,7 @@ class TowerLocalVariableTest(test.TestCase): def _restore_tower_local_sum(self, save_path): """Restore to variables with mirroring in a fresh graph.""" with self.test_session(graph=ops.Graph()) as sess: - v, tower_local = _make_tower_local("sum") + v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM) # Overwrite the initial values. self._assign_tower_local(_devices, v, [7., 8.]) @@ -968,7 +1004,7 @@ class TowerLocalVariableTest(test.TestCase): def testTensorConversion(self): with context.graph_mode(): - _, tower_local = _make_tower_local("sum") + _, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM) converted = ops.internal_convert_to_tensor(tower_local, as_ref=False) self.assertIsInstance(converted, ops.Tensor) self.assertEqual(converted.dtype, tower_local.dtype) diff --git a/tensorflow/contrib/distribute/python/warm_starting_util_test.py b/tensorflow/contrib/distribute/python/warm_starting_util_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d8bacdb338d93a169a26a55d8ee5f5f9f0d59fce --- /dev/null +++ b/tensorflow/contrib/distribute/python/warm_starting_util_test.py @@ -0,0 +1,97 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for warm_starting_util with Distribution Strategy. + +These tests are located here instead of as part of `WarmStartingUtilTest` +because they need access to distribution strategies which are only present in +contrib right now. +TODO(priyag): Move the tests to core `WarmStartingUtilTest` when distribution +strategy moves out of contrib. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.framework import ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import warm_starting_util as ws_util + + +class WarmStartingUtilWithDistributionStrategyTest( + test.TestCase, parameterized.TestCase): + + @combinations.generate(combinations.combine( + distribution=[combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus], + save_with_distribution=[True, False], + restore_with_distribution=[True, False], + mode=["graph"])) + def testWarmStart(self, distribution, save_with_distribution, + restore_with_distribution): + + var_name = "v" + original_value = [[1., 2.], [3., 4.]] + + # Create variable and save checkpoint from which to warm-start. + def create_var(g): + with self.test_session(graph=g) as sess: + var = variable_scope.get_variable(var_name, initializer=original_value) + sess.run(variables.global_variables_initializer()) + saver = saver_lib.Saver() + ckpt_prefix = os.path.join(self.get_temp_dir(), "model") + saver.save(sess, ckpt_prefix, global_step=0) + return var, sess.run(var) + + if save_with_distribution: + with ops.Graph().as_default() as g, distribution.scope(): + _, prev_init_val = create_var(g) + else: + with ops.Graph().as_default() as g: + _, prev_init_val = create_var(g) + + # Verify we initialized the values correctly. + self.assertAllEqual(original_value, prev_init_val) + + def warm_start(g): + with self.test_session(graph=g) as sess: + # Initialize with zeros. + var = variable_scope.get_variable( + var_name, initializer=[[0., 0.], [0., 0.]]) + ws_util.warm_start(self.get_temp_dir()) + sess.run(variables.global_variables_initializer()) + # Verify weights were correctly warm-started to previous values. + self.assertAllEqual(original_value, self.evaluate(var)) + + # Warm start in a new graph. + if restore_with_distribution: + with ops.Graph().as_default() as g, distribution.scope(): + warm_start(g) + else: + with ops.Graph().as_default() as g: + warm_start(g) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index ad00d1734dd14ed846522a33d888a5387cb25cc6..a8d0d493abcd7de540799f6b94c3cdb9ce9dafae 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -124,7 +124,7 @@ cuda_py_test( cuda_py_test( name = "conditional_distribution_test", - size = "small", + size = "medium", srcs = [ "python/kernel_tests/conditional_distribution_test.py", "python/kernel_tests/distribution_test.py", diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index b8f2a4b2c731bdaee78692c036fb9f2fba4e3760..296e66f2b24fecf2142066727b5b12ee5cbd0379 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -514,9 +514,8 @@ def masked_autoregressive_default_template( Masked Autoencoder for Distribution Estimation. In _International Conference on Machine Learning_, 2015. https://arxiv.org/abs/1502.03509 """ - - with ops.name_scope(name, "masked_autoregressive_default_template", - values=[log_scale_min_clip, log_scale_max_clip]): + name = name or "masked_autoregressive_default_template" + with ops.name_scope(name, values=[log_scale_min_clip, log_scale_max_clip]): def _fn(x): """MADE parameterized via `masked_autoregressive_default_template`.""" # TODO(b/67594795): Better support of dynamic shape. @@ -552,8 +551,7 @@ def masked_autoregressive_default_template( else _clip_by_value_preserve_grad) log_scale = which_clip(log_scale, log_scale_min_clip, log_scale_max_clip) return shift, log_scale - return template_ops.make_template( - "masked_autoregressive_default_template", _fn) + return template_ops.make_template(name, _fn) @deprecation.deprecated( diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py index ef3bdfa75fcaa8df17db1238ceadadf788601356..18a0f754e6e618f240db109f593a80dec57e200b 100644 --- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py @@ -326,6 +326,21 @@ class QuantizedDistribution(distributions.Distribution): graph_parents=graph_parents, name=name) + @property + def distribution(self): + """Base distribution, p(x).""" + return self._dist + + @property + def low(self): + """Lowest value that quantization returns.""" + return self._low + + @property + def high(self): + """Highest value that quantization returns.""" + return self._high + def _batch_shape_tensor(self): return self.distribution.batch_shape_tensor() @@ -569,8 +584,3 @@ class QuantizedDistribution(distributions.Distribution): dependencies = [distribution_util.assert_integer_form( value, message="value has non-integer components.")] return control_flow_ops.with_dependencies(dependencies, value) - - @property - def distribution(self): - """Base distribution, p(x).""" - return self._dist diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index f5aaa5cf34abde3ea4d25de1ecf3adaef3f2a770..aa680a92be64cf0f099acd335369f2a1610c5953 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -134,7 +134,7 @@ def auto_correlation( x_len = util.prefer_static_shape(x_rotated)[-1] # TODO(langmore) Investigate whether this zero padding helps or hurts. At - # the moment is is necessary so that all FFT implementations work. + # the moment is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = math_ops.cast(x_len, np.float64) @@ -198,7 +198,7 @@ def auto_correlation( # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to - # divide by by to make this an unbiased estimate of the expectation + # divide by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = math_ops.cast(x_len, dtype.real_dtype) max_lags = math_ops.cast(max_lags, dtype.real_dtype) diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 58c548d798178a2848006cbf301f7d5cb2143f24..16844e0d6885919118adcd3f5a7777eec57b1e9c 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -18,35 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import threading - from tensorflow.contrib.data.python.ops import prefetching_ops from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.saver import BaseSaverBuilder - -_uid_counter = 0 -_uid_lock = threading.Lock() - - -def _generate_shared_name(prefix): - with _uid_lock: - global _uid_counter - uid = _uid_counter - _uid_counter += 1 - return "{}{}".format(prefix, uid) -class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): +class Iterator(iterator_ops.EagerIterator): """An iterator producing tf.Tensor objects from a tf.data.Dataset. NOTE: Unlike the iterator created by the @@ -80,38 +58,18 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): "`tf.contrib.eager.Iterator`. Use `for ... in dataset:` to iterate " "over the dataset instead.") - super(Iterator, self).__init__(dataset) if not context.context().device_spec.device_type: is_remote_device = False else: is_remote_device = context.context().device_spec.device_type != "CPU" - self._buffer_resource_handle = None if is_remote_device: - with ops.device("/device:CPU:0"): - iter_string_handle = gen_dataset_ops.iterator_to_string_handle( - self._resource) - - @function.Defun(dtypes.string) - def remote_fn(h): - remote_iterator = iterator_ops.Iterator.from_string_handle( - h, self.output_types, self.output_shapes, self.output_classes) - return remote_iterator.get_next() - - remote_fn.add_to_graph(None) - target = constant_op.constant("/device:CPU:0") - with ops.device(self._device): - self._buffer_resource_handle = prefetching_ops.function_buffering_resource( # pylint: disable=line-too-long - string_arg=iter_string_handle, - output_types=self._flat_output_types, - f=remote_fn, - target_device=target, - buffer_size=10, - container="", - shared_name=_generate_shared_name( - "contrib_eager_iterator_function_buffer_resource")) - self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( # pylint: disable=line-too-long - handle=self._buffer_resource_handle, - handle_device=self._device) + with ops.device(None): + # Let the placer figure out where to place the various functions etc. + # created by the CopyToDeviceDataset. + dataset = dataset.apply(prefetching_ops.copy_to_device( + context.context().device_name)) + dataset = dataset.prefetch(1) + super(Iterator, self).__init__(dataset) def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. @@ -120,40 +78,4 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): # that there is no more data to iterate over. # TODO(b/77291417): Fix with context.execution_mode(context.SYNC): - if self._buffer_resource_handle is not None: - with ops.device(self._device): - ret = prefetching_ops.function_buffering_resource_get_next( - function_buffer_resource=self._buffer_resource_handle, - output_types=self._flat_output_types) - return sparse.deserialize_sparse_tensors( - nest.pack_sequence_as(self._output_types, ret), self._output_types, - self._output_shapes, self._output_classes) - else: - return super(Iterator, self)._next_internal() - - # TODO(shivaniagrawal): Expose checkpointable stateful objects from dataset - # attributes(potential). - - class _Saveable(BaseSaverBuilder.SaveableObject): - """SaveableObject for saving/restoring iterator state.""" - - def __init__(self, iterator_resource, name): - serialized_iterator = gen_dataset_ops.serialize_iterator( - iterator_resource) - specs = [ - BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE") - ] - # pylint: disable=protected-access - super(Iterator._Saveable, self).__init__(iterator_resource, specs, name) - - def restore(self, restored_tensors, restored_shapes): - with ops.colocate_with(self.op): - return gen_dataset_ops.deserialize_iterator(self.op, - restored_tensors[0]) - - def _gather_saveables_for_checkpoint(self): - - def _saveable_factory(name): - return self._Saveable(self._resource, name) - - return {"ITERATOR": _saveable_factory} + return super(Iterator, self)._next_internal() diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 68bec9aee894edd60a025ac1cf87ca3e010db842..a753d77580758af9de8410de4a08f7ea278c4c79 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -37,6 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops +from tensorflow.python.training import checkpoint_management from tensorflow.python.training.checkpointable import util as checkpointable_utils @@ -193,6 +194,20 @@ class IteratorTest(test.TestCase): x = math_ops.add(x, x) self.assertAllEqual([0., 2.], x.numpy()) + def testGpuTensor(self): + ds = Dataset.from_tensors([0., 1.]) + with ops.device(test.gpu_device_name()): + for x in ds: + y = math_ops.add(x, x) + self.assertAllEqual([0., 2.], y.numpy()) + + def testGpuDefinedDataset(self): + with ops.device(test.gpu_device_name()): + ds = Dataset.from_tensors([0., 1.]) + for x in ds: + y = math_ops.add(x, x) + self.assertAllEqual([0., 2.], y.numpy()) + def testTensorsExplicitPrefetchToDevice(self): ds = Dataset.from_tensor_slices([0., 1.]) ds = ds.apply(prefetching_ops.prefetch_to_device(test.gpu_device_name())) @@ -292,6 +307,19 @@ class IteratorTest(test.TestCase): checkpoint.restore(save_path) self.assertEqual(2, iterator.get_next().numpy()) + def testRestoreInReconstructedIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + dataset = Dataset.range(10) + for i in range(5): + iterator = datasets.Iterator(dataset) + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + checkpoint.restore(checkpoint_management.latest_checkpoint( + checkpoint_directory)) + for j in range(2): + self.assertEqual(i * 2 + j, iterator.get_next().numpy()) + checkpoint.save(file_prefix=checkpoint_prefix) + class DatasetConstructorBenchmark(test.Benchmark): diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index 12155a459c29c353c57679c407e7dda25047a35c..6f02c90368d966b8cf8d0dee09f9d2a5013c90c1 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -15,8 +15,6 @@ py_library( "//tensorflow/contrib/eager/python/examples/revnet:config", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", "//tensorflow/contrib/eager/python/examples/rnn_ptb", - "//tensorflow/contrib/eager/python/examples/sagan", - "//tensorflow/contrib/eager/python/examples/sagan:config", "//tensorflow/contrib/eager/python/examples/spinn:data", ], ) diff --git a/tensorflow/contrib/eager/python/examples/densenet/BUILD b/tensorflow/contrib/eager/python/examples/densenet/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..2dc196f550a10367066730f6f042c4ed69533ec3 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/densenet/BUILD @@ -0,0 +1,48 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +py_binary( + name = "densenet", + srcs = ["densenet.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + ], +) + +cuda_py_test( + name = "densenet_test", + size = "large", + srcs = ["densenet_test.py"], + additional_deps = [ + ":densenet", + "//tensorflow/contrib/eager/python:tfe", + "//tensorflow:tensorflow_py", + ], + tags = [ + "no_pip", + "optonly", + ], +) + +cuda_py_test( + name = "densenet_graph_test", + size = "large", + srcs = ["densenet_graph_test.py"], + additional_deps = [ + ":densenet", + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + ], + tags = [ + "no_pip", + "noasan", + "nomsan", + "notsan", + "optonly", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet.py b/tensorflow/contrib/eager/python/examples/densenet/densenet.py new file mode 100644 index 0000000000000000000000000000000000000000..6de4e6940094849b5cf6f977e351aef525c77cc2 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/densenet/densenet.py @@ -0,0 +1,296 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Densely Connected Convolutional Networks. + +Reference [ +Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993) + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +l2 = tf.keras.regularizers.l2 + + +class ConvBlock(tf.keras.Model): + """Convolutional Block consisting of (batchnorm->relu->conv). + + Arguments: + num_filters: number of filters passed to a convolutional layer. + data_format: "channels_first" or "channels_last" + bottleneck: if True, then a 1x1 Conv is performed followed by 3x3 Conv. + weight_decay: weight decay + dropout_rate: dropout rate. + """ + + def __init__(self, num_filters, data_format, bottleneck, weight_decay=1e-4, + dropout_rate=0): + super(ConvBlock, self).__init__() + self.bottleneck = bottleneck + + axis = -1 if data_format == "channels_last" else 1 + inter_filter = num_filters * 4 + # don't forget to set use_bias=False when using batchnorm + self.conv2 = tf.keras.layers.Conv2D(num_filters, + (3, 3), + padding="same", + use_bias=False, + data_format=data_format, + kernel_initializer="he_normal", + kernel_regularizer=l2(weight_decay)) + self.batchnorm1 = tf.keras.layers.BatchNormalization(axis=axis) + self.dropout = tf.keras.layers.Dropout(dropout_rate) + + if self.bottleneck: + self.conv1 = tf.keras.layers.Conv2D(inter_filter, + (1, 1), + padding="same", + use_bias=False, + data_format=data_format, + kernel_initializer="he_normal", + kernel_regularizer=l2(weight_decay)) + self.batchnorm2 = tf.keras.layers.BatchNormalization(axis=axis) + + def call(self, x, training=True): + output = self.batchnorm1(x, training=training) + + if self.bottleneck: + output = self.conv1(tf.nn.relu(output)) + output = self.batchnorm2(output, training=training) + + output = self.conv2(tf.nn.relu(output)) + output = self.dropout(output, training=training) + + return output + + +class TransitionBlock(tf.keras.Model): + """Transition Block to reduce the number of features. + + Arguments: + num_filters: number of filters passed to a convolutional layer. + data_format: "channels_first" or "channels_last" + weight_decay: weight decay + dropout_rate: dropout rate. + """ + + def __init__(self, num_filters, data_format, + weight_decay=1e-4, dropout_rate=0): + super(TransitionBlock, self).__init__() + axis = -1 if data_format == "channels_last" else 1 + + self.batchnorm = tf.keras.layers.BatchNormalization(axis=axis) + self.conv = tf.keras.layers.Conv2D(num_filters, + (1, 1), + padding="same", + use_bias=False, + data_format=data_format, + kernel_initializer="he_normal", + kernel_regularizer=l2(weight_decay)) + self.avg_pool = tf.keras.layers.AveragePooling2D(data_format=data_format) + + def call(self, x, training=True): + output = self.batchnorm(x, training=training) + output = self.conv(tf.nn.relu(output)) + output = self.avg_pool(output) + return output + + +class DenseBlock(tf.keras.Model): + """Dense Block consisting of ConvBlocks where each block's + output is concatenated with its input. + + Arguments: + num_layers: Number of layers in each block. + growth_rate: number of filters to add per conv block. + data_format: "channels_first" or "channels_last" + bottleneck: boolean, that decides which part of ConvBlock to call. + weight_decay: weight decay + dropout_rate: dropout rate. + """ + + def __init__(self, num_layers, growth_rate, data_format, bottleneck, + weight_decay=1e-4, dropout_rate=0): + super(DenseBlock, self).__init__() + self.num_layers = num_layers + self.axis = -1 if data_format == "channels_last" else 1 + + self.blocks = [] + for _ in range(int(self.num_layers)): + self.blocks.append(ConvBlock(growth_rate, + data_format, + bottleneck, + weight_decay, + dropout_rate)) + + def call(self, x, training=True): + for i in range(int(self.num_layers)): + output = self.blocks[i](x, training=training) + x = tf.concat([x, output], axis=self.axis) + + return x + + +class DenseNet(tf.keras.Model): + """Creating the Densenet Architecture. + + Arguments: + depth_of_model: number of layers in the model. + growth_rate: number of filters to add per conv block. + num_of_blocks: number of dense blocks. + output_classes: number of output classes. + num_layers_in_each_block: number of layers in each block. + If -1, then we calculate this by (depth-3)/4. + If positive integer, then the it is used as the + number of layers per block. + If list or tuple, then this list is used directly. + data_format: "channels_first" or "channels_last" + bottleneck: boolean, to decide which part of conv block to call. + compression: reducing the number of inputs(filters) to the transition block. + weight_decay: weight decay + rate: dropout rate. + pool_initial: If True add a 7x7 conv with stride 2 followed by 3x3 maxpool + else, do a 3x3 conv with stride 1. + include_top: If true, GlobalAveragePooling Layer and Dense layer are + included. + """ + + def __init__(self, depth_of_model, growth_rate, num_of_blocks, + output_classes, num_layers_in_each_block, data_format, + bottleneck=True, compression=0.5, weight_decay=1e-4, + dropout_rate=0, pool_initial=False, include_top=True): + super(DenseNet, self).__init__() + self.depth_of_model = depth_of_model + self.growth_rate = growth_rate + self.num_of_blocks = num_of_blocks + self.output_classes = output_classes + self.num_layers_in_each_block = num_layers_in_each_block + self.data_format = data_format + self.bottleneck = bottleneck + self.compression = compression + self.weight_decay = weight_decay + self.dropout_rate = dropout_rate + self.pool_initial = pool_initial + self.include_top = include_top + + # deciding on number of layers in each block + if isinstance(self.num_layers_in_each_block, list) or isinstance( + self.num_layers_in_each_block, tuple): + self.num_layers_in_each_block = list(self.num_layers_in_each_block) + else: + if self.num_layers_in_each_block == -1: + if self.num_of_blocks != 3: + raise ValueError( + "Number of blocks must be 3 if num_layers_in_each_block is -1") + if (self.depth_of_model - 4) % 3 == 0: + num_layers = (self.depth_of_model - 4) / 3 + if self.bottleneck: + num_layers //= 2 + self.num_layers_in_each_block = [num_layers] * self.num_of_blocks + else: + raise ValueError("Depth must be 3N+4 if num_layer_in_each_block=-1") + else: + self.num_layers_in_each_block = [ + self.num_layers_in_each_block] * self.num_of_blocks + + axis = -1 if self.data_format == "channels_last" else 1 + + # setting the filters and stride of the initial covn layer. + if self.pool_initial: + init_filters = (7, 7) + stride = (2, 2) + else: + init_filters = (3, 3) + stride = (1, 1) + + self.num_filters = 2 * self.growth_rate + + # first conv and pool layer + self.conv1 = tf.keras.layers.Conv2D(self.num_filters, + init_filters, + strides=stride, + padding="same", + use_bias=False, + data_format=self.data_format, + kernel_initializer="he_normal", + kernel_regularizer=l2( + self.weight_decay)) + if self.pool_initial: + self.pool1 = tf.keras.layers.MaxPooling2D(pool_size=(3, 3), + strides=(2, 2), + padding="same", + data_format=self.data_format) + self.batchnorm1 = tf.keras.layers.BatchNormalization(axis=axis) + + self.batchnorm2 = tf.keras.layers.BatchNormalization(axis=axis) + + # last pooling and fc layer + if self.include_top: + self.last_pool = tf.keras.layers.GlobalAveragePooling2D( + data_format=self.data_format) + self.classifier = tf.keras.layers.Dense(self.output_classes) + + # calculating the number of filters after each block + num_filters_after_each_block = [self.num_filters] + for i in range(1, self.num_of_blocks): + temp_num_filters = num_filters_after_each_block[i-1] + ( + self.growth_rate * self.num_layers_in_each_block[i-1]) + # using compression to reduce the number of inputs to the + # transition block + temp_num_filters = int(temp_num_filters * compression) + num_filters_after_each_block.append(temp_num_filters) + + # dense block initialization + self.dense_blocks = [] + self.transition_blocks = [] + for i in range(self.num_of_blocks): + self.dense_blocks.append(DenseBlock(self.num_layers_in_each_block[i], + self.growth_rate, + self.data_format, + self.bottleneck, + self.weight_decay, + self.dropout_rate)) + if i+1 < self.num_of_blocks: + self.transition_blocks.append( + TransitionBlock(num_filters_after_each_block[i+1], + self.data_format, + self.weight_decay, + self.dropout_rate)) + + def call(self, x, training=True): + output = self.conv1(x) + + if self.pool_initial: + output = self.batchnorm1(output, training=training) + output = tf.nn.relu(output) + output = self.pool1(output) + + for i in range(self.num_of_blocks - 1): + output = self.dense_blocks[i](output, training=training) + output = self.transition_blocks[i](output, training=training) + + output = self.dense_blocks[ + self.num_of_blocks - 1](output, training=training) + output = self.batchnorm2(output, training=training) + output = tf.nn.relu(output) + + if self.include_top: + output = self.last_pool(output) + output = self.classifier(output) + + return output diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4b3cb624bc947a1d1956eff6accb6d4da3bf3b87 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py @@ -0,0 +1,151 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests and Benchmarks for Densenet model under graph execution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.eager.python.examples.densenet import densenet + + +def data_format(): + return 'channels_first' if tf.test.is_gpu_available() else 'channels_last' + + +def image_shape(batch_size): + if data_format() == 'channels_first': + return [batch_size, 3, 224, 224] + return [batch_size, 224, 224, 3] + + +def random_batch(batch_size): + images = np.random.rand(*image_shape(batch_size)).astype(np.float32) + num_classes = 1000 + labels = np.random.randint( + low=0, high=num_classes, size=[batch_size]).astype(np.int32) + one_hot = np.zeros((batch_size, num_classes)).astype(np.float32) + one_hot[np.arange(batch_size), labels] = 1. + return images, one_hot + + +class DensenetGraphTest(tf.test.TestCase): + + def testApply(self): + depth = 7 + growth_rate = 2 + num_blocks = 3 + output_classes = 10 + num_layers_in_each_block = -1 + batch_size = 1 + with tf.Graph().as_default(): + images = tf.placeholder(tf.float32, image_shape(None)) + model = densenet.DenseNet(depth, growth_rate, num_blocks, + output_classes, num_layers_in_each_block, + data_format(), bottleneck=True, compression=0.5, + weight_decay=1e-4, dropout_rate=0, + pool_initial=False, include_top=True) + predictions = model(images, training=False) + + init = tf.global_variables_initializer() + + with tf.Session() as sess: + sess.run(init) + np_images, _ = random_batch(batch_size) + out = sess.run(predictions, feed_dict={images: np_images}) + self.assertAllEqual([batch_size, output_classes], out.shape) + + +class DensenetBenchmark(tf.test.Benchmark): + + def __init__(self): + self.depth = 121 + self.growth_rate = 32 + self.num_blocks = 4 + self.output_classes = 1000 + self.num_layers_in_each_block = [6, 12, 24, 16] + + def _report(self, label, start, num_iters, batch_size): + avg_time = (time.time() - start) / num_iters + dev = 'gpu' if tf.test.is_gpu_available() else 'cpu' + name = 'graph_%s_%s_batch_%d_%s' % (label, dev, batch_size, data_format()) + extras = {'examples_per_sec': batch_size / avg_time} + self.report_benchmark( + iters=num_iters, wall_time=avg_time, name=name, extras=extras) + + def benchmark_graph_apply(self): + with tf.Graph().as_default(): + images = tf.placeholder(tf.float32, image_shape(None)) + model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks, + self.output_classes, + self.num_layers_in_each_block, data_format(), + bottleneck=True, compression=0.5, + weight_decay=1e-4, dropout_rate=0, + pool_initial=True, include_top=True) + predictions = model(images, training=False) + + init = tf.global_variables_initializer() + + batch_size = 64 + with tf.Session() as sess: + sess.run(init) + np_images, _ = random_batch(batch_size) + num_burn, num_iters = (3, 30) + for _ in range(num_burn): + sess.run(predictions, feed_dict={images: np_images}) + start = time.time() + for _ in range(num_iters): + sess.run(predictions, feed_dict={images: np_images}) + self._report('apply', start, num_iters, batch_size) + + def benchmark_graph_train(self): + for batch_size in [16, 32, 64]: + with tf.Graph().as_default(): + np_images, np_labels = random_batch(batch_size) + dataset = tf.data.Dataset.from_tensors((np_images, np_labels)).repeat() + (images, labels) = dataset.make_one_shot_iterator().get_next() + + model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks, + self.output_classes, + self.num_layers_in_each_block, data_format(), + bottleneck=True, compression=0.5, + weight_decay=1e-4, dropout_rate=0, + pool_initial=True, include_top=True) + logits = model(images, training=True) + cross_ent = tf.losses.softmax_cross_entropy( + logits=logits, onehot_labels=labels) + regularization = tf.add_n(model.losses) + loss = cross_ent + regularization + optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) + train_op = optimizer.minimize(loss) + + init = tf.global_variables_initializer() + with tf.Session() as sess: + sess.run(init) + (num_burn, num_iters) = (5, 10) + for _ in range(num_burn): + sess.run(train_op) + start = time.time() + for _ in range(num_iters): + sess.run(train_op) + self._report('train', start, num_iters, batch_size) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0736ed02b7437240e5da4dd529ad9ba9a5a15042 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py @@ -0,0 +1,350 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests and Benchmarks for Densenet model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gc +import time +import tensorflow as tf +import tensorflow.contrib.eager as tfe + +from tensorflow.contrib.eager.python.examples.densenet import densenet +from tensorflow.python.client import device_lib + + +class DensenetTest(tf.test.TestCase): + + def test_bottleneck_true(self): + depth = 7 + growth_rate = 2 + num_blocks = 3 + output_classes = 10 + num_layers_in_each_block = -1 + batch_size = 1 + data_format = ('channels_first') if tf.test.is_gpu_available() else ( + 'channels_last') + + model = densenet.DenseNet(depth, growth_rate, num_blocks, + output_classes, num_layers_in_each_block, + data_format, bottleneck=True, compression=0.5, + weight_decay=1e-4, dropout_rate=0, + pool_initial=False, include_top=True) + + if data_format == 'channels_last': + rand_input = tf.random_uniform((batch_size, 32, 32, 3)) + else: + rand_input = tf.random_uniform((batch_size, 3, 32, 32)) + output_shape = model(rand_input).shape + self.assertEqual(output_shape, (batch_size, output_classes)) + + def test_bottleneck_false(self): + depth = 7 + growth_rate = 2 + num_blocks = 3 + output_classes = 10 + num_layers_in_each_block = -1 + batch_size = 1 + data_format = ('channels_first') if tf.test.is_gpu_available() else ( + 'channels_last') + + model = densenet.DenseNet(depth, growth_rate, num_blocks, + output_classes, num_layers_in_each_block, + data_format, bottleneck=False, compression=0.5, + weight_decay=1e-4, dropout_rate=0, + pool_initial=False, include_top=True) + + if data_format == 'channels_last': + rand_input = tf.random_uniform((batch_size, 32, 32, 3)) + else: + rand_input = tf.random_uniform((batch_size, 3, 32, 32)) + output_shape = model(rand_input).shape + self.assertEqual(output_shape, (batch_size, output_classes)) + + def test_pool_initial_true(self): + depth = 7 + growth_rate = 2 + num_blocks = 4 + output_classes = 10 + num_layers_in_each_block = [1, 2, 2, 1] + batch_size = 1 + data_format = ('channels_first') if tf.test.is_gpu_available() else ( + 'channels_last') + + model = densenet.DenseNet(depth, growth_rate, num_blocks, + output_classes, num_layers_in_each_block, + data_format, bottleneck=True, compression=0.5, + weight_decay=1e-4, dropout_rate=0, + pool_initial=True, include_top=True) + + if data_format == 'channels_last': + rand_input = tf.random_uniform((batch_size, 32, 32, 3)) + else: + rand_input = tf.random_uniform((batch_size, 3, 32, 32)) + output_shape = model(rand_input).shape + self.assertEqual(output_shape, (batch_size, output_classes)) + + def test_regularization(self): + if tf.test.is_gpu_available(): + rand_input = tf.random_uniform((10, 3, 32, 32)) + data_format = 'channels_first' + else: + rand_input = tf.random_uniform((10, 32, 32, 3)) + data_format = 'channels_last' + weight_decay = 1e-4 + + conv = tf.keras.layers.Conv2D( + 3, (3, 3), + padding='same', + use_bias=False, + data_format=data_format, + kernel_regularizer=tf.keras.regularizers.l2(weight_decay)) + optimizer = tf.train.GradientDescentOptimizer(0.1) + conv(rand_input) # Initialize the variables in the layer + + def compute_true_l2(vs, wd): + return tf.reduce_sum(tf.square(vs)) * wd + + true_l2 = compute_true_l2(conv.variables, weight_decay) + keras_l2 = tf.add_n(conv.losses) + self.assertAllClose(true_l2, keras_l2) + + with tf.GradientTape() as tape_true, tf.GradientTape() as tape_keras: + loss = tf.reduce_sum(conv(rand_input)) + loss_with_true_l2 = loss + compute_true_l2(conv.variables, weight_decay) + loss_with_keras_l2 = loss + tf.add_n(conv.losses) + + true_grads = tape_true.gradient(loss_with_true_l2, conv.variables) + keras_grads = tape_keras.gradient(loss_with_keras_l2, conv.variables) + self.assertAllClose(true_grads, keras_grads) + + optimizer.apply_gradients(zip(keras_grads, conv.variables)) + keras_l2_after_update = tf.add_n(conv.losses) + self.assertNotAllClose(keras_l2, keras_l2_after_update) + + +def compute_gradients(model, images, labels): + with tf.GradientTape() as tape: + logits = model(images, training=True) + cross_ent = tf.losses.softmax_cross_entropy( + logits=logits, onehot_labels=labels) + regularization = tf.add_n(model.losses) + loss = cross_ent + regularization + tf.contrib.summary.scalar(name='loss', tensor=loss) + return tape.gradient(loss, model.variables) + + +def apply_gradients(model, optimizer, gradients): + optimizer.apply_gradients(zip(gradients, model.variables)) + + +def device_and_data_format(): + return ('/gpu:0', + 'channels_first') if tf.test.is_gpu_available() else ('/cpu:0', + 'channels_last') + + +def random_batch(batch_size, data_format): + shape = (3, 224, 224) if data_format == 'channels_first' else (224, 224, 3) + shape = (batch_size,) + shape + + num_classes = 1000 + images = tf.random_uniform(shape) + labels = tf.random_uniform( + [batch_size], minval=0, maxval=num_classes, dtype=tf.int32) + one_hot = tf.one_hot(labels, num_classes) + + return images, one_hot + + +class MockIterator(object): + + def __init__(self, tensors): + self._tensors = [tf.identity(x) for x in tensors] + + def next(self): + return self._tensors + + +class DensenetBenchmark(tf.test.Benchmark): + + def __init__(self): + self.depth = 121 + self.growth_rate = 32 + self.num_blocks = 4 + self.output_classes = 1000 + self.num_layers_in_each_block = [6, 12, 24, 16] + + def _train_batch_sizes(self): + """Choose batch sizes based on GPU capability.""" + for device in device_lib.list_local_devices(): + if tf.DeviceSpec.from_string(device.name).device_type == 'GPU': + if 'K20' in device.physical_device_desc: + return (16,) + if 'P100' in device.physical_device_desc: + return (16, 32, 64) + + if tf.DeviceSpec.from_string(device.name).device_type == 'TPU': + return (32,) + return (16, 32) + + def _report(self, label, start, num_iters, device, batch_size, data_format): + avg_time = (time.time() - start) / num_iters + dev = tf.DeviceSpec.from_string(device).device_type.lower() + name = '%s_%s_batch_%d_%s' % (label, dev, batch_size, data_format) + extras = {'examples_per_sec': batch_size / avg_time} + self.report_benchmark( + iters=num_iters, wall_time=avg_time, name=name, extras=extras) + + def _force_device_sync(self): + # If this function is called in the context of a non-CPU device + # (e.g., inside a 'with tf.device("/gpu:0")' block) + # then this will force a copy from CPU->NON_CPU_DEVICE->CPU, + # which forces a sync. This is a roundabout way, yes. + tf.constant(1.).cpu() + + def _benchmark_eager_apply(self, label, device_and_format, defun=False, + execution_mode=None, compiled=False): + with tfe.execution_mode(execution_mode): + device, data_format = device_and_format + model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks, + self.output_classes, + self.num_layers_in_each_block, data_format, + bottleneck=True, compression=0.5, + weight_decay=1e-4, dropout_rate=0, + pool_initial=True, include_top=True) + if defun: + model.call = tfe.defun(model.call, compiled=compiled) + batch_size = 64 + num_burn = 5 + num_iters = 30 + with tf.device(device): + images, _ = random_batch(batch_size, data_format) + for _ in xrange(num_burn): + model(images, training=False).cpu() + if execution_mode: + tfe.async_wait() + gc.collect() + start = time.time() + for _ in xrange(num_iters): + model(images, training=False).cpu() + if execution_mode: + tfe.async_wait() + self._report(label, start, num_iters, device, batch_size, data_format) + + def benchmark_eager_apply_sync(self): + self._benchmark_eager_apply('eager_apply', device_and_data_format(), + defun=False) + + def benchmark_eager_apply_async(self): + self._benchmark_eager_apply( + 'eager_apply_async', device_and_data_format(), defun=False, + execution_mode=tfe.ASYNC) + + def benchmark_eager_apply_with_defun(self): + self._benchmark_eager_apply('eager_apply_with_defun', + device_and_data_format(), defun=True) + + def _benchmark_eager_train(self, + label, + make_iterator, + device_and_format, + defun=False, + execution_mode=None, + compiled=False): + with tfe.execution_mode(execution_mode): + device, data_format = device_and_format + for batch_size in self._train_batch_sizes(): + (images, labels) = random_batch(batch_size, data_format) + model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks, + self.output_classes, + self.num_layers_in_each_block, data_format, + bottleneck=True, compression=0.5, + weight_decay=1e-4, dropout_rate=0, + pool_initial=True, include_top=True) + optimizer = tf.train.GradientDescentOptimizer(0.1) + apply_grads = apply_gradients + if defun: + model.call = tfe.defun(model.call, compiled=compiled) + apply_grads = tfe.defun(apply_gradients, compiled=compiled) + + num_burn = 3 + num_iters = 10 + with tf.device(device): + iterator = make_iterator((images, labels)) + for _ in xrange(num_burn): + (images, labels) = iterator.next() + apply_grads(model, optimizer, + compute_gradients(model, images, labels)) + if execution_mode: + tfe.async_wait() + self._force_device_sync() + gc.collect() + + start = time.time() + for _ in xrange(num_iters): + (images, labels) = iterator.next() + apply_grads(model, optimizer, + compute_gradients(model, images, labels)) + if execution_mode: + tfe.async_wait() + self._force_device_sync() + self._report(label, start, num_iters, device, batch_size, data_format) + + def benchmark_eager_train_sync(self): + self._benchmark_eager_train('eager_train', MockIterator, + device_and_data_format(), defun=False) + + def benchmark_eager_train_async(self): + self._benchmark_eager_train( + 'eager_train_async', + MockIterator, + device_and_data_format(), + defun=False, + execution_mode=tfe.ASYNC) + + def benchmark_eager_train_with_defun(self): + self._benchmark_eager_train( + 'eager_train_with_defun', MockIterator, + device_and_data_format(), defun=True) + + def benchmark_eager_train_datasets(self): + + def make_iterator(tensors): + with tf.device('/device:CPU:0'): + ds = tf.data.Dataset.from_tensors(tensors).repeat() + return tfe.Iterator(ds) + + self._benchmark_eager_train( + 'eager_train_dataset', make_iterator, + device_and_data_format(), defun=False) + + def benchmark_eager_train_datasets_with_defun(self): + + def make_iterator(tensors): + with tf.device('/device:CPU:0'): + ds = tf.data.Dataset.from_tensors(tensors).repeat() + return tfe.Iterator(ds) + + self._benchmark_eager_train( + 'eager_train_dataset_with_defun', make_iterator, + device_and_data_format(), defun=True) + + +if __name__ == '__main__': + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py index cc9cf53410f641cc3303b4450e9eaa1301904a64..9a4217929916c258b7e8f2e5b3add2905d20d1da 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py @@ -29,7 +29,6 @@ import time import tensorflow as tf -import tensorflow.contrib.eager as tfe from tensorflow.examples.tutorials.mnist import input_data layers = tf.keras.layers @@ -214,7 +213,7 @@ def train_one_epoch(generator, discriminator, generator_optimizer, total_generator_loss = 0.0 total_discriminator_loss = 0.0 - for (batch_index, images) in enumerate(tfe.Iterator(dataset)): + for (batch_index, images) in enumerate(dataset): with tf.device('/cpu:0'): tf.assign_add(step_counter, 1) @@ -227,7 +226,10 @@ def train_one_epoch(generator, discriminator, generator_optimizer, maxval=1., seed=batch_index) - with tf.GradientTape(persistent=True) as g: + # we can use 2 tapes or a single persistent tape. + # Using two tapes is memory efficient since intermediate tensors can be + # released between the two .gradient() calls below + with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: generated_images = generator(noise) tf.contrib.summary.image( 'generated_images', @@ -243,9 +245,10 @@ def train_one_epoch(generator, discriminator, generator_optimizer, generator_loss_val = generator_loss(discriminator_gen_outputs) total_generator_loss += generator_loss_val - generator_grad = g.gradient(generator_loss_val, generator.variables) - discriminator_grad = g.gradient(discriminator_loss_val, - discriminator.variables) + generator_grad = gen_tape.gradient(generator_loss_val, + generator.variables) + discriminator_grad = disc_tape.gradient(discriminator_loss_val, + discriminator.variables) generator_optimizer.apply_gradients( zip(generator_grad, generator.variables)) @@ -261,7 +264,7 @@ def train_one_epoch(generator, discriminator, generator_optimizer, def main(_): (device, data_format) = ('/gpu:0', 'channels_first') - if FLAGS.no_gpu or tfe.num_gpus() <= 0: + if FLAGS.no_gpu or tf.contrib.eager.num_gpus() <= 0: (device, data_format) = ('/cpu:0', 'channels_last') print('Using device %s, and data format %s.' % (device, data_format)) @@ -287,7 +290,7 @@ def main(_): latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if latest_cpkt: print('Using latest checkpoint at ' + latest_cpkt) - checkpoint = tfe.Checkpoint(**model_objects) + checkpoint = tf.train.Checkpoint(**model_objects) # Restore variables on creation if a checkpoint exists. checkpoint.restore(latest_cpkt) diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..f91ae374488b735549e5c8fc03a309b14f8ba4be --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb @@ -0,0 +1,634 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0TD5ZrvEMbhZ" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n", + "\n", + "# Convolutional VAE: An example with tf.keras and eager\n", + "\n", + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ITZuApL56Mny" + }, + "source": [ + "This notebook demonstrates how to generate images of handwritten digits using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager) by training a Variational Autoencoder. (VAE, [[1]](https://arxiv.org/abs/1312.6114), [[2]](https://arxiv.org/abs/1401.4082)).\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "P-JuIu2N_SQf" + }, + "outputs": [], + "source": [ + "# to generate gifs\n", + "!pip install imageio" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "e1_Y75QXJS6h" + }, + "source": [ + "## Import TensorFlow and enable Eager execution" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "YfIk2es3hJEd" + }, + "outputs": [], + "source": [ + "from __future__ import absolute_import, division, print_function\n", + "\n", + "# Import TensorFlow \u003e= 1.9 and enable eager execution\n", + "import tensorflow as tf\n", + "tfe = tf.contrib.eager\n", + "tf.enable_eager_execution()\n", + "\n", + "import os\n", + "import time\n", + "import numpy as np\n", + "import glob\n", + "import matplotlib.pyplot as plt\n", + "import PIL\n", + "import imageio\n", + "from IPython import display" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "iYn4MdZnKCey" + }, + "source": [ + "## Load the MNIST dataset\n", + "Each MNIST image is originally a vector of 784 integers, each of which is between 0-255 and represents the intensity of a pixel. We model each pixel with a Bernoulli distribution in our model, and we statically binarize the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "a4fYMGxGhrna" + }, + "outputs": [], + "source": [ + "(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "NFC2ghIdiZYE" + }, + "outputs": [], + "source": [ + "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", + "test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype('float32')\n", + "\n", + "# Normalizing the images to the range of [0., 1.]\n", + "train_images /= 255.\n", + "test_images /= 255.\n", + "\n", + "# Binarization\n", + "train_images[train_images \u003e= .5] = 1.\n", + "train_images[train_images \u003c .5] = 0.\n", + "test_images[test_images \u003e= .5] = 1.\n", + "test_images[test_images \u003c .5] = 0." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "S4PIDhoDLbsZ" + }, + "outputs": [], + "source": [ + "TRAIN_BUF = 60000\n", + "BATCH_SIZE = 100\n", + "\n", + "TEST_BUF = 10000" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PIGN6ouoQxt3" + }, + "source": [ + "## Use *tf.data* to create batches and shuffle the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "-yKCCQOoJ7cn" + }, + "outputs": [], + "source": [ + "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(TRAIN_BUF).batch(BATCH_SIZE)\n", + "test_dataset = tf.data.Dataset.from_tensor_slices(test_images).shuffle(TEST_BUF).batch(BATCH_SIZE)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "THY-sZMiQ4UV" + }, + "source": [ + "## Wire up the generative and inference network with *tf.keras.Sequential*\n", + "\n", + "In our VAE example, we use two small ConvNets for the generative and inference network. Since these neural nets are small, we use `tf.keras.Sequential` to simplify our code. Let $x$ and $z$ denote the observation and latent variable respectively in the following descriptions. \n", + "\n", + "### Generative Network\n", + "This defines the generative model which takes a latent encoding as input, and outputs the parameters for a conditional distribution of the observation, i.e. $p(x|z)$. Additionally, we use a unit Gaussian prior $p(z)$ for the latent variable.\n", + "\n", + "### Inference Network\n", + "This defines an approximate posterior distribution $q(z|x)$, which takes as input an observation and outputs a set of parameters for the conditional distribution of the latent representation. In this example, we simply model this distribution as a diagonal Gaussian. In this case, the inference network outputs the mean and log-variance parameters of a factorized Gaussian (log-variance instead of the variance directly is for numerical stability).\n", + "\n", + "### Reparameterization Trick\n", + "During optimization, we can sample from $q(z|x)$ by first sampling from a unit Gaussian, and then multiplying by the standard deviation and adding the mean. This ensures the gradients could pass through the sample to the inference network parameters.\n", + "\n", + "### Network architecture\n", + "For the inference network, we use two convolutional layers followed by a fully-connected layer. In the generative network, we mirror this architecture by using a fully-connected layer followed by three convolution transpose layers (a.k.a. deconvolutional layers in some contexts). Note, it's common practice to avoid using batch normalization when training VAEs, since the additional stochasticity due to using mini-batches may aggravate instability on top of the stochasticity from sampling." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "VGLbvBEmjK0a" + }, + "outputs": [], + "source": [ + "class CVAE(tf.keras.Model):\n", + " def __init__(self, latent_dim):\n", + " super(CVAE, self).__init__()\n", + " self.latent_dim = latent_dim\n", + " self.inference_net = tf.keras.Sequential(\n", + " [\n", + " tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),\n", + " tf.keras.layers.Conv2D(\n", + " filters=32, kernel_size=3, strides=(2, 2), activation=tf.nn.relu),\n", + " tf.keras.layers.Conv2D(\n", + " filters=64, kernel_size=3, strides=(2, 2), activation=tf.nn.relu),\n", + " tf.keras.layers.Flatten(),\n", + " # No activation\n", + " tf.keras.layers.Dense(latent_dim + latent_dim),\n", + " ]\n", + " )\n", + "\n", + " self.generative_net = tf.keras.Sequential(\n", + " [\n", + " tf.keras.layers.InputLayer(input_shape=(latent_dim,)),\n", + " tf.keras.layers.Dense(units=7*7*32, activation=tf.nn.relu),\n", + " tf.keras.layers.Reshape(target_shape=(7, 7, 32)),\n", + " tf.keras.layers.Conv2DTranspose(\n", + " filters=64,\n", + " kernel_size=3,\n", + " strides=(2, 2),\n", + " padding=\"SAME\",\n", + " activation=tf.nn.relu),\n", + " tf.keras.layers.Conv2DTranspose(\n", + " filters=32,\n", + " kernel_size=3,\n", + " strides=(2, 2),\n", + " padding=\"SAME\",\n", + " activation=tf.nn.relu),\n", + " # No activation\n", + " tf.keras.layers.Conv2DTranspose(\n", + " filters=1, kernel_size=3, strides=(1, 1), padding=\"SAME\"),\n", + " ]\n", + " )\n", + "\n", + " def sample(self, eps=None):\n", + " if eps is None:\n", + " eps = tf.random_normal(shape=(100, self.latent_dim))\n", + " return self.decode(eps, apply_sigmoid=True)\n", + "\n", + " def encode(self, x):\n", + " mean, logvar = tf.split(self.inference_net(x), num_or_size_splits=2, axis=1)\n", + " return mean, logvar\n", + "\n", + " def reparameterize(self, mean, logvar):\n", + " eps = tf.random_normal(shape=mean.shape)\n", + " return eps * tf.exp(logvar * .5) + mean\n", + "\n", + " def decode(self, z, apply_sigmoid=False):\n", + " logits = self.generative_net(z)\n", + " if apply_sigmoid:\n", + " probs = tf.sigmoid(logits)\n", + " return probs\n", + "\n", + " return logits" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0FMYgY_mPfTi" + }, + "source": [ + "## Define the loss function and the optimizer\n", + "\n", + "VAEs train by maximizing the evidence lower bound (ELBO) on the marginal log-likelihood:\n", + "\n", + "$$\\log p(x) \\ge \\text{ELBO} = \\mathbb{E}_{q(z|x)}\\left[\\log \\frac{p(x, z)}{q(z|x)}\\right].$$\n", + "\n", + "In practice, we optimize the single sample Monte Carlo estimate of this expectation:\n", + "\n", + "$$\\log p(x| z) + \\log p(z) - \\log q(z|x),$$\n", + "where $z$ is sampled from $q(z|x)$.\n", + "\n", + "**Note**: we could also analytically compute the KL term, but here we incorporate all three terms in the Monte Carlo estimator for simplicity." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "iWCn_PVdEJZ7" + }, + "outputs": [], + "source": [ + "def log_normal_pdf(sample, mean, logvar, raxis=1):\n", + " log2pi = tf.log(2. * np.pi)\n", + " return tf.reduce_sum(\n", + " -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),\n", + " axis=raxis)\n", + "\n", + "def compute_loss(model, x):\n", + " mean, logvar = model.encode(x)\n", + " z = model.reparameterize(mean, logvar)\n", + " x_logit = model.decode(z)\n", + "\n", + " cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)\n", + " logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])\n", + " logpz = log_normal_pdf(z, 0., 0.)\n", + " logqz_x = log_normal_pdf(z, mean, logvar)\n", + " return -tf.reduce_mean(logpx_z + logpz - logqz_x)\n", + "\n", + "def compute_gradients(model, x):\n", + " with tf.GradientTape() as tape:\n", + " loss = compute_loss(model, x)\n", + " return tape.gradient(loss, model.trainable_variables), loss\n", + "\n", + "optimizer = tf.train.AdamOptimizer(1e-4)\n", + "def apply_gradients(optimizer, gradients, variables, global_step=None):\n", + " optimizer.apply_gradients(zip(gradients, variables), global_step=global_step)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Rw1fkAczTQYh" + }, + "source": [ + "## Training\n", + "\n", + "* We start by iterating over the dataset\n", + "* During each iteration, we pass the image to the encoder to obtain a set of mean and log-variance parameters of the approximate posterior $q(z|x)$\n", + "* We then apply the *reparameterization trick* to sample from $q(z|x)$\n", + "* Finally, we pass the reparameterized samples to the decoder to obtain the logits of the generative distribution $p(x|z)$\n", + "* **Note:** Since we use the dataset loaded by keras with 60k datapoints in the training set and 10k datapoints in the test set, our resulting ELBO on the test set is slightly higher than reported results in the literature which uses dynamic binarization of Larochelle's MNIST.\n", + "\n", + "## Generate Images\n", + "\n", + "* After training, it is time to generate some images\n", + "* We start by sampling a set of latent vectors from the unit Gaussian prior distribution $p(z)$\n", + "* The generator will then convert the latent sample $z$ to logits of the observation, giving a distribution $p(x|z)$\n", + "* Here we plot the probabilities of Bernoulli distributions\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "NS2GWywBbAWo" + }, + "outputs": [], + "source": [ + "epochs = 100\n", + "latent_dim = 50\n", + "num_examples_to_generate = 100\n", + "\n", + "# keeping the random vector constant for generation (prediction) so\n", + "# it will be easier to see the improvement.\n", + "random_vector_for_generation = tf.random_normal(\n", + " shape=[num_examples_to_generate, latent_dim])\n", + "model = CVAE(latent_dim)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "RmdVsmvhPxyy" + }, + "outputs": [], + "source": [ + "def generate_and_save_images(model, epoch, test_input):\n", + " predictions = model.sample(test_input)\n", + " fig = plt.figure(figsize=(10,10))\n", + "\n", + " for i in range(predictions.shape[0]):\n", + " plt.subplot(10, 10, i+1)\n", + " plt.imshow(predictions[i, :, :, 0], cmap='gray')\n", + " plt.axis('off')\n", + "\n", + " # tight_layout minimizes the overlap between 2 sub-plots\n", + " plt.tight_layout()\n", + " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "2M7LmLtGEMQJ" + }, + "outputs": [], + "source": [ + "generate_and_save_images(model, 0, random_vector_for_generation)\n", + "\n", + "for epoch in range(1, epochs + 1):\n", + " start_time = time.time()\n", + " for train_x in train_dataset:\n", + " gradients, loss = compute_gradients(model, train_x)\n", + " apply_gradients(optimizer, gradients, model.trainable_variables)\n", + " end_time = time.time()\n", + "\n", + " if epoch % 5 == 0:\n", + " loss = tfe.metrics.Mean()\n", + " for test_x in test_dataset.make_one_shot_iterator():\n", + " loss(compute_loss(model, test_x))\n", + " elbo = -loss.result()\n", + " display.clear_output(wait=False)\n", + " print('Epoch: {}, Test set ELBO: {}, '\n", + " 'time elapse for current epoch {}'.format(epoch,\n", + " elbo,\n", + " end_time - start_time))\n", + " generate_and_save_images(\n", + " model, epoch, random_vector_for_generation)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "P4M_vIbUi7c0" + }, + "source": [ + "### Display an image using the epoch number" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "WfO5wCdclHGL" + }, + "outputs": [], + "source": [ + "def display_image(epoch_no):\n", + " plt.figure(figsize=(15,15))\n", + " plt.imshow(np.array(PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))))\n", + " plt.axis('off')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "5x3q9_Oe5q0A" + }, + "outputs": [], + "source": [ + "display_image(epochs) # Display images" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "NywiH3nL8guF" + }, + "source": [ + "### Generate a GIF of all the saved images." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "IGKQgENQ8lEI" + }, + "outputs": [], + "source": [ + "with imageio.get_writer('cvae.gif', mode='I') as writer:\n", + " filenames = glob.glob('image*.png')\n", + " filenames = sorted(filenames)\n", + " for filename in filenames:\n", + " image = imageio.imread(filename)\n", + " writer.append_data(image)\n", + " # this is a hack to display the gif inside the notebook\n", + " os.system('mv cvae.gif cvae.gif.png')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "uV0yiKpzNP1b" + }, + "outputs": [], + "source": [ + "display.Image(filename=\"cvae.gif.png\")" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "JGZBy7glUU2O" + }, + "outputs": [], + "source": [ + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "default_view": {}, + "last_runtime": { + "build_target": "//learning/brain/python/client:colab_notebook", + "kind": "private" + }, + "name": "cvae.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1eb0NOTQapkYs3X0v-zL1x5_LFKgDISnp", + "timestamp": 1527173385672 + } + ], + "toc_visible": true, + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..44ff43a1112e771eb6c91c398286a003e17632e0 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb @@ -0,0 +1,733 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0TD5ZrvEMbhZ" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n", + "\n", + "# DCGAN: An example with tf.keras and eager\n", + "\n", + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "ITZuApL56Mny" + }, + "source": [ + "This notebook demonstrates how to generate images of handwritten digits using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). To do so, we use Deep Convolutional Generative Adverserial Networks ([DCGAN](https://arxiv.org/pdf/1511.06434.pdf)).\n", + "\n", + "This model takes about ~30 seconds per epoch (using tf.contrib.eager.defun to create graph functions) to train on a single Tesla K80 on Colab, as of July 2018.\n", + "\n", + "Below is the output generated after training the generator and discriminator models for 150 epochs.\n", + "\n", + "![sample output](https://tensorflow.org/images/gan/dcgan.gif)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "u_2z-B3piVsw" + }, + "outputs": [], + "source": [ + "# to generate gifs\n", + "!pip install imageio" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "e1_Y75QXJS6h" + }, + "source": [ + "## Import TensorFlow and enable eager execution" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "YfIk2es3hJEd" + }, + "outputs": [], + "source": [ + "from __future__ import absolute_import, division, print_function\n", + "\n", + "# Import TensorFlow \u003e= 1.9 and enable eager execution\n", + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "\n", + "import os\n", + "import time\n", + "import numpy as np\n", + "import glob\n", + "import matplotlib.pyplot as plt\n", + "import PIL\n", + "import imageio\n", + "from IPython import display" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "iYn4MdZnKCey" + }, + "source": [ + "## Load the dataset\n", + "\n", + "We are going to use the MNIST dataset to train the generator and the discriminator. The generator will then generate handwritten digits." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "a4fYMGxGhrna" + }, + "outputs": [], + "source": [ + "(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "NFC2ghIdiZYE" + }, + "outputs": [], + "source": [ + "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", + "# We are normalizing the images to the range of [-1, 1]\n", + "train_images = (train_images - 127.5) / 127.5" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "S4PIDhoDLbsZ" + }, + "outputs": [], + "source": [ + "BUFFER_SIZE = 60000\n", + "BATCH_SIZE = 256" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PIGN6ouoQxt3" + }, + "source": [ + "## Use tf.data to create batches and shuffle the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "-yKCCQOoJ7cn" + }, + "outputs": [], + "source": [ + "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "THY-sZMiQ4UV" + }, + "source": [ + "## Write the generator and discriminator models\n", + "\n", + "* **Generator** \n", + " * It is responsible for **creating convincing images that are good enough to fool the discriminator**.\n", + " * It consists of Conv2DTranspose (Upsampling) layers. We start with a fully connected layer and upsample the image 2 times so as to reach the desired image size (mnist image size) which is (28, 28, 1). \n", + " * We use **leaky relu** activation except for the **last layer** which uses **tanh** activation.\n", + " \n", + "* **Discriminator**\n", + " * **The discriminator is responsible for classifying the fake images from the real images.**\n", + " * In other words, the discriminator is given generated images (from the generator) and the real MNIST images. The job of the discriminator is to classify these images into fake (generated) and real (MNIST images).\n", + " * **Basically the generator should be good enough to fool the discriminator that the generated images are real**." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "VGLbvBEmjK0a" + }, + "outputs": [], + "source": [ + "class Generator(tf.keras.Model):\n", + " def __init__(self):\n", + " super(Generator, self).__init__()\n", + " self.fc1 = tf.keras.layers.Dense(7*7*64, use_bias=False)\n", + " self.batchnorm1 = tf.keras.layers.BatchNormalization()\n", + " \n", + " self.conv1 = tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(1, 1), padding='same', use_bias=False)\n", + " self.batchnorm2 = tf.keras.layers.BatchNormalization()\n", + " \n", + " self.conv2 = tf.keras.layers.Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same', use_bias=False)\n", + " self.batchnorm3 = tf.keras.layers.BatchNormalization()\n", + " \n", + " self.conv3 = tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False)\n", + "\n", + " def call(self, x, training=True):\n", + " x = self.fc1(x)\n", + " x = self.batchnorm1(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = tf.reshape(x, shape=(-1, 7, 7, 64))\n", + "\n", + " x = self.conv1(x)\n", + " x = self.batchnorm2(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2(x)\n", + " x = self.batchnorm3(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = tf.nn.tanh(self.conv3(x)) \n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "bkOfJxk5j5Hi" + }, + "outputs": [], + "source": [ + "class Discriminator(tf.keras.Model):\n", + " def __init__(self):\n", + " super(Discriminator, self).__init__()\n", + " self.conv1 = tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')\n", + " self.conv2 = tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')\n", + " self.dropout = tf.keras.layers.Dropout(0.3)\n", + " self.flatten = tf.keras.layers.Flatten()\n", + " self.fc1 = tf.keras.layers.Dense(1)\n", + "\n", + " def call(self, x, training=True):\n", + " x = tf.nn.leaky_relu(self.conv1(x))\n", + " x = self.dropout(x, training=training)\n", + " x = tf.nn.leaky_relu(self.conv2(x))\n", + " x = self.dropout(x, training=training)\n", + " x = self.flatten(x)\n", + " x = self.fc1(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "gDkA05NE6QMs" + }, + "outputs": [], + "source": [ + "generator = Generator()\n", + "discriminator = Discriminator()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "k1HpMSLImuRi" + }, + "outputs": [], + "source": [ + "# Defun gives 10 secs/epoch performance boost\n", + "generator.call = tf.contrib.eager.defun(generator.call)\n", + "discriminator.call = tf.contrib.eager.defun(discriminator.call)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "0FMYgY_mPfTi" + }, + "source": [ + "## Define the loss functions and the optimizer\n", + "\n", + "* **Discriminator loss**\n", + " * The discriminator loss function takes 2 inputs; **real images, generated images**\n", + " * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones (since these are the real images)**\n", + " * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros (since these are the fake images)**\n", + " * Then the total_loss is the sum of real_loss and the generated_loss\n", + " \n", + "* **Generator loss**\n", + " * It is a sigmoid cross entropy loss of the generated images and an **array of ones**\n", + " \n", + "\n", + "* The discriminator and the generator optimizers are different since we will train them separately." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "wkMNfBWlT-PV" + }, + "outputs": [], + "source": [ + "def discriminator_loss(real_output, generated_output):\n", + " # [1,1,...,1] with real output since it is true and we want\n", + " # our generated examples to look like it\n", + " real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real_output), logits=real_output)\n", + "\n", + " # [0,0,...,0] with generated images since they are fake\n", + " generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(generated_output), logits=generated_output)\n", + "\n", + " total_loss = real_loss + generated_loss\n", + "\n", + " return total_loss" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "90BIcCKcDMxz" + }, + "outputs": [], + "source": [ + "def generator_loss(generated_output):\n", + " return tf.losses.sigmoid_cross_entropy(tf.ones_like(generated_output), generated_output)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "iWCn_PVdEJZ7" + }, + "outputs": [], + "source": [ + "discriminator_optimizer = tf.train.AdamOptimizer(1e-4)\n", + "generator_optimizer = tf.train.AdamOptimizer(1e-4)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Rw1fkAczTQYh" + }, + "source": [ + "## Training\n", + "\n", + "* We start by iterating over the dataset\n", + "* The generator is given **noise as an input** which when passed through the generator model will output a image looking like a handwritten digit\n", + "* The discriminator is given the **real MNIST images as well as the generated images (from the generator)**.\n", + "* Next, we calculate the generator and the discriminator loss.\n", + "* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables (inputs) and apply those to the optimizer.\n", + "\n", + "## Generate Images\n", + "\n", + "* After training, its time to generate some images!\n", + "* We start by creating noise array as an input to the generator\n", + "* The generator will then convert the noise into handwritten images.\n", + "* Last step is to plot the predictions and **voila!**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "NS2GWywBbAWo" + }, + "outputs": [], + "source": [ + "EPOCHS = 150\n", + "noise_dim = 100\n", + "num_examples_to_generate = 100\n", + "\n", + "# keeping the random vector constant for generation (prediction) so\n", + "# it will be easier to see the improvement of the gan.\n", + "random_vector_for_generation = tf.random_normal([num_examples_to_generate,\n", + " noise_dim])" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "RmdVsmvhPxyy" + }, + "outputs": [], + "source": [ + "def generate_and_save_images(model, epoch, test_input):\n", + " # make sure the training parameter is set to False because we\n", + " # don't want to train the batchnorm layer when doing inference.\n", + " predictions = model(test_input, training=False)\n", + "\n", + " fig = plt.figure(figsize=(10,10))\n", + " \n", + " for i in range(predictions.shape[0]):\n", + " plt.subplot(10, 10, i+1)\n", + " plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')\n", + " plt.axis('off')\n", + " \n", + " # tight_layout minimizes the overlap between 2 sub-plots\n", + " plt.tight_layout()\n", + " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "2M7LmLtGEMQJ" + }, + "outputs": [], + "source": [ + "def train(dataset, epochs, noise_dim): \n", + " for epoch in range(epochs):\n", + " start = time.time()\n", + " \n", + " for images in dataset:\n", + " # generating noise from a uniform distribution\n", + " noise = tf.random_normal([BATCH_SIZE, noise_dim])\n", + " \n", + " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", + " generated_images = generator(noise, training=True)\n", + " \n", + " real_output = discriminator(images, training=True)\n", + " generated_output = discriminator(generated_images, training=True)\n", + " \n", + " gen_loss = generator_loss(generated_output)\n", + " disc_loss = discriminator_loss(real_output, generated_output)\n", + " \n", + " gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)\n", + " gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)\n", + " \n", + " generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))\n", + " discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))\n", + "\n", + " \n", + " if epoch % 10 == 0:\n", + " display.clear_output(wait=True)\n", + " generate_and_save_images(generator,\n", + " epoch + 1,\n", + " random_vector_for_generation)\n", + "\n", + " print ('Time taken for epoch {} is {} sec'.format(epoch + 1,\n", + " time.time()-start))\n", + " # generating after the final epoch\n", + " generate_and_save_images(generator,\n", + " epochs,\n", + " random_vector_for_generation)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "Ly3UN0SLLY2l" + }, + "outputs": [], + "source": [ + "train(train_dataset, EPOCHS, noise_dim)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "P4M_vIbUi7c0" + }, + "source": [ + "# Display an image using the epoch number" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "WfO5wCdclHGL" + }, + "outputs": [], + "source": [ + "def display_image(epoch_no):\n", + " plt.figure(figsize=(15,15))\n", + " plt.imshow(np.array(PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))))\n", + " plt.axis('off')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "5x3q9_Oe5q0A" + }, + "outputs": [], + "source": [ + "display_image(EPOCHS)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "NywiH3nL8guF" + }, + "source": [ + "## Generate a GIF of all the saved images." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xmO0Dmu2WICn" + }, + "source": [ + "\u003c!-- TODO(markdaoust): Remove the hack when Ipython version is updated --\u003e\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "IGKQgENQ8lEI" + }, + "outputs": [], + "source": [ + "with imageio.get_writer('dcgan.gif', mode='I') as writer:\n", + " filenames = glob.glob('image*.png')\n", + " filenames = sorted(filenames)\n", + " for filename in filenames:\n", + " image = imageio.imread(filename)\n", + " writer.append_data(image)\n", + " # this is a hack to display the gif inside the notebook\n", + " os.system('mv dcgan.gif dcgan.gif.png')" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "uV0yiKpzNP1b" + }, + "outputs": [], + "source": [ + "display.Image(filename=\"dcgan.gif.png\")" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "4UJjSnIMOzOJ" + }, + "outputs": [], + "source": [ + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "dcgan.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1eb0NOTQapkYs3X0v-zL1x5_LFKgDISnp", + "timestamp": 1527173385672 + } + ], + "toc_visible": true, + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..1a5a186e7a3e456cc43f8091370d3eeb795d5e0e --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb @@ -0,0 +1,1184 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "image_captioning_with_attention.ipynb", + "version": "0.3.2", + "views": {}, + "default_view": {}, + "provenance": [ + { + "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg", + "timestamp": 1530222436922 + } + ], + "private_outputs": true, + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "metadata": { + "id": "K2s1A9eLRPEj", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n" + ] + }, + { + "metadata": { + "id": "Cffg2i257iMS", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Image Captioning with Attention\n", + "\n", + "
\n", + "\n", + " Run in Google Colab \n", + "\n", + "View source on GitHub
" + ] + }, + { + "metadata": { + "id": "QASbY_HGo4Lq", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Image captioning is the task of generating a caption for an image. Given an image like this:\n", + "\n", + "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n", + "\n", + "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", + "\n", + "Our goal is generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n", + "\n", + "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n", + "\n", + "This model architecture below is similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044). \n", + "\n", + "The code uses [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager), which you can learn more about in the linked guides.\n", + "\n", + "This notebook is an end-to-end example. If you run it, it will download the [MS-COCO](http://cocodataset.org/#home) dataset, preprocess and cache a subset of the images using Inception V3, train an encoder-decoder model, and use it to generate captions on new images.\n", + "\n", + "The code requires TensorFlow version >=1.9. If you're running this in [Colab]()\n", + "\n", + "In this example, we're training on a relatively small amount of data as an example. On a single P100 GPU, this example will take about ~2 hours to train. We train on the first 30,000 captions (corresponding to about ~20,000 images depending on shuffling, as there are multiple captions per image in the dataset)\n" + ] + }, + { + "metadata": { + "id": "U8l4RJ0XRPEm", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Import TensorFlow and enable eager execution\n", + "# This code requires TensorFlow version >=1.9\n", + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "\n", + "# We'll generate plots of attention in order to see which parts of an image\n", + "# our model focuses on during captioning\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Scikit-learn includes many helpful utilities\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.utils import shuffle\n", + "\n", + "import re\n", + "import numpy as np\n", + "import os\n", + "import time\n", + "import json\n", + "from glob import glob\n", + "from PIL import Image\n", + "import pickle" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "b6qbGw8MRPE5", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Download and prepare the MS-COCO dataset\n", + "\n", + "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code code below will download and extract the dataset automatically. \n", + "\n", + "**Caution: large download ahead**. We'll use the training set, it's a 13GB file." + ] + }, + { + "metadata": { + "id": "krQuPYTtRPE7", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "annotation_zip = tf.keras.utils.get_file('captions.zip', \n", + " cache_subdir=os.path.abspath('.'),\n", + " origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n", + " extract = True)\n", + "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n", + "\n", + "name_of_zip = 'train2014.zip'\n", + "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n", + " image_zip = tf.keras.utils.get_file(name_of_zip, \n", + " cache_subdir=os.path.abspath('.'),\n", + " origin = 'http://images.cocodataset.org/zips/train2014.zip',\n", + " extract = True)\n", + " PATH = os.path.dirname(image_zip)+'/train2014/'\n", + "else:\n", + " PATH = os.path.abspath('.')+'/train2014/'" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "aANEzb5WwSzg", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Optionally, limit the size of the training set for faster training\n", + "For this example, we'll select a subset of 30,000 captions and use these and the corresponding images to train our model. As always, captioning quality will improve if you choose to use more data." + ] + }, + { + "metadata": { + "id": "4G3b8x8_RPFD", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# read the json file\n", + "with open(annotation_file, 'r') as f:\n", + " annotations = json.load(f)\n", + "\n", + "# storing the captions and the image name in vectors\n", + "all_captions = []\n", + "all_img_name_vector = []\n", + "\n", + "for annot in annotations['annotations']:\n", + " caption = ' ' + annot['caption'] + ' '\n", + " image_id = annot['image_id']\n", + " full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n", + " \n", + " all_img_name_vector.append(full_coco_image_path)\n", + " all_captions.append(caption)\n", + "\n", + "# shuffling the captions and image_names together\n", + "# setting a random state\n", + "train_captions, img_name_vector = shuffle(all_captions,\n", + " all_img_name_vector,\n", + " random_state=1)\n", + "\n", + "# selecting the first 30000 captions from the shuffled set\n", + "num_examples = 30000\n", + "train_captions = train_captions[:num_examples]\n", + "img_name_vector = img_name_vector[:num_examples]" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "mPBMgK34RPFL", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "len(train_captions), len(all_captions)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "8cSW4u-ORPFQ", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Preprocess the images using InceptionV3\n", + "Next, we will use InceptionV3 (pretrained on Imagenet) to classify each image. We will extract features from the last convolutional layer. \n", + "\n", + "First, we will need to convert the images into the format inceptionV3 expects by:\n", + "* Resizing the image to (299, 299)\n", + "* Using the [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) method to place the pixels in the range of -1 to 1 (to match the format of the images used to train InceptionV3)." + ] + }, + { + "metadata": { + "id": "zXR0217aRPFR", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def load_image(image_path):\n", + " img = tf.read_file(image_path)\n", + " img = tf.image.decode_jpeg(img, channels=3)\n", + " img = tf.image.resize_images(img, (299, 299))\n", + " img = tf.keras.applications.inception_v3.preprocess_input(img)\n", + " return img, image_path" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "MDvIu4sXRPFV", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Initialize InceptionV3 and load the pretrained Imagenet weights\n", + "\n", + "To do so, we'll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. \n", + "* Each image is forwarded through the network and the vector that we get at the end is stored in a dictionary (image_name --> feature_vector). \n", + "* We use the last convolutional layer because we are using attention in this example. The shape of the output of this layer is ```8x8x2048```. \n", + "* We avoid doing this during training so it does not become a bottleneck. \n", + "* After all the images are passed through the network, we pickle the dictionary and save it to disk." + ] + }, + { + "metadata": { + "id": "RD3vW4SsRPFW", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "image_model = tf.keras.applications.InceptionV3(include_top=False, \n", + " weights='imagenet')\n", + "new_input = image_model.input\n", + "hidden_layer = image_model.layers[-1].output\n", + "\n", + "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "rERqlR3WRPGO", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Caching the features extracted from InceptionV3\n", + "\n", + "We will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but memory intensive, requiring 8 \\* 8 \\* 2048 floats per image. At the time of writing, this would exceed the memory limitations of Colab (although these may change, an instance appears to have about 12GB of memory currently). \n", + "\n", + "Performance could be improved with a more sophisticated caching strategy (e.g., by sharding the images to reduce random access disk I/O) at the cost of more code.\n", + "\n", + "This will take about 10 minutes to run in Colab with a GPU. If you'd like to see a progress bar, you could: install [tqdm](https://github.com/tqdm/tqdm) (```!pip install tqdm```), then change this line: \n", + "\n", + "```for img, path in image_dataset:``` \n", + "\n", + "to:\n", + "\n", + "```for img, path in tqdm(image_dataset):```." + ] + }, + { + "metadata": { + "id": "Dx_fvbVgRPGQ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# getting the unique images\n", + "encode_train = sorted(set(img_name_vector))\n", + "\n", + "# feel free to change the batch_size according to your system configuration\n", + "image_dataset = tf.data.Dataset.from_tensor_slices(\n", + " encode_train).map(load_image).batch(16)\n", + "\n", + "for img, path in image_dataset:\n", + " batch_features = image_features_extract_model(img)\n", + " batch_features = tf.reshape(batch_features, \n", + " (batch_features.shape[0], -1, batch_features.shape[3]))\n", + "\n", + " for bf, p in zip(batch_features, path):\n", + " path_of_feature = p.numpy().decode(\"utf-8\")\n", + " np.save(path_of_feature, bf.numpy())" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "nyqH3zFwRPFi", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Preprocess and tokenize the captions\n", + "\n", + "* First, we'll tokenize the captions (e.g., by splitting on spaces). This will give us a vocabulary of all the unique words in the data (e.g., \"surfing\", \"football\", etc).\n", + "* Next, we'll limit the vocabulary size to the top 5,000 words to save memory. We'll replace all other words with the token \"UNK\" (for unknown).\n", + "* Finally, we create a word --> index mapping and vice-versa.\n", + "* We will then pad all sequences to the be same length as the longest one. " + ] + }, + { + "metadata": { + "id": "HZfK8RhQRPFj", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# This will find the maximum length of any caption in our dataset\n", + "def calc_max_length(tensor):\n", + " return max(len(t) for t in tensor)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "oJGE34aiRPFo", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# The steps above is a general process of dealing with text processing\n", + "\n", + "# choosing the top 5000 words from the vocabulary\n", + "top_k = 5000\n", + "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k, \n", + " oov_token=\"\", \n", + " filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n", + "tokenizer.fit_on_texts(train_captions)\n", + "train_seqs = tokenizer.texts_to_sequences(train_captions)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "8Q44tNQVRPFt", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "tokenizer.word_index = {key:value for key, value in tokenizer.word_index.items() if value <= top_k}\n", + "# putting token in the word2idx dictionary\n", + "tokenizer.word_index[tokenizer.oov_token] = top_k + 1\n", + "tokenizer.word_index[''] = 0" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "0fpJb5ojRPFv", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# creating the tokenized vectors\n", + "train_seqs = tokenizer.texts_to_sequences(train_captions)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "olQArbgbRPF1", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# creating a reverse mapping (index -> word)\n", + "index_word = {value:key for key, value in tokenizer.word_index.items()}" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "AidglIZVRPF4", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# padding each vector to the max_length of the captions\n", + "# if the max_length parameter is not provided, pad_sequences calculates that automatically\n", + "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "gL0wkttkRPGA", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# calculating the max_length \n", + "# used to store the attention weights\n", + "max_length = calc_max_length(train_seqs)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "M3CD75nDpvTI", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Split the data into training and testing" + ] + }, + { + "metadata": { + "id": "iS7DDMszRPGF", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Create training and validation sets using 80-20 split\n", + "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector, \n", + " cap_vector, \n", + " test_size=0.2, \n", + " random_state=0)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "XmViPkRFRPGH", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "uEWM9xrYcg45", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Our images and captions are ready! Next, let's create a tf.data dataset to use for training our model.\n", + "\n" + ] + }, + { + "metadata": { + "id": "Q3TnZ1ToRPGV", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# feel free to change these parameters according to your system's configuration\n", + "\n", + "BATCH_SIZE = 64\n", + "BUFFER_SIZE = 1000\n", + "embedding_dim = 256\n", + "units = 512\n", + "vocab_size = len(tokenizer.word_index)\n", + "# shape of the vector extracted from InceptionV3 is (64, 2048)\n", + "# these two variables represent that\n", + "features_shape = 2048\n", + "attention_features_shape = 64" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "SmZS2N0bXG3T", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# loading the numpy files \n", + "def map_func(img_name, cap):\n", + " img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n", + " return img_tensor, cap" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "FDF_Nm3tRPGZ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n", + "\n", + "# using map to load the numpy files in parallel\n", + "# NOTE: Be sure to set num_parallel_calls to the number of CPU cores you have\n", + "# https://www.tensorflow.org/api_docs/python/tf/py_func\n", + "dataset = dataset.map(lambda item1, item2: tf.py_func(\n", + " map_func, [item1, item2], [tf.float32, tf.int32]), num_parallel_calls=8)\n", + "\n", + "# shuffling and batching\n", + "dataset = dataset.shuffle(BUFFER_SIZE)\n", + "# https://www.tensorflow.org/api_docs/python/tf/contrib/data/batch_and_drop_remainder\n", + "dataset = dataset.batch(BATCH_SIZE)\n", + "dataset = dataset.prefetch(1)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "nrvoDphgRPGd", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Model\n", + "\n", + "Fun fact, the decoder below is identical to the one in the example for [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", + "\n", + "The model architecture is inspired by the [Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) paper.\n", + "\n", + "* In this example, we extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048). \n", + "* We squash that to a shape of (64, 2048).\n", + "* This vector is then passed through the CNN Encoder(which consists of a single Fully connected layer).\n", + "* The RNN(here GRU) attends over the image to predict the next word." + ] + }, + { + "metadata": { + "id": "AAppCGLKRPGd", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def gru(units):\n", + " # If you have a GPU, we recommend using the CuDNNGRU layer (it provides a \n", + " # significant speedup).\n", + " if tf.test.is_gpu_available():\n", + " return tf.keras.layers.CuDNNGRU(units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_initializer='glorot_uniform')\n", + " else:\n", + " return tf.keras.layers.GRU(units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_activation='sigmoid', \n", + " recurrent_initializer='glorot_uniform')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "ja2LFTMSdeV3", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class BahdanauAttention(tf.keras.Model):\n", + " def __init__(self, units):\n", + " super(BahdanauAttention, self).__init__()\n", + " self.W1 = tf.keras.layers.Dense(units)\n", + " self.W2 = tf.keras.layers.Dense(units)\n", + " self.V = tf.keras.layers.Dense(1)\n", + " \n", + " def call(self, features, hidden):\n", + " # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n", + " \n", + " # hidden shape == (batch_size, hidden_size)\n", + " # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n", + " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", + " \n", + " # score shape == (batch_size, 64, hidden_size)\n", + " score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n", + " \n", + " # attention_weights shape == (batch_size, 64, 1)\n", + " # we get 1 at the last axis because we are applying score to self.V\n", + " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", + " \n", + " # context_vector shape after sum == (batch_size, hidden_size)\n", + " context_vector = attention_weights * features\n", + " context_vector = tf.reduce_sum(context_vector, axis=1)\n", + " \n", + " return context_vector, attention_weights" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "AZ7R1RxHRPGf", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class CNN_Encoder(tf.keras.Model):\n", + " # Since we have already extracted the features and dumped it using pickle\n", + " # This encoder passes those features through a Fully connected layer\n", + " def __init__(self, embedding_dim):\n", + " super(CNN_Encoder, self).__init__()\n", + " # shape after fc == (batch_size, 64, embedding_dim)\n", + " self.fc = tf.keras.layers.Dense(embedding_dim)\n", + " \n", + " def call(self, x):\n", + " x = self.fc(x)\n", + " x = tf.nn.relu(x)\n", + " return x" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "V9UbGQmERPGi", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class RNN_Decoder(tf.keras.Model):\n", + " def __init__(self, embedding_dim, units, vocab_size):\n", + " super(RNN_Decoder, self).__init__()\n", + " self.units = units\n", + "\n", + " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", + " self.gru = gru(self.units)\n", + " self.fc1 = tf.keras.layers.Dense(self.units)\n", + " self.fc2 = tf.keras.layers.Dense(vocab_size)\n", + " \n", + " self.attention = BahdanauAttention(self.units)\n", + " \n", + " def call(self, x, features, hidden):\n", + " # defining attention as a separate model\n", + " context_vector, attention_weights = self.attention(features, hidden)\n", + " \n", + " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", + " x = self.embedding(x)\n", + " \n", + " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", + " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", + " \n", + " # passing the concatenated vector to the GRU\n", + " output, state = self.gru(x)\n", + " \n", + " # shape == (batch_size, max_length, hidden_size)\n", + " x = self.fc1(output)\n", + " \n", + " # x shape == (batch_size * max_length, hidden_size)\n", + " x = tf.reshape(x, (-1, x.shape[2]))\n", + " \n", + " # output shape == (batch_size * max_length, vocab)\n", + " x = self.fc2(x)\n", + "\n", + " return x, state, attention_weights\n", + "\n", + " def reset_state(self, batch_size):\n", + " return tf.zeros((batch_size, self.units))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Qs_Sr03wRPGk", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "encoder = CNN_Encoder(embedding_dim)\n", + "decoder = RNN_Decoder(embedding_dim, units, vocab_size)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "-bYN7xA0RPGl", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "optimizer = tf.train.AdamOptimizer()\n", + "\n", + "# We are masking the loss calculated for padding\n", + "def loss_function(real, pred):\n", + " mask = 1 - np.equal(real, 0)\n", + " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", + " return tf.reduce_mean(loss_)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "PHod7t72RPGn", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Training\n", + "\n", + "* We extract the features stored in the respective `.npy` files and then pass those features through the encoder.\n", + "* The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.\n", + "* The decoder returns the predictions and the decoder hidden state.\n", + "* The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n", + "* Use teacher forcing to decide the next input to the decoder.\n", + "* Teacher forcing is the technique where the target word is passed as the next input to the decoder.\n", + "* The final step is to calculate the gradients and apply it to the optimizer and backpropagate.\n" + ] + }, + { + "metadata": { + "id": "Vt4WZ5mhJE-E", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# adding this in a separate cell because if you run the training cell \n", + "# many times, the loss_plot array will be reset\n", + "loss_plot = []" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "UlA4VIQpRPGo", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "EPOCHS = 20\n", + "\n", + "for epoch in range(EPOCHS):\n", + " start = time.time()\n", + " total_loss = 0\n", + " \n", + " for (batch, (img_tensor, target)) in enumerate(dataset):\n", + " loss = 0\n", + " \n", + " # initializing the hidden state for each batch\n", + " # because the captions are not related from image to image\n", + " hidden = decoder.reset_state(batch_size=target.shape[0])\n", + "\n", + " dec_input = tf.expand_dims([tokenizer.word_index['']] * BATCH_SIZE, 1)\n", + " \n", + " with tf.GradientTape() as tape:\n", + " features = encoder(img_tensor)\n", + " \n", + " for i in range(1, target.shape[1]):\n", + " # passing the features through the decoder\n", + " predictions, hidden, _ = decoder(dec_input, features, hidden)\n", + "\n", + " loss += loss_function(target[:, i], predictions)\n", + " \n", + " # using teacher forcing\n", + " dec_input = tf.expand_dims(target[:, i], 1)\n", + " \n", + " total_loss += (loss / int(target.shape[1]))\n", + " \n", + " variables = encoder.variables + decoder.variables\n", + " \n", + " gradients = tape.gradient(loss, variables) \n", + " \n", + " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", + " \n", + " if batch % 100 == 0:\n", + " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, \n", + " batch, \n", + " loss.numpy() / int(target.shape[1])))\n", + " # storing the epoch end loss value to plot later\n", + " loss_plot.append(total_loss / len(cap_vector))\n", + " \n", + " print ('Epoch {} Loss {:.6f}'.format(epoch + 1, \n", + " total_loss/len(cap_vector)))\n", + " print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "1Wm83G-ZBPcC", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "plt.plot(loss_plot)\n", + "plt.xlabel('Epochs')\n", + "plt.ylabel('Loss')\n", + "plt.title('Loss Plot')\n", + "plt.show()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "xGvOcLQKghXN", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Caption!\n", + "\n", + "* The evaluate function is similar to the training loop, except we don't use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n", + "* Stop predicting when the model predicts the end token.\n", + "* And store the attention weights for every time step." + ] + }, + { + "metadata": { + "id": "RCWpDtyNRPGs", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def evaluate(image):\n", + " attention_plot = np.zeros((max_length, attention_features_shape))\n", + "\n", + " hidden = decoder.reset_state(batch_size=1)\n", + "\n", + " temp_input = tf.expand_dims(load_image(image)[0], 0)\n", + " img_tensor_val = image_features_extract_model(temp_input)\n", + " img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n", + "\n", + " features = encoder(img_tensor_val)\n", + "\n", + " dec_input = tf.expand_dims([tokenizer.word_index['']], 0)\n", + " result = []\n", + "\n", + " for i in range(max_length):\n", + " predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n", + "\n", + " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", + "\n", + " predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n", + " result.append(index_word[predicted_id])\n", + "\n", + " if index_word[predicted_id] == '':\n", + " return result, attention_plot\n", + "\n", + " dec_input = tf.expand_dims([predicted_id], 0)\n", + "\n", + " attention_plot = attention_plot[:len(result), :]\n", + " return result, attention_plot" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "fD_y7PD6RPGt", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def plot_attention(image, result, attention_plot):\n", + " temp_image = np.array(Image.open(image))\n", + "\n", + " fig = plt.figure(figsize=(10, 10))\n", + " \n", + " len_result = len(result)\n", + " for l in range(len_result):\n", + " temp_att = np.resize(attention_plot[l], (8, 8))\n", + " ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n", + " ax.set_title(result[l])\n", + " img = ax.imshow(temp_image)\n", + " ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n", + "\n", + " plt.tight_layout()\n", + " plt.show()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "io7ws3ReRPGv", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# captions on the validation set\n", + "rid = np.random.randint(0, len(img_name_val))\n", + "image = img_name_val[rid]\n", + "real_caption = ' '.join([index_word[i] for i in cap_val[rid] if i not in [0]])\n", + "result, attention_plot = evaluate(image)\n", + "\n", + "print ('Real Caption:', real_caption)\n", + "print ('Prediction Caption:', ' '.join(result))\n", + "plot_attention(image, result, attention_plot)\n", + "# opening the image\n", + "Image.open(img_name_val[rid])" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Rprk3HEvZuxb", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Try it on your own images\n", + "For fun, below we've provided a method you can use to caption your own images with the model we've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)\n" + ] + }, + { + "metadata": { + "id": "9Psd1quzaAWg", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "image_url = 'https://tensorflow.org/images/surf.jpg'\n", + "image_extension = image_url[-4:]\n", + "image_path = tf.keras.utils.get_file('image'+image_extension, \n", + " origin=image_url)\n", + "\n", + "result, attention_plot = evaluate(image_path)\n", + "print ('Prediction Caption:', ' '.join(result))\n", + "plot_attention(image_path, result, attention_plot)\n", + "# opening the image\n", + "Image.open(image_path)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "VJZXyJco6uLO", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Next steps\n", + "\n", + "Congrats! You've just trained an image captioning model with attention. Next, we recommend taking a look at this example [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb). It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset." + ] + } + ] +} diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..b173f856c641b4d7dca96adda113f904c97a25a7 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb @@ -0,0 +1,689 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "hcD2nPQvPOFM" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n", + "\n", + "# Text Generation using a RNN\n", + "\n", + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on Github\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BwpJ5IffzRG6" + }, + "source": [ + "This notebook demonstrates how to generate text using an RNN using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). If you like, you can write a similar [model](https://github.com/fchollet/deep-learning-with-python-notebooks/blob/master/8.1-text-generation-with-lstm.ipynb) using less code. Here, we show a lower-level impementation that's useful to understand as prework before diving in to deeper examples in a similar, like [Neural Machine Translation with Attention](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", + "\n", + "This notebook is an end-to-end example. When you run it, it will download a dataset of Shakespeare's writing. We'll use a collection of plays, borrowed from Andrej Karpathy's excellent [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). The notebook will train a model, and use it to generate sample output.\n", + " \n", + "Here is the output(with start string='w') after training a single layer GRU for 30 epochs with the default settings below:\n", + "\n", + "```\n", + "were to the death of him\n", + "And nothing of the field in the view of hell,\n", + "When I said, banish him, I will not burn thee that would live.\n", + "\n", + "HENRY BOLINGBROKE:\n", + "My gracious uncle--\n", + "\n", + "DUKE OF YORK:\n", + "As much disgraced to the court, the gods them speak,\n", + "And now in peace himself excuse thee in the world.\n", + "\n", + "HORTENSIO:\n", + "Madam, 'tis not the cause of the counterfeit of the earth,\n", + "And leave me to the sun that set them on the earth\n", + "And leave the world and are revenged for thee.\n", + "\n", + "GLOUCESTER:\n", + "I would they were talking with the very name of means\n", + "To make a puppet of a guest, and therefore, good Grumio,\n", + "Nor arm'd to prison, o' the clouds, of the whole field,\n", + "With the admire\n", + "With the feeding of thy chair, and we have heard it so,\n", + "I thank you, sir, he is a visor friendship with your silly your bed.\n", + "\n", + "SAMPSON:\n", + "I do desire to live, I pray: some stand of the minds, make thee remedies\n", + "With the enemies of my soul.\n", + "\n", + "MENENIUS:\n", + "I'll keep the cause of my mistress.\n", + "\n", + "POLIXENES:\n", + "My brother Marcius!\n", + "\n", + "Second Servant:\n", + "Will't ple\n", + "```\n", + "\n", + "Of course, while some of the sentences are grammatical, most do not make sense. But, consider:\n", + "\n", + "* Our model is character based (when we began training, it did not yet know how to spell a valid English word, or that words were even a unit of text).\n", + "\n", + "* The structure of the output resembles a play (blocks begin with a speaker name, in all caps similar to the original text). Sentences generally end with a period. If you look at the text from a distance (or don't read the invididual words too closely, it appears as if it's an excerpt from a play).\n", + "\n", + "As a next step, you can experiment training the model on a different dataset - any large text file(ASCII) will do, and you can modify a single line of code below to make that change. Have fun!\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "R3p22DBDsaCA" + }, + "source": [ + "## Install unidecode library\n", + "A helpful library to convert unicode to ASCII." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "wZ6LOM12wKGH" + }, + "outputs": [], + "source": [ + "!pip install unidecode" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "WGyKZj3bzf9p" + }, + "source": [ + "## Import tensorflow and enable eager execution." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "yG_n40gFzf9s" + }, + "outputs": [], + "source": [ + "# Import TensorFlow \u003e= 1.9 and enable eager execution\n", + "import tensorflow as tf\n", + "\n", + "# Note: Once you enable eager execution, it cannot be disabled. \n", + "tf.enable_eager_execution()\n", + "\n", + "import numpy as np\n", + "import re\n", + "import random\n", + "import unidecode\n", + "import time" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "EHDoRoc5PKWz" + }, + "source": [ + "## Download the dataset\n", + "\n", + "In this example, we will use the [shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt). You can use any other dataset that you like.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "pD_55cOxLkAb" + }, + "outputs": [], + "source": [ + "path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "UHjdCjDuSvX_" + }, + "source": [ + "## Read the dataset\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "-E5JvY3wzf94" + }, + "outputs": [], + "source": [ + "text = unidecode.unidecode(open(path_to_file).read())\n", + "# length of text is the number of characters in it\n", + "print (len(text))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Il9ww98izf-D" + }, + "source": [ + "Creating dictionaries to map from characters to their indices and vice-versa, which will be used to vectorize the inputs" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "IalZLbvOzf-F" + }, + "outputs": [], + "source": [ + "# unique contains all the unique characters in the file\n", + "unique = sorted(set(text))\n", + "\n", + "# creating a mapping from unique characters to indices\n", + "char2idx = {u:i for i, u in enumerate(unique)}\n", + "idx2char = {i:u for i, u in enumerate(unique)}" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "1v_qUYfAzf-I" + }, + "outputs": [], + "source": [ + "# setting the maximum length sentence we want for a single input in characters\n", + "max_length = 100\n", + "\n", + "# length of the vocabulary in chars\n", + "vocab_size = len(unique)\n", + "\n", + "# the embedding dimension \n", + "embedding_dim = 256\n", + "\n", + "# number of RNN (here GRU) units\n", + "units = 1024\n", + "\n", + "# batch size \n", + "BATCH_SIZE = 64\n", + "\n", + "# buffer size to shuffle our dataset\n", + "BUFFER_SIZE = 10000" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "LFjSVAlWzf-N" + }, + "source": [ + "## Creating the input and output tensors\n", + "\n", + "Vectorizing the input and the target text because our model cannot understand strings only numbers.\n", + "\n", + "But first, we need to create the input and output vectors.\n", + "Remember the max_length we set above, we will use it here. We are creating **max_length** chunks of input, where each input vector is all the characters in that chunk except the last and the target vector is all the characters in that chunk except the first.\n", + "\n", + "For example, consider that the string = 'tensorflow' and the max_length is 9\n", + "\n", + "So, the `input = 'tensorflo'` and `output = 'ensorflow'`\n", + "\n", + "After creating the vectors, we convert each character into numbers using the **char2idx** dictionary we created above." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "0UHJDA39zf-O" + }, + "outputs": [], + "source": [ + "input_text = []\n", + "target_text = []\n", + "\n", + "for f in range(0, len(text)-max_length, max_length):\n", + " inps = text[f:f+max_length]\n", + " targ = text[f+1:f+1+max_length]\n", + "\n", + " input_text.append([char2idx[i] for i in inps])\n", + " target_text.append([char2idx[t] for t in targ])\n", + " \n", + "print (np.array(input_text).shape)\n", + "print (np.array(target_text).shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "MJdfPmdqzf-R" + }, + "source": [ + "## Creating batches and shuffling them using tf.data" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "p2pGotuNzf-S" + }, + "outputs": [], + "source": [ + "dataset = tf.data.Dataset.from_tensor_slices((input_text, target_text)).shuffle(BUFFER_SIZE)\n", + "dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "m8gPwEjRzf-Z" + }, + "source": [ + "## Creating the model\n", + "\n", + "We use the Model Subclassing API which gives us full flexibility to create the model and change it however we like. We use 3 layers to define our model.\n", + "\n", + "* Embedding layer\n", + "* GRU layer (you can use an LSTM layer here)\n", + "* Fully connected layer" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "P3KTiiInzf-a" + }, + "outputs": [], + "source": [ + "class Model(tf.keras.Model):\n", + " def __init__(self, vocab_size, embedding_dim, units, batch_size):\n", + " super(Model, self).__init__()\n", + " self.units = units\n", + " self.batch_sz = batch_size\n", + "\n", + " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", + "\n", + " if tf.test.is_gpu_available():\n", + " self.gru = tf.keras.layers.CuDNNGRU(self.units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_initializer='glorot_uniform')\n", + " else:\n", + " self.gru = tf.keras.layers.GRU(self.units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_activation='sigmoid', \n", + " recurrent_initializer='glorot_uniform')\n", + "\n", + " self.fc = tf.keras.layers.Dense(vocab_size)\n", + " \n", + " def call(self, x, hidden):\n", + " x = self.embedding(x)\n", + "\n", + " # output shape == (batch_size, max_length, hidden_size) \n", + " # states shape == (batch_size, hidden_size)\n", + "\n", + " # states variable to preserve the state of the model\n", + " # this will be used to pass at every step to the model while training\n", + " output, states = self.gru(x, initial_state=hidden)\n", + "\n", + "\n", + " # reshaping the output so that we can pass it to the Dense layer\n", + " # after reshaping the shape is (batch_size * max_length, hidden_size)\n", + " output = tf.reshape(output, (-1, output.shape[2]))\n", + "\n", + " # The dense layer will output predictions for every time_steps(max_length)\n", + " # output shape after the dense layer == (max_length * batch_size, vocab_size)\n", + " x = self.fc(output)\n", + "\n", + " return x, states" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "trpqTWyvk0nr" + }, + "source": [ + "## Call the model and set the optimizer and the loss function" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "7t2XrzEOzf-e" + }, + "outputs": [], + "source": [ + "model = Model(vocab_size, embedding_dim, units, BATCH_SIZE)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "dkjWIATszf-h" + }, + "outputs": [], + "source": [ + "optimizer = tf.train.AdamOptimizer()\n", + "\n", + "# using sparse_softmax_cross_entropy so that we don't have to create one-hot vectors\n", + "def loss_function(real, preds):\n", + " return tf.losses.sparse_softmax_cross_entropy(labels=real, logits=preds)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "lPrP0XMUzf-p" + }, + "source": [ + "## Train the model\n", + "\n", + "Here we will use a custom training loop with the help of GradientTape()\n", + "\n", + "* We initialize the hidden state of the model with zeros and shape == (batch_size, number of rnn units). We do this by calling the function defined while creating the model.\n", + "\n", + "* Next, we iterate over the dataset(batch by batch) and calculate the **predictions and the hidden states** associated with that input.\n", + "\n", + "* There are a lot of interesting things happening here.\n", + " * The model gets hidden state(initialized with 0), lets call that **H0** and the first batch of input, lets call that **I0**.\n", + " * The model then returns the predictions **P1** and **H1**.\n", + " * For the next batch of input, the model receives **I1** and **H1**.\n", + " * The interesting thing here is that we pass **H1** to the model with **I1** which is how the model learns. The context learned from batch to batch is contained in the **hidden state**.\n", + " * We continue doing this until the dataset is exhausted and then we start a new epoch and repeat this.\n", + "\n", + "* After calculating the predictions, we calculate the **loss** using the loss function defined above. Then we calculate the gradients of the loss with respect to the model variables(input)\n", + "\n", + "* Finally, we take a step in that direction with the help of the optimizer using the apply_gradients function.\n", + "\n", + "Note:- If you are running this notebook in Colab which has a **Tesla K80 GPU** it takes about 23 seconds per epoch.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "d4tSNwymzf-q" + }, + "outputs": [], + "source": [ + "# Training step\n", + "\n", + "EPOCHS = 30\n", + "\n", + "for epoch in range(EPOCHS):\n", + " start = time.time()\n", + " \n", + " # initializing the hidden state at the start of every epoch\n", + " hidden = model.reset_states()\n", + " \n", + " for (batch, (inp, target)) in enumerate(dataset):\n", + " with tf.GradientTape() as tape:\n", + " # feeding the hidden state back into the model\n", + " # This is the interesting step\n", + " predictions, hidden = model(inp, hidden)\n", + " \n", + " # reshaping the target because that's how the \n", + " # loss function expects it\n", + " target = tf.reshape(target, (-1,))\n", + " loss = loss_function(target, predictions)\n", + " \n", + " grads = tape.gradient(loss, model.variables)\n", + " optimizer.apply_gradients(zip(grads, model.variables), global_step=tf.train.get_or_create_global_step())\n", + "\n", + " if batch % 100 == 0:\n", + " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch+1,\n", + " batch,\n", + " loss))\n", + " \n", + " print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))\n", + " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "DjGz1tDkzf-u" + }, + "source": [ + "## Predicting using our trained model\n", + "\n", + "The below code block is used to generated the text\n", + "\n", + "* We start by choosing a start string and initializing the hidden state and setting the number of characters we want to generate.\n", + "\n", + "* We get predictions using the start_string and the hidden state\n", + "\n", + "* Then we use a multinomial distribution to calculate the index of the predicted word. **We use this predicted word as our next input to the model**\n", + "\n", + "* **The hidden state returned by the model is fed back into the model so that it now has more context rather than just one word.** After we predict the next word, the modified hidden states are again fed back into the model, which is how it learns as it gets more context from the previously predicted words.\n", + "\n", + "* If you see the predictions, the model knows when to capitalize, make paragraphs and the text follows a shakespeare style of writing which is pretty awesome!" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "WvuwZBX5Ogfd" + }, + "outputs": [], + "source": [ + "# Evaluation step(generating text using the model learned)\n", + "\n", + "# number of characters to generate\n", + "num_generate = 1000\n", + "\n", + "# You can change the start string to experiment\n", + "start_string = 'Q'\n", + "# converting our start string to numbers(vectorizing!) \n", + "input_eval = [char2idx[s] for s in start_string]\n", + "input_eval = tf.expand_dims(input_eval, 0)\n", + "\n", + "# empty string to store our results\n", + "text_generated = ''\n", + "\n", + "# low temperatures results in more predictable text.\n", + "# higher temperatures results in more surprising text\n", + "# experiment to find the best setting\n", + "temperature = 1.0\n", + "\n", + "# hidden state shape == (batch_size, number of rnn units); here batch size == 1\n", + "hidden = [tf.zeros((1, units))]\n", + "for i in range(num_generate):\n", + " predictions, hidden = model(input_eval, hidden)\n", + "\n", + " # using a multinomial distribution to predict the word returned by the model\n", + " predictions = predictions / temperature\n", + " predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n", + " \n", + " # We pass the predicted word as the next input to the model\n", + " # along with the previous hidden state\n", + " input_eval = tf.expand_dims([predicted_id], 0)\n", + " \n", + " text_generated += idx2char[predicted_id]\n", + "\n", + "print (start_string + text_generated)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "AM2Uma_-yVIq" + }, + "source": [ + "## Next steps\n", + "\n", + "* Change the start string to a different character, or the start of a sentence.\n", + "* Experiment with training on a different, or with different parameters. [Project Gutenberg](http://www.gutenberg.org/ebooks/100), for example, contains a large collection of books.\n", + "* Experiment with the temperature parameter.\n", + "* Add another RNN layer.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "gtEd86sX5cB2" + }, + "outputs": [], + "source": [ + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "text_generation.ipynb", + "private_outputs": true, + "provenance": [], + "toc_visible": true, + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/README.md b/tensorflow/contrib/eager/python/examples/l2hmc/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f171806e379da7213b6ee33e0d454056068fe7a5 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/README.md @@ -0,0 +1,53 @@ +# L2HMC with TensorFlow eager execution + +This folder contains an implementation of [L2HMC](https://arxiv.org/pdf/1711.09268.pdf) adapted from the released implementation by the authors. The presented implementation runs in both eager and graph mode. +With eager execution enabled, longer sample chains can be handled compared to graph mode, since no graph is explicitly stored. Moreover, with eager execution enabled, there is no need to use a `tf.while_loop`. + +## What is L2HMC? +L2HMC is an adaptive Markov Chain Monte Carlo (MCMC) algorithm that learns a non-volume preserving transformation +for a Hamiltonian Monte Carlo (HMC) sampling algorithm. More specifically, the non-volume preserving +transformation is learned with neural nets instantiated within Normalizing Flows +(real-NVPs). + +## Content + +- `l2hmc.py`: Dynamics definitions and example energy functions, +including the 2D strongly correlated Gaussian and the rough well energy function, +- `l2hmc_test.py`: Unit tests and benchmarks for training a sampler on the energy functions in both eager and graph mode. +- `neural_nets.py`: The neural net for learning the kernel on the 2D strongly correlated example. +- `main.py`: Run to train a samplers on 2D energy landscapes. + +## To run +- Make sure you have installed TensorFlow 1.9+ or the latest `tf-nightly` or `tf-nightly-gpu` pip package. +- Execute the command + +```bash +python main.py --train_dir ${PWD}/dump --use_defun +``` + +Specifying the optional argument `train_dir` will store event files for +tensorboard and a plot of sampled chain from the trained sampler. + +Specifying the optional argument `use_defun` will let the program use compiled +graphs when running specific sections and improve the overall speed. + +## Boosting Performance with `tfe.defun` +Currently, some models may experience increased overhead with eager execution enabled. +To improve performance, we could wrap certain functions with the decorator `@tfe.defun`. +For example, we could wrap the function that does the sampling step: + +```python +@tfe.defun +def apply_transition(old_sample): + new_sample = ... + return new_sample +``` + +We could also explicitly wrap the desired function with `tfe.defun`: + +```python +apply_transition = tfe.defun(apply_transition) +``` + +## Reference +Generalizing Hamiltonian Monte Carlo with Neural Networks. Levy, Daniel, Hoffman, Matthew D, and Sohl-Dickstein, Jascha. International Conference on Learning Representations (ICLR), 2018. diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py index 729d8525fab31ee214178ca1bcb18dbd069f767a..14b8324e488a864cb23ff2507fab1c53c0583bc0 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py @@ -32,20 +32,28 @@ from tensorflow.contrib.eager.python.examples.l2hmc import neural_nets class Dynamics(tf.keras.Model): - """Dynamics engine of naive L2HMC sampler. - - Args: - x_dim: dimensionality of observed data - loglikelihood_fn: log-likelihood function of conditional probability - n_steps: number of leapfrog steps within each transition - eps: initial value learnable scale of step size - """ - - def __init__(self, x_dim, loglikelihood_fn, n_steps=25, eps=.1): + """Dynamics engine of naive L2HMC sampler.""" + + def __init__(self, + x_dim, + minus_loglikelihood_fn, + n_steps=25, + eps=.1, + np_seed=1): + """Initialization. + + Args: + x_dim: dimensionality of observed data + minus_loglikelihood_fn: log-likelihood function of conditional probability + n_steps: number of leapfrog steps within each transition + eps: initial value learnable scale of step size + np_seed: Random seed for numpy; used to control sampled masks. + """ super(Dynamics, self).__init__() + npr.seed(np_seed) self.x_dim = x_dim - self.potential = loglikelihood_fn + self.potential = minus_loglikelihood_fn self.n_steps = n_steps self._construct_time() @@ -54,7 +62,7 @@ class Dynamics(tf.keras.Model): self.position_fn = neural_nets.GenericNet(x_dim, factor=2.) self.momentum_fn = neural_nets.GenericNet(x_dim, factor=1.) - self.eps = tfe.Variable( + self.eps = tf.Variable( initial_value=eps, name="eps", dtype=tf.float32, trainable=True) def apply_transition(self, position): @@ -68,8 +76,8 @@ class Dynamics(tf.keras.Model): position, forward=False) # Decide direction uniformly - forward_mask = tf.cast( - tf.random_uniform(shape=[tf.shape(position)[0]]) > .5, tf.float32) + batch_size = tf.shape(position)[0] + forward_mask = tf.cast(tf.random_uniform((batch_size,)) > .5, tf.float32) backward_mask = 1. - forward_mask # Obtain proposed states @@ -108,7 +116,6 @@ class Dynamics(tf.keras.Model): position_post, momentum_post, logdet = lf_fn(position_post, momentum_post, i) sumlogdet += logdet - accept_prob = self._compute_accept_prob(position, momentum, position_post, momentum_post, sumlogdet) @@ -125,17 +132,17 @@ class Dynamics(tf.keras.Model): sumlogdet += logdet position, logdet = self._update_position_forward(position, momentum, t, - mask) + mask, mask_inv) sumlogdet += logdet position, logdet = self._update_position_forward(position, momentum, t, - mask_inv) + mask_inv, mask) sumlogdet += logdet momentum, logdet = self._update_momentum_forward(position, momentum, t) sumlogdet += logdet - return position, momentum, tf.reduce_sum(sumlogdet, axis=1) + return position, momentum, sumlogdet def _backward_lf(self, position, momentum, i): """One backward augmented leapfrog step. See Appendix A in paper.""" @@ -149,17 +156,17 @@ class Dynamics(tf.keras.Model): sumlogdet += logdet position, logdet = self._update_position_backward(position, momentum, t, - mask) + mask_inv, mask) sumlogdet += logdet position, logdet = self._update_position_backward(position, momentum, t, - mask_inv) + mask, mask_inv) sumlogdet += logdet momentum, logdet = self._update_momentum_backward(position, momentum, t) sumlogdet += logdet - return position, momentum, tf.reduce_sum(sumlogdet, axis=1) + return position, momentum, sumlogdet def _update_momentum_forward(self, position, momentum, t): """Update v in the forward leapfrog step.""" @@ -172,12 +179,11 @@ class Dynamics(tf.keras.Model): momentum * tf.exp(scale) - .5 * self.eps * (tf.exp(transformed) * grad - translation)) - return momentum, scale + return momentum, tf.reduce_sum(scale, axis=1) - def _update_position_forward(self, position, momentum, t, mask): + def _update_position_forward(self, position, momentum, t, mask, mask_inv): """Update x in the forward leapfrog step.""" - mask_inv = 1. - mask scale, translation, transformed = self.position_fn( [momentum, mask * position, t]) scale *= self.eps @@ -186,8 +192,7 @@ class Dynamics(tf.keras.Model): mask * position + mask_inv * (position * tf.exp(scale) + self.eps * (tf.exp(transformed) * momentum + translation))) - - return position, mask_inv * scale + return position, tf.reduce_sum(mask_inv * scale, axis=1) def _update_momentum_backward(self, position, momentum, t): """Update v in the backward leapfrog step. Inverting the forward update.""" @@ -200,21 +205,20 @@ class Dynamics(tf.keras.Model): tf.exp(scale) * (momentum + .5 * self.eps * (tf.exp(transformed) * grad - translation))) - return momentum, scale + return momentum, tf.reduce_sum(scale, axis=1) - def _update_position_backward(self, position, momentum, t, mask): + def _update_position_backward(self, position, momentum, t, mask, mask_inv): """Update x in the backward leapfrog step. Inverting the forward update.""" - mask_inv = 1. - mask scale, translation, transformed = self.position_fn( - [momentum, mask_inv * position, t]) + [momentum, mask * position, t]) scale *= -self.eps transformed *= self.eps position = ( - mask_inv * position + mask * tf.exp(scale) * - (position - self.eps * tf.exp(transformed) * momentum + translation)) + mask * position + mask_inv * tf.exp(scale) * + (position - self.eps * (tf.exp(transformed) * momentum + translation))) - return position, mask * scale + return position, tf.reduce_sum(mask_inv * scale, axis=1) def _compute_accept_prob(self, position, momentum, position_post, momentum_post, sumlogdet): @@ -222,8 +226,10 @@ class Dynamics(tf.keras.Model): old_hamil = self.hamiltonian(position, momentum) new_hamil = self.hamiltonian(position_post, momentum_post) + prob = tf.exp(tf.minimum(old_hamil - new_hamil + sumlogdet, 0.)) - return tf.exp(tf.minimum(old_hamil - new_hamil + sumlogdet, 0.)) + # Ensure numerical stability as well as correct gradients + return tf.where(tf.is_finite(prob), prob, tf.zeros_like(prob)) def _construct_time(self): """Convert leapfrog step index into sinusoidal time.""" @@ -248,6 +254,8 @@ class Dynamics(tf.keras.Model): self.masks = [] for _ in range(self.n_steps): + # Need to use npr here because tf would generated different random + # values across different `sess.run` idx = npr.permutation(np.arange(self.x_dim))[:self.x_dim // 2] mask = np.zeros((self.x_dim,)) mask[idx] = 1. @@ -273,19 +281,15 @@ class Dynamics(tf.keras.Model): def grad_potential(self, position, check_numerics=True): """Get gradient of potential function at current location.""" - if not tf.executing_eagerly(): - # TODO(lxuechen): Change this to tfe.gradients_function when it works - grad = tf.gradients(self.potential(position), position)[0] - else: + if tf.executing_eagerly(): grad = tfe.gradients_function(self.potential)(position)[0] - - if check_numerics: - return tf.check_numerics(grad, message="gradient of potential") + else: + grad = tf.gradients(self.potential(position), position)[0] return grad -# Examples of unnormalized log density/probabilities +# Examples of unnormalized log densities def get_scg_energy_fn(): """Get energy function for 2d strongly correlated Gaussian.""" @@ -295,32 +299,53 @@ def get_scg_energy_fn(): sigma_inv = tf.matrix_inverse(sigma) def energy(x): - """Unnormalized log density/energy of 2d strongly correlated Gaussian.""" + """Unnormalized minus log density of 2d strongly correlated Gaussian.""" xmmu = x - mu return .5 * tf.diag_part( tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu))) - return energy + return energy, mu, sigma -def get_multivariate_gaussian_energy_fn(x_dim=2): - """Get energy function for 2d strongly correlated Gaussian.""" - - mu = tf.random_normal(shape=[x_dim]) - # Lower triangularize and positive diagonal - l = tf.sigmoid( - tf.matrix_band_part(tf.random_normal(shape=[x_dim, x_dim]), -1, 0)) - # Exploit Cholesky decomposition - sigma = tf.matmul(l, tf.transpose(l)) - sigma *= 100. # Small covariance causes extreme numerical instability - sigma_inv = tf.matrix_inverse(sigma) +def get_rw_energy_fn(): + """Get energy function for rough well distribution.""" + # For small eta, the density underlying the rough-well energy is very close to + # a unit Gaussian; however, the gradient is greatly affected by the small + # cosine perturbations + eta = 1e-2 + mu = tf.constant([0., 0.]) + sigma = tf.constant([[1., 0.], [0., 1.]]) def energy(x): - """Unnormalized log density/energy of 2d strongly correlated Gaussian.""" + ip = tf.reduce_sum(x**2., axis=1) + return .5 * ip + eta * tf.reduce_sum(tf.cos(x / eta), axis=1) - xmmu = x - mu - return .5 * tf.diag_part( - tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu))) + return energy, mu, sigma + + +# Loss function +def compute_loss(dynamics, x, scale=.1, eps=1e-4): + """Compute loss defined in equation (8).""" + + z = tf.random_normal(tf.shape(x)) # Auxiliary variable + x_, _, x_accept_prob, x_out = dynamics.apply_transition(x) + z_, _, z_accept_prob, _ = dynamics.apply_transition(z) + + # Add eps for numerical stability; following released impl + x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps + z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps + + loss = tf.reduce_mean( + (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0) + + return loss, x_out, x_accept_prob + + +def loss_and_grads(dynamics, x, loss_fn=compute_loss): + """Obtain loss value and gradients.""" + with tf.GradientTape() as tape: + loss_val, out, accept_prob = loss_fn(dynamics, x) + grads = tape.gradient(loss_val, dynamics.trainable_variables) - return energy + return loss_val, grads, out, accept_prob diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py index e33b4cae4c73388dfd78542c9907953f137ad710..955747988536bd21d52df66a35af4aa31b3f7688 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py @@ -37,63 +37,37 @@ def get_default_hparams(): n_warmup_iters=3) -# Relevant functions for benchmarking -def compute_loss(dynamics, x, scale=.1, eps=1e-4): - """Compute loss defined in equation (8).""" - - z = tf.random_normal(tf.shape(x)) - x_, _, x_accept_prob, x_out = dynamics.apply_transition(x) - z_, _, z_accept_prob, _ = dynamics.apply_transition(z) - - # Add eps for numerical stability; following released impl - x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps - z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps - - loss = tf.reduce_mean( - (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0) - - return loss, x_out - - -def loss_and_grads(dynamics, x, loss_fn=compute_loss): - """Obtain loss value and gradients.""" - - with tf.GradientTape() as tape: - loss_val, x_out = loss_fn(dynamics, x) - grads = tape.gradient(loss_val, dynamics.variables) - - return loss_val, grads, x_out - - -def warmup(dynamics, optimizer, n_iters=1, n_samples=200, loss_fn=compute_loss): +def warmup(dynamics, + optimizer, + n_iters=1, + n_samples=200, + loss_fn=l2hmc.compute_loss): """Warmup optimization to reduce overhead.""" samples = tf.random_normal( shape=[n_samples, dynamics.x_dim], dtype=tf.float32) for _ in range(n_iters): - _, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn) + _, grads, samples, _ = l2hmc.loss_and_grads( + dynamics, samples, loss_fn=loss_fn) optimizer.apply_gradients(zip(grads, dynamics.variables)) def fit(dynamics, samples, optimizer, - loss_fn=compute_loss, + loss_fn=l2hmc.compute_loss, n_iters=5000, verbose=True, - logdir=None, - decay_lr=True): + logdir=None): """Fit L2HMC sampler with given log-likelihood function.""" if logdir: summary_writer = tf.contrib.summary.create_file_writer(logdir) for i in range(n_iters): - loss, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn) - # TODO(lxuechen): Proper learning rate decay - if decay_lr: - grads = [grad * .96**(i // 1000) for grad in grads] + loss, grads, samples, _ = l2hmc.loss_and_grads( + dynamics, samples, loss_fn=loss_fn) optimizer.apply_gradients(zip(grads, dynamics.variables)) if verbose: print("Iteration %d: loss %.4f" % (i, loss)) @@ -112,9 +86,10 @@ class L2hmcTest(tf.test.TestCase): # Eager mode testing hparams = get_default_hparams() + energy_fn, _, _ = l2hmc.get_scg_energy_fn() dynamics = l2hmc.Dynamics( x_dim=hparams.x_dim, - loglikelihood_fn=l2hmc.get_scg_energy_fn(), + minus_loglikelihood_fn=energy_fn, n_steps=hparams.n_steps, eps=hparams.eps) samples = tf.random_normal(shape=[hparams.n_samples, hparams.x_dim]) @@ -127,9 +102,10 @@ class L2hmcTest(tf.test.TestCase): # Graph mode testing with tf.Graph().as_default(): + energy_fn, _, _ = l2hmc.get_scg_energy_fn() dynamics = l2hmc.Dynamics( x_dim=hparams.x_dim, - loglikelihood_fn=l2hmc.get_scg_energy_fn(), + minus_loglikelihood_fn=energy_fn, n_steps=hparams.n_steps, eps=hparams.eps) x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim]) @@ -150,32 +126,20 @@ class L2hmcTest(tf.test.TestCase): class L2hmcBenchmark(tf.test.Benchmark): """Eager and graph benchmarks for l2hmc.""" - def _get_energy_fn(self): - """Get specific energy function according to FLAGS.""" - - if FLAGS.energy_fn == "scg": - energy_fn = l2hmc.get_scg_energy_fn() - elif FLAGS.energy_fn == "multivariate_gaussian": - energy_fn = l2hmc.get_multivariate_gaussian_energy_fn(x_dim=FLAGS.x_dim) - else: - raise ValueError("No such energy function %s" % FLAGS.energy_fn) - - return energy_fn - def benchmark_graph(self): """Benchmark Graph performance.""" hparams = get_default_hparams() tf.reset_default_graph() with tf.Graph().as_default(): - energy_fn = self._get_energy_fn() + energy_fn, _, _ = l2hmc.get_scg_energy_fn() dynamics = l2hmc.Dynamics( x_dim=hparams.x_dim, - loglikelihood_fn=energy_fn, + minus_loglikelihood_fn=energy_fn, n_steps=hparams.n_steps, eps=hparams.eps) x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim]) - loss, x_out = compute_loss(dynamics, x) + loss, x_out, _ = l2hmc.compute_loss(dynamics, x) global_step = tf.Variable(0., name="global_step", trainable=False) learning_rate = tf.train.exponential_decay( @@ -183,7 +147,11 @@ class L2hmcBenchmark(tf.test.Benchmark): optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) train_op = optimizer.minimize(loss, global_step=global_step) - with tf.Session() as sess: + # Single thread; fairer comparison against eager + session_conf = tf.ConfigProto( + intra_op_parallelism_threads=1, inter_op_parallelism_threads=1) + + with tf.Session(config=session_conf) as sess: sess.run(tf.global_variables_initializer()) # Warmup to reduce initialization effect when timing @@ -218,14 +186,14 @@ class L2hmcBenchmark(tf.test.Benchmark): """Benchmark Eager performance.""" hparams = get_default_hparams() - energy_fn = self._get_energy_fn() + energy_fn, _, _ = l2hmc.get_scg_energy_fn() dynamics = l2hmc.Dynamics( x_dim=hparams.x_dim, - loglikelihood_fn=energy_fn, + minus_loglikelihood_fn=energy_fn, n_steps=hparams.n_steps, eps=hparams.eps) optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate) - loss_fn = tfe.defun(compute_loss) if defun else compute_loss + loss_fn = tfe.defun(l2hmc.compute_loss) if defun else l2hmc.compute_loss # Warmup to reduce initialization effect when timing warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, loss_fn=loss_fn) @@ -234,12 +202,7 @@ class L2hmcBenchmark(tf.test.Benchmark): samples = tf.random_normal( shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32) start_time = time.time() - fit(dynamics, - samples, - optimizer, - loss_fn=loss_fn, - n_iters=hparams.n_iters, - decay_lr=True) + fit(dynamics, samples, optimizer, loss_fn=loss_fn, n_iters=hparams.n_iters) wall_time = time.time() - start_time examples_per_sec = hparams.n_samples / wall_time @@ -251,14 +214,8 @@ class L2hmcBenchmark(tf.test.Benchmark): wall_time=wall_time) del dynamics - del loss_fn if __name__ == "__main__": - tf.flags.DEFINE_string("energy_fn", "scg", - ("The energy function/unnormalized log-probability. " - "Either be `scg` or `multivariate_gaussian`")) - tf.flags.DEFINE_integer("x_dim", 2, "Dimensionality of observation space.") - FLAGS = tf.flags.FLAGS tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/main.py b/tensorflow/contrib/eager/python/examples/l2hmc/main.py new file mode 100644 index 0000000000000000000000000000000000000000..45e1f98429f48749d374c2aefd8874690c3830ad --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/main.py @@ -0,0 +1,235 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""L2HMC on simple Gaussian mixture model with TensorFlow eager.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +from absl import flags +import numpy as np +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.l2hmc import l2hmc +try: + import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top + HAS_MATPLOTLIB = True +except ImportError: + HAS_MATPLOTLIB = False +tfe = tf.contrib.eager + + +def main(_): + tf.enable_eager_execution() + global_step = tf.train.get_or_create_global_step() + global_step.assign(1) + + energy_fn, mean, covar = { + "scg": l2hmc.get_scg_energy_fn(), + "rw": l2hmc.get_rw_energy_fn() + }[FLAGS.energy_fn] + + x_dim = 2 + train_iters = 5000 + eval_iters = 2000 + eps = 0.1 + n_steps = 10 # Chain length + n_samples = 200 + record_loss_every = 100 + + dynamics = l2hmc.Dynamics( + x_dim=x_dim, minus_loglikelihood_fn=energy_fn, n_steps=n_steps, eps=eps) + learning_rate = tf.train.exponential_decay( + 1e-3, global_step, 1000, 0.96, staircase=True) + optimizer = tf.train.AdamOptimizer(learning_rate) + checkpointer = tf.train.Checkpoint( + optimizer=optimizer, dynamics=dynamics, global_step=global_step) + + if FLAGS.train_dir: + summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) + if FLAGS.restore: + latest_path = tf.train.latest_checkpoint(FLAGS.train_dir) + checkpointer.restore(latest_path) + print("Restored latest checkpoint at path:\"{}\" ".format(latest_path)) + sys.stdout.flush() + + if not FLAGS.restore: + # Training + if FLAGS.use_defun: + # Use `tfe.deun` to boost performance when there are lots of small ops + loss_fn = tfe.defun(l2hmc.compute_loss) + else: + loss_fn = l2hmc.compute_loss + + samples = tf.random_normal(shape=[n_samples, x_dim]) + for i in range(1, train_iters + 1): + loss, samples, accept_prob = train_one_iter( + dynamics, + samples, + optimizer, + loss_fn=loss_fn, + global_step=global_step) + + if i % record_loss_every == 0: + print("Iteration {}, loss {:.4f}, x_accept_prob {:.4f}".format( + i, loss.numpy(), + accept_prob.numpy().mean())) + if FLAGS.train_dir: + with summary_writer.as_default(): + with tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar("Training loss", loss, step=global_step) + print("Training complete.") + sys.stdout.flush() + + if FLAGS.train_dir: + saved_path = checkpointer.save( + file_prefix=os.path.join(FLAGS.train_dir, "ckpt")) + print("Saved checkpoint at path: \"{}\" ".format(saved_path)) + sys.stdout.flush() + + # Evaluation + if FLAGS.use_defun: + # Use tfe.deun to boost performance when there are lots of small ops + apply_transition = tfe.defun(dynamics.apply_transition) + else: + apply_transition = dynamics.apply_transition + + samples = tf.random_normal(shape=[n_samples, x_dim]) + samples_history = [] + for i in range(eval_iters): + samples_history.append(samples.numpy()) + _, _, _, samples = apply_transition(samples) + samples_history = np.array(samples_history) + print("Sampling complete.") + sys.stdout.flush() + + # Mean and covariance of target distribution + mean = mean.numpy() + covar = covar.numpy() + ac_spectrum = compute_ac_spectrum(samples_history, mean, covar) + print("First 25 entries of the auto-correlation spectrum: {}".format( + ac_spectrum[:25])) + ess = compute_ess(ac_spectrum) + print("Effective sample size per Metropolis-Hastings step: {}".format(ess)) + sys.stdout.flush() + + if FLAGS.train_dir: + # Plot autocorrelation spectrum in tensorboard + plot_step = tfe.Variable(1, trainable=False, dtype=tf.int64) + + for ac in ac_spectrum: + with summary_writer.as_default(): + with tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar("Autocorrelation", ac, step=plot_step) + plot_step.assign(plot_step + n_steps) + + if HAS_MATPLOTLIB: + # Choose a single chain and plot the trajectory + single_chain = samples_history[:, 0, :] + xs = single_chain[:100, 0] + ys = single_chain[:100, 1] + plt.figure() + plt.plot(xs, ys, color="orange", marker="o", alpha=0.6) # Trained chain + plt.savefig(os.path.join(FLAGS.train_dir, "single_chain.png")) + + +def train_one_iter(dynamics, + x, + optimizer, + loss_fn=l2hmc.compute_loss, + global_step=None): + """Train the sampler for one iteration.""" + loss, grads, out, accept_prob = l2hmc.loss_and_grads( + dynamics, x, loss_fn=loss_fn) + optimizer.apply_gradients( + zip(grads, dynamics.trainable_variables), global_step=global_step) + + return loss, out, accept_prob + + +def compute_ac_spectrum(samples_history, target_mean, target_covar): + """Compute autocorrelation spectrum. + + Follows equation 15 from the L2HMC paper. + + Args: + samples_history: Numpy array of shape [T, B, D], where T is the total + number of time steps, B is the batch size, and D is the dimensionality + of sample space. + target_mean: 1D Numpy array of the mean of target(true) distribution. + target_covar: 2D Numpy array representing a symmetric matrix for variance. + Returns: + Autocorrelation spectrum, Numpy array of shape [T-1]. + """ + + # Using numpy here since eager is a bit slow due to the loop + time_steps = samples_history.shape[0] + trace = np.trace(target_covar) + + rhos = [] + for t in range(time_steps - 1): + rho_t = 0. + for tau in range(time_steps - t): + v_tau = samples_history[tau, :, :] - target_mean + v_tau_plus_t = samples_history[tau + t, :, :] - target_mean + # Take dot product over observation dims and take mean over batch dims + rho_t += np.mean(np.sum(v_tau * v_tau_plus_t, axis=1)) + + rho_t /= trace * (time_steps - t) + rhos.append(rho_t) + + return np.array(rhos) + + +def compute_ess(ac_spectrum): + """Compute the effective sample size based on autocorrelation spectrum. + + This follows equation 16 from the L2HMC paper. + + Args: + ac_spectrum: Autocorrelation spectrum + Returns: + The effective sample size + """ + # Cutoff from the first value less than 0.05 + cutoff = np.argmax(ac_spectrum[1:] < .05) + if cutoff == 0: + cutoff = len(ac_spectrum) + ess = 1. / (1. + 2. * np.sum(ac_spectrum[1:cutoff])) + return ess + + +if __name__ == "__main__": + flags.DEFINE_string( + "train_dir", + default=None, + help="[Optional] Directory to store the training information") + flags.DEFINE_boolean( + "restore", + default=False, + help="[Optional] Restore the latest checkpoint from `train_dir` if True") + flags.DEFINE_boolean( + "use_defun", + default=False, + help="[Optional] Use `tfe.defun` to boost performance") + flags.DEFINE_string( + "energy_fn", + default="scg", + help="[Optional] The energy function used for experimentation" + "Other options include `rw`") + FLAGS = flags.FLAGS + tf.app.run(main) diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py index e230ad5e259df5b450897bd815e901e3934cd293..68e0bc31239007e3b1b8451cf1d6e7592c6ca030 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py +++ b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py @@ -25,7 +25,6 @@ from __future__ import division from __future__ import print_function import tensorflow as tf -import tensorflow.contrib.eager as tfe class GenericNet(tf.keras.Model): @@ -47,13 +46,13 @@ class GenericNet(tf.keras.Model): # Scale self.scale_layer = _custom_dense(x_dim, .001) - self.coeff_scale = tfe.Variable( + self.coeff_scale = tf.Variable( initial_value=tf.zeros([1, x_dim]), name='coeff_scale', trainable=True) # Translation self.translation_layer = _custom_dense(x_dim, factor=.001) # Transformation self.transformation_layer = _custom_dense(x_dim, .001) - self.coeff_transformation = tfe.Variable( + self.coeff_transformation = tf.Variable( initial_value=tf.zeros([1, x_dim]), name='coeff_transformation', trainable=True) diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index 54ebcad8e929c3195099121a290dd7c0651e5c9f..1ab1b71bd0549e06a1d86611c21faef1f182d740 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -41,11 +41,11 @@ "\n", "# Neural Machine Translation with Attention\n", "\n", - "
\n", - "\n", + "
\n", + "\n", " Run in Google Colab \n", "\n", - "View source on Github
" + "
View source on GitHub
" ] }, { @@ -383,6 +383,7 @@ "source": [ "BUFFER_SIZE = len(input_tensor_train)\n", "BATCH_SIZE = 64\n", + "N_BATCH = BUFFER_SIZE//BATCH_SIZE\n", "embedding_dim = 256\n", "units = 1024\n", "vocab_inp_size = len(inp_lang.word2idx)\n", @@ -677,21 +678,23 @@ " # using teacher forcing\n", " dec_input = tf.expand_dims(targ[:, t], 1)\n", " \n", - " total_loss += (loss / int(targ.shape[1]))\n", + " batch_loss = (loss / int(targ.shape[1]))\n", + " \n", + " total_loss += batch_loss\n", " \n", " variables = encoder.variables + decoder.variables\n", " \n", " gradients = tape.gradient(loss, variables)\n", - " \n", + " \n", " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", - "\n", + " \n", " if batch % 100 == 0:\n", " print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n", " batch,\n", - " loss.numpy() / int(targ.shape[1])))\n", + " batch_loss.numpy()))\n", " \n", " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n", - " total_loss/len(input_tensor)))\n", + " total_loss / N_BATCH))\n", " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" ], "execution_count": 0, @@ -906,4 +909,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb deleted file mode 100644 index 9c1af9c2084bac7ae6369babeaa13720e6199097..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb +++ /dev/null @@ -1,323 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "vDJ4XzMqodTy" - }, - "source": [ - "# Automatic Differentiation\n", - "\n", - "In the previous tutorial we introduced `Tensor`s and operations on them. In this tutorial we will cover [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), a key technique for optimizing machine learning models." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "GQJysDM__Qb0" - }, - "source": [ - "## Setup\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "OiMPZStlibBv" - }, - "outputs": [], - "source": [ - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "\n", - "tfe = tf.contrib.eager # Shorthand for some symbols" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "1CLWJl0QliB0" - }, - "source": [ - "## Derivatives of a function\n", - "\n", - "TensorFlow provides APIs for automatic differentiation - computing the derivative of a function. The way that more closely mimics the math is to encapsulate the computation in a Python function, say `f`, and use `tfe.gradients_function` to create a function that computes the derivatives of `f` with respect to its arguments. If you're familiar with [autograd](https://github.com/HIPS/autograd) for differentiating numpy functions, this will be familiar. For example: " - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "9FViq92UX7P8" - }, - "outputs": [], - "source": [ - "from math import pi\n", - "\n", - "def f(x):\n", - " return tf.square(tf.sin(x))\n", - "\n", - "assert f(pi/2).numpy() == 1.0\n", - "\n", - "\n", - "# grad_f will return a list of derivatives of f\n", - "# with respect to its arguments. Since f() has a single argument,\n", - "# grad_f will return a list with a single element.\n", - "grad_f = tfe.gradients_function(f)\n", - "assert tf.abs(grad_f(pi/2)[0]).numpy() \u003c 1e-7" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "v9fPs8RyopCf" - }, - "source": [ - "### Higher-order gradients\n", - "\n", - "The same API can be used to differentiate as many times as you like:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 276 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 730, - "status": "ok", - "timestamp": 1527005655565, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "3D0ZvnGYo0rW", - "outputId": "e23f8cc6-6813-4944-f20f-825b8a03c2ff" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAEDCAYAAAAhsS8XAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXd0HNX5sJ/ZXrTq3ZLV3IvcDdgGGwOm2WCbHhJa6C2B\nUBISQioBfoQPkjhACA4QCIQSDITQbGMbsHHvVbZ6s7q0vc18f4xmJVltJa0q+5zDOXhn9s7dqzvv\nfe/briBJkkSYMGHChBkxqAa7A2HChAkTJrSEBXuYMGHCjDDCgj1MmDBhRhhhwR4mTJgwI4ywYA8T\nJkyYEUZYsIcJEybMCCNkgl0URVasWMHtt98eqibDhAkTJkwvCJlgf+2118jJyQlVc2HChAkTppeE\nRLBXVlayceNGrrjiilA0FyZMmDBh+kBIBPvjjz/OQw89hCAIoWguTJgwYcL0gT4L9g0bNhAfH8/E\niRMJVycIEyZMmMFH6GutmGeeeYYPP/wQtVqN2+3Gbrdz3nnn8dRTT3X6HUmSwtp9CKittvH8UxsQ\nxZY/4aXXTGfa7PRB7NXAU1dj5y9PrIfmYUgeFcnya2aQmBI5uB0bYE5WNPHS/9uE6JcHYukVucw8\nPWOQezXw7NhcyCfvH0Bqfi+uumkO4ycnD3KvBpY+C/bWbNu2jdWrV/PCCy90e291tTVUj+03EhIs\nQ7qfWzfls2tzMTNPH01UrJEv/3eU5LRIVnx/5mB3rUP6azw3fnaMQ7vLOX1RNrVVNvIOVZGeFcPS\nq6YNmT6GmlP7KYoi/3ltF9WVNhacO4btXxfi9fi5+Mpc0jJjhkw/+5t9O0r5Zu1xDEYtpy/KZuOn\nR4mOM3HlTbNRqTo3UAynv3swhOPYhymSJJF3sAqtTs35l05mQm4K6VkxVJY2UVdtH+zuDRgOu4ej\n+yqIjDYwbW4a514yiYTkCMqKGnC7vIPdvQFjz9YSqittjJuSxNTZaVywcgoAX3xwCL9PHOTeDRyH\ndpej0ai47PqZTJyWwoTcFOprHBzdf3KwuzaghFSwz507NyhtPUzfOVnehLXRRdbYeLQ6DQATp6UC\ncGhv+WB2bUA5sLMMv19i2pz0gEaWNS4BUZQoOlE3yL0bGDxuHzu+LsRk1jH/nDEApI6OZtL0VFxO\nLyfLmwa5hwNDU4OT+loHozJiiIw2AjB7QSYajYrtXxXg9foHuYcDR1hjH6bkHawCYOzkxMBnmWPj\nMJq1HDtwEt93YBJ7PT4O7CrDYNQwPrfFhpo1Lh6AgmPVg9W1AaWyrBG/X2JCbjIGozbweXqWbIIp\nLawfrK4NKMX58kI+Oic28FmERc/UOWnYbR7yDn53tPawYB+GiKLI8SNVGEzaNvZTtVrFhKkpuF0+\n8o+OfKGWd7gKt8vHlJmj0GrVgc9j4kxExRopzq/7Tixw5cWNAKSkR7f5PHV0NIIApUXfEcHevEMb\nnR3b5vPxU5IAqChpHPA+DRZhwT4MKS2sx+XwMmZCYjuH0MRpKQAc3lsxGF0bUJQXNWdiYpvPBUEg\ne1w8Pq9IyXdAW60oaUAQ5Gig1uj0GhJTIqkqb8Lj9g1S7wYGn89PWXE90XGmgBlGITrWhN6gobIs\nLNjDDGE6MsMoRMUYSUiOoLKsacQ7zaoqrGh1amLiTO2uZY1LAKDgWM1Ad2tA8Xr9VFVYSUi2oNNr\n2l1Py4xBkqC8pGEQejdwVJQ04vOKZJyirYO80CeNiqSpwYXD5h6E3g08YcE+zJAkiZKCOswWHUmp\nHcdpJyRbEEWJupqRGx3jdvloqHWQmGLpMCciMcWCOUJH0fEaRHHkLnBV5U2IokRKelSH10dlyOaZ\nkW5nD5hhcuI6vJ48Sh6fyrLvhiM5LNiHGQ67B6fDS2JyZKdJXgnJcqxrdeXQj8vtLcpv6ywJSRAE\nMsfF43L6RvTLXF4sa+Kn2tcVkkdFodGoKCvqu8b+zjtv8f3vX8Fvf/ton9sKNUX5tWi0KlLSOl7g\nFDPVSJ4LrWm/dwszpKk5aQMgLimi03u+C4JdCeFLTOk8YSMlLYqDu8qpOWkjtRPBN9wpb/YzpHai\nsas1KlLSoygpqMdhc2OK0Pf6WWvWvMsf//hnkpNTet1Gf9BY76Sxzknm2DjUmo51VXlnBye/I3b2\nsMY+zKitkgV7fGLngj02wYxKLVBdaRuobg04VRWyYO/MHAUQlyCPkTJmIw2/X+RkeRNxCWb0Bm2n\n943KaA577IPW/vTTf6C8vIyHH76ft99+s9ft9AeKUzQto/MMW61OQ1xiBFWV1hHve4Kwxj7sUDT2\n+C40drVaRVyCmdpqG36/iFo9stZvSZKoKrditugwWzrXQKNijahUwojLxH17/XF25VXj8fhx+Hzo\nGh1s/+vmTu8X/SJ2RA59egTDxhMd3jNnQiJXLh7TaRsPPPAztm79lj//+UUiI4dWDZ76Zl9SXBfK\nDshmqZqTNqpPWgM295HKyHrjvwPUVNnQGzRERHa9pU5ItiD6pREn1ADsVjcOu6fbIl9qtYqYeBN1\nNfYRWXlU0Ty7W7hVzddFf181VYlApbUhhDLHYxPMXd6XnCbPl5PfATt7WGMfRng9PhrrnM2JJ11X\nx5Tt7BVUn7QGbO4jhZPliuO0+98VlxBBbZWdpgYnUTHtwyKHI1cuHsNdV83g1b9upuhELdf/cG63\ntvN/v7ydpgYnN99xxoirrFpX48Bk1rXJuu2IlsiYRqYxsiughjX2YURts2bSlRlGocWBOvLsy4p9\nPZiyvLGJshZXWzXydi71tXYMJm1QDtHYeBM+r4itaWTFcXs9PqyNLmLiu1+0IyL1mCN0VJY2jcgd\nXGvCgn0YEbCvd2NLBIiNN6NSCdSMwMiYqoqeaeww8hyoPq9fFmixwe1CouPkBa6+ti8L3NDT9Otq\nHED3ZhiQQ2ATUyNx2D3YbZ7+7tqgEhbsw4hgHKcKao2K2AQztVWyA3WkIIoS1ZVWYuJNHWZankpc\n8wtfO8J8DbLfAKI7yLrtiNhmjba+WRD2hnfe+YDIyKHldAzY1+O7F+xAIEu5sa734zAcCAv2YURt\nlQ2VWgj6ZU5ItuD3S4GogZFAU4MTr8dPQlJwfgNThA6DUTPinMi11fIi31E5hY5Q5kx97cgSaMrc\nDkZjB7luDEBDnbPf+jQUCAv2YYIoitRW24mNNwcdvjgS7eyN9fILGR1r7OZOGUEQiE2IoLFeXhBG\nCjXNpqXoYE0xMSYEoa+mmKGHUjYjJi44wR7VPG/CGnuYIUFDnRO/TwzKDKOg3DuS7MuNzZpWVJAC\nDVrMMSOpdk5AsAepsas1KiJjjNTXOEaU47Cuxo7ZokdvCC7AL6yxhxlS9MRxqqBotU0NI2cSN9bL\nmlZUTHAaO7Qkrijmi5GAYpazRBmC/k5snBm3y4fTMTKODHS7vNitnqDNMAAGoxaDUUNDWGMPMxRQ\nbMTdZde1Rm/QojdoaGxw9Ve3BhzFFNMTwa68+HUjJORRkiRqquxExciZtcESHXCgjoxxCETEBBHq\n2JqoWBNNDc4RFVRwKn0W7B6PhyuuuILly5ezbNky/vKXv4SiX2FOQdG6eyLQlPubGpyI4sjYfjfW\nOzGatEFFxCgoERMjJTLGYfPgcfuCdpwqxI4wB2rAcRpkRIxCdIwRSQJr48hReE6lz4Jdp9Px2muv\nsWbNGtasWcOmTZvYt29fKPoWphVNDU7UGhWmCF2PvhcZbUT0S9itwz8xxe8XsTa6Ag6wYNHq1ETF\nGKkbIaYYRTAHa19XiGkWgL0NeWxdtvebb77ijTdeDfq7lZUVfPHFp0Hd+/jjv2bjxvXd3te6lMCa\nNe/x2Wf/C6r9qICdXR6HTz75L9XVLUdJPvnk7ykqKgyqraFKSEoKGI3yi+bxePD5RvYRXINFU4OL\nyChDj9PBFQ2/sd7ZI3vsUMTa6EKS6FVpgMgYIyX5TjxuX4+0/aGIIpCCTU5SUByHvY2MObVs7/z5\nZ7a7x+/3o1ar231eXl7GF198xnnnXdCrZ3eE4gyPjDawfPllQX9PGQfFEf+//33EzJlTSUrKAODh\nh38esj4OFiGZ4aIosnLlSoqLi7n22mvJzc0NRbNhmnG7vLhdvnZnWgZDZLQszGVTTudlTYcDgYiY\nHpqjACKjlHFw9SiyaCjS0EuNXatTY4nU98oU07ps78UXX4LFYuHIkUPcd99DPP74r7FYIsnLO8r4\n8ROZP/9MnnvuaQRBQKvV8OyzL/Dii6soKirkppuu5YILlnLllde0af+ZZ55k9+6dpKSktonaOXr0\nCH/+8zO4XC6ioqL5+c8fIzY2jnvuuQ3RFUt1XSGxH5Zht9sxmUycccYCfve7x3jpJXk3UVlZwcMP\n38+rr77JK6/8nW+++QqHw4lWSmTS9HvYsGEdR44c5sEHH0Sj0fL886t54IF7ufvu+zh8+ADl5eXc\neee9gKzZHz16hB//+AE+//wT3nnnLfx+H5MmTeEnP/npkKrBExLBrlKpWLNmDTabjTvvvJPjx48z\nZkznJUDD9IymZufnqYf0BoMiBEdCZExDLyJiFJQFztroHPaC/Vv315ROK+RP+ZsRCnomTJxjPfh8\nInnfbGgjiGYkTmXlmKWdfu/Usr2ffPLfNt8vLS3mT396AYCHH76Pn/zkp0yZkktEhIamJg+33343\nb731Ok8++f/atb1x45eUlpbwz3++TU1NDd///hUsXXopPp+PZ599iieeeIaoqGjWrfuCF19cxc9+\n9kskScLpsHPdlT9l6VXTWL36bwBkZGTi9/uoqCgnJSWVdes+55xzzgPgssuu4oYbbsbn9XPj9+9k\n566t3P/Idbz33tv88pe/ICGhbWGwRYvO5fbbbwwI9nXrPuf6639IUVEh69Z9zgsvrEatVvPHPz7J\n559/wvnnX9Sjv0V/EtI9aUREBHPnzuWrr77qVrAnJAyPioNDoZ/VzdUMU9OiO+1PZ58b9HLFO5fD\nNyR+S1/64HHKCUaZ2fE9bidttLxb8fukbr87FMapK9xuH6oIAU0v6uyrNSp8PhEkUKtbBLPJqOv2\nd6tUEBdnJjragsViwNj8HYNBy8KFSwPfP/30uTz//HMsW7aMJUuWkJSURHS0CZ1O0+Ezjh07wIoV\nl5KQYCEhwcK8eWcQGWnEZquhoCCfBx+8F0mSEEWRxMREEhIsCAhkpE4nMTmShAQLZrMes9lAQoKF\npUsvZuvWTdxyyy1s2rSeZ599loQEC7t2bebll1/G6XRSVV9FWXkaCQkWtFo1ktQyL7RaNTExJsaO\nTSczM4OKigJGjx5NeXkpixcv4I033uD48WPccceNSJKE2+0mLS15SM2bPgv2uro6tFotFosFl8vF\nli1buPXWW7v9XnX10C9OlZBgGRL9LC2WDyJWaYQO+9NVPyVJQqNVUV1pHfTf0tfxPFkhn5QjIva4\nHalZhlWUNnb53aHyN+8Mr9dPbN5YZo45gwvPn9Lj7x/aU87GT49x9sUTmDA1uc217n63KErU1trw\netVYrS6cTg/V1VZcLi8+X8vcXLHiGqZNm8uWLV9z5ZVX8swzq2hocODx+Dp8htPpwWZzB6653V6a\nmpzU1dnIysrm+edXt+uny+VFY9Gh0amorrZit7uRJDXV1VZOO+0sHn30p8yaNQ+/X8JojKGsrJZf\n/erXrF79OvHxCTx8/2+wNjgpL6vH6/W3+f1er5/6egfV1VYWLDibd99dQ0ZGJvPnL6S62orV6mTJ\nkou47ba7ejR+oSDYxaPPUTHV1dVcd911XHrppVxxxRUsWLCAhQsX9rXZMK1QzCi9McUIgkBktJHG\nBuewzzhsqHNiMut65fxsbYoZziip8PGJPQvxU1Ac6LZ+DPUrKyslOzuHa6+9nilTplBcXIjJZMZu\n79hpO23aTNau/RxRFKmpqWHXrp0AjB6dSX19AwcO7AfA5/NRUJAPtBwy0tE7MWpUGmq1ilde+TuL\nF8tmGI/HgyBAZGQUDoeD44W7geY5ZTJhs3UcMbVw4WK++mpDG5POrFlz2bBhHfX1ssLV1NREZWVl\nr8aqv+izxj5+/Hjef//9UPQlTCcoSTmW6N5FtcihfnacDi8mc8/CJYcKfr+IrcnV6yPN9AY59r1p\nmMcuK6nwSjninqIIdmtTb8YhOHv+O++8ya5dO1Cr1YwfP47TT58PgFqt4cYbv8eFFy5r4zxduPBs\ndu3azvXXX016egYzZswCQKPR8LvfPcmzz/4fNpsNUfRz5ZXXkJWVjd8vtfk9p7J48RKef/5P3HLL\nnYBsJl62bAXXXXcVKSmpZGeNw14vv1sXXbSMxx57DK1Wx/PPr27jO7BYLGRmZlNcXMiECZMAyMzM\n4pZb7uT+++9CFCW0Wi333/8QycnJHfZlMBCkQVLjhvJ2V2GobMtff/5b/H6R6++e1+H17vq5ef0J\n9m4rYcX3Z5CcNnhlV/synvW1dt56aTsTpiZz9sUTetXGO//YQUOtg5t/cmanEQxD5W/eGbu/Lebb\nDflcddOcwCEiPcHn8/PS018xKiOaS66Z3g89bEt/jeen/zlAwbEarr9nXq+UleL8Wj5+ez9zFmQy\ne0HmkP+7KwyYKSZM/6JoqpG91NahVSz7MI6MCZQS6GFyUmssUQZ8PhGnffgesqBkS0b38pg/jUaN\nyawb9lmX1kYXGo0Ko6nr4/A6I1AMrH5kZOGeSliwD3FsTW4kCSKjei/QomLkRUERjsORvsSwKyj2\n2OFsjrE1m1D6Mg4RUfrmeTV8fS7WRheWXiTsKUREGlCpBJrqh+9c6IqwYB/iBBynoRBoI0Fj78OB\n1C3JWsP3ZbY2udHp1d0e3NwVkVEGRFEatsfDuV0+3C5fnzKpVSoBc4QOm3X4zoWuCAv2IU5LclLv\nJ7GinQxrjT0g2Hs/DgHH4TBd4CRJwtroIiKyb6UhlO/3Z2RMf6LsWvpaIiMi0oDd6hmRVR7Dgn2I\n05dQRwWVSsASbRjW205rkwuDSYtW1/tAroDGPkwFmsftw+vxY4nU96kdRSAO13FQ+t3bKDGFiCh5\nHEdCgbxTCQv2IU6LYO/bJI6KNuJyyjVnhhuSJGFvchNhCZFAG6amGGujLIAi+qipBmLZexXyOPhY\nlV1sCDR2kP1YI42wYB/iNDXI3v++xp8PZzu72+XD5xOJ6KOmqtGoMUcM34gQJfbc0kdTjPL94TYO\nu3fv5KGH7gv0uzNTzD333MbRo0e6bU/Z+diaXPzpT39i587tverX22+/idvdsjg89NCPsdsHt0R0\nWLAPYSRJoqnBiSW6995/BWXbaRuG207lRY6w9L3ssCXagK3JNSztqoqG3dcFztI8F4abYAcQBLoV\n7MGiaOyNDU7uvfdeZs2a06t23nnnTdzulrF86qlnMZsHt9Dc8C5MPcJxu3x43H5S0ntvX1dQzBjD\n0Z6oLEZ9FWggh41WljZht7r75LcYDBRTTF8FmlanwWDU9Eiwu1wufvnLn1JdXYUoilx//c0sXnxu\np2V1y8pK+b//exybrQlJEvjtb58gNXUUq1Y9x9atmxEEFddddxPnnHMeu3fvZPXqvxEVFU1BwQkm\nTJjIo4/+FoBvv93Mn//8DNHRMYwdOx6ApkYnGq0qEBnkdrt5/PFfU1RUSEZGBh5PS7TP9u3f8vLL\nf8Pr9TJqVBqPPPIYBoOBK664hLMXXcAXm7/Er1vGV9veZNas09HrDfzvfx/xm9/8AZB3Cf/+9xs8\n8cQzPP30Exw9egi3282iRedw00238u67b1FTU80999xOdHQ0zz33PFdccQkvv/xP3njjNZKTU1ix\n4nIAVq/+G2azmauuupZ//euffPnlF3i9Ps46axE33dR9fa2eEBbsQxjlxeurLRFaBPtwtCfam0In\n2C2tQh6Hm2BXNHb/hv+yY9WePu065tg8iH6J/IffBsAyew4JV1zd6f1bt24mPj6Bp556FgCHw95l\nWd1f//oXXHfdjaxYsZTy8jpEUWTjxvWcOJHHa6/9m/r6Om6++TpmzJgJQF7eMV5//R3i4uK4444f\nsn//XsaPn8hTT/2eP//5RUaNSuOXv/wZ0D6Gfc2adzEajbzyyr84ceI4N910LQCNjQ28+upqnnvu\nr+j1Bt5441Xeeut1brjhZvk3R5pZMu8uRqfHcrSkVB6XOafx9NN/wO12odcbWLfuCxYvXgLAbbfd\nhcViQRRFfvSjO8jPP87ll1/Nv//9ZqCcsYzcr3PPXcJzz/0xINjXr1/LM8/8me3bv6W0tJiXXnoN\nSZJ4+OH72bt3D9OmhS4TOCzYhzC2EAo087DW2BUTRN8XuMCBG43D7+ARa5MLlUpAq1XTVxe4oBKQ\n/CKSJJs3uiM7ewyrVj3HCy/8hTPOWMC0adPJzz9Bfv4J7rvvruayuhLx8Qk4HA5qaqpZsEAuBqjV\nypr1vn17OPfc8wGIiYllxoxZHD58CJPJxKRJk4mPjwdgzJhxVFRUYDAYSU0dxahRaQAsWXIhH3zw\nH3kXm9ayKO/Zs5srmhelnJwxjBkzDoCDBw9QWJjPHXf8EEmS8Pl8TJkyLfC9JUvO57//ymuj7KjV\nak477Qy+/vorFi1azJYtX3PXXT8CYN26z/jwwzX4/X7q6mopKCggO3sMIDX/pyD//9ix42loaKC2\ntob6+noiIyNJTEzinXfeYvv2bdx007VyXXmni9LS4rBg/66gCGFzH6NBWrcxHCMhrMoCF4JxaHEi\nD79xsDW6MVv0JF55NQl33dKn2ibfrDvOvu2lrLxuJkmp3Z/MlZ4+mpdffp0tW77hxRf/wty5p3PW\nWYvIzs5pV1bX4ei4iuOpma6t/60IfwC1WoXf3/HS5fPKu5RTzVGtfVBKu5IkMWfO6Tz22O86bMto\nNBIRaWj3TixefB7/+c/bREZamDhxMkajkYqKct566w1efvmfmM0RPP74r/F4uleSzj77HL78ci21\ntbWcc86SQL9+8IMbuOSSFd1+v7eEnadDmIBtOQQCTa2WD8Iejs5TW5MbQQCzpe+VKZXdj32YmaT8\nPhGH3ROyc2t7GvJYU1ODXq9nyZILuOaa73Ps2NFOy+qaTGYSE5P46qsNAHi9XtxuF9OmzWTdui8Q\nRZH6+nr27dvDpEmTO31mRkYmlZUVlJeXAbB27Wf4mmuntx6H6dNn8PnnnwCQn3+cEyfyAJg8eSr7\n9++lrEw2s7jdLkpKits8IyJSj8ftlw8faWbGjFkcO3aUDz9cEyjVa7fbMRqNmExm6upq+fbbzYH7\nuypJvHjxeaxb9zkbN67n7LPPAeC0007n448/xOl0No9tdaAEcKgIa+xDmFBq7CAvEDVVNiRJGlLn\nM3aHvcmFKUKHStV3PSSwcxlmC5xijuprcpKCEvIYbJJSfv5xVq16DpVKQKPR8sADP+uyrO4vfvFr\n/u//HueVV15CENT89rdPsHDh2Rw8uI8bbrgGQVBx5533EhMTS2FhQZtnKXNTp9Px4IOP8OCDPyI6\nOobc3OmcrKiT+99KsC9ffjmPP/5rbrjhe4wdO45Jk+QDSKKjo3nkkcf41a8ewePxIggCt9xyB+np\no1Hs4Ip5T1kwQD7qc968BXzyycf84he/BmDMmLGMHTueH/zgKlJTR5Gb22LSueSS5TzwwL3Exyfw\n3HPP07q8cVZWNg6Hg4SEJGJj4wCYM+d0iooKuf32GwEwmUw8+uhviYkJnWkwXLa3Cwa7lOcH/9pD\neXEDtz54FuoujkELtp99LXXaV3oznqIo8dLTm0hIsbDyBzND0o9X/vQNOr2G7912Wkj6OBCUFtbz\n0Vt7mTUvg7lnZfW5nzUnrbzzj51MmZnKmUvGhbCnbQn1eH79RR77d5Zx+Q2zSEju+1F0u7YUsXVj\nAVf/cC4xCb2vQzRQhMv2jgDsVjdGs7ZLod4ThmPIo8PuQRSlkJijFMwWPXbr8KpuGKr6KAqBujnD\nLJY9lKGvcjtKlNTwS9zrirBgH6JIkoTN2vc0+tZERA6/kMdQJeW0JsKix+cTh1V5hUCSVojGQT5R\nSh1wTA8X7FY3KrXQp+qWrVHGczifVdARYcE+RHG7fPh9Ysjs69CqNsYwKlVqDziQQ6OpwvAM/VQW\n41Bp7CDb2a2NrmG1c7Fb3Zgj9CHzEQV8DcO48mlH9FmwV1ZWct1113HRRRexbNkyXnvttVD06zuP\nLYQhfgrDWaCFUmMfjg5UpU5MSOdDpB6vx4/X4+/+5iGAKMqRQaGIjlIwRegQhJGnsfc5KkatVvOz\nn/2MiRMnYrfbWblyJfPnzycnJycU/fvOEuqIGBie2afWfjDFBBY42zAah0YXRpMWjVYdsjbNES0L\nvU4/9APkHHYvktTS71AghwHrh/VZBR3RZ409ISGBiRMnAmA2m8nJyaGqqqrPHfuuE8oYdgVThKzp\nDCfB3qKxh84EEXAiD5NxCPhbQjgGMPwWuP5QdkAOIW1qdCGKw8ck1R0htbGXlpZy5MgRcnNzQ9ls\nv2I/sA9nfv5gd6Md/TGJ1WpV83Fg7V9kT2UFjsOHQvasUKE4y3p7aHFHKFv51uMgSRKi14vo9SL5\nhpZT1enwIvqlkO5aAMzNC73d2lI0S3S7se3Zjd/WtuyszWbj/fffDfxbKaHbEU8++XuKigq7fX5X\nbbRGKcMbeCeC0NhffvnFoMvwRkQakEQJR/MC9/bbb+JyOrHu2IansmJIlOHtKSHbf9ntdu69914e\neeQRzGZzt/cHG4/Z3xT+8xU8dfWkLL2IjB9ci1rfdtIMVj+V1OnRmXHExoduPKNiTVSWNRIfFwGS\nSPlHH1P15QYchUUApF9zFaOvvrL3HQ9RPxUcNg9R0UYSE7tPew+WSItcVsDr8ZOQYEH0ejn8uz/Q\nsGcvxwFUKjK+/z3SLuu/lO+eUOlpBCA+IaLN+PV1bqamRcv/I8lt+Z1ODj3zJE2HDiOo1URNyyV1\n6UXEzJqJ293IRx/9h1tvlZNqoqNN6PWaDvvw9NNPtPm3co8oim2SzLpqozVarZqYGBP2etlhmjIq\nqsvviKLIT3/6QPcD0ExisoXjh6vQqNUkJFh49+03mJF/HOn4CZIvWMI//vFy0G0NFUIi2H0+H/fe\ney+XXnoGQVm+AAAgAElEQVQp5557blDfGSpJIEm33U3l6r9R8dHH1Gzbwagf/wRdQiIwuMkqNVXy\nc90eb7d96Ek/DUYNol+iuKgW19frqHn3bVCrMU+bjqesjJI3/43D4SFu2aV9/g196SfIafQ2q5vU\n9KiQ/x10ejX1tQ6qq61UvfUGDXv2ohuVhikhDmt+AUX/fANfXDLmyVNC+tzeUFoip5urNEJgHEIx\nN33N1SGrKps4WVpD2XPP4Dx2FOOEiYgOBw27dtOwZy8Zj/2Gx//2V4qLi1m27BJmzz6NM86YT0ND\nE7fddme7Urv33HMbd999H+PHT2DJkrO46qpr2bbtW+6++8fY7fY2ZXg9Hl+733FqGV673Ul9vYP6\nCh8V1cf42aOrEVRSuzK8F198Cdu3b2XlyivZunUz8+efGVQZ3sYGG3GWCZQUTeTdF/8fVSdP8ugX\nnxEdHcOq85exaNHZg16GVyHYxTwkgv2RRx5hzJgxXH/99aFobkAxZmeT8cvfUPOfd2hY+wU1775N\n6h13D3a3sFvdGIyhdZZBS9hgQ1k19o8+QB1hIePXv0MTFYW3tpbS/3uC2g/eR9DpiD3/wpA+u6co\ntt9Q25ahJUnJumsnDWu/QJeSyuhHHiUpLZ6SbXspfuL3VP79RTIe+y2a6OiQP78nKONgajZBbF5/\ngsK8GsQ+HhaimJSP7KvEsXsHWXlHiZg9h5RbbkdQq7Ht3kX5qj9R9cY/uf32uykszGf16jcAWUB2\nVGp36tRpbZ7hdDrJyRnDD394Gx6Ph6uvXtGuDO+pdFaGt7qqhgN5a3nxpb+RkBTdrgyvTqdn1aqX\nALnMMARXhvfEkZM89PC9HDt8mNOLi3lPq+WPj/6G1IVnN4dVDn4Z3p7SZxv7zp07+eijj/j2229Z\nvnw5K1asYNOmTaHo24Ch0ulIuOp76DOzsO3cgfuUQkEDTX8kJykoNvvyz9Yjud3EX34lmqgoALRx\ncaQ9+DDqqGhqP3i/nZ11oOmPUEeFCIset8tH+auvIOh0pNx+F6pmM5whK5uEK67Cb7VS8dILSOLg\nnrak2MAVm3ioUELBRb+Ir64Oc+40Um6+DUEtKxMRM2ZinjET57Gj2Hbvavd9pdSuIAiBUrunotFo\nWLhwMQBFRYXtyvB2xJ49uwPXWpfhPX7iCI22kzz007u48cbv8emnH3Py5MnA95SCXa1pXYbX7/ez\nZcvXnHmmXE543brPuOmm7/PL395Do/Ukx3ftQLTZEIwmLDNntYqVb1+G9/jxvEAZ3m3btgbK8N50\n07UUFxdRWjq4MqTPGvusWbM4fPhwKPoyqAiCQPzyFZQ9+wy1H35A6l33DFpfPG4fPm9ok5MUApl2\nReUk5owhct78Nte1cfHELDmfmnf+TeNXm4i98KKQ9yFY+iPrVEFxwDk9kPW9a9GPGtXmevQ55+E4\nchj7nt3Y9+4hYkZo6tT0BsWpp8yHeYtzuPSq6SExT73+/Ld4GxsZW7uDhB8/jqBpKxISr7qGwoMH\nqPv4w3YLXDCldnU6Xa+SiToqw+tyekhLnsA//vFih98xGjs+OKW7MrySX8Odt/4Ya1UtKrMZVSft\nwOCV4e0p4czTVpgmT8WQnYNt905cxUWD1g8lWsPcLwJN1vpcmggSr/0BQgcVE6POPAtBr6dh/dpB\njRCx2xRNNfTjYDLJWqkvbhSR889sd10QBOIvXQlA49eDuwPtL40dwKgRcUtajJNz0aWktruujU8g\n9qKlaB0ObLW1PW6/dVZrR2V4O6KzMrwW4yiq6gq6LMPbEd2V4XW6rZRXH8EraIg9/0LM5oghV4a3\np4QFeysEQSDuUnnVrf1wzaD1w94PMewK2qZqAPyJ6RhGZ3R4j9pkJmr+Anz1dR1uwQeK/opbBlDX\nyMJFGJ/b4eIGoE9PR5+ZhX3/PnwNDSHvQ7DYbW40GlW/JBFprDVIggrDWZ0HPcScfwGRlkhydDqu\nu+5q/vrXP7W7p7WG3dn/63Q6Hnro5zz44I+4665bSOlgIQG5DK/D4eCGG77Hm2++zqRJU/B6fKgF\nI8vOv5lf/eoRrr/+Gm677SaKAwpY57sCpQzv1q1bmDdPXsRbl+F96onfkBSdgU+tJ3rxOYEyvD/6\n0R3t2u6sDO95553P7bffyPXXX82jjz6M0+notD8DQbhs7ylIkkTJE7/HdeI4s/72PFbVwJ+LeWhv\nORs/OcbZF09gwtTkbu/vSYTEybfe5D8FSSTGaLns9vaaqoLnZCWFP/8phpwxjP7ZL4Lue6j6CfDZ\n+wfJP1rNdXefEXKtffsfVrFDmMycuUnMXjyx0z42bFhP1euvEb/ycmIvWhrSPgTLK3/+Bp2ubZnh\nkETFNNTz6ZNvURI1kcuun0liSuchpZWvrKbp602kPfAwpgkTO73vVEIVWVZfY+etv29n4rQUFl04\nvs/ttWl7/Vo+3tSI0xTLzQ8uGtJnFYTL9vYSQRCInLcAgLqt2walD/Z+qBMDIIki9p3b0IsunGLX\n2p8uKRlz7jRcJ47jzD8R0n4Ei8Mun5xkNIXWBOEuL0dVIv8mp6/rqCPL3NMRtFoav/lqUIpl+f0i\nTrs3kDUcShq+XI/eKzvIFbNXZ0SedjoA1m1bQ96PYLDb+m/3Ztu1E53fgU8Uhk3dnO4IC/YOiJg+\nAwSB2i3fDsrzbf1kgnDmHcNXX49Rr8Jh93QrqKIXy9tza6tjwAYSu9WDyaxDpQqtBtX09Sb0Pnvg\nGV2hNpmImDUb78mTOPOOhbQfweC094+fQZIkmrZuwaCSfSjdFYYzjp+AOioK687tg+J3USKkQlkA\nDMBvteI8dpSICNkRPFzKK3RHWLB3gCYqCuOYsTQdPoKvqWnAn99iYw/tJLZukxeqiPhI/H4Jj7vr\nF9Q0YSIqgwH7/n0Drq1KkoTD7gm5pir5fDRt+Qa9SYtaLQRV4TFqwVkANH61MaR9CYYWB3Jox8FT\nUY6vpoao0SmAnOHbFYJKhWX2XES7HfuhgyHtSzD0lyPdtncPiCIxaXJSYncL/XAhLNg7IWLGTJAk\n7Ht2D/izbVY3Or0arS50zjLJ58O6cwfqqCgik2SnT3eTWNBoME2egre6Gm9l+xjl/sTjluvRm0L8\nIjvzjuG3Womae7qcpBSEhmYcPwFNfDz23bsGXFvtLweyfd9eAGInjmnznK6wzJVt/IqCMJD0V0CB\nbfdOABInZAItoaXDnbBg74SIGbOAlj/8QOKweUKumdgPHUS02bDMnoup+eVw2LufxObmTEJbsyAY\nKPorxM9+8IDcbm4uZoseh82Dv5sMTkEQME/JRXS5cBUMbME4RZMO9c7Fvm8vCALxM+SSCcEscIbs\nHLTxCdh270Z0D6wA7I8FTnS5cBw8gG5UGjHpSfJzutm5DBfCgr0TtAkJmLMycRw+hN85cLWa/c1H\ntoX6Rbbtkhcoy9zTWqr6BTGJzVOnyvfu3xfS/nSHsuiEWmN3HDyAoNFgHDs+ICQUO3ZXmCdPBhhw\nM0TAaRjCcfA77DiP52HIysIQG41OrwlqLgiCgGXuaUhu14BXArXb3KjVAnpD6Hax9gP7kXw+ImbM\nDJykFLaxfweIPf00JJ8P+/6B01Zb6oKEVrA7jxxGZTJhyMoOtN2dXRVAExWNPjNLNmEM4ALXHxq7\nr6kJd0kxxrHjUOn1LQePBGGGMI6fCCoVjmaNf6DoD03VcfAgiGJgN2a26II+VcvUXBTNcWSABbvV\ng9kSuiPxoGU3HjFzVuDIwWDeieFAWLB3QdzpcwGwD2CSjqNZezSZQ/cie2uq8dZUYxw3HkGlajk5\nJ0jtxDw1F/x+HIcGTqgFxiGEgt1xWNa2TZNk4WQyB7/AqU0mDNk5uAry8XeSldgf2PvBFKPY1825\nzYI9Qq6b4/N2H+pnyM5B0GpxHDkSsv50h9/ffCReCHctkt+Pfd9eNHFx6NNHN5+jGjbFfCcwZWSg\njo7GcfTIgEWFOPohCkJ5CZXEkp4INICIZgFg3zdw5pieHKoQLIq2bWo2q/Rk5wLIJXwlaUC1VbtN\nPrZOG6Iqn5IoYj+wD3VUNPrmzOOWk5S6HweVVotxzFg8pSX4rAMTMRYI+QzhrsVdXITodGKeMhVB\nEFCpBExmXdh5+l1AEARM48bjb2rC26qKXH/SH84yx1G5SJsi2I1mbY+0E31GJmpLJPb9ewes0mGo\nw/wkScJ+8CBqiwV9Wnpz280CLQgnMoBpkrwgOAbQzu6whfbwZldhAX6rFfPU3IBZo+UkpeDGwdg8\nj5xHj4asX13RktcRwnfimNx347gJgc9MEXrstu7zO4YDYcHeDcaxcvqy89jATGJFyChadV+RJAnn\nkcOoLRZ0qXIFQ5VKhdEUvHYiqFSYp0zF39SEp6wsJP3qDiXr1BCirFNPeRn+xgZMkyYHasP0VGM3\nZGahMhqxHzwwIC+/z+vH7fKFdtfSvCgpTnHo+dmnioLgODIwVV0d/RDDrrzPxrHjAp+ZI3T4fWK3\n+R3DgbBg7wbjOFmwO/IGRrAHJnGItp3eqpNytun4CW2KXZkidEFlnyooL4DzeF5I+tUdoc46DZhh\nJrWciNRTk5SgVmOaOAlfTQ3eATiwvT+Sk5S/n6KwyO03C/Ygk3MMGZkIegPOgRLsIfa3SKKIM+8Y\n2oQEtLGxgc+VMOCRkKQUFuzdoEtJQRURMWAae8AUEyKNXdGqTi3cZI7Q4fOKeNzB1cYwjh0LDIxg\n74+sUyVMUQlbBNDpNWi0qh5FQgSiQgbAkRyIkArRIi+JIq4Tx9EmJaGJbCn4pZg4gtXY5XDRcXgq\nK/A19H952lC/E56yMkSHo83iBq1MUiPAzh4W7N0gqFQYx47DV1uLt7am35/nsHnQaENXotXZiWBX\n4sODSVIC0CY3L3An+l+whzrrVBJFXMfz0CYno4mOaXPNZNYFbWMHMI1vti/n9f84hNqR7ikvQ3Q6\nMeaMbfO5orH3xHFomiDbph1H+z865tSjAfuKsvtWduMKph7kdwx1woI9CEwBO3v/F4Gy290hsyVK\nkoTjyBHU0dFok9qW/+2xGUIQMOaMwVdT0+9aWqhj2D3lZYguF8bsMe2umSL0uBxeRDE4k5Q2KUle\n4PKPh6RvXRHqyKCAGWZMW8FuNMsFsHq0c5kwSf7OAJye5rSHVmMP2NfHnaqx93yBG6qEBXsQKBPA\n2c92dlFsLtEaqi1nRTl+axOm8RPbJXa0bL+Df5kVgdDf5phQZ50qZYcNOe0FuzlChySB09GDBS47\nR17gGvv38I1Ql6pV/m6GUwS77EzXYg8iA1dBP3o0KpMJ59H+F+x2m6f5oJG+h3xKkoTz2FFZ2UlI\naHOtJToorLED8MgjjzBv3jyWLVsWiuaGHPr0dFQGQyBEqr9w2r1A6JxErhOyVqnYx1ujJED1RDsZ\nKMEeao3ddUIW7MacnHbXerpzATlJB8DVz3XqQ+08dR0/jspsRpfc/vAWU4QuqNIKCoJKhTFnDN7q\n6n6vgKr4W0KRdeo9eRJ/UxOmcePbtWfqYeLeUCYkgn3lypW8/PLLoWhqSCKo1RjGjMVbWYmvsbHf\nnhNq779SsEoRRK3paagfgD4zE0GjwXm8f80QIR+H/BOoDIZAuGdrejMOxmbN33mifwW70idjCHZw\nvoYGOfs4Z0yHRwGazDo8bj/eILJPFQxZ2QD9WhhNFCWcdk/ozTCnOE4BjCYtKpUwIsoKhESwz549\nm8jIzo/VGgmYAuaY/rOzh7rgkzM/H0GnQz8qrd21nhQCU1BpdegzMuWsvX6s7hdK27LfbsdTUY4h\nK7tjgdbDJCUAQ1YWCEJgR9RfOOweDEYtanXfX9PO7OsKyjj0RGs3ZCuCvf8WOJfTiySFbpFX3l/j\nuHHtrgmCIIcBjwCNPfSn445QAtvvgnwss+cAYPPa2V65G5WgIjMyndSIFLSq4IZUkiRKqmwcKqzH\n6/Oj16pxVcs1SEKhnYhuN56yUoxjxiKo29smjQETRM8msXHMGFwnjuMqyA9E2tS7GjhQewS7186M\nxFySTAndtNKCy+OjqNJKQYUVm9NLXJSB6pPyGZmhMEEoQqejXUvrZ/RES1MZjOhSR+EqKkTy+RA0\nGlw+F2W2SmpddVg9NqbGTySxB+NQ3eCk+KSVmkYXjXYPybEmbFY3lsj+ta8rKHPObvMQGR3cOb+G\nTEWwFwQ+a/JYOVhzBLVKjU6tY5pxLALB/4ayGjvHShpweXx4vCKG5jyLUNVOchacQGU0ouvkIG1T\nhI6aShuSJA3ps0+7Y9AEe7CHsg42Sj995qmUCgL+smI0ESJrDn/G+vxvcPtbBIJereOHs65mUdYZ\nnbbncvt4d30ea7cXU9voanMtFRiFig0HK7GkRzNtbPCC4dTxbDxYDJJEzKTxnY61KUKH2+Xr0d9C\nNWsa9Z99iqqiGPuUFJ7f/k8K6ksC1z/K/4zxcdlcPuVipiVP6rSftY1O3vz8KGu3FeM/JSJlAgIR\nCHyxr4JLzxpDQkzvDxR3VpYCkDRzKrEd/E7RKz9b8kuBvgUzHo1TJnLys1JM9joKLV6e3foyTW5b\n4PoHJ/7HOTkLuHzyxUQbOt7NSpLEoYI63t9wnG2HKmmdKyYAs1FRWu9k8+EqLpqXiVbTdoHuyd+t\nvCgfQaMhbfZU1Pr2QjIxSW5Lq1YF326ChbKUZNyFBcTGGvmycAtv7H0fu7elCqjmoIZrc5dz4biz\nUQkd7zx8fpEvthbxxbZi8kraOqQjgfGo2Ha8moRJSSyYntprgeuz2zlWWUlU7lQSk6La/5wECzGx\nZqrKrUSY9CEvGT2QDJpgD8XJ5f3NqSes65JTsB4/zs8+e4I6dwMx+miWZi3BrDVT2FTCjpO7+eu2\n1yisKueirPPaTEBJktiTV8O/1h6jtsmN2aDh9MlJ5GbHYTHpcHv9HPy2GFu5lX2FdWx9YTNn5qZw\n9TljMXYT097RSfB1u+WEHCk5vdOxNpq0NDW4evS38CXIduqSndt5Ub0Zl8/FxNhxTImbiElrZGvF\nTo7WHucPm1Zx69TrmBrfItwTEixUVDby4TcFfLatBK9PJCnWxPQxcWSlRBJl1lHX5GbfF3l4PH4+\n2JTPf78uYMVZ2Vxw2mhUvXiha/fLBbs8cakd/k63V3ZY19bYqa62djiWHZI6GoAvv3iP12MLUQkq\nFqbNJ9mUiFqlYm3RRj4/vomvCrfx4xm3k2ZpqyHaXV5Wf3yY3XlybkRWSiRzJiQSH2Ug0qyjsKSB\nE5sK8UgSf//gAO9/mcfV54xj1viEwFgG+3cT3W5s+QUYMjKpa/IA7XcnIvKqUlHeSHxK8AuGdnQW\nrq1b+MM7v2efUIlBrWdZ9gVEaE04vE6+LPuKV/e8y9aivVw/+WoidW3bLqux8/f/HqKo0oogQG5O\nHLPGJWAx6dBpVRzeW0HV4WqqrW6een0H//06hmvPG0dKnDnoPiooNeRVqe3fCWU8NVp58SkuriMu\nIaLHz+hvgl10QybYR0LhnO7QjE7HU1GOVFXDBdPO56LMc1GrZC3qtJRZLEybx1/3ruZ/hWupdzdy\n7YTLEQQBUZR4Y+0xvtxVhlolcPEZGSw9IxO9rq0GdnJfJTbgnqum8eaXJ/hqXwWHCuu5fflkclLb\naxhdoURsKHbQjjBF6KmtsuP1+II+hk9jiUSKjcZZkI97ZgLXTbqKuckzA9fnJs8krz6fv+59mb/v\n/ye35t7A5DjZP9Fk9/D/3t7L4aJ6Yix6li/IYt7UZNStbN+SJLH/02MkJ0bww9mjeG/jCd7dcIJj\nJQ3cvHQSEUZt0GMgiSKu/BNok5JQR3T8khqMvXOYGZtNO1VH9hC5KI2bp/6A7KjMwPXTk2ezsWwz\n7+V9xPP7/sFDs+8hSi9r7gUVTTy/5gA1jS7GpUez8qxsxqZFtVEEIlUCJyhk/oxR5Khh3c4yVr2/\nn0vmZ3LJgqwe9dVdUgx+f9dzQTHN9cDGDqDPysK6dQuewgJy58zmqvHLida3zNWLpy7iua//wcHa\nI/xt32vcN/P2wDuzbmcp/15/HJ9fZP6UZFYuzCHmlNBOZ7mVqsPVfO+C8aw7Ws3+/FoeW72dW5dN\nYvaExB711VUom4wMWZ2PnzIOTrsHgt8wDzlC4jz9yU9+wtVXX01BQQGLFi3ivffeC0WzQwqv38s2\nXSUAC8VMlmYtCUxQhWRzIg/MvovRllFsqdjOloodeLx+Vr2/ny93lZGWEMGvb5rLZQtz2gl1kF8q\ntVpgfGYsj14/m6XzMqizunj6zT0cKqzrUX9dBfmoLZFoYuM6vcds7rkDtdZZzwmLG6Nb5Oa0S9oI\ndYWxMdncnnsjgiDw0v5XKWgspqzGzk+e28jhonpmjI3ndzefxpnTUtsIdWjJOjVb9MyfmsKvbpzL\n5KxY9p2o5TevbKemIfjDPjyVFXKmZQeJSQqCIGDsRbnWfK0Nl05gVK3Iw3N+3EaoA6hVahann8ml\n2RfS4G7kxX2v4vF72Hm0isf/uZPaRheXzM/koWtmMC49up15QVlooqMMXLV4LI/dMJuEaAMfflPI\nX98/gKsHhapcRYUAGDK6EGi98DUA7NLLO46JNjO3TP1BG6EOEG2I5I7cG5mdNJ2CpiI+yP8ESZJ4\nb+MJ3vjiGEa9mrtXTuWHSye1E+rQ4sxNSbLw4ytyuXP5FNRqgefXHGDtjpJ293dFQLBndqXs9G4c\nhhohEex//OMf+frrrzlw4AAbNmzgsssuC0WzQ4qPC77ggFGO1811xXRq54vUWbhl6nUY1HrezfuQ\nJ975ht15NUzMiOGn184kNb7zLaTdJod1CYKARq1i5Vk53L1iKn5R5Nl39rEnL7iSBr6GBnx1dRiy\ns7u0R5osPZvEkiTx5tH3qIyRp022tXPn5vjYMdw85Qd4RR+vHnybJ/+1g8paB8vmZXLXyqmdmpdO\nTaOPNOu478ppLJ2XSU2ji6fe3E1NY3DCPbBr6SB+vTXmCB32HhREs3nsvHbk31TGa7FYvZjdnX/v\nvIxFnJ48myJrCX/Z9i9e+OAgGo2K+66axvIzszstcnZqyOeohAgevX4OE0ZHs+tYNb//xza8vuBC\nE92FhYBcfrkzeqOx76s+yIeu3fhVkNOk79SGLggC14xfSaIpnnXFm1i1di0fbykiMcbIo9fPZua4\nzlXj1rH8giAwe0IiP/3eTCLNOv61No/3NgYfkeMqKGhWdmI7vUcJKuhJstZQJJx5GgQn7VWsL/kK\nX1IcqFS4iwq6vD/WEMOKMctw+92UGzczd1Ii9105DVMX5zVKUnO87ikOmxnjEvjRFdNQqWDV+/vZ\nd6K22/4G4tezOtdMAMzmniVkbKvcxeG6YwGNx11U1OX9U+InMit+FtWuKlxRedxxWS4rzsru0lau\nCJbWsdsqQWDlWdmsODNLFu7/Ck64K9Ea3Y2DyaxD9Eu4Xd1rwZIk8caRd2n0WIkeI0cFKZpgRwiC\nwDUTVhKvTeaE8xCaqHruv3IaU7I630lBx4WvIoxa7r9qOtPHxLMnr5oXPjiIr5uDuEHW2AW9ocPE\nJIWeFkRz+py8ceRdVFodmrQ0vKWliN7Ov2vQGPjh5O+jktQckjaQkqziZ9fOJD6qa8d4R+WbM5It\n/PwHs0iKMfLxliI+3VrcbX99TU346moxZGV1qewoCoUzrLGPbCRJ4p28D/FLflZOvBR9Wjru4mIk\nX+dCQJQkDuw04a9PQB1Vx4QZTWi6iUV2OeV6JR3F607OjOX+K6ejUgk8/8EBik927TQLVrD3ZNvZ\n5LHybt6H6NU6zjvjGvk5XQg0gEa7h6Nbk5G8OvTp+czO7d7x4+iiLsiy+Vksbxbu/+/tvTi6EcTu\n4iJQqzuM429NT8Zhb81B9tUcZGx0NhNzF7Y8pwsKym1U7pPNIElTCsgZ1X3OR2dJWhq1ijuWT2ba\n2Hh259Xwj/8d7nKnIbrdchz/6NEdxvG3xmTWBa2xf160AZvXzgWZ5xA1Zjz4/biLuxawhw77cBWN\nQ9B4mTCnhqggok4cNg9GU/vyzfHRRh64egbRETre/vI4Ww5UdtmOMle72rVA730NQ42wYO+GvTUH\nOVx3jImx45iWMAVDZhaSz4e7vPMDJ9798gTbDlUxyn0GBrWeTwq/aBMW2RHdnZw0Lj2aW5ZOwuPx\n8+w7e6lrcnV4H7QW7F072QICLYhJvOb4/3D4nFyacxEJcaloE5PkOO5OhIrXJ7Lq/f1U1/qZrJ+P\niI+Xd/272+d0V6L1kvlZLJmTTkWtg+c/OIC/kxOdJJ8Pd0kx+lFpCJquHcPBVroUJZGP8j9DQDYt\nKDZrxYbdETUNTv7yn/34bdGMi5hMtfsk31bs6PI50PU4aDVqfn7jaeSkRrLl4En+u7nz57uL5bBX\nfWb3DldThB6n3dNtQbQ6Vz1flnxFtD6KxekLMGQpOR6dL/Q7j1bx7/XHMTtyiNPHsa1qBycd1V0+\np7vyzXFRBu6/ajomvYbV/zvMwS78UO4gHKcARlNYsI94vH4v7+V9hFpQc8XYSxAEAUPzC9LZJN5y\nsJJPtxWTEmfivhWncXb6mdi8djaVbu7yWQFbYhfJSbMnJHLl4jE02Dw89+4+3B2kf0uShKuwAG1S\nMmpT1yFhwdZJqXLUsK1yF6nmZM4cdToAhsxMRLsdX017u78kSbzxxVGOlzYyd2Iid5x1PuNixrC7\n4gDHG7rW8oMpJ3Dl2WOYlhPHwYI63lrbcfanp6ICyefDkJnZ5fOgbXJOV+w4uYdK+0lOS5lFkjkR\nTXQ06sjITk1STreP597bh9Xh5drzxnL9tOXoVFo+PPEpTl/nCzO0ONI7K99s1Gu457Jc4iL1vP9V\nAbuPdSwkXUWKwzCzy+eBPA6SJO8eu+Kj/M/wij4uyb4AnVrXUlqgsOPSAkWVVv720SF0WjX3XT6D\n5WMvDCySXeH1+PF5xS7nQlpCBPdenosgwAtrDlDdiXM9GMcpgFqjQm/QhJ2nI5mvirZT56rnrLQz\nSN51YmQAACAASURBVDLLoVXKit/RJC6qtPLqJ0cw6tXcc1kuEUYti9PPxKgxsLZ4Iy5f5xqhI8ia\n00vmpLNoeiolVTZe+7T9IdvemmpEpxNDN1tOCH7b+VnReiQkLsg8J+AgU7a0rg78Det3lbFpbwUZ\nSRZuvGgiKpWKZdnny20Vru/yWcEcqqBSCdx6yWRGJZhZt6uUTXvL292jaNHKgc1dEczOxS/6+Tj/\nc9SCmosyzwVk+7l+dCa+ulr81rbmMUmS+Pt/D1FWbeecWWmcPTONaH0USzIWY/XaWF+8qcs+OZr9\nLV3ZgyPNOu65LBedVsXf/nuI0mpbu3sCAi2I+dCShdv5PC2xlrG9cjdpEanMSZ4BgDYxEUFv6NAU\nY3V4WPX+frw+kdsumUxGsoUZCVPJsKSzu2ofRU2dR7Z0ZZZrzbj0aK49bxx2l49V/9nfTuGRJAlX\nQQGa2Lg2B4x0hnK62HAmLNg7QZREPjwiv8jnjl4Y+FyXOgpBpwts7RRsTi+r3t+Pxydy89JJJMea\nADBpjUFp7cEWvhIEgWvOHUd28zb8y91tTUKK9qgfPbrb36jRqtHp1V1O4hpnHdsqd5FkSmRGYss5\nmYqgcDVHXCjklTbw5to8Ik1a7rlsKnqtHNaZHZXB5MRxHKo7SnFTaafPC/ZlNuo1/OiyXMwGDa9/\nfoyiyraC1V0s90s/OrPLdiC4sgJbKrZT46pjfuppxBlboioMGfLC4TrFzv7p1uJANNTV57SEWy4e\nfSZmjYmNZZvxdGKeCzjSgygtMTrJwg8vnoTb42fVf/bjPCUM0l1UhMpgQJuY1G1bxiAW+k8L5UV+\n+ZiLAou8oFJhGD0aT0V5mxpCoiTx9Bs7qWkO7Zw+Nl6+XxC4NOdCAD488Wmnz+rJwe4Lp49i4fRU\niqtsvHqKwuOrq8NvberWDKNgMssZ2X7fwBzc3h+EBXsn7Ks5RLn1JHOTZ7aJzRXUavTpo3GXlSF6\n5IknShIvfXQoMIFnnFIKYHH6AowaY7PW3vEWvCfHf2k1Ku5cPoUIo5Y31+ZxpJVtUXHkBaOhKc/r\n6kX+vOhLREnkgszFbcLZFE3Y3cq+3OTw8MIHB5GQuGP5FGIjDW3aWjHxAgA+K/qy0+c57B50ejUa\nbfe1t+Ojjdy8dBI+v8hf1+zH4WoxIbiKikClQp/eteMUujdJ+UU/nxauR6vSckHm4jbXlJ1L63E4\nWlzPuxtPEB2h47ZLJreJ1derdZyZdgZ2r6NTW3tXjvSOmDMhkQtPG83Jeif/+KRFqIkuJ57KCvSj\nM7p1nEL3C1yNs5a91QcYbRnFhJi2NWf0ozNAknCXtSzaH35dwK4jVUzJjm2XVDU+dgzjonM4Up9H\nqbX9jgtaFcULsk7M984dR05qJN8ePMmGVgpPixkmSMHeA9/TUCUs2DtAkiQ+L/oSAaGNtq5gyMgA\nUcRdKk/iT74tYn9+LZOz2k9gAKPGyDnpZ2L3Odhcsb3DZ/a0VG1spIHbL52MKEk8+c8d2Jrtoorm\nqE/vXmMHWai5HF78HYTN1bsa+LZiB4nGeGYlTmtzTW0yoU1KDjhQlcWt3upm5VnZjB8d0669qUkT\nyLCks7f6AJX2kx32x9HDEq3TxsSzdF4G1Q0u/v5fOUJEEkXcJcXoUkeh0nbfVncF0fbWHKTe3cAZ\nKbMD2aMKp2rsjTY3z39wEAGB2y+dQmQHv2Vh2jw0Kg3rSr5ClNqPe7C7ltasaM5e3XGkivW7ypr7\nJDtOgxVoxm58DV+WfI2ExOL0s9qZiJQdorJjPFhQx0ffFJIYY+TWZZM7DHFdPPpMud3Srzt8Xkeh\nr12h1ai4Q1F41uUFdnGKshOMWQ5GRmRMWLB3QF7DCYqaSpgzahrJ5vZpywFttaSIYyUNvL+pgBiL\nnluWTeo0RvvMUWegUWnYVLq5y5fZaAo+ZX5SZizLF2RR0+Dk7/89hF8UcRcVoYmL6zSF/lSUhcTl\naO8w21S2Bb/k57yMRe2ybEHeFYgOB97qaj7eXMjBgjpyc+K48PSOXyBBEDg/82wkJD4v2tDuut8v\n4nL0/ASp5QuymZgRw57jNXy2rQRPZQWSxxP0rkWtVmHo4gShDSXfALAwbX67a5rYOFQREbiLChFF\niRc/PEiT3cPli3IYlx7dYXuROgunJc9s1oAPtrvem8ObNWoVt186BYtJy1vr8sgvbwoqMak1gRju\nDsbB4ZWVkmh9FDMTc9tdNzSbvNwlRdRb3fzto4OoVAIPXzen0zIQk+MmkGiMZ0flbqye9v6B3pz5\nGhtpaN7FSc27OF/LLjZowa4cQhMW7COKdc2OrUsnLunwuiLYrScKeOED+bT62y6ZTKSp8wkYoTMz\nO3E61c5aDte1P4HIYfc0F/rv2Z/k4jMymT4ugX0nalm34SB+a1PQmgl0blf1ij42l2/DrDExO2lG\nh99VIi3yt+9nzdcFxEbquXlp54sbwNT4SSQa49lZtReb197mmtOhnCDVs6p6ijM1yqzjvY0nKN4j\nH9emzwh+HMzmjk8QKrGWcaKxgImx4zpc5AVBwDA6A+//Z++9oyS560PfT3WOk3ty3JyjNiqsJAQS\nCiRjHgbDRRhjHDg8Xb/jc1+wr6/TxX6PCxiuMRgso4vBZIQQKGu1knalzTnvTs6xezqHqvdHdfX0\nzHRPV3XXzG6P+nMO54jpqq7f/vpX39/3942jo/zqlYtc7pli++oaHtzdsuDz3tVyDwICL/W8Ns8B\nnm+jkUq3lc8+thFRlPjGL87jV8JeVUTEwMLRQW8OHCWaiHJv850ZN3lLQ4Ncvri7i2/98gLTwRgf\nuX8VazKc3BQMgoF7W+4iLiV4vf/IvM+12NjT2bKymkf2yae4J399iXBvD6bKKoxudQW0SqaYZch4\naIIL41foKGtldXXmI6y1sQmMRgbOX2HKH+VDB1Zk1c7SOdCyH4BDfW/O+yzfLjEGg8Cffmwn5S4L\nxw+eBtRrJpDdvnxq5Cz+WIC9jXdgMWbWuJQN5PQbZzAIAn/4/k05i3QZBAN3Ne0lLsbn2ZgLaVpc\n7rTw2ffJpqmzb+QxD65kB6HobOfjweRvdW8GbV1BmYdTh85QU27j04/M7zE7lzpnLZtq1tPl66Fr\nTmRIPqYYhY0dVTx2ZzvjvjCjF6/JjlOPumJZNocFQZgv0BJigoN9b2I1WrizcU/GewWTCUtzC6He\nPq71TLBzjYcHdub2b+yp34ndZONQ/xFi4uy5L2QePnB3B2tbKrh0sYfE1JSqYAIFRw7TXDFQEuxz\nODxwFAmJu5Lx2pkQTCZC5R5c02NsX1HFQ3vULZpWdzMdZW1cGL/CaHCmNEAsliAaSeTdJabCbeVz\n79tIXVj+zkRt5iYCmZhJzpn9Mh/qO4KAwN2N2WvLm5plrbQiMMZv37eKlU3qKlDubbgDs8HE6/1v\nzTJL5auhKaxvq+T9d3VQ4RtBQsDctLDWnI5ycvGnNTKejvo5PnyaWnsNG6rnt1JTiNfKpYwbouP8\n4Qc24bSpM6cdaJI3+jcH3p7190Ln4X13drCp2YUjMEmgok6V4xRkJcHumF8Q7ezYRaYiXvY27MJh\nzl4CIFBei0FMsMYa5vGH16mqm24zWdnfuJvpqJ+Tw2dmfabFkT4Xo8HAH7x/Ix2CbGcPVOSOClIo\n2diXGQkxwZuDR7Gb7OyY4yxM5/zNca7HnJilBJ/YVampTviB5v1ISBzqnwl9DGl0EmVibWslO8rk\nF/IHF0JZMzLnkmkR90730+nrZn31GjyO7DVNfnFsiCmTi6b4FA/snN9PNBtOs4OdtdsYC41zZWIm\nwUirsywTj+xppSE2yZiljGeOZY62yIQjJdhnopYODxwlLsY50Hxn1gJXsbjIDy7K9+yqiNHRoL5F\n5NqqVVTbqjgxfJpQfCaxphBNFWQB/cntZRiQuBiyc6l7UvW9mWK4lY3nrizaOsDIZJBDI7IA/vB6\nKw6VmxvIG5yAwBsDb836e9BfWK/TCpeVRzrkMT3fk8AXVCeoS6aYZcaZsQtMR/3srd+Z1fwwPBHk\nn5++wIhdFniG4eylBTKxvXYzZRa3XNI3IduU83GWZaJ8eoSIxcGZ4Rg/Oaiu6l0mU8yhPtneqWiU\nmXjr4hDPvd2D112DNRpE1Nip/u5m+USUblstVKABJMZGMSViTLk8PHO4S3VFzJR9OdlvVZREDg8c\nxWIws6dhfmlihe+/dJXzkxA3WamYHtE0VoNg4M7G3UTFGMeGTqX+rkcTa9PYIAAjtiq59rvKcscO\np4V4TCSajIcfD01weeIaK8rbaHRlLiIWiSX4p5+fp9con9hck5kjnrJRba9iXdVqbnq7GUxGSyUS\nIuFQrOAuRmU+OSP3pljGN35+XlXRNKvNJNfoLwn25cGb/UnNpCmzZhIMx/jqT84SjMTZcfc2gJyF\nj+ZiMpjY23AHoXiI06Pn5O/N01mWTsLvJz4+TvmqFdRVO3n+aC+vn82tsc7VTkLxMMeHT1Ftq8xq\nfugemubffn0Zm8XI6js2AvMTdHLR5m6hxd3E2bGLTIbldmip7NsCBFqkV/491u/ZjNlk4F9+dYHB\n8UCOu2bmwZ8U7NcmbzIWnmB77Rbspszmh4On+3nt9ACtdW6cHe3ERoY1N/ne27ALg2DgjYG3U05U\nPZpYK+ty54Ht+EMxvp4hIzMTc9fD4cFjSEhZbeuiJPHtZy7SM+Jn3a4NcvXTXm3vBMD+xt3y8waO\nAoX5W9KJ9HZjcDhZtbGdK71TfO+FqznLM880tS4J9qJnJDjG5clrrKrooN453x4nihL//MsLDE0E\neXB3C3fcK0eKaBVoAPsa5GbYR5LOQz00VeVlcrS384UPyxmZTz13Jecx3GY3z3KYnRw5Q1SMsa9h\nd0bzw4QvzNd/dpZoXOSzj22kZu2qWc9XiyAI3N20FwmJI8nYfj02OGUc9RtW86n3riMUSfA/fniG\nqRyOsJQpxidfd3hQFjCKwJnL2RtjfO/5q7jsZrm+fFurnKDTp635Q7nVzZaajfT7B1NO1KA/e+Er\ntUR65cqW++/blsrI/PavLuYs8JV+gkuICY4MHMNusmUMcQT46cEbnLg6yrrWCn7noY1Y6hsI9/Qg\nqTQFKmyp2YDL7OTtoRPExLgu74QYDhEbGcHa2spnHt1Ia62LQ2cGePlE9sxnBSVxr1g7w5UEexJF\nuNzVON9pKkoS//aby5y/OcHmFdX89r2rMNrtmGvr5BK+Gn/8WkcNqyo6uDp5nbHQuC6mmHBaEkZ9\nlYM/+ZCc/v8/f3aOgbHsGqviMFM0pCMDxxEQ2Nuwc961/lCM//GjM4z7IvzWgRVsW12TSoTKVbo2\nEztrt2IxmHlr8ASiJBIMROXa2xra380lPUFr38Z6Pnh3B+O+MF/58Zl56fbppNvYg7Egp0fPU+fw\nsHJOZySQW9v90y/OYzQKfOHDW/BU2LG2JHMbNJ7gYMZ2/cbAW8RjCaKReEFrQUomz1kbGzGYzXz8\n3WtY21LBiSujPPX8lQXXa7rGfmH8Mt6oj11127EY54/ntdP9/ObtHuqqHPzRBzdjMhqwtrUhRcLE\nRrSZY0wGE3sadhKIBTk7ekGnTb5PTtBqacVqkes3lTkt/ODlaxy9tPD4lBr9UQ2dqm4nSoId2Z76\n9uAJ7CYbWz2bZn0mSRLfe+Eqb5wbpL3ezR+8b2OqNrS1pQUxGCA+kbv5xVz2N8ia4JHB4/os4jkZ\np2tbK/nUe9cRjMT5hx+con8B4a5oJ0OBYTp93ayrWk2lbXb4Zjga58s/OsPAWID37Grh4WQSkqmq\nCoPTSaRXm6YKcvOFHbVbGQ9PcH3qplx72zm/9rYWIr09mKpmErQe3d/OPVsb6Rn28z9/fo5INLM5\nIt0Uc3T4FHExzr6GXfMiO/pH/Xz1x2eIxUU+976NqUggm5J5mYcZQnaiVnJy5CyTPjlRpxDBHh0a\nQopGU2vBZDTw+d/aQmudrLH+7FDmKozpzw0GoryZNItkMsO8fnaAp567gtNm4n//7S2pMFdbARuc\n8k4cHjiqi8Ye7p1dN6m63MYXPrwFq9nIvzxzMWtFTCj+FnklwQ5cmriKN+pjZ922WU5TUZT4wUvX\nOHiqn5Zal1z7Oa0LUioDNQ9tdXvtZmxGK28NHk/VAS/UFCPHLM/UqblzcwMff/cafIEo//D9k/SN\nzM/uA7C7LMSiCd7slU1DiqlIwReM8qUfnqZz0Medm+r5yP2rUgJPEASsLa3ERoZJhNT3I1XY23AH\nMLPBFTIHce8UCa93VsyyIAh84sE1bFtVw8WuSf6/H55KlV9Ix2I1YTAK+H0RDg8cxSAY2DPn1HKj\n38sX//0kvmCM333PWrantXSzNDSC0ZiXYDcIBvbU7ySaiHKm/zKgjzkqPVHNYTPxnz+yLdV16Iev\nXEPMoLkr8z/pnebixBVa3U00u2eHzx483c+Tv76Mw2bi//joduoqHanPlLkP5zEP9c5aVpZ3cHny\nGmNTXnk8eig7afPQ0VDGEx/Zislo4J9+cT6rcz1XeYXbHV0E+6FDh3jooYd48MEH+da3vqXHVy4p\niq17X1LIAATCMf76X9/mpRN9NNY4+dOPbpuXfKMkwITz0E4sRgs767YxFfEy4Z1esPZ2LhKRCNHB\nQawt87vkvGtnM598cC3TwRh///2TnL0xfyErNeBP9VzAaXKwxbMx9dnAWIC/+e5xbvT72Luxjk89\nvG5eeKcyD1GN9mWAVRUdeOzVnB68KNfeLmhzk58/t06O0WDgjz64ib0b67jR7+Pv//0k497ZxdgE\nQcDhtOD1Buj3D7K5ej1llplMxbM3xvh//+MUoUiC33tkPfdtnx3eKZhMWBubiPT1IiXU9SJNZ09y\n7V3ol7OSC5qHLPWCypwW/vSj22iodvD80V65xO2cE4wiSPvGhxElkb1pm3xCFPnF6zd56rkruB1m\n/uxjO2irn53NaU3mNuSzwQHsa5Sf1z0qO/4Lm4ceBLMZS33DrL+vbq7gCx/egtEg8LWfnuVXh7vm\n+R6cRR7yWLBgF0WRv/7rv+Y73/kOv/rVr3j22We5cUN9g9lbTSAW5NzoBeqddbS55UV5Y8DLX/3b\nMY5fGmZjRxX/5eM7MpYLSBU+ykNjB9ifXMTT06FUE+t8CPb0yl1yWjIn5Ny7vYnfe2Q9kViCr/z4\nLN9/8eqsRsj25CKOhBLsqt+O2WAiIYocPNXP3/6vmbKrv//ohlmVChUUAZKPI1kQBFlrj8jfq4um\nmqEAmslo4DOPbuCBO5rpHwvwF//6NgdP9c/SWh1Oi6yhSTMCJhiO893nLvOVH59FkuBPPrSZOzc3\nzPt+5blSLEZ0WJt9GaDGXsWaipWMTckRQos1DzXldv6vT+xkXWsFp66N8bf/6wRXe6dSnyuCdGzK\ni0kwckfdtuT/D/H3/36KX77ZRXWZjT/72A5aaufXIzK6XJiqqvMW7Ns9m7EYLYwm5yHfkE8pHic6\n0I+lqRnBOD/BaV1bJX/2sR1UuK387NBN/vt3j+JNc7Ar85CpzEQxkJ+KmMbZs2dpa2ujqUnWYB55\n5BFefvllVuboDH+7cGz4FHEpwd76ndwY8PGrw12phtEfeWAN79nRlNXmayqvwFhenpd9GeSQv3pH\nHVLEgLUy/58ikOzmtFBFxzs3N9BS6+Kbv7zASyf6OHVtjPt2NHHXlobUIjbFrGyr3s6JKyM8/UYn\nfaMBrBYjv//oBvZtyt4I2VqAfRnktPJXzsjOaz00VVuWeTAIAr/zrtU0e1z88JXrPPX8FY5cGOL+\nHc1sXVWN3WkGUaDcUEGdqY3nj/bw/NEepvxRmj1OHn94/YIJSNbWVjgsR6RYG9Vn/yrsbbiD586f\nBPKfB0mSiPT2YK7xYHQ4Ml7jtMlNsb//4lUOnh7gi/9+kl3rannPrhbaG9wYTQKJMGz2bGRiUuSn\np65w5PwQkViC3etr+eSDaxdMQLK2thI4fYq4dwo86uqzKNhMVnZ4tjB8TijIkR4dHJA7aC1QVmJF\nYxn/9VO7+Oenz/PW+SFOXh7hwLYmHtzdoqo2/e1MwYJ9eHiYhoYZDaauro5z584V+rVLxuHXbuK2\n1fLTX0SIhk4AsKa5nA/cvYK772hldHThxtHWllaC58+R8PtVV1RUEASBXVU76ZQgYgzm/W8I3OyS\nx5KjNkprnZu/+NQufn7oJgdP9/OTgzf4+aGbNDsEagFLqIIvfvs6kgQCcNeWBj50zwoqciSJWOrl\nAlD5OMwAKm0VtFrlscfN+dfnCPf2YLDbMdXUZL1GEATu2drI5hXVfO+FK5y6Nsa1Pi9mk4GV9ghu\nrIhDTfyXf5ZzGkxGgQ/e3cF797blbEg+43PpgT3ZSzFkY1vtZl6NyzZ2m4Yqn+nEp6ZITE9jX71m\nwetMRgOffGgd+zc38IOXrnHs8gjHLo9gtRhZb4hgilk5fzzB4SHZgVpdZuV337OG/Zvqc54srS2y\nYI/09sIq9WUdFPY27OTZ2GWwJPJ2pCvm0VzlqxXz1MkbE/zwxSu8eLyXF4/3Umkzsgq41jfIPopD\nSU2nYMGeb5ynR+NOvliU9zVisVThrK6mbUMZ79ndxuZVM4Ih1zgDa1cRPH8O2/QYFR2Zj+gLsT+4\nk05OMMl43nMyeLMTwWikactaDJbcmt7nP7qDx9+/mVeO9/DayT68kWvgr0OYqmJ9exWbV9Zw59ZG\nOhrV1X4BGGxvI9DVTXWFDYM5u1DK9m9cX76Wq/gYZgCPJ3Ps+EIkwmGuDg9TtnEDtbW50/o9Hjd/\n9bkauod8vHlmgDfODBCMD+CmkcRoLVtX13Dn1ib2b26gXGX2Y9yxnj5AGh7I+7f0mGqJAn7HOOs8\nC6+nTM+Y6L4KQNW61arG4PG42bOliWMXhzhxeYSzN4eJ+idwBMqxhl3csb6Sh/a2cceGeowqhaxh\n01omngHT+FDWcS5Edc0WXox3EbIFcFeYsZltuW+aw/SY/Oy6LesoU/H8h+vKeffuVl4+1suxi8Pc\nnOwkOi4QIXbbyCotFCzY6+vrGRiYyXAcHh6mtjZ3NblcmvBS4XY5KBMcfOKTM45TZWwejzvnOMUa\n+eUbOXeZWEO75uf7RuQ42QlxnDOd17KmbWdDEkUC3d2Y6xsY90YA9RrvvnW17F3r4YsHj8BwHfva\nV/LgYzPt77T8RoaGJqTrNxg4dzWrlrTQfNojZYCPs5PnGRrOXP99IUI3roMkYahv1DRuh1Hg3Tua\n2L3RzZd+fhTGG/n9d+1gzUY5SS0aijIaUn8cN9d4mL5xk5ERX14+E1vcSVgI8XLnm7Q5s5/Ass3l\n+DlZ449X1WmahxV1Lvl/66Z58RdhhEAl//UTu1MmoYnxzBFVmYiVy9FCE5ev0Yz2dz0WjSMkjMRM\nYV64eDjl79DC1JVrIAiEnFVEVDzf43EzNRlk56pqdq6q5t8vXeTwwDH+eNunbxtZBeo3yYKdp5s3\nb6anp4f+/n6i0SjPPvss73rXuwr92iXD6ZKTc/I9eaQch3nalxUbXswS4a2hzK3SFiI2MoIYDmsq\nS5pOr7+fgbhc7yYRzj/LrpAIIYBIQN7gvMIklyauar9/AYehGo4NnyJqliNlCnGYWVtaSUxPk/BO\n5b44A2JIQLLEOTt2nmBMe/io1m5BczkyeCxlDss3httUXYPBbi/4nYibI6mINS2k/Ax1dRhs2rX9\nSCLKyZEzVNrKWVe1OvcNtyEFC3aj0cif//mf8+lPf5pHH32URx55pGgcpyB73RMFZJjJHdqteduX\nlZfHZIWjQydJiNpC5RSBls1hmIu3Bk8gGuIYjIU5itK7SuVD+sv81tAJzfcXItglSeLI4HEkc2zW\nWPIhFcedx3qQJEmO5XdZiIlxToyc1vwdkd4ejC43psrsDS6yMRme4vLENdxuuTZOvvOQym0YHiYR\nztzjdyGUd6LM7eCGt5ORYPZEokzEx8cQQ6G834nTI+cIJyLsadiZtarn7Y4uo77nnnt4/vnneeGF\nF/jsZz+rx1cuGWo61C+EYDBgbW6RO7THtH+H8vKsbmhjOurn4sQVTfcXItBiYpzjQ6dwW1w4XbaC\nsuysTc0gCPlvcIEoJrOB2rIazo1eIBDT5kyO9PSA0Sg3QdFIl6+HocAwq+rlzamgeSigxEIkHEcU\nJWoqKhAQNGuriWSbQmtLa15moLcGTyAhsaJWbpBR8AYnSQS7ta8H5bntHvm31DoPhZ7elAYwe+vv\nyHHl7Utxbkc6okdYk7W1FUSRaL/6+t8KyrH/jhbZtq2kcatFrfc/E+fGLhKIB9lVvx2nq7CiRwab\nDXNdHZFe7bVzYKb29r6GO4hLCY4Pq9dWpUSCSF8v1sYmBJN2t5FSUXBfm1yeVw+NPZ/QT+W55W4H\nG6rX0u3rZcA/pPr+mYxT7WtBlETeGjyGxWBmbf0KoHCTFID/Zqfme5WNdVVdG3aTnbcHj2s6yabe\niTzMUWOhCa5O3ZAT5xboRXC7844X7Hp0S0lpaXmYIZTnrqxrodXdxIXxy0xFvKrvj/T2YPXUaA61\nhJkyxXc27sbutCBJEM6Qbq8WW2sbYihEbEzb0VkUJULBKA6XlV11OzAIBt5MK2Obi+jQEFIslteL\nHI6HOT5yhipbJRtqV2O1mQpaC6bKKowud14ae3oxOKWsw9z2gQtRiGC/PtWZKlNcWe6aNZ58UHwu\ngc4uzfcq819WZmdX3Ta80WlNJ9lCNPa3FW29QbvD9naiJNiz9PzUQiGOQ7n9lwmTycj+xt2pgmRq\niHu9JLxTODsy92ZdiLHQOJcnr7GyvJ16Z50uRY9mzBDa5iEciiFJ8m9RbnWzuWYD/f5BeqZzl1eV\nn9clP19D82qFkyNniSai7G24A4NgwOW2FiTYBUHA2tpKbHSURDB3Hfh00ovBba5Zj9Ps4O2hE6q1\n1ZlSAtrnQaluuq9hly6nWKV2TiAfjT2tAJgSEXNk4Jjq+yM9PRjLyzGVqw/XBbmD2uHBY1iNw/az\nyAAAIABJREFUFrZ7Nue+4TamJNh1qAlhaWzKu8FAeu3tO+q2YTaYOTxwdFYv0GwoJwRnR7vm5x5O\nvihK5T5dTi55OlDnNthQxvRG/1tZ70lH2VBteQi0wwPHEBBSdYJcZTbCwRgJFZ12sjErUUkDMxq7\nFZPBxO76HfhjAc6MXVB1f7inB8FiwVKvLWQ2FA9xauQcHns1qyo6sCeTowra4JK1c4Ld3Zpr56QL\n9hZXE02uBs6NX8IXzR12mPD7iU+M56WtX5y4wlTEy676HdhMhXVuutWUBLsOGrvBYsFS30Ckt1dT\ng4FU+6/kGOReq1sYC09wbTJ7aVWFcHdSsK9coWm8CTHBW4PHsJvsbE82ULiVGvvcssXrq1ZTZavk\n+PBpQvHcURWRnm4QhKy1crIxmFamuMomR5G43PILHQ7mb5KaqSFU2DwovQFe7zuS9R4FMRYjOjiA\ntblZdfNqhbcGTxATY+xv2I0gCBiNBmx287ym1lqxtrUhRqNEhwY13Rf0y450s8WIIAjsb1B/kk1F\nieVhlns9qUjcnaEnQ7FREuw6VXGztrbKDQZG1duXQ0nh4XDOZGoq2qrSwWchlKO3a4U2wX5+/BLe\n6DS767enyhTrobGbysowVlRoPrnMbTSS3gv0+PCphW6diVmu1R6zrPgY0rskKYK9kHlImeY0nlzm\ntoOrd9aypmIlV6duMBRYuLBYdKAfEgnNZhhJkni9/wgmwTgrEShTU2utKPMQ6dY+D+lF8XbXb8ds\nMPN6/5GcJ9l87eujgXEujl+ho6x1XpniYuQdL9hNJiMWa2EOM8jPgTpjgpg59q0ob6PeUcupkXN4\nIwsfPSM93Rhdbiw12rz3mRooOJNp84U2FrC1thGfnCQ+rb65daZGI/uUXqD9CztR42NjiMFgqtGF\nWsLxMEcGj1NucbOlZkPq704dBLu5tg7BastbY7enbfR3N8s1Z17PYZbK13F6ZfI6w8FRttduxW2Z\nccA7nBaikQRxFX1Ss2FtawcgnPSBqCEVy59WBM1hdrC7fjvj4UkujF9e8P5wlpLFuXj55htyb9em\n4tfWoSTYAXRpXGtTFnFXl+p7UppqmkATBIEDzXeSkBK80Z/9CJ4IBOSY5bY2TTHLI8HRlGbS5Jqp\nRTLjMCvw+J2HGSJTa8ByaxmbazbQ5x9I9QLNhCI0tEbEvDV0gnAizN1N+zAZZkIkXW7brDHlg2Aw\nYG1J5jZE1X9PwB9JOdIVttZspMzi5u2hE0QS2b8rX8fpoeQaO9A8u2iZLj6X5hbZ96RBY1cc6XPL\n9R5ovhOAg71vLnh/pLtbDr1VUdZEISEmePnmYewmOzuz9HYtNkqCHXkRh0M6Ocw0LOJsLfH2NOzE\nbrJzqP8IsURmW2+mLjlqeLVX1kzua7l71t9TDrMCN7h8en9mm4d7mmRh82rv61nvjeQRsyxKIq/1\nvYlJMHLXHA3NVVa4xg7JVnnJ3qNqCQWiqYQ5BaNBjpYKxcOcWCC2P9zTI/sZmptVP28yPMXZ0Qu0\nuBppL5ut4ephojRYrdibGjU1t86k7AA0uRpYVSF3VxoKjGS8VwyHiQ4NYm1t0+RnOD16Dm/Yx976\nnRl7uxYjJcGOPkX1jQ4H5to6wt1dquOvszWxthot3NW4B38skDVRJ9zdBYBNQ4ifPxbgyOBxqmyV\nbJvT29VoNGBzmAno4GsArSappAliTqnatZWraHY1cnLkLGOhiYz3ztRGUX/0vjRxjZHgGDvrts0y\nP0CaYC/UcagxQkh2pMczNpa4q3EPAgIH+97MuLYkUSTS24uloUFVdU+FN/rfQkLinub98059ejWa\ncK1ckWxunVkYz2WhXqeK1n4oy0k20tsjN5xJnp7VIEkSL/a8hoDAPc3aSy3frpQEO/o5UG1tbXJz\n67HMfRTnEsiiqQIcaN6PQTDwat8bGV/mGYHWrnp8b/S/TUyMcV/znRmrJzqdloJfZHONB4PDkYrY\nUUMwEMXuNGOYo2UJgsC7Wu9BQuKVLFp7uKcHU2UVJnfuUr0KB/veAODepKBIJ2WKKXiD09YPN7TA\nWqi0VbCzbiv9/kHOj1+a93lsZBgpEtZkhgnHw7ze/xYOkz3VJSkdvRpNOJOOfbV29oUE+9aajZRb\nynh78HjGaKl8lJ0rk9fpne5nT/N2ah2e3DcUCSXBjj72REhzFiUXWC5CSU3VmaHed6Wtgu2ezfT7\nB7k2Nb/VYKS7G4PdPqt59ULExDiv9b2JzWhjX2PmeucOl+wwixXgMBMEAVtbO7HhIRJBdfVeFmpi\nvbN2K5XWCo4MHMUfm53wIzevntKkrQ/4h7g4foUV5e20ls03WzidFgRBB5NUY5Pc3FqlSWohgQbw\nnrb7AHi+65V5G324S04CsmlIVDvUf4RAPMj9LXdnND/oEQYMssYO6k2UC82D0WDkQPN+wolIRlv7\njGBvVz2+F7pfBeD969+j+p5ioCTY0U+w2zQK9kAggsEgYLVlrm9yX8tdAPxmzssshsNEh4dkW6JK\nx+nx4dP4otNy+QBT5rBAvY7fyganRluNRePEogkcWZpZGA1G7mu5i6gY4/W+2ZEh+djXn+18AYD3\ntN2b8XPBIMz0Pi0AwWTC2tSsurl1LsHe5Gpgc80GOn098zZ6xWFva1Mn2COJKC/3HMJmtKXMG3PR\n6xSrJM+p3uCy2NgVDjTvx2ly8HLvoXlljSPd3QhWG+Y6dQla3b5erkxeZ23lKlZW5Vfm+HalJNjR\nJzkH0h2oXaquV7JOswnnjvI2NlSt5erkdS5PXEv9PdIrN69Wm4QRS8T4deeLmAQj97ZkfpFhZh4K\nFWqK5hjuzJ1OHgwosfzZbcPKZnSw7w3CaUfwlIamUmPv8fVxevQ87WWtbKpen/U6R4EF0RSsrW1y\nc+vB3MXhsvlb0nmw7X4Anu96ddbfw12dsuNU5Ty80f8W/liA+1ruxGG2Z7xGL43d5HTKvqcedb6n\nXPNgM9l4oPUAoXiIV5MmNQAxEiE6OICttVW14/TF7oPAzGloOVES7OinsRudTswejyoHaqZ43Uy8\nb+V7AXj6xq9TyRlhjbVRXus/zER4kgPNd6YyLDOhxNMXHPrZnhTs3SoE+5xyAhm/z2Tj/pa78ccC\nPN89I9RmTBDqErSeufk8AI+teHDBk47DaSURF4lG8jdJyePqmDXOhcgWGZROR3kraytXcXnyGlfH\n5MxkKZEg0tONpbEJgzV3Gnw0EePFnoNYjZZ5kVHpWG0mDEZBl2bO1tY2xECA+MR4zmuV9ZDJiaxw\nT/N+XGYnr/a+ntLatTpOu329nB49T6u7ibWVq1TdU0yUBDv6aewgmyHEQID4+MIO1Eg4jpiQcgr2\nFncjd9Rto9c/wMmRs/K93eodp/5YgOe6XsZhsvNQ+/0LXjtz/C4sIsRUVS1XOFQR069GoAE80HqA\nSmsFr/S+zlhoAkmSCHfexFhRgakid1OJ61OdXJy4wpqKlTm74ug1D6kNrjN3eYhcphiFhzveDcC/\nnvwhoiQSHRpEikZTz8rFq72vMx31c6D5TpxmR9brBEE2Sekh2BVnphqHeiAQxeYwY1ygcbjNZE1q\n7WFe6T2U/O6u5LPacz5DlER+dPVpJCQ+uOqRvGrX3+6UBDtgs5sRhMJty6Dezp7LlpjOYysexCgY\neebm88TFOOGebtXFnp7rfJlQPMx729+FY4EXGfQ7uQiCgLW9ndjYKAn/wr0y1ZggACxGCx9Y9TBx\nMc7Prz9LfHKShNerSlsXJZFfXP81AI+tfDDn9XqZIaxNzQgmkzqTlMr1sKqig931O7g52cOhviMz\np5b29pzPGA6O8uuul3CbXTzQeiDn9Ypg18MkBRBRc3LxR3HmWAsga+1ui4sXe15jKDCcMn+q0djf\nGjxOl6+HnbVbWbMMtXUoCXZgRjsp1LYMaY7DHNrJjKaa+/hcY6/m7qa9jIXGefbys0T7+7C1tee0\nJfb7BznUf4QaWxV3N+/P+Rw9Ty6KoMm5wanUVEGOkFlR3s7p0XN0npdjme0qBPvzXa/S6etmR+0W\nVpS357xeL1+DYDJhbWsn0tebMwM1FIgiCLKSkYsPrXoUp8XBMzefw3dD7g9rzeE4FSWRf7/0E+Ji\nnI+s/cCC2rqCw2VBTEhEwvm1jVRIKTs5NrhYNJF0pOdeC1ajhY+u+SBxMc5TF39EuKsLwWrNqewE\nY0GevvEbLEYLH1z1iOp/Q7FREuxJHAU2tVZIFYDKqbHnti2n89iKB6m113DhzKuy4zRH4a9ALMi3\nzn6XhJTgw2veh9mQu7OQXho7gK09Gb+cQ0tTa4oBeQP+8OrHEBA4d+ol+Tk5BHunt5tfd71IhbWc\nj679kJqhF9wuMR1be4ecgZqjMJrib1FjFnBbXHx8ywcJJyIMXz0jtwTMUdnyzYG3ueHtZKtnk+pa\n44rSESgwWcvocmGuqyfcdXPBDFTF9KVG2QHYVruZXXU76J/sITI4ILcEXEDZkSSJH1/7Jf5YgIfb\nH6DSVqHtH1JEFCTYn3vuOR599FHWr1/PhQvqakbfrjicFuJxkVi0MIeZ0eXCVFOT04G6UHJSJmwm\nG7+36XdpnJDHF2/OrpmIksi/XfgBY+EJHmq7n81pRa4WwmwxYjIb9NXYcwl2laYYhbayFt634iEq\nRmQTj9CcvRJfOB7m3y78AEmS+E8bPqpKS4UZwVKojR3SI4Sy29klSSLojy7oMJzL/Sv2s8rVimPE\nR6imbMGWgDemuvj59Wexm2z8b2s+oNqmrOcGZ1+xEjEUWrCEb0CDeVLhI2veR7vfgiBJRBqqFrz2\nmZvPc3ToJK3uplQo8XKlIMG+Zs0avv71r7NrV3G3kQL9Mu1A1lZFv3/BEr4zyUnqF3Gzu5GdIbmS\n43+E3s7YvT0hJvjZtV9xceIKG6rX8sgK9YkXejrMTBWVGMsrcjpQlSbWFqv6XqUPtNxDw6TERJmR\n73U9k7HD0Hhokq+c+iZj4Qne3XYvaypXqv5+XU8uyRPFQoI9GkkQj4ua1oJBMPCJqvswiXDdHeLZ\nzhczXnd54hpfP/0vxMQ4v7vutym3qs/Q1cskBaROmOGb2edB2UDU2NgVHGYHDxnWAfBi4irnxi5m\nvO5g75s83/0KHns1f7T192YVfluOFCTYV6xYQXt7e8Hmi9sBPe3L9lWyQyZ841rWawIabMsKkiTh\nGJwk6rJxnXH++7Gv8ubA28QSMSRJotPbzd8f/0de7XsDj72axzf8DgZB20+smKREsfDf1NbeTnxy\ngrh3Kus1akI+5xIfGcYUjROsr+T06Dn+7uiXOTt6QY6UiYc5N3aRvz/+VXqn+9nXsItHO7RlFerl\nPAW5hK/B4Vjw5JIyy6k0QSiYBuSNPVBXwW+6XuKpiz+k0ys3E58IT/JKzyG+ceZfEZH47OZPsq1W\nW7s3p1OfujkAthXyxhq+OT+LWkFLQEE65UNyiejBWgvfPPtdXuh+lfHQJJIk0Tc9wJMXvs9Prv0S\nt8XFn2z7zLz6QMuR5b1taSC1iHXQ0uwrZcEeun6dsn2ZE4JSha80CLX4xAQJr5eqHTt5fOPd/MeV\nn/H9yz/l+5d/ikEwpOLc72zczftXPpwzCiYTDqc11dRaq8Cdi629g8CZ04S7unBtnV+PRBQlQoEo\ndU3qtUiYMe9s2v4uhhoDHB44xjfPfReb0Uo4IQsho2Dkd9Z+iDsb92gOZzOaDHJTax0EuyAI2No7\nCF68QMLvz9h0PJDH6Q1InYYevPPjdE78hreHTvD20AncZhfTMdlUZTGY+YMtn8oZ4pkJXcOAm5oR\nzGbCndkFeyCPDU6SJELXr2EsL+f37v5j/vncv/H0jd/w9I3f4DQ5CMTlshaNznr+04aPUmPX1rug\nWMkp2B9//HHGMhS1euKJJ7j//oXjohfC43Hnfe9iUN8oCxdBmj22fMYpVmygz2Ih1n0z6/3RcByH\n00J9vfqGu2NXzwFQvXkDWzfdza6Ojfzowq+YCE4RSUSxGM18eOPDrPdof4kVqmuc3LwyitVsKvg3\nMm3byPjTP8cw1IvnATkZJv07/dMRJAkqq5yanjU9JJfCbb1jO19Ys5rf8j3ED889Q59vkFpnNR5H\nNfd27GNVdXte4/Z43JRV2Jn2hnVZp8GN6whevIB1apjKjoZ5nw/2eAGoayjT9LxY900MFgsb9uzm\nHw17OTt8iYOdR7gwcpXtDRvZ0bCZXU1bqXLk5yS0W+UInXhMLGgelHuHV6/Cd/kKVS4TRvv8jFcx\nLp8SW1orqax2qvruyOgoiakpqvbuYf2qjaxs/L95vfsoNya6uTnZTXtVM4+tfTfbGzbm3OBvN5lU\nCDkF+5NPPrkoDx4dzd2YdimJJ731I8PTqbF5PO68x2ltayd4/RpDPSMZF7HPG8JVZtP0/aOnzgOQ\nqGtO3mfmtzs+OG+chcytYJQXf3/fJEZLYUFTiepGEATGz5zH8eD0vHGODcv/bTQZNI158uIVMBoJ\nuqoJj05jxcUn1/zO7IvE/OZBGaPVZmJ0KMbgwBQm8/xKmFoQa5sAGD59gXjzfFv/0IAs2EVJUj3m\nSrtAsKcX+9p1jE/K2ZdNplY+vroV0vb1RABGA/mtB1GUEASYnAjkvabSf3NjcxtcvETf8XM41s0v\n6TAxLhd5C0diqp83ffQMAIaW9uQ9RvbX7GN/zewSvGNjC+dTFPKuLyVqNx/dwh2L3c7u1Cm0S8G2\nchUksyPnEo8liEYSmk0doZs3wGDQVL1OK3ral40OB9bmFsKdNxFj8xuGaAl1VJDicSK9PVhbWjGY\nc8d858tSOlDzsS37Ll8BScK+Kv/TWS4MSkG06cLnANLs7FnmIdVBSsNGGrpxHZgxf5aQKUiwv/TS\nSxw4cIAzZ87wuc99js985jN6jWvJ0VOgAakXLpxceOnkLdB6urE2NauqCZIvelX1U7CvXo0Ui2Us\njKY11BFk+7oUj2PX2MBbK3rOg6miAlNVNaEb1zPGcSvKRKbyzdnwXZCjP+yr1xQ8voXQqyAazETG\nhLI4UIP++R2kchG6cT2ZCLa8qjMWSkHO0wceeIAHHnhAr7HcUowmAza7WZfQLgDbSlk7CV2fHxmT\nj0CL9PUixWIprWex0H2DW72WqVdeJnTtKuzbMeuzlNPQrX4eglfkZsb2Net0GV829HQcAtjXrmX6\nyGGiA/1yL9A0gn456zS9iXUufJcugyBgX7nI68FlZXTITzQSx2or7IRkqqzCWFFB+OYNJEmaZfNO\nxEUi4Tg1deojVsRIhEhPN7aOFRjMy6OlnV6UMk/TcLosuoR2AZjcZZjr6uRFPEdLyycRQ9FycmWc\nFopTx9hlkDV2QBbsc8hHUw1dvSJ/75q1OowuO8qY9BLsjrXyRhRMjj+dgD+C3WGZ10EqG2Isiv/a\ndaytbRhsmcvu6oWe60EQBOwdK0l4vfMqPeZzig13dYIolswwGSgJ9jQcbqvcQShaWG0MBfuKVXK2\n3eDsbDslo1GTQFM01UW0qQLYHMkOQjpkXYKcqGT2eAhdn2+GCE5re5mleJzQtatYGpswlWkLkdSK\ncnIJ6DQP9qRgV35HBSXrVJNA60yao1Yv7lqAxTjBJTf6K7M3uFSoo1P9O6GYOW2LfGopRkqCPQ29\ntVVbMlEpNCdRSUvhK5CbFQcvX8JUVY25tk6XsWXDYBCwOy26vcgA9lVrEIMBgr19s/4e8EcwGAVV\nha9Arr8jRaPY1y6utg76m2LMNR5MlVWErlyZZa/OJ+s0nDTvLbZ9HcDp1i9JCcCxXi5vEbw0O0M0\nmEcsf8lxmp2SYE9D7+O3suDC1+YIdo2mmEhPD2IggGPDhiWpHe10yZUu9Yp0UgSQ7+Lslzngj+J0\nWVX/mxRtVzFrLCZ6a6qCIGBfu5aEf5rowExHpXyyToNXZbOWfdXiC/bUyUWnebA0NWN0uwlcujBr\nfWk1xUiiSOjGdUzV1arq8b/TKAn2NGZqY+ijnVgam+RFfPH87EWs0XmqaDeKtrPYOF3WlDNLD5Tj\nt+/ijBlCFCWC/ogmDW2pHKdAMuzOoFt0EMxsSKErl1J/05p1Koki4RvXsDU2YCpXn9yWL3qfXASD\nAce69SSmpoilFQTT+k5EeroR/f4leyeKjZJgTyNlitEpblcwGHBs3ETC6yXa15v6e9AvF74yW9TF\n6wYvyZUzHeuWSLAnj9+BaX02OHN9A0aXG9/FGYEWDkaRJPWaqhSPE7p+DUtD46Lb1xWcLqu+Jqk1\n8x2oWjX2aH8fYihE2frsPVv1xKljpUsFx/qNAATSzDFaywkEzstZ2M5N2urfvFMoCfY0UuVaddLY\nYWbhKQsRwO+PqDZBiLGoLNCampdEQwP9fQ2CIGBfs4bo2BjRoaFZ36021DHc3YUUiaSckEuBw2kh\nFNSnIBqAubYWU2XlLDu7Vo1dOb2VbVwawa6EYOql7EBmO7tWG3vg/DkQhNQmUWI2JcGeht4CDcCx\ncRMIQkqwJ+Ii4WAspRXnInzjBlI0uqRHTr01dgDnlq0A+M+ckr9bY6ijEua4FPZ1BYfLgiRBKKjn\nBreOxLQvFSml1d/iP30KBIHKnTtyX6wDBoMBu9Osq0nK7PHIkVKXLyEl5JLLWk6xiUCA8I3r2Fas\nxOhUV1PmnUZJsKeRqsmuo8ZucpdhbWsndP0aYjg0I9BUaqpLbV8H/TrnpOPcvFXe4M6clr97WqOm\nelk24yx2/Ho6etuXIS2e/bL8u2rZ4BJ+P6Hr17CtWImlYum6/zhdVgL+iK5lQxzrNyCGQqkG14FA\nRHUHqeClCyBJODdv0W08y42SYE/DaJS1Ez01dkiaYxIJgpcupR291WmqwUsXwGDAsQQhfgrKpqPn\nPJjKy3GvWU3o+jUSfr8mm2oiECB4+RLW1rYlM0dBWv0gHU8ujqRpzn/yhPzdGrJOA+fPgihmLIG8\nmDhcFuKxwruLzfrOpAkleOkCoigSCsRK9nUdKQn2OSyGdpJuZ1eEhBpTTCIYINzZKadML3KGYTqu\nRTDFAFTt3gWiSODc2Rmbqop58J8+CYkE7juWtlNXyiSl48nFXFWFbeUqQlcuE/f5CGrIOvWflk87\nzq3bdRuPGmYK5OnoSF6XPLlcukgoEEs+J/fpTZIkAufPYXS5sbaW6sNkoyTY5+BcBO3E1rECg8NB\n4MI5/Elh6VIj0E6evCVHTovVhNFk0NUkBVC1+w5AtrPPmCByv8z+48cAcO1cWsGu/EZ+nTc4985d\nIElMnzyhOutUiscJnj+L2ePB0pi9z+ti4FgkE6WtYwWhq1fwDcvlBdTMQ7S/j8TUFI6NmxZsXP1O\npzQzc1gM+7JgNOLYsJH42BjTQxOAOk3Vd+RNAMr27Mtxpb4IgiAnKekYCQFgb2nB7PEQPH+OgC+C\n2WLM2es0EQwQuHgBa0srlrrFzbqdy4wTWd95cN0hb3BTx0+qzjoNXrmMGA7j3Lp9SZLU0tGz92k6\n7n37QRQZPymH86oxTwbOlcwwaigJ9jnoHcuu4Eoen729Q7Oek43Y+BihK5exr1mL2ePRdSxqcLqt\nBANREon5ZWbzRRAEnFu3I4bDBLxBVRqa/9QpSCRwLbEZBtJ8DTpr7OaqamwrVjJ1U85tUGNbDiSj\niVzbltYMA/pnZCuU7doDRiMTV+X67K6yhedBkiR8bx0GoxHHpk26jmW5URLsc9C7NoaCa+cdGBxO\npsenEYTcx07fW0cAKNu7X9dxqEV5mUM6hrmBLJhEDISjkioNzX9CNsMstX0dwGQyYrObdDfFgLwe\nIkbZb5Jrk5dEEf/p0xgcjkUvApeJmdr0+s6D0e3GuXkLfp86v1P4+jWi/X24tu/E5F6aJLVipSTY\n57BYx06DxULZnXcRESzYzCzoLJMkCd+RNxFMpluiqcLiRMaAXJ0yXlkLgMO+cMxyIhggcOE81pYW\nLHX1uo5DLU63VXeNHeSNShHsuTT2wNkzxCfGcW3fiWBa+v7zi3WKBSjbt5+ISW66nsvvNHXwFQAq\n7r1P93EsN0qCfQ56t8hLp/yee4kYnViiC/dfjHR1EhsawrV9B0aHQ/dxqGExQv0ABJMJy54DAJgm\nBhe8dvrYMdkMs8RO03ScbiuxaIJoRJ+6OQrm6hrE2mYArGSfY0mSmPj1rwCofM9Duo5BLQ6XFUEA\n/3RY9+92btlGxCr38XQ4sod8xqd9+E8cx9LQuKTZx8VKSbDPYTGSUhSkihpEgxGzf4LIQH/W6xSn\nqXvfrTHDwOKE+ikIa2THl3TzEmI08zyLkQgTv3oawWymbP9duo9BLYsV+gkgJRtbx46+kfWa0LWr\nhG/ewLltO9amJt3HoAaDQcDhshLw6T8HBrOZmKMKczxE5NrlrNf53ngdKR6n/MB9S+48LkZKgn0O\n9mSjCb1NEEDKlmiNB/AefDXjNdGhQbyvH8JYXoFzw61zEC3m8Tuc/EqzfwLf4cxCbfKlF4hPTlL5\n7gcxV1XpPga1KCeXxbCzx9zVAMRPHSHS25PxmolfPwtA1Xsf0f35WnCVWQn49auboyBJEiEs2OIB\nxn72E6T4/JORJIp4XzuIYLFQtv/WKTvFREGC/R/+4R9473vfy/vf/34+//nP4/cvbGIoBpTO7Ho7\nT2FG+7WbJbxvHCLS2zvrc0kUGXryO0ixGLUf+/gtsacqLKbGrmyaNqJMPv+bVL0QhbjPx+RvnsXo\nclP50MO6P18Li1E3R8E/HUEQwBIPMfqTH837PNLbQ/D8Wexr1t7yZhIutxVRlHR3pkfCcRIJCWeZ\njUh3FxPP/XreNVMvv0hsbBT37r0YHaXaMGooSLDfddddPPvsszz99NO0tbXxzW9+U69x3VIcLquu\njSYUFOHg2b0NKRql/2tfJj41lfp88sXnCd+4jnvXbjmJ5RaSciIvgkBTNovqbZuIjY4y+sMfzGqb\nN/7M04jhMFXve/8t8zEoLKZgD/giuMpsODdsIHjh/KwKoHHvFEPffRK49do6LF6yljJpmH0fAAAa\nGklEQVSvVWtXYKyoYPyZp4mklbj2nznN6I/+A2N5OdXve7+uz17OFCTY9+/fn4ru2LZtG0PJkqzF\njtNlIREXCQVjun6vsohrNq+j5kMfJj4xQf/XvkLg/Dkmnv8N47/4GUa3m9qPfULX5+aDEuq3GCYp\nZR4aH3svloZGpl55iYGvfYXg1Sv0f+0reF99GXNdHRX33Kv7s7WSEmg6z0MiIRLwR3G5rdR8+CMg\nCAz809cY/fF/ELx0kZ6//SsiXZ2U7b8zVV/mVuJMxpj7dbazKxuFu8pJ3Sc/BYkEg//yTbyHXmP6\nxDEGv/UNBLOZpj/5Auaqal2fvZzR7az/k5/8hEceufWahR64ymwA+KZCGC36uSHSC4BVvPcRosPD\n+N58nf6vfEm+QBCo/cSnMLrduj2zEBwuK36f/pEQQX8Um92EraaKlv/z/2Hwm/9E4NxZAufOAnIr\nvdqPfeKWmqIUUmGfOgs0xTnvKrNia22j/vd+n7Gf/pjJ559j8vnnAKj+4G9R9fCjt4WzcLGcyOm1\nk1ybtlF+zwG8h15j+KknU9c0/OGfYOtYoetzlzs535zHH3+csbGxeX9/4oknuP/++wH4xje+gdls\n5rHHHlP9YI/n9hBemahvLOP8yX68UyHWbtQvfjoWkW3JbR3VWG1map74Y3qb5DR5Z1srrlUrsdXn\n97zFmM/KagcTowHKy+w5U//V4vG4CQailFfak2N2U/fXf0H3975P4GYnTR98P+Vbt9xSYZY+l5Ik\nYbYYiYRius5xKOmU9tSV4fG48Tz2IB0P3sfwiy8x+tobNH3wfVTv26t6nItNJCg7NRNxUfNzF7pe\nTMjmzqaWSjweNzX/+fP4H32IYE8Pwd4+XCtX4jlwd/4D12mcxUbOt/XJJ59c8POf//znvPbaazz1\n1FOaHjw6Oq3p+qVEMMpCxTcZ0nWck+MBzBYjvukwJGOCHe95FAAJmAam83iex+NelPlUmh50dY5T\nWV24rdvjcTPQP0UkHMdqM80as/PhD+AEYsDY2K1zwmeaS4fLwtSUvmuht2cSAKNZmPW9pt1307D7\nbkQWfkcW6zfPRizp4B4dntb03FzjHB2SP4snEjPXVTVgqGrAtW2PfM0S/DuXej7zRe3mU5Cd4dCh\nQ3z729/mG9/4BhaL+qbEtzvKsdM7FdL1ewMamzffapyL0CpwOmnaUcxdxYDLbSUcjJGI61c3J6Ch\nyuftgMNpwWAQdDfF+DWUsS6hnoLO13/zN39DLBbj05/+NABbt27lL//yL/UY1y1FKUbkndRPsMfj\nCcKhONW1Lt2+c7FZjIgQxWbvLi8ewZ6ejVxWoU9dfH9qgysOgSYnKVkWJSrGZjdhNqtr7F5CHQUJ\n9hdeeEGvcdxWKCnUemrsWhpL3C4ojkM9X+Zpb1JTLRKBBmkRIdN6CnZlHopng3O5rQwP+BBFCYNB\nHx+IPKfFMwfFQinzNAMGg4DLbcWno8ZejEdOd1Lo6Bnipphi3MUk0Bahbo5/OoLJbMBqu/WRP2px\nlVmRJHRrbB2NxIlFE0VjjiomSoI9C64yG9O+sG71yFM2VZV9HW8HFG1y2qtfyGNRmmIWySTlcltv\ni1BGtSjzoFcIbDEqO8VCSbBnwVWe1E50SkyZidctHuep1WbCYjWltGw9mPZGVNWjv53Q2yQVi8n+\nlmIywwC43PJ49drgis2BXEyUBHsWUtqqTkJN0XqLSVMFKCu3Me0N61Zewe8L43RbMRqLZ+nNJOfo\nu8kXm0Bz6Zx9qrbBRgntFM/btcS4dV7ExSrYXeVW4jGRcKjw8gpiQiQwHSk6TdWuc6hfSqAVkQMZ\n0kwxemvsRTYPxUBJsGfBlXIc6qOx+7xhLFYjVlv2ZgK3I8pGpMcG5/OGkaSZTbNYUJp769Vowl+E\nDmSYEcC6bXAlG/uiURLsWdDz2ClJEtPeMGXl+oTKLSXuVN2cwoWaEj7qKrJTC8gbXGA6qkuS0kyo\nY3EJNLtDPrnodYpN+Z2KKKCgWCgJ9iwojiI9NPZwKEY8JhadGQbSNXYdBHsyfLTYNHYgFb+uh8/F\nX6Q2doNB55PLdASL1ahbHaISM5QEexasNhNWm4lpHbSTYrWvw8yY9Qh59E4GgeJKylFwV+h3cim2\nrNN0nGU2gv4ooljYyUWSJDnkswjXQjFQEuwLUF5h10VTTQn2Isyw01ewh2Z9ZzFRVj5TyrlQ/NMR\nrDYTZkvxaaoutz5hwJFwnGgkUco6XSRKgn0ByirtRCMJIuHCOtQrWl4xCjRZABl1MUEUsynGrZhi\nCtzgZE01UnRmGAW9fE/KWtCrREOJ2ZQE+wKUJxddoTZFRRiUFaFgFwQBV5lVN429WDXVMp1MMak0\n+iLc3CDNmV7gelBOPiWNfXEoCfYFKK9MCvYCtZNitrGDvCEVenKRJAnvVKho58DhtGA0GZj2FmaK\nmYlhL855KEu+E4XWUVI2yJLGvjiUBPsCpDT2As0QPm84lZ5fjLh0iIwJh2JFrakKgoC73Fawxp4K\ndSxSU0x5pbwWCq18OqOxlwT7YlAS7AugaCeFRMYoMezFqqmCPsdvRaAVW1JOOmXlNiLheEEnF8W2\nrJwGiw1XmQ1B0FFjL+L34namJNgXQA+NPZTsvFPMtsRULHsBgl0xRxVzeJvyGxZijil2wW40GnCX\n23TR2F1lVoymkghaDEqzugDu8qR2UsDxu9jt66BPyGOqDnt5cZogANzJzOFC1oMSy1+sgh3ksYcC\nMaKR/E4uibiI3xcpaeuLSEmwL4DRaKCswo53In/tRLElLgvBXsDJxZ/snFTM86BHZIx3MoTdaS5a\nfwvM2MXznQdlHZXs64tHSbDnoKLKTjgUy7u64XLQ2O0OczIiJH9fQzE2sZ7LzMklv40+kRCZ9oYp\nr3ToOawlRzlt5NsTuBTquPgUpDZ89atf5eWXX8ZgMFBdXc0Xv/hFPB6PXmO7LSivcsCNCbyTIWx2\n7ZUZZ2LYi1c7EQQBd4Gx7FMTQSxWE3ZHcVW3TCelqeY5D9PJ6pbFbIaBdI09T8E+mXwninwebmcK\n0tg/85nP8Mtf/pJf/OIX3HvvvXz961/Xa1y3DRVV8uKbGg/mdf+Mxl68tmWQtVUlZFEroijhnQxR\nU+sqqlZwc0nVD8rTBKGY9IpesCshjwVr7MU9D7czBQl2p9OZ+u9QKITBsPwsOxVV8rF5ajI/we7z\nhrHZzUWZbZmOklKfz8s87Q0jJiSqa525L77NcZfbknXltXeUUtaQoiwUKwVr7KnkpJIpZrEoWNp8\n+ctf5umnn8btdvPUU0/pMabbivKkYM/HgSpJEn5vmOpal97DWnIqq+V5mBwPUFOn7d+jnHZqlsE8\nlFXYGBv2EwxENdcR9y2T+ihmsxGny5J3LLtvKoTZYszLtFlCHTkF++OPP87Y2Ni8vz/xxBPcf//9\nPPHEEzzxxBN861vf4nvf+x6f//znVT3Y43FrH+0toL2jGrPFiN8X0TxmnzdEIiFRU+ta9H/vYn9/\n+4oa3uQ60VBC87OuXxgBWJJ50IOFxljXUM7NK2MYMWj+twT9sgN+5eparLbCT3C3ci6ra130dE5Q\nWenAZDIueG36OCVJwucNU1XjpLa2bLGHqYliWJtqybm6nnzySVVf9Oijj/IHf/AHqgX76Oi0qutu\nJR6Pm7ExP+UVdsZH/YyM+DTZiHs7JwCwuyyL+u/1eNyLPp8Gs/zv7uuZ1Pysvp5JAKprF3+chZJr\nLk0W2dzY0zWOzaVN4xwdnsbhtOCbDkGB07AUv/lCOJwWkODm9bHUaS4Tc8cZDESJRRM4Fvmd0Mqt\nnk+1qN18CjKKd3d3p/775ZdfZsWKFYV83W1LeZWdeEzU3OtxYjQAQLWn+G3LTpcFs8XI5HhA871T\n40EEAapqijvMD/KPZU8kRPy+cNE7ThXyLQZWcpwuDQWdB7/0pS/R2dmJwWCgsbGR//bf/pte47qt\nSDlQJ0Ka4rAnxmQhWFlT/IJdEAQqaxyMDfkRRVGTo3xyIoi73JbzyF4MKGtB6wbnm1oeoY4KqVh2\njQ7UkuN0aShIsP/jP/6jXuO4rSlXQh4ngjS3V6q+b2IsgMEgLJuXubLaycjANN7J8ILH73TCoRjh\nYIy6huVhv3SX2zBbjIyPaBPsqVICRR4Ro1Be0thva5ZffOIiUJFHZIwkSUyOBamodmA0Lo9pTkXG\njKkXalMTyRA/lRvB7Y4gCFTXOpmaCBKPq4/pXy4x7AqKxq1VY1fWTrGHfN7uLA+Js8ikkpQ0xLL7\nfRFi0QRVy8AMo1BZo5gh1M+DEuqobI7LgSqPC0mCyTH186AIwOUi2K02M1abSXNew9iwH4vVVNQl\nNoqBkmBXgdVmxuYwa9LYFcfpcnAYKlRWy5uUFvuysgksF40dZpzhym+shuWmsYN8gvNNhojH1J1c\nYtEEUxMhauqKOwO5GCgJdpVUVNnxTYVIJERV1yuO06plEBGj4C63YTQZNGmqisau1iZfDCgJZ+Oj\nftX3eCdDOFyWos9ATsdT70aSYFzlBqfM13JIVLvdKQl2lVRUOpAk9WFuyykiRsFgEKiosjM1HlSd\nUj81EcRqMy2rLEPFvKbWgRoJx5n2qnc4Fws19bJDfHRIXfz32LAs2Ks1Zi6X0E5JsKskPTJGDROj\nAYwmw7Lz/lfWOInHRVWVHhMJEd9UmIpqx7I6elttJtxlVtWmGEXw1TbcXpmWheKplwW0WsE+PlLS\n2JeKkmBXiWJSGVOxiEVRYmo8SGW1A4Nh+Qg0SK8Zk3uD802FEEWJymXkOFWoqnURDEQJBaM5rx0Z\n9AFQu0xCPhUqqx2YTAZNGrvBIKSc8CUWj5JgV0ldo6xtDfX7cl477Q0Rj4vLKiJGIeVAVWFnV65Z\nTo5TBcWBqsYcMzwgrxllDS0XDAYD1XUuJsdyh36Kosj4aICqGueyCf+9nSnNsErsDgsVVXaGB3yI\n4sL25YlRWaAtJ8epwkzIo3qBphzZlxNqHaiSJDEyMI3TbcHpLu6a/Jnw1LkRRSnnBuedCJGIiyX7\n+hJREuwaqG8qJxZN5EzQmXGcLj9NtbzSjsEgpOylCzHQO4XBIFDXWL4EI1taUiGPOQRaYDpCMBBd\ndvZ1BbV29jHFvl4S7EtCSbBroK5ZMcd4F7wuFeq4DE0xRqOB2sYyxob9RMLZ+8DGonHGhvx46t2Y\nLcVfI2Yu5VV2jEYhZ6jfyKDiOF1e9nUFj8rIGCUipuQ4XRpKgl0D9U2y5jnUl93OLkkSw33eZZ1d\n19xWgSTBQM9U1muG+mWTVUPL8tPWQbYvV9Y4mRgLLGiaW672dYXKGtmBOja08AkuFepYEuxLQkmw\na6Cy2oHFalpQY58cDzLti9C6onJZhfiloxRC6+uazHrNYK88R40tFUsypltBtcdJIi4u6G9QNHZF\ns11uGAwGqmtdTIwFsjpQJUlibMSPu9ymS4ORErkpCXYNCIJAfVMZvqkwwUDmMLeeG+MAtKyoXsqh\nLSm1jWWYzAb6urNr7AO98mf1zctTYwdobJM3uO7r4xk/F0WJ0aFpKmtkhWC54ql3IYpS1rj+oD9K\nOBgr2deXkJJg10h9k3ykHs4S9thzU+6a1LqiasnGtNQYjQYaWyuYGg/iz9B8JB5PMDLgo6bOtaw1\ntPZV1QgCdF6d3zoS5HIKsWhi2TpOFXLZ2ZV3Yrmao25HSoJdI3WKnT2DOSYaiTPY68VT75Jbhy1j\nmpPaan8Gc8zIwDSJxPK1ryvY7GYaWysYGZzOuMEp9vXl6jhV8CT/fdlMc9cuDgOwan3tko3pnU5J\nsGukrtGNIGROVOrrmkQUJVqXsRlGoSkp2Pu657/Mg0kzzHK2ryt0rKkBoCuD1t6f7PW63DXVqhon\nVR4nXdfG55kofd4Q/d1T1DeXL9tggtuRkmDXiNliorrWxeigj3BodrhfygyzcvmaYRSqa53YHGb6\nuybnFQQbSDpOl7vGDtCxWhbsN6+Ozvq7fzrCjUujVFTZl71tWRAENmxrQBQlrpwbmvXZhdMDAKze\nUNLWlxJdBPt3vvMd1q1bx9RUdmfacmLNxjoSCYmTR2aaeUuSRM/NcWx207K3qYL8Mje3VRDwR2cV\nRgsGogz1e6msdmB3LG9zFICrzEZtg5uBnqlZG/25432IosTWPS3LNjoqnTUb6zCaDFw6Mzhroz9/\nsh+DQWDlOs8tHN07j4IF+9DQEIcPH6axsVGP8RQFm3Y04S6zcu5Ef6rK4cRogMB0lJaOqmVX+Csb\nTcmwx7PH+1N/e/2Fa8RjIhu3v3PWQ8eaGiRpJjomEo5z8fQADqeFNRvrbvHolgarzcyqdR68k7Lp\nBeTQ38E+Ly0dle+ITf52omDB/nd/93f82Z/9mR5jKRqMJgO77ulATEgcfb0T72SIF56+CEB78mj+\nTmD1+lqqPE4unhrg3Ik+blwe4eaVUeqby9m0s+lWD2/J6Fgja6PnT8kb/cUzA0QjCTbf0YTJtPyy\nbrOxYZu8mV86M5A0ywwCsGrDO2Nzu50oKBbtlVdeoaGhgbVr1+o1nqJhzca6/7+9u4tpMkvjAP6v\ntIDDOKaK06DD6CwOG4gFRhPdgURtbeSjVlFRboymDUZvrCB+hKJGA8aAqJekxAjRZDTK2myI0Wym\nWiEIIsYFN6Q6bHAcjAVRMhSj9OvZC9dO2NJqzOgp5fndnSYn+acfT09P3/c56Or4DY/+PYBfe19g\n7I0H6Uu/mVI/OWXRUuQVKPH3c/fQ+nMvZNFSREmnQZX31ymx/fCOfPYXSPxOjt/6hvGT+Q6ipNMg\ni46aUr9aAEAx7yvI47/Af+zP8fiXFng8Psiio/Dd95F/MUG4eW9h1+v1GBoK/Me/uLgYZrMZZ8+e\n9T/2oafqRAKJRIK/rfwLrl56ALfLixU5yf4Vy1QyY2Yscjcq8Y+f/gXXmAc/qpIi6uDqD5W3KQ29\nPQPobP0Vvw+/RvrSRMTERs6pUR9CIpFg8Y/z0fLPX/DVzFjMmhOHH5Z+G1HHAU4WEvrIavzo0SPo\n9XrExsa+7Y8yMACFQoHLly9j9mz+hmaMMVE+urD/P7VaDYvFgpkzI/8SN8YYC2d/2nXsEolkSm3F\nMMZYuPrTVuyMMcbCA995yhhjEYYLO2OMRRgu7IwxFmGEFXa73Y7CwkLk5+ejoKAADx48EBXlvc6f\nP4+cnBzodDrU1NSIjhNUuPfsqa6uRm5uLtatW4ddu3ZhdPT9B2J/Ts3NzcjJyUF2djbq6upEx5mQ\nw+HA1q1bkZeXB51Oh3PnzomOFJTP58P69euxc+dO0VGCcjqdMBqNyM3NhVarRVdXl+hIE2poaMCa\nNWug0+lQWloKl2vig378SBCDwUAtLS1ERGSz2WjLli2iooTU3t5Oer2e3G43ERG9ePFCcKKJPXv2\njAwGA6lUKhoeHhYdZ0Ktra3k9XqJiOjEiRNUU1MjONEfvF4vaTQa6u/vJ5fLRWvXrqXe3l7RsQIM\nDg5ST08PERGNjo7S6tWrwzInEVF9fT2VlpbSjh07REcJ6sCBA9TY2EhERG63m5xOp+BEgRwOB6nV\nahobGyMiot27d5PFYgk5R9iKXSKRwOl8e+KK0+mEQhGe/SQuXLiA7du3Qyp9e/fcrFnh2ZJ3MvTs\nyczMxLRpb99yGRkZcDgc75nx+XR3d2P+/PmYN28eZDIZtFotrFar6FgB5syZg5SUFABAXFwckpKS\nMDg4KDhVIIfDgVu3bmHTpk2iowQ1OjqKzs5ObNy4EQAglUrx5Zfh2WLZ5/Ph9evX8Hg8ePPmDb7+\nOnQbZGH3+paVlaGoqAhVVVUgIly8eFFUlJAeP36Mzs5OnD59GjExMdi/fz+USqXoWONMxp49jY2N\n0Gq1omP4DQwMICEhwT9WKBRhvT0IAP39/bDb7UhLSxMdJcC7hca7xVs46u/vh1wuR1lZGex2OxYt\nWoTy8nLExobXgSAKhQJ6vR4rV67E9OnTkZWVhczMzJBzPmlhD9ZnpqSkBLdv30Z5eTk0Gg2uX78O\nk8mE+vr6TxknqFD9cLxeL0ZGRnDp0iV0d3ejuLhYyEpusvTsCfWaq9VqAEBtbS1kMhl0Ot3njheU\nyOfsY7x69QpGoxEmkwlxcXGi44xjs9kQHx+PlJQU3LlzR3ScoDweD3p6enD48GEolUocO3YMdXV1\nMBqNoqONMzIyAqvVips3b2LGjBkwGo1oamoK/fn55BtEQSxZsmTcePHixYKShFZUVEQdHR3+sUaj\noZcvXwpMNN7Dhw8pMzOT1Go1qVQqSk1NJZVKRUNDQ6KjTejKlStUWFjo3y8MF/fv3yeDweAfm81m\nMpvNAhMF53a7yWAwUENDg+goEzp58iStWLGC1Go1ZWVlUUZGBu3bt090rADPnz8ntVrtH9+9ezcs\n/w+4du0alZeX+8cWi4WOHj0aco6wPXaFQoGOjg4AQFtbGxYsWCAqSkgajQZtbW0AgL6+Png8Hsjl\ncsGp/pCcnIzW1lZYrVbcuHEDCoUCFoslLBuxNTc348yZM6itrUV0dHgdvKBUKvHkyRM8ffoULpcL\nV69exapVq0THmpDJZMLChQuxbds20VEmtGfPHthsNlitVpw6dQrLli1DdXW16FgB4uPjkZCQgL6+\nPgBAe3s7kpKSBKcKNHfuXHR1dWFsbAxE9EE5he2xV1RUoLKyEj6fDzExMaioqBAVJaQNGzbAZDJB\np9NBJpOhqqpKdKSQwrlnT2VlJdxuNwwGAwAgPT0dR44cERvqf6KionDo0CEYDAYQEQoKCsLyQ37v\n3j00NTUhOTkZ+fn5kEgkKCkpwfLly0VHm5QOHjyIvXv3wuPxIDExEcePHxcdKUBaWhqys7ORn58P\nqVSK1NRUbN68OeQc7hXDGGMRhu88ZYyxCMOFnTHGIgwXdsYYizBc2BljLMJwYWeMsQjDhZ0xxiIM\nF3bGGIswXNgZYyzC/Be68EGj7hfMcwAAAABJRU5ErkJggg==\n", - "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0x7f385e198650\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "def f(x):\n", - " return tf.square(tf.sin(x))\n", - "\n", - "def grad(f):\n", - " return lambda x: tfe.gradients_function(f)(x)[0]\n", - "\n", - "x = tf.lin_space(-2*pi, 2*pi, 100) # 100 points between -2π and +2π\n", - "\n", - "import matplotlib.pyplot as plt\n", - "\n", - "plt.plot(x, f(x), label=\"f\")\n", - "plt.plot(x, grad(f)(x), label=\"first derivative\")\n", - "plt.plot(x, grad(grad(f))(x), label=\"second derivative\")\n", - "plt.plot(x, grad(grad(grad(f)))(x), label=\"third derivative\")\n", - "plt.legend()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "-39gouo7mtgu" - }, - "source": [ - "## Gradient tapes\n", - "\n", - "Every differentiable TensorFlow operation has an associated gradient function. For example, the gradient function of `tf.square(x)` would be a function that returns `2.0 * x`. To compute the gradient of a user-defined function (like `f(x)` in the example above), TensorFlow first \"records\" all the operations applied to compute the output of the function. We call this record a \"tape\". It then uses that tape and the gradients functions associated with each primitive operation to compute the gradients of the user-defined function using [reverse mode differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation).\n", - "\n", - "Since operations are recorded as they are executed, Python control flow (using `if`s and `while`s for example) is naturally handled:\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "MH0UfjympWf7" - }, - "outputs": [], - "source": [ - "def f(x, y):\n", - " output = 1\n", - " for i in range(y):\n", - " output = tf.multiply(output, x)\n", - " return output\n", - "\n", - "def g(x, y):\n", - " # Return the gradient of `f` with respect to it's first parameter\n", - " return tfe.gradients_function(f)(x, y)[0]\n", - "\n", - "assert f(3.0, 2).numpy() == 9.0 # f(x, 2) is essentially x * x\n", - "assert g(3.0, 2).numpy() == 6.0 # And its gradient will be 2 * x\n", - "assert f(4.0, 3).numpy() == 64.0 # f(x, 3) is essentially x * x * x\n", - "assert g(4.0, 3).numpy() == 48.0 # And its gradient will be 3 * x * x" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "aNmR5-jhpX2t" - }, - "source": [ - "At times it may be inconvenient to encapsulate computation of interest into a function. For example, if you want the gradient of the output with respect to intermediate values computed in the function. In such cases, the slightly more verbose but explicit [tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context is useful. All computation inside the context of a `tf.GradientTape` is \"recorded\".\n", - "\n", - "For example:" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "bAFeIE8EuVIq" - }, - "outputs": [], - "source": [ - "x = tf.ones((2, 2))\n", - " \n", - "# TODO(b/78880779): Remove the 'persistent=True' argument and use\n", - "# a single t.gradient() call when the bug is resolved.\n", - "with tf.GradientTape(persistent=True) as t:\n", - " # TODO(ashankar): Explain with \"watch\" argument better?\n", - " t.watch(x)\n", - " y = tf.reduce_sum(x)\n", - " z = tf.multiply(y, y)\n", - "\n", - "# Use the same tape to compute the derivative of z with respect to the\n", - "# intermediate value y.\n", - "dz_dy = t.gradient(z, y)\n", - "assert dz_dy.numpy() == 8.0\n", - "\n", - "# Derivative of z with respect to the original input tensor x\n", - "dz_dx = t.gradient(z, x)\n", - "for i in [0, 1]:\n", - " for j in [0, 1]:\n", - " assert dz_dx[i][j].numpy() == 8.0" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "DK05KXrAAld3" - }, - "source": [ - "### Higher-order gradients\n", - "\n", - "Operations inside of the `GradientTape` context manager are recorded for automatic differentiation. If gradients are computed in that context, then the gradient computation is recorded as well. As a result, the exact same API works for higher-order gradients as well. For example:" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "cPQgthZ7ugRJ" - }, - "outputs": [], - "source": [ - "# TODO(ashankar): Should we use the persistent tape here instead? Follow up on Tom and Alex's discussion\n", - "\n", - "x = tf.constant(1.0) # Convert the Python 1.0 to a Tensor object\n", - "\n", - "with tf.GradientTape() as t:\n", - " with tf.GradientTape() as t2:\n", - " t2.watch(x)\n", - " y = x * x * x\n", - " # Compute the gradient inside the 't' context manager\n", - " # which means the gradient computation is differentiable as well.\n", - " dy_dx = t2.gradient(y, x)\n", - "d2y_dx2 = t.gradient(dy_dx, x)\n", - "\n", - "assert dy_dx.numpy() == 3.0\n", - "assert d2y_dx2.numpy() == 6.0" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "4U1KKzUpNl58" - }, - "source": [ - "## Next Steps\n", - "\n", - "In this tutorial we covered gradient computation in TensorFlow. With that we have enough of the primitives required to build an train neural networks, which we will cover in the [next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/3_neural_networks.ipynb)." - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "default_view": {}, - "name": "Automatic Differentiation", - "provenance": [], - "version": "0.3.2", - "views": {} - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb deleted file mode 100644 index d268cbcd9171b0f4a4f2ab27ad958374e521685b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb +++ /dev/null @@ -1,209 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "U9i2Dsh-ziXr" - }, - "source": [ - "# Eager Execution Tutorial: Importing Data\n", - "\n", - "This notebook demonstrates the use of the [`tf.data.Dataset` API](https://www.tensorflow.org/guide/datasets) to build pipelines to feed data to your program. It covers:\n", - "\n", - "* Creating a `Dataset`.\n", - "* Iteration over a `Dataset` with eager execution enabled.\n", - "\n", - "We recommend using the `Dataset`s API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.\n", - "\n", - "If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly simpler.\n", - "You can use Python iteration over the `tf.data.Dataset` object and do not need to explicitly create an `tf.data.Iterator` object.\n", - "As a result, the discussion on iterators in the [TensorFlow Guide](https://www.tensorflow.org/guide/datasets) is not relevant when eager execution is enabled." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "z1JcS5iBXMRO" - }, - "source": [ - "# Setup: Enable eager execution\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "RlIWhyeLoYnG" - }, - "outputs": [], - "source": [ - "# Import TensorFlow.\n", - "import tensorflow as tf\n", - "\n", - "# Enable eager execution\n", - "tf.enable_eager_execution()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "H9UySOPLXdaw" - }, - "source": [ - "# Step 1: Create a source `Dataset`\n", - "\n", - "Create a _source_ dataset using one of the factory functions like [`Dataset.from_tensors`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensors), [`Dataset.from_tensor_slices`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensor_slices) or using objects that read from files like [`TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TextLineDataset) or [`TFRecordDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset). See the [TensorFlow Guide](https://www.tensorflow.org/guide/datasets#reading_input_data) for more information." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "WPTUfGq6kJ5w" - }, - "outputs": [], - "source": [ - "ds_tensors = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])\n", - "\n", - "# Create a CSV file\n", - "import tempfile\n", - "_, filename = tempfile.mkstemp()\n", - "with open(filename, 'w') as f:\n", - " f.write(\"\"\"Line 1\n", - "Line 2\n", - "Line 3\n", - " \"\"\")\n", - "ds_file = tf.data.TextLineDataset(filename)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "twBfWd5xyu_d" - }, - "source": [ - "# Step 2: Apply transformations\n", - "\n", - "Use the transformations functions like [`map`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map), [`batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch), [`shuffle`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle) etc. to apply transformations to the records of the dataset. See the [API documentation for `tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) for details." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "ngUe237Wt48W" - }, - "outputs": [], - "source": [ - "ds_tensors = ds_tensors.map(tf.square).shuffle(2).batch(2)\n", - "ds_file = ds_file.batch(2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "IDY4WsYRhP81" - }, - "source": [ - "# Step 3: Iterate\n", - "\n", - "When eager execution is enabled `Dataset` objects support iteration.\n", - "If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that there is no need for calls to `Dataset.make_one_shot_iterator()` or `get_next()` calls." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 153 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 388, - "status": "ok", - "timestamp": 1525154629129, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "lCUWzso6mbqR", - "outputId": "8e4b0298-d27d-4ac7-e26a-ef94af0594ec" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Elements of ds_tensors:\n", - "tf.Tensor([1 9], shape=(2,), dtype=int32)\n", - "tf.Tensor([16 25], shape=(2,), dtype=int32)\n", - "tf.Tensor([ 4 36], shape=(2,), dtype=int32)\n", - "\n", - "Elements in ds_file:\n", - "tf.Tensor(['Line 1' 'Line 2'], shape=(2,), dtype=string)\n", - "tf.Tensor(['Line 3' ' '], shape=(2,), dtype=string)\n" - ] - } - ], - "source": [ - "print('Elements of ds_tensors:')\n", - "for x in ds_tensors:\n", - " print(x)\n", - "\n", - "print('\\nElements in ds_file:')\n", - "for x in ds_file:\n", - " print(x)" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "default_view": {}, - "name": "Eager Execution Tutorial: Importing Data", - "provenance": [], - "version": "0.3.2", - "views": {} - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb deleted file mode 100644 index 84f1d031d40604ae029e8a8347474950ee01b38a..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb +++ /dev/null @@ -1,485 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "k2o3TTG4TFpt" - }, - "source": [ - "# Training Models\n", - "\n", - "In the previous tutorial we covered the TensorFlow APIs for automatic differentiation, a basic building block for machine learning.\n", - "In this tutorial we will use the TensorFlow primitives introduced in the prior tutorials to do some simple machine learning.\n", - "\n", - "TensorFlow also includes a higher-level neural networks API (`tf.keras`) which provides useful abstractions to reduce boilerplate. We strongly recommend those higher level APIs for people working with neural networks. However, in this short tutorial we cover neural network training from first principles to establish a strong foundation." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "3LXMVuV0VhDr" - }, - "source": [ - "## Setup" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "PJ64L90aVir3" - }, - "outputs": [], - "source": [ - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "tfe = tf.contrib.eager # Shorthand for some symbols" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "eMAWbDJFVmMk" - }, - "source": [ - "## Variables\n", - "\n", - "Tensors in TensorFlow are immutable stateless objects. Machine learning models, however, need to have changing state: as your model trains, the same code to compute predictions should behave differently over time (hopefully with a lower loss!). To represent this state which needs to change over the course of your computation, you can choose to rely on the fact that Python is a stateful programming language:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "VkJwtLS_Jbn8" - }, - "outputs": [], - "source": [ - "# Using python state\n", - "x = tf.zeros([10, 10])\n", - "x += 2 # This is equivalent to x = x + 2, which does not mutate the original\n", - " # value of x\n", - "print(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "wfneTXy7JcUz" - }, - "source": [ - "TensorFlow, however, has stateful operations built in, and these are often more pleasant to use than low-level Python representations of your state. To represent weights in a model, for example, it's often convenient and efficient to use TensorFlow variables.\n", - "\n", - "A Variable is an object which stores a value and, when used in a TensorFlow computation, will implicitly read from this stored value. There are operations (`tf.assign_sub`, `tf.scatter_update`, etc) which manipulate the value stored in a TensorFlow variable." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "itxmrMil6DQi" - }, - "outputs": [], - "source": [ - "v = tfe.Variable(1.0)\n", - "assert v.numpy() == 1.0\n", - "\n", - "# Re-assign the value\n", - "v.assign(3.0)\n", - "assert v.numpy() == 3.0\n", - "\n", - "# Use `v` in a TensorFlow operation like tf.square() and reassign\n", - "v.assign(tf.square(v))\n", - "assert v.numpy() == 9.0" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "-paSaeq1JzwC" - }, - "source": [ - "Computations using Variables are automatically traced when computing gradients. For Variables representing embeddings TensorFlow will do sparse updates by default, which are more computation and memory efficient.\n", - "\n", - "Using Variables is also a way to quickly let a reader of your code know that this piece of state is mutable." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "BMiFcDzE7Qu3" - }, - "source": [ - "## Example: Fitting a linear model\n", - "\n", - "Let's now put the few concepts we have so far ---`Tensor`, `GradientTape`, `Variable` --- to build and train a simple model. This typically involves a few steps:\n", - "\n", - "1. Define the model.\n", - "2. Define a loss function.\n", - "3. Obtain training data.\n", - "4. Run through the training data and use an \"optimizer\" to adjust the variables to fit the data.\n", - "\n", - "In this tutorial, we'll walk through a trivial example of a simple linear model: `f(x) = x * W + b`, which has two variables - `W` and `b`. Furthermore, we'll synthesize data such that a well trained model would have `W = 3.0` and `b = 2.0`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "gFzH64Jn9PIm" - }, - "source": [ - "### Define the model\n", - "\n", - "Let's define a simple class to encapsulate the variables and the computation." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "_WRu7Pze7wk8" - }, - "outputs": [], - "source": [ - "class Model(object):\n", - " def __init__(self):\n", - " # Initialize variable to (5.0, 0.0)\n", - " # In practice, these should be initialized to random values.\n", - " self.W = tfe.Variable(5.0)\n", - " self.b = tfe.Variable(0.0)\n", - " \n", - " def __call__(self, x):\n", - " return self.W * x + self.b\n", - " \n", - "model = Model()\n", - "\n", - "assert model(3.0).numpy() == 15.0" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "xa6j_yXa-j79" - }, - "source": [ - "### Define a loss function\n", - "\n", - "A loss function measures how well the output of a model for a given input matches the desired output. Let's use the standard L2 loss." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "Y0ysUFGY924U" - }, - "outputs": [], - "source": [ - "def loss(predicted_y, desired_y):\n", - " return tf.reduce_mean(tf.square(predicted_y - desired_y))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "qutT_fkl_CBc" - }, - "source": [ - "### Obtain training data\n", - "\n", - "Let's synthesize the training data with some noise." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "gxPTb-kt_N5m" - }, - "outputs": [], - "source": [ - "TRUE_W = 3.0\n", - "TRUE_b = 2.0\n", - "NUM_EXAMPLES = 1000\n", - "\n", - "inputs = tf.random_normal(shape=[NUM_EXAMPLES])\n", - "noise = tf.random_normal(shape=[NUM_EXAMPLES])\n", - "outputs = inputs * TRUE_W + TRUE_b + noise" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "-50nq-wPBsAW" - }, - "source": [ - "Before we train the model let's visualize where the model stands right now. We'll plot the model's predictions in red and the training data in blue." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 293 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 1210, - "status": "ok", - "timestamp": 1527005898290, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "_eb83LtrB4nt", - "outputId": "3873f508-72fb-41e7-a7f5-3f513deefe38" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEDCAYAAAA2k7/eAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXlgU1X2xz/pAhRautCWUsCwWVlcUHHGBUFQcSg7uM8P\nFLUICo4VpygObihI3UdmUHBB0IGZQbEgFNGCqKgMolV2pKylCy1pukDp+n5/3LxmaUsDTUjSns8/\nbZKXd09C+b7zvvfccw2apmkIgiAITR4/TwcgCIIgnB9E8AVBEJoJIviCIAjNBBF8QRCEZoIIviAI\nQjNBBF8QBKGZENDYE+Tk5JCUlER+fj7+/v7cdtttTJgwgcLCQhITEzl27BidOnXijTfeICQkxBUx\nC4IgCOeAobF1+Hl5eeTn59OrVy9OnjzJ2LFj+ec//8mnn35KWFgYCQkJLFy4kKKiIh5//HFXxS0I\ngiCcJY22dKKioujVqxcAbdq0oXv37uTm5pKWlsaYMWMAGDNmDF999VVjhxIEQRAagUs9/MzMTPbs\n2cNll13GiRMniIyMBNRFoaCgwJVDCYIgCGeJywT/5MmTPPLII8ycOZM2bdpgMBhcdWpBEATBBbhE\n8CsrK3nkkUcYNWoUN910EwDt2rUjPz8fUD5/REREg+eRtj6CIAjuo9FVOgAzZ86kR48e3HPPPTXP\nDR48mE8//ZRJkyaxcuVKbrzxxgbPYzAYyMsrdkVIbiUqKkTidCESp2vxhTh9IUbwrTidodGCv23b\nNlavXk1cXByjR4/GYDCQmJhIQkICjz76KJ988gmxsbG8+eabjR1KEARBaASNFvwrr7yS3bt31/na\n4sWLG3t6QRAEwUXISltBEIRmggi+IAhCM0EEXxAEoZkggi8IgtBMEMEXBEFoJojgC4IgNBNE8AVB\nEJoJIviCIAjNBBF8QRCEZoIIviAIQjNBBF8QBKGZIIIvCILQTBDBFwRBaCaI4AuCIDQTRPAFQRCa\nCSL4giAIzQQRfEEQhLOk0GTi84R7+XbIDXyecA+FBSZPh+QULtnTVhAEoTnx7YzHuDflUwyAlv4z\nizEwfNFiT4fVIJLhC4IgnCWhhw9hsPxusDz2BVwi+DNnzuTaa69lxIgRNc/Nnz+fAQMGMGbMGMaM\nGcM333zjiqEEQRA8TqHRiGb5XQMKjV08GI3zuMTSGTt2LOPHjycpKcnu+YkTJzJx4kRXDCEIguA1\nXJ/8OosxEHr4EIXGLlyf/JqnQ3IKlwh+v379OHbsWK3nNU2r42hBEATfJjQ8wic8e0fc6uF//PHH\njBo1iqeeeori4mJ3DiUIgiA0gNsE/+677+arr74iJSWFyMhI5s6d666hBEEQXMLRjAwW9u3FWmN7\nFvbtxeGMDE+H5FLcVpYZERFR8/vtt9/O5MmTnXpfVFSIu0JyKRKna5E4XYsvxOmNMb53xQhmZh1T\n5Zalx5h3ww08cfSop8NyGS4TfEe/Pi8vj6ioKAC+/PJL4uLinDpPXp73Wz9RUSESpwuROF2LL8Tp\nTTEWmkx8O+MxQg8fIjory67cMtZk8po4z4SzF0+XCP706dPZsmULZrOZG264gWnTprFlyxZ2796N\nn58fHTt25Pnnn3fFUIIgCC7FdhHVXFSZpcHyM8vGqWgKuETwX3311VrPjRs3zhWnFgRBcCu2i6ju\nBp4JDKR7QACZ4RH839dfezAy1yMrbQVBaNbYLqK6AOgaP4L4w7lMSt+NsXt3T4bmcqSXjiAIzRpf\nXUR1LojgC4LQrPHVRVTnglg6giA0WXy1jbG7kAxfEIQmi6+2MXYXIviCIDQZbGvqC41Ggg9k+GQb\nY3chgi8IQpPBMaOfE9vRrq7eV9oYuwsRfEEQfB49s/dbn2qX0V8QEcHiq/7odAWOyWRmxoyNHD7c\nFqOxkPffHwX4uzv884YIviAIPk2hycS/B1/HJVnH2In9StnK7heelWc/Y8ZGUlLGAwbS0zWmTFnO\n/PnD3RK3JxDBFwTBJzmakUHquOFE5mTTpbqaAcD1wDygQ1AQ1UOGnnVN/eHDbcHmHuHgwWDXBu1h\nRPAFQfBJUscNt3a2BJYDdwG9gRNDhp5TNY7RWEh6uvUeoWvXEhdG7HlE8AVB8BkKTSY2PDqVgB+/\nI8ZstvPrg1HCvz22I3ec42rZ5OTBwFKLh1/EggUjqapyTezegAi+IAhejz4pq23aQBuzmWHAAuz9\n+t8CA8mPH8Edya8RGn5uXS7tu7w3vS1aRfAFQfBq9ElZR/vmbuBZwGgwkN0hlqEr19C5a7dGjdXU\nJ22ltYIgCF7NtzMe4xKL2IPVvrkAuAgwjBzDpPTdjRZ7aPqTtiL4giB4NaGHD1GC1WDR7ZvktqFs\nbH85r2eMJiHhUwoKzI0ey2gstBtJJm0FQRDciF5u2anARGZ4BC169eIBlI3TBsuk7MbNPJ60Sdkv\nuQa279CApSxaNOasx7NdbNWhw0mGDn2P7OxImbQVBEFwF/rEbNba1cysqKjZSPyF6mo+GzWW0MOH\nOGHsUjMp62i/HD7cttZK2eTkwYSHh51xXEffftSopaxffyMAERHes/euKxDBFwTBK9D74HwO9u0R\nCs3E11FT71gzbzQW1RJvZ7L+ui4cTRWXCP7MmTP5+uuvadeuHatXrwagsLCQxMREjh07RqdOnXjj\njTcICXFuZ3VBEJo+u7Zt48vRQ+ladpqDBgNRrVtjAIqxL7fMrKfE0rFmPjl5EHfcsY2zFe+6LhxN\nFZcI/tixYxk/fjxJSUk1zy1cuJBrrrmGhIQEFi5cyDvvvMPjjz/uiuEEQWgCfDkmntllp5XMahpP\nnzyJBsQDy4BC/MhoFcawxcvqfH94eFit7P1cxLuuC0dTxSWC369fP44dO2b3XFpaGh999BEAY8aM\nYfz48SL4giCwa9s20sbG0+10qZ110x14JSwMf8LZZO7HKt6G0+Hs/8dSFi3q69S5z0W867pwNMS5\nzBV4A27z8E0mE5GRkQBERUVRUFDgrqEEQfAQZyN8+qTs6VUruUjTOIS9dZMNxAwczN8Pjyc9fXTN\n+87GUz8X8T4XzmWuwBvwuknbqCjf8PklTtcicbqW8xXn1Kmf2wlfy5bL+fe/76p1nPnECVbc1J/e\nmZmUAEOBpajOltHAAYOBsBtvZMz7i1g3ZZ2dLRMXV4qfXxUPPZTKwYPBdO1azIIF8UREnJ+Muq7v\nMisrHNu5gqyscJ/423Cb4Ldr1478/HwiIyPJy8sjIsK53ha+UAIVFeUbpVoSp2uROGuzb18QtsK3\nb18QeXnFtTL/+PIVzMjMtGuN0BUYDsxqFcRfjuQCUFEFs2dfT1mZ1ZaZPXsQ99+/qubCsnWrRlnZ\nUubNG+R2W6W+7zI21oTt/UlsbIFH/zacvdi4TPA1+65DDB48mE8//ZRJkyaxcuVKbrzxRlcNJQiC\nl1DfJKlueRgwcUH6FPBbb+fXtwF+A7a0CuLmVevszlmXLVNX6aQnbRVfneh1ieBPnz6dLVu2YDab\nueGGG5g2bRqTJk3iL3/5C5988gmxsbG8+eabrhhKEAQvoi7hKzSZCN34D17hUfIoYS4VLKu29+t3\nderEHWnfOd3Vsq4Liyfr58/XXIGrcYngv/rqq3U+v3jxYlecXhAEL8VW+ApNJj57YAKl337NDSgp\n7mD5GY+yccotO1FNfn8RuXknSUhY6ZQlU9eFJSlpQ7Opn3cVXjdpKwiCb/LtjMeI/fZr7sKayb9k\n+RkG3AkstuxEFRYRwr33fe60JaNfWPS5gTvu2Far742v2CqeRARfEASnqasMs9BUwOJxCVyanU4+\nUIgSeAOqffFLQPuwMAwDB3N98muYTGamTv2c9evB1pLJyPBvMOOvr++NyWQmKcn36uLPNyL4giA4\nja3g/pqeT+DqP3BJ9UHmY83ql6E2J9GAn4Hc9pezLCqRblRzHX4251iGrbNvMh1mx44nOVPGX59v\n76t18ecbEXxBEJxGF1wDJ5jM5fyjOrNWs7NC4G3gd/zZenUS3/74ol0LY6toK2c/KKiCIUPgwIE4\nsrLOPAnrOHl7/PguDhzowaZNucDnqE488U26AVpjEMEXBKEW9a2g7djhGMb0eG5mHa3R6mx2to6r\nWcUPwCrC9uRbXjEDqaxfD+HhO4GBNe8oKytl69Z8evUKtjuTPglr36++nPbtnyY39yrgJFlZUxg7\ndgFm85PY3mMYjZVn/BzNFRF8QRBqoSySEcA60tPD2bp1CY/cW0nv1GfpDeQBuWDX7Kwc2A2s4l+W\nV04C+ZbfU4E7KS01UFqqERT0DKWlLYGZVFcbyMrSqK5+gVGjate2O9o1YWGvACNrYi0o6ITtPUZY\n2GmSk2+u873N3eoRwReEZsDZZrrKElkH3Ik/WxmSNYaCOdX0B0qAB4BVwNNAJ9QFIB9Y1vZeKNoO\n/Aj8iWuuWU6LFktZvx5KS62iXFraDyXS1ucKC411irGjbw/tsL0TCA8/Smmp9fHAgQE1n6059bp3\nBhF8QWgG1JXpnqk1gdFYyK/pp4nnEv7ATiKBUGCA5edyIALoDOwliNfIJCzsM7744g/MmfOz5Zyr\nSU4eTnh4GAkJn5KSYmv8nLT8tBXuzDpjd/Ttr7mmmhYtrHcCM2eOYs6cule9Nqde984ggi8ITRDH\njP7AgTY01JrgxImX+emnYsrKuhKifcxf2EAw0BdqGp6lAnehWiMUA7tpxRtsB8Ixm1vx7LPf0aJF\na8s41nYrtgunjh/fRVbWFEs8y/DzKyYm5gQrV1ptGltqL7q6pdbdyaJFRiff27xr9UXwBaEJ4ijm\nsbFzcJwQdbQ7Nm8uAm0ioxjFzeykEHgCa06+HNCnVf8H/MKlrGUssMvyTDybNy+kqOhBHD1z2xW5\nBQVXMmvWOvbtC8JorCQ5Of6M9lJj2hj4agsEdyGCLwhNEEcxj4jowlVXnbk1gb92kkR6MM/yzCrs\nnfM2wK/Atxh4mXuB14A1qJ6X6hynToXSkGceHh7Gv/99l090Hm1qiOALQhNEedcFqInXlvz++06O\nHGmFn18nOnSoAuztjuh2+7k07Q36Y5XrEuzLLb8DXmYDMAhQ1TLV1aUUFS1BOfoltG5toqiocZ65\nlFK6DxF8QfBRMjIOM27cKgoKOhEefpSVK0fRtavyspOTB7N16wKyslR9elnZGMrKlgHxpKauZfPm\nLwgOziW8bTsuP/4SxvTddMZe5IcCs4COwC78mc8e4D+Wo4rp1CmW7t0rSUmZgC7w/fq9xZ49cy0x\nZTJzZm1fXm+toCyd2oIupZTuQwRfEFzI+cxOx41bVSPopaUaI0c+w9VX9yArK5zYWBMREUa7lasQ\nhFoDO4PiIhN9i/7ENVk/0Qb4G/AKcDvKq28DbAPKgHXczKqaupyLgRGAxv79s3jvvTuxnRQtLw+0\ni2nOnKW1JlQbEnQppXQfIviC4EKczU6dvTDUd5zJZCYnR8O2nUBeXsuasdUuTHOxN2X2AH3wYz/j\nuYgYNHoDpZYjwlBVOCGWM5qBv/Moyqu3LacEMHD6tCrBtP18Q4akoZorpALBbNqUQ0GB2e6z2Qt6\nIZs25TJkSFrN55NSSvchgi8ILsTZ7NTZC4PjcWVl7wGQlpZLdfVMbNsJaFqE3dgFBbG0ajWL06fb\no0Q4GH/SmcooWgNXo8wZ/Qy3AWuBTGB/y1DSus7B//dDVFUtAyqBY8Bky/mV+Dt+PiXWa8HSJNls\nHk5S0lK71saHDlUCHwPDgLWYzY+Tnm79HqSU0n2I4AuCC3E2O63vwmCb0cfE5PH99/arUX/80Q+z\neSLUallWTnCwieJi69iqdcFsYmJe4FTO10xhA3HAPuBFrEK/BJiLMmwOGAL5vPsLxPVpz6fJg3n0\n0dWkpgLkoMR+Hcrw2QU8SEzMJ3YtjWfOvJJNm/6H2Xzmjpb6pC+0q3WslFK6DxF8QXAhzmanHTpk\nk57+L5SBUkSHDrZ7waoeNtAeVd9uFfGiohzL744ty/IpKTlOy5azKC/vhqYFoaZdDbSqzOMeNjCX\nusstw1E9cF7yv4viqo9hv4Hd+zVSU5+ksjIEg8FAy5a5lJXNQ9PiwLLQKiZmPgZDJCkp92N7pzJw\noL/dqlr9oud4kevS5UKMxsI6jxXcgwi+ILgQ57PTQLDbG+o9TCazpc1viuX1AcD1wDwgBmhBdXWU\n5fjrUHl5NKqTzd1o2mbKyu5CtTK7k0BWkMjtdM2HFtRfbvk9MI+HoeoKm6MKKS9vA1wClHD69KXA\nBJt3LeP06WNkZ3fA8U7l3/++krouenXd/Yh9c35xu+APHjyY4OBg/Pz8CAgIYMWKFe4eUhC8nuzs\nSGyFMjs7khkzNmI2P255vgBVUdMH8ENJdjzqYvAekAHMwX4dbIjl8XX48xBTeJvLLM9uwb7c8kng\nAiCdIBaRALyB/YYka1G1O/r5P8T+viAEaFeniNd30bMV97i4UmbPHiT2zXnG7YJvMBhYunQpoaGh\n7h5KEHwG+4VRbTh+fCdVVRdhFdV1wAzL4+G0aJFEefkBVG/KAqAH9gIcDBThzxb+j6uJRTnt+j1E\nf+ApoCfKfd9LIPN4CUjEKub6VuOlqEla2/PnYX9fUMw111Rb2hA7l6HbintUVIistPUAbhd8TdOo\nrq529zCC4HHOpgbfcWFUVtYITKZZwDisjQysgtuyZSTl5UlYBVffHlx/vJcW/EoiH2EE2qLuC/Qz\nhANGlNjPYxgwGnXXsAw4hf1W48ss77I9fy7qziIAP78sbrklnDfeGC4Zuo9xXjL8+++/H4PBwB13\n3MHtt9/u7iEFwSMkJq4hNbUt4E96egDl5Z/z4Yf/V+ex4eFhREf3tlsYdfp0F5RfHw38jvLvwwGN\n4uJg7DPuzsALQCcCWMOdfEJHqJmYreuSsA94jf0og8d2/mA2yh5S7REgwTKOyvZbtjxAWdlUoAug\nMWKErHz1Vdwu+MuXLycqKgqTycTEiRPp1q0b/fr1q/f4qKgQd4fkEiRO1+LNcZ44Yeahh1I5eDCY\nrl2LWbAgnoiIsFrHfPVVNqA6RYLGjz++WvO5Tpww88ADKWzapAF5DBgQRocO2PnfaguRGTaP56E8\n/BIgG3v5Pgp0pxWf8hc+IRi4FPtLQk9Urm4GdhDCAn4BuqPyfNsjL0c1QHseqEB1vDcAd9Kp0zx+\n/fVxpkxJZd++X8jP38vhw0amTl1d5/dwNnjzv7ktvhKnM7hd8KOiogCIiIjg5ptvZvv27WcUfF/w\n9XzFf5Q4XUNCwqqaUsmtW4P57rt/sHHjBMLDw2r62eTktKO6ugu2QlpcHMK+fUctG4Cssus5k5Ky\nBPgNeBWIRAl4H+yFuDd6GwNYgLVBcQkBHGIqM7kQ5ei3pnb1zS6gCEjmJ9Qq226Wcxc5HKkvv7rC\n8rt1nLCwzlRV+TN//nASElaSnj6DzEx9Edi5Z/re/m+u40txOoNbBb+0tJTq6mratGnDqVOn+O67\n75g6dao7hxSEs8IZ3z0jwx94B1XXspOsrF4MGrSEjRsn2PWzUatHrUJaWdmKQYOWEh3d27K61FbM\nNVT3Guvyp8DAX6ioGIO9ZBuAdJSFcyd+7Gc0/ehOUU0bYw1l7tyLtQ/O98AxgviI7aisvrvlqF7A\nXtQU7oVAS9RkrS781cDdNWfu3n1pzfcgPW58H7cKfn5+PlOnTsVgMFBVVcWIESPo37+/O4cUhLPC\nmRYHJtNhVCHjcvQtQbKyNJKSllJQYFuHPgyVscehWo/dR1bWf8nK8kdNehage/JwANs+OFBJUFB7\n2rV7kZycSNS0613AZtQdwAECuZ1EVnARqqmZ7eWjI+oeIBz4BQMvk4Ta+1XP6kNRtf0VwHPAEWAp\nEAXMx8+vkN69+9K5cxHwHtnZkbJdYBPErYLfuXNnUlJS3DmEIDQKZ7LWdu3iLJOr9hOnq1ZVomm7\nsWb1eulxgeXnRqADavJ1OJCEqoTRd4PNRU3QLgBKKSp63tJL/hmU4P8XmI4BE4MZSD920hvV0cYP\ne1PGhDJqnuIOVOb+PKp/zjKgHFWRU2zzGb5HZfnqDDExc9mwwdrKWL/zueOObTV3PrJIyveRlbZC\ns8aZrLVbt5Ns365qz21lVrUviEOJahCqQUEgyjL5K9a+MwuAKajFSvYNz2AkyqefZxnNgLJdDgNt\n8eN1pjKDi6gkFHUPEYqa2l2GtbNlJvAmM1C1OXrzhDCUPbPaMsYCwsJ2YzYPx/Hi1a5dnN1nru/O\nR6pzfBsRfKFZo2etGRmtMZn2kZFhJCHhUzsv33qMH/v3P83p00bUQqRhwHpU24NiVLZ+P8qqWYeq\naTegxHYJyj5xXK2q/x6MkvA2qJLMSIJYwiNs4VpqbyLeFdiBWob1Bd1YxR2oOwioPX2rHgcGmtiy\nZQJJSUvZtCnHIvzqmG7dTtl9L+LXN01E8IVmjb5w6J57PmbHji5kZYWwY0cOP/zwHuXlFwD5XHNN\nMG+8MYJZs75jx47nsbY+eB3ohxL7oSj/Xq+jz0ZZKmGWn0eAVjiuVlVoKKPmYcBAC/L4Cw/QAlUh\nb9s8Qd9E/DCQg4G5bEU1MzuFaocQgrJwnkDdfRxH1c8vY8CAkJrPW1BgJimpfntG/PqmiQi+0OQ4\n212nTCYzX32VhbWG/l8cP/4Mutilpi6jRYuNZGWFY93cIwNV6a7L8XxUhYttHf0ylKWi96XRPfVi\nlOtegrJgwlGZfSEB/I9HeYCXUFOqtvcDbVDSvhmYxzxURq8Bn6CqbabYjP0CMBbdVoqN3cE//zm+\n5jM3tEJW/PqmiQi+0OQ42z1RZ8zYSEVFP6zyGoKj9ZKSchz4BiW5T6KyedvVqi+gJktt31cIvI+q\njLH11FehBD8Y/QLhzxbGE04HrF1yjlG7q2UJLfgHPwArLec5iTJ42tuN3bZtB6677hNLtY2Z5OTx\ntS56Z7owSsuEpokIvtDkcPSfMzL87TbpePLJK5k792ebTUayUTX2+i5MjguTTKiKmj4o774QVSrp\n2Opgj8P7TqKmWG1ragpQ9fVRwDEMHORmRtGXHXRB1ebonW3uRuX/eqOFDYRyqtfrxBauo23bzhQV\n7aBduzgOHy6nqGgnaq5AjT1oUIsGBVs2C29+iOALTQ6r/1wIrGXv3kPs2KGqY9LTNdasmUVl5eya\n15V4ZwMXofZvLUJZNsrDV4+fw75LDdiLeybKdHkS1aYsFHjA8vM9VK+aXqhFVOpcLXmTKXSnDfZe\n/RKsPStPAj8Bb/MhsbGZpG+6tdbn7dv3LYqKpqAvuwoK+onk5IRaxzkiE7PNDz9PByAIrsRkMlNe\nXkFY2AcEBLwCDKWiwr7LTGVlZ8vjVNRkaxFqknMsSozboiY6W6BE+yrss/k+KL9cL4FcjppwDUDZ\nQS1R+XmY5fho1H+1/6F8/+W04s88yqNcBfzB4ewRqPqeA0ApLXibn4AJhIZ2r/Mzq5LKcJTFNJKe\nPS8/45yFjtFYiLrEgEzMNg8kwxe8nrOZhH300S9Yt05tuWetbdFw3A5Q/QxGTWr2xl5y+6Hq4/X3\nV1PbqgkDLkbZKDp6q4IdDsf/BDwGrCKA9fyFJfRE4xDKvsHh6L3Ad0AyM7Dtf1lYmFHnZ7auE1DH\nXXjh6TN9nTXIxGzzQwRf8HrOxmv+8UfbLvB6bcsAVHVMIcq6Kbc8PoaycRzr1k/avL8Y5bnvRmX9\nB1CLqqC21/+75Zi7gVmoydTjwP34cZw7uI8LqKpZLfsAyux5DGsPnP8BvxDLWraj7gqWoyZ9Aykp\naUtBgbnWxc5RuBcsGElVVcPfq0zMNj9E8AWv5+y8Zj1710V4p817v0cJcg9U1h2E6lLZBiW90ZZj\nJlse630oo4CHULZJAaqRWm/LuZdg7SMfClwLfInK9A1AJa14iGmspSWq4YFt82Mj8E/LmX/AwFv8\nRHT0ajiul4BqqN2n/CkqaklS0sZaIu0o3BERvtHhUTj/iOALXo+zi4BMJjMtWxZjbTlcSfv2p+jQ\noYqYmFOsWxeNmjgNQVXbPIG1cuYN1H+HauADVNOx6Vjl+VUgFtXorA/KytmL/cbeT6IuAN2Bv+HP\nVhKYRDBVhKHW49ree8SiMv1iYBVhFF74MqN676CkJJy0tGVAlkMMS2RiVWgUIviC1+Ho2c+ceSWO\nXrPtMR06ZAOB/PCDH2ZzT/SOM4GBc7jiigt4440refTRNahM3LZ2XpffdcCzNs/PcXjdgLqABKP6\n4kRZXg/HWk+Ti6rq6QQYaM10pvI6ccAh1KXAcQeqXagcftfVf+XzVbNqPv+QIWmoLQhXO8QQjtFo\nbvwXLDRbRPAFr8Pesy9g69YFREf3tpuwTUhYaXPMv7AX8uXAXVRUXEpqan9++WW+peVwFUpiQdkx\noGrsq7AX1hhqL3tqgbXR2WzUHMCtKBtHb5u8kAC+ZzjzuAhqvPo4y1nuxtp4YR/wHZEcaD+Jrz+c\nbPf5rXc09s3aYmN3kJw8HkE4V0TwBa/D3rNfR1bWk2RlWSds580bxKZNucBnqLr2AOBDy/GjUR0r\n56JWn75ETs5L2Fe5t8Bq59S1+2suKovXNwjMBfoC/0JZOlGoOv3lqHmAUYCBEN5iCjsJwbbxMDxt\n+WlEratNAl7hCoYOfYCvLRuB22Jt1uaPyTSXdu3i6NbtVJ2rZQXhbBDBF7wG3aZRu0Ppq17bYJt9\n792rcdllb1NW9kfURGknVL2LrdeeidqnNQJrWwMsP09TO6PvgrU12V7L43hUnX4R9nbPMlRWfxKV\nvz+PH/uJJ5o+VDAX1SvT9uzdUReAjqgan9dYQlBQFR9+OK7O70GqZwR3IYIveA22Vg5otG37EuXl\nJzl9Wm8ZUMC+fXuprn4Re4G3ldeLUNOhQ1HevGPVTjFqktb2Ob2RgV4FvxtlEd2Ftbe8fv5y1F3E\nt0B/gunNFPbQBWtdTrHD2fcAJ4C5PI1a2KURGvqi6744QXASEXzBa3AsvywpiaK6uhxYCORjMBRS\nXd0fewFuR+3e7yFY+9G/i/1WIUUoF/0pVO69H2XLrEb5+WGoCdqXUP89CrHtUaNkPRQD2VxPF66h\niDhUBX6X0UVuAAAgAElEQVRLyxHxWHtiHgR+IJhv+NYS0xLgd/r0EWtGOP9IawXB45w4YSYhYSUH\nDuzFdql/dfUhVOVLS+AhNK071kVSYF3s9CyqNn4Z1lYJuhV0G9bSS/189wDXAPcB/ijhH2455/2o\n7cCfQOXl01F2zyrURSKHAJ5kEg9yDUX0Rjn8D6LuJf6Gala8C7WI6oer/8qivbsIC/sSVc4ZCEzn\nxIm62yQIgjtxe4b/zTffMGfOHDRNY9y4cUyaNMndQwpegG3ZZExMHgZDJdnZHepsjfDQQ6kWK8ex\nX/x0rJt+L0fVzt+OypL1TUNOowTeH+XNL0CJ/SFUZh6GyvR160evrNmCEnRQUv0qtdsi2/aoAX/2\ncieP0Qkl246LqK6wRP078D3h7G8/hc/evIvw8DAGDowmJcW6w5T0rRE8gVsFv7q6mtmzZ7N48WKi\no6O59dZbufHGG+neXbKbpo6jH6+EfDTp6Rrl5e/QokXrmjr7I0d0K0fvF78ENXG6DjWRql8AilCZ\nfABqQrYUJdIXUbsscwLwIsryse1cqS+gMqIyeX3B1KWoUk1be2h/zWMD+dxOEp1Qa2mzqb2Iapcl\nor/zHPA05GrMmbOURYuM0rdG8ArcKvi//fYbRqORjh07AjBs2DDS0tJE8JsBGRn+WCtfirGVx82b\n8ygq6g74k54eQIcOv6AmQnWhPWY51rZ0ch5KmF8GrkZZO9NRG4E4ZubBKHHvbvndtsGZXrnTEmuZ\nZXeUXD+OtavNVmASBhYxhme4kZyada/hqBoix0VUR4BlrETdbahY9JWxUnkjeANuFfzc3Fw6dOhQ\n87h9+/Zs377dnUMKHka3cvbu3U99PeSLiqqwzchPnnyRsLBXMJs7oCpkOlN7pWs0ag9Z2wqd5aiL\nQ0vs5fc3lGUzHXVn8aHl+TxUM7OHgR9Qwr4AlZf/EVv7BvJoxTKmMZN5DiPehSoYnYOq9P8dSOav\nQDLWuxn1WcW6EbwJtwq+pmkNH+RAVFSIGyJxPRJn3Uyd+rnFyvkMW8E2GNqiaUtQAh2DdW/YYIqK\nqhk6tA2pqX6orQIN1M6hg1Btix07YZajBFvvn3MUlbFnoCyhLOy3F1mG6pXzrOW5EahGadZiSj/2\n8Sce4BKUfeM4Iqj7h2LgZ+C+las5tKyYgwdX07GjCU2rICtrNV27lrBgwUgiIs7/34ov/H36Qozg\nO3E6g1sFPyYmhqysrJrHubm5REdHn/E9vtDlLyrKN7oReiLOffuCUNKod3pUQqtpLVFTnX1Q2fda\nrFn+cNLSnqBt21YUFenyOgwl4hEosR9qeY/tRWAr1lYJlUCZ5fl01DKnO1HzAbaSHYK6IDjePagW\nyi1ZwZ9ZSRRK7B0bJ+9EXbIOA/P4KzCPqsX1t2uuqjr/f9O+8PfpCzGCb8XpDG4V/EsuuYQjR45w\n7NgxoqKiWLNmDa+99po7hxRsUOWOq5zaOMRVqD4wBajVrh+ibJTTqHLIO1HSeT3wH2xFt7z8QgwG\n6ySpyqFjUcuWdGtoKMoaCkNV1kSiBL81+mbg1sVYRcBbqAzfceGVfZ8cP78f8av+HxN5kWiUQXQZ\nSuyHYu/qlwDbCWAZe1AXDqSDpeAzuFXw/f39mTVrFvfddx+apnHrrbfKhO15xFrueP42qU5OHszW\nrQvIyrLtJvMSyh/XbZxAVAXMIpS9UwSUU1b2V1Stew+sE7ftUNXtF6JEvhR1Z6B78Ccs4ziutu1v\nGddoOacR5d/HoCqBVJ+cmBgThTmnmcrrNVO8rbGK/TqsG5OUAIb7JnHqxLWQ0s0ynvj0gu/g9jr8\nAQMGMGDAAHcPI9TBwYPBOL9xyJlxdpvB8PAwoqN7k5VlK8DtgV9RkqnbOONQojsCdVF4EXVR6Im6\nIPwN6wVjBipTvxh157Ac1YuyBEgE3qb2att1KMG3nW69HVXW+RtgwJ9D9Mx5mb5Qk9lXAL9YzqqL\n/fdArrErr//8ExVVgRQUmJESS8EXkdYKTZiuXYvZurXhjUOcwXGbQb2WPiOjNSbTXiIiutC9eyXJ\nyYOJicnDXoBboeri11HbT9d/jwIWo5oRnLa8pxRVNtkVtSh8JKoLpm255nLUBWW25Ryhlvd84zBW\nBaq0cwYQTkve5EFeJghVgW9bxT8Pa9u0X4G+H3zMjcNGEGbZSUpKLAVfRQS/CbNgQTxlZWfORPXM\nvS7hts3gHfvc/PBDMWbzg+gymZX1Pjt2BLFmzVpURfoLqJbCe1GLnlRFTm0/HcvvIVgbmC1Bib6+\n4YgJlYNrqDsAx7qZwyg//3eU7/8ZyhKy9sDx89tD9+4XcOD3KQzj31yE2tMqHzUlbHvGCNQ9QC6w\ns/f/MX2YbR2/IPguIvhNmIiIhjNRxxWxWVnL2bFjZK1NRxy3GSwpsb0AFKJE/lkqKx27wFdYfgaj\nJmuXo7L3LagFSgtQ3vpfLOcyoKpt9K0D9bJJtayp9sYkO1Fi74eaatVQ62BNGAwLMBgKCAgopLz8\nSTJ/f5XH+DctsC/UdOyGvxdY3fmv9L7iYj4Su0ZoQojgN3McM3clzKvsNh3ZsuUFIiO70bLlLMrK\nugIFVFaWohqSrUMJdBeH8/RC9YzvgzJJ/FENyu5CyeoW1MpWfd1qqOW9GsqnL0RV46iWxMHBwbRu\nncHJkyGcPDkLuBJ1FzAFa0Y/E1sZ17TOaFoorQK2MLG8KxEUcrHlXbaRhqIWUYUDu6KiefS7//FE\neIQLvl1B8C5E8Jsp1s1Gcqg94Wm/yjUnpx05OX7AH1AZ9RSsbvdc6l4odQRrqeQIm2MvRpVq9gJS\nULtP9QdmWc5/EjVluhblxa8FWtO2bQGXXRZJaupk9L48+lixsVmUlMTY1PAb0Dcab8F7TDr1Fn1Q\nS7JKLKPbRhqGarV22YrV3DZgoAu+XUHwTkTwmxG2lTbHj+8kK+shlOwtIyTkFOXlBygr80ctWtJ3\nnApFudlTsIq33mCgN9YLg75QKhrlpV+OfR6t7yk7EiXYek2+vvq1o+U1nWLgH+hZe1aWRnb2U8BS\nlGe/gKCgIMLDs4mIMFJdfYCiIpvaevZzEy25iHIuR80QBAKnLL+/iJrizQUyW7Zk8jdb6Ny1G4LQ\nlBHBb0bY+/WjgPctrxRQXNwG5YPbNv3VO0teQG3bR29yZrtQqhglpzEoJ9w2j24DVGP1823PV4i6\nSNger0u09ThNuxp1UVCWTXi4ucZ6ggJiY+cSHd2bnN8/5s8nV9ADlc3bVuC8iqrSH44ylHq/9TZT\n7ri7MV+rIPgMIvhNGMeVthkZAdgLbQFK0O+zPLbfzi8oKJLQ0AxycsB+Zep2AgK+oby8M0pC26EE\nXpU8qvLKu7GuUf0NdRFohVoEFYSSXF2GQ1H5ti7HJSg7Zz72F4GdqBYIYfj5taekRP8cAOFEtA1j\n4P4JtDhZXNPw7ANq32f8AmwA/mgptxSE5oIIfhPmgQdSSElR1S7p6RrR0S9gK6ABAcFUVtq2Frbv\nHBMenkV09CXk5NyAEu8yoAXV1X+mvPxfqIlafU3qfJTYg8r030bZNDstj5+ynONFVEbvKO6foTJ6\n2wtBEQbDU5bM/iQwGVXeeSctWhykqKhnTbxBPMWf9syhB8qmsZ3yrVXT87fnmPlIoku+Y0HwJUTw\nmzCbNtlPvppMkSi/3AAcoqoqFNiBVWSHoiZXewO7aN26DXv2bANyUPZNN5QL/jFK7HeiRPtt7Gvs\ny7BfHDUHqxVkQElxLPbibsDffzdVVXrXSwNt215ARUUwpaW23n4psbFzadu2M3v2DMOfF3iAp4lE\nTfmWoNbTrkXdY4xCFYgagX1Az7feZqRYOEIzRQTfx6ivxUFdzzvWo1RXF6AmX5cBT6BpytYJCHgG\n6EBl5WGUtbINuICMjN/RtBmo0kt9kdW/UBuRLMde1Gdh3SzcfkMSgyHC0iq72CaeocTEvMjp07EY\nDCauvroNYCQ19YGacw4atJStW49SWmr9DLGxOaSnTyMh4VP279nCgzxNR2qvvT2NMpbyLaMa3nqb\nv4rQC80cEXwfo74WB5s2VWI2twRuID09FFjK1Ve3JDX1JZS1coyIiALy8x0nTcMJDu6C2TwRtcI1\nEHgMNUmql17G2hyvi7njxGtfVG/6LNRCKqtI33ijgV275pKV1cVyvjhiY/ewceM9hIeH1bSgLSgw\n06KF/cpgs7mQMWPmUlDQifDwTFauHMnRjAx6bnyEKyiqKcB0XHt7CDUzkN82lAlfbpIKHEFABN/n\naKjFgV4yefhwW7p00YBpNa/17fsOO3a8YKmpt9opRUU5qOy8LepPwlY+9SZluoAXYW2LYOuOH0Jd\nWKpQ3S5fom3bKAYNakFy8jAAkpI2cvhwT4uYj6/VfK2uHjXh4WGkp0+refzh669wdO7zRKNaIDiu\nHNCAHwEzEP3W20yXrF4QahDB9zEcWxyoCpnaJZNGYxHHjkXYvZafH0OfPhXk5LRCTZqGAK2orn4I\nlQ8/i6qksfXWT6ImVZfTqlUZoaEZ5OYuQS2YeslyjmKUp1+NukNQq2mvu+49OwFvTMOxXdu28Wn8\nYDprGiGWT51nGd22Z/1m4GhQax7/+nvJ6gXBARF8H0H36A8caENs7BxLk7MqysurSE21XgDCwvYw\ncGABTz55Bbfeuhpb8TYai9i06TQwFaugv4+qfAFlyVyCtY/8LtS+sGHAnUREzKWg4EJUnxtFQMBz\nVFY+jaqLWYtqoaA2B8/OjnTJZ1/98VL2JD7MGyhhn24T/VxUfVBHVGfLuLfe5nHJ6gWhTkTwfQTH\nJmeXXvoe0IKjR1sTGzuXdu3i6NbtFMnJdxIeHkZCwkoyMyej574xMb9SXh5JUVE7lH0TjxLyAlQd\n/nKs1TS6NdQH+CcBAZFERBwnK2sq6uKgoQt8dXVny/kqsDY8U6tnjcZKwPle+nXxzovPU/LmK/S0\njOLY2bIT6rK0u20od4lXLwhnRATfwzgrho7e/Y8/+mE2Wy8AV11l3c3KZDKzaVMlqi7+LgCOH99h\n6UNjK+h3Uv8kbAVwjKFDI/jww7sZMiSN48f15z9E7Vg1nerqcMv5PrR7f1jYaZKTbwZqTzQ7s/NW\nWspnbEmYQBuse8sOpfZWJzuB61es5o/SA0cQGkQE38M4K4a1vXt9az8A+92sZszYiNlchbJWQoAi\nqqvD7I4PCqrA3382JSX+qBW2O7H37gNRxY7v2Yy/FvssXu+pY8DP7xjV1db4Bg4MqLlwOV6sGtp5\n6z8LF3D0bzO4Cvu2CMtRl6VZqLqhA35+jFi3kd59Lz/j+QRBUIjgexhnxVDV1VtLFsvL29h590Zj\nUc3dwvr1oNab2u4rOwfb3HjIEPjii3KsneGvB55BbczdApVPG2p8+OTkwWza9CVms2MBJIDGLbdE\n1Cqp1HG8WNW389aWDRv45s7RdEXtcWVfza9G247aB6vlW28zQ7x6QTgr3Cb48+fP5z//+Q/t2rUD\nIDExUfa2rQNnxdCxZLGumvWkJFuf374vjiqvXEZY2GkGDgwgOXkQ69dX2RwTDlyFqrixdrLU4wkP\nD2PgQH9SUmwXQe0gOrqamJh8gHptKceLVV07b+kWzlUood9B7f2xvgcKWrfmwY1SgSMI54JbM/yJ\nEycyceJEdw7h8zgjhnWhXwD0rP6OO7ZZetvrXWRKsG5Q0gb4BX//Mq65xkhy8gjCw8MID8+yW8Wq\nO+V610nHeGrHOr5mgjgl5X7qs6XOtAfsrm3bWDlyCK0qKojCauH0x9pBPwzVmu1SaYsgCI3CrYKv\nVmoKZ+JMYujMhO6jj37BunVKbK37wd4DDMVgeBlNe9Hy2giqql4lNXUKLVooQV65cpRlFWsHNO0A\nXbv2IC5udZ2LovRY580bVBNTUtIGkpMHn7VHD8q++e6uMcRpGq1QbdG+xv5+oyewBwiZ+wp/u39S\ng+cUBOHMuFXwP/74Y1JSUrj44ot54oknCAkJcedwTQZd6Otql+B4cfjxRz9sxTYg4DQXX/yZpea+\nu4PnrpqS6YLctavRbhWrM9Q1yWw0ak7ZUjqrP17KvsSHa/bK0uvpY7G3cHYAwdNncKeIvSC4hEYJ\n/sSJE8nPz6/1fGJiInfffTcPP/wwBoOB119/nblz5zJnzpwGzxkV5RsXBXfFeeKEmZtu+pjMTH17\nQGs1TFZWeK1xDYYT2MpkSEgxv/zyIACjRy+289z1n3FxpWcV/4kTZh56KJWDB4PZv78a2wtMVlY4\n69Zdz5Qpyzl4MJiuXUtYsGAkERG1z3/49995+7rr0PLyuBb7GYYY1KaFy1BtEQ4FBjLhhx+49Mor\nnY7zfNDc/z5diS/ECL4TpzM0SvA/+OADp467/fbbmTx5slPH5uUVNyak84Le7MsdJCSsIjPTdutA\na7uE2NiCWuNefXUbUlP1LpXFXHGFH6NHL7HYQBUMHfoemZlhnDixj4gII927L2X27EHk5RU7vQYg\nIWGVzWSw/d61sbEFVFX5M3/+8Jrjq6pq/ztu2bCBNXeOph2qyfJOVF2QXsV/AHVZy42MYsSaL7nN\nMinrTX8P7vx3dyW+EKcvxAi+FaczuM3SycvLIyoqCoAvv/ySuLg4dw3VpFB2i307ML1dgu0Eqi7W\nmZnRxMbutWm10NbOchk1ailpabcAt9Qay9k1APYe/TDCwl6hS5cLnZpkLjSZ+PD2MZz67Rcisd9A\nUe+8vxXVxrjrW2/zkEzKCoLbcJvgv/zyy+zevRs/Pz86duzI888/766hmhQxMXnArVhbIvzGpk33\n1Mq8HVst6CtthwxJw9kJVGcnW+1LR0MZOLA9ixbd2OBnKTSZePe6fgSeyKcDEIf9fUsUkI5aQnaD\nbDcoCG7HbYKfnJzsrlM3aQyGSlS/GmXRXH55O6daLehi7Wxd/9kce7alo4UmEx/eOoLAHdvpgloC\n1prabYx/B4au38TAmwf4xG2zIPg6stLWy8jO7oCavtQff1bncfWJta04x8WVMnt2/eLsrJCfqXTU\nkS0bNrDqztG0R7VAsF3nexfWNsbfA/1XrJa2CIJwHhHB9zIam3XbinNDE05nI+QNUWgy8cn/3U7Z\nT/8jGrVm19a+CQcWoBZRHbj8Sh5Y/gmh4REuGVsQBOcQwT+POFMV8+STV7J1q76l31FmzhxV57lc\nKdaNpdBk4tUr+hB56iRdgQyUjWNr32QBtGzFdau/kKxeEDyECP55xJmqmLlzfyYr60nAQGmpxpw5\nS1m0yOiJcJ1CX0TVDvsKnGewt286zn1FFlAJgodp8oJfV1ataZzzhhyNwZmqmHNpU+AJ0lI+Iz1h\nAj1QmyN2wN7CuQDV1fJXoK9U4AiCV9DkBb+urBo46w05XIEz/vzZVNl4gkKTifWJD3MkdY1da4SZ\n2Fs4+1FCP12EXhC8hiYv+PVnzOc/i3amKuZcu2eeD45mZLDwun6EVVfVqqmPRO2EG4US+/C/PSdZ\nvSB4GU1e8OvOmM+u2ZercGai1ZsmY3UKTSa+eHgSpWnrCUNtObgT1XxZb41QBJQBmZf2JfG/n0kF\njiB4IU1e8OvPmL0zi/Y2jmZk8J8Bf2RuRTnLgenozZZVa4RoVEYfcNUfeOCj/4jQC4IX0+QFv76M\n2duyaG9k17ZtpA4dRFfq3ua8N/CjwY/79hwQoRcEH6DJC75w9hzNyODzUX/C73guc4FXULZNMbW3\nHBz6xUYRe0HwEUTwBTv0rL43qtfNEdTq2GUooX8JaAvkRbfn9tVfyN6yguBDiOALgEXoRw7hgooK\nLgGGoerrXwKmAGtRk7SFLVpy7efrZbWsIPggIvjNHL0C50Taeru6erXHFrRHZfeHUZ0tbxOhFwSf\nRQTfDTi7k5SnSUv5jG8SJhAFGKlrjy3YB1RFRnHXmi/FvhEEH0cEvwHqEu+GthNzdicpT1JoMvFz\nwgQ6A0+gsnjbCdm9wGZUVi/2jSA0DUTwG6Au8f7sswlnfI8398PRWyPkfLWei4BAVKTxKBvnJJCN\n2nLwZulXLwhNCj9PB+DtnIt4G42FqDwZvKkfztGMDN699CJOpK7huYoKgoBjqEjDgDtRi6hCbhzC\ntL2H+OOAgZ4MVxAEFyMZfgOcSzMzb+yHU2gy8emga7m2vAwT1qx+CfA00BXYHxjI7d9tFa9eEJoo\njRL8devWMX/+fDIyMlixYgV9+vSpee2dd97hk08+wd/fn6eeeor+/fs3OlhPcC7i7S39cMwnTvDf\ne+6hcPO3VBYXM1vTMAAfY83qpwFzAgOpvOkW7ntjviyiEoQmTKMEPy4ujvnz5/P000/bPZ+RkUFq\naipr164lJyeHiRMnsn79egwGQz1n8l68RbzPlqMZGSy89goiNI0eqEVU24FLUTX2r6I6XO5v2Yp7\nf9sjQi8IzYBGCX63burWX9M0u+fT0tKIj48nICCATp06YTQa+e2337jssssaM5zgJLp9E6VpdrtQ\nPY0S/FDADJyIiua2z9eL2AtCM8EtHn5ubi59+/atedy+fXtyc3PdMZTgQKHJxL8HX0e306VUYF9b\n3xVYDByL7ci9GzeL0AtCM6NBwZ84cSL5+fm1nk9MTGTw4MF1vscx4wectnMaqnH3FrwtTvOJE6Q8\n8ACZa9Yws6LCzqvXM/x9QI9Ro3j4/fcJi/Ausfe277M+JE7X4Qsxgu/E6QwNCv4HH3xw1ieNiYkh\nOzu75nFOTg7R0dFOvTcvr/isxzvfREWFeE2cu7Zt48sx8XQ9Xcpx7FfMDgNeADqixF7fW7aiyru+\nZ2/6Ps+ExOk6fCFG8K04ncFldfi2Wf3gwYNZu3Yt5eXlHD16lCNHjnDppZe6aijBhi/HxDP7dCn3\no1bM7sW6AiAU8IvtyIC9h5h+vEi2HBSEZk6jPPyvvvqK2bNnU1BQwOTJk+nZsyfvvvsuPXr0YOjQ\noQwbNoyAgACeeeYZn6zQ8WaOZmSQOm443U6X2vn03VFtEsqBnE6duCPtO/HqBUEAwKDVZbh7EF+5\nffJUnIUmE9/OeIyDa1fzXEUFy1BdLXWffhYQGhZG62uu488fLaGiKtAjcZ4NvnTbLHG6Bl+IEXwr\nTmeQlbY+gi702qYNtDSb6YZ9D5xS4ECrIG5eta6m/01YhG/8sQqCcH4QwfcRvp3xGPemfGqXydv2\nwJkT25G/pO/2ZIiCIHg5IvhejJ7Vhx4+hHbogJ1X3xO1G1V7Pz+yYzowdOUazwUqCIJPIILvxdhm\n9Y419dlhYcQMHMz1ya/JpKwgCE4hgu9l7Nq2jS9G/wljWRm5wALgblRN/SthYXTv0o1CYxfGiNAL\ngnCWiOB7GV+OiefFsrKaTH4ZkIry6SMHDub6RYs9GZ4gCD6MCL6X0a3stJ1XHwKYgoJYPGQo1ye/\n5sHIBEHwdUTwPYztxGyh0cjuwBZo5dYMvxioHjKU4ZLZC4LQSETwPYxduWX6z/x94CCe+vF7jGVl\nHDcYaHX9QMZIZi8IggsQwfcwoYcP2Vk4nQsLuftonidDEgShiSKbmJ9HCk0mPk+4l2+H3MDnCfdQ\nWGCi0Gi02e4cCo1dPBihIAhNGcnwzyOO9s1iDFyf/DqLMVg8/C4yMSsIgtsQwT+PONo3oYcPERoe\nIROygiCcF8TSOY+IfSMIgieRDN8NOJZaXp/8OqHhEWLfCILgUUTw3UBdXv3wRYvFvhEEwaOIpeMG\n6vLqBUEQPI0IvhsQr14QBG9ELB03IF69IAjeSKMEf926dcyfP5+MjAxWrFhBnz59ADh27Bjx8fF0\n69YNgMsuu4xnn3220cH6CuLVC4LgjTRK8OPi4pg/fz5PP/10rdcuuOACVq5c2ZjTC4IgCC6kUYKv\nZ/CapjVwpCAIguBp3DZpm5mZydixYxk/fjw//fSTu4YRBEEQnKTBDH/ixInk5+fXej4xMZHBgwfX\n+Z7o6Gi+/vprQkND2blzJw8//DBr1qyhTZs2DQYUFRXiRNjnD/OJE6Q+9BDBBw9S3LUr8QsWAN4X\nZ31InK5F4nQdvhAj+E6cztCg4H/wwQdnfdLAwEBCQ0MB6NOnD507d+bQoUM1k7pnIi+v+KzHcyef\nJ0yyLqLaupXFZZVM/OwTr4uzLqKiQiROFyJxug5fiBF8K05ncJmlY+vjm0wmqqurATh69ChHjhyh\nc+fOrhrqvCKLqARBaCo0atL2q6++Yvbs2RQUFDB58mR69uzJu+++y08//cTf//53AgIC8PPz4/nn\nn6dt27auivm8Umg0oqX/XLPloCyiEgTBV2mU4N90003cdNNNtZ4fMmQIQ4YMacypvQZZRCUIQlNB\nVto2gCyiEgShqSC9dARBEJoJzVLw69pbVhAEoanTLC2d+vrVC4IgNGWaZYYvpZaCIDRHmqXgS796\nQRCaI03e0qlrf1kptRQEoTnS5AW/Pr9ePHtBEJobTd7SEb9eEARB0eQFX/x6QRAERZO3dMSvFwRB\nUDR5wZfWCIIgCIomb+kIgiAIChF8QRCEZoIIviAIQjNBBF8QBKGZIIIvCILQTBDBFwRBaCY0SvCT\nk5MZOnQoo0aNYtq0aZSUlNS89s477zBkyBCGDh3Kd9991+hABUEQhMbRKMHv378/a9asISUlBaPR\nyDvvvAPA/v37SU1NZe3atSxatIjnnnsOTdMaOJsgCILgThol+Ndeey1+fuoUffv2JScnB4ANGzYQ\nHx9PQEAAnTp1wmg08ttvvzU+WkEQBOGccZmHv2LFCgYOHAhAbm4uHTp0qHmtffv25ObmumooQRAE\n4RxosLXCxIkTyc/Pr/V8YmIigwcPBmDBggUEBgYyfPhwgDrtG4PBUOs5QRAE4fzRoOB/8MEHZ3x9\n5cqVbNq0iSVLltQ8FxMTQ3Z2ds3jnJwcoqOjnQooKirEqeM8jcTpWiRO1+ILcfpCjOA7cTpDoyyd\nb775hnfffZcFCxbQokWLmucHDx7M2rVrKS8v5+jRoxw5coRLL7200cEKgiAI545Ba0T5zJAhQ6io\nqIMzjrUAAATvSURBVCAsLAyAyy67jGeffRZQZZkrVqwgICCAp556iv79+7skYEEQBOHcaJTgC4Ig\nCL6DrLQVBEFoJojgC4IgNBNE8AVBEJoJXiv47733Hj179sRsNns6lDp58803GTlyJKNHj+b+++8n\nLy/P0yHVyZn6HXkT69atY/jw4fTq1YudO3d6Ohw7vvnmG/70pz9xyy23sHDhQk+HUy8zZ87k2muv\nZcSIEZ4OpV5ycnKYMGEC8fHxjBgxwq6c25soLy/ntttuY/To0YwYMYL58+d7OqR6qa6uZsyYMUye\nPLnhgzUvJDs7W7vvvvu0QYMGaQUFBZ4Op05KSkpqfl+yZIn29NNPezCa+tm8ebNWVVWlaZqmvfzy\ny9orr7zi4YjqJiMjQzt48KA2fvx4bceOHZ4Op4aqqirtpptu0jIzM7Xy8nJt5MiR2v79+z0dVp1s\n3bpV27VrlzZ8+HBPh1Ivx48f13bt2qVpmvo/NGTIEK/9Pk+dOqVpmqZVVlZqt912m/brr796OKK6\n+eCDD7Tp06drDz74YIPHemWGP2fOHJKSkjwdxhlp06ZNze+lpaU1PYW8jfr6HXkb3bp1o0uXLl7X\nZO+3337DaDTSsWNHAgMDGTZsGGlpaZ4Oq0769etH27ZtPR3GGYmKiqJXr16A+j/UvXt3jh8/7uGo\n6iYoKAhQ2X5lZaWHo6mbnJwcNm3axG233ebU8Q2utD3fbNiwgQ4dOnDRRRd5OpQGef3110lJSSEk\nJMRrb01tWbFiBcOGDfN0GD5FXX2htm/f7sGImg6ZmZns2bPHaxdlVldXM3bsWI4cOcKf//xnr4xT\nT46Li4udOt4jgl9ff55HH32Ud955h/fff7/mOU9mfA31EUpMTCQxMZGFCxfy0UcfMW3aNA9EeXb9\njjzp7zoTp7fhbXccTYWTJ0/yyCOPMHPmTLu7ZW/Cz8+Pzz77jJKSEh566CH2799Pjx49PB1WDV9/\n/TWRkZH06tWLLVu2OPUejwh+ff159u3bx7Fjxxg1ahSappGbm8u4ceP473//S7t27c5zlA33EdIZ\nPnw4Dz74oMcE/1z6HXkCZ79PbyImJoasrKyax7m5uU73hRLqprKykkceeYRRo0Zx0003eTqcBgkO\nDuYPf/gD3377rVcJ/s8//8yGDRvYtGkTZWVlnDx5kqSkJJKTk+t9j1cZz3FxcWzevJm0tDQ2bNhA\n+/btWblypUfEviEOHz5c83taWhrdunXzYDT1U1+/I2/Gm7LqSy65hCNHjnDs2DHKy8tZs2YNN954\no6fDqhdv+u7qY+bMmfTo0YN77rnH06HUi8lkqrFJTp8+zQ8//OB1/8cfe+wxvv76a9LS0njttdf4\n4x//eEaxBy/08G0xGAxe+wf86quvcvDgQfz8/IiNjeW5557zdEh18sILL1BRUcF9990H2Pc78ia+\n+uorZs+eTUFBAZMnT6Znz568++67ng4Lf39/Zs2axX333Yemadx66610797d02HVyfTp09myZQtm\ns5kbbriBadOmMW7cOE+HZce2bdtYvXo1cXFxjB49GoPBQGJiIgMGDPB0aHbk5eXxxBNPUF1dTXV1\nNfHx8TX7ffgy0ktHEAShmeBVlo4gCILgPkTwBUEQmgki+IIgCM0EEXxBEIRmggi+IAhCM0EEXxAE\noZkggi8IgtBMEMEXBEFoJvw//5K32R/vBHAAAAAASUVORK5CYII=\n", - "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0x7f5be3c99f50\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Current loss: 9.48636\n" - ] - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "plt.scatter(inputs, outputs, c='b')\n", - "plt.scatter(inputs, model(inputs), c='r')\n", - "plt.show()\n", - "\n", - "print('Current loss: '),\n", - "print(loss(model(inputs), outputs).numpy())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "sSDP-yeq_4jE" - }, - "source": [ - "### Define a training loop\n", - "\n", - "We now have our network and our training data. Let's train it, i.e., use the training data to update the model's variables (`W` and `b`) so that the loss goes down using [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent). There are many variants of the gradient descent scheme that are captured in `tf.train.Optimizer` implementations. We'd highly recommend using those implementations, but in the spirit of building from first principles, in this particular example we will implement the basic math ourselves." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "MBIACgdnA55X" - }, - "outputs": [], - "source": [ - "def train(model, inputs, outputs, learning_rate):\n", - " with tf.GradientTape() as t:\n", - " current_loss = loss(model(inputs), outputs)\n", - " dW, db = t.gradient(current_loss, [model.W, model.b])\n", - " model.W.assign_sub(learning_rate * dW)\n", - " model.b.assign_sub(learning_rate * db)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "RwWPaJryD2aN" - }, - "source": [ - "Finally, let's repeatedly run through the training data and see how `W` and `b` evolve." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 446 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 569, - "status": "ok", - "timestamp": 1527005915434, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "XdfkR223D9dW", - "outputId": "c43591ae-d5ac-4f2b-a8e7-bfce607e0919" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 0: W=5.00 b=0.00, loss=9.48636\n", - "Epoch 1: W=4.58 b=0.42, loss=6.28101\n", - "Epoch 2: W=4.24 b=0.76, loss=4.29357\n", - "Epoch 3: W=3.98 b=1.02, loss=3.06128\n", - "Epoch 4: W=3.78 b=1.23, loss=2.29721\n", - "Epoch 5: W=3.61 b=1.39, loss=1.82345\n", - "Epoch 6: W=3.49 b=1.52, loss=1.52970\n", - "Epoch 7: W=3.38 b=1.62, loss=1.34756\n", - "Epoch 8: W=3.30 b=1.70, loss=1.23463\n", - "Epoch 9: W=3.24 b=1.76, loss=1.16460\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW0AAAEDCAYAAAD+/1UIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl4VOXdPvD7zJZ9XwmELQkQIAELsiTsi6xiEBGXAiIW\nbV8WBY2K0tLa4lbsr283qxURtIoioAi8SpFNg6whi0FJKAoJBgLZt5k5c87vj5OZLIRkgEnOGXJ/\nritXJsmZyT0sN1+enPOMIMuyDCIicgs6tQMQEZHzWNpERG6EpU1E5EZY2kREboSlTUTkRljaRERu\nxODMQePGjYOvry90Oh0MBgM2b97c1rmIiKgZTpW2IAjYuHEjAgIC2joPERG1wKnlEVmWIUlSW2ch\nIqJWCM5cETl+/HgEBARAEATMmTMH9957b3tkIyKiJpxaHvnggw8QFhaG4uJiLFiwAD179sTgwYPb\nOhsRETXh1PJIWFgYACA4OBgTJ05EVlZWi8fL3t6AIADdugFvvglYrTeflIiIWl8eqampgSRJ8PHx\nQXV1NR5++GEsXrwYI0aMuPadCgtRvfoFeL2zDkJtLWxdu6NqRSrMs+8DDE4N9y4XFuaHoqIKVb73\ntTCTc7SYCdBmLmZyjlYzOaPVSfvy5ct44IEHkJKSgjlz5mDcuHEtFzYAREai6oWXUHwkA9WPPApd\n4QX4L/sVgpMGwWPTvwFRdCocERE15tQPIm9Ew3/FdBcK4P3ntfB89x0IVivEmFhUP/kMzCmzAL2+\nLb79VbT6LysztU6LmQBt5mIm52g1kzPa5YpIKaozKl9+DcWHT6Jm7gLof/wB/r98BEGjh8Fj28cA\nTyckInJKu17GLnWJRuXaP6P40AnUPDgP+jN58F+0AEFjhsO0fRvLm4ioFarsPSJ1647KP/0VxWnH\nUTvnAehPf4+AhfMQNG4ETDs/A/hiOkREzVJ1wyipR09U/OV1lHx9FLX3zIH+uxwEPPQAAieMgunz\nXSxvIqImNLHLny0mDhV/fxMlB4+g9u57YMjORMDcOQicNAamPV+wvImI6miitO1scb1Q8fo6lOz/\nBrUzZsJ4Mh0B99+DwKnjYdy7h+VNRNftL395DR999IHj4+XLl2DVqlWOj//61/+HDz/8txrRboim\nStvO1iceFf96B8V702CeNgPG48cQOGcmAu+cBOOBfSxvInJa//6JyM7OAKBsfldWVorc3FzH17Oz\nM5GQMECteNdNk6VtZ+vXH+Vvv4uSPQdhnjwVxiPfIPCeGQhImQpj2ldqxyMiN5CQMBBZWZkAgLNn\nz6Bnzxj4+PigsrISVqsVP/74A+Liequc0nnqXFN+ncSEASjf8AEMJ0/A+9UX4bH7c5hSpsIycjSq\nnloJcdhwtSMSkRN8Vj8Pj+3bXPqY5jtTULX699f8emhoKPR6Ay5duoisrEz075+I6uoyZGdnwsfH\nBzExsTCotL3GjdD0pN2UOPBnKH/vI5Ts2gPL2PEwHdyPoBmTEDD7LhiOHlY7HhFpVGJiIrKyMpCd\nrZT2gAEDkJWVgaws91oaAdxk0m5KHHQ7yjZtheHIYfi8sgam/Xth2r8X5vETUZ26EuJtg9SOSETN\nqFr9+xan4rbSr18isrIy8d//KssjHh4y/vnPf8HX1wfTpt3V7nluhltN2k2JQ4aibPMnKP1kFyzJ\nI+GxZzeCJo2F/8/vhSHzpNrxiEgjEhIGIC3tIPz9/SEIAgICAlBZWYHs7Cz075+gdrzr4talbWcd\nnoyyrTtQuuUzWIYlweOL/0PQhFHwn3c/9HU/gCCijismJhbl5WXo3z+x0ef8/Pzg7+9er33bLrv8\ntStZhvHAPvi8/AcYjx0BAJin3wWPXz+Hom69lRdn0Ait7jTGTM7RYi5mco5WMznjlpi0GxEEWEeP\nRemO3Sj9YAusPxsEj88+AYYMQeCEUfBc/xaEinK1UxIR3ZBbr7TtBAHWcRNQuutLlG7aCsycCUNO\nNvxSn0BIQm/4Ll8CQ/pxXqhDRG7l1i1tO0GAdex4YMsWFJ88hapnV0EKCYHXu+8gaNJYBI4fCc+3\n/wWhvEztpERErbr1S7sBKSIS1U88heKjmSj9YAvM02bAcOpb+D29HCGJveH7xGIYThzj9E1EmtWh\nSttBp4N13ASUv/2uMn2v/DWk0DB4vbcBQZPHIWjcCHiue5PTNxFpTscs7QakiEhUP/4kio9koHTT\nVpinzYD++1Pwe2aFMn0//j8wHD/K6ZuINKHDl7aDTgfr2PHK9J2eg8rnfgMpNBxe/96IoCnjETQ2\nmdM3kZsqLPwJ8+bNUTuGS7C0myFFRKJm2QoUHzmpTN/T74L+9HfK9J3QC77LfgXDsSOcvonciKCh\nazRuBku7Jfbpe91GXEk/hcrnV0MKj4DX++8iaOoEZfp+6w0IZaVqJyWiVoiiiD/8YTXmz78fy5Yt\ng9lsVjvSDbn1roi8BpddASVJMB7YB6+N62Ha9RkEUYTs5QXzXXejZu5DEAcPcfqqS61elcVMztFi\nLq1nWr3aA9u3u3afujvvFLF6dcsFXFj4E2bPnoF//GMd+vdPwJ/+9CI6dYrGfff93KVZbkbHvSKy\nrel0sI4Zh/K3NuDKye9Q+fxvIYVHwPOD9xA0bSKCxiTB861/cvom0piIiEjH5lAzZsxAZmaGyolu\njFtuzaoVcng4apY+gZrFy2A8uB+eG9fDY+d2+D37FHx/92uYZ8xEzbwF1zV9E93KVq82tzoVt5Wm\na9ru+leSk7Yr6HSwjh6Lin+9o0zfq34HKSISnpv+XTd9D4fnv16HUFqidlKiDquw8Cd8+202AGDH\njh1ITByocqIbw9J2MTk8HDVLHkfxN+ko3fwpau+6G/q8XPitTEVIYm/4LXkMhiOHeeYJUTvr3r0H\ndu36DPPn34+ysjKkpNyjdqQbwh9EtgOhqAiem/4Nz41vw3D2vwAAsU88DAsfxpVREyH16KlKruZo\n/QdZWqLFXMzkHK1mcgYn7XYgh4WhZvEylBw6gdKPt6M25W7oz+QBTz2FkKEDETR6OLxf/gMMWRmc\nwImoRfxBZHvS6WAdORrWkaNReeUKQr/eA/Omj2A6sA8+a1+Gz9qXYYvuCvOUabBMvRPWIcMAN3qV\naCJqe2wElcghIcDChSifcS+EygoYv/wPPHZ+BtPuz+H9xj/g/cY/IAUHwzxpKixTpsMyeizg5aV2\nbCJSGUtbA2RfP1hmzIRlxkzAYoHx64NKgf/fDni9/y683n8Xsrc3LGMnwDx1OiwTJ0EODFI7NhGp\ngKWtNSYTrGPHKy/c8PJaGE4cg8euHTDt3A6PHZ/CY8enkA0GWJNGKgU+ZRqkTlFqpyaidsLS1jKd\nDuLgIRAHD0HV86uhP/09PHZ9BtPO7TAd2AvTgb3AMytg/dkgmKdMh2XqnbDF9VI7NRG1IZ494i4E\nAbbefVD9+JMo/WI/rqTnoOLFV2EZOQaGjJPw/cNvEZw8GEFJg+Dz+9XKHuCSpHZqItVVVlZi69bN\nbfb406dPQGVlJQDgypXLGDnydmRlZTT4+kSUl7vuxcSdLm1JkjBz5kw89thjLvvmdOOkzl1Qu/BR\nlH38Ka7knEH5X/8J89Q7oS/Ih/f/voagKeMRPDAevqlPwLjvS8BiUTsykSoqKsqxdetHzX5NcsFg\n07dvArKzMwEA2dmZ6NWrD7KylI/PnfsRgYFB8Pf3v+nvY+d0aW/YsAExMTEu+8bkOnJQMMz33o/y\n9e/h8qmzKHvnfdTOeQCCuRZe699C4L0pCOkbA79fPgLT9m1A3VRA1BG8/vpfceFCAR5++EH8/e//\ni/T045g3bx5++9vnMX/+fVe9QML777+Lt99+EwBQUJCPFSuW4pFH5mHx4kU4d+7Hqx4/ISHRUdpZ\nWZmYM+dBfPttfYknJCS69Pk4taZdWFiI/fv347HHHsPbb7/t0gDkYt7esEyZBsuUaYAowvhNGky7\nPoPHzs/g+fGH8Pz4Q8geHrCMGQfLlOkw3zEFcmio2qmpAwke1L/Zzxcfz3bJ8U398pdL8MMP/8W6\nde8BANLTjyMrKwsbNnyIyMhIFBb+dM0XSHjllTVITV2Jzp27ICcnG2vXvoQ///kfjY7p3z8R69e/\nBQA4depbPPLIY/joo38DUEo8IWGAUzmd5VRpr1mzBqmpqaio0NZln9QKgwHWEaNgHTEKVb9/GYas\nDOUslF074PH5Lnh8vgu+Oh2sQ4fDMnU6zFOmA2HN/wUhupUkJiYiMjKyxWNqamqQnZ2BVauehn23\nD1EUrzqub99+yM39HrW1tbDZbPD09ERUVGcUFOQjOzsD99/v2j27Wy3tffv2ITQ0FPHx8Th8+LDT\nD+zsdfTtqcNnGj9SeVv7CpCbC3zyCYStW2E6lAbToa/hu+pZICEBYWPGAGPGAKNGARqZwrX4ewdo\nM5fmMzWzxAAAYde68/Ue34TFUg69XufIEBjoDS8vL8fHklQNQajPaDQCOp0JwcHeCAgIwPbtn7by\nHfzQvXs37N//OQYMSEBYmB+GDBmMrKxjKC8vw6Br/E/hRrVa2idOnMCXX36J/fv3w2w2o6qqCqmp\nqXjllVdavJ8WN2NhpgYCI4H5jwLzH4Vw8SI8Pt8Jj53bYUr7CsjKAv7yFwCAGN8X1uHJsCSPhHVY\nMuQwZ/+quI4Wf+8AbeZipqvV1sqoqKh0ZCgtrQZQ31GSZMLly1dw5kwBPD09sXv3HgwbloSaGhkR\nEZ3w4YdbMXbsBABAXl4uYmPjrvoeffr0w7p1b2PhwkdRVFSBbt164YUXViE+vp/Tz93Zf2xbLe3l\ny5dj+fLlAIAjR45g3bp1rRY2uRc5IgK18xagdt4ChPmbUPrFPhi/Pghj2tcwHjsMw6kceK1TfjAj\n9u4D6/BkWJNHwjJ8BOTwcJXTE7XM3z8ACQkDMH/+fRg6NAnDhyc3+rrBYMCCBY9g0aL5iIrqjG7d\nuju+9utfv4A//vElvPPOOthsIsaPv6PZ0k5IGIDNmzehXz/llXF69+6DoqIizJgx0+XP57q2ZrWX\n9uuvv97qsfzXvnVukcligSH9BEyHvlKK/OhhCNXVji+Lcb1gHT4C1uQRsCaNgBTR8jqhSzJphBZz\nMZNztJrJGdxPW0VumclqheHkCRgPfQ3T1wdhOHIYuqr6UwjFmFhYk0Y43lxxib0Wf50AbeZiJudo\nNZMzeBk7XR+jEeLtQyHePhQ1S5crJZ55UllKSTsI4+Fv4LVxPbw2rgcA2Lr3UNbD65ZUpM5d1M1P\n5OZY2nRzjEaIg26HOOh21Cx5HBBFGLIylBI/9BWMh9Lg9d4GeL23AQBg69odluQR9SUe3VXlJ0Dk\nXlja5FoGA8TbBkG8bRBq/mcpYLPB8G0WjF9/VV/iddvNAoAtuiusSSNgsS+ndO3mvi+TTdQOWNrU\ntvR6iIkDISYORM0vFwM2G/Q538KUdtAxjXtu+jc8NylXkNk6d3Gsh1uSRkDq3kPlJ0CkLSxtal96\nPWwJiahJSETNo/8DSBL0p3Ial/hHH8Dzow8AALZOUcCY0fDqkwAxcQDEhETI/gEqPwki9bC0SV06\nHWz9+qOmX3/U/OKXSol//x2MaV/BlKYsqeD99+GL9x13EXv0VKb3hAF1RT5Aefk2og6ApU3aotPB\nFt8Xtvi+qF24CJBlhJUWonx/GgyZGcpb1kl4frIF+GSL4262LtH1JZ44AGLiwDY5Z5zcT2VlJXbv\n/j/MnHlPm32PNWt+i+TkkRg9elybfQ87ljZpmyAAvXrBHNQJ5pRZyudkGbr8844CN2RmwJhxEh67\nPoPHrs8cd7WFR9SXeMJAiIkDIHWJ5g86Oxj7ftpNS1uSJOh07vc6MCxtcj+CACm6KyzRXWGZdqfj\n07qLhTBknmwwkWfA4z9fwOM/XziOkYKCHAVuf7N17wm44V9edzVokE+znz9+vMolxzfVcD9tvV4P\nLy9vREVF4ttvc/Dqq39Gaurj2LBhEwBlL+3a2hosWPALFBTk47XXXkFZWSk8PT2Rmvocunbtds3v\nc/ToYXz44fsoKSnG4sVPIClphFP5rhdLm24ZUkQkLBMnwzJxsuNzwpUrMGTVl7gh82T962va7+fr\nBzEh0bE+LiYOhC02DjDwr8etoOF+2unpx5Ga+gTWrn0VRqPfTe+l3VBh4U/429/eRH7+eSxd+hg2\nbdoGo9Ho8ufDP5V0S5NDQmAdMw7WMfVrjUJ5GQzZWfVTeVYGjIcPwXTo6/r7eXlB7NvfsT4uJg6A\n2DseMJnUeBq3FGcn5Bs9vjV9+/ZDVFRUi5exO7uXdkPjxk0EAHTpEo2oqM748ccfmt1c6maxtKnD\nkf0DHOeCO1RVwZCT3WAiz4AhIx3G40fr72c0QozvpxR4/0Rg2CAIIZ2VnQ65Tu42PD09Hbf1ej1s\ntvrXibRYzAAAWZbg5+fveLUbZzSd2K81wd8sljYRAPj4OPZUcTCbYfgup9FZK4Zvs2HMPOk4JBSA\n5B8AW2wsbDFxsMX1glj33tajJ+Dh0f7PhRrx9vZGdd3OlE33xwsKCkZpaQnKy8vh6emJtLSvMGxY\nEry9fdCpUxT27v1Pq3tp2+3d+x9MnjwNFy4U4MKFghbXv28GS5voWjw8IA64DeKA2+o/Z7VCn3sa\nhqwM+F/4EeaMbOjP5MKQlQnjieON7i7rdJC6doMYG+codFtsHMTYXsqLSXA6bxcN99M2mTwQHBzs\n+Jor9tK2i47uhsWLF6GkpBhPPbWyTdazAW7Nqipmco4WMwFNcokidOd+hOFMLvS5udCfyVXKPS8X\nustFV93XMZ3H1he5LTbupqdzLf5aMZNzuDUrUXsyGCD1jIGlZwzQ4OwVABBKS6DPy4U+LxeGuvf6\nvNOtT+f2Iq9bcuF0TgBLm6jNyYFBEAcPgTh4CMwNvyCK0J/7oa7E86DPO11X7KeVc8sbnF8O1E3n\nccpSixjXS1lyccF0Ts7bsGEd9u79DwRBgCzLEAQBY8dOwNy5C9otA5dHVMRMztFiJqBtc101neee\nVpZczv4XgtXa6FjHdB7XCx59eqEyJBK26GhIXaJh69IVcmioqhO6Fn//tJrJGZy0iTTIqem8bu3c\nUFfoHrs/B3Z/Dt+mj+XlBVvnLkqJR3eFFN0VtrpCl6KjIUV2AvT6dnx2dDNY2kTuxGCArWcsbD1j\ngTumNPqSUFKM0MorKMv8Dvr8c9Dln4f+/Pm69z/CkJfb7EPKBgOkqM6wdbFP59GOYpeio2HrHM3l\nFw1haRPdIuSgYKBXN1iir3FaWmUl9PnnlUI/fx76/PPQ5Z9zFLvx0NcQrrFaaguPUAo8uiukLg0K\nvW5al32d+6893TyWNlFH4esLW5942PrEN/91sxm6CwV1ZX4e+vPn6m+fOwdDxkkYjx9r9q5SYKBS\n4F2i69bT64sdiX0A2YNLMC7C0iYihYcHpB49IfXo2fzXbTboLhbWTen1yy/224b/5kHIzmz2rqF6\nPaTQMEjhEZAiIhq/D49s9DG8vdvwSbo/ljYROUevhxTVGVJUZ4hDh139dVmGUFzcYPlFKXPv4iKI\n5wuUrXPP5ELIymjx20h+/pDCwyFFRCrvHcVu/1wEpIhIyMHBHXJLXZY2EbmGIEAOCYEYEgI0uPTf\nO8wPpQ1OrxMqK6C7dBG6ixfr3hdCd+lS3fv6z+v/e+aaa+xA3Q9Qw8KbTO0RjlJvWPJosEmUu2Np\nE1G7kn39YPP1U86AaYnVCt2Vy1eVeePCvwjD96cgZKS3+FBSQGD91B4RAXTtAm9PX0jBIZCCgyEH\nBUMKCoYcEgIpKFjTJc/SJiJtMhohRXZSziNviSxDqChvMq03md7r3gy5px13a/71cOoe0ttbKfSg\nukIPaVDswcH1X6u7LQcHQ/bxbZeLmFjaROTeBAGyfwBs/gHKKw61xGKB7nIRQsQqlJ45D11JMYSS\nYuiuXGl0Wygpga6kGIYzeRCqnXsRBtlobDSty0H1hS4FBSsTfXDj4pcDAq97XZ6lTUQdh8kEKaoz\nEOYHa9dezt3HbFYKvbgYuuIrSrHbbxcX15e9/eOfLsBwKseph5Z1OsiBgZCCQ4AG/wtoCUubiKgl\nHh7KEk1kJ9icvY8oQigtVQq9boq/ZvHX3XYWS5uIyNUMBsihobCFhgJOvkxkmJMP3fFOciQicmMs\nbSIiN8LSJiJyIyxtIiI30uoPIi0WCx588EFYrVbYbDZMmjQJixcvbo9sRETURKulbTKZsGHDBnh5\necFms+H+++/HqFGjkJiY2B75iIioAaeWR7y8vAAoU7coim0aiIiIrs2p0pYkCSkpKUhOTkZycjKn\nbCIilTh1cY1Op8O2bdtQWVmJX/3qV8jLy0NsbAs7dHXvjmDp6i0Vi49nN3t48KD+zX7epcfrhKsy\nqZoHuCqT6nmaZNJEngaZNJPH7tyPmsrD42+N41tzXVdE+vr6YsiQITh48GDLpQ1Ar7t6t6trvkR8\nM8e2xfFNM6mdp2kmLeRpmEkreeyZtJSnxfuolMd+/FX3UznPVffVQJ5GH2skj7MEWW5hl3EAxcXF\nMBqN8PPzQ21tLRYuXIhFixZh9OjRLT5wUYNNz7UgLMyPmZzATM7TYi5mco5WMzmj1Um7qKgIzzzz\nDCRJgiRJmDp1aquFTUREbaPV0u7duze2bt3aHlmIiKgVvCKSiMiNsLSJiNwIS5uIyI2wtImI3AhL\nm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uI\nyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiN\nsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0\niYjciKG1AwoLC5GamorLly9Dr9dj9uzZmDdvXntkIyKiJlotbb1ej2effRbx8fGoqqrC3XffjeTk\nZMTExLRHPiIiaqDV5ZGwsDDEx8cDAHx8fBATE4NLly61eTAiIrrada1p5+fn47vvvkNiYmJb5SEi\noha0ujxiV1VVhaVLl2LlypXw8fFp8dju3QFJuvqY48ermj1+0KDmH8+Vx+t0V2dSMw+AqzKpnadp\nJi3kaZhJK3nszp1r9tOq5eHxt8bxrXGqtEVRxNKlS3HXXXdhwoQJTj2wTnf1EB8W5neNY5t/DFcf\n3zST2nmaZtJCnoaZtJLHnklLeVq6j1p57Mc3vZ/aeZre1kKehh9rJY+zBFmW5dYOSk1NRVBQEJ59\n9lmnH7ioqOKGArWVsDA/ZnICMzlPi7mYyTlazeSMVte0jx8/ju3bt+Obb75BSkoKZs6ciQMHDtx0\nQCIiun6tLo8MGjQIp06dao8sRETUCl4RSUTkRljaRERuhKVNRORGWNpERG6EpU1E5EacviKSiIiu\nnyQBZWVASYmA4uL6t5ISwfG5khIBn37q3OOxtImInGSxoFHRNizgpkWsfAyUlgqQJMFlGVjaRNTh\nyDJQWYmrCvdaU7D981VVzpWvXi8jKEhGaKiMuDgJQUEygoOVt6Ag1L2XHe+DgmQAvk49NkubiG4Z\nNTVAUZGAixcFXLqkw6VLyu2iIuVj5fMCLl8GLBbnLhv39lZKtUcPqVHR1pdw4/INCZHh5wcIrhuu\nG2FpE5GmSZIyEV+6JDhK2F7IDd8uXtShvLzlpvTwkBERIWPgQMDfX2yxfO23vbza6Yk6iaVNRKqo\nqUGjwm1cwvVTcVGRAFFsuYxDQiR07izhtttkhIfLiIiQEB5uvy3X3Zbg769MwMqGUTXt9Exdi6VN\nRC4likBhoYD8fB3OnxdQUQGcPevRYEpWStn5qVhCeLjUqIAblnJYmAyjsZ2enAawtInoutTUABcu\nCDh/Xof8fB3y8+23laK+cEGAzda0kE2OW9eaiusnYuVzbbku7M5Y2kTUSFkZGpVw49sCLl9u/po8\nQZARGSnjZz+T0KWL/U1GfLwnPD2rEBGhnE3RkabitsDSJupAZFlZR25Ywsq0XH+7oqL58dZolNG5\ns4z4eBFdusjo0kVCdLTkuB0VJcNkuvp+YWGeKCqS2viZdRwsbaJbiNUKnDvXtJDrlzIKCgSYzc2X\nso+P3KiEu3SxfywhOlpZtmjppdeofbC0idyMLAM//SQgL0+H3FwdzpzRIS9PeV9QAEhS8xdphIZK\niI9X1pPrC7m+mAMDuYbsDljaRBpVXQ2cOaOUccNyzsvTobr66naNiJCQlARERFgbTczR0TI6d5bg\n7a3CkyCXY2kTqcg+Nefm1heyfWrOz796LcLTU0bPnhJiYxu/xcQoZ1so5x/XqvBMqL2wtInagX1q\nbljK9um5uak5MlLCyJEiYmIal3OXLlxX7uhY2kQuIsvK+csNJ2b7W0HBtafmuDjJUc72277O7R1E\nHRBLm+g6WSzA99/rcOkScOKEqdH03NzU3KmTMjU3XMqIi5PQuTOnZrp+LG2iFtTUADk5OmRm6pGV\npbw/dUoHq9Vezh4AAC+va681c2omV2JpE9WprASys/XIzKwv6dOndY0uyfbwkJGQIKF/fxsGDzYh\nIqIasbGcmqn9sLSpQyopAbKylIJW3utx5kzj1vX2ljF4sA2JiRISEpT3cXGS4zLssDATiopsKqSn\njoylTbe8S5cEx9KGvaTPnWtc0AEBMkaOFJGQICEx0YbERBt69uT0TNrD0qZbhv3sjYblnJmpQ2Fh\n4+YNDZUwbpyIxESbo6S7dpV5NSC5BZY2uSVZBn74QXAUs30N+sqVxgUdFSVh8mRrgwlaQmQkC5rc\nF0ubNM9mA06f1jUq56ws/VWb6HfrJiEpyepYg05IkBAWJquUmqhtsLRJc8xmID1dj6+/1iMtTY/j\nx4Hqah/H1wVBeYXrCRPqp+f+/W0IDFQxNFE7YWmT6mprgRMnlIJOS9Pj2DE9amvrp+h+/YCEBKtj\nDbpfPxvPfaYOi6VN7a6mBjh+vL6kjx/XO/Z4FgQZfftKSE62YfhwG4YPF9G7NzdBIrJjaVObq64G\njh2rL+lQSYIFAAANpklEQVQTJ/SwWOpLun9/CUlJNiQl2TBsmIigIJUDE2kYS5tcrqrq6pK2X/at\n09WXdHKyiKFDuRZNdD1Y2nTTKiuBo0ftJW1AeroOolhf0omJ9klaKemAAJUDE7kxljZdt8pK4MgR\n+9kdBmRk1Je0Xi9jwAAJw4crk/SQITb4+6scmOgW0mppr1y5Evv27UNISAi2b9/eHplIYyoqgMOH\n6yfpjIz6TZT0ehkDB0pIShKRnGzDkCE8s4OoLbVa2nfffTfmzp2L1NTU9shDGlBeDnzzjVLQaWnK\nFYeSpJS0wSDjttskJCeLGD6cJU3U3lot7cGDB6OgoKA9spBKZBk4eVKHXbsMOHgQSE/3dZS00ajs\ndGc/u+P2223w8WnlAYmozXBNu4OyWoFDh/TYtcuAXbsMuHBB2bPDaARuv92G5GSlpAcPtvFVvIk0\npM1KOyzMr60e+oZ19EzV1cAXXwBbtwLbtyt7SgNAYCAwdy6QkgJMmgT4+BigtX/Ptfh7B2gzFzM5\nR4uZnNFmfzOLiira6qFvSFiYX4fMVFICfPGFATt3GrBvnwE1NcqyR2SkhAULREydKiIpyebY2N/H\np2P+Ot0ILeZiJudoNZMznCptWeZOae7kwgUBu3YpRZ2Wpnec6REba8PUqUpRDxwocYN/IjfUammv\nWLEChw8fRmlpKcaMGYMlS5Zg1qxZ7ZGNrkNurg47dypFnZ6ud3z+ttuUop4yRUSvXpKKCYnIFVot\n7bVr17ZHDrpOkqSc8WEv6rw8paj1euVls+xFHRXF/yUR3Uq09dMmapHVCqSl1Z/x8dNPyvqGl5eM\nKVOsmDpVxB13cMMlolsZS1vjqquBvXuVaXr3bgNKS5X16cBAGffeqxT1mDEiT8sj6iBY2hpUUgJ8\n/rlS1Pv315/xERUlYdYspaiHDas/44OIOg6WtkYUFAiOZY+GZ3z06lX/g8SBAyW+IC1RB8fSVtGp\nU8C775qwc6cBJ0/Wn/Hxs5/ZT82zIjaWP0gkonos7XZWWgp89JER775rxKlTAOABg0HGqFH1Z3x0\n6sSiJqLmsbTbgSwDR4/qsGGDCZ9+akBtrQCjUcbMmcCECTWYOFHkq7cQkVNY2m2orEyZqjduNOLU\nKWX5o0cPCXPnmjFnjoi+fX1RVCSqnJKI3AlL28VkGTh2TIeNG0345BPlzA+jUcZdd1kxd64VI0bY\nePk4Ed0wlraLlJUBmzcbsWFD/VTdrZuEuXMtuP9+K8LCuE5NRDePpX0TZBk4cUJZq962TZmqDQYZ\nM2YoU/XIkZyqici1WNo3oLy8fqrOyWk8Vd93nxXh4ZyqiahtsLSdJMtAeroOGzYYsW2bEdXVylQ9\nfboV8+ZZMWoUp2oianss7VZUVChT9caNRmRnK1N11671U3VEBKdqImo/LO1m2F/oduNGI7ZsUaZq\nvV7GtGnKVD16NKdqIlIHS7uBysr6qTorq36q/vnPlTNAOFUTkdpY2gAyMpS16o8/rp+qp05Vpuox\nYzhVE5F2dNjSrqwEtmxRzgDJzFSm6i5dJCxdasEDD1gRGcmpmoi0p8OVdmamDu+8o6xVV1UpU/Xk\nyVbMn69M1Xp9649BRKSWDlHalZXAtm3A3//u7dgCtXNnCYsXK1M1d9UjIndxS5d2eTnwxhsmvP66\nCeXlgE6nw+TJylr12LGcqonI/dySpV1ZCbz1lgl/+5sJpaUCQkIkrF4tICWliq9OTkRu7ZYq7epq\nYN06I/72NxOuXNEhMFDGc8+ZsXChBT16+KGoiIVNRO7tlijt2lpgwwYj/vxnE4qKdPD3l5Gaasai\nRRb4+6udjojIddy6tM1m4N13lbIuLNTBx0fG8uVmPPaYha8EQ0S3JLcsbasVeP99I/70JxMKCnTw\n9paxZIkZv/qVFSEhXAIholuXW5W2KAIffWTA2rUeOHdOB09PGY89ZsGSJRa+yAARdQhuUdo2G7Bl\niwF//KMHzp7VwWSS8cgjFixbZuF+IETUoWi6tCUJ+PRTA1591YTcXD2MRhkPPWTB449beOoeEXVI\nmixtSQJ27lTK+tQpPfR6GT//uVLWXbuyrImo49JUacsy8MUXerz8sgeys/XQ6WTMmWPF8uVm9OjB\nsiYi0kRpyzKwd69S1unpegiCjLvvtuLJJ82IjWVZExHZqVrasgwcPKiU9dGjykYgM2ZY8eSTFvTp\nI6kZjYhIk1Qr7UOH9HjpJRMOHVIiTJlixVNPWdC/P8uaiOha2r20jx7V4aWXPHDwoPKtJ04UkZpq\nxoABLGsiotY49UJaBw4cwOTJkzFp0iS88cYbN/SNTpzQ4b77vDBtmg8OHjRgzBgRu3ZV4b33aljY\nREROanXSliQJL7zwAtavX4/w8HDcc889GD9+PGJiYpz6BllZOrzyigc+/1z5ViNGiEhNtWDYMNvN\nJSci6oBaLe3MzEx069YNnTt3BgBMmzYNe/bsabW0c3J0ePVVE3bsMAIAhg4V8fTTFowYwbImIrpR\nrZb2xYsX0alTJ8fHERERyMrKavE+990HfPihN2RZwKBBNjz9tBmjR9sgCDcfmIioI2u1tGX5+s+T\n3rQJGDBAwtNPmzF+PMuaiMhVWi3tyMhIXLhwwfHxxYsXER4e3uJ9lJ7XA/C+yXiuFRbmp3aEqzCT\nc7SYCdBmLmZyjhYzOaPVs0cSEhJw7tw5FBQUwGKxYMeOHRg/fnx7ZCMioiZanbT1ej1WrVqFhx9+\nGLIs45577nH6zBEiInItQb6RRWsiIlKFUxfXEBGRNrC0iYjcCEubiMiNuHTDqAMHDmDNmjWQZRmz\nZs3CokWLXPnwN2TlypXYt28fQkJCsH37drXjAAAKCwuRmpqKy5cvQ6/XY/bs2Zg3b56qmSwWCx58\n8EFYrVbYbDZMmjQJixcvVjWTnSRJmDVrFiIiIvD666+rHQfjxo2Dr68vdDodDAYDNm/erHYkVFRU\n4LnnnkNubi50Oh3WrFmDAQMGqJrp7NmzeOKJJyAIAmRZxvnz57Fs2TLV/6yvX78emzdvhiAI6NWr\nF1588UWYTCZVM73zzjuOP0et9oHsIjabTZ4wYYKcn58vWywWecaMGXJeXp6rHv6GHT16VM7JyZGn\nT5+udhSHS5cuyTk5ObIsy3JlZaV8xx13aOLXqrq6WpZlWRZFUZ49e7ackZGhciLF22+/La9YsUJ+\n9NFH1Y4iy7Isjxs3Ti4tLVU7RiNPP/20vHnzZlmWZdlqtcoVFRUqJ2rMZrPJycnJ8oULF1TNUVhY\nKI8bN042m82yLMvysmXL5K1bt6qa6fTp0/L06dNls9ksi6IoP/TQQ/KPP/54zeNdtjzScI8So9Ho\n2KNEbYMHD4a/v7/aMRoJCwtDfHw8AMDHxwcxMTG4dOmSyqkALy8vAMrULYqiymkUhYWF2L9/P2bP\nnq12FAdZliFJ2tmZsrKyEseOHcOsWbMAAAaDAb6+viqnaiwtLQ1du3ZttCWGWiRJQk1NDURRRG1t\nbasXC7a1M2fOYODAgTCZTNDr9bj99tuxe/fuax7vstJubo8SLRSR1uXn5+O7775DYmKi2lEgSRJS\nUlKQnJyM5ORkTWRas2YNUlNTIWhoLwRBELBw4ULMmjULH374odpxkJ+fj6CgIDz77LOYOXMmVq1a\nhdraWrVjNbJz505MmzZN7RiIiIjAggULMGbMGIwaNQp+fn5ISkpSNVNcXByOHj2KsrIy1NTU4MCB\nA/jpp5+uebzLSlvm6d7XraqqCkuXLsXKlSvh4+OjdhzodDps27YNBw4cQEZGBvLy8lTNs2/fPoSG\nhiI+Pl5Tf74++OADbNmyBW+++Sbee+89HDt2TNU8oigiJycHDzzwALZu3QpPT88b3ve+LVitVnz5\n5ZeYMmWK2lFQXl6OPXv2YO/evTh48CCqq6tV/1lXTEwMfvGLX2DBggVYtGgR+vTpA4Ph2j9udFlp\n38geJR2ZKIpYunQp7rrrLkyYMEHtOI34+vpiyJAhOHjwoKo5Tpw4gS+//BLjx4/HihUrcPjwYaSm\npqqaCVCWtwAgODgYEydObHXXy7YWGRmJyMhIJCQkAAAmTZqEnJwcVTM1dODAAfTr1w/BwcFqR0Fa\nWhqio6MRGBgIvV6PiRMnIj09Xe1YmDVrFrZs2YKNGzciICAA3bp1u+axLittLe9RoqUpzW7lypWI\njY3F/Pnz1Y4CACguLkZFRQUAoLa2FocOHULPnj1VzbR8+XLs27cPe/bswWuvvYahQ4filVdeUTVT\nTU0NqqqqAADV1dX46quvEBcXp2qm0NBQdOrUCWfPngUAfPPNN5raamLHjh2YPn262jEAAFFRUcjI\nyIDZbIYsy5r5tSouLgYAXLhwAbt3727x18tlp/xpdY8S+4RWWlqKMWPGYMmSJY4f2Kjl+PHj2L59\nO3r16oWUlBQIgoAnnngCo0aNUi1TUVERnnnmGUiSBEmSMHXqVIwePVq1PFp1+fJlLF68GIIgwGaz\n4c4778SIESPUjoXnn38eTz75JERRRHR0NF588UW1IwFQBoC0tDT87ne/UzsKACAxMRGTJk1CSkoK\nDAYD+vbti3vvvVftWFiyZAnKyspgMBjwm9/8Bn5+196BkHuPEBG5EV4RSUTkRljaRERuhKVNRORG\nWNpERG6EpU1E5EZY2kREboSlTUTkRljaRERu5P8D+7Wym3BFpegAAAAASUVORK5CYII=\n", - "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0x7f5be4b8ec50\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "model = Model()\n", - "\n", - "# Collect the history of W-values and b-values to plot later\n", - "Ws, bs = [], []\n", - "epochs = range(10)\n", - "for epoch in epochs:\n", - " Ws.append(model.W.numpy())\n", - " bs.append(model.b.numpy())\n", - " current_loss = loss(model(inputs), outputs)\n", - "\n", - " train(model, inputs, outputs, learning_rate=0.1)\n", - " print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %\n", - " (epoch, Ws[-1], bs[-1], current_loss))\n", - "\n", - "# Let's plot it all\n", - "plt.plot(epochs, Ws, 'r',\n", - " epochs, bs, 'b')\n", - "plt.plot([TRUE_W] * len(epochs), 'r--',\n", - " [TRUE_b] * len(epochs), 'b--')\n", - "plt.legend(['W', 'b', 'true W', 'true_b'])\n", - "plt.show()\n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "vPnIVuaSJwWz" - }, - "source": [ - "## Next Steps\n", - "\n", - "In this tutorial we covered `Variable`s and built and trained a simple linear model using the TensorFlow primitives discussed so far.\n", - "\n", - "In theory, this is pretty much all you need to use TensorFlow for your machine learning research.\n", - "In practice, particularly for neural networks, the higher level APIs like `tf.keras` will be much more convenient since it provides higher level building blocks (called \"layers\"), utilities to save and restore state, a suite of loss functions, a suite of optimization strategies etc. \n", - "\n", - "The [next tutorial](TODO) will cover these higher level APIs." - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "default_view": {}, - "name": "Training Models", - "provenance": [], - "version": "0.3.2", - "views": {} - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb deleted file mode 100644 index 5749f22ac58e0a012ed7e3fec4dfe2913d3f8273..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb +++ /dev/null @@ -1,551 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "pwX7Fii1rwsJ" - }, - "outputs": [], - "source": [ - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "tfe = tf.contrib.eager\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "UEu3q4jmpKVT" - }, - "source": [ - "# High level API\n", - "\n", - "We recommend using `tf.keras` as a high-level API for building neural networks. That said, most TensorFlow APIs are usable with eager execution.\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "zSFfVVjkrrsI" - }, - "source": [ - "## Layers: common sets of useful operations\n", - "\n", - "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n", - "\n", - "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n", - "\n", - "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "8PyXlPl-4TzQ" - }, - "outputs": [], - "source": [ - "# In the tf.keras.layers package, layers are objects. To construct a layer,\n", - "# simply construct the object. Most layers take as a first argument the number\n", - "# of output dimensions / channels.\n", - "layer = tf.keras.layers.Dense(100)\n", - "# The number of input dimensions is often unnecessary, as it can be inferred\n", - "# the first time the layer is used, but it can be provided if you want to \n", - "# specify it manually, which is useful in some complex models.\n", - "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Fn69xxPO5Psr" - }, - "source": [ - "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n", - "Conv2D, LSTM, BatchNormalization, Dropout, and many others." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 204 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 244, - "status": "ok", - "timestamp": 1527783641557, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "E3XKNknP5Mhb", - "outputId": "c5d52434-d980-4488-efa7-5660819d0207" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\u003ctf.Tensor: id=30, shape=(10, 10), dtype=float32, numpy=\n", - "array([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)\u003e" - ] - }, - "execution_count": 3, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "# To use a layer, simply call it.\n", - "layer(tf.zeros([10, 5]))" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 221 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 320, - "status": "ok", - "timestamp": 1527783642457, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "Wt_Nsv-L5t2s", - "outputId": "f0d96dce-0128-4080-bfe2-0ee6fbc0ad90" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[\u003ctf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", - " array([[ 0.43788117, -0.62099844, -0.30525017, -0.59352523, 0.1783089 ,\n", - " 0.47078604, -0.23620895, -0.30482283, 0.01366901, -0.1288507 ],\n", - " [ 0.18407935, -0.56550485, 0.54180616, -0.42254075, 0.3702994 ,\n", - " 0.36705834, -0.29678228, 0.36660975, 0.36717761, 0.46269661],\n", - " [ 0.1709305 , -0.11529458, 0.32710236, 0.46300393, -0.62802851,\n", - " 0.51641601, 0.39624029, 0.26918125, -0.25196898, 0.21353298],\n", - " [ 0.35752094, 0.44161648, 0.61500639, -0.12653333, 0.41629118,\n", - " 0.36193585, 0.066082 , -0.59253877, 0.47318751, 0.17115968],\n", - " [-0.22554061, -0.17727301, 0.5525015 , 0.3678053 , -0.00454676,\n", - " 0.24066836, -0.53640735, 0.13792562, -0.10727292, 0.59708995]], dtype=float32)\u003e,\n", - " \u003ctf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)\u003e]" - ] - }, - "execution_count": 4, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "# Layers have many useful methods. For example, you can inspect all variables\n", - "# in a layer by calling layer.variables. In this case a fully-connected layer\n", - "# will have variables for weights and biases.\n", - "layer.variables" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 221 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 226, - "status": "ok", - "timestamp": 1527783643252, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "6ilvKjz8_4MQ", - "outputId": "f647fced-c2d7-41a3-c237-242036784665" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(\u003ctf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", - " array([[ 0.43788117, -0.62099844, -0.30525017, -0.59352523, 0.1783089 ,\n", - " 0.47078604, -0.23620895, -0.30482283, 0.01366901, -0.1288507 ],\n", - " [ 0.18407935, -0.56550485, 0.54180616, -0.42254075, 0.3702994 ,\n", - " 0.36705834, -0.29678228, 0.36660975, 0.36717761, 0.46269661],\n", - " [ 0.1709305 , -0.11529458, 0.32710236, 0.46300393, -0.62802851,\n", - " 0.51641601, 0.39624029, 0.26918125, -0.25196898, 0.21353298],\n", - " [ 0.35752094, 0.44161648, 0.61500639, -0.12653333, 0.41629118,\n", - " 0.36193585, 0.066082 , -0.59253877, 0.47318751, 0.17115968],\n", - " [-0.22554061, -0.17727301, 0.5525015 , 0.3678053 , -0.00454676,\n", - " 0.24066836, -0.53640735, 0.13792562, -0.10727292, 0.59708995]], dtype=float32)\u003e,\n", - " \u003ctf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)\u003e)" - ] - }, - "execution_count": 5, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "# The variables are also accessible through nice accessors\n", - "layer.kernel, layer.bias" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "O0kDbE54-5VS" - }, - "source": [ - "## Implementing custom layers\n", - "The best way to implement your own layer is extending the tf.keras.Layer class and implementing:\n", - " * `__init__` , where you can do all input-independent initialization\n", - " * `build`, where you know the shapes of the input tensors and can do the rest of the initialization\n", - " * `call`, where you do the forward computation\n", - "\n", - "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes required to create the variables will need to be explicitly specified." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 391 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 251, - "status": "ok", - "timestamp": 1527783661512, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "5Byl3n1k5kIy", - "outputId": "6e7f9285-649a-4132-82ce-73ea92f15862" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tf.Tensor(\n", - "[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(10, 10), dtype=float32)\n", - "[\u003ctf.Variable 'my_dense_layer_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", - "array([[-0.4011991 , 0.22458655, -0.33237562, -0.25117266, 0.33528614,\n", - " -0.01392961, 0.58580834, -0.16346583, 0.28465688, -0.47191954],\n", - " [-0.52922136, 0.22416979, -0.58209574, -0.60914612, 0.05226624,\n", - " -0.18325993, 0.5591442 , -0.24718609, 0.37148207, 0.40475875],\n", - " [ 0.16912812, -0.47618777, -0.38989353, 0.30105609, -0.08085585,\n", - " 0.44758242, 0.545829 , 0.51421839, 0.11063248, 0.20159996],\n", - " [ 0.34073615, -0.59835428, 0.06498981, -0.44489855, -0.34302285,\n", - " 0.20969599, 0.35527444, -0.03173476, -0.22227573, 0.09303057],\n", - " [ 0.41764337, -0.06435019, -0.52509922, -0.39957345, 0.56811184,\n", - " 0.23481232, -0.61666459, 0.31144124, -0.11532354, -0.42421889]], dtype=float32)\u003e]\n" - ] - } - ], - "source": [ - "class MyDenseLayer(tf.keras.layers.Layer):\n", - " def __init__(self, num_outputs):\n", - " super(MyDenseLayer, self).__init__()\n", - " self.num_outputs = num_outputs\n", - " \n", - " def build(self, input_shape):\n", - " self.kernel = self.add_variable(\"kernel\", \n", - " shape=[input_shape[-1].value, \n", - " self.num_outputs])\n", - " \n", - " def call(self, input):\n", - " return tf.matmul(input, self.kernel)\n", - " \n", - "layer = MyDenseLayer(10)\n", - "print(layer(tf.zeros([10, 5])))\n", - "print(layer.variables)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "tk8E2vY0-z4Z" - }, - "source": [ - "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`.\n", - "\n", - "Overall code is easier to read and maintain if it uses standard layers whenever possible, as other readers will be familiar with the behavior of standard layers. If you want to use a layer which is not present in tf.keras.layers or tf.contrib.layers, consider filing a [github issue](http://github.com/tensorflow/tensorflow/issues/new) or, even better, sending us a pull request!" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Qhg4KlbKrs3G" - }, - "source": [ - "## Models: composing layers\n", - "\n", - "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n", - "\n", - "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 190 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 420, - "status": "ok", - "timestamp": 1527783698512, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "N30DTXiRASlb", - "outputId": "a8b23a8e-5cf9-4bbf-f93b-6c763d74e2b3" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tf.Tensor(\n", - "[[[[ 0. 0. 0.]\n", - " [ 0. 0. 0.]\n", - " [ 0. 0. 0.]]\n", - "\n", - " [[ 0. 0. 0.]\n", - " [ 0. 0. 0.]\n", - " [ 0. 0. 0.]]]], shape=(1, 2, 3, 3), dtype=float32)\n", - "['resnet_identity_block_1/conv2d_3/kernel:0', 'resnet_identity_block_1/conv2d_3/bias:0', 'resnet_identity_block_1/batch_normalization_3/gamma:0', 'resnet_identity_block_1/batch_normalization_3/beta:0', 'resnet_identity_block_1/conv2d_4/kernel:0', 'resnet_identity_block_1/conv2d_4/bias:0', 'resnet_identity_block_1/batch_normalization_4/gamma:0', 'resnet_identity_block_1/batch_normalization_4/beta:0', 'resnet_identity_block_1/conv2d_5/kernel:0', 'resnet_identity_block_1/conv2d_5/bias:0', 'resnet_identity_block_1/batch_normalization_5/gamma:0', 'resnet_identity_block_1/batch_normalization_5/beta:0', 'resnet_identity_block_1/batch_normalization_3/moving_mean:0', 'resnet_identity_block_1/batch_normalization_3/moving_variance:0', 'resnet_identity_block_1/batch_normalization_4/moving_mean:0', 'resnet_identity_block_1/batch_normalization_4/moving_variance:0', 'resnet_identity_block_1/batch_normalization_5/moving_mean:0', 'resnet_identity_block_1/batch_normalization_5/moving_variance:0']\n" - ] - } - ], - "source": [ - "class ResnetIdentityBlock(tf.keras.Model):\n", - " def __init__(self, kernel_size, filters):\n", - " super(ResnetIdentityBlock, self).__init__(name='')\n", - " filters1, filters2, filters3 = filters\n", - "\n", - " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n", - " self.bn2a = tf.keras.layers.BatchNormalization()\n", - "\n", - " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n", - " self.bn2b = tf.keras.layers.BatchNormalization()\n", - "\n", - " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n", - " self.bn2c = tf.keras.layers.BatchNormalization()\n", - "\n", - " def call(self, input_tensor, training=False):\n", - " x = self.conv2a(input_tensor)\n", - " x = self.bn2a(x, training=training)\n", - " x = tf.nn.relu(x)\n", - "\n", - " x = self.conv2b(x)\n", - " x = self.bn2b(x, training=training)\n", - " x = tf.nn.relu(x)\n", - "\n", - " x = self.conv2c(x)\n", - " x = self.bn2c(x, training=training)\n", - "\n", - " x += input_tensor\n", - " return tf.nn.relu(x)\n", - "\n", - " \n", - "block = ResnetIdentityBlock(1, [1, 2, 3])\n", - "print(block(tf.zeros([1, 2, 3, 3])))\n", - "print([x.name for x in block.variables])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "wYfucVw65PMj" - }, - "source": [ - "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 153 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 361, - "status": "ok", - "timestamp": 1526674830777, - "user": { - "displayName": "Alexandre Passos", - "photoUrl": "//lh4.googleusercontent.com/-kmTTWXEgAPw/AAAAAAAAAAI/AAAAAAAAAC0/q_DoOzKGwds/s50-c-k-no/photo.jpg", - "userId": "108023195365833072773" - }, - "user_tz": 420 - }, - "id": "L9frk7Ur4uvJ", - "outputId": "882e9076-b6d9-4380-bb1e-7c6b57d54c39" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\u003ctf.Tensor: id=1423, shape=(1, 2, 3, 3), dtype=float32, numpy=\n", - "array([[[[0., 0., 0.],\n", - " [0., 0., 0.],\n", - " [0., 0., 0.]],\n", - "\n", - " [[0., 0., 0.],\n", - " [0., 0., 0.],\n", - " [0., 0., 0.]]]], dtype=float32)\u003e" - ] - }, - "execution_count": 26, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n", - " tf.keras.layers.BatchNormalization(),\n", - " tf.keras.layers.Conv2D(2, 1, \n", - " padding='same'),\n", - " tf.keras.layers.BatchNormalization(),\n", - " tf.keras.layers.Conv2D(3, (1, 1)),\n", - " tf.keras.layers.BatchNormalization()])\n", - "my_seq(tf.zeros([1, 2, 3, 3]))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "c5YwYcnuK-wc" - }, - "source": [ - "# Next steps\n", - "\n", - "Now you can go back to the previous notebook and adapt the linear regression example to use layers and models to be better structured." - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "default_view": {}, - "name": "4 - High level API - TensorFlow Eager.ipynb", - "provenance": [], - "version": "0.3.2", - "views": {} - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/README.md b/tensorflow/contrib/eager/python/examples/notebooks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0d5ed848946d1eee643a57bf8c341520268c56b1 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/README.md @@ -0,0 +1,11 @@ +## Research and experimentation + +Eager execution provides an imperative, define-by-run interface for advanced +operations. Write custom layers, forward passes, and training loops with auto +differentiation. Start with these notebooks, then read the +[eager execution guide](https://www.tensorflow.org/guide/eager). + +1. [Eager execution basics](./eager_basics.ipynb) +2. [Automatic differentiation and gradient tapes](./automatic_differentiation.ipynb) +3. [Custom training: basics](./custom_training.ipynb) +4. [Custom layers](./custom_layers.ipynb) diff --git a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..51b7ffc4de0cee31f7a907ae7bf90f17056f9bcf --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb @@ -0,0 +1,366 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "t09eeeR5prIJ" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "form", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "GCCk8_dHpuNf" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xh8WkEwWpnm7" + }, + "source": [ + "# Automatic differentiation and gradient tape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "idv0bPeCp325" + }, + "source": [ + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "vDJ4XzMqodTy" + }, + "source": [ + "In the previous tutorial we introduced `Tensor`s and operations on them. In this tutorial we will cover [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), a key technique for optimizing machine learning models." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "GQJysDM__Qb0" + }, + "source": [ + "## Setup\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "OiMPZStlibBv" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "\n", + "tfe = tf.contrib.eager # Shorthand for some symbols" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1CLWJl0QliB0" + }, + "source": [ + "## Derivatives of a function\n", + "\n", + "TensorFlow provides APIs for automatic differentiation - computing the derivative of a function. The way that more closely mimics the math is to encapsulate the computation in a Python function, say `f`, and use `tfe.gradients_function` to create a function that computes the derivatives of `f` with respect to its arguments. If you're familiar with [autograd](https://github.com/HIPS/autograd) for differentiating numpy functions, this will be familiar. For example: " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "9FViq92UX7P8" + }, + "outputs": [], + "source": [ + "from math import pi\n", + "\n", + "def f(x):\n", + " return tf.square(tf.sin(x))\n", + "\n", + "assert f(pi/2).numpy() == 1.0\n", + "\n", + "\n", + "# grad_f will return a list of derivatives of f\n", + "# with respect to its arguments. Since f() has a single argument,\n", + "# grad_f will return a list with a single element.\n", + "grad_f = tfe.gradients_function(f)\n", + "assert tf.abs(grad_f(pi/2)[0]).numpy() \u003c 1e-7" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "v9fPs8RyopCf" + }, + "source": [ + "### Higher-order gradients\n", + "\n", + "The same API can be used to differentiate as many times as you like:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "3D0ZvnGYo0rW" + }, + "outputs": [], + "source": [ + "def f(x):\n", + " return tf.square(tf.sin(x))\n", + "\n", + "def grad(f):\n", + " return lambda x: tfe.gradients_function(f)(x)[0]\n", + "\n", + "x = tf.lin_space(-2*pi, 2*pi, 100) # 100 points between -2π and +2π\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "plt.plot(x, f(x), label=\"f\")\n", + "plt.plot(x, grad(f)(x), label=\"first derivative\")\n", + "plt.plot(x, grad(grad(f))(x), label=\"second derivative\")\n", + "plt.plot(x, grad(grad(grad(f)))(x), label=\"third derivative\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-39gouo7mtgu" + }, + "source": [ + "## Gradient tapes\n", + "\n", + "Every differentiable TensorFlow operation has an associated gradient function. For example, the gradient function of `tf.square(x)` would be a function that returns `2.0 * x`. To compute the gradient of a user-defined function (like `f(x)` in the example above), TensorFlow first \"records\" all the operations applied to compute the output of the function. We call this record a \"tape\". It then uses that tape and the gradients functions associated with each primitive operation to compute the gradients of the user-defined function using [reverse mode differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation).\n", + "\n", + "Since operations are recorded as they are executed, Python control flow (using `if`s and `while`s for example) is naturally handled:\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "MH0UfjympWf7" + }, + "outputs": [], + "source": [ + "def f(x, y):\n", + " output = 1\n", + " # Must use range(int(y)) instead of range(y) in Python 3 when\n", + " # using TensorFlow 1.10 and earlier. Can use range(y) in 1.11+\n", + " for i in range(int(y)):\n", + " output = tf.multiply(output, x)\n", + " return output\n", + "\n", + "def g(x, y):\n", + " # Return the gradient of `f` with respect to it's first parameter\n", + " return tfe.gradients_function(f)(x, y)[0]\n", + "\n", + "assert f(3.0, 2).numpy() == 9.0 # f(x, 2) is essentially x * x\n", + "assert g(3.0, 2).numpy() == 6.0 # And its gradient will be 2 * x\n", + "assert f(4.0, 3).numpy() == 64.0 # f(x, 3) is essentially x * x * x\n", + "assert g(4.0, 3).numpy() == 48.0 # And its gradient will be 3 * x * x" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "aNmR5-jhpX2t" + }, + "source": [ + "At times it may be inconvenient to encapsulate computation of interest into a function. For example, if you want the gradient of the output with respect to intermediate values computed in the function. In such cases, the slightly more verbose but explicit [tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context is useful. All computation inside the context of a `tf.GradientTape` is \"recorded\".\n", + "\n", + "For example:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "bAFeIE8EuVIq" + }, + "outputs": [], + "source": [ + "x = tf.ones((2, 2))\n", + " \n", + "# TODO(b/78880779): Remove the 'persistent=True' argument and use\n", + "# a single t.gradient() call when the bug is resolved.\n", + "with tf.GradientTape(persistent=True) as t:\n", + " # TODO(ashankar): Explain with \"watch\" argument better?\n", + " t.watch(x)\n", + " y = tf.reduce_sum(x)\n", + " z = tf.multiply(y, y)\n", + "\n", + "# Use the same tape to compute the derivative of z with respect to the\n", + "# intermediate value y.\n", + "dz_dy = t.gradient(z, y)\n", + "assert dz_dy.numpy() == 8.0\n", + "\n", + "# Derivative of z with respect to the original input tensor x\n", + "dz_dx = t.gradient(z, x)\n", + "for i in [0, 1]:\n", + " for j in [0, 1]:\n", + " assert dz_dx[i][j].numpy() == 8.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "DK05KXrAAld3" + }, + "source": [ + "### Higher-order gradients\n", + "\n", + "Operations inside of the `GradientTape` context manager are recorded for automatic differentiation. If gradients are computed in that context, then the gradient computation is recorded as well. As a result, the exact same API works for higher-order gradients as well. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "cPQgthZ7ugRJ" + }, + "outputs": [], + "source": [ + "# TODO(ashankar): Should we use the persistent tape here instead? Follow up on Tom and Alex's discussion\n", + "\n", + "x = tf.constant(1.0) # Convert the Python 1.0 to a Tensor object\n", + "\n", + "with tf.GradientTape() as t:\n", + " with tf.GradientTape() as t2:\n", + " t2.watch(x)\n", + " y = x * x * x\n", + " # Compute the gradient inside the 't' context manager\n", + " # which means the gradient computation is differentiable as well.\n", + " dy_dx = t2.gradient(y, x)\n", + "d2y_dx2 = t.gradient(dy_dx, x)\n", + "\n", + "assert dy_dx.numpy() == 3.0\n", + "assert d2y_dx2.numpy() == 6.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4U1KKzUpNl58" + }, + "source": [ + "## Next Steps\n", + "\n", + "In this tutorial we covered gradient computation in TensorFlow. With that we have enough of the primitives required to build an train neural networks, which we will cover in the [next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/3_neural_networks.ipynb)." + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "automatic_differentiation.ipynb", + "private_outputs": true, + "provenance": [], + "toc_visible": true, + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..a0bbbb612381c5eb386b04fd7bb9914eb01f4c8e --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb @@ -0,0 +1,399 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "custom_layers.ipynb", + "version": "0.3.2", + "views": {}, + "default_view": {}, + "provenance": [], + "private_outputs": true, + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "cells": [ + { + "metadata": { + "id": "tDnwEv8FtJm7", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "##### Copyright 2018 The TensorFlow Authors." + ] + }, + { + "metadata": { + "id": "JlknJBWQtKkI", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "cellView": "form" + }, + "cell_type": "code", + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "60RdWsg1tETW", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Custom layers" + ] + }, + { + "metadata": { + "id": "BcJg7Enms86w", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "
\n", + "\n", + " Run in Google Colab\n", + "\n", + "View source on GitHub
" + ] + }, + { + "metadata": { + "id": "UEu3q4jmpKVT", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "We recommend using `tf.keras` as a high-level API for building neural networks. That said, most TensorFlow APIs are usable with eager execution.\n" + ] + }, + { + "metadata": { + "id": "pwX7Fii1rwsJ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "tfe = tf.contrib.eager\n", + "\n", + "tf.enable_eager_execution()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "zSFfVVjkrrsI", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Layers: common sets of useful operations\n", + "\n", + "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n", + "\n", + "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n", + "\n", + "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n" + ] + }, + { + "metadata": { + "id": "8PyXlPl-4TzQ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# In the tf.keras.layers package, layers are objects. To construct a layer,\n", + "# simply construct the object. Most layers take as a first argument the number\n", + "# of output dimensions / channels.\n", + "layer = tf.keras.layers.Dense(100)\n", + "# The number of input dimensions is often unnecessary, as it can be inferred\n", + "# the first time the layer is used, but it can be provided if you want to \n", + "# specify it manually, which is useful in some complex models.\n", + "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Fn69xxPO5Psr", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n", + "Conv2D, LSTM, BatchNormalization, Dropout, and many others." + ] + }, + { + "metadata": { + "id": "E3XKNknP5Mhb", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# To use a layer, simply call it.\n", + "layer(tf.zeros([10, 5]))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Wt_Nsv-L5t2s", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Layers have many useful methods. For example, you can inspect all variables\n", + "# in a layer by calling layer.variables. In this case a fully-connected layer\n", + "# will have variables for weights and biases.\n", + "layer.variables" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "6ilvKjz8_4MQ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# The variables are also accessible through nice accessors\n", + "layer.kernel, layer.bias" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "O0kDbE54-5VS", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Implementing custom layers\n", + "The best way to implement your own layer is extending the tf.keras.Layer class and implementing:\n", + " * `__init__` , where you can do all input-independent initialization\n", + " * `build`, where you know the shapes of the input tensors and can do the rest of the initialization\n", + " * `call`, where you do the forward computation\n", + "\n", + "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes required to create the variables will need to be explicitly specified." + ] + }, + { + "metadata": { + "id": "5Byl3n1k5kIy", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class MyDenseLayer(tf.keras.layers.Layer):\n", + " def __init__(self, num_outputs):\n", + " super(MyDenseLayer, self).__init__()\n", + " self.num_outputs = num_outputs\n", + " \n", + " def build(self, input_shape):\n", + " self.kernel = self.add_variable(\"kernel\", \n", + " shape=[input_shape[-1].value, \n", + " self.num_outputs])\n", + " \n", + " def call(self, input):\n", + " return tf.matmul(input, self.kernel)\n", + " \n", + "layer = MyDenseLayer(10)\n", + "print(layer(tf.zeros([10, 5])))\n", + "print(layer.variables)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "tk8E2vY0-z4Z", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`.\n", + "\n", + "Overall code is easier to read and maintain if it uses standard layers whenever possible, as other readers will be familiar with the behavior of standard layers. If you want to use a layer which is not present in tf.keras.layers or tf.contrib.layers, consider filing a [github issue](http://github.com/tensorflow/tensorflow/issues/new) or, even better, sending us a pull request!" + ] + }, + { + "metadata": { + "id": "Qhg4KlbKrs3G", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Models: composing layers\n", + "\n", + "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n", + "\n", + "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model." + ] + }, + { + "metadata": { + "id": "N30DTXiRASlb", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class ResnetIdentityBlock(tf.keras.Model):\n", + " def __init__(self, kernel_size, filters):\n", + " super(ResnetIdentityBlock, self).__init__(name='')\n", + " filters1, filters2, filters3 = filters\n", + "\n", + " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n", + " self.bn2a = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n", + " self.bn2b = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n", + " self.bn2c = tf.keras.layers.BatchNormalization()\n", + "\n", + " def call(self, input_tensor, training=False):\n", + " x = self.conv2a(input_tensor)\n", + " x = self.bn2a(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2b(x)\n", + " x = self.bn2b(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2c(x)\n", + " x = self.bn2c(x, training=training)\n", + "\n", + " x += input_tensor\n", + " return tf.nn.relu(x)\n", + "\n", + " \n", + "block = ResnetIdentityBlock(1, [1, 2, 3])\n", + "print(block(tf.zeros([1, 2, 3, 3])))\n", + "print([x.name for x in block.variables])" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "wYfucVw65PMj", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential" + ] + }, + { + "metadata": { + "id": "L9frk7Ur4uvJ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(2, 1, \n", + " padding='same'),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(3, (1, 1)),\n", + " tf.keras.layers.BatchNormalization()])\n", + "my_seq(tf.zeros([1, 2, 3, 3]))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "c5YwYcnuK-wc", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Next steps\n", + "\n", + "Now you can go back to the previous notebook and adapt the linear regression example to use layers and models to be better structured." + ] + } + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..5f1b48fa0d4aea06adab19a0e561923e1f557e50 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb @@ -0,0 +1,477 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Custom training: basics", + "version": "0.3.2", + "views": {}, + "default_view": {}, + "provenance": [], + "private_outputs": true, + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "metadata": { + "id": "5rmpybwysXGV", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "##### Copyright 2018 The TensorFlow Authors." + ] + }, + { + "metadata": { + "id": "m8y3rGtQsYP2", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "cellView": "form" + }, + "cell_type": "code", + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "hrXv0rU9sIma", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Custom training: basics" + ] + }, + { + "metadata": { + "id": "7S0BwJ_8sLu7", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "
\n", + "\n", + " Run in Google Colab\n", + "\n", + "View source on GitHub
" + ] + }, + { + "metadata": { + "id": "k2o3TTG4TFpt", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "In the previous tutorial we covered the TensorFlow APIs for automatic differentiation, a basic building block for machine learning.\n", + "In this tutorial we will use the TensorFlow primitives introduced in the prior tutorials to do some simple machine learning.\n", + "\n", + "TensorFlow also includes a higher-level neural networks API (`tf.keras`) which provides useful abstractions to reduce boilerplate. We strongly recommend those higher level APIs for people working with neural networks. However, in this short tutorial we cover neural network training from first principles to establish a strong foundation." + ] + }, + { + "metadata": { + "id": "3LXMVuV0VhDr", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Setup" + ] + }, + { + "metadata": { + "id": "PJ64L90aVir3", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "\n", + "tf.enable_eager_execution()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "eMAWbDJFVmMk", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Variables\n", + "\n", + "Tensors in TensorFlow are immutable stateless objects. Machine learning models, however, need to have changing state: as your model trains, the same code to compute predictions should behave differently over time (hopefully with a lower loss!). To represent this state which needs to change over the course of your computation, you can choose to rely on the fact that Python is a stateful programming language:\n" + ] + }, + { + "metadata": { + "id": "VkJwtLS_Jbn8", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Using python state\n", + "x = tf.zeros([10, 10])\n", + "x += 2 # This is equivalent to x = x + 2, which does not mutate the original\n", + " # value of x\n", + "print(x)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "wfneTXy7JcUz", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "TensorFlow, however, has stateful operations built in, and these are often more pleasant to use than low-level Python representations of your state. To represent weights in a model, for example, it's often convenient and efficient to use TensorFlow variables.\n", + "\n", + "A Variable is an object which stores a value and, when used in a TensorFlow computation, will implicitly read from this stored value. There are operations (`tf.assign_sub`, `tf.scatter_update`, etc) which manipulate the value stored in a TensorFlow variable." + ] + }, + { + "metadata": { + "id": "itxmrMil6DQi", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "v = tf.Variable(1.0)\n", + "assert v.numpy() == 1.0\n", + "\n", + "# Re-assign the value\n", + "v.assign(3.0)\n", + "assert v.numpy() == 3.0\n", + "\n", + "# Use `v` in a TensorFlow operation like tf.square() and reassign\n", + "v.assign(tf.square(v))\n", + "assert v.numpy() == 9.0" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "-paSaeq1JzwC", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Computations using Variables are automatically traced when computing gradients. For Variables representing embeddings TensorFlow will do sparse updates by default, which are more computation and memory efficient.\n", + "\n", + "Using Variables is also a way to quickly let a reader of your code know that this piece of state is mutable." + ] + }, + { + "metadata": { + "id": "BMiFcDzE7Qu3", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Example: Fitting a linear model\n", + "\n", + "Let's now put the few concepts we have so far ---`Tensor`, `GradientTape`, `Variable` --- to build and train a simple model. This typically involves a few steps:\n", + "\n", + "1. Define the model.\n", + "2. Define a loss function.\n", + "3. Obtain training data.\n", + "4. Run through the training data and use an \"optimizer\" to adjust the variables to fit the data.\n", + "\n", + "In this tutorial, we'll walk through a trivial example of a simple linear model: `f(x) = x * W + b`, which has two variables - `W` and `b`. Furthermore, we'll synthesize data such that a well trained model would have `W = 3.0` and `b = 2.0`." + ] + }, + { + "metadata": { + "id": "gFzH64Jn9PIm", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Define the model\n", + "\n", + "Let's define a simple class to encapsulate the variables and the computation." + ] + }, + { + "metadata": { + "id": "_WRu7Pze7wk8", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class Model(object):\n", + " def __init__(self):\n", + " # Initialize variable to (5.0, 0.0)\n", + " # In practice, these should be initialized to random values.\n", + " self.W = tf.Variable(5.0)\n", + " self.b = tf.Variable(0.0)\n", + " \n", + " def __call__(self, x):\n", + " return self.W * x + self.b\n", + " \n", + "model = Model()\n", + "\n", + "assert model(3.0).numpy() == 15.0" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "xa6j_yXa-j79", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Define a loss function\n", + "\n", + "A loss function measures how well the output of a model for a given input matches the desired output. Let's use the standard L2 loss." + ] + }, + { + "metadata": { + "id": "Y0ysUFGY924U", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def loss(predicted_y, desired_y):\n", + " return tf.reduce_mean(tf.square(predicted_y - desired_y))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "qutT_fkl_CBc", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Obtain training data\n", + "\n", + "Let's synthesize the training data with some noise." + ] + }, + { + "metadata": { + "id": "gxPTb-kt_N5m", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "TRUE_W = 3.0\n", + "TRUE_b = 2.0\n", + "NUM_EXAMPLES = 1000\n", + "\n", + "inputs = tf.random_normal(shape=[NUM_EXAMPLES])\n", + "noise = tf.random_normal(shape=[NUM_EXAMPLES])\n", + "outputs = inputs * TRUE_W + TRUE_b + noise" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "-50nq-wPBsAW", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Before we train the model let's visualize where the model stands right now. We'll plot the model's predictions in red and the training data in blue." + ] + }, + { + "metadata": { + "id": "_eb83LtrB4nt", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.scatter(inputs, outputs, c='b')\n", + "plt.scatter(inputs, model(inputs), c='r')\n", + "plt.show()\n", + "\n", + "print('Current loss: '),\n", + "print(loss(model(inputs), outputs).numpy())" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "sSDP-yeq_4jE", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Define a training loop\n", + "\n", + "We now have our network and our training data. Let's train it, i.e., use the training data to update the model's variables (`W` and `b`) so that the loss goes down using [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent). There are many variants of the gradient descent scheme that are captured in `tf.train.Optimizer` implementations. We'd highly recommend using those implementations, but in the spirit of building from first principles, in this particular example we will implement the basic math ourselves." + ] + }, + { + "metadata": { + "id": "MBIACgdnA55X", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def train(model, inputs, outputs, learning_rate):\n", + " with tf.GradientTape() as t:\n", + " current_loss = loss(model(inputs), outputs)\n", + " dW, db = t.gradient(current_loss, [model.W, model.b])\n", + " model.W.assign_sub(learning_rate * dW)\n", + " model.b.assign_sub(learning_rate * db)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "RwWPaJryD2aN", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Finally, let's repeatedly run through the training data and see how `W` and `b` evolve." + ] + }, + { + "metadata": { + "id": "XdfkR223D9dW", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "model = Model()\n", + "\n", + "# Collect the history of W-values and b-values to plot later\n", + "Ws, bs = [], []\n", + "epochs = range(10)\n", + "for epoch in epochs:\n", + " Ws.append(model.W.numpy())\n", + " bs.append(model.b.numpy())\n", + " current_loss = loss(model(inputs), outputs)\n", + "\n", + " train(model, inputs, outputs, learning_rate=0.1)\n", + " print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %\n", + " (epoch, Ws[-1], bs[-1], current_loss))\n", + "\n", + "# Let's plot it all\n", + "plt.plot(epochs, Ws, 'r',\n", + " epochs, bs, 'b')\n", + "plt.plot([TRUE_W] * len(epochs), 'r--',\n", + " [TRUE_b] * len(epochs), 'b--')\n", + "plt.legend(['W', 'b', 'true W', 'true_b'])\n", + "plt.show()\n", + " " + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "vPnIVuaSJwWz", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Next Steps\n", + "\n", + "In this tutorial we covered `Variable`s and built and trained a simple linear model using the TensorFlow primitives discussed so far.\n", + "\n", + "In theory, this is pretty much all you need to use TensorFlow for your machine learning research.\n", + "In practice, particularly for neural networks, the higher level APIs like `tf.keras` will be much more convenient since it provides higher level building blocks (called \"layers\"), utilities to save and restore state, a suite of loss functions, a suite of optimization strategies etc. \n", + "\n", + "The [next tutorial](TODO) will cover these higher level APIs." + ] + } + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb similarity index 50% rename from tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb rename to tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb index 51d10a778413cfbb574b4e22e8adcb18bd731dee..f1e13de5dec2fbda126caeb355494875317e3373 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb @@ -1,27 +1,107 @@ { + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "eager_basics.ipynb", + "version": "0.3.2", + "views": {}, + "default_view": {}, + "provenance": [], + "private_outputs": true, + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, "cells": [ { + "metadata": { + "id": "iPpI7RaYoZuE", + "colab_type": "text" + }, "cell_type": "markdown", + "source": [ + "##### Copyright 2018 The TensorFlow Authors." + ] + }, + { "metadata": { - "colab_type": "text", - "id": "U9i2Dsh-ziXr" + "id": "hro2InpHobKk", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "cellView": "form" }, + "cell_type": "code", + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "U9i2Dsh-ziXr", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Eager execution basics" + ] + }, + { + "metadata": { + "id": "Hndw-YcxoOJK", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "
\n", + "\n", + " Run in Google Colab\n", + "\n", + "View source on GitHub
" + ] + }, + { + "metadata": { + "id": "6sILUVbHoSgH", + "colab_type": "text" + }, + "cell_type": "markdown", "source": [ - "# An introduction to TensorFlow\n", - "\n", "This is an introductory tutorial for using TensorFlow. It will cover:\n", "\n", "* Importing required packages\n", "* Creating and using Tensors\n", - "* Using GPU acceleration\n" + "* Using GPU acceleration\n", + "* Datasets" ] }, { - "cell_type": "markdown", "metadata": { - "colab_type": "text", - "id": "z1JcS5iBXMRO" + "id": "z1JcS5iBXMRO", + "colab_type": "text" }, + "cell_type": "markdown", "source": [ "## Import TensorFlow\n", "\n", @@ -30,32 +110,32 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "cellView": "code", + "id": "RlIWhyeLoYnG", + "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, - "colab_type": "code", - "id": "RlIWhyeLoYnG" + "cellView": "code" }, - "outputs": [], + "cell_type": "code", "source": [ "import tensorflow as tf\n", "\n", "tf.enable_eager_execution()" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { - "colab_type": "text", - "id": "H9UySOPLXdaw" + "id": "H9UySOPLXdaw", + "colab_type": "text" }, + "cell_type": "markdown", "source": [ "## Tensors\n", "\n", @@ -63,46 +143,18 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "cellView": "code", + "id": "ngUe237Wt48W", + "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - }, - "height": 125 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 320, - "status": "ok", - "timestamp": 1526420535530, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 + } }, - "id": "ngUe237Wt48W", - "outputId": "b1a1cd60-4eb3-443d-cd6b-68406390784e" + "cellView": "code" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tf.Tensor(3, shape=(), dtype=int32)\n", - "tf.Tensor([4 6], shape=(2,), dtype=int32)\n", - "tf.Tensor(25, shape=(), dtype=int32)\n", - "tf.Tensor(6, shape=(), dtype=int32)\n", - "tf.Tensor(aGVsbG8gd29ybGQ, shape=(), dtype=string)\n", - "tf.Tensor(13, shape=(), dtype=int32)\n" - ] - } - ], + "cell_type": "code", "source": [ "print(tf.add(1, 2))\n", "print(tf.add([1, 2], [3, 4]))\n", @@ -112,66 +164,46 @@ "\n", "# Operator overloading is also supported\n", "print(tf.square(2) + tf.square(3))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { - "colab_type": "text", - "id": "IDY4WsYRhP81" + "id": "IDY4WsYRhP81", + "colab_type": "text" }, + "cell_type": "markdown", "source": [ "Each Tensor has a shape and a datatype" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { + "id": "srYWH1MdJNG7", + "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - }, - "height": 53 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 215, - "status": "ok", - "timestamp": 1526420538162, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "srYWH1MdJNG7", - "outputId": "5e4ac41c-5115-4e50-eba0-42e249c16561" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(1, 2)\n", - "\u003cdtype: 'int32'\u003e\n" - ] + } } - ], + }, + "cell_type": "code", "source": [ "x = tf.matmul([[1]], [[2, 3]])\n", "print(x.shape)\n", "print(x.dtype)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { - "colab_type": "text", - "id": "eBPw8e8vrsom" + "id": "eBPw8e8vrsom", + "colab_type": "text" }, + "cell_type": "markdown", "source": [ "The most obvious differences between NumPy arrays and TensorFlow Tensors are:\n", "\n", @@ -180,11 +212,11 @@ ] }, { - "cell_type": "markdown", "metadata": { - "colab_type": "text", - "id": "Dwi1tdW3JBw6" + "id": "Dwi1tdW3JBw6", + "colab_type": "text" }, + "cell_type": "markdown", "source": [ "### NumPy Compatibility\n", "\n", @@ -197,52 +229,17 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { + "id": "lCUWzso6mbqR", + "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - }, - "height": 251 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 238, - "status": "ok", - "timestamp": 1526420540562, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "lCUWzso6mbqR", - "outputId": "fd0a22bc-8249-49dd-fcbd-63161cc47e46" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TensorFlow operations convert numpy arrays to Tensors automatically\n", - "tf.Tensor(\n", - "[[ 42. 42. 42.]\n", - " [ 42. 42. 42.]\n", - " [ 42. 42. 42.]], shape=(3, 3), dtype=float64)\n", - "And NumPy operations convert Tensors to numpy arrays automatically\n", - "[[ 43. 43. 43.]\n", - " [ 43. 43. 43.]\n", - " [ 43. 43. 43.]]\n", - "The .numpy() method explicitly converts a Tensor to a numpy array\n", - "[[ 42. 42. 42.]\n", - " [ 42. 42. 42.]\n", - " [ 42. 42. 42.]]\n" - ] + } } - ], + }, + "cell_type": "code", "source": [ "import numpy as np\n", "\n", @@ -258,14 +255,16 @@ "\n", "print(\"The .numpy() method explicitly converts a Tensor to a numpy array\")\n", "print(tensor.numpy())" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { - "colab_type": "text", - "id": "PBNP8yTRfu_X" + "id": "PBNP8yTRfu_X", + "colab_type": "text" }, + "cell_type": "markdown", "source": [ "## GPU acceleration\n", "\n", @@ -273,42 +272,18 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "cellView": "code", + "id": "3Twf_Rw-gQFM", + "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - }, - "height": 53 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 340, - "status": "ok", - "timestamp": 1526420543562, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 + } }, - "id": "3Twf_Rw-gQFM", - "outputId": "2239ae2b-adf3-4895-b1f3-464cf5361d1b" + "cellView": "code" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Is there a GPU available: False\n", - "Is the Tensor on GPU #0: False\n" - ] - } - ], + "cell_type": "code", "source": [ "x = tf.random_uniform([3, 3])\n", "\n", @@ -317,26 +292,28 @@ "\n", "print(\"Is the Tensor on GPU #0: \"),\n", "print(x.device.endswith('GPU:0'))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { - "colab_type": "text", - "id": "vpgYzgVXW2Ud" + "id": "vpgYzgVXW2Ud", + "colab_type": "text" }, + "cell_type": "markdown", "source": [ "### Device Names\n", "\n", - "The `Tensor.device` property provides a fully qualified string name of the device hosting the contents of the Tensor. This name encodes a bunch of details, such as an identifier of the network address of the host on which this program is executing and the device within that host. This is required for distributed execution of TensorFlow programs, but we'll skip that for now. The string will end with `GPU:\u003cN\u003e` if the tensor is placed on the `N`-th tensor on the host." + "The `Tensor.device` property provides a fully qualified string name of the device hosting the contents of the Tensor. This name encodes a bunch of details, such as an identifier of the network address of the host on which this program is executing and the device within that host. This is required for distributed execution of TensorFlow programs, but we'll skip that for now. The string will end with `GPU:` if the tensor is placed on the `N`-th tensor on the host." ] }, { - "cell_type": "markdown", "metadata": { - "colab_type": "text", - "id": "ZWZQCimzuqyP" + "id": "ZWZQCimzuqyP", + "colab_type": "text" }, + "cell_type": "markdown", "source": [ "\n", "\n", @@ -346,41 +323,17 @@ ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { + "id": "RjkNZTuauy-Q", + "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - }, - "height": 53 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 1762, - "status": "ok", - "timestamp": 1526420547562, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "RjkNZTuauy-Q", - "outputId": "2e613293-ccac-4db2-b793-8ceb5b5adcfd" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "On CPU:\n", - "10 loops, best of 3: 35.8 ms per loop\n" - ] + } } - ], + }, + "cell_type": "code", "source": [ "def time_matmul(x):\n", " %timeit tf.matmul(x, x)\n", @@ -398,32 +351,141 @@ " x = tf.random_uniform([1000, 1000])\n", " assert x.device.endswith(\"GPU:0\")\n", " time_matmul(x)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "o1K4dlhhHtQj", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Datasets\n", + "\n", + "This section demonstrates the use of the [`tf.data.Dataset` API](https://www.tensorflow.org/guide/datasets) to build pipelines to feed data to your model. It covers:\n", + "\n", + "* Creating a `Dataset`.\n", + "* Iteration over a `Dataset` with eager execution enabled.\n", + "\n", + "We recommend using the `Dataset`s API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.\n", + "\n", + "If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly simpler.\n", + "You can use Python iteration over the `tf.data.Dataset` object and do not need to explicitly create an `tf.data.Iterator` object.\n", + "As a result, the discussion on iterators in the [TensorFlow Guide](https://www.tensorflow.org/guide/datasets) is not relevant when eager execution is enabled." ] }, { + "metadata": { + "id": "zI0fmOynH-Ne", + "colab_type": "text" + }, "cell_type": "markdown", + "source": [ + "### Create a source `Dataset`\n", + "\n", + "Create a _source_ dataset using one of the factory functions like [`Dataset.from_tensors`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensors), [`Dataset.from_tensor_slices`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensor_slices) or using objects that read from files like [`TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TextLineDataset) or [`TFRecordDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset). See the [TensorFlow Guide](https://www.tensorflow.org/guide/datasets#reading_input_data) for more information." + ] + }, + { "metadata": { - "colab_type": "text", - "id": "YEOJTNiOvnpQ" + "id": "F04fVOHQIBiG", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } }, + "cell_type": "code", "source": [ - "## Next Steps\n", + "ds_tensors = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])\n", "\n", - "In this tutorial we covered the most fundamental concepts in TensorFlow - `Tensor`s, operations, and devices.\n", - "In [the next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/2_gradients.ipynb) we will cover automatic differentiation - a building block required for training many machine learning models like neural networks." + "# Create a CSV file\n", + "import tempfile\n", + "_, filename = tempfile.mkstemp()\n", + "\n", + "with open(filename, 'w') as f:\n", + " f.write(\"\"\"Line 1\n", + "Line 2\n", + "Line 3\n", + " \"\"\")\n", + "\n", + "ds_file = tf.data.TextLineDataset(filename)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "vbxIhC-5IPdf", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Apply transformations\n", + "\n", + "Use the transformations functions like [`map`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map), [`batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch), [`shuffle`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle) etc. to apply transformations to the records of the dataset. See the [API documentation for `tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) for details." ] + }, + { + "metadata": { + "id": "uXSDZWE-ISsd", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "ds_tensors = ds_tensors.map(tf.square).shuffle(2).batch(2)\n", + "\n", + "ds_file = ds_file.batch(2)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "A8X1GNfoIZKJ", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Iterate\n", + "\n", + "When eager execution is enabled `Dataset` objects support iteration.\n", + "If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that there is no need for calls to `Dataset.make_one_shot_iterator()` or `get_next()` calls." + ] + }, + { + "metadata": { + "id": "ws-WKRk5Ic6-", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "print('Elements of ds_tensors:')\n", + "for x in ds_tensors:\n", + " print(x)\n", + "\n", + "print('\\nElements in ds_file:')\n", + "for x in ds_file:\n", + " print(x)" + ], + "execution_count": 0, + "outputs": [] } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "default_view": {}, - "name": "TensorFlow: An introduction", - "provenance": [], - "version": "0.3.2", - "views": {} - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index b14ef1df8ff4c660b9b6f2abfd5df6572d10b1e8..07d8788882c2d831dfb041fe7409af51857190bf 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -29,6 +29,7 @@ import tensorflow.contrib.eager as tfe from tensorflow.contrib.eager.python.examples.resnet50 import resnet50 from tensorflow.contrib.summary import summary_test_util from tensorflow.python.client import device_lib +from tensorflow.python.eager import tape def device_and_data_format(): @@ -49,13 +50,21 @@ def random_batch(batch_size, data_format): return images, one_hot -def compute_gradients(model, images, labels): - with tf.GradientTape() as tape: +def compute_gradients(model, images, labels, num_replicas=1): + with tf.GradientTape() as grad_tape: logits = model(images, training=True) loss = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) tf.contrib.summary.scalar(name='loss', tensor=loss) - return tape.gradient(loss, model.variables) + if num_replicas != 1: + loss /= num_replicas + + # TODO(b/110991947): We can mistakenly trace the gradient call in + # multi-threaded environment. Explicitly disable recording until + # this is fixed. + with tape.stop_recording(): + grads = grad_tape.gradient(loss, model.variables) + return grads def apply_gradients(model, optimizer, gradients): @@ -188,11 +197,14 @@ class ResNet50Benchmarks(tf.test.Benchmark): return (32,) return (16, 32) - def _report(self, label, start, num_iters, device, batch_size, data_format): + def _report(self, label, start, num_iters, device, batch_size, data_format, + num_replicas=1): avg_time = (time.time() - start) / num_iters dev = tf.DeviceSpec.from_string(device).device_type.lower() - name = '%s_%s_batch_%d_%s' % (label, dev, batch_size, data_format) - extras = {'examples_per_sec': batch_size / avg_time} + replica_str = '' if num_replicas == 1 else 'replicas_%d_' % num_replicas + name = '%s_%s_batch_%d_%s%s' % (label, dev, batch_size, + replica_str, data_format) + extras = {'examples_per_sec': (num_replicas * batch_size) / avg_time} self.report_benchmark( iters=num_iters, wall_time=avg_time, name=name, extras=extras) diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD index 432bb546f83932d0e0a465d7af7c641b60d2e564..4f0d46b1bae3760a63b2abe871034bdedf258f07 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/BUILD +++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD @@ -43,6 +43,27 @@ py_library( ], ) +py_library( + name = "resnet_preprocessing", + srcs = ["resnet_preprocessing.py"], + srcs_version = "PY2AND3", + tags = ["local"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "imagenet_input", + srcs = ["imagenet_input.py"], + srcs_version = "PY2AND3", + tags = ["local"], + deps = [ + ":resnet_preprocessing", + "//tensorflow:tensorflow_py", + ], +) + # Tests cuda_py_test( name = "ops_test", @@ -72,11 +93,13 @@ cuda_py_test( size = "large", srcs = ["revnet_test.py"], additional_deps = [ + ":blocks_test", ":config", ":revnet", "//tensorflow:tensorflow_py", ], tags = [ + "no_pip", # depends on blocks_test, which is not available in pip package "optonly", ], ) @@ -87,7 +110,6 @@ py_library( srcs = ["cifar_input.py"], srcs_version = "PY2AND3", deps = [ - ":revnet", "//tensorflow:tensorflow_py", ], ) @@ -112,3 +134,39 @@ py_binary( "//tensorflow:tensorflow_py", ], ) + +py_binary( + name = "main_estimator", + srcs = ["main_estimator.py"], + srcs_version = "PY2AND3", + deps = [ + ":cifar_input", + ":main", + ":revnet", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "main_estimator_lib", + srcs = ["main_estimator.py"], + srcs_version = "PY2AND3", + deps = [ + ":cifar_input", + ":main", + ":revnet", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "main_estimator_tpu_lib", + srcs = ["main_estimator_tpu.py"], + srcs_version = "PY2AND3", + deps = [ + ":cifar_input", + ":main", + ":revnet", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/revnet/README.md b/tensorflow/contrib/eager/python/examples/revnet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..822d86e9c7a7e620da3b84ded9af98b1c1d4b701 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/README.md @@ -0,0 +1,112 @@ +# RevNet with TensorFlow eager execution + +This folder contains a TensorFlow eager implementation of the [Reversible Residual Network](https://arxiv.org/pdf/1707.04585.pdf) adapted from the released implementation by the authors. The presented implementation can be ran with both eager and graph execution. The code is considerably simplified with `tf.GradientTape`. Moreover, we reduce the a redundant forward pass in the implementation by the authors. This saves us from using `tf.stop_gradient` and makes the model run faster. + +## Content + +- `revnet.py`: The RevNet model. +- `blocks.py`: The relevant reversible blocks. +- `ops.py`: Auxiliary downsampling operation. +- `cifar_tfrecords.py`: Script to generate the TFRecords for both CIFAR-10 and CIFAR-100. +- `cifar_input.py`: Script to read from TFRecords and generate dataset objects with the `tf.data` API. +- `config.py`: Configuration file for network architectures and training hyperparameters. +- `main.py`: Main training and evaluation script. +- `main_estimator.py`: Script to train RevNet models on CIFAR-10 and CIFAR-100 with the `tf.estimator` API. +- `main_estimator_tpu.py`: Script to train RevNet models on ImageNet with TPU estimators on Cloud TPUs. +- `resnet_preprocessing.py`, `imagenet_input.py`: Boilerplate to read ImageNet data from TFRecords. + +## Train on CIFAR-10/CIFAR-100 +- Make sure you have installed TensorFlow 1.10+ or the latest `tf-nightly` +or `tf-nightly-gpu` pip package in order to access the eager execution feature. + +- First run + +```bash +python cifar_tfrecords.py --data_dir ${PWD}/cifar +``` +to download the cifar dataset and convert them +to TFRecords. This produces TFRecord files for both CIFAR-10 and CIFAR-100. + +- To train a model, run + +```bash +python main.py --data_dir ${PWD}/cifar +``` + +- Optional arguments for `main.py` include + - `train_dir`: Directory to store eventfiles and checkpoints. + - `restore`: Restore the latest checkpoint. + - `validate`: Use validation set for training monitoring. + - `dataset`: Use either `cifar-10` or `cifar-100`. + - `config`: RevNet configuration. + - `use_defun`: Use `tfe.defun` to boost performance. + +- To train a model with estimators in graph execution, run + +```bash +python main_estimator.py --data_dir ${PWD}/cifar +``` +To ensure our code works properly when using the Keras model in an estimator, +`tf-nightly` or `tf-nightly-gpu` is highly recommended as of August 2018. + +- Optional arguments for `main.py` include + - `model_dir`: Directory to store eventfiles and checkpoints. + - `dataset`: Use either `cifar-10` or `cifar-100`. + - `config`: RevNet configuration. + - `export`: Export the model for serving if True. + +## Speed up with `tfe.defun` +To ensure that `tf.contrib.eager.defun` in our code works properly with all +part of the model during training, the latest `tf-nightly` or `tf-nightly-gpu` +is highly recommended as of August 2018. + +Even though the speed difference between pure eager execution and graph execution is noticeable, +the difference between fully "defunned" model training and graph +training is negligible. + +## Train on ImageNet with Cloud TPUs +The standard way to train models on Cloud TPUs is via TPU estimators and graph +execution. Models built with the `tf.keras` API are fully compatible with TPU estimators. +To ensure our code works properly in this setting, +`tf-nightly` or `tf-nightly-gpu` is highly recommended as of August 2018. + +### Setup a Google Cloud project + +Follow the instructions at the [Quickstart Guide](https://cloud.google.com/tpu/docs/quickstart) +to get a GCE VM with access to Cloud TPU. + +To run this model, you will need: + +* A GCE VM instance with an associated Cloud TPU resource +* A GCS bucket to store your training checkpoints +* (Optional): The ImageNet training and validation data preprocessed into + TFRecord format, and stored in GCS. + +### Format the data + +The data is expected to be formatted in TFRecord format, as generated by [this +script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py). + +If you do not have ImageNet dataset prepared, you can use a randomly generated +fake dataset to test the model. It is located at +`gs://cloud-tpu-test-datasets/fake_imagenet`. + +### Start training + +Train the model by executing the following command (substituting the appropriate +values): + +```bash +python main_estimator_tpu.py \ + --tpu=$TPU_NAME \ + --data_dir=$DATA_DIR \ + --model_dir=$MODEL_DIR +``` + +## Performance +- RevNet-38 achieves >92% and >71% accuracy on CIFAR-10 and CIFAR-100 respectively. +- RevNet-56 achieves <26% top-1 error rate on ImageNet. + +## Reference +The Reversible Residual Network: Backpropagation Without Storing Activations. +Aidan N. Gomez, Mengye Ren, Raquel Urtasun, Roger B. Grosse. Neural Information Processing Systems (NIPS), 2017. diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py index 74c1825a49a702c8c4cc8ec04ebb87917bca380d..f61354bc38a9fcb941f186cac4eac8097eea742d 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/blocks.py +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py @@ -24,7 +24,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six +import functools +import operator + import tensorflow as tf from tensorflow.contrib.eager.python.examples.revnet import ops @@ -44,8 +46,9 @@ class RevBlock(tf.keras.Model): batch_norm_first=False, data_format="channels_first", bottleneck=False, - fused=True): - """Initialize RevBlock. + fused=True, + dtype=tf.float32): + """Initialization. Args: n_res: number of residual blocks @@ -56,6 +59,7 @@ class RevBlock(tf.keras.Model): data_format: tensor data format, "NCHW"/"NHWC" bottleneck: use bottleneck residual if True fused: use fused batch normalization if True + dtype: float16, float32, or float64 """ super(RevBlock, self).__init__() self.blocks = tf.contrib.checkpoint.List() @@ -69,7 +73,8 @@ class RevBlock(tf.keras.Model): batch_norm_first=curr_batch_norm_first, data_format=data_format, bottleneck=bottleneck, - fused=fused) + fused=fused, + dtype=dtype) self.blocks.append(block) if data_format == "channels_first": @@ -86,45 +91,27 @@ class RevBlock(tf.keras.Model): h = block(h, training=training) return h - def backward_grads_and_vars(self, x, y, dy, training=True): + def backward_grads(self, x, y, dy, training=True): """Apply reversible block backward to outputs.""" grads_all = [] - vars_all = [] - for i in reversed(range(len(self.blocks))): block = self.blocks[i] if i == 0: - y_inv = x + # First block usually contains downsampling that can't be reversed + dy, grads = block.backward_grads_with_downsample( + x, y, dy, training=True) else: - # Don't update running stats when reconstructing activations - vars_and_vals = block.get_moving_stats() - y_inv = block.backward(y, training=training) - block.restore_moving_stats(vars_and_vals) - - # Update running stats when computing gradients during training - dy, grads, vars_ = block.backward_grads_and_vars( - y_inv, dy, training=training) + y, dy, grads = block.backward_grads(y, dy, training=training) + grads_all = grads + grads_all - grads_all += grads - vars_all += vars_ - - return dy, grads_all, vars_all + return dy, grads_all class _Residual(tf.keras.Model): """Single residual block contained in a _RevBlock. Each `_Residual` object has two _ResidualInner objects, corresponding to the `F` and `G` functions in the paper. - - Args: - filters: output filter size - strides: length 2 list/tuple of integers for height and width strides - input_shape: length 3 list/tuple of integers - batch_norm_first: whether to apply activation and batch norm before conv - data_format: tensor data format, "NCHW"/"NHWC", - bottleneck: use bottleneck residual if True - fused: use fused batch normalization if True """ def __init__(self, @@ -134,7 +121,20 @@ class _Residual(tf.keras.Model): batch_norm_first=True, data_format="channels_first", bottleneck=False, - fused=True): + fused=True, + dtype=tf.float32): + """Initialization. + + Args: + filters: output filter size + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC", + bottleneck: use bottleneck residual if True + fused: use fused batch normalization if True + dtype: float16, float32, or float64 + """ super(_Residual, self).__init__() self.filters = filters @@ -156,21 +156,21 @@ class _Residual(tf.keras.Model): input_shape=f_input_shape, batch_norm_first=batch_norm_first, data_format=data_format, - fused=fused) + fused=fused, + dtype=dtype) self.g = factory( filters=filters // 2, strides=(1, 1), input_shape=g_input_shape, batch_norm_first=batch_norm_first, data_format=data_format, - fused=fused) + fused=fused, + dtype=dtype) - def call(self, x, training=True, concat=True): + def call(self, x, training=True): """Apply residual block to inputs.""" - - x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis) + x1, x2 = x f_x2 = self.f(x2, training=training) - # TODO(lxuechen): Replace with simpler downsampling x1_down = ops.downsample( x1, self.filters // 2, self.strides, axis=self.axis) x2_down = ops.downsample( @@ -178,185 +178,327 @@ class _Residual(tf.keras.Model): y1 = f_x2 + x1_down g_y1 = self.g(y1, training=training) y2 = g_y1 + x2_down - if not concat: # For correct backward grads - return y1, y2 - return tf.concat([y1, y2], axis=self.axis) + return y1, y2 - def backward(self, y, training=True): - """Reconstruct inputs from outputs; only valid when stride 1.""" + def backward_grads(self, y, dy, training=True): + """Manually compute backward gradients given input and output grads.""" + dy1, dy2 = dy + y1, y2 = y + + with tf.GradientTape() as gtape: + gtape.watch(y1) + gy1 = self.g(y1, training=training) + grads_combined = gtape.gradient( + gy1, [y1] + self.g.trainable_variables, output_gradients=dy2) + dg = grads_combined[1:] + dx1 = dy1 + grads_combined[0] + # This doesn't affect eager execution, but improves memory efficiency with + # graphs + with tf.control_dependencies(dg + [dx1]): + x2 = y2 - gy1 + + with tf.GradientTape() as ftape: + ftape.watch(x2) + fx2 = self.f(x2, training=training) + grads_combined = ftape.gradient( + fx2, [x2] + self.f.trainable_variables, output_gradients=dx1) + df = grads_combined[1:] + dx2 = dy2 + grads_combined[0] + # Same behavior as above + with tf.control_dependencies(df + [dx2]): + x1 = y1 - fx2 + + x = x1, x2 + dx = dx1, dx2 + grads = df + dg - assert self.strides == (1, 1) + return x, dx, grads - y1, y2 = tf.split(y, num_or_size_splits=2, axis=self.axis) - g_y1 = self.g(y1, training=training) - x2 = y2 - g_y1 - f_x2 = self.f(x2, training=training) - x1 = y1 - f_x2 + def backward_grads_with_downsample(self, x, y, dy, training=True): + """Manually compute backward gradients given input and output grads.""" + # Splitting this from `backward_grads` for better readability + x1, x2 = x + y1, _ = y + dy1, dy2 = dy + + with tf.GradientTape() as gtape: + gtape.watch(y1) + gy1 = self.g(y1, training=training) + grads_combined = gtape.gradient( + gy1, [y1] + self.g.trainable_variables, output_gradients=dy2) + dg = grads_combined[1:] + dz1 = dy1 + grads_combined[0] + + # dx1 need one more step to backprop through downsample + with tf.GradientTape() as x1tape: + x1tape.watch(x1) + z1 = ops.downsample(x1, self.filters // 2, self.strides, axis=self.axis) + dx1 = x1tape.gradient(z1, x1, output_gradients=dz1) + + with tf.GradientTape() as ftape: + ftape.watch(x2) + fx2 = self.f(x2, training=training) + grads_combined = ftape.gradient( + fx2, [x2] + self.f.trainable_variables, output_gradients=dz1) + dx2, df = grads_combined[0], grads_combined[1:] + + # dx2 need one more step to backprop through downsample + with tf.GradientTape() as x2tape: + x2tape.watch(x2) + z2 = ops.downsample(x2, self.filters // 2, self.strides, axis=self.axis) + dx2 += x2tape.gradient(z2, x2, output_gradients=dy2) + + dx = dx1, dx2 + grads = df + dg - return tf.concat([x1, x2], axis=self.axis) + return dx, grads - def backward_grads_and_vars(self, x, dy, training=True): - """Manually compute backward gradients given input and output grads.""" - with tf.GradientTape(persistent=True) as tape: - x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed - x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis) - tape.watch([x1, x2]) - # Stitch back x for `call` so tape records correct grads - x = tf.concat([x1, x2], axis=self.axis) - dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis) - y1, y2 = self.call(x, training=training, concat=False) - x2_down = ops.downsample( - x2, self.filters // 2, self.strides, axis=self.axis) - - grads_combined = tape.gradient( - y2, [y1] + self.g.trainable_variables, output_gradients=[dy2]) - dy2_y1, dg = grads_combined[0], grads_combined[1:] - dy1_plus = dy2_y1 + dy1 - - grads_combined = tape.gradient( - y1, [x1, x2] + self.f.trainable_variables, output_gradients=[dy1_plus]) - dx1, dx2, df = grads_combined[0], grads_combined[1], grads_combined[2:] - dx2 += tape.gradient(x2_down, [x2], output_gradients=[dy2])[0] - - del tape +# Ideally, the following should be wrapped in `tf.keras.Sequential`, however +# there are subtle issues with its placeholder insertion policy and batch norm +class _BottleneckResidualInner(tf.keras.Model): + """Single bottleneck residual inner function contained in _Resdual. - grads = df + dg - vars_ = self.f.trainable_variables + self.g.trainable_variables + Corresponds to the `F`/`G` functions in the paper. + Suitable for training on ImageNet dataset. + """ + + def __init__(self, + filters, + strides, + input_shape, + batch_norm_first=True, + data_format="channels_first", + fused=True, + dtype=tf.float32): + """Initialization. - return tf.concat([dx1, dx2], axis=self.axis), grads, vars_ + Args: + filters: output filter size + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC" + fused: use fused batch normalization if True + dtype: float16, float32, or float64 + """ + super(_BottleneckResidualInner, self).__init__() + axis = 1 if data_format == "channels_first" else 3 + if batch_norm_first: + self.batch_norm_0 = tf.keras.layers.BatchNormalization( + axis=axis, input_shape=input_shape, fused=fused, dtype=dtype) + self.conv2d_1 = tf.keras.layers.Conv2D( + filters=filters // 4, + kernel_size=1, + strides=strides, + input_shape=input_shape, + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype) + + self.batch_norm_1 = tf.keras.layers.BatchNormalization( + axis=axis, fused=fused, dtype=dtype) + self.conv2d_2 = tf.keras.layers.Conv2D( + filters=filters // 4, + kernel_size=3, + strides=(1, 1), + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype) + + self.batch_norm_2 = tf.keras.layers.BatchNormalization( + axis=axis, fused=fused, dtype=dtype) + self.conv2d_3 = tf.keras.layers.Conv2D( + filters=filters, + kernel_size=1, + strides=(1, 1), + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype) - def get_moving_stats(self): - vars_and_vals = {} + self.batch_norm_first = batch_norm_first - def _is_moving_var(v): # pylint: disable=invalid-name - n = v.name - return n.endswith("moving_mean:0") or n.endswith("moving_variance:0") + def call(self, x, training=True): + net = x + if self.batch_norm_first: + net = self.batch_norm_0(net, training=training) + net = tf.nn.relu(net) + net = self.conv2d_1(net) - for v in filter(_is_moving_var, self.f.variables + self.g.variables): - vars_and_vals[v] = v.read_value() + net = self.batch_norm_1(net, training=training) + net = tf.nn.relu(net) + net = self.conv2d_2(net) - return vars_and_vals + net = self.batch_norm_2(net, training=training) + net = tf.nn.relu(net) + net = self.conv2d_3(net) - def restore_moving_stats(self, vars_and_vals): - for var_, val in six.iteritems(vars_and_vals): - var_.assign(val) + return net -def _BottleneckResidualInner(filters, - strides, - input_shape, - batch_norm_first=True, - data_format="channels_first", - fused=True): - """Single bottleneck residual inner function contained in _Resdual. +class _ResidualInner(tf.keras.Model): + """Single residual inner function contained in _ResdualBlock. Corresponds to the `F`/`G` functions in the paper. - Suitable for training on ImageNet dataset. + """ - Args: - filters: output filter size - strides: length 2 list/tuple of integers for height and width strides - input_shape: length 3 list/tuple of integers - batch_norm_first: whether to apply activation and batch norm before conv - data_format: tensor data format, "NCHW"/"NHWC" - fused: use fused batch normalization if True + def __init__(self, + filters, + strides, + input_shape, + batch_norm_first=True, + data_format="channels_first", + fused=True, + dtype=tf.float32): + """Initialization. - Returns: - A keras model - """ + Args: + filters: output filter size + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC" + fused: use fused batch normalization if True + dtype: float16, float32, or float64 + """ + super(_ResidualInner, self).__init__() + axis = 1 if data_format == "channels_first" else 3 + if batch_norm_first: + self.batch_norm_0 = tf.keras.layers.BatchNormalization( + axis=axis, input_shape=input_shape, fused=fused, dtype=dtype) + self.conv2d_1 = tf.keras.layers.Conv2D( + filters=filters, + kernel_size=3, + strides=strides, + input_shape=input_shape, + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype) + + self.batch_norm_1 = tf.keras.layers.BatchNormalization( + axis=axis, fused=fused, dtype=dtype) + self.conv2d_2 = tf.keras.layers.Conv2D( + filters=filters, + kernel_size=3, + strides=(1, 1), + data_format=data_format, + use_bias=False, + padding="SAME", + dtype=dtype) - axis = 1 if data_format == "channels_first" else 3 - model = tf.keras.Sequential() - if batch_norm_first: - model.add( - tf.keras.layers.BatchNormalization( - axis=axis, input_shape=input_shape, fused=fused)) - model.add(tf.keras.layers.Activation("relu")) - model.add( - tf.keras.layers.Conv2D( - filters=filters // 4, - kernel_size=1, - strides=strides, - input_shape=input_shape, - data_format=data_format, - use_bias=False, - padding="SAME")) - - model.add(tf.keras.layers.BatchNormalization(axis=axis, fused=fused)) - model.add(tf.keras.layers.Activation("relu")) - model.add( - tf.keras.layers.Conv2D( - filters=filters // 4, - kernel_size=3, - strides=(1, 1), - data_format=data_format, - use_bias=False, - padding="SAME")) - - model.add(tf.keras.layers.BatchNormalization(axis=axis, fused=fused)) - model.add(tf.keras.layers.Activation("relu")) - model.add( - tf.keras.layers.Conv2D( - filters=filters, - kernel_size=1, - strides=(1, 1), - data_format=data_format, - use_bias=False, - padding="SAME")) + self.batch_norm_first = batch_norm_first - return model + def call(self, x, training=True): + net = x + if self.batch_norm_first: + net = self.batch_norm_0(net, training=training) + net = tf.nn.relu(net) + net = self.conv2d_1(net) + net = self.batch_norm_1(net, training=training) + net = tf.nn.relu(net) + net = self.conv2d_2(net) -def _ResidualInner(filters, - strides, - input_shape, - batch_norm_first=True, - data_format="channels_first", - fused=True): - """Single residual inner function contained in _ResdualBlock. + return net - Corresponds to the `F`/`G` functions in the paper. - Args: - filters: output filter size - strides: length 2 list/tuple of integers for height and width strides - input_shape: length 3 list/tuple of integers - batch_norm_first: whether to apply activation and batch norm before conv - data_format: tensor data format, "NCHW"/"NHWC" - fused: use fused batch normalization if True +class InitBlock(tf.keras.Model): + """Initial block of RevNet.""" - Returns: - A keras model - """ + def __init__(self, config): + """Initialization. - axis = 1 if data_format == "channels_first" else 3 - model = tf.keras.Sequential() - if batch_norm_first: - model.add( - tf.keras.layers.BatchNormalization( - axis=axis, input_shape=input_shape, fused=fused)) - model.add(tf.keras.layers.Activation("relu")) - model.add( - tf.keras.layers.Conv2D( - filters=filters, - kernel_size=3, - strides=strides, - input_shape=input_shape, - data_format=data_format, - use_bias=False, - padding="SAME")) - - model.add(tf.keras.layers.BatchNormalization(axis=axis, fused=fused)) - model.add(tf.keras.layers.Activation("relu")) - model.add( - tf.keras.layers.Conv2D( - filters=filters, - kernel_size=3, - strides=(1, 1), - data_format=data_format, - use_bias=False, - padding="SAME")) + Args: + config: tf.contrib.training.HParams object; specifies hyperparameters + """ + super(InitBlock, self).__init__() + self.config = config + self.axis = 1 if self.config.data_format == "channels_first" else 3 + self.conv2d = tf.keras.layers.Conv2D( + filters=self.config.init_filters, + kernel_size=self.config.init_kernel, + strides=(self.config.init_stride, self.config.init_stride), + data_format=self.config.data_format, + use_bias=False, + padding="SAME", + input_shape=self.config.input_shape, + dtype=self.config.dtype) + self.batch_norm = tf.keras.layers.BatchNormalization( + axis=self.axis, fused=self.config.fused, dtype=self.config.dtype) + self.activation = tf.keras.layers.Activation("relu") + + if self.config.init_max_pool: + self.max_pool = tf.keras.layers.MaxPooling2D( + pool_size=(3, 3), + strides=(2, 2), + padding="SAME", + data_format=self.config.data_format, + dtype=self.config.dtype) + + def call(self, x, training=True): + net = x + net = self.conv2d(net) + net = self.batch_norm(net, training=training) + net = self.activation(net) + + if self.config.init_max_pool: + net = self.max_pool(net) + + return tf.split(net, num_or_size_splits=2, axis=self.axis) + + +class FinalBlock(tf.keras.Model): + """Final block of RevNet.""" + + def __init__(self, config): + """Initialization. + + Args: + config: tf.contrib.training.HParams object; specifies hyperparameters - return model + Raises: + ValueError: Unsupported data format + """ + super(FinalBlock, self).__init__() + self.config = config + self.axis = 1 if self.config.data_format == "channels_first" else 3 + + f = self.config.filters[-1] # Number of filters + r = functools.reduce(operator.mul, self.config.strides, 1) # Reduce ratio + r *= self.config.init_stride + if self.config.init_max_pool: + r *= 2 + + if self.config.data_format == "channels_first": + w, h = self.config.input_shape[1], self.config.input_shape[2] + input_shape = (f, w // r, h // r) + elif self.config.data_format == "channels_last": + w, h = self.config.input_shape[0], self.config.input_shape[1] + input_shape = (w // r, h // r, f) + else: + raise ValueError("Data format should be either `channels_first`" + " or `channels_last`") + self.batch_norm = tf.keras.layers.BatchNormalization( + axis=self.axis, + input_shape=input_shape, + fused=self.config.fused, + dtype=self.config.dtype) + self.activation = tf.keras.layers.Activation("relu") + self.global_avg_pool = tf.keras.layers.GlobalAveragePooling2D( + data_format=self.config.data_format, dtype=self.config.dtype) + self.dense = tf.keras.layers.Dense( + self.config.n_classes, dtype=self.config.dtype) + + def call(self, x, training=True): + net = tf.concat(x, axis=self.axis) + net = self.batch_norm(net, training=training) + net = self.activation(net) + net = self.global_avg_pool(net) + net = self.dense(net) + + return net diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py index a28ca6e3e076ef1d52ab5a34e5559536cf5d52cc..9ff6b605b912772a92ab9e07a0ba5b9325030e43 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py @@ -22,6 +22,27 @@ import tensorflow as tf from tensorflow.contrib.eager.python.examples.revnet import blocks +def compute_degree(g1, g2, eps=1e-7): + """Compute the degree between two vectors using their usual inner product.""" + + def _dot(u, v): + return tf.reduce_sum(u * v) + + g1_norm = tf.sqrt(_dot(g1, g1)) + g2_norm = tf.sqrt(_dot(g2, g2)) + if g1_norm.numpy() == 0 and g2_norm.numpy() == 0: + cosine = 1. - eps + else: + g1_norm = 1. if g1_norm.numpy() == 0 else g1_norm + g2_norm = 1. if g2_norm.numpy() == 0 else g2_norm + cosine = _dot(g1, g2) / g1_norm / g2_norm + # Restrict to arccos range + cosine = tf.minimum(tf.maximum(cosine, eps - 1.), 1. - eps) + degree = tf.acos(cosine) * 180. / 3.141592653589793 + + return degree + + def _validate_block_call_channels_last(block_factory, test): """Generic testing function for `channels_last` data format. @@ -33,30 +54,30 @@ def _validate_block_call_channels_last(block_factory, test): test: tf.test.TestCase object """ with tf.device("/cpu:0"): # NHWC format - input_shape = (224, 224, 32) + input_shape = (8, 8, 128) data_shape = (16,) + input_shape x = tf.random_normal(shape=data_shape) # Stride 1 block = block_factory( - filters=64, + filters=128, strides=(1, 1), input_shape=input_shape, data_format="channels_last") y_tr, y_ev = block(x, training=True), block(x, training=False) test.assertEqual(y_tr.shape, y_ev.shape) - test.assertEqual(y_ev.shape, (16, 224, 224, 64)) + test.assertEqual(y_ev.shape, (16, 8, 8, 128)) test.assertNotAllClose(y_tr, y_ev) # Stride of 2 block = block_factory( - filters=64, + filters=128, strides=(2, 2), input_shape=input_shape, data_format="channels_last") y_tr, y_ev = block(x, training=True), block(x, training=False) test.assertEqual(y_tr.shape, y_ev.shape) - test.assertEqual(y_ev.shape, (16, 112, 112, 64)) + test.assertEqual(y_ev.shape, (16, 4, 4, 128)) test.assertNotAllClose(y_tr, y_ev) @@ -74,267 +95,174 @@ def _validate_block_call_channels_first(block_factory, test): test.skipTest("GPU not available") with tf.device("/gpu:0"): # Default NCHW format - input_shape = (32, 224, 224) + input_shape = (128, 8, 8) data_shape = (16,) + input_shape x = tf.random_normal(shape=data_shape) # Stride of 1 - block = block_factory(filters=64, strides=(1, 1), input_shape=input_shape) + block = block_factory(filters=128, strides=(1, 1), input_shape=input_shape) y_tr, y_ev = block(x, training=True), block(x, training=False) test.assertEqual(y_tr.shape, y_ev.shape) - test.assertEqual(y_ev.shape, (16, 64, 224, 224)) + test.assertEqual(y_ev.shape, (16, 128, 8, 8)) test.assertNotAllClose(y_tr, y_ev) # Stride of 2 - block = block_factory(filters=64, strides=(2, 2), input_shape=input_shape) + block = block_factory(filters=128, strides=(2, 2), input_shape=input_shape) y_tr, y_ev = block(x, training=True), block(x, training=False) test.assertEqual(y_tr.shape, y_ev.shape) - test.assertEqual(y_ev.shape, (16, 64, 112, 112)) + test.assertEqual(y_ev.shape, (16, 128, 4, 4)) test.assertNotAllClose(y_tr, y_ev) class RevBlockTest(tf.test.TestCase): - def test_call_channels_first(self): - """Test `call` function with `channels_first` data format.""" - if not tf.test.is_gpu_available(): - self.skipTest("GPU not available") + def _check_grad_angle(self, grads, grads_true, atol=1e0): + """Check the angle between two list of vectors are all close.""" + for g1, g2 in zip(grads, grads_true): + degree = compute_degree(g1, g2) + self.assertLessEqual(degree, atol) - with tf.device("/gpu:0"): # Default NCHW format - input_shape = (32, 224, 224) - data_shape = (16,) + input_shape - x = tf.random_normal(shape=data_shape) - - # Stride of 1 - block = blocks.RevBlock( - n_res=3, filters=64, strides=(1, 1), input_shape=input_shape) - y_tr, y_ev = block(x, training=True), block(x, training=False) - self.assertEqual(y_tr.shape, y_ev.shape) - self.assertEqual(y_ev.shape, (16, 64, 224, 224)) - self.assertNotAllClose(y_tr, y_ev) - - # Stride of 2 - block = blocks.RevBlock( - n_res=3, filters=64, strides=(2, 2), input_shape=input_shape) - y_tr, y_ev = block(x, training=True), block(x, training=False) - self.assertEqual(y_tr.shape, y_ev.shape) - self.assertEqual(y_ev.shape, [16, 64, 112, 112]) - self.assertNotAllClose(y_tr, y_ev) - - def test_call_channels_last(self): - """Test `call` function with `channels_last` data format.""" - with tf.device("/cpu:0"): # NHWC format - input_shape = (224, 224, 32) - data_shape = (16,) + input_shape - x = tf.random_normal(shape=data_shape) - - # Stride 1 - block = blocks.RevBlock( - n_res=3, - filters=64, - strides=(1, 1), - input_shape=input_shape, - data_format="channels_last") - y_tr, y_ev = block(x, training=True), block(x, training=False) - self.assertEqual(y_tr.shape, y_ev.shape) - self.assertEqual(y_ev.shape, (16, 224, 224, 64)) - self.assertNotAllClose(y_tr, y_ev) - - # Stride of 2 - block = blocks.RevBlock( - n_res=3, - filters=64, - strides=(2, 2), - input_shape=input_shape, - data_format="channels_last") - y_tr, y_ev = block(x, training=True), block(x, training=False) - self.assertEqual(y_tr.shape, y_ev.shape) - self.assertEqual(y_ev.shape, (16, 112, 112, 64)) - self.assertNotAllClose(y_tr, y_ev) - - def test_backward_grads_and_vars_channels_first(self): + def test_backward_grads_channels_first(self): """Test `backward` function with `channels_first` data format.""" if not tf.test.is_gpu_available(): self.skipTest("GPU not available") with tf.device("/gpu:0"): # Default NCHW format - input_shape = (32, 224, 224) - data_shape = (16,) + input_shape - x = tf.random_normal(shape=data_shape) - # Stride 1 - y = tf.random_normal(shape=data_shape) - dy = tf.random_normal(shape=data_shape) - block = blocks.RevBlock( - n_res=3, filters=32, strides=(1, 1), input_shape=input_shape) - dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy) - self.assertEqual(dy.shape, x.shape) - self.assertTrue(isinstance(grads, list)) - self.assertTrue(isinstance(vars_, list)) - - # Stride 2 - y = tf.random_normal(shape=(16, 32, 112, 112)) - dy = tf.random_normal(shape=(16, 32, 112, 112)) - block = blocks.RevBlock( - n_res=3, filters=32, strides=(2, 2), input_shape=input_shape) - dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy) - self.assertEqual(dy.shape, x.shape) - self.assertTrue(isinstance(grads, list)) - self.assertTrue(isinstance(vars_, list)) - - def test_backward_grads_and_vars_channels_last(self): - """Test `backward` function with `channels_last` data format.""" - with tf.device("/cpu:0"): # NHWC format - input_shape = (224, 224, 32) + input_shape = (128, 8, 8) data_shape = (16,) + input_shape - x = tf.random_normal(shape=data_shape) - - # Stride 1 - y = tf.random_normal(shape=data_shape) - dy = tf.random_normal(shape=data_shape) + x = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1) block = blocks.RevBlock( n_res=3, - filters=32, + filters=128, strides=(1, 1), input_shape=input_shape, - data_format="channels_last") - dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy) - self.assertEqual(dy.shape, x.shape) - self.assertTrue(isinstance(grads, list)) - self.assertTrue(isinstance(vars_, list)) + fused=False, + dtype=tf.float64) + with tf.GradientTape() as tape: + tape.watch(x) + x1, x2 = tf.split(x, num_or_size_splits=2, axis=1) + y1, y2 = block((x1, x2), training=True) + y = tf.concat((y1, y2), axis=1) + # Compute grads from reconstruction + (dx1, dx2), dw = block.backward_grads( + x=(x1, x2), y=(y1, y2), dy=(dy1, dy2), training=True) + dx = tf.concat((dx1, dx2), axis=1) + vars_ = block.trainable_variables + # Compute true grads + grads = tape.gradient(y, [x] + vars_, output_gradients=dy) + dx_true, dw_true = grads[0], grads[1:] + self.assertAllClose(dx_true, dx) + self.assertAllClose(dw_true, dw) + self._check_grad_angle(dx_true, dx) + self._check_grad_angle(dw_true, dw) # Stride 2 - y = tf.random_normal(shape=(16, 112, 112, 32)) - dy = tf.random_normal(shape=(16, 112, 112, 32)) + x = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy = tf.random_normal(shape=(16, 128, 4, 4), dtype=tf.float64) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1) block = blocks.RevBlock( n_res=3, - filters=32, + filters=128, strides=(2, 2), input_shape=input_shape, - data_format="channels_last") - dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy) - self.assertEqual(dy.shape, x.shape) - self.assertTrue(isinstance(grads, list)) - self.assertTrue(isinstance(vars_, list)) - - -class _ResidualTest(tf.test.TestCase): + fused=False, + dtype=tf.float64) + with tf.GradientTape() as tape: + tape.watch(x) + x1, x2 = tf.split(x, num_or_size_splits=2, axis=1) + y1, y2 = block((x1, x2), training=True) + y = tf.concat((y1, y2), axis=1) + # Compute grads from reconstruction + (dx1, dx2), dw = block.backward_grads( + x=(x1, x2), y=(y1, y2), dy=(dy1, dy2), training=True) + dx = tf.concat((dx1, dx2), axis=1) + vars_ = block.trainable_variables + # Compute true grads + grads = tape.gradient(y, [x] + vars_, output_gradients=dy) + dx_true, dw_true = grads[0], grads[1:] + self.assertAllClose(dx_true, dx) + self.assertAllClose(dw_true, dw) + self._check_grad_angle(dx_true, dx) + self._check_grad_angle(dw_true, dw) + + def test_backward_grads_with_nativepy(self): + if not tf.test.is_gpu_available(): + self.skipTest("GPU not available") - def test_call(self): - """Test `call` function. + input_shape = (128, 8, 8) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1) + block = blocks.RevBlock( + n_res=3, + filters=128, + strides=(1, 1), + input_shape=input_shape, + fused=False, + dtype=tf.float64) + with tf.GradientTape() as tape: + tape.watch(x) + x1, x2 = tf.split(x, num_or_size_splits=2, axis=1) + y1, y2 = block((x1, x2), training=True) + y = tf.concat((y1, y2), axis=1) - Varying downsampling and data format options. - """ + # Compute true grads + dx_true = tape.gradient(y, x, output_gradients=dy) - _validate_block_call_channels_first(blocks._Residual, self) - _validate_block_call_channels_last(blocks._Residual, self) + # Compute grads from reconstruction + (dx1, dx2), _ = block.backward_grads( + x=(x1, x2), y=(y1, y2), dy=(dy1, dy2), training=True) + dx = tf.concat((dx1, dx2), axis=1) - def test_backward_channels_first(self): - """Test `backward` function with `channels_first` data format.""" - if not tf.test.is_gpu_available(): - self.skipTest("GPU not available") + thres = 1e-5 + diff_abs = tf.reshape(abs(dx - dx_true), [-1]) + assert all(diff_abs < thres) - with tf.device("/gpu:0"): # Default NCHW format - input_shape = (16, 224, 224) - data_shape = (16,) + input_shape - x = tf.random_normal(shape=data_shape) - residual = blocks._Residual( - filters=16, strides=(1, 1), input_shape=input_shape) - - y_tr, y_ev = residual(x, training=True), residual(x, training=False) - x_ = residual.backward(y_ev, training=False) - self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1) - x_ = residual.backward(y_tr, training=True) # This updates moving avg - self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1) - - def test_backward_channels_last(self): - """Test `backward` function with `channels_last` data format.""" - with tf.device("/cpu:0"): # NHWC format - input_shape = (224, 224, 16) - data_shape = (16,) + input_shape - x = tf.random_normal(shape=data_shape) - residual = blocks._Residual( - filters=16, - strides=(1, 1), - input_shape=input_shape, - data_format="channels_last") - y_tr, y_ev = residual(x, training=True), residual(x, training=False) - x_ = residual.backward(y_ev, training=False) - self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1) - x_ = residual.backward(y_tr, training=True) # This updates moving avg - self.assertAllClose(x, x_, rtol=1e-1, atol=1e-1) +class _ResidualTest(tf.test.TestCase): - def test_backward_grads_and_vars_channels_first(self): + def test_backward_grads_channels_first(self): """Test `backward_grads` function with `channels_first` data format.""" if not tf.test.is_gpu_available(): self.skipTest("GPU not available") with tf.device("/gpu:0"): # Default NCHW format - input_shape = (16, 224, 224) - data_shape = (16,) + input_shape - x = tf.random_normal(shape=data_shape) - dy = tf.random_normal(shape=data_shape) - residual = blocks._Residual( - filters=16, strides=(1, 1), input_shape=input_shape) - - vars_and_vals = residual.get_moving_stats() - dx_tr, grads_tr, vars_tr = residual.backward_grads_and_vars( - x, dy=dy, training=True) - dx_ev, grads_ev, vars_ev = residual.backward_grads_and_vars( - x, dy=dy, training=False) - self.assertNotAllClose(dx_tr, dx_ev) - self.assertTrue(isinstance(grads_tr, list)) - self.assertTrue(isinstance(grads_ev, list)) - self.assertTrue(isinstance(vars_tr, list)) - self.assertTrue(isinstance(vars_ev, list)) - for grad_tr, var_tr, grad_ev, var_ev in zip(grads_tr, vars_tr, grads_ev, - vars_ev): - self.assertEqual(grad_tr.shape, grad_ev.shape) - self.assertEqual(var_tr.shape, var_ev.shape) - self.assertEqual(grad_tr.shape, var_tr.shape) - - # Compare against the true gradient computed by the tape - residual.restore_moving_stats(vars_and_vals) - with tf.GradientTape(persistent=True) as tape: - tape.watch(x) - y = residual(x, training=True) - grads = tape.gradient( - y, [x] + residual.trainable_variables, output_gradients=[dy]) - dx_tr_true, grads_tr_true = grads[0], grads[1:] - - del tape - - self.assertAllClose(dx_tr, dx_tr_true, rtol=1e-1, atol=1e-1) - self.assertAllClose(grads_tr, grads_tr_true, rtol=1e-1, atol=1e-1) - - def test_backward_grads_and_vars_channels_last(self): - """Test `backward_grads` function with `channels_last` data format.""" - with tf.device("/cpu:0"): # NHWC format - input_shape = (224, 224, 16) + input_shape = (128, 8, 8) data_shape = (16,) + input_shape - x = tf.random_normal(shape=data_shape) - dy = tf.random_normal(shape=data_shape) + # Use double precision for testing + x_true = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy = tf.random_normal(shape=data_shape, dtype=tf.float64) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=1) residual = blocks._Residual( - filters=16, + filters=128, strides=(1, 1), input_shape=input_shape, - data_format="channels_last") - - dx_tr, grads_tr, vars_tr = residual.backward_grads_and_vars( - x, dy=dy, training=True) - dx_ev, grads_ev, vars_ev = residual.backward_grads_and_vars( - x, dy=dy, training=False) - self.assertNotAllClose(dx_tr, dx_ev) - self.assertTrue(isinstance(grads_tr, list)) - self.assertTrue(isinstance(grads_ev, list)) - self.assertTrue(isinstance(vars_tr, list)) - self.assertTrue(isinstance(vars_ev, list)) - for grad_tr, var_tr, grad_ev, var_ev in zip(grads_tr, vars_tr, grads_ev, - vars_ev): - self.assertEqual(grad_tr.shape, grad_ev.shape) - self.assertEqual(var_tr.shape, var_ev.shape) - self.assertEqual(grad_tr.shape, var_tr.shape) + fused=False, + dtype=tf.float64) + + with tf.GradientTape() as tape: + tape.watch(x_true) + x1_true, x2_true = tf.split(x_true, num_or_size_splits=2, axis=1) + y1, y2 = residual((x1_true, x2_true), training=True) + y = tf.concat((y1, y2), axis=1) + + # Gradients computed due to reversibility + (x1, x2), (dx1, dx2), dw = residual.backward_grads( + y=(y1, y2), dy=(dy1, dy2), training=True) + x = tf.concat((x1, x2), axis=1) + dx = tf.concat((dx1, dx2), axis=1) + # True gradients computed by the tape + grads = tape.gradient( + y, [x_true] + residual.trainable_variables, output_gradients=dy) + dx_true, dw_true = grads[0], grads[1:] + + self.assertAllClose(x_true, x) + self.assertAllClose(dx_true, dx) + self.assertAllClose(dw_true, dw) class _ResidualInnerTest(tf.test.TestCase): diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py index e1d8b3a0559704bd8f00a8cc4b9fe735ad1de5f9..e9672f13e1587c96cea0fc7dd58b66ef256296cd 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py +++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py @@ -35,7 +35,7 @@ def get_ds_from_tfrecords(data_dir, epochs=None, shuffle=True, data_format="channels_first", - num_parallel_calls=8, + num_parallel_calls=12, prefetch=0, div255=True, dtype=tf.float32): @@ -111,6 +111,6 @@ def get_ds_from_tfrecords(data_dir, }[split] dataset = dataset.shuffle(size) - dataset = dataset.batch(batch_size) + dataset = dataset.batch(batch_size, drop_remainder=True) return dataset diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py index f79428b2a97f0ac2ce991f4c26b9123cddc24325..377844ad8fbca92629a4d71f5df2aab67b570c3c 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py +++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Read CIFAR-10 data from pickled numpy arrays and writes TFRecords. +"""Read CIFAR data from pickled numpy arrays and writes TFRecords. Generates tf.train.Example protos and writes them to TFRecord files from the -python version of the CIFAR-10 dataset downloaded from +python version of the CIFAR dataset downloaded from https://www.cs.toronto.edu/~kriz/cifar.html. """ @@ -32,20 +32,22 @@ from six.moves import cPickle as pickle from six.moves import urllib import tensorflow as tf -CIFAR_FILENAME = 'cifar-10-python.tar.gz' -CIFAR_DOWNLOAD_URL = 'https://www.cs.toronto.edu/~kriz/' + CIFAR_FILENAME -CIFAR_LOCAL_FOLDER = 'cifar-10-batches-py' +BASE_URL = 'https://www.cs.toronto.edu/~kriz/' +CIFAR_FILE_NAMES = ['cifar-10-python.tar.gz', 'cifar-100-python.tar.gz'] +CIFAR_DOWNLOAD_URLS = [BASE_URL + name for name in CIFAR_FILE_NAMES] +CIFAR_LOCAL_FOLDERS = ['cifar-10', 'cifar-100'] +EXTRACT_FOLDERS = ['cifar-10-batches-py', 'cifar-100-python'] -def download_and_extract(data_dir): - """Download CIFAR-10 if not already downloaded.""" - filepath = os.path.join(data_dir, CIFAR_FILENAME) +def download_and_extract(data_dir, file_name, url): + """Download CIFAR if not already downloaded.""" + filepath = os.path.join(data_dir, file_name) if tf.gfile.Exists(filepath): return filepath if not tf.gfile.Exists(data_dir): tf.gfile.MakeDirs(data_dir) - urllib.request.urlretrieve(CIFAR_DOWNLOAD_URL, filepath) + urllib.request.urlretrieve(url, filepath) tarfile.open(os.path.join(filepath), 'r:gz').extractall(data_dir) return filepath @@ -58,12 +60,22 @@ def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) -def _get_file_names(): +def _get_file_names(folder): """Returns the file names expected to exist in the input_dir.""" + assert folder in ['cifar-10', 'cifar-100'] + file_names = {} - file_names['train'] = ['data_batch_%d' % i for i in range(1, 5)] - file_names['validation'] = ['data_batch_5'] - file_names['test'] = ['test_batch'] + if folder == 'cifar-10': + file_names['train'] = ['data_batch_%d' % i for i in range(1, 5)] + file_names['validation'] = ['data_batch_5'] + file_names['train_all'] = ['data_batch_%d' % i for i in range(1, 6)] + file_names['test'] = ['test_batch'] + else: + file_names['train_all'] = ['train'] + file_names['test'] = ['test'] + # Split in `convert_to_tfrecord` function + file_names['train'] = ['train'] + file_names['validation'] = ['train'] return file_names @@ -76,14 +88,28 @@ def read_pickle_from_file(filename): return data_dict -def convert_to_tfrecord(input_files, output_file): +def convert_to_tfrecord(input_files, output_file, folder): """Converts files with pickled data to TFRecords.""" + assert folder in ['cifar-10', 'cifar-100'] + print('Generating %s' % output_file) with tf.python_io.TFRecordWriter(output_file) as record_writer: for input_file in input_files: data_dict = read_pickle_from_file(input_file) data = data_dict[b'data'] - labels = data_dict[b'labels'] + try: + labels = data_dict[b'labels'] + except KeyError: + labels = data_dict[b'fine_labels'] + + if folder == 'cifar-100' and input_file.endswith('train.tfrecords'): + data = data[:40000] + labels = labels[:40000] + elif folder == 'cifar-100' and input_file.endswith( + 'validation.tfrecords'): + data = data[40000:] + labels = labels[40000:] + num_entries_in_batch = len(labels) for i in range(num_entries_in_batch): @@ -97,19 +123,24 @@ def convert_to_tfrecord(input_files, output_file): def main(_): - print('Download from {} and extract.'.format(CIFAR_DOWNLOAD_URL)) - download_and_extract(FLAGS.data_dir) - file_names = _get_file_names() - input_dir = os.path.join(FLAGS.data_dir, CIFAR_LOCAL_FOLDER) - - for mode, files in file_names.items(): - input_files = [os.path.join(input_dir, f) for f in files] - output_file = os.path.join(FLAGS.data_dir, mode + '.tfrecords') - try: - os.remove(output_file) - except OSError: - pass - convert_to_tfrecord(input_files, output_file) + for file_name, url, folder, extract_folder in zip( + CIFAR_FILE_NAMES, CIFAR_DOWNLOAD_URLS, CIFAR_LOCAL_FOLDERS, + EXTRACT_FOLDERS): + print('Download from {} and extract.'.format(url)) + data_dir = os.path.join(FLAGS.data_dir, folder) + download_and_extract(data_dir, file_name, url) + file_names = _get_file_names(folder) + input_dir = os.path.join(data_dir, extract_folder) + + for mode, files in file_names.items(): + input_files = [os.path.join(input_dir, f) for f in files] + output_file = os.path.join(data_dir, mode + '.tfrecords') + try: + os.remove(output_file) + except OSError: + pass + convert_to_tfrecord(input_files, output_file, folder) + print('Done!') @@ -118,6 +149,6 @@ if __name__ == '__main__': flags.DEFINE_string( 'data_dir', default=None, - help='Directory to download and extract CIFAR-10 to.') + help='Directory to download, extract and store TFRecords.') tf.app.run(main) diff --git a/tensorflow/contrib/eager/python/examples/revnet/config.py b/tensorflow/contrib/eager/python/examples/revnet/config.py index 30b0edbf43304f4dd1b3a10165bdb28886d2d152..29f1db0e0367515757413c8e47f7b7280fc4cfbb 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/config.py +++ b/tensorflow/contrib/eager/python/examples/revnet/config.py @@ -27,17 +27,17 @@ from __future__ import division from __future__ import print_function import tensorflow as tf -tfe = tf.contrib.eager def get_hparams_cifar_38(): """RevNet-38 configurations for CIFAR-10/CIFAR-100.""" config = tf.contrib.training.HParams() + config.add_hparam("num_train_images", 50000) + config.add_hparam("num_eval_images", 10000) config.add_hparam("init_filters", 32) config.add_hparam("init_kernel", 3) config.add_hparam("init_stride", 1) - config.add_hparam("n_classes", 10) config.add_hparam("n_rev_blocks", 3) config.add_hparam("n_res", [3, 3, 3]) config.add_hparam("filters", [32, 64, 112]) @@ -46,7 +46,7 @@ def get_hparams_cifar_38(): config.add_hparam("bottleneck", False) config.add_hparam("fused", True) config.add_hparam("init_max_pool", False) - if tfe.num_gpus() > 0: + if tf.test.is_gpu_available(): config.add_hparam("input_shape", (3, 32, 32)) config.add_hparam("data_format", "channels_first") else: @@ -66,11 +66,44 @@ def get_hparams_cifar_38(): config.add_hparam("dtype", tf.float32) config.add_hparam("eval_batch_size", 1000) config.add_hparam("div255", True) - # TODO(lxuechen): This is imprecise, when training with validation set, + # This is imprecise, when training with validation set, # we only have 40k images in training data - config.add_hparam("iters_per_epoch", 50000 // config.batch_size) + config.add_hparam("iters_per_epoch", + config.num_train_images // config.batch_size) config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch) + # Customized TPU hyperparameters due to differing batch size caused by + # TPU architecture specifics + # Suggested batch sizes to reduce overhead from excessive tensor padding + # https://cloud.google.com/tpu/docs/troubleshooting + config.add_hparam("tpu_batch_size", 1024) + config.add_hparam("tpu_eval_batch_size", 1024) + config.add_hparam("tpu_iters_per_epoch", + config.num_train_images // config.tpu_batch_size) + config.add_hparam("tpu_epochs", + config.max_train_iter // config.tpu_iters_per_epoch) + config.add_hparam("tpu_eval_steps", + config.num_eval_images // config.tpu_eval_batch_size) + return config + + +def get_hparams_cifar_110(): + config = get_hparams_cifar_38() + config.filters = [32, 64, 128] + config.n_res = [9, 9, 9] + + return config + + +def get_hparams_cifar_164(): + config = get_hparams_cifar_38() + config.filters = [32, 64, 128] + config.n_res = [9, 9, 9] + config.use_bottleneck = True + # Due to bottleneck residual blocks + filters = [f * 4 for f in config.filters] + config.filters = filters + return config @@ -78,15 +111,18 @@ def get_hparams_imagenet_56(): """RevNet-56 configurations for ImageNet.""" config = tf.contrib.training.HParams() + config.add_hparam("n_classes", 1000) + config.add_hparam("dataset", "ImageNet") + config.add_hparam("num_train_images", 1281167) + config.add_hparam("num_eval_images", 50000) config.add_hparam("init_filters", 128) config.add_hparam("init_kernel", 7) config.add_hparam("init_stride", 2) - config.add_hparam("n_classes", 1000) config.add_hparam("n_rev_blocks", 4) config.add_hparam("n_res", [2, 2, 2, 2]) config.add_hparam("filters", [128, 256, 512, 832]) config.add_hparam("strides", [1, 2, 2, 2]) - config.add_hparam("batch_size", 16) + config.add_hparam("batch_size", 256) config.add_hparam("bottleneck", True) config.add_hparam("fused", True) config.add_hparam("init_max_pool", True) @@ -96,6 +132,9 @@ def get_hparams_imagenet_56(): else: config.add_hparam("input_shape", (224, 224, 3)) config.add_hparam("data_format", "channels_last") + # Due to bottleneck residual blocks + filters = [f * 4 for f in config.filters] + config.filters = filters # Training details config.add_hparam("weight_decay", 1e-4) @@ -105,17 +144,32 @@ def get_hparams_imagenet_56(): config.add_hparam("max_train_iter", 600000) config.add_hparam("seed", 1234) config.add_hparam("shuffle", True) - config.add_hparam("log_every", 50) - config.add_hparam("save_every", 50) + config.add_hparam("log_every", 500) + config.add_hparam("save_every", 500) config.add_hparam("dtype", tf.float32) - config.add_hparam("eval_batch_size", 1000) + config.add_hparam("eval_batch_size", 256) config.add_hparam("div255", True) - # TODO(lxuechen): Update this according to ImageNet data - config.add_hparam("iters_per_epoch", 50000 // config.batch_size) + config.add_hparam("iters_per_epoch", + config.num_train_images // config.batch_size) config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch) - if config.bottleneck: - filters = [f * 4 for f in config.filters] - config.filters = filters + # Customized TPU hyperparameters due to differing batch size caused by + # TPU architecture specifics + # Suggested batch sizes to reduce overhead from excessive tensor padding + # https://cloud.google.com/tpu/docs/troubleshooting + config.add_hparam("tpu_batch_size", 1024) + config.add_hparam("tpu_eval_batch_size", 1024) + config.add_hparam("tpu_iters_per_epoch", + config.num_train_images // config.tpu_batch_size) + config.add_hparam("tpu_epochs", + config.max_train_iter // config.tpu_iters_per_epoch) + config.add_hparam("tpu_eval_steps", + config.num_eval_images // config.tpu_eval_batch_size) + return config + + +def get_hparams_imagenet_104(): + config = get_hparams_imagenet_56() + config.n_res = [2, 2, 11, 2] return config diff --git a/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py new file mode 100644 index 0000000000000000000000000000000000000000..34a9984b0ecc527ad1991c28146246b716e96c98 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py @@ -0,0 +1,229 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Efficient ImageNet input pipeline using tf.data.Dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os + +import tensorflow as tf + +from tensorflow.contrib.eager.python.examples.revnet import resnet_preprocessing + + +def image_serving_input_fn(): + """Serving input fn for raw images.""" + + def _preprocess_image(image_bytes): + """Preprocess a single raw image.""" + image = resnet_preprocessing.preprocess_image( + image_bytes=image_bytes, is_training=False) + return image + + image_bytes_list = tf.placeholder( + shape=[None], + dtype=tf.string, + ) + images = tf.map_fn( + _preprocess_image, image_bytes_list, back_prop=False, dtype=tf.float32) + return tf.estimator.export.ServingInputReceiver( + images, {'image_bytes': image_bytes_list}) + + +class ImageNetInput(object): + """Generates ImageNet input_fn for training or evaluation. + + The training data is assumed to be in TFRecord format with keys as specified + in the dataset_parser below, sharded across 1024 files, named sequentially: + train-00000-of-01024 + train-00001-of-01024 + ... + train-01023-of-01024 + + The validation data is in the same format but sharded in 128 files. + + The format of the data required is created by the script at: + https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py + + Args: + is_training: `bool` for whether the input is for training + data_dir: `str` for the directory of the training and validation data; + if 'null' (the literal string 'null', not None), then construct a null + pipeline, consisting of empty images. + use_bfloat16: If True, use bfloat16 precision; else use float32. + transpose_input: 'bool' for whether to use the double transpose trick + num_cores: `int` for the number of TPU cores + """ + + def __init__(self, is_training, + use_bfloat16, + data_dir, + num_cores=8, + num_parallel_calls=64, + image_size=224, + transpose_input=False, + cache=False): + self.image_preprocessing_fn = resnet_preprocessing.preprocess_image + self.is_training = is_training + self.use_bfloat16 = use_bfloat16 + self.data_dir = data_dir + self.num_cores = num_cores + self.num_parallel_calls = num_parallel_calls + if self.data_dir == 'null' or self.data_dir == '': + self.data_dir = None + self.transpose_input = transpose_input + self.image_size = image_size + self.cache = cache + + def set_shapes(self, batch_size, images, labels): + """Statically set the batch_size dimension.""" + if self.transpose_input: + images.set_shape(images.get_shape().merge_with( + tf.TensorShape([None, None, None, batch_size]))) + labels.set_shape(labels.get_shape().merge_with( + tf.TensorShape([batch_size]))) + else: + images.set_shape(images.get_shape().merge_with( + tf.TensorShape([batch_size, None, None, None]))) + labels.set_shape(labels.get_shape().merge_with( + tf.TensorShape([batch_size]))) + + return images, labels + + def dataset_parser(self, value): + """Parse an ImageNet record from a serialized string Tensor.""" + keys_to_features = { + 'image/encoded': tf.FixedLenFeature((), tf.string, ''), + 'image/format': tf.FixedLenFeature((), tf.string, 'jpeg'), + 'image/class/label': tf.FixedLenFeature([], tf.int64, -1), + 'image/class/text': tf.FixedLenFeature([], tf.string, ''), + 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), + 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), + 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), + 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32), + 'image/object/class/label': tf.VarLenFeature(dtype=tf.int64), + } + + parsed = tf.parse_single_example(value, keys_to_features) + image_bytes = tf.reshape(parsed['image/encoded'], shape=[]) + + image = self.image_preprocessing_fn( + image_bytes=image_bytes, + is_training=self.is_training, + image_size=self.image_size, + use_bfloat16=self.use_bfloat16) + + # Subtract one so that labels are in [0, 1000). + label = tf.cast( + tf.reshape(parsed['image/class/label'], shape=[]), dtype=tf.int32) - 1 + + return image, label + + def input_fn(self, params): + """Input function which provides a single batch for train or eval. + + Args: + params: `dict` of parameters passed from the `TPUEstimator`. + `params['batch_size']` is always provided and should be used as the + effective batch size. + + Returns: + A `tf.data.Dataset` object. + """ + if self.data_dir is None: + tf.logging.info('Using fake input.') + return self.input_fn_null(params) + + # Retrieves the batch size for the current shard. The # of shards is + # computed according to the input pipeline deployment. See + # tf.contrib.tpu.RunConfig for details. + batch_size = params['batch_size'] + + # Shuffle the filenames to ensure better randomization. + file_pattern = os.path.join( + self.data_dir, 'train-*' if self.is_training else 'validation-*') + dataset = tf.data.Dataset.list_files(file_pattern, shuffle=self.is_training) + + if self.is_training and not self.cache: + dataset = dataset.repeat() + + def fetch_dataset(filename): + buffer_size = 8 * 1024 * 1024 # 8 MiB per file + dataset = tf.data.TFRecordDataset(filename, buffer_size=buffer_size) + return dataset + + # Read the data from disk in parallel + dataset = dataset.apply( + tf.contrib.data.parallel_interleave( + fetch_dataset, cycle_length=self.num_parallel_calls, sloppy=True)) + if self.cache: + dataset = dataset.cache().apply( + tf.contrib.data.shuffle_and_repeat(1024 * 16)) + else: + dataset = dataset.shuffle(1024) + + # Use the fused map-and-batch operation. + # + # For XLA, we must used fixed shapes. Because we repeat the source training + # dataset indefinitely, we can use `drop_remainder=True` to get fixed-size + # batches without dropping any training examples. + # + # When evaluating, `drop_remainder=True` prevents accidentally evaluating + # the same image twice by dropping the final batch if it is less than a full + # batch size. As long as this validation is done with consistent batch size, + # exactly the same images will be used. + dataset = dataset.apply( + tf.contrib.data.map_and_batch( + self.dataset_parser, batch_size=batch_size, + num_parallel_batches=self.num_cores, drop_remainder=True)) + + # Transpose for performance on TPU + if self.transpose_input: + dataset = dataset.map( + lambda images, labels: (tf.transpose(images, [1, 2, 3, 0]), labels), + num_parallel_calls=self.num_cores) + + # Assign static batch size dimension + dataset = dataset.map(functools.partial(self.set_shapes, batch_size)) + + # Prefetch overlaps in-feed with training + dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) + return dataset + + def input_fn_null(self, params): + """Input function which provides null (black) images.""" + batch_size = params['batch_size'] + dataset = tf.data.Dataset.range(1).repeat().map(self._get_null_input) + dataset = dataset.prefetch(batch_size) + + dataset = dataset.batch(batch_size, drop_remainder=True) + if self.transpose_input: + dataset = dataset.map( + lambda images, labels: (tf.transpose(images, [1, 2, 3, 0]), labels), + num_parallel_calls=8) + + dataset = dataset.map(functools.partial(self.set_shapes, batch_size)) + + dataset = dataset.prefetch(32) # Prefetch overlaps in-feed with training + tf.logging.info('Input dataset: %s', str(dataset)) + return dataset + + def _get_null_input(self, _): + null_image = tf.zeros([224, 224, 3], tf.bfloat16 + if self.use_bfloat16 else tf.float32) + return (null_image, tf.constant(0, tf.int32)) diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py index 106559250940acba1a7bb600283e25dae6252e4b..b702e91f92220c2a9003a1b82411131332012a9e 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main.py @@ -23,28 +23,129 @@ import sys from absl import flags import tensorflow as tf -from tqdm import tqdm from tensorflow.contrib.eager.python.examples.revnet import cifar_input from tensorflow.contrib.eager.python.examples.revnet import config as config_ from tensorflow.contrib.eager.python.examples.revnet import revnet tfe = tf.contrib.eager +def apply_gradients(optimizer, grads, vars_, global_step=None): + """Functional style apply_grads for `tfe.defun`.""" + optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + + def main(_): """Eager execution workflow with RevNet trained on CIFAR-10.""" - if FLAGS.data_dir is None: - raise ValueError("No supplied data directory") + tf.enable_eager_execution() - if not os.path.exists(FLAGS.data_dir): - raise ValueError("Data directory {} does not exist".format(FLAGS.data_dir)) + config = get_config(config_name=FLAGS.config, dataset=FLAGS.dataset) + ds_train, ds_train_one_shot, ds_validation, ds_test = get_datasets( + data_dir=FLAGS.data_dir, config=config) + model = revnet.RevNet(config=config) + global_step = tf.train.get_or_create_global_step() # Ensure correct summary + global_step.assign(1) + learning_rate = tf.train.piecewise_constant( + global_step, config.lr_decay_steps, config.lr_list) + optimizer = tf.train.MomentumOptimizer( + learning_rate, momentum=config.momentum) + checkpointer = tf.train.Checkpoint( + optimizer=optimizer, model=model, optimizer_step=global_step) - tf.enable_eager_execution() - config = config_.get_hparams_cifar_38() + if FLAGS.use_defun: + model.call = tfe.defun(model.call) + model.compute_gradients = tfe.defun(model.compute_gradients) + model.get_moving_stats = tfe.defun(model.get_moving_stats) + model.restore_moving_stats = tfe.defun(model.restore_moving_stats) + global apply_gradients # pylint:disable=global-variable-undefined + apply_gradients = tfe.defun(apply_gradients) + + if FLAGS.train_dir: + summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) + if FLAGS.restore: + latest_path = tf.train.latest_checkpoint(FLAGS.train_dir) + checkpointer.restore(latest_path) + print("Restored latest checkpoint at path:\"{}\" " + "with global_step: {}".format(latest_path, global_step.numpy())) + sys.stdout.flush() + + for x, y in ds_train: + train_one_iter(model, x, y, optimizer, global_step=global_step) + + if global_step.numpy() % config.log_every == 0: + it_test = ds_test.make_one_shot_iterator() + acc_test, loss_test = evaluate(model, it_test) + + if FLAGS.validate: + it_train = ds_train_one_shot.make_one_shot_iterator() + it_validation = ds_validation.make_one_shot_iterator() + acc_train, loss_train = evaluate(model, it_train) + acc_validation, loss_validation = evaluate(model, it_validation) + print("Iter {}, " + "training set accuracy {:.4f}, loss {:.4f}; " + "validation set accuracy {:.4f}, loss {:.4f}; " + "test accuracy {:.4f}, loss {:.4f}".format( + global_step.numpy(), acc_train, loss_train, acc_validation, + loss_validation, acc_test, loss_test)) + else: + print("Iter {}, test accuracy {:.4f}, loss {:.4f}".format( + global_step.numpy(), acc_test, loss_test)) + sys.stdout.flush() + + if FLAGS.train_dir: + with summary_writer.as_default(): + with tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar("Test accuracy", acc_test) + tf.contrib.summary.scalar("Test loss", loss_test) + if FLAGS.validate: + tf.contrib.summary.scalar("Training accuracy", acc_train) + tf.contrib.summary.scalar("Training loss", loss_train) + tf.contrib.summary.scalar("Validation accuracy", acc_validation) + tf.contrib.summary.scalar("Validation loss", loss_validation) + + if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir: + saved_path = checkpointer.save( + file_prefix=os.path.join(FLAGS.train_dir, "ckpt")) + print("Saved checkpoint at path: \"{}\" " + "with global_step: {}".format(saved_path, global_step.numpy())) + sys.stdout.flush() + + +def get_config(config_name="revnet-38", dataset="cifar-10"): + """Return configuration.""" + print("Config: {}".format(config_name)) + sys.stdout.flush() + config = { + "revnet-38": config_.get_hparams_cifar_38(), + "revnet-110": config_.get_hparams_cifar_110(), + "revnet-164": config_.get_hparams_cifar_164(), + }[config_name] + + if dataset == "cifar-10": + config.add_hparam("n_classes", 10) + config.add_hparam("dataset", "cifar-10") + else: + config.add_hparam("n_classes", 100) + config.add_hparam("dataset", "cifar-100") + + return config + + +def get_datasets(data_dir, config): + """Return dataset.""" + if data_dir is None: + raise ValueError("No supplied data directory") + if not os.path.exists(data_dir): + raise ValueError("Data directory {} does not exist".format(data_dir)) + if config.dataset not in ["cifar-10", "cifar-100"]: + raise ValueError("Unknown dataset {}".format(config.dataset)) + print("Training on {} dataset.".format(config.dataset)) + sys.stdout.flush() + data_dir = os.path.join(data_dir, config.dataset) if FLAGS.validate: # 40k Training set ds_train = cifar_input.get_ds_from_tfrecords( - data_dir=FLAGS.data_dir, + data_dir=data_dir, split="train", data_aug=True, batch_size=config.batch_size, @@ -55,7 +156,7 @@ def main(_): prefetch=config.batch_size) # 10k Training set ds_validation = cifar_input.get_ds_from_tfrecords( - data_dir=FLAGS.data_dir, + data_dir=data_dir, split="validation", data_aug=False, batch_size=config.eval_batch_size, @@ -67,7 +168,7 @@ def main(_): else: # 50k Training set ds_train = cifar_input.get_ds_from_tfrecords( - data_dir=FLAGS.data_dir, + data_dir=data_dir, split="train_all", data_aug=True, batch_size=config.batch_size, @@ -76,10 +177,11 @@ def main(_): data_format=config.data_format, dtype=config.dtype, prefetch=config.batch_size) + ds_validation = None - # Always compute loss and accuracy on whole training and test set + # Always compute loss and accuracy on whole test set ds_train_one_shot = cifar_input.get_ds_from_tfrecords( - data_dir=FLAGS.data_dir, + data_dir=data_dir, split="train_all", data_aug=False, batch_size=config.eval_batch_size, @@ -90,7 +192,7 @@ def main(_): prefetch=config.eval_batch_size) ds_test = cifar_input.get_ds_from_tfrecords( - data_dir=FLAGS.data_dir, + data_dir=data_dir, split="test", data_aug=False, batch_size=config.eval_batch_size, @@ -100,103 +202,27 @@ def main(_): dtype=config.dtype, prefetch=config.eval_batch_size) - model = revnet.RevNet(config=config) - global_step = tfe.Variable(1, trainable=False) - learning_rate = tf.train.piecewise_constant( - global_step, config.lr_decay_steps, config.lr_list) - optimizer = tf.train.MomentumOptimizer( - learning_rate, momentum=config.momentum) - checkpointer = tf.train.Checkpoint( - optimizer=optimizer, model=model, optimizer_step=global_step) - - if FLAGS.train_dir: - summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) - if FLAGS.restore: - latest_path = tf.train.latest_checkpoint(FLAGS.train_dir) - checkpointer.restore(latest_path) - print("Restored latest checkpoint at path:\"{}\" " - "with global_step: {}".format(latest_path, global_step.numpy())) - sys.stdout.flush() - - warmup(model, config) + return ds_train, ds_train_one_shot, ds_validation, ds_test - for x, y in ds_train: - loss = train_one_iter(model, x, y, optimizer, global_step=global_step) - if global_step.numpy() % config.log_every == 0: - it_train = ds_train_one_shot.make_one_shot_iterator() - acc_train, loss_train = evaluate(model, it_train) - it_test = ds_test.make_one_shot_iterator() - acc_test, loss_test = evaluate(model, it_test) - if FLAGS.validate: - it_validation = ds_validation.make_one_shot_iterator() - acc_validation, loss_validation = evaluate(model, it_validation) - print("Iter {}, " - "training set accuracy {:.4f}, loss {:.4f}; " - "validation set accuracy {:.4f}, loss {:4.f}" - "test accuracy {:.4f}, loss {:.4f}".format( - global_step.numpy(), acc_train, loss_train, acc_validation, - loss_validation, acc_test, loss_test)) - else: - print("Iter {}, " - "training set accuracy {:.4f}, loss {:.4f}; " - "test accuracy {:.4f}, loss {:.4f}".format( - global_step.numpy(), acc_train, loss_train, acc_test, - loss_test)) - sys.stdout.flush() - - if FLAGS.train_dir: - with summary_writer.as_default(): - with tf.contrib.summary.always_record_summaries(): - tf.contrib.summary.scalar("Training loss", loss) - tf.contrib.summary.scalar("Test accuracy", acc_test) - if FLAGS.validate: - tf.contrib.summary.scalar("Validation accuracy", acc_validation) - - if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir: - saved_path = checkpointer.save( - file_prefix=os.path.join(FLAGS.train_dir, "ckpt")) - print("Saved checkpoint at path: \"{}\" " - "with global_step: {}".format(saved_path, global_step.numpy())) - sys.stdout.flush() - - -def warmup(model, config, steps=1): - mock_input = tf.random_normal((config.batch_size,) + config.input_shape) - for _ in range(steps): - model(mock_input, training=False) - - -def train_one_iter(model, - inputs, - labels, - optimizer, - global_step=None, - verbose=False): +def train_one_iter(model, inputs, labels, optimizer, global_step=None): """Train for one iteration.""" - if FLAGS.manual_grad: - if verbose: - print("Using manual gradients") - grads, vars_, loss = model.compute_gradients(inputs, labels) - optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) - else: # For correctness validation - if verbose: - print("Not using manual gradients") - with tf.GradientTape() as tape: - logits, _ = model(inputs, training=True) - loss = model.compute_loss(logits=logits, labels=labels) - grads = tape.gradient(loss, model.trainable_variables) - optimizer.apply_gradients( - zip(grads, model.trainable_variables), global_step=global_step) - - return loss.numpy() + logits, saved_hiddens = model(inputs, training=True) + values = model.get_moving_stats() + grads, loss = model.compute_gradients(saved_hiddens, labels) + # Restore moving averages when executing eagerly to avoid updating twice + model.restore_moving_stats(values) + apply_gradients( + optimizer, grads, model.trainable_variables, global_step=global_step) + + return logits, loss def evaluate(model, iterator): """Compute accuracy with the given dataset iterator.""" mean_loss = tfe.metrics.Mean() accuracy = tfe.metrics.Accuracy() - for x, y in tqdm(iterator): + for x, y in iterator: logits, _ = model(x, training=False) loss = model.compute_loss(logits=logits, labels=y) accuracy( @@ -208,12 +234,12 @@ def evaluate(model, iterator): if __name__ == "__main__": + flags.DEFINE_string( + "data_dir", default=None, help="Directory to load tfrecords") flags.DEFINE_string( "train_dir", default=None, help="[Optional] Directory to store the training information") - flags.DEFINE_string( - "data_dir", default=None, help="Directory to load tfrecords") flags.DEFINE_boolean( "restore", default=False, @@ -222,9 +248,18 @@ if __name__ == "__main__": "validate", default=False, help="[Optional] Use the validation set or not for hyperparameter search") + flags.DEFINE_string( + "dataset", + default="cifar-10", + help="[Optional] The dataset used; either `cifar-10` or `cifar-100`") + flags.DEFINE_string( + "config", + default="revnet-38", + help="[Optional] Architecture of network. " + "Other options include `revnet-110` and `revnet-164`") flags.DEFINE_boolean( - "manual_grad", + "use_defun", default=False, - help="[Optional] Use manual gradient graph to save memory") + help="[Optional] Use `tfe.defun` to boost performance.") FLAGS = flags.FLAGS tf.app.run(main) diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..3a17eb30da3b989acb0b33f2fcb730da76546c18 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py @@ -0,0 +1,200 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Estimator workflow with RevNet train on CIFAR-10.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl import flags +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import cifar_input +from tensorflow.contrib.eager.python.examples.revnet import main as main_ +from tensorflow.contrib.eager.python.examples.revnet import revnet + + +def model_fn(features, labels, mode, params): + """Function specifying the model that is required by the `tf.estimator` API. + + Args: + features: Input images + labels: Labels of images + mode: One of `ModeKeys.TRAIN`, `ModeKeys.EVAL` or 'ModeKeys.PREDICT' + params: A dictionary of extra parameter that might be passed + + Returns: + An instance of `tf.estimator.EstimatorSpec` + """ + + inputs = features + if isinstance(inputs, dict): + inputs = features["image"] + + config = params["config"] + model = revnet.RevNet(config=config) + + if mode == tf.estimator.ModeKeys.TRAIN: + global_step = tf.train.get_or_create_global_step() + learning_rate = tf.train.piecewise_constant( + global_step, config.lr_decay_steps, config.lr_list) + optimizer = tf.train.MomentumOptimizer( + learning_rate, momentum=config.momentum) + logits, saved_hidden = model(inputs, training=True) + grads, loss = model.compute_gradients(saved_hidden, labels, training=True) + with tf.control_dependencies(model.get_updates_for(inputs)): + train_op = optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) + + return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) + else: + logits, _ = model(inputs, training=False) + predictions = tf.argmax(logits, axis=1) + probabilities = tf.nn.softmax(logits) + + if mode == tf.estimator.ModeKeys.EVAL: + loss = model.compute_loss(labels=labels, logits=logits) + return tf.estimator.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops={ + "accuracy": + tf.metrics.accuracy(labels=labels, predictions=predictions) + }) + + else: # mode == tf.estimator.ModeKeys.PREDICT + result = { + "classes": predictions, + "probabilities": probabilities, + } + + return tf.estimator.EstimatorSpec( + mode=mode, + predictions=predictions, + export_outputs={ + "classify": tf.estimator.export.PredictOutput(result) + }) + + +def get_input_fn(config, data_dir, split): + """Get the input function that is required by the `tf.estimator` API. + + Args: + config: Customized hyperparameters + data_dir: Directory where the data is stored + split: One of `train`, `validation`, `train_all`, and `test` + + Returns: + Input function required by the `tf.estimator` API + """ + + data_dir = os.path.join(data_dir, config.dataset) + # Fix split-dependent hyperparameters + if split == "train_all" or split == "train": + data_aug = True + batch_size = config.batch_size + epochs = config.epochs + shuffle = True + prefetch = config.batch_size + else: + data_aug = False + batch_size = config.eval_batch_size + epochs = 1 + shuffle = False + prefetch = config.eval_batch_size + + def input_fn(): + """Input function required by the `tf.estimator.Estimator` API.""" + return cifar_input.get_ds_from_tfrecords( + data_dir=data_dir, + split=split, + data_aug=data_aug, + batch_size=batch_size, + epochs=epochs, + shuffle=shuffle, + prefetch=prefetch, + data_format=config.data_format) + + return input_fn + + +def main(_): + tf.logging.set_verbosity(tf.logging.INFO) + + # RevNet specific configuration + config = main_.get_config(config_name=FLAGS.config, dataset=FLAGS.dataset) + + # Estimator specific configuration + run_config = tf.estimator.RunConfig( + model_dir=FLAGS.model_dir, # Directory for storing checkpoints + tf_random_seed=config.seed, + save_summary_steps=config.log_every, + save_checkpoints_steps=config.log_every, + session_config=None, # Using default + keep_checkpoint_max=100, + keep_checkpoint_every_n_hours=10000, # Using default + log_step_count_steps=config.log_every, + train_distribute=None # Default not use distribution strategy + ) + + # Construct estimator + revnet_estimator = tf.estimator.Estimator( + model_fn=model_fn, + model_dir=FLAGS.model_dir, + config=run_config, + params={"config": config}) + + # Construct input functions + train_input_fn = get_input_fn( + config=config, data_dir=FLAGS.data_dir, split="train_all") + eval_input_fn = get_input_fn( + config=config, data_dir=FLAGS.data_dir, split="test") + + # Train and evaluate estimator + revnet_estimator.train(input_fn=train_input_fn) + revnet_estimator.evaluate(input_fn=eval_input_fn) + + if FLAGS.export: + input_shape = (None,) + config.input_shape + inputs = tf.placeholder(tf.float32, shape=input_shape) + input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ + "image": inputs + }) + revnet_estimator.export_savedmodel(FLAGS.model_dir, input_fn) + + +if __name__ == "__main__": + flags.DEFINE_string( + "data_dir", default=None, help="Directory to load tfrecords") + flags.DEFINE_string( + "model_dir", + default=None, + help="[Optional] Directory to store the training information") + flags.DEFINE_string( + "dataset", + default="cifar-10", + help="[Optional] The dataset used; either `cifar-10` or `cifar-100`") + flags.DEFINE_boolean( + "export", + default=False, + help="[Optional] Export the model for serving if True") + flags.DEFINE_string( + "config", + default="revnet-38", + help="[Optional] Architecture of network. " + "Other options include `revnet-110` and `revnet-164`") + FLAGS = flags.FLAGS + tf.app.run() diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py new file mode 100644 index 0000000000000000000000000000000000000000..8520cf5b71af503be35d5415707a283fb363a476 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py @@ -0,0 +1,394 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cloud TPU Estimator workflow with RevNet train on ImageNet.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from absl import flags +import tensorflow as tf +from tensorflow.contrib import summary +from tensorflow.contrib.eager.python.examples.revnet import config as config_ +from tensorflow.contrib.eager.python.examples.revnet import imagenet_input +from tensorflow.contrib.eager.python.examples.revnet import revnet +from tensorflow.contrib.training.python.training import evaluation +from tensorflow.python.estimator import estimator + +MEAN_RGB = [0.485, 0.456, 0.406] +STDDEV_RGB = [0.229, 0.224, 0.225] + + +def _host_call_fn(gs, loss, lr): + """Training host call. + + Creates scalar summaries for training metrics. + + This function is executed on the CPU and should not directly reference + any Tensors in the rest of the `model_fn`. To pass Tensors from the + model to the `metric_fn`, provide as part of the `host_call`. See + https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec + for more information. + + Arguments should match the list of `Tensor` objects passed as the second + element in the tuple passed to `host_call`. + + Args: + gs: `Tensor with shape `[batch]` for the global_step + loss: `Tensor` with shape `[batch]` for the training loss. + lr: `Tensor` with shape `[batch]` for the learning_rate. + + Returns: + List of summary ops to run on the CPU host. + """ + # Host call fns are executed FLAGS.iterations_per_loop times after one + # TPU loop is finished, setting max_queue value to the same as number of + # iterations will make the summary writer only flush the data to storage + # once per loop. + gs = gs[0] + with summary.create_file_writer( + FLAGS.model_dir, max_queue=FLAGS.iterations_per_loop).as_default(): + with summary.always_record_summaries(): + summary.scalar("loss", loss[0], step=gs) + summary.scalar("learning_rate", lr[0], step=gs) + return summary.all_summary_ops() + + +def _metric_fn(labels, logits): + """Evaluation metric function. Evaluates accuracy. + + This function is executed on the CPU and should not directly reference + any Tensors in the rest of the `model_fn`. To pass Tensors from the model + to the `metric_fn`, provide as part of the `eval_metrics`. See + https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec + for more information. + + Arguments should match the list of `Tensor` objects passed as the second + element in the tuple passed to `eval_metrics`. + + Args: + labels: `Tensor` with shape `[batch]`. + logits: `Tensor` with shape `[batch, num_classes]`. + + Returns: + A dict of the metrics to return from evaluation. + """ + predictions = tf.argmax(logits, axis=1) + top_1_accuracy = tf.metrics.accuracy(labels, predictions) + in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) + top_5_accuracy = tf.metrics.mean(in_top_5) + + return { + "top_1_accuracy": top_1_accuracy, + "top_5_accuracy": top_5_accuracy, + } + + +def model_fn(features, labels, mode, params): + """Model function required by the `tf.contrib.tpu.TPUEstimator` API. + + Args: + features: Input images + labels: Labels of images + mode: One of `ModeKeys.TRAIN`, `ModeKeys.EVAL` or 'ModeKeys.PREDICT' + params: A dictionary of extra parameter that might be passed + + Returns: + An instance of `tf.contrib.tpu.TPUEstimatorSpec` + """ + revnet_config = params["revnet_config"] + model = revnet.RevNet(config=revnet_config) + + inputs = features + if isinstance(inputs, dict): + inputs = features["image"] + + if revnet_config.data_format == "channels_first": + assert not FLAGS.transpose_input # channels_first only for GPU + inputs = tf.transpose(inputs, [0, 3, 1, 2]) + + if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT: + inputs = tf.transpose(inputs, [3, 0, 1, 2]) # HWCN to NHWC + + # Normalize the image to zero mean and unit variance. + inputs -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=inputs.dtype) + inputs /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=inputs.dtype) + + if mode == tf.estimator.ModeKeys.TRAIN: + global_step = tf.train.get_or_create_global_step() + learning_rate = tf.train.piecewise_constant( + global_step, revnet_config.lr_decay_steps, revnet_config.lr_list) + optimizer = tf.train.MomentumOptimizer(learning_rate, + revnet_config.momentum) + if FLAGS.use_tpu: + optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) + + logits, saved_hidden = model(inputs, training=True) + grads, loss = model.compute_gradients(saved_hidden, labels, training=True) + with tf.control_dependencies(model.get_updates_for(inputs)): + train_op = optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) + if not FLAGS.skip_host_call: + # To log the loss, current learning rate, and epoch for Tensorboard, the + # summary op needs to be run on the host CPU via host_call. host_call + # expects [batch_size, ...] Tensors, thus reshape to introduce a batch + # dimension. These Tensors are implicitly concatenated to + # [params['batch_size']]. + gs_t = tf.reshape(global_step, [1]) + loss_t = tf.reshape(loss, [1]) + lr_t = tf.reshape(learning_rate, [1]) + host_call = (_host_call_fn, [gs_t, loss_t, lr_t]) + + return tf.contrib.tpu.TPUEstimatorSpec( + mode=mode, loss=loss, train_op=train_op, host_call=host_call) + + elif mode == tf.estimator.ModeKeys.EVAL: + logits, _ = model(inputs, training=False) + loss = model.compute_loss(labels=labels, logits=logits) + + return tf.contrib.tpu.TPUEstimatorSpec( + mode=mode, loss=loss, eval_metrics=(_metric_fn, [labels, logits])) + + else: # Predict or export + logits, _ = model(inputs, training=False) + predictions = { + "classes": tf.argmax(logits, axis=1), + "probabilities": tf.nn.softmax(logits), + } + + return tf.contrib.tpu.TPUEstimatorSpec( + mode=mode, + predictions=predictions, + export_outputs={ + "classify": tf.estimator.export.PredictOutput(predictions) + }) + + +def main(_): + tf.logging.set_verbosity(tf.logging.INFO) + + # RevNet specific configuration + revnet_config = { + "revnet-56": config_.get_hparams_imagenet_56(), + "revnet-104": config_.get_hparams_imagenet_104() + }[FLAGS.revnet_config] + + if FLAGS.use_tpu: + revnet_config.data_format = "channels_last" + + tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( + FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) + + # Estimator specific configuration + config = tf.contrib.tpu.RunConfig( + cluster=tpu_cluster_resolver, + model_dir=FLAGS.model_dir, + session_config=tf.ConfigProto( + allow_soft_placement=True, log_device_placement=True), + tpu_config=tf.contrib.tpu.TPUConfig( + iterations_per_loop=FLAGS.iterations_per_loop, + num_shards=FLAGS.num_shards, + per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig. + PER_HOST_V2), + ) + + # Input pipelines are slightly different (with regards to shuffling and + # preprocessing) between training and evaluation. + imagenet_train, imagenet_eval = [ + imagenet_input.ImageNetInput( + is_training=is_training, + data_dir=FLAGS.data_dir, + transpose_input=FLAGS.transpose_input, + use_bfloat16=False) for is_training in [True, False] + ] + + revnet_classifier = tf.contrib.tpu.TPUEstimator( + model_fn=model_fn, + use_tpu=FLAGS.use_tpu, + train_batch_size=revnet_config.tpu_batch_size, + eval_batch_size=revnet_config.tpu_eval_batch_size, + config=config, + export_to_tpu=False, + params={"revnet_config": revnet_config}) + + steps_per_epoch = revnet_config.tpu_iters_per_epoch + eval_steps = revnet_config.tpu_eval_steps + + # pylint: disable=protected-access + if FLAGS.mode == "eval": + # Run evaluation when there's a new checkpoint + for ckpt in evaluation.checkpoints_iterator( + FLAGS.model_dir, timeout=FLAGS.eval_timeout): + tf.logging.info("Starting to evaluate.") + try: + start_timestamp = time.time() # This time will include compilation time + eval_results = revnet_classifier.evaluate( + input_fn=imagenet_eval.input_fn, + steps=eval_steps, + checkpoint_path=ckpt) + elapsed_time = int(time.time() - start_timestamp) + tf.logging.info("Eval results: %s. Elapsed seconds: %d" % + (eval_results, elapsed_time)) + + # Terminate eval job when final checkpoint is reached + current_step = int(os.path.basename(ckpt).split("-")[1]) + if current_step >= revnet_config.max_train_iter: + tf.logging.info( + "Evaluation finished after training step %d" % current_step) + break + + except tf.errors.NotFoundError: + # Since the coordinator is on a different job than the TPU worker, + # sometimes the TPU worker does not finish initializing until long after + # the CPU job tells it to start evaluating. In this case, the checkpoint + # file could have been deleted already. + tf.logging.info( + "Checkpoint %s no longer exists, skipping checkpoint" % ckpt) + + else: # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval' + current_step = estimator._load_global_step_from_checkpoint_dir( + FLAGS.model_dir) + + tf.logging.info( + "Training for %d steps (%.2f epochs in total). Current" + " step %d." % (revnet_config.max_train_iter, + revnet_config.max_train_iter / steps_per_epoch, + current_step)) + + start_timestamp = time.time() # This time will include compilation time + + if FLAGS.mode == "train": + revnet_classifier.train( + input_fn=imagenet_train.input_fn, + max_steps=revnet_config.max_train_iter) + + else: + assert FLAGS.mode == "train_and_eval" + while current_step < revnet_config.max_train_iter: + # Train for up to steps_per_eval number of steps. + # At the end of training, a checkpoint will be written to --model_dir. + next_checkpoint = min(current_step + FLAGS.steps_per_eval, + revnet_config.max_train_iter) + revnet_classifier.train( + input_fn=imagenet_train.input_fn, max_steps=next_checkpoint) + current_step = next_checkpoint + + tf.logging.info("Finished training up to step %d. Elapsed seconds %d." % + (next_checkpoint, int(time.time() - start_timestamp))) + + # Evaluate the model on the most recent model in --model_dir. + # Since evaluation happens in batches of --eval_batch_size, some images + # may be excluded modulo the batch size. As long as the batch size is + # consistent, the evaluated images are also consistent. + tf.logging.info("Starting to evaluate.") + eval_results = revnet_classifier.evaluate( + input_fn=imagenet_eval.input_fn, steps=eval_steps) + tf.logging.info("Eval results: %s" % eval_results) + + elapsed_time = int(time.time() - start_timestamp) + tf.logging.info("Finished training up to step %d. Elapsed seconds %d." % + (revnet_config.max_train_iter, elapsed_time)) + + if FLAGS.export_dir is not None: + # The guide to serve an exported TensorFlow model is at: + # https://www.tensorflow.org/serving/serving_basic + tf.logging.info("Starting to export model.") + revnet_classifier.export_savedmodel( + export_dir_base=FLAGS.export_dir, + serving_input_receiver_fn=imagenet_input.image_serving_input_fn) + + +if __name__ == "__main__": + # Cloud TPU Cluster Resolver flags + flags.DEFINE_string( + "tpu", + default=None, + help="The Cloud TPU to use for training. This should be either the name " + "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " + "url.") + flags.DEFINE_string( + "tpu_zone", + default=None, + help="[Optional] GCE zone where the Cloud TPU is located in. If not " + "specified, we will attempt to automatically detect the GCE project from " + "metadata.") + flags.DEFINE_string( + "gcp_project", + default=None, + help="[Optional] Project name for the Cloud TPU-enabled project. If not " + "specified, we will attempt to automatically detect the GCE project from " + "metadata.") + + # Model specific parameters + flags.DEFINE_string( + "data_dir", default=None, help="Directory to load tfrecords") + flags.DEFINE_string( + "model_dir", + default=None, + help="[Optional] Directory to store the model information") + flags.DEFINE_string( + "revnet_config", + default="revnet-56", + help="[Optional] Architecture of network. " + "Other options include `revnet-104`") + flags.DEFINE_boolean( + "use_tpu", default=True, help="[Optional] Whether to use TPU") + flags.DEFINE_integer( + "num_shards", default=8, help="Number of shards (TPU chips).") + flags.DEFINE_integer( + "iterations_per_loop", + default=100, + help=( + "Number of steps to run on TPU before feeding metrics to the CPU." + " If the number of iterations in the loop would exceed the number of" + " train steps, the loop will exit before reaching" + " --iterations_per_loop. The larger this value is, the higher the" + " utilization on the TPU.")) + flags.DEFINE_integer( + "eval_timeout", + default=None, + help="Maximum seconds between checkpoints before evaluation terminates.") + flags.DEFINE_integer( + "steps_per_eval", + default=5000, + help=( + "Controls how often evaluation is performed. Since evaluation is" + " fairly expensive, it is advised to evaluate as infrequently as" + " possible (i.e. up to --train_steps, which evaluates the model only" + " after finishing the entire training regime).")) + flags.DEFINE_bool( + "transpose_input", + default=True, + help="Use TPU double transpose optimization") + flags.DEFINE_string( + "export_dir", + default=None, + help=("The directory where the exported SavedModel will be stored.")) + flags.DEFINE_bool( + "skip_host_call", + default=False, + help=("Skip the host_call which is executed every training step. This is" + " generally used for generating training summaries (train loss," + " learning rate, etc...). When --skip_host_call=false, there could" + " be a performance drop if host_call function is slow and cannot" + " keep up with the TPU-side computation.")) + flags.DEFINE_string( + "mode", + default="train_and_eval", + help='One of {"train_and_eval", "train", "eval"}.') + FLAGS = flags.FLAGS + tf.app.run() diff --git a/tensorflow/contrib/eager/python/examples/revnet/resnet_preprocessing.py b/tensorflow/contrib/eager/python/examples/revnet/resnet_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..21a1ab85d46cde11453e1f693cc4aabbbf3c90ed --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/resnet_preprocessing.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ImageNet preprocessing for ResNet.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +IMAGE_SIZE = 224 +CROP_PADDING = 32 + + +def distorted_bounding_box_crop(image_bytes, + bbox, + min_object_covered=0.1, + aspect_ratio_range=(0.75, 1.33), + area_range=(0.05, 1.0), + max_attempts=100, + scope=None): + """Generates cropped_image using one of the bboxes randomly distorted. + + See `tf.image.sample_distorted_bounding_box` for more documentation. + + Args: + image_bytes: `Tensor` of binary image data. + bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` + where each coordinate is [0, 1) and the coordinates are arranged + as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole + image. + min_object_covered: An optional `float`. Defaults to `0.1`. The cropped + area of the image must contain at least this fraction of any bounding + box supplied. + aspect_ratio_range: An optional list of `float`s. The cropped area of the + image must have an aspect ratio = width / height within this range. + area_range: An optional list of `float`s. The cropped area of the image + must contain a fraction of the supplied image within in this range. + max_attempts: An optional `int`. Number of attempts at generating a cropped + region of the image of the specified constraints. After `max_attempts` + failures, return the entire image. + scope: Optional `str` for name scope. + Returns: + cropped image `Tensor` + """ + with tf.name_scope(scope, 'distorted_bounding_box_crop', [image_bytes, bbox]): + shape = tf.image.extract_jpeg_shape(image_bytes) + sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( + shape, + bounding_boxes=bbox, + min_object_covered=min_object_covered, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + max_attempts=max_attempts, + use_image_if_no_bounding_boxes=True) + bbox_begin, bbox_size, _ = sample_distorted_bounding_box + + # Crop the image to the specified bounding box. + offset_y, offset_x, _ = tf.unstack(bbox_begin) + target_height, target_width, _ = tf.unstack(bbox_size) + crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) + image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) + + return image + + +def _at_least_x_are_equal(a, b, x): + """At least `x` of `a` and `b` `Tensors` are equal.""" + match = tf.equal(a, b) + match = tf.cast(match, tf.int32) + return tf.greater_equal(tf.reduce_sum(match), x) + + +def _decode_and_random_crop(image_bytes, image_size): + """Make a random crop of image_size.""" + bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) + image = distorted_bounding_box_crop( + image_bytes, + bbox, + min_object_covered=0.1, + aspect_ratio_range=(3. / 4, 4. / 3.), + area_range=(0.08, 1.0), + max_attempts=10, + scope=None) + original_shape = tf.image.extract_jpeg_shape(image_bytes) + bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) + + image = tf.cond( + bad, + lambda: _decode_and_center_crop(image_bytes, image_size), + lambda: tf.image.resize_bicubic([image], # pylint: disable=g-long-lambda + [image_size, image_size])[0]) + + return image + + +def _decode_and_center_crop(image_bytes, image_size): + """Crops to center of image with padding then scales image_size.""" + shape = tf.image.extract_jpeg_shape(image_bytes) + image_height = shape[0] + image_width = shape[1] + + padded_center_crop_size = tf.cast( + ((image_size / (image_size + CROP_PADDING)) * + tf.cast(tf.minimum(image_height, image_width), tf.float32)), + tf.int32) + + offset_height = ((image_height - padded_center_crop_size) + 1) // 2 + offset_width = ((image_width - padded_center_crop_size) + 1) // 2 + crop_window = tf.stack([offset_height, offset_width, + padded_center_crop_size, padded_center_crop_size]) + image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) + image = tf.image.resize_bicubic([image], [image_size, image_size])[0] + + return image + + +def _flip(image): + """Random horizontal image flip.""" + image = tf.image.random_flip_left_right(image) + return image + + +def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE): + """Preprocesses the given image for evaluation. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + use_bfloat16: `bool` for whether to use bfloat16. + image_size: image size. + + Returns: + A preprocessed image `Tensor`. + """ + image = _decode_and_random_crop(image_bytes, image_size) + image = _flip(image) + image = tf.reshape(image, [image_size, image_size, 3]) + image = tf.image.convert_image_dtype( + image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32) + return image + + +def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE): + """Preprocesses the given image for evaluation. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + use_bfloat16: `bool` for whether to use bfloat16. + image_size: image size. + + Returns: + A preprocessed image `Tensor`. + """ + image = _decode_and_center_crop(image_bytes, image_size) + image = tf.reshape(image, [image_size, image_size, 3]) + image = tf.image.convert_image_dtype( + image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32) + return image + + +def preprocess_image(image_bytes, + is_training=False, + use_bfloat16=False, + image_size=IMAGE_SIZE): + """Preprocesses the given image. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + is_training: `bool` for whether the preprocessing is for training. + use_bfloat16: `bool` for whether to use bfloat16. + image_size: image size. + + Returns: + A preprocessed image `Tensor`. + """ + if is_training: + return preprocess_for_train(image_bytes, use_bfloat16, image_size) + else: + return preprocess_for_eval(image_bytes, use_bfloat16, image_size) diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py index 0228bff6fab9b9704bfea0836f06a6ec0ff7839e..1f2cb14972f0b92d29489adff8f94e790e1ec4ed 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py @@ -24,10 +24,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools -import operator - -import six import tensorflow as tf from tensorflow.contrib.eager.python.examples.revnet import blocks @@ -45,66 +41,10 @@ class RevNet(tf.keras.Model): self.axis = 1 if config.data_format == "channels_first" else 3 self.config = config - self._init_block = self._construct_init_block() + self._init_block = blocks.InitBlock(config=self.config) + self._final_block = blocks.FinalBlock(config=self.config) self._block_list = self._construct_intermediate_blocks() - self._final_block = self._construct_final_block() - self._moving_stats_vars = None - - def _construct_init_block(self): - init_block = tf.keras.Sequential( - [ - tf.keras.layers.Conv2D( - filters=self.config.init_filters, - kernel_size=self.config.init_kernel, - strides=(self.config.init_stride, self.config.init_stride), - data_format=self.config.data_format, - use_bias=False, - padding="SAME", - input_shape=self.config.input_shape), - tf.keras.layers.BatchNormalization( - axis=self.axis, fused=self.config.fused), - tf.keras.layers.Activation("relu"), - ], - name="init") - if self.config.init_max_pool: - init_block.add( - tf.keras.layers.MaxPooling2D( - pool_size=(3, 3), - strides=(2, 2), - padding="SAME", - data_format=self.config.data_format)) - return init_block - - def _construct_final_block(self): - f = self.config.filters[-1] # Number of filters - r = functools.reduce(operator.mul, self.config.strides, 1) # Reduce ratio - r *= self.config.init_stride - if self.config.init_max_pool: - r *= 2 - - if self.config.data_format == "channels_first": - w, h = self.config.input_shape[1], self.config.input_shape[2] - input_shape = (f, w // r, h // r) - elif self.config.data_format == "channels_last": - w, h = self.config.input_shape[0], self.config.input_shape[1] - input_shape = (w // r, h // r, f) - else: - raise ValueError("Data format should be either `channels_first`" - " or `channels_last`") - - final_block = tf.keras.Sequential( - [ - tf.keras.layers.BatchNormalization( - axis=self.axis, - input_shape=input_shape, - fused=self.config.fused), - tf.keras.layers.Activation("relu"), - tf.keras.layers.GlobalAveragePooling2D( - data_format=self.config.data_format), - tf.keras.layers.Dense(self.config.n_classes) - ], - name="final") - return final_block + self._moving_average_variables = [] def _construct_intermediate_blocks(self): # Precompute input shape after initial block @@ -139,7 +79,8 @@ class RevNet(tf.keras.Model): batch_norm_first=(i != 0), # Only skip on first block data_format=self.config.data_format, bottleneck=self.config.bottleneck, - fused=self.config.fused) + fused=self.config.fused, + dtype=self.config.dtype) block_list.append(rev_block) # Precompute input shape for the next block @@ -174,97 +115,103 @@ class RevNet(tf.keras.Model): def compute_loss(self, logits, labels): """Compute cross entropy loss.""" - cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits( - logits=logits, labels=labels) + if self.config.dtype == tf.float32 or self.config.dtype == tf.float16: + cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels) + else: + # `sparse_softmax_cross_entropy_with_logits` does not have a GPU kernel + # for float64, int32 pairs + labels = tf.one_hot( + labels, depth=self.config.n_classes, axis=1, dtype=self.config.dtype) + cross_ent = tf.nn.softmax_cross_entropy_with_logits( + logits=logits, labels=labels) return tf.reduce_mean(cross_ent) - def compute_gradients(self, inputs, labels, training=True): + def compute_gradients(self, saved_hidden, labels, training=True, l2_reg=True): """Manually computes gradients. - This method also SILENTLY updates the running averages of batch - normalization when `training` is set to True. + This method silently updates the running averages of batch normalization. Args: - inputs: Image tensor, either NHWC or NCHW, conforming to `data_format` + saved_hidden: List of hidden states Tensors labels: One-hot labels for classification training: Use the mini-batch stats in batch norm if set to True + l2_reg: Apply l2 regularization Returns: - list of tuples each being (grad, var) for optimizer to use + A tuple with the first entry being a list of all gradients and the second + being the loss """ - # Run forward pass to record hidden states; avoid updating running averages - vars_and_vals = self.get_moving_stats() - _, saved_hidden = self.call(inputs, training=training) - self.restore_moving_stats(vars_and_vals) - - grads_all = [] - vars_all = [] + def _defunable_pop(l): + """Functional style list pop that works with `tfe.defun`.""" + t, l = l[-1], l[:-1] + return t, l - # Manually backprop through last block + # Backprop through last block x = saved_hidden[-1] with tf.GradientTape() as tape: - x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed tape.watch(x) - # Running stats updated below logits = self._final_block(x, training=training) loss = self.compute_loss(logits, labels) - grads_combined = tape.gradient(loss, [x] + self._final_block.trainable_variables) - dy, grads_ = grads_combined[0], grads_combined[1:] - grads_all += grads_ - vars_all += self._final_block.trainable_variables + dy, final_grads = grads_combined[0], grads_combined[1:] - # Manually backprop through intermediate blocks + # Backprop through intermediate blocks + intermediate_grads = [] for block in reversed(self._block_list): - y = saved_hidden.pop() + y, saved_hidden = _defunable_pop(saved_hidden) x = saved_hidden[-1] - dy, grads, vars_ = block.backward_grads_and_vars( - x, y, dy, training=training) - grads_all += grads - vars_all += vars_ - - # Manually backprop through first block - saved_hidden.pop() - x = saved_hidden.pop() - assert not saved_hidden # Cleared after backprop + dy, grads = block.backward_grads(x, y, dy, training=training) + intermediate_grads = grads + intermediate_grads + # Backprop through first block + _, saved_hidden = _defunable_pop(saved_hidden) + x, saved_hidden = _defunable_pop(saved_hidden) + assert not saved_hidden with tf.GradientTape() as tape: - x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed - # Running stats updated below y = self._init_block(x, training=training) + init_grads = tape.gradient( + y, self._init_block.trainable_variables, output_gradients=dy) - grads_all += tape.gradient( - y, self._init_block.trainable_variables, output_gradients=[dy]) - vars_all += self._init_block.trainable_variables + # Ordering match up with `model.trainable_variables` + grads_all = init_grads + final_grads + intermediate_grads + if l2_reg: + grads_all = self._apply_weight_decay(grads_all) - # Apply weight decay - grads_all = self._apply_weight_decay(grads_all, vars_all) + return grads_all, loss - return grads_all, vars_all, loss - - def _apply_weight_decay(self, grads, vars_): + def _apply_weight_decay(self, grads): """Update gradients to reflect weight decay.""" - # Don't decay bias return [ g + self.config.weight_decay * v if v.name.endswith("kernel:0") else g - for g, v in zip(grads, vars_) + for g, v in zip(grads, self.trainable_variables) ] def get_moving_stats(self): - vars_and_vals = {} - - def _is_moving_var(v): + """Get moving averages of batch normalization.""" + device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0" + with tf.device(device): + return [v.read_value() for v in self.moving_average_variables] + + def restore_moving_stats(self, values): + """Restore moving averages of batch normalization.""" + device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0" + with tf.device(device): + for var_, val in zip(self.moving_average_variables, values): + var_.assign(val) + + @property + def moving_average_variables(self): + """Get all variables that are batch norm moving averages.""" + + def _is_moving_avg(v): n = v.name return n.endswith("moving_mean:0") or n.endswith("moving_variance:0") - for v in filter(_is_moving_var, self.variables): - vars_and_vals[v] = v.read_value() - - return vars_and_vals + if not self._moving_average_variables: + self._moving_average_variables = filter(_is_moving_avg, self.variables) - def restore_moving_stats(self, vars_and_vals): - for var_, val in six.iteritems(vars_and_vals): - var_.assign(val) + return self._moving_average_variables diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py index a5f240436a51f1f07669e06017761f003bfd9395..84b2ddf0de0739936d458ae1bce832cfbb167d64 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -22,6 +22,7 @@ import gc import time import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import blocks_test from tensorflow.contrib.eager.python.examples.revnet import config as config_ from tensorflow.contrib.eager.python.examples.revnet import revnet from tensorflow.python.client import device_lib @@ -30,26 +31,33 @@ tfe = tf.contrib.eager def train_one_iter(model, inputs, labels, optimizer, global_step=None): """Train for one iteration.""" - grads, vars_, loss = model.compute_gradients(inputs, labels, training=True) - optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + logits, saved_hidden = model(inputs) + grads, loss = model.compute_gradients( + saved_hidden=saved_hidden, labels=labels) + optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) - return loss + return logits, loss class RevNetTest(tf.test.TestCase): def setUp(self): super(RevNetTest, self).setUp() - tf.set_random_seed(1) - config = config_.get_hparams_imagenet_56() + config = config_.get_hparams_cifar_38() + config.add_hparam("n_classes", 10) + config.add_hparam("dataset", "cifar-10") + # Reconstruction could cause numerical error, use double precision for tests + config.dtype = tf.float64 + config.fused = False # Fused batch norm does not support tf.float64 shape = (config.batch_size,) + config.input_shape self.model = revnet.RevNet(config=config) - self.x = tf.random_normal(shape=shape) + self.x = tf.random_normal(shape=shape, dtype=tf.float64) self.t = tf.random_uniform( shape=[config.batch_size], minval=0, maxval=config.n_classes, - dtype=tf.int32) + dtype=tf.int64) self.config = config def tearDown(self): @@ -65,28 +73,61 @@ class RevNetTest(tf.test.TestCase): y, _ = self.model(self.x, training=False) self.assertEqual(y.shape, [self.config.batch_size, self.config.n_classes]) + def _check_grad_angle_combined(self, grads, grads_true): + """Verify that the reconstructed gradients has correct direction. + + Due to numerical imprecision, the magnitude may be slightly different. + Yet according to the paper, the angle should be roughly the same. + + Args: + grads: list of gradients from reconstruction + grads_true: list of true gradients + """ + + def _combine(gs): + return [tf.reshape(g, [-1]) for g in gs] + + g1_all = tf.concat(_combine(grads), axis=0) + g2_all = tf.concat(_combine(grads_true), axis=0) + + self.assertEqual(len(g1_all.shape), 1) + self.assertEqual(len(g2_all.shape), 1) + + degree = blocks_test.compute_degree(g1_all, g2_all) + self.assertLessEqual(degree, 1e0) + def test_compute_gradients(self): """Test `compute_gradients` function.""" - - grads, vars_, _ = self.model.compute_gradients( - inputs=self.x, labels=self.t, training=True) + _, saved_hidden = self.model(self.x) # Initialize model + grads, loss = self.model.compute_gradients( + saved_hidden=saved_hidden, labels=self.t) + vars_ = self.model.trainable_variables self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list)) self.assertEqual(len(grads), len(vars_)) for grad, var in zip(grads, vars_): - if grad is not None: - self.assertEqual(grad.shape, var.shape) + self.assertEqual(grad.shape, var.shape) + + # Compare against the true gradient computed by the tape + with tf.GradientTape() as tape: + logits, _ = self.model(self.x) + loss_true = self.model.compute_loss(logits=logits, labels=self.t) + grads_true = tape.gradient(loss_true, vars_) + self.assertAllClose(loss, loss_true) + self.assertAllClose(grads, grads_true, rtol=1e-4, atol=1e-4) + self._check_grad_angle_combined(grads, grads_true) def test_call_defun(self): """Test `call` function with defun.""" - y, _ = tfe.defun(self.model.call)(self.x, training=False) self.assertEqual(y.shape, [self.config.batch_size, self.config.n_classes]) def test_compute_gradients_defun(self): """Test `compute_gradients` function with defun.""" compute_gradients = tfe.defun(self.model.compute_gradients) - grads, vars_, _ = compute_gradients(self.x, self.t, training=True) + _, saved_hidden = self.model(self.x) + grads, _ = compute_gradients(saved_hidden=saved_hidden, labels=self.t) + vars_ = self.model.trainable_variables self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list)) self.assertEqual(len(grads), len(vars_)) @@ -96,8 +137,11 @@ class RevNetTest(tf.test.TestCase): def test_training_graph(self): """Test model training in graph mode.""" - with tf.Graph().as_default(): + config = config_.get_hparams_cifar_38() + config.add_hparam("n_classes", 10) + config.add_hparam("dataset", "cifar-10") + x = tf.random_normal( shape=(self.config.batch_size,) + self.config.input_shape) t = tf.random_uniform( @@ -105,15 +149,13 @@ class RevNetTest(tf.test.TestCase): minval=0, maxval=self.config.n_classes, dtype=tf.int32) - global_step = tfe.Variable(0., trainable=False) - model = revnet.RevNet(config=self.config) - grads_all, vars_all, _ = model.compute_gradients(x, t, training=True) + global_step = tf.Variable(0., trainable=False) + model = revnet.RevNet(config=config) + _, saved_hidden = model(x) + grads, _ = model.compute_gradients(saved_hidden=saved_hidden, labels=t) optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) - updates = model.get_updates_for(x) - self.assertEqual(len(updates), 192) - with tf.control_dependencies(model.get_updates_for(x)): - train_op = optimizer.apply_gradients( - zip(grads_all, vars_all), global_step=global_step) + train_op = optimizer.apply_gradients( + zip(grads, model.trainable_variables), global_step=global_step) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index c2340a293a80924f2dfa90e2fb23134b0f1feb6b..15776c694e92825895437a4c1547699f6d9269fb 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -310,12 +310,12 @@ def main(_): with tf.device("/device:GPU:0" if have_gpu else None): # Make learning_rate a Variable so it can be included in the checkpoint # and we can resume training with the last saved learning_rate. - learning_rate = tfe.Variable(20.0, name="learning_rate") + learning_rate = tf.Variable(20.0, name="learning_rate") model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim, FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout, use_cudnn_rnn) optimizer = tf.train.GradientDescentOptimizer(learning_rate) - checkpoint = tfe.Checkpoint( + checkpoint = tf.train.Checkpoint( learning_rate=learning_rate, model=model, # GradientDescentOptimizer has no state to checkpoint, but noting it # here lets us swap in an optimizer that does. diff --git a/tensorflow/contrib/eager/python/examples/sagan/BUILD b/tensorflow/contrib/eager/python/examples/sagan/BUILD deleted file mode 100644 index b470a41d815ce650731680065cc7341f844e3fdc..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/sagan/BUILD +++ /dev/null @@ -1,59 +0,0 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) - -load("//tensorflow:tensorflow.bzl", "cuda_py_test") - -# Model -py_library( - name = "config", - srcs = ["config.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow:tensorflow_py", - ], -) - -py_library( - name = "ops", - srcs = ["ops.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow:tensorflow_py", - ], -) - -py_library( - name = "sagan", - srcs = ["sagan.py"], - srcs_version = "PY2AND3", - deps = [ - ":ops", - "//tensorflow:tensorflow_py", - ], -) - -# Tests -cuda_py_test( - name = "ops_test", - size = "small", - srcs = ["ops_test.py"], - additional_deps = [ - ":ops", - "//tensorflow:tensorflow_py", - ], -) - -cuda_py_test( - name = "sagan_test", - size = "large", - srcs = ["sagan_test.py"], - additional_deps = [ - ":config", - ":sagan", - "//tensorflow:tensorflow_py", - ], - tags = [ - "optonly", - ], -) diff --git a/tensorflow/contrib/eager/python/examples/sagan/config.py b/tensorflow/contrib/eager/python/examples/sagan/config.py deleted file mode 100644 index 1967bbd867447d9deaf9a7cb3b22a38889276a50..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/sagan/config.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Self-attention generative adversarial with eager execution. - -Configuration in format of tf.contrib.training.HParams. -Supports default 128x128 ImageNet. - -Reference [Self-Attention Generative Adversarial -Networks](https://arxiv.org/pdf/1805.08318.pdf) - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf -tfe = tf.contrib.eager - - -def get_hparams_imagenet(): - """Configurations to train SAGAN on 128x128 ImageNet dataset.""" - config = tf.contrib.training.HParams() - if tf.test.is_gpu_available(): - config.add_hparam("image_shape", (3, 128, 128)) - config.add_hparam("data_format", "channels_first") - config.add_hparam("g_init_shape", (512, 4, 4)) - else: - config.add_hparam("image_shape", (128, 128, 3)) - config.add_hparam("data_format", "channels_first") - config.add_hparam("g_init_shape", (4, 4, 512)) - - config.add_hparam("latent_dim", 128) - config.add_hparam("update_g_once_every", 1) - config.add_hparam("batch_size", 64) - config.add_hparam("d_init_filters", 32) - config.add_hparam("num_upsamples", 5) - # (512, 4, 4) -> (3, 128, 128) - return config - - -def get_hparams_mock(): - """Configurations of smaller networks for testing.""" - config = tf.contrib.training.HParams() - if tf.test.is_gpu_available(): - config.add_hparam("image_shape", (3, 16, 16)) - config.add_hparam("data_format", "channels_first") - config.add_hparam("g_init_shape", (32, 2, 2)) - else: - config.add_hparam("image_shape", (16, 16, 3)) - config.add_hparam("data_format", "channels_last") - config.add_hparam("g_init_shape", (2, 2, 32)) - - config.add_hparam("latent_dim", 16) - config.add_hparam("update_g_once_every", 1) - config.add_hparam("batch_size", 2) - config.add_hparam("d_init_filters", 4) - config.add_hparam("num_upsamples", 3) - # (32, 2, 2) -> (3, 16, 16) - return config diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops.py b/tensorflow/contrib/eager/python/examples/sagan/ops.py deleted file mode 100644 index 9a03cab1d12fc16baa7343f72ac58ccd39f698bc..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/sagan/ops.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Self-attention generative adversarial with eager execution. - -Auxiliary operations. - -Reference [Self-Attention Generative Adversarial -Networks](https://arxiv.org/pdf/1805.08318.pdf) -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - - -def flatten_hw(x, data_format="channels_first"): - """Flatten the input tensor across height and width dimensions.""" - if data_format == "channels_last": - x = tf.transpose(x, perm=[0, 3, 1, 2]) # Convert to `channels_first` - - old_shape = tf.shape(x) - new_shape = [old_shape[0], old_shape[2] * old_shape[3], old_shape[1]] - - return tf.reshape(x, new_shape) - - -def broaden_hw(x, h, w, c, data_format="channels_first"): - """Broaden dimension so that output has height and width.""" - if data_format == "channels_first": - shape = [-1, c, h, w] - else: - shape = [-1, h, w, c] - - return tf.reshape(x, shape) - - -class BroadenHW(tf.keras.layers.Layer): - """Wrapper class so that `broaden_hw` can be used in `tf.keras.Sequential`.""" - - def __init__(self, h, w, c, data_format="channels_first"): - super(BroadenHW, self).__init__() - self.h = h - self.w = w - self.c = c - self.data_format = data_format - - def call(self, x): - return broaden_hw( - x, h=self.h, w=self.w, c=self.c, data_format=self.data_format) - - def compute_output_shape(self, input_shape): - input_shape = tf.TensorShape(input_shape).as_list() - if self.data_format == "channels_first": - output_shape = (input_shape[0], self.c, self.h, self.w) - else: - output_shape = (input_shape[0], self.h, self.w, self.c) - - return tf.TensorShape(output_shape) diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops_test.py b/tensorflow/contrib/eager/python/examples/sagan/ops_test.py deleted file mode 100644 index 3454985904215b59d27fc4b76ccb4a8c2c2eff00..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/sagan/ops_test.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for auxiliary operations.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf -from tensorflow.contrib.eager.python.examples.sagan import ops - - -class OpsTest(tf.test.TestCase): - - def test_flatten_hw(self): - """Test `flatten_hw` function with mock object.""" - - batch_size = 1 - # Default NCHW format - if tf.test.is_gpu_available(): - x = tf.random_normal(shape=(batch_size, 3, 4, 4)) - y = ops.flatten_hw(x, data_format="channels_first") - self.assertEqual(y.shape, (batch_size, 4 * 4, 3)) - - # NHWC format - x = tf.random_normal(shape=(batch_size, 4, 4, 3)) - y = ops.flatten_hw(x, data_format="channels_last") - self.assertEqual(y.shape, (batch_size, 4 * 4, 3)) - - def test_broaden_hw(self): - """Test `broaden_hw` function with mock object.""" - - batch_size = 1 - # NHWC format - x = tf.random_normal(shape=[batch_size, 4 * 4 * 16]) - y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_last") - self.assertEqual(y.shape, (batch_size, 4, 4, 16)) - - # Default NCHW format - if tf.test.is_gpu_available(): - y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_first") - self.assertEqual(y.shape, (batch_size, 16, 4, 4)) - - -if __name__ == "__main__": - tf.enable_eager_execution() - tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan.py b/tensorflow/contrib/eager/python/examples/sagan/sagan.py deleted file mode 100644 index 561be36c911d7145e2d4a5ed12eccd8ceb054f45..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/sagan/sagan.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Self-attention generative adversarial with eager execution. - -Code for main model. - -Reference [Self-Attention Generative Adversarial -Networks](https://arxiv.org/pdf/1805.08318.pdf) -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf -from tensorflow.contrib.eager.python.examples.sagan import ops -tfe = tf.contrib.eager - - -class SelfAttentionModule(tf.keras.Model): - """Self-attention module composed of convolutional layers.""" - - def __init__(self, - attention_features, - original_features, - data_format="channels_first"): - """Initialize the module. - - Args: - attention_features: Number of filters for the attention computation. - original_features: Number of filters of the original Tensor. - data_format: Either 'channels_first' or 'channels_last' - """ - super(SelfAttentionModule, self).__init__() - self.data_format = data_format - # Matrix multiplication implemented as 2D Convolution - self.f = tf.keras.layers.Conv2D( - filters=attention_features, - kernel_size=1, - strides=(1, 1), - data_format=data_format) - self.g = tf.keras.layers.Conv2D( - filters=attention_features, - kernel_size=1, - strides=(1, 1), - data_format=data_format) - self.h = tf.keras.layers.Conv2D( - filters=original_features, - kernel_size=1, - strides=(1, 1), - data_format=data_format) - self.scale = tfe.Variable(0., trainable=True) - - def call(self, x): - f = self.f(x) - g = self.g(x) - h = self.h(x) - - f_flatten = ops.flatten_hw(f, data_format=self.data_format) - g_flatten = ops.flatten_hw(g, data_format=self.data_format) - h_flatten = ops.flatten_hw(h, data_format=self.data_format) - - s = tf.matmul(g_flatten, f_flatten, transpose_b=True) - b = tf.nn.softmax(s, axis=-1) - o = tf.matmul(b, h_flatten) - y = self.scale * tf.reshape(o, tf.shape(x)) + x - - return y - - def compute_output_shape(self, input_shape): - return input_shape - - -class SAGAN(tf.contrib.checkpoint.Checkpointable): - """Self-attention generative adversarial network.""" - - def __init__(self, config): - """Initialize the model. - - Args: - config: tf.contrib.training.HParams object; specifies hyperparameters - """ - super(SAGAN, self).__init__() - self.config = config - self.generator = self._construct_generator() - self.discriminator = self._construct_discriminator() - - def _construct_generator(self): - """Construct generator.""" - # TODO(lxuechen): Add spectral normalization for WGAN - axis = 1 if self.config.data_format == "channels_first" else 3 - - generator = tf.keras.Sequential() - generator.add( - tf.keras.layers.InputLayer(input_shape=(self.config.latent_dim,))) - generator.add( - tf.keras.layers.Dense( - units=np.prod(self.config.g_init_shape), activation=tf.nn.relu)) - - if self.config.data_format == "channels_first": - c, h, w = self.config.g_init_shape - else: - h, w, c = self.config.g_init_shape - - # Reshape to NHWC/NCHW - generator.add( - ops.BroadenHW(h=h, w=w, c=c, data_format=self.config.data_format)) - - filters_list = [c // 2**p for p in range(1, self.config.num_upsamples + 1)] - filters_list[-1] = 3 # Standard RGB images - - for filters in filters_list[:len(filters_list) // 2]: - generator.add( - tf.keras.layers.Conv2DTranspose( - filters=filters, - kernel_size=4, - strides=(2, 2), - use_bias=False, - padding="SAME", - data_format=self.config.data_format)) - generator.add(tf.keras.layers.BatchNormalization(axis=axis)) - generator.add(tf.keras.layers.Activation("relu")) - - # pylint: disable=undefined-loop-variable - generator.add( - SelfAttentionModule( - original_features=filters, - attention_features=filters // 8, - data_format=self.config.data_format)) - # pylint: enable=undefined-loop-variable - - for filters in filters_list[len(filters_list) // 2:]: - generator.add( - tf.keras.layers.Conv2DTranspose( - filters=filters, - kernel_size=4, - strides=(2, 2), - use_bias=False, - padding="SAME", - data_format=self.config.data_format)) - if filters == 3: - # Assume Image rescaled to [-1, 1] - generator.add(tf.keras.layers.Activation("tanh")) - else: - generator.add(tf.keras.layers.BatchNormalization(axis=axis)) - generator.add(tf.keras.layers.Activation("relu")) - - return generator - - def _construct_discriminator(self): - """Construct discriminator.""" - # TODO(lxuechen): Add spectral normalization for WGAN - discriminator = tf.keras.Sequential() - discriminator.add( - tf.keras.layers.InputLayer(input_shape=self.config.image_shape)) - - filters_list = [ - self.config.d_init_filters * 2**p - for p in range(self.config.num_upsamples) - ] - - for filters in filters_list[:(len(filters_list) + 1) // 2]: - discriminator.add( - tf.keras.layers.Conv2D( - filters=filters, - kernel_size=4, - strides=(2, 2), - padding="SAME", - data_format=self.config.data_format)) - discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1)) - - # pylint: disable=undefined-loop-variable - discriminator.add( - SelfAttentionModule( - original_features=filters, - attention_features=filters // 8, - data_format=self.config.data_format)) - # pylint: enable=undefined-loop-variable - - for filters in filters_list[(len(filters_list) + 1) // 2:]: - discriminator.add( - tf.keras.layers.Conv2D( - filters=filters, - kernel_size=4, - strides=(2, 2), - padding="SAME", - data_format=self.config.data_format)) - discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1)) - - discriminator.add(tf.keras.layers.Flatten()) - discriminator.add(tf.keras.layers.Dense(units=1)) - - return discriminator - - def compute_loss_and_grads(self, real_images, noise, training=True): - """Compute loss and gradients for both generator and discriminator.""" - # TODO(lxuechen): Add gradient penalty for discriminator - with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape: - real_logits = self.discriminator(real_images, training=training) - - fake_images = self.generator.call(noise, training=training) - fake_logits = self.discriminator.call(fake_images) - - g_loss = self.compute_g_loss(fake_logits) - d_loss = self.compute_d_loss(fake_logits, real_logits) - - g_grads = g_tape.gradient(g_loss, self.generator.trainable_variables) - d_grads = d_tape.gradient(d_loss, self.discriminator.trainable_variables) - - return g_loss, d_loss, g_grads, d_grads - - def compute_g_loss(self, fake_logits): - return -tf.reduce_mean(fake_logits) # Hinge loss - - def compute_d_loss(self, fake_logits, real_logits): - # Hinge loss - real_loss = tf.reduce_mean(tf.nn.relu(1. - real_logits)) - fake_loss = tf.reduce_mean(tf.nn.relu(1. + fake_logits)) - return real_loss + fake_loss diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py b/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py deleted file mode 100644 index 18345945108111b57c5401c26b7dca0bfc8f8316..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for self-attention generative adversarial network.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf -from tensorflow.contrib.eager.python.examples.sagan import config as config_ -from tensorflow.contrib.eager.python.examples.sagan import sagan -tfe = tf.contrib.eager - - -class SAGANTest(tf.test.TestCase): - - def setUp(self): - super(SAGANTest, self).setUp() - config = config_.get_hparams_mock() - self.noise_shape = (config.batch_size, config.latent_dim) - self.logits_shape = (config.batch_size, 1) - self.images_shape = (config.batch_size,) + config.image_shape - - self.model = sagan.SAGAN(config=config) - self.noise = tf.random_normal(shape=self.noise_shape) - self.real_images = tf.random_normal(shape=self.images_shape) - self.config = config - - def tearDown(self): - del self.model - del self.noise - del self.real_images - super(SAGANTest, self).tearDown() - - def test_generator_call(self): - """Test `generator.__call__` function.""" - fake_images = self.model.generator(self.noise, training=False) - self.assertEqual(fake_images.shape, self.images_shape) - - def test_generator_call_defun(self): - """Test `generator.__call__` function with defun.""" - call_ = tfe.defun(self.model.generator.__call__) - fake_images = call_(self.noise, training=False) - self.assertEqual(fake_images.shape, self.images_shape) - - def test_discriminator_call(self): - """Test `discriminator.__call__` function.""" - real_logits = self.model.discriminator(self.real_images) - self.assertEqual(real_logits.shape, self.logits_shape) - - def test_discriminator_call_defun(self): - """Test `discriminator.__call__` function with defun.""" - call_ = tfe.defun(self.model.discriminator.__call__) - real_logits = call_(self.real_images) - self.assertEqual(real_logits.shape, self.logits_shape) - - def test_compute_loss_and_grads(self): - """Test `compute_loss_and_grads` function.""" - g_loss, d_loss, g_grads, d_grads = self.model.compute_loss_and_grads( - self.real_images, self.noise, training=False) - self.assertEqual(g_loss.shape, ()) - self.assertEqual(d_loss.shape, ()) - self.assertTrue(isinstance(g_grads, list)) - self.assertTrue(isinstance(d_grads, list)) - g_vars = self.model.generator.trainable_variables - d_vars = self.model.discriminator.trainable_variables - - self.assertEqual(len(g_grads), len(g_vars)) - self.assertEqual(len(d_grads), len(d_vars)) - - def test_compute_loss_and_grads_defun(self): - """Test `compute_loss_and_grads` function with defun.""" - compute_loss_and_grads = tfe.defun(self.model.compute_loss_and_grads) - g_loss, d_loss, g_grads, d_grads = compute_loss_and_grads( - self.real_images, self.noise, training=False) - self.assertEqual(g_loss.shape, ()) - self.assertEqual(d_loss.shape, ()) - self.assertTrue(isinstance(g_grads, list)) - self.assertTrue(isinstance(d_grads, list)) - g_vars = self.model.generator.trainable_variables - d_vars = self.model.discriminator.trainable_variables - - self.assertEqual(len(g_grads), len(g_vars)) - self.assertEqual(len(d_grads), len(d_vars)) - - -if __name__ == "__main__": - tf.enable_eager_execution() - tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index 8ac553e0ae71382966d03d9ef4429adf5137b369..d18a097063c7d25947af3e2e2959ce574edd553f 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -36,7 +36,7 @@ from third_party.examples.eager.spinn import spinn from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import test from tensorflow.python.framework import test_util -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: enable=g-bad-import-order @@ -422,7 +422,7 @@ class SpinnTest(test_util.TensorFlowTestCase): # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) object_graph = checkpointable_utils.object_metadata( - saver.latest_checkpoint(config.logdir)) + checkpoint_management.latest_checkpoint(config.logdir)) ckpt_variable_names = set() for node in object_graph.nodes: for attribute in node.attributes: diff --git a/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb b/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..75cb3f8227fe90223734f422e458f15810b8089a --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb @@ -0,0 +1,282 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "TFE Workshop: control flow", + "version": "0.3.2", + "provenance": [], + "include_colab_link": true + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "[View in Colaboratory](https://colab.research.google.com/gist/alextp/664b2f8700485ff6801f4d26293bd567/tfe-workshop-control-flow.ipynb)" + ] + }, + { + "metadata": { + "id": "9BpQzh9BvJlj", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 37 + }, + "outputId": "0b336886-8204-4815-89fa-5291a49d5784" + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "import numpy as np\n", + "tf.enable_eager_execution()" + ], + "execution_count": 1, + "outputs": [] + }, + { + "metadata": { + "id": "0roIB19GvOjI", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Eager execution basics\n", + "\n", + "When eager execution is enabled TensorFlow immediately executes operations, and Tensors are always available. " + ] + }, + { + "metadata": { + "id": "jeO8F-V-vN24", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + }, + "outputId": "aeb3bdec-50b7-440d-93d8-5a171f091081" + }, + "cell_type": "code", + "source": [ + "t = tf.constant([[1, 2], [3, 4]])\n", + "t" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 2 + } + ] + }, + { + "metadata": { + "id": "Y17RwSFxvlDL", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + }, + "outputId": "cfcc10c7-707b-4997-99b3-a5f382c5166b" + }, + "cell_type": "code", + "source": [ + "tf.matmul(t, t)" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 3 + } + ] + }, + { + "metadata": { + "id": "Dab1bS3TvmRE", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "8a624f3d-a658-4359-c586-1c5f6bf4c8b7" + }, + "cell_type": "code", + "source": [ + "# It's also possible to have Python control flow which depends on the value of tensors.\n", + "if t[0, 0] > 0.5:\n", + " print(\"T is bigger\")\n", + "else:\n", + " print(\"T is smaller\")" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "text": [ + "T is bigger\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "dPgptJcGwIon", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "c4f27f2b-0848-4475-dde5-2534dac65a5c" + }, + "cell_type": "code", + "source": [ + "# Tensors are also usable as numpy arrays\n", + "np.prod(t)" + ], + "execution_count": 6, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "24" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 6 + } + ] + }, + { + "metadata": { + "id": "p3DTfQXnwXzj", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Exercise\n", + "\n", + "The algorithm for bisecting line search is a pretty simple way to find a zero of a continuous scalar function in an interval [a,b] where f(a) and f(b) have different signs. Simply evaluate f((a+b)/2), and narrow the interval by replacing either a or b with (a+b)/2 such that the function when applied on the boundary of the interval still has different signs.\n", + "\n", + "Implement a python function `bisecting_line_search(f, a, b, epsilon)` which returns a value such that `tf.abs(f(value)) < epsilon`.\n", + "\n", + "One thing to keep in mind: python's `==` opertor is not overloaded on Tensors, so you need to use `tf.equal` to compare for equality." + ] + }, + { + "metadata": { + "id": "6eq0YuI6ykm5", + "colab_type": "code", + "colab": {} + }, + "cell_type": "code", + "source": [ + "# Example test harness to get you going\n", + "\n", + "def test_f(x):\n", + " return x - 0.1234\n", + "def bisecting_line_search(f, a, b, epsilon):\n", + " # Return x such that f(x) <= epsilon.\n", + " pass\n", + "a = tf.constant(0.0)\n", + "b = tf.constant(1.0)\n", + "epsilon = tf.constant(0.001)\n", + "x = bisecting_line_search(test_f, a, b, epsilon)\n" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "LcMmEfd_xvej", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 170 + }, + "outputId": "f402aa50-8ce3-4416-f755-8bbcd1af7809" + }, + "cell_type": "code", + "source": [ + "#@title Double-click to see the solution\n", + "\n", + "def bisecting_line_search(f, a, b, epsilon):\n", + " f_a = f(a)\n", + " f_b = f(b)\n", + " probe = (a + b) / 2\n", + " f_probe = f(probe)\n", + " while tf.abs(f_probe) > epsilon:\n", + " if tf.equal(tf.sign(f_probe), tf.sign(f_a)):\n", + " a = probe\n", + " f_a = f_probe\n", + " else:\n", + " b = probe\n", + " f_b = f_probe\n", + " probe = (a + b) / 2\n", + " f_probe = f(probe)\n", + " print(\"new probe\", probe)\n", + " return probe\n", + "\n", + "bisecting_line_search(test_f, 0., 1., 0.001)" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "text": [ + "('new probe', 0.25)\n", + "('new probe', 0.125)\n", + "('new probe', 0.0625)\n", + "('new probe', 0.09375)\n", + "('new probe', 0.109375)\n", + "('new probe', 0.1171875)\n", + "('new probe', 0.12109375)\n", + "('new probe', 0.123046875)\n" + ], + "name": "stdout" + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.123046875" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 8 + } + ] + } + ] +} diff --git a/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb b/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..f3a65f5aab1fe683565caf21dcfa8054045fd759 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb @@ -0,0 +1,1018 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "TFE Workshop: Models.ipynb", + "version": "0.3.2", + "provenance": [], + "collapsed_sections": [], + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "[View in Colaboratory](https://colab.research.google.com/gist/alextp/5cfcffd408bd5103f5ae747bc97ab0b5/tfe-workshop-models.ipynb)" + ] + }, + { + "metadata": { + "id": "BMxv1O6Q0SJL", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "8be9c556-ac7f-4142-e35e-19dc2b097121" + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "tfe = tf.contrib.eager" + ], + "execution_count": 1, + "outputs": [] + }, + { + "metadata": { + "id": "lE1vJhxp0WR9", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Variables\n", + "\n", + "TensorFlow variables are useful to store the state in your program. They are integrated with other parts of the API (taking gradients, checkpointing, graph functions)." + ] + }, + { + "metadata": { + "id": "C4ztQNgc0VpW", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "8b63ae1f-2670-49c0-a31b-8cf7fc4194a1" + }, + "cell_type": "code", + "source": [ + "# Creating variables\n", + "v = tf.Variable(1.0)\n", + "v" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 2 + } + ] + }, + { + "metadata": { + "id": "H0daItGg1IAp", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "e47d5aab-16a1-4e29-c27d-7fbc0b94b5d3" + }, + "cell_type": "code", + "source": [ + "v.assign_add(1.0)\n", + "v" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 3 + } + ] + }, + { + "metadata": { + "id": "BJvBzcIG1hyK", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Layers: common sets of useful operations\n", + "\n", + "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n", + "\n", + "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n", + "\n", + "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n" + ] + }, + { + "metadata": { + "id": "iSQTS3QW1YQQ", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "c5d8aa10-dcad-44f7-f0eb-0faf5249fd7e" + }, + "cell_type": "code", + "source": [ + "# In the tf.keras.layers package, layers are objects. To construct a layer,\n", + "# simply construct the object. Most layers take as a first argument the number\n", + "# of output dimensions / channels.\n", + "layer = tf.keras.layers.Dense(100)\n", + "\n", + "# The number of input dimensions is often unnecessary, as it can be inferred\n", + "# the first time the layer is used, but it can be provided if you want to \n", + "# specify it manually, which is useful in some complex models.\n", + "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))\n" + ], + "execution_count": 4, + "outputs": [] + }, + { + "metadata": { + "id": "nRuUogoS1liV", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + }, + "outputId": "c352ce79-d519-45e4-a12e-1eaba76871a2" + }, + "cell_type": "code", + "source": [ + "layer(tf.zeros([2, 2]))" + ], + "execution_count": 5, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 5 + } + ] + }, + { + "metadata": { + "id": "JH4Kf4ka1mht", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 136 + }, + "outputId": "c34e2378-f83d-42c5-d30a-ebe55620368a" + }, + "cell_type": "code", + "source": [ + "layer.variables" + ], + "execution_count": 6, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[,\n", + " ]" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 6 + } + ] + }, + { + "metadata": { + "id": "DSI4NF0_1vn-", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n", + "Conv2D, LSTM, BatchNormalization, Dropout, and many others." + ] + }, + { + "metadata": { + "id": "hMgDBftJ12Bp", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Models: composing layers\n", + "\n", + "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n", + "\n", + "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model.\n" + ] + }, + { + "metadata": { + "id": "K3gVY6gj1nbe", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 190 + }, + "outputId": "6e9be0c4-960e-46c2-cdd9-7e94ad09d46b" + }, + "cell_type": "code", + "source": [ + "class ResnetIdentityBlock(tf.keras.Model):\n", + " def __init__(self, kernel_size, filters):\n", + " super(ResnetIdentityBlock, self).__init__(name='')\n", + " filters1, filters2, filters3 = filters\n", + "\n", + " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n", + " self.bn2a = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n", + " self.bn2b = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n", + " self.bn2c = tf.keras.layers.BatchNormalization()\n", + "\n", + " def call(self, input_tensor, training=False):\n", + " x = self.conv2a(input_tensor)\n", + " x = self.bn2a(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2b(x)\n", + " x = self.bn2b(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2c(x)\n", + " x = self.bn2c(x, training=training)\n", + "\n", + " x += input_tensor\n", + " return tf.nn.relu(x)\n", + " \n", + "block = ResnetIdentityBlock(1, [1, 2, 3])\n", + "print(block(tf.zeros([1, 2, 3, 3])))\n", + "print([x.name for x in block.variables])" + ], + "execution_count": 7, + "outputs": [ + { + "output_type": "stream", + "text": [ + "tf.Tensor(\n", + "[[[[0. 0. 0.]\n", + " [0. 0. 0.]\n", + " [0. 0. 0.]]\n", + "\n", + " [[0. 0. 0.]\n", + " [0. 0. 0.]\n", + " [0. 0. 0.]]]], shape=(1, 2, 3, 3), dtype=float32)\n", + "['resnet_identity_block/conv2d/kernel:0', 'resnet_identity_block/conv2d/bias:0', 'resnet_identity_block/batch_normalization/gamma:0', 'resnet_identity_block/batch_normalization/beta:0', 'resnet_identity_block/conv2d_1/kernel:0', 'resnet_identity_block/conv2d_1/bias:0', 'resnet_identity_block/batch_normalization_1/gamma:0', 'resnet_identity_block/batch_normalization_1/beta:0', 'resnet_identity_block/conv2d_2/kernel:0', 'resnet_identity_block/conv2d_2/bias:0', 'resnet_identity_block/batch_normalization_2/gamma:0', 'resnet_identity_block/batch_normalization_2/beta:0', 'resnet_identity_block/batch_normalization/moving_mean:0', 'resnet_identity_block/batch_normalization/moving_variance:0', 'resnet_identity_block/batch_normalization_1/moving_mean:0', 'resnet_identity_block/batch_normalization_1/moving_variance:0', 'resnet_identity_block/batch_normalization_2/moving_mean:0', 'resnet_identity_block/batch_normalization_2/moving_variance:0']\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "LPXhHUIc1-sO", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential" + ] + }, + { + "metadata": { + "id": "5pXgzNAU17xk", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 173 + }, + "outputId": "03b7eaf8-9b35-482b-bcf0-a99af6c2c6a4" + }, + "cell_type": "code", + "source": [ + " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(2, 1, \n", + " padding='same'),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(3, (1, 1)),\n", + " tf.keras.layers.BatchNormalization()])\n", + "my_seq(tf.zeros([1, 2, 3, 3]))\n" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 8 + } + ] + }, + { + "metadata": { + "id": "MZrns6p22GEQ", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Exercise!\n", + "\n", + "Make a simple convolutional neural network model, useful for things such as MNIST which don't need too many parameters. A sequence of two or three convolutions with small output channels (say, 32 and 64) plus one or two fully connected layers is probably enough.\n", + "\n", + "The input shape should be [batch_size, 28, 28, 1]." + ] + }, + { + "metadata": { + "id": "8CAUa3KNN916", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "97c0ff3c-c962-4c13-eee8-406101465761" + }, + "cell_type": "code", + "source": [ + "# TODO: Implement a convolutional model as described above, and assign it to\n", + "# model.\n", + "model = tf.keras.Sequential([\n", + " \n", + "])" + ], + "execution_count": 9, + "outputs": [] + }, + { + "metadata": { + "id": "vLDDduR32E82", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "09bb1d43-b4c6-44b5-916e-0d2903d10cf4" + }, + "cell_type": "code", + "source": [ + "#@title Click to see the answer\n", + "\n", + "max_pool = tf.keras.layers.MaxPooling2D(\n", + " (2, 2), (2, 2), padding='same')\n", + " # The model consists of a sequential chain of layers, so tf.keras.Sequential\n", + " # (a subclass of tf.keras.Model) makes for a compact description.\n", + "model = tf.keras.Sequential(\n", + " [\n", + " tf.keras.layers.Conv2D(\n", + " 32,\n", + " 5,\n", + " padding='same',\n", + " activation=tf.nn.relu),\n", + " max_pool,\n", + " tf.keras.layers.Conv2D(\n", + " 64,\n", + " 5,\n", + " padding='same',\n", + " activation=tf.nn.relu),\n", + " max_pool,\n", + " tf.keras.layers.Flatten(),\n", + " tf.keras.layers.Dense(1024, activation=tf.nn.relu),\n", + " tf.keras.layers.Dropout(0.4),\n", + " tf.keras.layers.Dense(10)\n", + " ])\n", + "\n", + "model(tf.zeros([1, 28, 28, 1]))" + ], + "execution_count": 10, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 10 + } + ] + }, + { + "metadata": { + "id": "H_CKVBroik4M", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Stop here for now" + ] + }, + { + "metadata": { + "id": "_yRwuE6MMmzC", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Training\n", + "\n", + "When eager execution is enabled, you can write Pythonic training loops. Simply\n", + "\n", + "1. load your data into a `tf.data.Dataset`, which lets you construct functional pipelines for processing, shuffling, and batching your data,\n", + "2. iterate over the dataset using a Python `for` loop, and\n", + "3. perform an optimization step in the body of your `for` loop.\n", + "\n", + "This workflow is exemplified in the following exercise." + ] + }, + { + "metadata": { + "id": "gj0-EkTc_Xt1", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "\n", + "\n", + "## Exercise!\n", + "\n", + "In this exercise, you'll train the convolutional model you implemented for the previous exericse on the MNIST dataset. " + ] + }, + { + "metadata": { + "id": "WOGm9HHn_byR", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "outputId": "bbccc7ad-33cd-446e-bcda-f358c7547e1b" + }, + "cell_type": "code", + "source": [ + "#@title Utilities for downloading MNIST data (double-click to show code)\n", + "import gzip\n", + "import os\n", + "import tempfile\n", + "from six.moves import urllib\n", + "import shutil\n", + "\n", + "import numpy as np\n", + "\n", + "def read32(bytestream):\n", + " \"\"\"Read 4 bytes from bytestream as an unsigned 32-bit integer.\"\"\"\n", + " dt = np.dtype(np.uint32).newbyteorder('>')\n", + " return np.frombuffer(bytestream.read(4), dtype=dt)[0]\n", + "\n", + "\n", + "def check_image_file_header(filename):\n", + " \"\"\"Validate that filename corresponds to images for the MNIST dataset.\"\"\"\n", + " with tf.gfile.Open(filename, 'rb') as f:\n", + " magic = read32(f)\n", + " read32(f) # num_images, unused\n", + " rows = read32(f)\n", + " cols = read32(f)\n", + " if magic != 2051:\n", + " raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,\n", + " f.name))\n", + " if rows != 28 or cols != 28:\n", + " raise ValueError(\n", + " 'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %\n", + " (f.name, rows, cols))\n", + "\n", + "\n", + "def check_labels_file_header(filename):\n", + " \"\"\"Validate that filename corresponds to labels for the MNIST dataset.\"\"\"\n", + " with tf.gfile.Open(filename, 'rb') as f:\n", + " magic = read32(f)\n", + " read32(f) # num_items, unused\n", + " if magic != 2049:\n", + " raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,\n", + " f.name))\n", + " \n", + "def download(directory, filename):\n", + " \"\"\"Download (and unzip) a file from the MNIST dataset if not already done.\"\"\"\n", + " filepath = os.path.join(directory, filename)\n", + " if tf.gfile.Exists(filepath):\n", + " return filepath\n", + " if not tf.gfile.Exists(directory):\n", + " tf.gfile.MakeDirs(directory)\n", + " # CVDF mirror of http://yann.lecun.com/exdb/mnist/\n", + " url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n", + " _, zipped_filepath = tempfile.mkstemp(suffix='.gz')\n", + " print('Downloading %s to %s' % (url, zipped_filepath))\n", + " urllib.request.urlretrieve(url, zipped_filepath)\n", + " with gzip.open(zipped_filepath, 'rb') as f_in, \\\n", + " tf.gfile.Open(filepath, 'wb') as f_out:\n", + " shutil.copyfileobj(f_in, f_out)\n", + " os.remove(zipped_filepath)\n", + " return filepath\n", + "\n", + "\n", + "def dataset(directory, images_file, labels_file):\n", + " \"\"\"Download and parse MNIST dataset.\"\"\"\n", + "\n", + " images_file = download(directory, images_file)\n", + " labels_file = download(directory, labels_file)\n", + "\n", + " check_image_file_header(images_file)\n", + " check_labels_file_header(labels_file)\n", + "\n", + " def decode_image(image):\n", + " # Normalize from [0, 255] to [0.0, 1.0]\n", + " image = tf.decode_raw(image, tf.uint8)\n", + " image = tf.cast(image, tf.float32)\n", + " image = tf.reshape(image, [28, 28, 1])\n", + " return image / 255.0\n", + "\n", + " def decode_label(label):\n", + " label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]\n", + " label = tf.reshape(label, []) # label is a scalar\n", + " return tf.to_int32(label)\n", + "\n", + " images = tf.data.FixedLengthRecordDataset(\n", + " images_file, 28 * 28, header_bytes=16).map(decode_image)\n", + " labels = tf.data.FixedLengthRecordDataset(\n", + " labels_file, 1, header_bytes=8).map(decode_label)\n", + " return tf.data.Dataset.zip((images, labels))\n", + "\n", + "\n", + "def get_training_data(directory):\n", + " \"\"\"tf.data.Dataset object for MNIST training data.\"\"\"\n", + " return dataset(directory, 'train-images-idx3-ubyte',\n", + " 'train-labels-idx1-ubyte').take(1024)\n", + "\n", + "def get_test_data(directory):\n", + " \"\"\"tf.data.Dataset object for MNIST test data.\"\"\"\n", + " return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')" + ], + "execution_count": 11, + "outputs": [] + }, + { + "metadata": { + "id": "4ejmJ2dv_f0R", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 85 + }, + "outputId": "274c0381-e505-4e69-f910-3def6f8572a7" + }, + "cell_type": "code", + "source": [ + "# Don't forget to run the cell above!\n", + "training_data = get_training_data(\"/tmp/mnist/train\")\n", + "test_data = get_test_data(\"/tmp/mnist/test\")" + ], + "execution_count": 12, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/tmp4ull1xwa.gz\n", + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/tmp1eikhj1v.gz\n", + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/tmpcp8xah9c.gz\n", + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/tmpqww_1e74.gz\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "TANpFS6GKLMC", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Fill in the implementation of `train_one_epoch` below and run the cell to train your model. " + ] + }, + { + "metadata": { + "id": "btKL0Ss9_rmC", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 102 + }, + "outputId": "56858516-86fc-424a-f00d-6f088f98bf9b" + }, + "cell_type": "code", + "source": [ + "EPOCHS = 5\n", + "optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.5)\n", + "\n", + "def loss_fn(logits, labels):\n", + " return tf.reduce_mean(\n", + " tf.nn.sparse_softmax_cross_entropy_with_logits(\n", + " logits=tf.squeeze(logits), labels=labels))\n", + "\n", + "def train_one_epoch(model, training_data, optimizer):\n", + " # TODO: Implement an optimization step and return the average loss.\n", + " #\n", + " # Hint: Use `tf.GradientTape` to compute the gradient of the loss, and use\n", + " # `optimizer.apply_gradients` to update the model's variables, which are\n", + " # accessible as `model.variables`\n", + " average_loss = tfe.metrics.Mean('loss')\n", + " for images, labels in training_data.shuffle(buffer_size=10000).batch(64):\n", + " pass\n", + " return average_loss.result()\n", + "\n", + "for epoch in range(EPOCHS):\n", + " loss = train_one_epoch(model, training_data, optimizer)\n", + " print(\"Average loss after epoch %d: %.4f\" % (epoch, loss))" + ], + "execution_count": 14, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Average loss after epoch 0: 2.2847\n", + "Average loss after epoch 1: 2.2305\n", + "Average loss after epoch 2: 2.1334\n", + "Average loss after epoch 3: 1.9115\n", + "Average loss after epoch 4: 1.4285\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "yAOFupJN_htg", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 102 + }, + "outputId": "67e711e4-76c9-4e3f-bb49-a14955dba03a" + }, + "cell_type": "code", + "source": [ + "#@title Double-click to see a solution.\n", + "EPOCHS = 5\n", + "optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.5)\n", + "\n", + "def _loss_fn(logits, labels):\n", + " return tf.reduce_mean(\n", + " tf.nn.sparse_softmax_cross_entropy_with_logits(\n", + " logits=tf.squeeze(logits), labels=labels))\n", + "\n", + "def _train_one_epoch(model, training_data):\n", + " average_loss = tfe.metrics.Mean(\"loss\")\n", + " for images, labels in training_data.shuffle(buffer_size=10000).batch(64):\n", + " with tf.GradientTape() as tape:\n", + " logits = model(images, training=True)\n", + " loss = _loss_fn(logits, labels)\n", + " average_loss(loss)\n", + " gradients = tape.gradient(loss, model.variables)\n", + " optimizer.apply_gradients(zip(gradients, model.variables))\n", + " return average_loss.result()\n", + " \n", + "for epoch in range(EPOCHS):\n", + " loss = _train_one_epoch(model, training_data)\n", + " print(\"Average loss after epoch %d: %.4f\" % (epoch, loss))" + ], + "execution_count": 15, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Average loss after epoch 0: 1.0563\n", + "Average loss after epoch 1: 0.8013\n", + "Average loss after epoch 2: 0.6306\n", + "Average loss after epoch 3: 0.5543\n", + "Average loss after epoch 4: 0.5037\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "uDy1DrYA_2Jz", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Run the below cell to qualitatively evaluate your model. Note how eager execution interoperates seamlessly with `matplotlib`." + ] + }, + { + "metadata": { + "id": "vR7rMtpu_3nB", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1752 + }, + "outputId": "b212aefa-f4b3-425c-f34d-2491429fa521" + }, + "cell_type": "code", + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "sampled_data = test_data.batch(1).shuffle(buffer_size=10000).take(5)\n", + "for image, label in sampled_data:\n", + " plt.figure()\n", + " plt.imshow(tf.reshape(image, (28, 28)))\n", + " plt.show()\n", + " logits = model(image, training=False)\n", + " prediction = tf.argmax(logits, axis=1, output_type=tf.int64)\n", + " print(\"Prediction: %d\" % prediction)" + ], + "execution_count": 16, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEwpJREFUeJzt3X1Ilff/x/HXmScxV2GZOmLVohXK\nKmLQjbUsy+pbI7rbaEm1IFhRSU1aE+kO3LqxCGrBMlsNkq0zZIM2Cu1mUTg1itXQbVnBQqKZNtcN\n2d3J3x9ffpLrNN/ndM65jn6fj7/m5cfrvI9XPHedc7zOcTU3NzcLAPCvXnJ6AABoD4glABgQSwAw\nIJYAYEAsAcCAWAKAAbEEAANiCQAG7kB/cOPGjbpw4YJcLpdyc3M1ZMiQYM4FABEloFieOXNGV69e\nlcfj0ZUrV5SbmyuPxxPs2QAgYgT0MLy8vFwZGRmSpP79++vWrVu6e/duUAcDgEgSUCwbGhrUvXv3\nlq979Oih+vr6oA0FAJEmKC/w8F4cADq6gGKZmJiohoaGlq9v3LihhISEoA0FAJEmoFiOHj1aJSUl\nkqTq6molJiaqS5cuQR0MACJJQK+Gv/nmm3rjjTf03nvvyeVyaf369cGeCwAiios3/wWAtnEFDwAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkA\nBsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgC\ngAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMHA7\nPQAQiAcPHpjX3rlzx+f2nj17qqGhodW2kydPmvb566+/mm//xx9/NK+13r4kjRgx4pltFRUVGjly\nZKttP/30k3mfL73E+dPz8JsBAIOAziwrKyu1YsUKDRgwQJI0cOBArV27NqiDAUAkCfhh+PDhw7Vz\n585gzgIAEYuH4QBgEHAsL1++rCVLlmju3LkqKysL5kwAEHFczc3Nzf7+UF1dnc6dO6cpU6aotrZW\nCxYsUGlpqaKjo0MxIwA4LqDnLJOSkjR16lRJUp8+fdSzZ0/V1dWpd+/eQR0OeB7+dIg/HQq3gH4z\nhw4d0hdffCFJqq+v182bN5WUlBTUwQAgkgR0Zjl+/HitWrVKx48f16NHj7RhwwYeggPo0AKKZZcu\nXbR79+5gzwIAESugF3gAf1RVVZnXfvfdd6Z1hw8fNu/zzJkzPrd7vV5FRUWZ99Me+LpPDx8+NP98\nR/t9BBPP5gKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAAM+3RGtPO/qV5fL\n1ep7BQUF5n1mZWWZ1z558sS8NhRcLpdpnT9vZebPJYT9+vUzry0pKfG5/Y8//mj1NW+7Fhz8FgHA\ngFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgCt40MrBgwd9bp87d26r7y1btsy8z1de\necW89q233jKte//99837/Dfff/99q68TExNNP/fqq6+ab8Of+x8MvXv3Duvt/a/gzBIADIglABgQ\nSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4mp/3CVXoMB49emRe+/rrr/vcfvXqVfXt\n27fl68zMTPM+P/74Y/PauLg481ognDizBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGx\nBAADYgkABny6YztVX19vXjthwgTz2oEDB5q+l5eXZ96n223/Z/b48WPTuuvXr5v3efz4cZ/bFy5c\nqC+//NK8n0CNHTvWvLZfv34hnAQvwnRmWVNTo4yMDBUVFUn67z/U+fPnKzMzUytWrNDDhw9DOiQA\nOK3NWN67d095eXlKTU1t2bZz505lZmbqq6++Ut++fVVcXBzSIQHAaW3GMjo6WoWFha0+fL6ysrLl\noV16errKy8tDNyEARIA2n0xyu93PPOfU1NSk6OhoSVJ8fLxfz58BQHv0wi/w8HaYzkhISDCv/eWX\nX4Jym0ePHg3Kfv6N9cWg3r17m/e5cOHCgL4HPC2gWMbGxur+/fuKiYlRXV1dq4foCI9QvRqelJTk\nc/vRo0c1ceLElq+PHDli3ievhvNqeEcQ0N9Zjho1SiUlJZKk0tJSjRkzJqhDAUCkafN/+VVVVdqy\nZYuuXbsmt9utkpISbdu2TTk5OfJ4POrVq5dmzJgRjlkBwDFtxnLQoEE6cODAM9v3798fkoEAIBLx\ngWXt1A8//GBeO3v2bPPa572Ik5aWplOnTrV8ff78efM+J02aZF5rnfX333837/N5vF6voqKiAvrZ\nd99917x20KBB5rWrVq0yr42JiTGvxYvj2nAAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAY\nEEsAMCCWAGDA5Y7tlD+X23377bcvfHv/vDTQn7cS8+ft1NLS0kzr/Ln/o0aN8rk9OTn5mcsmO3Xq\nZNrn7du3zbc/YsQI89q9e/ea1y5YsMC8Fi+OM0sAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyI\nJQAYEEsAMCCWAGDQ5kfhIjItXrzYvHb06NHmtRcvXnzu9z744IOW//bnUruhQ4ea11ovN3S7g/NP\nNzk5OaCfe/qTLtvi9XrNa/351E4udwwvziwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIAreNqpjIyMkKz9N59//nlQ9tMRPHjwwOkREGacWQKAAbEEAANiCQAGxBIADIglABgQSwAw\nIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADEyxrKmpUUZGhoqKiiRJOTk5mjZtmubPn6/58+fr5MmT\noZwRABzX5rsO3bt3T3l5eUpNTW21PTs7W+np6SEbDAAiSZtnltHR0SosLFRiYmI45gGAiNTmmaXb\n7Zbb/eyyoqIi7d+/X/Hx8Vq7dq169OgRkgGBSDRx4kTzWq/XG8JJEC4Bvfnv9OnTFRcXp5SUFO3Z\ns0e7du3SunXrgj0bELGOHj1qXvuf//zHvHb27Nnmtd988415LV5cQK+Gp6amKiUlRZI0fvx41dTU\nBHUoAIg0AcUyKytLtbW1kqTKykoNGDAgqEMBQKRp82F4VVWVtmzZomvXrsntdqukpETz5s3TypUr\n1blzZ8XGxmrTpk3hmBUAHNNmLAcNGqQDBw48s33y5MkhGQgAIhGf7ggEgAsx/vdwuSMAGBBLADAg\nlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADDgckcgAKdPnw7JfqdNmxaS/eLFcWYJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAZcwQM85dSpU6Z1P//8s3mfL7/8snntuHHj\nzGsRXpxZAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAAy53RIf3999/+9we\nFxf3zPcyMjJM+/R6vebbP3jwoHlt7969zWsRXpxZAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIBYAoABsQQAAy53DIMnT56Y1+bm5prWbdiwwbzPmJgY89r24u7du+a1b7/9ts/tZWVlz3zP\nehnjO++8Y7792bNnm9cicplimZ+fr3Pnzunx48davHixBg8erNWrV8vr9SohIUFbt25VdHR0qGcF\nAMe0GcuKigpdunRJHo9HjY2NmjlzplJTU5WZmakpU6Zo+/btKi4uVmZmZjjmBQBHtPmc5bBhw7Rj\nxw5JUrdu3dTU1KTKykpNmDBBkpSenq7y8vLQTgkADmszllFRUYqNjZUkFRcXKy0tTU1NTS0Pu+Pj\n41VfXx/aKQHAYeYXeI4dO6bi4mLt27dPkyZNatne3NwcksE6kpdesv/RwebNm0M4ScfRpUsX89qy\nsrKAvgc8zRTL06dPa/fu3dq7d6+6du2q2NhY3b9/XzExMaqrq1NiYmKo52zXeDU8+Px5NXzy5Mk+\nt5eVlWn06NGttlVUVJj26c+r4V9//bV5rT//Y0V4tXlk7ty5o/z8fBUUFCguLk6SNGrUKJWUlEiS\nSktLNWbMmNBOCQAOa/PM8vDhw2psbNTKlStbtm3evFlr1qyRx+NRr169NGPGjJAOCQBOazOWc+bM\n0Zw5c57Zvn///pAMBACRyNXMKzQh58+HW1n/uP/TTz817zM7Ozvotx8qv/32m2nd0qVLzfs8deqU\nz+1er1dRUVHm/TyturravDY5OTmg20Bk4dlkADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUA\nGBBLADAglgBgwOWOYeDP5Y4JCQmmdbdu3TLvc+LEiea148aN87k9Jycn4PfavH//vnntJ598Ylrn\nzz/bbt26+dze2Nio7t27t9p28eJF0z6tx0mSXC6XeS0iF2eWAGBALAHAgFgCgAGxBAADYgkABsQS\nAAyIJQAYEEsAMCCWAGBALAHAgMsdI0xxcbFp3bJly8z7bGhoCHScFi/ySYj++Oflh88zefJk8z4/\n+ugjn9uHDh2q8+fPP7MN8IUzSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAw4Aqe\ndqqmpsa8Njs727z2yJEjPre/yBU8q1evNq8dPHiwaV1mZmZAswCB4swSAAyIJQAYEEsAMCCWAGBA\nLAHAgFgCgAGxBAADYgkABsQSAAyIJQAYcLkjABi4LYvy8/N17tw5PX78WIsXL9aJEydUXV2tuLg4\nSdKiRYs0bty4UM4JAI5qM5YVFRW6dOmSPB6PGhsbNXPmTI0cOVLZ2dlKT08Px4wA4Lg2Yzls2DAN\nGTJEktStWzc1NTXJ6/WGfDAAiCR+PWfp8Xh09uxZRUVFqb6+Xo8ePVJ8fLzWrl2rHj16hHJOAHCU\nOZbHjh1TQUGB9u3bp6qqKsXFxSklJUV79uzRn3/+qXXr1oV6VgBwjOlPh06fPq3du3ersLBQXbt2\nVWpqqlJSUiRJ48eP9+uNaAGgPWozlnfu3FF+fr4KCgpaXv3OyspSbW2tJKmyslIDBgwI7ZQA4LA2\nX+A5fPiwGhsbtXLlypZts2bN0sqVK9W5c2fFxsZq06ZNIR0SAJzGH6UDgAGXOwKAAbEEAANiCQAG\nxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKA\nAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4nbjRjRs3\n6sKFC3K5XMrNzdWQIUOcGCOoKisrtWLFCg0YMECSNHDgQK1du9bhqQJXU1OjpUuXauHChZo3b56u\nX7+u1atXy+v1KiEhQVu3blV0dLTTY/rln/cpJydH1dXViouLkyQtWrRI48aNc3ZIP+Xn5+vcuXN6\n/PixFi9erMGDB7f74yQ9e79OnDjh+LEKeyzPnDmjq1evyuPx6MqVK8rNzZXH4wn3GCExfPhw7dy5\n0+kxXti9e/eUl5en1NTUlm07d+5UZmampkyZou3bt6u4uFiZmZkOTukfX/dJkrKzs5Wenu7QVC+m\noqJCly5dksfjUWNjo2bOnKnU1NR2fZwk3/dr5MiRjh+rsD8MLy8vV0ZGhiSpf//+unXrlu7evRvu\nMfAvoqOjVVhYqMTExJZtlZWVmjBhgiQpPT1d5eXlTo0XEF/3qb0bNmyYduzYIUnq1q2bmpqa2v1x\nknzfL6/X6/BUDsSyoaFB3bt3b/m6R48eqq+vD/cYIXH58mUtWbJEc+fOVVlZmdPjBMztdismJqbV\ntqamppaHc/Hx8e3umPm6T5JUVFSkBQsW6MMPP9Rff/3lwGSBi4qKUmxsrCSpuLhYaWlp7f44Sb7v\nV1RUlOPHypHnLJ/W3Nzs9AhB8dprr2n58uWaMmWKamtrtWDBApWWlrbL54va0lGO2fTp0xUXF6eU\nlBTt2bNHu3bt0rp165wey2/Hjh1TcXGx9u3bp0mTJrVsb+/H6en7VVVV5fixCvuZZWJiohoaGlq+\nvnHjhhISEsI9RtAlJSVp6tSpcrlc6tOnj3r27Km6ujqnxwqa2NhY3b9/X5JUV1fXIR7OpqamKiUl\nRZI0fvx41dTUODyR/06fPq3du3ersLBQXbt27TDH6Z/3KxKOVdhjOXr0aJWUlEiSqqurlZiYqC5d\nuoR7jKA7dOiQvvjiC0lSfX29bt68qaSkJIenCp5Ro0a1HLfS0lKNGTPG4YleXFZWlmprayX99znZ\n//9Lhvbizp07ys/PV0FBQcurxB3hOPm6X5FwrFzNDpyrb9u2TWfPnpXL5dL69euVnJwc7hGC7u7d\nu1q1apVu376tR48eafny5Ro7dqzTYwWkqqpKW7Zs0bVr1+R2u5WUlKRt27YpJydHDx48UK9evbRp\n0yZ16tTJ6VHNfN2nefPmac+ePercubNiY2O1adMmxcfHOz2qmcfj0WeffaZ+/fq1bNu8ebPWrFnT\nbo+T5Pt+zZo1S0VFRY4eK0diCQDtDVfwAIABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwOD/\nAKCzFeFbFn4BAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Prediction: 5\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEQ1JREFUeJzt3W9Ilff/x/HXSSd2VmKaRwiqjTBy\nq9gfap2iliaFQfRvsCXW1rpRRJGTCJG0MSHLIpbF8M9qN3L7cjZvNQiOVAQt7LQcBLqB1Y0QaXYs\naUa2mZ3fjS9ff7Vcvj2ec65jez7ueZ1P57wPlzy7Li8vjysUCoUEAHihcU4PAABjAbEEAANiCQAG\nxBIADIglABgQSwAwIJYAYEAsAcAgMdx/uH//fl27dk0ul0ulpaWaO3duJOcCgLgSViyvXLmiW7du\nyefz6ebNmyotLZXP54v0bAAQN8I6DW9ublZeXp4kacaMGbp//74ePHgQ0cEAIJ6EFcvu7m5NmjRp\n8Ou0tDQFg8GIDQUA8SYiF3j4WxwAXnZhxdLj8ai7u3vw6zt37igjIyNiQwFAvAkrlosWLZLf75ck\ntbW1yePxaMKECREdDADiSVhXw9955x29+eab+uijj+RyubRv375IzwUAccXFH/8FgOFxBw8AGBBL\nADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMEp0eAIgnP/30k2nd+vXrzc+Zl5dnXvvtt9+a1yK2OLIEAANiCQAGxBIADIgl\nABgQSwAwIJYAYEAsAcCAWAKAAbEEAAPu4AGecuzYMdO6YDBofk6XyxXuOIgjHFkCgEFYR5aBQEC7\ndu1SVlaWJGnmzJkqKyuL6GAAEE/CPg2fP3++qqurIzkLAMQtTsMBwCDsWN64cUPbtm3Thg0bdOnS\npUjOBABxxxUKhUIj/UddXV1qaWlRfn6+Ojo6tGnTJjU1NSkpKSkaMwKA48L6mWVmZqZWrlwpSZo2\nbZomT56srq4uTZ06NaLDAbH24Ycfmtb98MMP5ucsKCgwr21oaDCvRWyFdRp++vRpnThxQtJ/f9/s\n7t27yszMjOhgABBPwjqyzM3N1e7du3Xu3Dn19/fr888/5xQcwEstrFhOmDBBNTU1kZ4FAOIWtzsC\nT7lw4ULEn3PVqlURf07EHr9nCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIA\nDLjdES89v98/5PYVK1Y899hIPrXRqre3N+LPidjjyBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEE\nAANiCQAGxBIADLiDB2NSKBQyr21oaBhy+4oVK/7xsUh6++23o/4aiD6OLAHAgFgCgAGxBAADYgkA\nBsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgIErNJL7xoA40dnZaV47derUIbc/efJE48aFd7zw7rvv\nmtf+/PPPYb0G4gtHlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBPd8SY\nVFlZ6ejrb9682dHXR+yZjizb29uVl5c3+LGht2/f1saNG1VQUKBdu3bpr7/+iuqQAOC0YWP58OFD\nVVRUyOv1Dm6rrq5WQUGBvvvuO02fPl2NjY1RHRIAnDZsLJOSklRfXy+PxzO4LRAIaNmyZZKknJwc\nNTc3R29CAIgDw/7MMjExUYmJzy7r6+tTUlKSJCk9PV3BYDA60wFAnBj1BR7+HCaccPz48YisffLk\nSSTGwb9AWLF0u9169OiRkpOT1dXV9cwpOhALO3bsMK/96quvhtw+mj/+O5JYb9++PazXQHwJ6ztl\n4cKF8vv9kqSmpiYtXrw4okMBQLwZ9siytbVVBw8eVGdnpxITE+X3+3X48GGVlJTI5/NpypQpWrNm\nTSxmBQDHDBvL2bNn69SpU89t/+abb6IyEADEI+7gQVyxXnCJ1oeAWX/+XlhYGJXXR/zi3nAAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDA7Y6IKxUVFaZ10brd8dVXXzWt6+3t\nNT9nSkpKuOMgjnBkCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADLjdEXHl\nyy+/dPT1BwYGTOv8fr/5OT/99NNwx0Ec4cgSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkA\nBsQSAAy4gwdR99tvv5nXjuSDwKzcbrf5sV9++cX0nGlpaaOaCWMPR5YAYEAsAcCAWAKAAbEEAANi\nCQAGxBIADIglABgQSwAwIJYAYEAsAcCA2x0RFusHe0kj+xCyJ0+ehDPOC507d878GLcx4p9wZAkA\nBqZYtre3Ky8vTw0NDZKkkpISrVq1Shs3btTGjRt14cKFaM4IAI4b9jT84cOHqqiokNfrfWZ7cXGx\ncnJyojYYAMSTYY8sk5KSVF9fL4/HE4t5ACAuuUKhUMiy8NixY5o0aZIKCwtVUlKiYDCo/v5+paen\nq6ysjB+MA3iphXU1fPXq1UpNTVV2drbq6up0/PhxlZeXR3o2xLGRXA3fvn27eW19fX0447xQc3Pz\nkNvfe+89BQKB57YBQwnrarjX61V2drYkKTc3V+3t7REdCgDiTVix3Llzpzo6OiRJgUBAWVlZER0K\nAOLNsKfhra2tOnjwoDo7O5WYmCi/36/CwkIVFRVp/PjxcrvdqqysjMWsAOCYYWM5e/ZsnTp16rnt\nK1asiMpAABCPzFfDgafdu3fPvHby5MkRf/0PPvjAvPY///nPkNsTEhKeu1CVkJAwqrnw8uJ2RwAw\nIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYMCnO+IZ//TpiuPGjXvmsc2bN0fl\n9V0ul2ndF198YX7OF93CyO2NsOLIEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAM\nuIMHz/jf58H/3fTp05957Mcff4zK6xcWFprWzZo1KyqvD/wTjiwBwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBgQCwBwIBYAoABtzviGRcuXBhy+8cff/zMY6FQKCqvX15eHpXnBUaLI0sAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDgCkXrvjXEjV9//dW8ds6cOUNuHxgY\nUEJCwuDXI/m2Wb9+vXmtz+czrRs3jv/nEVume8OrqqrU0tKix48fa+vWrZozZ4727NmjgYEBZWRk\n6NChQ0pKSor2rADgmGFjefnyZV2/fl0+n089PT1au3atvF6vCgoKlJ+fryNHjqixsVEFBQWxmBcA\nHDHsucy8efN09OhRSVJKSor6+voUCAS0bNkySVJOTo6am5ujOyUAOGzYWCYkJMjtdkuSGhsbtWTJ\nEvX19Q2edqenpysYDEZ3SgBwmPnvWZ49e1aNjY06efKkli9fPrid60Px74033jCvHRgYCOsx4GVn\niuXFixdVU1Ojr7/+WhMnTpTb7dajR4+UnJysrq4ueTyeaM+JUeBqODB6w37H9fb2qqqqSrW1tUpN\nTZUkLVy4UH6/X5LU1NSkxYsXR3dKAHDYsEeWZ86cUU9Pj4qKiga3HThwQHv37pXP59OUKVO0Zs2a\nqA4JAE7jl9L/BTgNB0aPDyz7F7AGSHpxBJ9+LCUlxfycJ06cMK8lgohXfGcCgAGxBAADYgkABsQS\nAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADbnf8F7hx44Z5rfV2x+TkZPNzjuTWSCBecWQJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMuN3xX6C4uNi89vvvv//HxxIT///b\n5a233hrVTMBYw5ElABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4Qi/6hCoAgCSO\nLAHAhFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAAD06c7VlVVqaWlRY8f\nP9bWrVt1/vx5tbW1KTU1VZK0ZcsWLV26NJpzAoCjho3l5cuXdf36dfl8PvX09Gjt2rVasGCBiouL\nlZOTE4sZAcBxw8Zy3rx5mjt3riQpJSVFfX19GhgYiPpgABBPRvQn2nw+n65evaqEhAQFg0H19/cr\nPT1dZWVlSktLi+acAOAocyzPnj2r2tpanTx5Uq2trUpNTVV2drbq6ur0+++/q7y8PNqzAoBjTFfD\nL168qJqaGtXX12vixInyer3Kzs6WJOXm5qq9vT2qQwKA04aNZW9vr6qqqlRbWzt49Xvnzp3q6OiQ\nJAUCAWVlZUV3SgBw2LAXeM6cOaOenh4VFRUNblu3bp2Kioo0fvx4ud1uVVZWRnVIAHAan8EDAAbc\nwQMABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHA\ngFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsA\nMCCWAGCQ6MSL7t+/X9euXZPL5VJpaanmzp3rxBgRFQgEtGvXLmVlZUmSZs6cqbKyMoenCl97e7u2\nb9+uTz75RIWFhbp9+7b27NmjgYEBZWRk6NChQ0pKSnJ6zBH5+3sqKSlRW1ubUlNTJUlbtmzR0qVL\nnR1yhKqqqtTS0qLHjx9r69atmjNnzpjfT9Lz7+v8+fOO76uYx/LKlSu6deuWfD6fbt68qdLSUvl8\nvliPERXz589XdXW102OM2sOHD1VRUSGv1zu4rbq6WgUFBcrPz9eRI0fU2NiogoICB6ccmaHekyQV\nFxcrJyfHoalG5/Lly7p+/bp8Pp96enq0du1aeb3eMb2fpKHf14IFCxzfVzE/DW9ublZeXp4kacaM\nGbp//74ePHgQ6zHwAklJSaqvr5fH4xncFggEtGzZMklSTk6OmpubnRovLEO9p7Fu3rx5Onr0qCQp\nJSVFfX19Y34/SUO/r4GBAYenciCW3d3dmjRp0uDXaWlpCgaDsR4jKm7cuKFt27Zpw4YNunTpktPj\nhC0xMVHJycnPbOvr6xs8nUtPTx9z+2yo9yRJDQ0N2rRpkz777DPdu3fPgcnCl5CQILfbLUlqbGzU\nkiVLxvx+koZ+XwkJCY7vK0d+Zvm0UCjk9AgR8dprr2nHjh3Kz89XR0eHNm3apKampjH586LhvCz7\nbPXq1UpNTVV2drbq6up0/PhxlZeXOz3WiJ09e1aNjY06efKkli9fPrh9rO+np99Xa2ur4/sq5keW\nHo9H3d3dg1/fuXNHGRkZsR4j4jIzM7Vy5Uq5XC5NmzZNkydPVldXl9NjRYzb7dajR48kSV1dXS/F\n6azX61V2drYkKTc3V+3t7Q5PNHIXL15UTU2N6uvrNXHixJdmP/39fcXDvop5LBctWiS/3y9Jamtr\nk8fj0YQJE2I9RsSdPn1aJ06ckCQFg0HdvXtXmZmZDk8VOQsXLhzcb01NTVq8eLHDE43ezp071dHR\nIem/P5P9328yjBW9vb2qqqpSbW3t4FXil2E/DfW+4mFfuUIOHKsfPnxYV69elcvl0r59+zRr1qxY\njxBxDx480O7du/XHH3+ov79fO3bs0Pvvv+/0WGFpbW3VwYMH1dnZqcTERGVmZurw4cMqKSnRn3/+\nqSlTpqiyslKvvPKK06OaDfWeCgsLVVdXp/Hjx8vtdquyslLp6elOj2rm8/l07Ngxvf7664PbDhw4\noL17947Z/SQN/b7WrVunhoYGR/eVI7EEgLGGO3gAwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAg\nlgBg8H/nb4OLnfGqVAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Prediction: 1\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAE1ZJREFUeJzt3X1olfX/x/HXccc1DyrLuY1GaRGL\nRqZSaE7zZmqKgnhDsVwqkYGRE29QW8tp4M102solNJ03fzSqgyPoBmFDIlg1Jw0xNsrZDbKGranD\nG5x3x33/+NF+rp153js751znrOfjv13n43Xex4NPrrPL61yujo6ODgEA7muA0wMAQCwglgBgQCwB\nwIBYAoABsQQAA2IJAAbEEgAMiCUAGLiD/YM7duzQ6dOn5XK5lJ+fr9GjR4dyLgCIKkHF8uTJkzp3\n7py8Xq9+++035efny+v1hno2AIgaQX0Mr6mp0cyZMyVJjz/+uC5fvqxr166FdDAAiCZBxfLChQt6\n8MEHO38eNmyYWltbQzYUAESbkJzg4bs4APR3QcUyJSVFFy5c6Pz577//VnJycsiGAoBoE1QsJ02a\npMrKSklSQ0ODUlJSNHjw4JAOBgDRJKiz4c8884yeeuopvfzyy3K5XNqyZUuo5wKAqOLiy38BIDCu\n4AEAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUA\nGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJ\nAAbEEgAMiCUAGLiD+UO1tbVavXq10tPTJUlPPPGECgoKQjoYAESToGIpSePHj1dJSUkoZwGAqMXH\ncAAwCDqWv/76q9544w0tXrxY33//fShnAoCo4+ro6Ojo7R9qaWlRXV2d5syZo6amJi1btkxVVVWK\nj48Px4wA4LigjixTU1M1d+5cuVwujRgxQsOHD1dLS0uoZwOAqBFULL/88ksdOnRIktTa2qqLFy8q\nNTU1pIMBQDQJ6mP4tWvXtH79el25ckW3b99Wbm6upk6dGo75ACAqBBVLAPivCfr/WQL90alTp0zr\nSktLzfssKysLdpz78nec09HRIZfL1WVbbm6ueZ+9+b/T/36e/o7/ZwkABsQSAAyIJQAYEEsAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAy4Nhz93tmzZ/1uT09P7/bY4sWLTfu0XhYZaT6fT3FxcUH/\n+Vu3bpnX9uV5YhFHlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgwA3LEHa9uUjs\nzJkzpnXz588377Opqcnv9uvXr2vMmDFdtt28edO8Xyu32/7PrKCgwLw2Pj7e7/bCwsIuPz/77LPm\nfQ4YwPFTT/ibAQADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABtywDEG5ffu2\nee1bb71lXrt3795gxgmKv5t7PfTQQ6Y/u3r1avPzLF++3Lz2yJEj5rW5ubndtj3wwAPdLtl84IEH\nzPtEzziyBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABtzdEV3cvXvX7/YB\nAwZ0eSwvL8+8z0hewujPokWLzI999NFHpn16PB7z8y9evNi89uuvvzavbW5u7ratuLhYb7/9drdt\n6DvTkWVjY6Nmzpyp8vJySdL58+e1dOlS5eTkaPXq1bp161ZYhwQApwWM5fXr17V161ZlZmZ2bisp\nKVFOTo4++eQTjRw5UhUVFWEdEgCcFjCW8fHxKisrU0pKSue22tpazZgxQ5KUlZWlmpqa8E0IAFEg\n4O8s3W633O6uy9rb2xUfHy9JSkpKUmtra3imA4Ao0ecTPHwdZv8yYEDPHzbufey9994z77M3ayPt\n6NGjYX+OL774IuzPcS9O6IRHULH0eDy6ceOGEhIS1NLS0uUjOmKb9Wz4hg0bzPv84IMP+jxXX/R0\nNvzo0aN66aWXumyLpbPh/r6AuLi4WOvWreu2DX0X1P+znDhxoiorKyVJVVVVmjx5ckiHAoBoE/DI\nsr6+Xrt27VJzc7PcbrcqKyu1Z88e5eXlyev1Ki0tTQsWLIjErADgmICxHDVqlD7++ONu23tzrxAA\niHVcwfMf8Ndff5nXzpo1y+/2n376SWPHju38uaGhoc9z+TN06FDTutLSUvM+X3zxxR4f++yzz7r8\nfL8TXPf69NNPzc/fm99D9kZaWlqvtqNvuDYcAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkA\nBsQSAAyIJQAYuDr4QsqYdPXqVfPaUaNGmdf++eeffrf7fD7FxcWZ93Ovf75V3+LQoUOmdY888khQ\nswRivUVKdnZ2WJ7/ny/Vtjh16lS3bU8++aR++eWXbtvQdxxZAoABsQQAA2IJAAbEEgAMiCUAGBBL\nADAglgBgQCwBwIBYAoABsQQAA+7uGKPKy8vNa3u6hLEvlixZYl67Z88e89rk5GTTupaWFvM+X3/9\ndb/bv/rqK82bN6/LtsrKSvN+w6E3d43s6TJGLm8MD44sAcCAWAKAAbEEAANiCQAGxBIADIglABgQ\nSwAwIJYAYEAsAcCAG5ZFmbt375rWvfDCC+Z9fvvtt+a1Pd0wq729XYMGDer8ubGx0bzPtLQ089qf\nf/7ZtG7Dhg3mfVZVVfnd3pebsIXLjRs3zGsHDhwYxknwbxxZAoABsQQAA2IJAAbEEgAMiCUAGBBL\nADAglgBgQCwBwIBYAoABsQQAA25YFmWsV5/25hLG3vD5fKbHiouLzfv8448/zGu/+uor89pYsWDB\nAvPaaLv8Ev+PI0sAMDDFsrGxUTNnzuy8/WpeXp7mzZunpUuXaunSpWE7ygGAaBHwY/j169e1detW\nZWZmdtm+bt06ZWVlhW0wAIgmAY8s4+PjVVZWppSUlEjMAwBRKeCRpdvtltvdfVl5ebmOHDmipKQk\nFRQUaNiwYWEZ8L/G+gv++52ICZdbt25F/DnDzYm/R8SmoM6Gz58/X4mJicrIyNCBAwe0b98+bd68\nOdSz/SdZ//H29CW9fdVTrG/dutXlOVeuXGneZ7SeDY/Ul//25mz40aNHzWsHDOD8bCQF9bedmZmp\njIwMSdL06dN79a3ZABCLgorlqlWr1NTUJEmqra1Venp6SIcCgGgT8GN4fX29du3apebmZrndblVW\nVmrJkiVas2aNBg0aJI/Ho8LCwkjMCgCOCRjLUaNG6eOPP+62ffbs2WEZCACiEZc7ogvr5Y4lJSWR\nGKdf6M0JHk7aRC/eGQAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYMDljlHG\nernbsWPHzPvszeV24fiC3958MfT69etN6/Lz84MdJyS2bdtmXvvKK6+EcRJECkeWAGBALAHAgFgC\ngAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDAFTxRxuVymdb15u6ap06dMq+9dOlSj49VV1eb93Ov\nsWPHmtfW1dUF9RyhMmbMGNO6lStXmvfJTcj6B95FADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBg4Oro6Ohwegj0b21tbea1kyZNMq07c+ZMsON08vl8iouL67Lthx9+MP3Z5557\nrs/Pj9jCkSUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADDg7o4Iu5MnT5rX\nhuIyxn/Ly8szPzZ+/PiQPz/6B1Msi4qKVFdXpzt37mjFihV6+umntXHjRvl8PiUnJ2v37t2Kj48P\n96wA4JiAsTxx4oTOnj0rr9ertrY2LVy4UJmZmcrJydGcOXNUXFysiooK5eTkRGJeAHBEwN9Zjhs3\nTnv37pUkDR06VO3t7aqtrdWMGTMkSVlZWaqpqQnvlADgsICxjIuLk8fjkSRVVFRoypQpam9v7/zY\nnZSUpNbW1vBOCQAOM5/gOX78uCoqKnT48GHNmjWrcztfh4lAZs+ebV7r8/nCOEl327dvj+jzIXaZ\nYlldXa3S0lIdPHhQQ4YMkcfj0Y0bN5SQkKCWlhalpKSEe07EsMrKSvPauXPnhvz5ezobvn37dr3z\nzjtdtm3bts20T5fL1ee5EFsCfgy/evWqioqKtH//fiUmJkqSJk6c2PkPoKqqSpMnTw7vlADgsIBH\nlseOHVNbW5vWrFnTuW3nzp3atGmTvF6v0tLStGDBgrAOCQBOCxjL7OxsZWdnd9t+5MiRsAwEANGI\nG5YhKL25CVlGRoZ5bTj+Z8Xvv//ud/vIkSN17ty5btsAf7g2HAAMiCUAGBBLADAglgBgQCwBwIBY\nAoABsQQAA2IJAAbEEgAMiCUAGHDDMgSlrKzMvDYclzDm5uaa16alpQX1GHAvjiwBwIBYAoABsQQA\nA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABlzuiizt37vjd7na7uzz2+eefh+X5V61aZVr3\n/vvvm/fpcrl6fGzgwIHm/eC/jSNLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADBw\ndXR0dDg9BKLHd99953f7888/3+WxqVOnmvf58MMPm9eeOXPGtC4hIcG8TyAUOLIEAANiCQAGxBIA\nDIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAG3LAMXQwZMiSox+5ny5Yt5rVcxohoZYplUVGR\n6urqdOfOHa1YsULffPONGhoalJiYKElavny5pk2bFs45AcBRAWN54sQJnT17Vl6vV21tbVq4cKEm\nTJigdevWKSsrKxIzAoDjAsZy3LhxGj16tCRp6NCham9vl8/nC/tgABBNAp7giYuLk8fjkSRVVFRo\nypQpiouLU3l5uZYtW6a1a9fq0qVLYR8UAJxk/j7L48ePa//+/Tp8+LDq6+uVmJiojIwMHThwQH/9\n9Zc2b94c7lkBwDGmEzzV1dUqLS3VwYMHNWTIEGVmZnY+Nn36dL377rvhmg8Rdvr0ab/bx4wZ0+Wx\nZ555xrzPsrIy89rXXnvNvBaIpIAfw69evaqioiLt37+/8+z3qlWr1NTUJEmqra1Venp6eKcEAIcF\nPLI8duyY2tratGbNms5tixYt0po1azRo0CB5PB4VFhaGdUgAcFrAWGZnZys7O7vb9oULF4ZlIACI\nRlzuCAAG3N0RAAw4sgQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQA\nA2IJAAbEEgAM3E486Y4dO3T69Gm5XC7l5+dr9OjRTowRUrW1tVq9erXS09MlSU888YQKCgocnip4\njY2NevPNN/Xqq69qyZIlOn/+vDZu3Cifz6fk5GTt3r1b8fHxTo/ZK/9+TXl5eWpoaFBiYqIkafny\n5Zo2bZqzQ/ZSUVGR6urqdOfOHa1YsUJPP/10zL9PUvfX9c033zj+XkU8lidPntS5c+fk9Xr122+/\nKT8/X16vN9JjhMX48eNVUlLi9Bh9dv36dW3dulWZmZmd20pKSpSTk6M5c+aouLhYFRUVysnJcXDK\n3vH3miRp3bp1ysrKcmiqvjlx4oTOnj0rr9ertrY2LVy4UJmZmTH9Pkn+X9eECRMcf68i/jG8pqZG\nM2fOlCQ9/vjjunz5sq5duxbpMXAf8fHxKisrU0pKSue22tpazZgxQ5KUlZWlmpoap8YLir/XFOvG\njRunvXv3SpKGDh2q9vb2mH+fJP+vy+fzOTyVA7G8cOGCHnzwwc6fhw0bptbW1kiPERa//vqr3njj\nDS1evFjff/+90+MEze12KyEhocu29vb2zo9zSUlJMfee+XtNklReXq5ly5Zp7dq1unTpkgOTBS8u\nLk4ej0eSVFFRoSlTpsT8+yT5f11xcXGOv1eO/M7yXh0dHU6PEBKPPvqocnNzNWfOHDU1NWnZsmWq\nqqqKyd8XBdJf3rP58+crMTFRGRkZOnDggPbt26fNmzc7PVavHT9+XBUVFTp8+LBmzZrVuT3W36d7\nX1d9fb3j71XEjyxTUlJ04cKFzp///vtvJScnR3qMkEtNTdXcuXPlcrk0YsQIDR8+XC0tLU6PFTIe\nj0c3btyQJLW0tPSLj7OZmZnKyMiQJE2fPl2NjY0OT9R71dXVKi0tVVlZmYYMGdJv3qd/v65oeK8i\nHstJkyapsrJSktTQ0KCUlBQNHjw40mOE3JdffqlDhw5JklpbW3Xx4kWlpqY6PFXoTJw4sfN9q6qq\n0uTJkx2eqO9WrVqlpqYmSf/3O9l//idDrLh69aqKioq0f//+zrPE/eF98ve6ouG9cnU4cKy+Z88e\n/fjjj3K5XNqyZYuefPLJSI8QcteuXdP69et15coV3b59W7m5uZo6darTYwWlvr5eu3btUnNzs9xu\nt1JTU7Vnzx7l5eXp5s2bSktLU2FhoQYOHOj0qGb+XtOSJUt04MABDRo0SB6PR4WFhUpKSnJ6VDOv\n16sPP/xQjz32WOe2nTt3atOmTTH7Pkn+X9eiRYtUXl7u6HvlSCwBINZwBQ8AGBBLADAglgBgQCwB\nwIBYAoABsQQAA2IJAAbEEgAM/gepgR0uaefKmwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Prediction: 4\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEelJREFUeJzt3W9MlfX/x/HXEWJyhg5BIG1ZfR0u\nKr3hhopOE2Q23FxiN0xCdNmGa5pG6hhTtNn8g85NtI0/aS1Z29moG96wILM2dYDKDRu0hrpyzCkC\nkUocDeH8brQfk8R4czyH64DPx624+Hid99nFnl2H61wHl8/n8wkA8J/GOD0AAIwExBIADIglABgQ\nSwAwIJYAYEAsAcCAWAKAAbEEAINwf//h7t27denSJblcLhUUFGjGjBmBnAsAQopfsTx//ryuXbsm\nj8ejq1evqqCgQB6PJ9CzAUDI8OtleE1NjdLT0yVJU6dO1e3bt9XZ2RnQwQAglPgVy7a2Nk2YMKHv\n65iYGLW2tgZsKAAINQG5wMNncQAY7fyKZXx8vNra2vq+vnXrluLi4gI2FACEGr9iOW/ePFVVVUmS\nGhsbFR8fr6ioqIAOBgChxK+r4TNnztSrr76qt99+Wy6XSzt27Aj0XAAQUlx8+C8ADI47eADAgFgC\ngAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAz8+lO4AJz3yy+/PLLtlVdeeWT777//bt7ne++9Z147f/58\n0zqPx2PeZyjjzBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4fD6fz+kh\ngNHsr7/+Mq+tr683r33rrbce2dba2qq4uLh+29rb2837XL16tXntp59+alrndrvN+wxlnFkCgAGx\nBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAF/sAzww/37981rMzMzzWtPnTplXvu4O2O8\nXm+/rysrK837XLJkiXnt2LFjzWtHA84sAcDArzPLuro6bdy4UYmJiZKkadOmafv27QEdDABCid8v\nw2fNmqXi4uJAzgIAIYuX4QBg4Hcsr1y5onXr1mnlypU6d+5cIGcCgJDj1+dZtrS0qL6+XhkZGWpu\nblZOTo6qq6sVERERjBkBwHF+/c4yISGh7y0GU6ZM0cSJE9XS0qLnn38+oMMBoWoobx1aunSpee2T\nvnWos7NTUVFR/bZ9+eWX5n3y1qHH8+tl+IkTJ3T06FFJ/3wyc3t7uxISEgI6GACEEr/OLNPS0rR5\n82b98MMP6u7u1s6dO3kJDmBU8yuWUVFRKikpCfQsABCyuN0ReIj1vcNbtmwx77O7u9u8dii/9//x\nxx8H3P7zzz/3+/p///ufeZ94PN5nCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAG\nxBIADPz6PEvAaT09Pea1x48fH3D7mjVr9MUXX/Tblpuba9pnb2+v+fE/+eQT89qcnBzz2kmTJpnX\n4slxZgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABtzBgxHpcXflDGT16tUDbu/t\n7dWYMf6dL+zcudO8trCw0K/HQGjhzBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIA\nDIglABhwuyNCSnFxsWndRx99ZN7n4/642UC3O77zzjumff77D539l7CwMPNahC7OLAHAgFgCgAGx\nBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAG3OyLovF6vee2kSZNM6+7cuePvOH0Gut2x\npqbG9G9nz579xI+PkcV0ZtnU1KT09HRVVFRIkm7cuKFVq1YpKytLGzdu1N9//x3UIQHAaYPGsqur\nS7t27VJKSkrftuLiYmVlZemrr77SCy+8oMrKyqAOCQBOGzSWERERKi8vV3x8fN+2uro6LVq0SJKU\nmppqfukCACNV+KALwsMVHt5/mdfrVUREhCQpNjZWra2twZkOAELEoLEcDNeHMJjIyEjz2j///DOI\nkzyqt7d3WB8PI5dfsXS73bp3757Gjh2rlpaWfi/RgX/jajhGA7/eZzl37lxVVVVJkqqrqzV//vyA\nDgUAoWbQM8uGhgbt27dP169fV3h4uKqqqnTgwAHl5+fL4/Fo8uTJWrZs2XDMCgCO4U3pCDpehmM0\neOILPHg6ffvtt+a1hw4dMq8NRASfRElJiWkdsXz6cG84ABgQSwAwIJYAYEAsAcCAWAKAAbEEAANi\nCQAGxBIADIglABgQSwAw4HZH+MV6W6D0zydTWU2ZMsW07v79++Z9trS0mNcCj8OZJQAYEEsAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMOB2R/Rz4cKFAbcnJyf3+15tbW1QHv/77783\nrRvKX4FMTk72dxygD2eWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDAHTzoZ8GC\nBQNu93q9/b43lD8YNhTWP1jm9XqD8vjA43BmCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEE\nAANiCQAGxBIADLjd8Slw5MgR89r/uo3R31scZ8yYYV7rcrn8eoxAuXnzpmldV1eXeZ9ut9vfcRBC\nOLMEAANTLJuampSenq6KigpJUn5+vpYuXapVq1Zp1apV+umnn4I5IwA4btCX4V1dXdq1a5dSUlL6\nbc/Ly1NqamrQBgOAUDLomWVERITKy8sVHx8/HPMAQEhy+Xw+n2Xh4cOHNWHCBGVnZys/P1+tra3q\n7u5WbGystm/frpiYmGDPCgCO8etq+Jtvvqno6GglJSWprKxMR44cUWFhYaBnQ4AM5Wr4Bx98MOD2\n3t5ejRnj3/XAoVwNP3/+vGndUK5GP+5/5AM9pzfeeMO0z6+//tr8+FwNHx38+ulPSUlRUlKSJCkt\nLU1NTU0BHQoAQo1fsdywYYOam5slSXV1dUpMTAzoUAAQagZ9Gd7Q0KB9+/bp+vXrCg8PV1VVlbKz\ns7Vp0yZFRkbK7XZrz549wzErADhm0Fi+9tprOn78+CPbrb/bAYDRgNsdnwLt7e2OPv6WLVvMayMi\nIkzrhnKBZyiqqqpM63799VfzPmfOnOnvOAgh3O4IAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMuN0RfomNjTWvTU5ODvjjnz17NuD7lNT30YODee6554Ly+AhdnFkCgAGxBAAD\nYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAF38MAv48ePN6999tlnA/74FRUVAd+nJM2aNcu0\nLiEhISiPj9DFmSUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADDgdkf45bff\nfjOv/eabb8xrs7OzTet6e3vN+/T5fH59D3gYZ5YAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQ\nSwAwIJYAYEAsAcCA2x0RdO+++25Q1lq5XC6/vgc8zBTLoqIi1dfX68GDB8rNzdX06dO1detW9fT0\nKC4uTvv371dERESwZwUAxwway9raWl2+fFkej0cdHR3KzMxUSkqKsrKylJGRoYMHD6qyslJZWVnD\nMS8AOGLQ31kmJyfr0KFDkqTx48fL6/Wqrq5OixYtkiSlpqaqpqYmuFMCgMMGjWVYWJjcbrckqbKy\nUgsWLJDX6+172R0bG6vW1tbgTgkADjNf4Dl16pQqKyt17NgxLV68uG87nwcY+nbs2BGQtUP5DMmR\nYjQ+JwSHKZZnzpxRSUmJPvvsM40bN05ut1v37t3T2LFj1dLSovj4+GDPiSfw8ccfP/Ha3t5ejRkz\nut5pNtBzWr16tenffv7558EYCSFs0J/+u3fvqqioSKWlpYqOjpYkzZ07V1VVVZKk6upqzZ8/P7hT\nAoDDBj2zPHnypDo6OrRp06a+bXv37tW2bdvk8Xg0efJkLVu2LKhDAoDTBo3lihUrtGLFike28zIE\nwNOEO3ieAnl5eea1Fy5ceOz3lixZ0vffZ8+eNe/zzp075rVAqBpdv7EHgCAhlgBgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg4PLxgZTww3fffWde+/Btkk543I+4z+d75A+W1dbWmvY5\ne/bsJ54LIwtnlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIDbHQHAgDNL\nADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAg3DLoqKiItXX1+vBgwfKzc3V6dOn1djYqOjoaEnS\n2rVrtXDhwmDOCQCOGjSWtbW1unz5sjwejzo6OpSZmak5c+YoLy9PqampwzEjADhu0FgmJydrxowZ\nkqTx48fL6/Wqp6cn6IMBQChx+Xw+n3Wxx+PRxYsXFRYWptbWVnV3dys2Nlbbt29XTExMMOcEAEeZ\nY3nq1CmVlpbq2LFjamhoUHR0tJKSklRWVqabN2+qsLAw2LMCgGNMV8PPnDmjkpISlZeXa9y4cUpJ\nSVFSUpIkKS0tTU1NTUEdEgCcNmgs7969q6KiIpWWlvZd/d6wYYOam5slSXV1dUpMTAzulADgsEEv\n8Jw8eVIdHR3atGlT37bly5dr06ZNioyMlNvt1p49e4I6JAA4bUgXeADgacUdPABgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbhTjzo7t27\ndenSJblcLhUUFGjGjBlOjBFQdXV12rhxoxITEyVJ06ZN0/bt2x2eyn9NTU16//33tWbNGmVnZ+vG\njRvaunWrenp6FBcXp/379ysiIsLpMYfk388pPz9fjY2Nio6OliStXbtWCxcudHbIISoqKlJ9fb0e\nPHig3NxcTZ8+fcQfJ+nR53X69GnHj9Wwx/L8+fO6du2aPB6Prl69qoKCAnk8nuEeIyhmzZql4uJi\np8d4Yl1dXdq1a5dSUlL6thUXFysrK0sZGRk6ePCgKisrlZWV5eCUQzPQc5KkvLw8paamOjTVk6mt\nrdXly5fl8XjU0dGhzMxMpaSkjOjjJA38vObMmeP4sRr2l+E1NTVKT0+XJE2dOlW3b99WZ2fncI+B\n/xAREaHy8nLFx8f3baurq9OiRYskSampqaqpqXFqPL8M9JxGuuTkZB06dEiSNH78eHm93hF/nKSB\nn1dPT4/DUzkQy7a2Nk2YMKHv65iYGLW2tg73GEFx5coVrVu3TitXrtS5c+ecHsdv4eHhGjt2bL9t\nXq+37+VcbGzsiDtmAz0nSaqoqFBOTo4+/PBD/fHHHw5M5r+wsDC53W5JUmVlpRYsWDDij5M08PMK\nCwtz/Fg58jvLh/l8PqdHCIgXX3xR69evV0ZGhpqbm5WTk6Pq6uoR+fuiwYyWY/bmm28qOjpaSUlJ\nKisr05EjR1RYWOj0WEN26tQpVVZW6tixY1q8eHHf9pF+nB5+Xg0NDY4fq2E/s4yPj1dbW1vf17du\n3VJcXNxwjxFwCQkJWrJkiVwul6ZMmaKJEyeqpaXF6bECxu126969e5KklpaWUfFyNiUlRUlJSZKk\ntLQ0NTU1OTzR0J05c0YlJSUqLy/XuHHjRs1x+vfzCoVjNeyxnDdvnqqqqiRJjY2Nio+PV1RU1HCP\nEXAnTpzQ0aNHJUmtra1qb29XQkKCw1MFzty5c/uOW3V1tebPn+/wRE9uw4YNam5ulvTP72T//50M\nI8Xdu3dVVFSk0tLSvqvEo+E4DfS8QuFYuXwOnKsfOHBAFy9elMvl0o4dO/Tyyy8P9wgB19nZqc2b\nN+vOnTvq7u7W+vXr9frrrzs9ll8aGhq0b98+Xb9+XeHh4UpISNCBAweUn5+v+/fva/LkydqzZ4+e\neeYZp0c1G+g5ZWdnq6ysTJGRkXK73dqzZ49iY2OdHtXM4/Ho8OHDeumll/q27d27V9u2bRuxx0ka\n+HktX75cFRUVjh4rR2IJACMNd/AAgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHA4P8ALqDX\nN3rmU3AAAAAASUVORK5CYII=\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Prediction: 1\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEqVJREFUeJzt3W9Ilff/x/HX+eWkpMQ0dQRrZdgm\nq24Miiz6Y0nrFKPVjZqiMgiW/SMX0ZxlDYJMiyALZrnqRlKc4a1u5B9cjIWZUbDA7ljWQqJMm1iR\nbSbne2P8/H7NY77P8Ryvoz0f97y8us777BpPrnMuP+e4vF6vVwCAd/o/pwcAgNGAWAKAAbEEAANi\nCQAGxBIADIglABgQSwAwIJYAYBAR6D88dOiQbt++LZfLpYKCAs2dOzeYcwFAWAkoljdu3NDDhw/l\n8XjU0tKigoICeTyeYM8GAGEjoJfhDQ0NSk9PlyTNnDlTXV1devnyZVAHA4BwElAsOzo6NHny5L6f\nY2Nj1d7eHrShACDcBOUGD5/FAWCsCyiWCQkJ6ujo6Pv56dOnio+PD9pQABBuAorlokWLVFNTI0m6\nc+eOEhISNHHixKAOBgDhJKC74Z9//rk+++wzff3113K5XDpw4ECw5wKAsOLiw38BYGis4AEAA2IJ\nAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBY\nAoABsQQAA2IJAAbEEgAMiCUAGBBLADAI6KtwgVC5ePGiab+9e/eaj/ngwQOf271er1wul/k4gWpp\naTHvm5SUFMJJMBxcWQKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAAOWOyIg\n9+/fD8lxMzMzTfutWrXKfMzBljv6MmPGjKAf88mTJ+Z9We4YvriyBAADYgkABsQSAAyIJQAYEEsA\nMCCWAGBALAHAgFgCgAGxBAADVvAgIOnp6eZ9/VntYrV06VLzvh6PZ9DfdXV19fs5OjradMwtW7aY\nH3/27NnmfRG+uLIEAIOAriwbGxu1c+dOJScnS5JmzZqlwsLCoA4GAOEk4Jfh8+fPV2lpaTBnAYCw\nxctwADAIOJb37t1Tbm6uMjIyVF9fH8yZACDsuLxer9fff9TW1qZbt27J7XartbVVOTk5qq2tVWRk\nZChmBADHBfSeZWJiolavXi1JmjZtmqZMmaK2tjZ99NFHQR0O4cufD6kNxZ8OFRUVmffdunWrz+3R\n0dF6/vz5gG0W/vzpUHFxsXlf6+Nj5AX0MvzSpUs6c+aMJKm9vV3Pnj1TYmJiUAcDgHAS0JXl8uXL\ntXv3bv3666/q6enRjz/+yEtwAGNaQLGcOHGiysrKgj0LAIStgG7wYHR5+325d9m4caPP7VVVVXK7\n3X0/V1dXD3suX6zvRebn54fk8YHB8HeWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCW\nAGBALAHAgOWO7wF/Pk5ssDX/Xq9XLpcroMf35+PUWMaIcMWVJQAYEEsAMCCWAGBALAHAgFgCgAGx\nBAADYgkABsQSAAyIJQAYsIJnlLp27Zp530WLFg378d5ewXPhwgXzv83IyBj24wNO48oSAAyIJQAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYRDg9APp7/vy5ab9gLGH0JTc31/Q7ljDi\nfcOVJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMODbHcOM2+027VddXW0+\n5qpVq8z7ejwen9ujo6P7LcWMjo42HxMYC0xXls3NzUpPT1dFRYUk6fHjx8rOzlZmZqZ27typf/75\nJ6RDAoDThozlq1evdPDgQaWmpvZtKy0tVWZmpi5cuKCPP/5YlZWVIR0SAJw2ZCwjIyNVXl6uhISE\nvm2NjY1asWKFJCktLU0NDQ2hmxAAwsCQH9EWERGhiIj+u3V3dysyMlKSFBcXp/b29tBMBwBhYtif\nZ8n9oeCqqqpyeoRBcVMH77OAYhkVFaXXr19r/Pjxamtr6/cSHcPD3XAgPAX0d5YLFy5UTU2NJKm2\ntlaLFy8O6lAAEG6GvLJsampScXGxHj16pIiICNXU1Ojo0aPKz8+Xx+PR1KlT9dVXX43ErADgmCFj\nOXv2bJ0/f37A9nPnzoVkIAAIR6zgGQH379837ztz5sygP35LS4t536SkpKA/PjAWsDYcAAyIJQAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYDPvzLDG0I0eOBP2Yubm55n1ZwggMH1eW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgOWOI6Cmpibox8zOzg76Mceq\nwb5dMykpacDvrEtT//zzT/PjT58+3byvP/+vfPLJJwO2VVVVye1299uWk5NjPuaaNWvM+0ZHR5v3\nHQu4sgQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA5fX6/U6PcRY588Xhj148MC0\nX0tLS0ge32kXL1407bd3717zMQf7b+r1euVyuczHGQ2G+5xWrVpl3tfj8Zj2GysrfbiyBAADYgkA\nBsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABix3HAFbtmwx71tWVmbabzSdtlAs9wyG\n4SwN9GdZYHV1dUCPEYiRXMJpXXI7mpbbvgtXlgBgYIplc3Oz0tPTVVFRIUnKz8/Xl19+qezsbGVn\nZ+u3334L5YwA4Lghvzf81atXOnjwoFJTU/tt37Vrl9LS0kI2GACEkyGvLCMjI1VeXq6EhISRmAcA\nwpL5Bs+JEyc0efJkZWVlKT8/X+3t7erp6VFcXJwKCwsVGxsb6lkBwDFDvgz3Ze3atYqJiVFKSopO\nnz6tkydPav/+/cGebczgbjh3w0cKd8NDJ6C74ampqUpJSZEkLV++XM3NzUEdCgDCTUCx3LFjh1pb\nWyVJjY2NSk5ODupQABBuhnwZ3tTUpOLiYj169EgRERGqqalRVlaW8vLyNGHCBEVFRamoqGgkZgUA\nxwwZy9mzZ+v8+fMDtn/xxRchGQgAwlFAN3gAt9tt3jcUN238eTWzYcOGQX/39k2KKVOmBDzTYEL1\n7YbPnz/3ub2rq6vfz99//735mNYbjJK0bds2035VVVXmY4YzljsCgAGxBAADYgkABsQSAAyIJQAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADljuOUteuXTPvu3DhwmEfd+HChf1+F6rPaKyvrzft589zepfR\n/FmLgy2jfHv7Tz/9ZD6mP8sd3zdcWQKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKA\nASt4RkBxcbF535qaGtN+WVlZ5mP+8ccf5n19fZOn9O+KmcF+NxR/vlwsWCtz8F/+rPbyR2FhYUiO\nG664sgQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAYur9frdXoI/Jd1adqi\nRYtCPEl/Xq9XLpcroH/b1dVl3newL+HCQBcvXhywLSMjY8D2zMxM8zEvXLhg3nfNmjWm/cbKOeXK\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGLDccZTy5xv7grE0cjjLHf35\ndseHDx+a9svOzjYf88MPP/S5PSkpSffv3++37ZdffjEdc8mSJebH98fBgwfN+1ZXVw/YNpzzJEn1\n9fXmfd+3b+I0fRVuSUmJbt26pTdv3mjz5s2aM2eO9uzZo97eXsXHx+vIkSOKjIwM9awA4JghY3n9\n+nXdvXtXHo9HnZ2dWrdunVJTU5WZmSm3261jx46psrLSr8X6ADDaDPme5bx583T8+HFJ/356SHd3\ntxobG7VixQpJUlpamhoaGkI7JQA4bMhYjhs3TlFRUZKkyspKLVmyRN3d3X0vu+Pi4tTe3h7aKQHA\nYab3LCWprq5OlZWVOnv2rFauXNm3nftDzvDnzfVgnaOxeK6TkpL6/Zyfn+/QJP+qqqoa9jHG4nkK\nB6ZYXr16VWVlZfr55581adIkRUVF6fXr1xo/frza2tqUkJAQ6jnxFu6Gczecu+Eja8iX4S9evFBJ\nSYlOnTqlmJgYSf/+R6qpqZEk1dbWavHixaGdEgAcNuSV5eXLl9XZ2am8vLy+bYcPH9a+ffvk8Xg0\ndepUffXVVyEdEgCcNmQsN27cqI0bNw7Yfu7cuZAMBADhiBU874G335d7l23btvncXlVVJbfb3fez\nr/fLRpvhvr/ntBkzZgzYdv/+/QE3rerq6szHnDJlinnfsfJFZFasDQcAA2IJAAbEEgAMiCUAGBBL\nADAglgBgQCwBwIBYAoABsQQAA2IJAAYsd0RA/PmIuPPnz5v3tX702u+//24+5g8//OBzu6/ljr6W\nEPry7bffmh9/w4YN5n398fayRoQWV5YAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYA\nYEAsAcCA5Y4AYMCVJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBA\nLAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgEGEZaeSkhLdunVLb9680ebN\nm3XlyhXduXNHMTExkqRNmzZp2bJloZwTABw1ZCyvX7+uu3fvyuPxqLOzU+vWrdOCBQu0a9cupaWl\njcSMAOC4IWM5b948zZ07V5IUHR2t7u5u9fb2hnwwAAgnLq/X67Xu7PF4dPPmTY0bN07t7e3q6elR\nXFycCgsLFRsbG8o5AcBR5ljW1dXp1KlTOnv2rJqamhQTE6OUlBSdPn1aT5480f79+0M9KwA4xnQ3\n/OrVqyorK1N5ebkmTZqk1NRUpaSkSJKWL1+u5ubmkA4JAE4bMpYvXrxQSUmJTp061Xf3e8eOHWpt\nbZUkNTY2Kjk5ObRTAoDDhrzBc/nyZXV2diovL69v2/r165WXl6cJEyYoKipKRUVFIR0SAJzm1w0e\nAHhfsYIHAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHA\ngFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsA\nMCCWAGBALAHAIMKJBz106JBu374tl8ulgoICzZ0714kxgqqxsVE7d+5UcnKyJGnWrFkqLCx0eKrA\nNTc3a+vWrfrmm2+UlZWlx48fa8+ePert7VV8fLyOHDmiyMhIp8f0y9vPKT8/X3fu3FFMTIwkadOm\nTVq2bJmzQ/qppKREt27d0ps3b7R582bNmTNn1J8naeDzunLliuPnasRjeePGDT18+FAej0ctLS0q\nKCiQx+MZ6TFCYv78+SotLXV6jGF79eqVDh48qNTU1L5tpaWlyszMlNvt1rFjx1RZWanMzEwHp/SP\nr+ckSbt27VJaWppDUw3P9evXdffuXXk8HnV2dmrdunVKTU0d1edJ8v28FixY4Pi5GvGX4Q0NDUpP\nT5ckzZw5U11dXXr58uVIj4F3iIyMVHl5uRISEvq2NTY2asWKFZKktLQ0NTQ0ODVeQHw9p9Fu3rx5\nOn78uCQpOjpa3d3do/48Sb6fV29vr8NTORDLjo4OTZ48ue/n2NhYtbe3j/QYIXHv3j3l5uYqIyND\n9fX1To8TsIiICI0fP77ftu7u7r6Xc3FxcaPunPl6TpJUUVGhnJwcfffdd/rrr78cmCxw48aNU1RU\nlCSpsrJSS5YsGfXnSfL9vMaNG+f4uXLkPcv/5fV6nR4hKKZPn67t27fL7XartbVVOTk5qq2tHZXv\nFw1lrJyztWvXKiYmRikpKTp9+rROnjyp/fv3Oz2W3+rq6lRZWamzZ89q5cqVfdtH+3n63+fV1NTk\n+Lka8SvLhIQEdXR09P389OlTxcfHj/QYQZeYmKjVq1fL5XJp2rRpmjJlitra2pweK2iioqL0+vVr\nSVJbW9uYeDmbmpqqlJQUSdLy5cvV3Nzs8ET+u3r1qsrKylReXq5JkyaNmfP09vMKh3M14rFctGiR\nampqJEl37txRQkKCJk6cONJjBN2lS5d05swZSVJ7e7uePXumxMREh6cKnoULF/adt9raWi1evNjh\niYZvx44dam1tlfTve7L//5cMo8WLFy9UUlKiU6dO9d0lHgvnydfzCodz5fI6cK1+9OhR3bx5Uy6X\nSwcOHNCnn3460iME3cuXL7V79249f/5cPT092r59u5YuXer0WAFpampScXGxHj16pIiICCUmJuro\n0aPKz8/X33//ralTp6qoqEgffPCB06Oa+XpOWVlZOn36tCZMmKCoqCgVFRUpLi7O6VHNPB6PTpw4\noRkzZvRtO3z4sPbt2zdqz5Pk+3mtX79eFRUVjp4rR2IJAKMNK3gAwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBg8B9OkjtgR8VvdgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "Prediction: 6\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "4SJizeJtNaAs", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Profiling\n", + "\n", + "If you want to drill down into the performance characteristics of your code, you can use native Python profilers like [`cProfile`](https://docs.python.org/3/library/profile.html). In the next exercise, you'll do just that." + ] + }, + { + "metadata": { + "id": "_2v0QnG8__PJ", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Exercise!\n", + "\n", + "This exercise does not require coding. If you have not completed the training exercise, replace `train_one_epoch` below with `_train_one_epoch`.\n", + "\n", + "Run the below cell and inspect the printed profiles. What parts of the code appear to be hotspots or\n", + "bottlenecks? How does sorting the profile by total time compare to sorting it\n", + "by cumulative time?\n", + "\n" + ] + }, + { + "metadata": { + "id": "IFypaYbG_9fB", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 714 + }, + "outputId": "d9c3596b-a165-4edd-fc6b-53ccd0d01d19" + }, + "cell_type": "code", + "source": [ + "import cProfile\n", + "import pstats\n", + "\n", + "cProfile.run(\"train_one_epoch(model, training_data, optimizer)\", \"training_profile\")\n", + "\n", + "stats = pstats.Stats(\"training_profile\").strip_dirs().sort_stats(\"tottime\")\n", + "stats.print_stats(10)\n", + "\n", + "stats.sort_stats(\"cumtime\").print_stats(10)" + ], + "execution_count": 17, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Thu Jun 7 12:25:04 2018 training_profile\n", + "\n", + " 92209 function calls (91817 primitive calls) in 3.446 seconds\n", + "\n", + " Ordered by: internal time\n", + " List reduced from 672 to 10 due to restriction <10>\n", + "\n", + " ncalls tottime percall cumtime percall filename:lineno(function)\n", + " 1080 2.552 0.002 2.552 0.002 {built-in method _pywrap_tensorflow_internal.TFE_Py_FastPathExecute}\n", + " 83 0.753 0.009 0.753 0.009 {built-in method _pywrap_tensorflow_internal.TFE_Py_Execute}\n", + " 16 0.006 0.000 1.019 0.064 network.py:736(_run_internal_graph)\n", + " 16 0.005 0.000 2.253 0.141 {built-in method _pywrap_tensorflow_internal.TFE_Py_TapeGradient}\n", + " 2321 0.004 0.000 0.007 0.000 abc.py:178(__instancecheck__)\n", + " 288 0.004 0.000 0.009 0.000 inspect.py:2092(_signature_from_function)\n", + " 878 0.004 0.000 0.005 0.000 ops.py:5936(__enter__)\n", + " 288 0.004 0.000 0.016 0.000 inspect.py:1079(getfullargspec)\n", + " 11006 0.003 0.000 0.005 0.000 {built-in method builtins.isinstance}\n", + " 768 0.003 0.000 0.008 0.000 {built-in method _pywrap_tensorflow_internal.Flatten}\n", + "\n", + "\n", + "Thu Jun 7 12:25:04 2018 training_profile\n", + "\n", + " 92209 function calls (91817 primitive calls) in 3.446 seconds\n", + "\n", + " Ordered by: cumulative time\n", + " List reduced from 672 to 10 due to restriction <10>\n", + "\n", + " ncalls tottime percall cumtime percall filename:lineno(function)\n", + " 1 0.000 0.000 3.446 3.446 {built-in method builtins.exec}\n", + " 1 0.000 0.000 3.446 3.446 :1()\n", + " 1 0.001 0.001 3.446 3.446 :9(train_one_epoch)\n", + " 1080 2.552 0.002 2.552 0.002 {built-in method _pywrap_tensorflow_internal.TFE_Py_FastPathExecute}\n", + " 16 0.000 0.000 2.255 0.141 backprop.py:739(gradient)\n", + " 16 0.000 0.000 2.253 0.141 imperative_grad.py:31(imperative_grad)\n", + " 16 0.005 0.000 2.253 0.141 {built-in method _pywrap_tensorflow_internal.TFE_Py_TapeGradient}\n", + " 400 0.002 0.000 2.246 0.006 backprop.py:145(grad_fn)\n", + " 400 0.002 0.000 2.239 0.006 backprop.py:95(_magic_gradient_function)\n", + " 32 0.001 0.000 1.601 0.050 nn_grad.py:497(_Conv2DGrad)\n", + "\n", + "\n" + ], + "name": "stdout" + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 17 + } + ] + }, + { + "metadata": { + "id": "8ixpnyCNNTI4", + "colab_type": "code", + "colab": {} + }, + "cell_type": "code", + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/workshop/3_inspecting.ipynb b/tensorflow/contrib/eager/python/examples/workshop/3_inspecting.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..64d19ec5c9bfccd07eabb21ce8fbb62b21f23efa --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/workshop/3_inspecting.ipynb @@ -0,0 +1,443 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Debugging \"graph-first\" models with eager execution", + "version": "0.3.2", + "provenance": [], + "include_colab_link": true + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "[View in Colaboratory](https://colab.research.google.com/gist/alextp/9568ab40f6ed6f9a3ba4736f6aef6127/debugging-graph-first-models-with-eager-execution.ipynb)" + ] + }, + { + "metadata": { + "id": "mm-t0GuIu1Dt", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "This colab uses eager execution and the Python debugger to modify the execution of a translation model. This combination lets you quickly explore counterfactuals when researching and designing modifications to a model.\n", + "\n", + "The model, Transformer from [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor), was originally written with graph building in mind. Executing it eagerly can still be helpful!" + ] + }, + { + "metadata": { + "id": "gxb1DvIDg4sv", + "colab_type": "code", + "colab": {} + }, + "cell_type": "code", + "source": [ + "#@title License (double click to show)\n", + "# Copyright 2018 The TensorFlow Authors.\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Gx3HA9N1ui64", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 37 + }, + "outputId": "f6986f34-f3e1-44e1-c902-2eb33081acad" + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "import pdb\n", + "tfe = tf.contrib.eager\n", + "\n", + "tf.enable_eager_execution()" + ], + "execution_count": 1, + "outputs": [] + }, + { + "metadata": { + "id": "3LkOm2ct-Lmc", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 37 + }, + "outputId": "2edc74d9-6bc0-4e78-ab4e-83bf96099ef4" + }, + "cell_type": "code", + "source": [ + "!pip install -q -U tensor2tensor\n", + "from tensor2tensor.models import transformer" + ], + "execution_count": 2, + "outputs": [] + }, + { + "metadata": { + "id": "1Z3oMsqV0zB6", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 170 + }, + "outputId": "0a8186ee-c688-457f-c9f6-9a6c1477a93b" + }, + "cell_type": "code", + "source": [ + "#@title Create a tensor2tensor translation model, fetch a checkpoint (double click to show)\n", + "from tensor2tensor import problems\n", + "from tensor2tensor.utils import trainer_lib\n", + "from tensor2tensor.utils import registry\n", + "\n", + "import numpy as np\n", + "import os\n", + "\n", + "# Setup some directories\n", + "data_dir = os.path.expanduser(\"~/t2t/data\")\n", + "tmp_dir = os.path.expanduser(\"~/t2t/tmp\")\n", + "train_dir = os.path.expanduser(\"~/t2t/train\")\n", + "checkpoint_dir = os.path.expanduser(\"~/t2t/checkpoints\")\n", + "tf.gfile.MakeDirs(data_dir)\n", + "tf.gfile.MakeDirs(tmp_dir)\n", + "tf.gfile.MakeDirs(train_dir)\n", + "tf.gfile.MakeDirs(checkpoint_dir)\n", + "gs_data_dir = \"gs://tensor2tensor-data\"\n", + "gs_ckpt_dir = \"gs://tensor2tensor-checkpoints/\"\n", + "\n", + "# Fetch the problem\n", + "ende_problem = problems.problem(\"translate_ende_wmt32k\")\n", + "\n", + "# Copy the vocab file locally so we can encode inputs and decode model outputs\n", + "# All vocabs are stored on GCS\n", + "vocab_name = \"vocab.ende.32768\"\n", + "vocab_file = os.path.join(gs_data_dir, vocab_name)\n", + "!gsutil cp {vocab_file} {data_dir}\n", + "\n", + "# Get the encoders from the problem\n", + "encoders = ende_problem.feature_encoders(data_dir)\n", + "\n", + "# Setup helper functions for encoding and decoding\n", + "def encode(input_str, output_str=None):\n", + " \"\"\"Input str to features dict, ready for inference\"\"\"\n", + " inputs = encoders[\"inputs\"].encode(input_str) + [1] # add EOS id\n", + " batch_inputs = tf.reshape(inputs, [1, -1, 1]) # Make it 3D.\n", + " return {\"inputs\": batch_inputs}\n", + "\n", + "def decode(integers):\n", + " \"\"\"List of ints to str\"\"\"\n", + " integers = list(np.squeeze(integers))\n", + " if 1 in integers:\n", + " integers = integers[:integers.index(1)]\n", + " return encoders[\"inputs\"].decode(np.squeeze(integers))\n", + "\n", + "# Copy the pretrained checkpoint locally\n", + "ckpt_name = \"transformer_ende_test\"\n", + "gs_ckpt = os.path.join(gs_ckpt_dir, ckpt_name)\n", + "!gsutil -q cp -R {gs_ckpt} {checkpoint_dir}\n", + "checkpoint_path = tf.train.latest_checkpoint(\n", + " os.path.join(checkpoint_dir, ckpt_name))\n", + "\n", + "# Create hparams and the model\n", + "model_name = \"transformer\"\n", + "hparams_set = \"transformer_base\"\n", + "\n", + "hparams = trainer_lib.create_hparams(hparams_set, data_dir=data_dir, problem_name=\"translate_ende_wmt32k\")\n", + "\n", + "# NOTE: Only create the model once when restoring from a checkpoint; it's a\n", + "# Layer and so subsequent instantiations will have different variable scopes\n", + "# that will not match the checkpoint.\n", + "translate_model = registry.model(model_name)(hparams, tf.estimator.ModeKeys.EVAL)" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Copying gs://tensor2tensor-data/vocab.ende.32768...\n", + "/ [1 files][316.4 KiB/316.4 KiB] \n", + "Operation completed over 1 objects/316.4 KiB. \n", + "INFO:tensorflow:Setting T2TModel mode to 'eval'\n", + "INFO:tensorflow:Setting hparams.layer_prepostprocess_dropout to 0.0\n", + "INFO:tensorflow:Setting hparams.symbol_dropout to 0.0\n", + "INFO:tensorflow:Setting hparams.attention_dropout to 0.0\n", + "INFO:tensorflow:Setting hparams.dropout to 0.0\n", + "INFO:tensorflow:Setting hparams.relu_dropout to 0.0\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "4IblPXLGjuCl", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "We've created a Transformer model and fetched an existing training checkpoint. It hasn't created variables yet, and we want to load them from the checkpoint before they're used (restore-on-create) so the first run of the model outputs the correct value. The `tfe.restore_variables_on_create` API looks up variables by name on creation and restores their values." + ] + }, + { + "metadata": { + "id": "o3MWxcAqJoqG", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + }, + "outputId": "fbc1b1bf-ffbe-4621-b3cb-5eb855fec3a8" + }, + "cell_type": "code", + "source": [ + "with tfe.restore_variables_on_create(checkpoint_path):\n", + " model_output = translate_model.infer(encode(\"Eager execution\"))\n", + "print(decode(model_output[\"outputs\"]))" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "text": [ + "INFO:tensorflow:Greedy Decoding\n", + "Hinrichtung\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "xk5HV9Hhu9zO", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Using global variable names can get somewhat fragile, so for new code we recommend the object-based `tf.keras.Model.save_weights` or `tf.train.Checkpoint`. However, these require some small code changes to work with existing graph building code.\n", + "\n", + "The Transformer model translates \"Eager execution\" in English to \"Hinrichtung\" in German, which refers to capital punishment rather than getting things done. Transformer first encodes the English, then decodes to German. We'll add a debugging hook at the start of the decode phase (once the encodings have been finalized) and see if we can correct the translation." + ] + }, + { + "metadata": { + "id": "GUGwbYvXZ9-7", + "colab_type": "code", + "colab": {} + }, + "cell_type": "code", + "source": [ + "previous_fast_decode = transformer.fast_decode\n", + "def debug_fn(*args, **kwargs):\n", + " pdb.set_trace()\n", + " return previous_fast_decode(*args, **kwargs) # \"step\" in pdb to step in\n", + "transformer.fast_decode = debug_fn # Add our debugging hook to Transformer" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "f61HlvECxJn0", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Now that we've \"monkey patched\" the model, we'll drop into a debugger just before decoding starts. In most cases it'd be simpler to add the `pdb.set_trace()` call to the code directly, but in this case we're working with prepackaged library code.\n", + "\n", + "First, let's find an encoding which represents the correct sense of \"execution\". Then we'll patch part of that encoding into the encoding of \"Eager execution\" to fix the translation. Feel free to poke around with the debugger (e.g. print a Tensor's value), but your main task is to save the encodings by assigning them to an attribute of the function:\n", + "\n", + "```\n", + "(running the next cell drops you into a pdb shell)\n", + "step\n", + "fast_decode.previous_encoding = encoder_output\n", + "continue\n", + "\n", + "```\n", + "\n", + "You can type `next` (or `n`) a few times before `continue` to watch the decoding ops run." + ] + }, + { + "metadata": { + "id": "dX4CPOGSpZrb", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 179 + }, + "outputId": "6de38c31-836f-40ef-b701-e42908172619" + }, + "cell_type": "code", + "source": [ + "model_output = translate_model.infer(encode(\"Immediate running\"))\n", + "print(decode(model_output[\"outputs\"]))" + ], + "execution_count": 7, + "outputs": [ + { + "output_type": "stream", + "text": [ + "> (4)debug_fn()\n", + "-> return previous_fast_decode(*args, **kwargs) # \"step\" in pdb to step in\n", + "(Pdb) step\n", + "--Call--\n", + "> /usr/local/lib/python2.7/dist-packages/tensor2tensor/models/transformer.py(427)fast_decode()\n", + "-> def fast_decode(encoder_output,\n", + "(Pdb) fast_decode.previous_encoding = encoder_output\n", + "(Pdb) continue\n", + "Sofortige Durchführung\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "-ZEZciV4FpLo", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Now we have an encoding saved which gets the correct sense for \"execution\"." + ] + }, + { + "metadata": { + "id": "QeC_oDVqHD_v", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 179 + }, + "outputId": "253c9af1-003e-46bd-8bf5-db968cf6a8cf" + }, + "cell_type": "code", + "source": [ + "# Assumes you followed the pdb instructions above!\n", + "transformer.fast_decode.previous_encoding" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 8 + } + ] + }, + { + "metadata": { + "id": "bC9JjeDcHEav", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Let's replace part of the encoding for \"Eager execution\" with the encoding of \"Immediate running\".\n", + "\n", + "Again we'll drop into a pdb shell. This time we'll run some TensorFlow operations to patch the encodings while the model is running.\n", + "\n", + "```\n", + "(running the next cell again drops you into a pdb shell)\n", + "step\n", + "encoder_output = tf.concat([fast_decode.previous_encoding[:, :3], encoder_output[:, 3:]], axis=1)\n", + "continue\n", + "```" + ] + }, + { + "metadata": { + "id": "t2as_Kn1h65G", + "colab_type": "code", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 179 + }, + "outputId": "5b4e546e-3bb4-4761-c545-467b631e3ffe" + }, + "cell_type": "code", + "source": [ + "model_output = translate_model.infer(encode(\"Eager execution\"))\n", + "print(decode(model_output[\"outputs\"]))" + ], + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "text": [ + "> (4)debug_fn()\n", + "-> return previous_fast_decode(*args, **kwargs) # \"step\" in pdb to step in\n", + "(Pdb) step\n", + "--Call--\n", + "> /usr/local/lib/python2.7/dist-packages/tensor2tensor/models/transformer.py(427)fast_decode()\n", + "-> def fast_decode(encoder_output,\n", + "(Pdb) encoder_output = tf.concat([fast_decode.previous_encoding[:, :3], encoder_output[:, 3:]], axis=1)\n", + "(Pdb) continue\n", + "sofortige Ausführung\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "rK6tYZ23I2cm", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "We get a different decoding, with the correct sense of \"execution\". Likely we're keeping just the encoding of \"tion\" from \"Eager execution\", so no great breakthrough in translation modeling.\n", + "\n", + "Similarly it's possible to modify attention vectors, or change words during decoding to help debug a beam search." + ] + }, + { + "metadata": { + "id": "Nb-4ipYNRWxA", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "This colab was adapted from the [Tensor2Tensor colab](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb). Credit to Ankur Taly for its concept." + ] + } + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index efa6ba062631500bd7cd16620ebec23d15b93b62..6efafccd6b93ad58da395e0b2e1e647809af62ad 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -291,8 +291,6 @@ class Metric(checkpointable.CheckpointableBase): class Mean(Metric): """Computes the (weighted) mean of the given values.""" - # TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64? - # Or defaults to type of the input if it is tf.float32, else tf.float64? def __init__(self, name=None, dtype=dtypes.float64, use_global_variables=False): @@ -377,7 +375,7 @@ class Accuracy(Mean): array_ops.shape(labels), array_ops.shape(predictions), message="Shapes of labels and predictions are unequal") matches = math_ops.equal(labels, predictions) - matches = math_ops.cast(matches, dtypes.float64) + matches = math_ops.cast(matches, self.dtype) super(Accuracy, self).call(matches, weights=weights) if weights is None: return labels, predictions @@ -421,7 +419,7 @@ class CategoricalAccuracy(Mean): labels = math_ops.argmax(labels, axis=-1) predictions = math_ops.argmax(predictions, axis=-1) matches = math_ops.equal(labels, predictions) - matches = math_ops.cast(matches, dtypes.float64) + matches = math_ops.cast(matches, self.dtype) super(CategoricalAccuracy, self).call(matches, weights=weights) if weights is None: return labels, predictions @@ -472,7 +470,7 @@ class BinaryAccuracy(Mean): predictions = ops.convert_to_tensor(predictions) predictions = predictions > self.threshold matches = math_ops.equal(labels, predictions) - matches = math_ops.cast(matches, dtypes.float64) + matches = math_ops.cast(matches, self.dtype) super(BinaryAccuracy, self).call(matches, weights=weights) if weights is None: return labels, predictions @@ -520,7 +518,7 @@ class SparseAccuracy(Mean): predictions = math_ops.argmax(predictions, axis=-1) labels = math_ops.cast(labels, dtypes.int64) matches = math_ops.equal(labels, predictions) - matches = math_ops.cast(matches, dtypes.float64) + matches = math_ops.cast(matches, self.dtype) super(SparseAccuracy, self).call(matches, weights=weights) if weights is None: return labels, predictions diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py index fdaca90fd13576e6ca8a3408aaf528dbc2384b0c..d70930864784b3e48140da27ca33ff13f593e663 100644 --- a/tensorflow/contrib/eager/python/saver.py +++ b/tensorflow/contrib/eager/python/saver.py @@ -125,8 +125,8 @@ class Saver(object): Args: var_list: The list of variables that will be saved and restored. Either a - list of `tfe.Variable` objects, or a dictionary mapping names to - `tfe.Variable` objects. + list of `tf.Variable` objects, or a dictionary mapping names to + `tf.Variable` objects. Raises: RuntimeError: if invoked when eager execution has not been enabled. diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index ca6430253b67d825290b6a376ba3f29b3ae67577..de11d00a1a0a34372467eedb02d790c920e7f449 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -34,6 +34,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@run @@enable_eager_execution +@@enable_remote_eager_execution @@custom_gradient @@ -70,6 +71,8 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@run_test_in_graph_and_eager_modes @@run_all_tests_in_graph_and_eager_modes +@@TensorSpec + @@DEVICE_PLACEMENT_EXPLICIT @@DEVICE_PLACEMENT_WARN @@DEVICE_PLACEMENT_SILENT @@ -113,7 +116,9 @@ from tensorflow.python.eager.execution_callbacks import inf_callback from tensorflow.python.eager.execution_callbacks import inf_nan_callback from tensorflow.python.eager.execution_callbacks import nan_callback from tensorflow.python.eager.execution_callbacks import seterr +from tensorflow.python.framework.tensor_spec import TensorSpec from tensorflow.python.framework.ops import enable_eager_execution +from tensorflow.python.framework.ops import enable_eager_execution_internal as enable_remote_eager_execution from tensorflow.python.framework.ops import eager_run as run from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes from tensorflow.python.framework.test_util import run_all_in_graph_and_eager_modes as run_all_tests_in_graph_and_eager_modes diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index db50b33af2e4f1cc6575d4b0d416d6d2669b5c35..4454abfb9667f824b9de0100bb81bae24ad5f7a6 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -27,7 +27,6 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import numerics -from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.summary import summary from tensorflow.python.summary.writer import writer @@ -45,12 +44,6 @@ class TFETest(test_util.TensorFlowTestCase): r'indices = 7 is not in \[0, 3\)'): array_ops.gather([0, 1, 2], 7) - def testVariableError(self): - with self.assertRaisesRegexp( - RuntimeError, - r'Variable not supported when eager execution is enabled'): - variables.Variable(initial_value=1.0) - def testGradients(self): def square(x): diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 30d297a5fb2dd2f844093d790d051a79105984dd..349f48f7f788b458af2639f7ad4cc4cd904465b4 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -18,6 +18,7 @@ py_library( ":boosted_trees", ":dnn", ":dnn_linear_combined", + ":early_stopping", ":export", ":extenders", ":head", @@ -27,7 +28,8 @@ py_library( ":multi_head", ":replicate_model_fn", ":rnn", - "//tensorflow/python:util", + ":saved_model_estimator", + "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -53,22 +55,10 @@ py_test( deps = [ ":baseline", ":head", - "//tensorflow/python:check_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:session", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variables", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:export_export", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -95,11 +85,8 @@ py_test( ], deps = [ ":boosted_trees", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/feature_column", "//third_party/py/numpy", ], ) @@ -109,7 +96,7 @@ py_library( srcs = ["python/estimator/dnn.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:nn", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator", "//tensorflow/python/estimator:dnn", ], @@ -128,16 +115,11 @@ py_test( deps = [ ":dnn", ":head", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:dnn_testing_utils", "//tensorflow/python/estimator:export_export", "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -148,7 +130,7 @@ py_library( srcs = ["python/estimator/dnn_linear_combined.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:nn", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator", "//tensorflow/python/estimator:dnn_linear_combined", ], @@ -167,18 +149,12 @@ py_test( deps = [ ":dnn_linear_combined", ":head", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:nn", - "//tensorflow/python:platform", - "//tensorflow/python:summary", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:dnn_testing_utils", "//tensorflow/python/estimator:export_export", "//tensorflow/python/estimator:linear_testing_utils", "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -191,10 +167,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:clip_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:util", @@ -210,18 +183,11 @@ py_test( tags = ["notsan"], # b/62863147 deps = [ ":extenders", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/predictor", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:framework_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", - "//tensorflow/python:variables", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/estimator:linear", - "//tensorflow/python/feature_column", "//third_party/py/numpy", ], ) @@ -245,21 +211,11 @@ py_test( tags = ["notsan"], # b/62863147 deps = [ ":export", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:metrics", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:session", - "//tensorflow/python:state_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variables", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator", "//tensorflow/python/estimator:export_export", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/saved_model:loader", - "//tensorflow/python/saved_model:tag_constants", ], ) @@ -270,25 +226,12 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:nn", - "//tensorflow/python:sparse_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:summary", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model:signature_constants", ], ) @@ -299,25 +242,10 @@ py_test( srcs_version = "PY2AND3", deps = [ ":head", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:string_ops", - "//tensorflow/python:training", - "//tensorflow/python:variables", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model:signature_constants", "//third_party/py/numpy", "@six_archive//:six", ], @@ -330,8 +258,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:estimator_py", ], ) @@ -344,10 +271,7 @@ py_test( tags = ["notsan"], deps = [ ":hooks", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:estimator_py", "//third_party/py/numpy", "@six_archive//:six", @@ -376,16 +300,11 @@ py_test( deps = [ ":head", ":linear", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:summary", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:export_export", "//tensorflow/python/estimator:linear_testing_utils", "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -398,8 +317,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:framework_ops", - "//tensorflow/python:util", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:dnn", "//tensorflow/python/estimator:linear", ], @@ -412,9 +330,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":logit_fns", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:session", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:model_fn", ], ) @@ -426,18 +342,11 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:summary", - "//tensorflow/python:training", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/saved_model:signature_constants", "@six_archive//:six", ], ) @@ -450,15 +359,10 @@ py_test( deps = [ ":head", ":multi_head", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:framework_ops", - "//tensorflow/python:string_ops", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/saved_model:signature_constants", "//third_party/py/numpy", "@six_archive//:six", ], @@ -471,24 +375,10 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:device", - "//tensorflow/python:device_lib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:sparse_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:state_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:util", - "//tensorflow/python/ops/losses", "@six_archive//:six", ], ) @@ -499,6 +389,7 @@ cuda_py_test( srcs = ["python/estimator/replicate_model_fn_test.py"], additional_deps = [ "@absl_py//absl/testing:parameterized", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator", "//tensorflow/python/estimator:dnn", "//tensorflow/python/estimator:export_export", @@ -507,21 +398,6 @@ cuda_py_test( "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:optimizers", "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/feature_column", - "//tensorflow/python/ops/losses", - "//tensorflow/python/saved_model:signature_constants", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:platform", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", ":replicate_model_fn", ], tags = [ @@ -537,22 +413,11 @@ py_library( srcs_version = "PY2AND3", deps = [ ":extenders", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/contrib/feature_column:feature_column_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:partitioned_variables", - "//tensorflow/python:rnn", - "//tensorflow/python:rnn_cell", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", "//tensorflow/python/estimator", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:optimizers", - "//tensorflow/python/feature_column", "@six_archive//:six", ], ) @@ -571,22 +436,73 @@ py_test( deps = [ ":head", ":rnn", + "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/contrib/data", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:check_ops", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:parsing_utils", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_library( + name = "early_stopping", + srcs = ["python/estimator/early_stopping.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py_no_contrib", + "//tensorflow/python/estimator", + ], +) + +py_test( + name = "early_stopping_test", + srcs = ["python/estimator/early_stopping_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":early_stopping", + "//tensorflow:tensorflow_py_no_contrib", + "//tensorflow/python/estimator", + "@absl_py//absl/testing:parameterized", + ], +) + +py_library( + name = "saved_model_estimator", + srcs = ["python/estimator/saved_model_estimator.py"], + deps = [ + ":export", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:export", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/saved_model", + ], +) + +py_test( + name = "saved_model_estimator_test", + size = "medium", + srcs = ["python/estimator/saved_model_estimator_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":export", + ":saved_model_estimator", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", - "//tensorflow/python:lib", - "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:platform", "//tensorflow/python:state_ops", - "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:variables", - "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/estimator:parsing_utils", - "//tensorflow/python/feature_column", - "//third_party/py/numpy", - "@six_archive//:six", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", ], ) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 788ac5ca7046d6dd30a3d5520b243944532622fa..e1453ae1d04ebd8d72f812b51480f0b05f7a5416 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -23,6 +23,7 @@ from tensorflow.contrib.estimator.python.estimator.baseline import * from tensorflow.contrib.estimator.python.estimator.boosted_trees import * from tensorflow.contrib.estimator.python.estimator.dnn import * from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * +from tensorflow.contrib.estimator.python.estimator.early_stopping import * from tensorflow.contrib.estimator.python.estimator.export import * from tensorflow.contrib.estimator.python.estimator.extenders import * from tensorflow.contrib.estimator.python.estimator.head import * @@ -32,6 +33,8 @@ from tensorflow.contrib.estimator.python.estimator.logit_fns import * from tensorflow.contrib.estimator.python.estimator.multi_head import * from tensorflow.contrib.estimator.python.estimator.replicate_model_fn import * from tensorflow.contrib.estimator.python.estimator.rnn import * +from tensorflow.contrib.estimator.python.estimator.saved_model_estimator import * +from tensorflow.python.estimator.export.export import * from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import @@ -63,6 +66,15 @@ _allowed_symbols = [ 'RNNEstimator', 'export_saved_model_for_mode', 'export_all_saved_models', + 'make_early_stopping_hook', + 'read_eval_metrics', + 'stop_if_lower_hook', + 'stop_if_higher_hook', + 'stop_if_no_increase_hook', + 'stop_if_no_decrease_hook', + 'build_raw_supervised_input_receiver_fn', + 'build_supervised_input_receiver_fn_from_input_fn', + 'SavedModelEstimator' ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py index d0e3e670f7332811c1bfdaea65b0308ce59ade59..505c94e97192afdd4e2ce9af2abb9825320751f2 100644 --- a/tensorflow/contrib/estimator/python/estimator/baseline_test.py +++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py @@ -113,6 +113,8 @@ class BaselineEstimatorEvaluationTest(test.TestCase): self.assertDictEqual({ metric_keys.MetricKeys.LOSS: 18., metric_keys.MetricKeys.LOSS_MEAN: 9., + metric_keys.MetricKeys.PREDICTION_MEAN: 13., + metric_keys.MetricKeys.LABEL_MEAN: 10., ops.GraphKeys.GLOBAL_STEP: 100 }, eval_metrics) @@ -141,6 +143,8 @@ class BaselineEstimatorEvaluationTest(test.TestCase): self.assertDictEqual({ metric_keys.MetricKeys.LOSS: 27., metric_keys.MetricKeys.LOSS_MEAN: 9., + metric_keys.MetricKeys.PREDICTION_MEAN: 13., + metric_keys.MetricKeys.LABEL_MEAN: 10., ops.GraphKeys.GLOBAL_STEP: 100 }, eval_metrics) @@ -166,7 +170,9 @@ class BaselineEstimatorEvaluationTest(test.TestCase): self.assertItemsEqual( (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN, - ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys()) + metric_keys.MetricKeys.PREDICTION_MEAN, + metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP), + eval_metrics.keys()) # Logit is bias which is [46, 58] self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS]) diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py index bd641014e9eec6623d66574bccd08ff03ebc28ac..7ed77bcce6f00ed13e9952951800f1017d582f19 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py @@ -49,7 +49,9 @@ class _BoostedTreesEstimator(estimator.Estimator): l2_regularization=0., tree_complexity=0., min_node_weight=0., - config=None): + config=None, + center_bias=False, + pruning_mode='none'): """Initializes a `BoostedTreesEstimator` instance. Args: @@ -82,17 +84,35 @@ class _BoostedTreesEstimator(estimator.Estimator): considered. The value will be compared with sum(leaf_hessian)/ (batch_size * n_batches_per_layer). config: `RunConfig` object to configure the runtime settings. + center_bias: Whether bias centering needs to occur. Bias centering refers + to the first node in the very first tree returning the prediction that + is aligned with the original labels distribution. For example, for + regression problems, the first node will return the mean of the labels. + For binary classification problems, it will return a logit for a prior + probability of label 1. + pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre- + pruning (do not split a node if not enough gain is observed) and post + pruning (build the tree up to a max depth and then prune branches with + negative gain). For pre and post pruning, you MUST provide + tree_complexity >0. + """ # pylint:disable=protected-access # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight) + tree_complexity, min_node_weight, center_bias, pruning_mode) def _model_fn(features, labels, mode, config): return canned_boosted_trees._bt_model_fn( - features, labels, mode, head, feature_columns, tree_hparams, - n_batches_per_layer, config) + features, + labels, + mode, + head, + feature_columns, + tree_hparams, + n_batches_per_layer, + config=config) super(_BoostedTreesEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) @@ -114,7 +134,9 @@ def boosted_trees_classifier_train_in_memory( tree_complexity=0., min_node_weight=0., config=None, - train_hooks=None): + train_hooks=None, + center_bias=False, + pruning_mode='none'): """Trains a boosted tree classifier with in memory dataset. Example: @@ -186,7 +208,18 @@ def boosted_trees_classifier_train_in_memory( considered. The value will be compared with sum(leaf_hessian)/ (batch_size * n_batches_per_layer). config: `RunConfig` object to configure the runtime settings. - train_hooks: a list of Hook instances to be passed to estimator.train(). + train_hooks: a list of Hook instances to be passed to estimator.train() + center_bias: Whether bias centering needs to occur. Bias centering refers + to the first node in the very first tree returning the prediction that + is aligned with the original labels distribution. For example, for + regression problems, the first node will return the mean of the labels. + For binary classification problems, it will return a logit for a prior + probability of label 1. + pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre- + pruning (do not split a node if not enough gain is observed) and post + pruning (build the tree up to a max depth and then prune branches with + negative gain). For pre and post pruning, you MUST provide + tree_complexity >0. Returns: a `BoostedTreesClassifier` instance created with the given arguments and @@ -207,7 +240,7 @@ def boosted_trees_classifier_train_in_memory( # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight) + tree_complexity, min_node_weight, center_bias, pruning_mode) def _model_fn(features, labels, mode, config): return canned_boosted_trees._bt_model_fn( @@ -247,7 +280,9 @@ def boosted_trees_regressor_train_in_memory( tree_complexity=0., min_node_weight=0., config=None, - train_hooks=None): + train_hooks=None, + center_bias=False, + pruning_mode='none'): """Trains a boosted tree regressor with in memory dataset. Example: @@ -313,6 +348,17 @@ def boosted_trees_regressor_train_in_memory( (batch_size * n_batches_per_layer). config: `RunConfig` object to configure the runtime settings. train_hooks: a list of Hook instances to be passed to estimator.train(). + center_bias: Whether bias centering needs to occur. Bias centering refers + to the first node in the very first tree returning the prediction that + is aligned with the original labels distribution. For example, for + regression problems, the first node will return the mean of the labels. + For binary classification problems, it will return a logit for a prior + probability of label 1. + pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre- + pruning (do not split a node if not enough gain is observed) and post + pruning (build the tree up to a max depth and then prune branches with + negative gain). For pre and post pruning, you MUST provide + tree_complexity >0. Returns: a `BoostedTreesClassifier` instance created with the given arguments and @@ -332,7 +378,7 @@ def boosted_trees_regressor_train_in_memory( # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight) + tree_complexity, min_node_weight, center_bias, pruning_mode) def _model_fn(features, labels, mode, config): return canned_boosted_trees._bt_model_fn( diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py index 76cbefe5e94502188388df6fc2816d130ac896d5..b1581f37509b5dc2bec98942e88c024905f25d93 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py @@ -115,6 +115,70 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): eval_res = est.evaluate(input_fn=input_fn, steps=1) self.assertAllClose(eval_res['average_loss'], 1.008551) + def testTrainAndEvaluateEstimatorWithCenterBias(self): + input_fn = _make_train_input_fn(is_classification=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=2, + head=self._head, + max_depth=5, + center_bias=True) + + # It will stop after 11 steps because of the max depth and num trees. + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(input_fn, steps=num_steps) + # 10 steps for training and 2 step for bias centering. + self._assert_checkpoint( + est.model_dir, global_step=12, finalized_trees=2, attempted_layers=10) + eval_res = est.evaluate(input_fn=input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 0.614642) + + def testTrainAndEvaluateEstimatorWithPrePruning(self): + input_fn = _make_train_input_fn(is_classification=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=2, + head=self._head, + max_depth=5, + tree_complexity=0.001, + pruning_mode='pre') + + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(input_fn, steps=num_steps) + # We stop actually after 2*depth*n_trees steps (via a hook) because we still + # could not grow 2 trees of depth 5 (due to pre-pruning). + self._assert_checkpoint( + est.model_dir, global_step=21, finalized_trees=0, attempted_layers=21) + eval_res = est.evaluate(input_fn=input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 3.83943) + + def testTrainAndEvaluateEstimatorWithPostPruning(self): + input_fn = _make_train_input_fn(is_classification=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=2, + head=self._head, + max_depth=5, + tree_complexity=0.001, + pruning_mode='post') + + # It will stop after 10 steps because of the max depth and num trees. + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(input_fn, steps=num_steps) + self._assert_checkpoint( + est.model_dir, global_step=10, finalized_trees=2, attempted_layers=10) + eval_res = est.evaluate(input_fn=input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 2.37652) + def testInferEstimator(self): train_input_fn = _make_train_input_fn(is_classification=False) predict_input_fn = numpy_io.numpy_input_fn( @@ -139,6 +203,33 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], [pred['predictions'] for pred in predictions]) + def testInferEstimatorWithCenterBias(self): + train_input_fn = _make_train_input_fn(is_classification=False) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5, + center_bias=True, + head=self._head) + + # It will stop after 6 steps because of the max depth and num trees (5 for + # training and 2 for bias centering). + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(train_input_fn, steps=num_steps) + self._assert_checkpoint( + est.model_dir, global_step=7, finalized_trees=1, attempted_layers=5) + # Validate predictions. + predictions = list(est.predict(input_fn=predict_input_fn)) + + self.assertAllClose( + [[1.634501], [1.325703], [1.187431], [2.019683], [2.832683]], + [pred['predictions'] for pred in predictions]) + def testBinaryClassifierTrainInMemoryAndEvalAndInfer(self): train_input_fn = _make_train_input_fn(is_classification=True) predict_input_fn = numpy_io.numpy_input_fn( @@ -159,14 +250,65 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): self.assertAllClose([[0], [1], [1], [0], [0]], [pred['class_ids'] for pred in predictions]) + def testBinaryClassifierTrainInMemoryAndEvalAndInferWithCenterBias(self): + train_input_fn = _make_train_input_fn(is_classification=True) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_classifier_train_in_memory( + train_input_fn=train_input_fn, + feature_columns=self._feature_columns, + n_trees=1, + max_depth=5, + center_bias=True) + # It will stop after 5 steps + 3 for bias, because of the max depth and num + # trees. + self._assert_checkpoint( + est.model_dir, global_step=8, finalized_trees=1, attempted_layers=5) + + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['accuracy'], 1.0) + # Validate predictions. + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0], [1], [1], [0], [0]], + [pred['class_ids'] for pred in predictions]) + + def testBinaryClassifierTrainInMemoryAndEvalAndInferWithPrePruning(self): + train_input_fn = _make_train_input_fn(is_classification=True) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_classifier_train_in_memory( + train_input_fn=train_input_fn, + feature_columns=self._feature_columns, + n_trees=1, + max_depth=5, + pruning_mode='pre', + tree_complexity=0.01) + # We stop actually after 2*depth*n_trees steps (via a hook) because we still + # could not grow 1 trees of depth 5 (due to pre-pruning). + self._assert_checkpoint( + est.model_dir, global_step=11, finalized_trees=0, attempted_layers=11) + + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['accuracy'], 1.0) + # Validate predictions. + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0], [1], [1], [0], [0]], + [pred['class_ids'] for pred in predictions]) + def testBinaryClassifierTrainInMemoryWithDataset(self): train_input_fn = _make_train_input_fn_dataset(is_classification=True) predict_input_fn = numpy_io.numpy_input_fn( x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) est = boosted_trees.boosted_trees_classifier_train_in_memory( - train_input_fn=train_input_fn, feature_columns=self._feature_columns, - n_trees=1, max_depth=5) + train_input_fn=train_input_fn, + feature_columns=self._feature_columns, + n_trees=1, + max_depth=5) # It will stop after 5 steps because of the max depth and num trees. self._assert_checkpoint( est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py index 4bb90cf81bc32723e24a220e45c43c1f9b3f1980..9efa8f474d865a36788cba40a15404bf0b30a17e 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn.py @@ -112,7 +112,8 @@ class DNNEstimator(estimator.Estimator): dropout=None, input_layer_partitioner=None, config=None, - warm_start_from=None): + warm_start_from=None, + batch_norm=False): """Initializes a `DNNEstimator` instance. Args: @@ -142,6 +143,7 @@ class DNNEstimator(estimator.Estimator): string filepath is provided instead of a `WarmStartSettings`, then all weights are warm-started, and it is assumed that vocabularies and Tensor names are unchanged. + batch_norm: Whether to use batch normalization after each hidden layer. """ def _model_fn(features, labels, mode, config): return dnn_lib._dnn_model_fn( # pylint: disable=protected-access @@ -155,7 +157,8 @@ class DNNEstimator(estimator.Estimator): activation_fn=activation_fn, dropout=dropout, input_layer_partitioner=input_layer_partitioner, - config=config) + config=config, + batch_norm=batch_norm) super(DNNEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config, warm_start_from=warm_start_from) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py index 894a2954987a4af760d3c08fc6f30405010150c5..2eef60c39f54bfb464b7da0eb57a47e9eee9b800 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py @@ -110,7 +110,8 @@ class DNNLinearCombinedEstimator(estimator.Estimator): dnn_activation_fn=nn.relu, dnn_dropout=None, input_layer_partitioner=None, - config=None): + config=None, + linear_sparse_combiner='sum'): """Initializes a DNNLinearCombinedEstimator instance. Args: @@ -142,6 +143,11 @@ class DNNLinearCombinedEstimator(estimator.Estimator): input_layer_partitioner: Partitioner for input layer. Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. config: RunConfig object to configure the runtime settings. + linear_sparse_combiner: A string specifying how to reduce the linear model + if a categorical column is multivalent. One of "mean", "sqrtn", and + "sum" -- these are effectively different ways to do example-level + normalization, which can be useful for bag-of-words features. For more + details, see @{tf.feature_column.linear_model$linear_model}. Raises: ValueError: If both linear_feature_columns and dnn_features_columns are @@ -169,7 +175,8 @@ class DNNLinearCombinedEstimator(estimator.Estimator): dnn_activation_fn=dnn_activation_fn, dnn_dropout=dnn_dropout, input_layer_partitioner=input_layer_partitioner, - config=config) + config=config, + linear_sparse_combiner=linear_sparse_combiner) super(DNNLinearCombinedEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py index dd009a6753f3231638f93e50fc8f19eae8820139..51b9ce7005cec3910ba73db62a674e4628ca30a2 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py @@ -100,7 +100,8 @@ def _linear_only_estimator_fn( weight_column=None, optimizer='Ftrl', config=None, - partitioner=None): + partitioner=None, + sparse_combiner='sum'): return dnn_linear_combined.DNNLinearCombinedEstimator( head=head_lib.regression_head( weight_column=weight_column, label_dimension=label_dimension, @@ -110,7 +111,8 @@ def _linear_only_estimator_fn( linear_feature_columns=feature_columns, linear_optimizer=optimizer, input_layer_partitioner=partitioner, - config=config) + config=config, + linear_sparse_combiner=sparse_combiner) class LinearOnlyEstimatorEvaluateTest( diff --git a/tensorflow/contrib/estimator/python/estimator/early_stopping.py b/tensorflow/contrib/estimator/python/estimator/early_stopping.py new file mode 100644 index 0000000000000000000000000000000000000000..3eab21d5acaf26f14a73e7fa8e9c50fffc22fe9c --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/early_stopping.py @@ -0,0 +1,469 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for early stopping.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import operator +import os + +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging +from tensorflow.python.summary import summary_iterator +from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util + +_EVENT_FILE_GLOB_PATTERN = 'events.out.tfevents.*' + + +def make_early_stopping_hook(estimator, + should_stop_fn, + run_every_secs=60, + run_every_steps=None): + """Creates early-stopping hook. + + Returns a `SessionRunHook` that stops training when `should_stop_fn` returns + `True`. + + Usage example: + + ```python + estimator = ... + hook = early_stopping.make_early_stopping_hook( + estimator, should_stop_fn=make_stop_fn(...)) + train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) + tf.estimator.train_and_evaluate(estimator, train_spec, ...) + ``` + + Args: + estimator: A `tf.estimator.Estimator` instance. + should_stop_fn: `callable`, function that takes no arguments and returns a + `bool`. If the function returns `True`, stopping will be initiated by the + chief. + run_every_secs: If specified, calls `should_stop_fn` at an interval of + `run_every_secs` seconds. Defaults to 60 seconds. Either this or + `run_every_steps` must be set. + run_every_steps: If specified, calls `should_stop_fn` every + `run_every_steps` steps. Either this or `run_every_secs` must be set. + + Returns: + A `SessionRunHook` that periodically executes `should_stop_fn` and initiates + early stopping if the function returns `True`. + + Raises: + TypeError: If `estimator` is not of type `tf.estimator.Estimator`. + ValueError: If both `run_every_secs` and `run_every_steps` are set. + """ + if not isinstance(estimator, estimator_lib.Estimator): + raise TypeError('`estimator` must have type `tf.estimator.Estimator`. ' + 'Got: {}'.format(type(estimator))) + + if run_every_secs is not None and run_every_steps is not None: + raise ValueError('Only one of `run_every_secs` and `run_every_steps` must ' + 'be set.') + + if estimator.config.is_chief: + return _StopOnPredicateHook(should_stop_fn, run_every_secs, run_every_steps) + else: + return _CheckForStoppingHook() + + +def stop_if_higher_hook(estimator, + metric_name, + threshold, + eval_dir=None, + min_steps=0, + run_every_secs=60, + run_every_steps=None): + """Creates hook to stop if the given metric is higher than the threshold. + + Usage example: + + ```python + estimator = ... + # Hook to stop training if accuracy becomes higher than 0.9. + hook = early_stopping.stop_if_higher_hook(estimator, "accuracy", 0.9) + train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) + tf.estimator.train_and_evaluate(estimator, train_spec, ...) + ``` + + Args: + estimator: A `tf.estimator.Estimator` instance. + metric_name: `str`, metric to track. "loss", "accuracy", etc. + threshold: Numeric threshold for the given metric. + eval_dir: If set, directory containing summary files with eval metrics. By + default, `estimator.eval_dir()` will be used. + min_steps: `int`, stop is never requested if global step is less than this + value. Defaults to 0. + run_every_secs: If specified, calls `should_stop_fn` at an interval of + `run_every_secs` seconds. Defaults to 60 seconds. Either this or + `run_every_steps` must be set. + run_every_steps: If specified, calls `should_stop_fn` every + `run_every_steps` steps. Either this or `run_every_secs` must be set. + + Returns: + An early-stopping hook of type `SessionRunHook` that periodically checks + if the given metric is higher than specified threshold and initiates + early stopping if true. + """ + return _stop_if_threshold_crossed_hook( + estimator=estimator, + metric_name=metric_name, + threshold=threshold, + higher_is_better=True, + eval_dir=eval_dir, + min_steps=min_steps, + run_every_secs=run_every_secs, + run_every_steps=run_every_steps) + + +def stop_if_lower_hook(estimator, + metric_name, + threshold, + eval_dir=None, + min_steps=0, + run_every_secs=60, + run_every_steps=None): + """Creates hook to stop if the given metric is lower than the threshold. + + Usage example: + + ```python + estimator = ... + # Hook to stop training if loss becomes lower than 100. + hook = early_stopping.stop_if_lower_hook(estimator, "loss", 100) + train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) + tf.estimator.train_and_evaluate(estimator, train_spec, ...) + ``` + + Args: + estimator: A `tf.estimator.Estimator` instance. + metric_name: `str`, metric to track. "loss", "accuracy", etc. + threshold: Numeric threshold for the given metric. + eval_dir: If set, directory containing summary files with eval metrics. By + default, `estimator.eval_dir()` will be used. + min_steps: `int`, stop is never requested if global step is less than this + value. Defaults to 0. + run_every_secs: If specified, calls `should_stop_fn` at an interval of + `run_every_secs` seconds. Defaults to 60 seconds. Either this or + `run_every_steps` must be set. + run_every_steps: If specified, calls `should_stop_fn` every + `run_every_steps` steps. Either this or `run_every_secs` must be set. + + Returns: + An early-stopping hook of type `SessionRunHook` that periodically checks + if the given metric is lower than specified threshold and initiates + early stopping if true. + """ + return _stop_if_threshold_crossed_hook( + estimator=estimator, + metric_name=metric_name, + threshold=threshold, + higher_is_better=False, + eval_dir=eval_dir, + min_steps=min_steps, + run_every_secs=run_every_secs, + run_every_steps=run_every_steps) + + +def stop_if_no_increase_hook(estimator, + metric_name, + max_steps_without_increase, + eval_dir=None, + min_steps=0, + run_every_secs=60, + run_every_steps=None): + """Creates hook to stop if metric does not increase within given max steps. + + Usage example: + + ```python + estimator = ... + # Hook to stop training if accuracy does not increase in over 100000 steps. + hook = early_stopping.stop_if_no_increase_hook(estimator, "accuracy", 100000) + train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) + tf.estimator.train_and_evaluate(estimator, train_spec, ...) + ``` + + Args: + estimator: A `tf.estimator.Estimator` instance. + metric_name: `str`, metric to track. "loss", "accuracy", etc. + max_steps_without_increase: `int`, maximum number of training steps with no + increase in the given metric. + eval_dir: If set, directory containing summary files with eval metrics. By + default, `estimator.eval_dir()` will be used. + min_steps: `int`, stop is never requested if global step is less than this + value. Defaults to 0. + run_every_secs: If specified, calls `should_stop_fn` at an interval of + `run_every_secs` seconds. Defaults to 60 seconds. Either this or + `run_every_steps` must be set. + run_every_steps: If specified, calls `should_stop_fn` every + `run_every_steps` steps. Either this or `run_every_secs` must be set. + + Returns: + An early-stopping hook of type `SessionRunHook` that periodically checks + if the given metric shows no increase over given maximum number of + training steps, and initiates early stopping if true. + """ + return _stop_if_no_metric_improvement_hook( + estimator=estimator, + metric_name=metric_name, + max_steps_without_improvement=max_steps_without_increase, + higher_is_better=True, + eval_dir=eval_dir, + min_steps=min_steps, + run_every_secs=run_every_secs, + run_every_steps=run_every_steps) + + +def stop_if_no_decrease_hook(estimator, + metric_name, + max_steps_without_decrease, + eval_dir=None, + min_steps=0, + run_every_secs=60, + run_every_steps=None): + """Creates hook to stop if metric does not decrease within given max steps. + + Usage example: + + ```python + estimator = ... + # Hook to stop training if loss does not decrease in over 100000 steps. + hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 100000) + train_spec = tf.estimator.TrainSpec(..., hooks=[hook]) + tf.estimator.train_and_evaluate(estimator, train_spec, ...) + ``` + + Args: + estimator: A `tf.estimator.Estimator` instance. + metric_name: `str`, metric to track. "loss", "accuracy", etc. + max_steps_without_decrease: `int`, maximum number of training steps with no + decrease in the given metric. + eval_dir: If set, directory containing summary files with eval metrics. By + default, `estimator.eval_dir()` will be used. + min_steps: `int`, stop is never requested if global step is less than this + value. Defaults to 0. + run_every_secs: If specified, calls `should_stop_fn` at an interval of + `run_every_secs` seconds. Defaults to 60 seconds. Either this or + `run_every_steps` must be set. + run_every_steps: If specified, calls `should_stop_fn` every + `run_every_steps` steps. Either this or `run_every_secs` must be set. + + Returns: + An early-stopping hook of type `SessionRunHook` that periodically checks + if the given metric shows no decrease over given maximum number of + training steps, and initiates early stopping if true. + """ + return _stop_if_no_metric_improvement_hook( + estimator=estimator, + metric_name=metric_name, + max_steps_without_improvement=max_steps_without_decrease, + higher_is_better=False, + eval_dir=eval_dir, + min_steps=min_steps, + run_every_secs=run_every_secs, + run_every_steps=run_every_steps) + + +def read_eval_metrics(eval_dir): + """Helper to read eval metrics from eval summary files. + + Args: + eval_dir: Directory containing summary files with eval metrics. + + Returns: + A `dict` with global steps mapping to `dict` of metric names and values. + """ + eval_metrics_dict = {} + for event in _summaries(eval_dir): + if not event.HasField('summary'): + continue + metrics = {} + for value in event.summary.value: + if value.HasField('simple_value'): + metrics[value.tag] = value.simple_value + if metrics: + eval_metrics_dict[event.step] = metrics + return eval_metrics_dict + + +def _stop_if_threshold_crossed_hook(estimator, metric_name, threshold, + higher_is_better, eval_dir, min_steps, + run_every_secs, run_every_steps): + """Creates early-stopping hook to stop training if threshold is crossed.""" + + if eval_dir is None: + eval_dir = estimator.eval_dir() + + is_lhs_better = operator.gt if higher_is_better else operator.lt + greater_or_lesser = 'greater than' if higher_is_better else 'less than' + + def stop_if_threshold_crossed_fn(): + """Returns `True` if the given metric crosses specified threshold.""" + + eval_results = read_eval_metrics(eval_dir) + + for step, metrics in eval_results.items(): + if step < min_steps: + continue + val = metrics[metric_name] + if is_lhs_better(val, threshold): + tf_logging.info( + 'At step %s, metric "%s" has value %s which is %s the configured ' + 'threshold (%s) for early stopping.', step, metric_name, val, + greater_or_lesser, threshold) + return True + return False + + return make_early_stopping_hook( + estimator=estimator, + should_stop_fn=stop_if_threshold_crossed_fn, + run_every_secs=run_every_secs, + run_every_steps=run_every_steps) + + +def _stop_if_no_metric_improvement_hook( + estimator, metric_name, max_steps_without_improvement, higher_is_better, + eval_dir, min_steps, run_every_secs, run_every_steps): + """Returns hook to stop training if given metric shows no improvement.""" + + if eval_dir is None: + eval_dir = estimator.eval_dir() + + is_lhs_better = operator.gt if higher_is_better else operator.lt + increase_or_decrease = 'increase' if higher_is_better else 'decrease' + + def stop_if_no_metric_improvement_fn(): + """Returns `True` if metric does not improve within max steps.""" + + eval_results = read_eval_metrics(eval_dir) + + best_val = None + best_val_step = None + for step, metrics in eval_results.items(): + if step < min_steps: + continue + val = metrics[metric_name] + if best_val is None or is_lhs_better(val, best_val): + best_val = val + best_val_step = step + if step - best_val_step >= max_steps_without_improvement: + tf_logging.info( + 'No %s in metric "%s" for %s steps, which is greater than or equal ' + 'to max steps (%s) configured for early stopping.', + increase_or_decrease, metric_name, step - best_val_step, + max_steps_without_improvement) + return True + return False + + return make_early_stopping_hook( + estimator=estimator, + should_stop_fn=stop_if_no_metric_improvement_fn, + run_every_secs=run_every_secs, + run_every_steps=run_every_steps) + + +def _summaries(eval_dir): + """Yields `tensorflow.Event` protos from event files in the eval dir. + + Args: + eval_dir: Directory containing summary files with eval metrics. + + Yields: + `tensorflow.Event` object read from the event files. + """ + if gfile.Exists(eval_dir): + for event_file in gfile.Glob( + os.path.join(eval_dir, _EVENT_FILE_GLOB_PATTERN)): + for event in summary_iterator.summary_iterator(event_file): + yield event + + +def _get_or_create_stop_var(): + with variable_scope.variable_scope( + name_or_scope='signal_early_stopping', + values=[], + reuse=variable_scope.AUTO_REUSE): + return variable_scope.get_variable( + name='STOP', + shape=[], + dtype=dtypes.bool, + initializer=init_ops.constant_initializer(False), + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + trainable=False) + + +class _StopOnPredicateHook(session_run_hook.SessionRunHook): + """Hook that requests stop when `should_stop_fn` returns `True`.""" + + def __init__(self, should_stop_fn, run_every_secs=60, run_every_steps=None): + if not callable(should_stop_fn): + raise TypeError('`should_stop_fn` must be callable.') + + self._should_stop_fn = should_stop_fn + self._timer = basic_session_run_hooks.SecondOrStepTimer( + every_secs=run_every_secs, every_steps=run_every_steps) + self._global_step_tensor = None + self._stop_var = None + self._stop_op = None + + def begin(self): + self._global_step_tensor = training_util.get_global_step() + self._stop_var = _get_or_create_stop_var() + self._stop_op = state_ops.assign(self._stop_var, True) + + def before_run(self, run_context): + del run_context + return session_run_hook.SessionRunArgs(self._global_step_tensor) + + def after_run(self, run_context, run_values): + global_step = run_values.results + if self._timer.should_trigger_for_step(global_step): + self._timer.update_last_triggered_step(global_step) + if self._should_stop_fn(): + tf_logging.info('Requesting early stopping at global step %d', + global_step) + run_context.session.run(self._stop_op) + run_context.request_stop() + + +class _CheckForStoppingHook(session_run_hook.SessionRunHook): + """Hook that requests stop if stop is requested by `_StopOnPredicateHook`.""" + + def __init__(self): + self._stop_var = None + + def begin(self): + self._stop_var = _get_or_create_stop_var() + + def before_run(self, run_context): + del run_context + return session_run_hook.SessionRunArgs(self._stop_var) + + def after_run(self, run_context, run_values): + should_early_stop = run_values.results + if should_early_stop: + tf_logging.info('Early stopping requested, suspending run.') + run_context.request_stop() diff --git a/tensorflow/contrib/estimator/python/estimator/early_stopping_test.py b/tensorflow/contrib/estimator/python/estimator/early_stopping_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e4bfd4b446b9413bd1627ef6904ff2dc9f1a9120 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/early_stopping_test.py @@ -0,0 +1,246 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for early_stopping.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +from absl.testing import parameterized +from tensorflow.contrib.estimator.python.estimator import early_stopping +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import run_config +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.platform import test +from tensorflow.python.training import monitored_session +from tensorflow.python.training import training_util + + +class _FakeRunConfig(run_config.RunConfig): + + def __init__(self, is_chief): + super(_FakeRunConfig, self).__init__() + self._is_chief = is_chief + + @property + def is_chief(self): + return self._is_chief + + +def _dummy_model_fn(features, labels, params): + _, _, _ = features, labels, params + + +class _FakeEstimator(estimator.Estimator): + """Fake estimator for testing.""" + + def __init__(self, config): + super(_FakeEstimator, self).__init__( + model_fn=_dummy_model_fn, config=config) + + +def _write_events(eval_dir, params): + """Test helper to write events to summary files.""" + for steps, loss, accuracy in params: + estimator._write_dict_to_summary(eval_dir, { + 'loss': loss, + 'accuracy': accuracy, + }, steps) + + +class ReadEvalMetricsTest(test.TestCase): + + def test_read_eval_metrics(self): + eval_dir = tempfile.mkdtemp() + _write_events( + eval_dir, + [ + # steps, loss, accuracy + (1000, 1, 2), + (2000, 3, 4), + (3000, 5, 6), + ]) + self.assertEqual({ + 1000: { + 'loss': 1, + 'accuracy': 2 + }, + 2000: { + 'loss': 3, + 'accuracy': 4 + }, + 3000: { + 'loss': 5, + 'accuracy': 6 + }, + }, early_stopping.read_eval_metrics(eval_dir)) + + def test_read_eval_metrics_when_no_events(self): + eval_dir = tempfile.mkdtemp() + self.assertTrue(os.path.exists(eval_dir)) + + # No error should be raised when eval directory exists with no event files. + self.assertEqual({}, early_stopping.read_eval_metrics(eval_dir)) + + os.rmdir(eval_dir) + self.assertFalse(os.path.exists(eval_dir)) + + # No error should be raised when eval directory does not exist. + self.assertEqual({}, early_stopping.read_eval_metrics(eval_dir)) + + +class EarlyStoppingHooksTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + config = _FakeRunConfig(is_chief=True) + self._estimator = _FakeEstimator(config=config) + eval_dir = self._estimator.eval_dir() + os.makedirs(eval_dir) + _write_events( + eval_dir, + [ + # steps, loss, accuracy + (1000, 0.8, 0.5), + (2000, 0.7, 0.6), + (3000, 0.4, 0.7), + (3500, 0.41, 0.68), + ]) + + def run_session(self, hooks, should_stop): + hooks = hooks if isinstance(hooks, list) else [hooks] + with ops.Graph().as_default(): + training_util.create_global_step() + no_op = control_flow_ops.no_op() + with monitored_session.SingularMonitoredSession(hooks=hooks) as mon_sess: + mon_sess.run(no_op) + self.assertEqual(mon_sess.should_stop(), should_stop) + + @parameterized.parameters((0.8, 0, False), (0.6, 4000, False), (0.6, 0, True)) + def test_stop_if_higher_hook(self, threshold, min_steps, should_stop): + self.run_session( + early_stopping.stop_if_higher_hook( + self._estimator, + metric_name='accuracy', + threshold=threshold, + min_steps=min_steps), should_stop) + + @parameterized.parameters((0.3, 0, False), (0.5, 4000, False), (0.5, 0, True)) + def test_stop_if_lower_hook(self, threshold, min_steps, should_stop): + self.run_session( + early_stopping.stop_if_lower_hook( + self._estimator, + metric_name='loss', + threshold=threshold, + min_steps=min_steps), should_stop) + + @parameterized.parameters((1500, 0, False), (500, 4000, False), + (500, 0, True)) + def test_stop_if_no_increase_hook(self, max_steps, min_steps, should_stop): + self.run_session( + early_stopping.stop_if_no_increase_hook( + self._estimator, + metric_name='accuracy', + max_steps_without_increase=max_steps, + min_steps=min_steps), should_stop) + + @parameterized.parameters((1500, 0, False), (500, 4000, False), + (500, 0, True)) + def test_stop_if_no_decrease_hook(self, max_steps, min_steps, should_stop): + self.run_session( + early_stopping.stop_if_no_decrease_hook( + self._estimator, + metric_name='loss', + max_steps_without_decrease=max_steps, + min_steps=min_steps), should_stop) + + @parameterized.parameters((1500, 0.3, False), (1500, 0.5, True), + (500, 0.3, True)) + def test_multiple_hooks(self, max_steps, loss_threshold, should_stop): + self.run_session([ + early_stopping.stop_if_no_decrease_hook( + self._estimator, + metric_name='loss', + max_steps_without_decrease=max_steps), + early_stopping.stop_if_lower_hook( + self._estimator, metric_name='loss', threshold=loss_threshold) + ], should_stop) + + @parameterized.parameters(False, True) + def test_make_early_stopping_hook(self, should_stop): + self.run_session([ + early_stopping.make_early_stopping_hook( + self._estimator, should_stop_fn=lambda: should_stop) + ], should_stop) + + def test_make_early_stopping_hook_typeerror(self): + with self.assertRaises(TypeError): + early_stopping.make_early_stopping_hook( + estimator=object(), should_stop_fn=lambda: True) + + def test_make_early_stopping_hook_valueerror(self): + with self.assertRaises(ValueError): + early_stopping.make_early_stopping_hook( + self._estimator, + should_stop_fn=lambda: True, + run_every_secs=60, + run_every_steps=100) + + +class StopOnPredicateHookTest(test.TestCase): + + def test_stop(self): + hook = early_stopping._StopOnPredicateHook( + should_stop_fn=lambda: False, run_every_secs=0) + with ops.Graph().as_default(): + training_util.create_global_step() + no_op = control_flow_ops.no_op() + with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess: + mon_sess.run(no_op) + self.assertFalse(mon_sess.should_stop()) + self.assertFalse(mon_sess.raw_session().run(hook._stop_var)) + + hook = early_stopping._StopOnPredicateHook( + should_stop_fn=lambda: True, run_every_secs=0) + with ops.Graph().as_default(): + training_util.create_global_step() + no_op = control_flow_ops.no_op() + with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess: + mon_sess.run(no_op) + self.assertTrue(mon_sess.should_stop()) + self.assertTrue(mon_sess.raw_session().run(hook._stop_var)) + + +class CheckForStoppingHookTest(test.TestCase): + + def test_stop(self): + hook = early_stopping._CheckForStoppingHook() + with ops.Graph().as_default(): + no_op = control_flow_ops.no_op() + assign_op = state_ops.assign(early_stopping._get_or_create_stop_var(), + True) + with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess: + mon_sess.run(no_op) + self.assertFalse(mon_sess.should_stop()) + mon_sess.run(assign_op) + self.assertTrue(mon_sess.should_stop()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index c9d86ef4ab89950b0c7b0414ba60d9e0a1cbe476..34f765d56546d3cd10fcde5ac444a221c73602cd 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -943,20 +943,30 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access class_probabilities = array_ops.slice( probabilities, begin=begin, size=size) class_labels = array_ops.slice(labels, begin=begin, size=size) - prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id + if self._label_vocabulary is None: + prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id + else: + prob_key = ( + keys.PROBABILITY_MEAN_AT_NAME % self._label_vocabulary[class_id]) metric_ops[head_lib._summary_key(self._name, prob_key)] = ( # pylint:disable=protected-access head_lib._predictions_mean( # pylint:disable=protected-access predictions=class_probabilities, weights=weights, name=prob_key)) - auc_key = keys.AUC_AT_CLASS % class_id + if self._label_vocabulary is None: + auc_key = keys.AUC_AT_CLASS % class_id + else: + auc_key = keys.AUC_AT_NAME % self._label_vocabulary[class_id] metric_ops[head_lib._summary_key(self._name, auc_key)] = ( # pylint:disable=protected-access head_lib._auc( # pylint:disable=protected-access labels=class_labels, predictions=class_probabilities, weights=weights, name=auc_key)) - auc_pr_key = keys.AUC_PR_AT_CLASS % class_id + if self._label_vocabulary is None: + auc_pr_key = keys.AUC_PR_AT_CLASS % class_id + else: + auc_pr_key = keys.AUC_PR_AT_NAME % self._label_vocabulary[class_id] metric_ops[head_lib._summary_key(self._name, auc_pr_key)] = ( # pylint:disable=protected-access head_lib._auc( # pylint:disable=protected-access labels=class_labels, diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index 7b884402d4650636bc9fe053994246aabb9c312d..2d367adb47080a630d1d2ef5ecfd4e8d5d0377d9 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -694,12 +694,14 @@ class MultiLabelHead(test.TestCase): # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, keys.AUC_PR: 0.7639, - keys.PROBABILITY_MEAN_AT_CLASS % 0: np.sum(_sigmoid(logits[:, 0])) / 2., - keys.AUC_AT_CLASS % 0: 0., - keys.AUC_PR_AT_CLASS % 0: 1., - keys.PROBABILITY_MEAN_AT_CLASS % 1: np.sum(_sigmoid(logits[:, 1])) / 2., - keys.AUC_AT_CLASS % 1: 1., - keys.AUC_PR_AT_CLASS % 1: 1., + keys.PROBABILITY_MEAN_AT_NAME % 'a': + np.sum(_sigmoid(logits[:, 0])) / 2., + keys.AUC_AT_NAME % 'a': 0., + keys.AUC_PR_AT_NAME % 'a': 1., + keys.PROBABILITY_MEAN_AT_NAME % 'b': + np.sum(_sigmoid(logits[:, 1])) / 2., + keys.AUC_AT_NAME % 'b': 1., + keys.AUC_PR_AT_NAME % 'b': 1., } self._test_eval( diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py index ddd6aa442f82bad2d4714dbcdc85b20b34773068..caadafdfa6972c141d32a705e62a98d220cace41 100644 --- a/tensorflow/contrib/estimator/python/estimator/hooks.py +++ b/tensorflow/contrib/estimator/python/estimator/hooks.py @@ -189,7 +189,7 @@ class InMemoryEvaluatorHook(training.SessionRunHook): init_fn=feed_variables, copy_from_scaffold=self._scaffold) with self._graph.as_default(): - return self._estimator._evaluate_run( + self._estimator._evaluate_run( checkpoint_path=None, scaffold=scaffold, update_op=self._update_op, diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py index 95ae971852ee6dffb6174fc243686721c30ef685..ee88d5ecf50aa15b2faa0f3e136c686b5b0ef62a 100644 --- a/tensorflow/contrib/estimator/python/estimator/hooks_test.py +++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py @@ -102,6 +102,7 @@ class InMemoryEvaluatorHookTest(test.TestCase): self.assertTrue(os.path.isdir(estimator.eval_dir())) step_keyword_to_value = summary_step_keyword_to_value_mapping( estimator.eval_dir()) + # 4.5 = sum(range(10))/10 # before training self.assertEqual(4.5, step_keyword_to_value[0]['mean_of_features']) @@ -110,6 +111,7 @@ class InMemoryEvaluatorHookTest(test.TestCase): self.assertEqual(4.5, step_keyword_to_value[8]['mean_of_features']) # end self.assertEqual(4.5, step_keyword_to_value[10]['mean_of_features']) + self.assertEqual(set([0, 4, 8, 10]), set(step_keyword_to_value.keys())) def test_uses_latest_variable_value(self): diff --git a/tensorflow/contrib/estimator/python/estimator/linear.py b/tensorflow/contrib/estimator/python/estimator/linear.py index b960b16f1ba6b1bf8046c922e21ac1ed136c599e..62a37abefb1f6ed291df1df3da6de35bfd2b6c52 100644 --- a/tensorflow/contrib/estimator/python/estimator/linear.py +++ b/tensorflow/contrib/estimator/python/estimator/linear.py @@ -99,7 +99,8 @@ class LinearEstimator(estimator.Estimator): model_dir=None, optimizer='Ftrl', config=None, - partitioner=None): + partitioner=None, + sparse_combiner='sum'): """Initializes a `LinearEstimator` instance. Args: @@ -116,6 +117,11 @@ class LinearEstimator(estimator.Estimator): callable. Defaults to FTRL optimizer. config: `RunConfig` object to configure the runtime settings. partitioner: Optional. Partitioner for input layer. + sparse_combiner: A string specifying how to reduce if a categorical column + is multivalent. One of "mean", "sqrtn", and "sum" -- these are + effectively different ways to do example-level normalization, which can + be useful for bag-of-words features. for more details, see + @{tf.feature_column.linear_model$linear_model}. """ def _model_fn(features, labels, mode, config): return linear_lib._linear_model_fn( # pylint: disable=protected-access @@ -126,6 +132,7 @@ class LinearEstimator(estimator.Estimator): feature_columns=tuple(feature_columns or []), optimizer=optimizer, partitioner=partitioner, - config=config) + config=config, + sparse_combiner=sparse_combiner) super(LinearEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..ce98e9987ec728fadf170e56fe4bfe24fc9a0105 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py @@ -0,0 +1,449 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class that creates an Estimator from a SavedModel.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export as export_lib +from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import loader_impl +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import monitored_session +from tensorflow.python.training import training_util + + +class SavedModelEstimator(estimator_lib.Estimator): + """Create an Estimator from a SavedModel. + + Only SavedModels exported with + `tf.contrib.estimator.export_all_saved_models()` or + `tf.estimator.Estimator.export_savedmodel()` are supported for this class. + + Example with `tf.estimator.DNNClassifier`: + + **Step 1: Create and train DNNClassifier.** + + ```python + feature1 = tf.feature_column.embedding_column( + tf.feature_column.categorical_column_with_vocabulary_list( + key='feature1', vocabulary_list=('green', 'yellow')), dimension=1) + feature2 = tf.feature_column.numeric_column(key='feature2', default_value=0.0) + + classifier = tf.estimator.DNNClassifier( + hidden_units=[4,2], feature_columns=[feature1, feature2]) + + def input_fn(): + features = {'feature1': tf.constant(['green', 'green', 'yellow']), + 'feature2': tf.constant([3.5, 4.2, 6.1])} + label = tf.constant([1., 0., 0.]) + return tf.data.Dataset.from_tensors((features, label)).repeat() + + classifier.train(input_fn=input_fn, steps=10) + ``` + + **Step 2: Export classifier.** + First, build functions that specify the expected inputs. + + ```python + # During train and evaluation, both the features and labels should be defined. + supervised_input_receiver_fn = ( + tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + {'feature1': tf.placeholder(dtype=tf.string, shape=[None]), + 'feature2': tf.placeholder(dtype=tf.float32, shape=[None])}, + tf.placeholder(dtype=tf.float32, shape=[None]))) + + # During predict mode, expect to receive a `tf.Example` proto, so a parsing + # function is used. + serving_input_receiver_fn = ( + tf.estimator.export.build_parsing_serving_input_receiver_fn( + tf.feature_column.make_parse_example_spec([feature1, feature2]))) + ``` + + Next, export the model as a SavedModel. A timestamped directory will be + created (for example `/tmp/export_all/1234567890`). + + ```python + # Option 1: Save all modes (train, eval, predict) + export_dir = tf.contrib.estimator.export_all_saved_models( + classifier, '/tmp/export_all', + {tf.estimator.ModeKeys.TRAIN: supervised_input_receiver_fn, + tf.estimator.ModeKeys.EVAL: supervised_input_receiver_fn, + tf.estimator.ModeKeys.PREDICT: serving_input_receiver_fn}) + + # Option 2: Only export predict mode + export_dir = classifier.export_savedmodel( + '/tmp/export_predict', serving_input_receiver_fn) + ``` + + **Step 3: Create a SavedModelEstimator from the exported SavedModel.** + + ```python + est = tf.contrib.estimator.SavedModelEstimator(export_dir) + + # If all modes were exported, you can immediately evaluate and predict, or + # continue training. Otherwise only predict is available. + eval_results = est.evaluate(input_fn=input_fn, steps=1) + print(eval_results) + + est.train(input_fn=input_fn, steps=20) + + def predict_input_fn(): + example = tf.train.Example() + example.features.feature['feature1'].bytes_list.value.extend(['yellow']) + example.features.feature['feature2'].float_list.value.extend([1.]) + return {'inputs':tf.constant([example.SerializeToString()])} + + predictions = est.predict(predict_input_fn) + print(next(predictions)) + ``` + """ + + def __init__(self, saved_model_dir, model_dir=None): + """Initialize a SavedModelEstimator. + + The SavedModelEstimator loads its model function and variable values from + the graphs defined in the SavedModel. There is no option to pass in + `RunConfig` or `params` arguments, because the model function graph is + defined statically in the SavedModel. + + Args: + saved_model_dir: Directory containing SavedModel protobuf and subfolders. + model_dir: Directory to save new checkpoints during training. + + Raises: + NotImplementedError: If a DistributionStrategy is defined in the config. + Unless the SavedModelEstimator is subclassed, this shouldn't happen. + """ + checkpoint = estimator_lib._get_saved_model_ckpt(saved_model_dir) # pylint: disable=protected-access + vars_to_warm_start = [name for name, _ in + checkpoint_utils.list_variables(checkpoint)] + warm_start_settings = estimator_lib.WarmStartSettings( + ckpt_to_initialize_from=checkpoint, + vars_to_warm_start=vars_to_warm_start) + + super(SavedModelEstimator, self).__init__( + model_fn=self._model_fn_from_saved_model, model_dir=model_dir, + warm_start_from=warm_start_settings) + if self._train_distribution or self._eval_distribution: + raise NotImplementedError( + 'SavedModelEstimator currently does not support ' + 'DistributionStrategy.') + self.saved_model_dir = saved_model_dir + self.saved_model_loader = loader_impl.SavedModelLoader(saved_model_dir) + self._available_modes = self._extract_available_modes() + + def _extract_available_modes(self): + """Return list of modes found in SavedModel.""" + available_modes = [] + logging.info('Checking available modes for SavedModelEstimator.') + for mode in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, + model_fn_lib.ModeKeys.PREDICT]: + try: + self._get_meta_graph_def_for_mode(mode) + except RuntimeError: + logging.warning('%s mode not found in SavedModel.' % mode) + continue + + if self._get_signature_def_for_mode(mode) is not None: + available_modes.append(mode) + + logging.info('Available modes for Estimator: %s' % available_modes) + return available_modes + + def _validate_mode(self, mode): + """Make sure that mode can be run using the SavedModel.""" + if mode not in self._available_modes: + raise RuntimeError('%s mode is not available in the SavedModel. Use ' + 'saved_model_cli to check that the Metagraph for this ' + 'mode has been exported.' % mode) + + def _get_meta_graph_def_for_mode(self, mode): + tags = model_fn_lib.EXPORT_TAG_MAP[mode] + return self.saved_model_loader.get_meta_graph_def_from_tags(tags) + + def _get_signature_def_for_mode(self, mode): + meta_graph_def = self._get_meta_graph_def_for_mode(mode) + sig_def_key = (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + if mode == model_fn_lib.ModeKeys.PREDICT else mode) + if sig_def_key not in meta_graph_def.signature_def: + logging.warning('Metagraph for mode %s was found, but SignatureDef with' + ' key \"%s\" is missing.' % (mode, sig_def_key)) + return None + return meta_graph_def.signature_def[sig_def_key] + + def _create_and_assert_global_step(self, graph): + # Do nothing here. The global step variable will be created/loaded from the + # SavedModel. If a global step variable were created here, the result + # will be two duplicate global step variables, causing issues during + # the warm-start phase. + # Due to the global variable being created in the model function, this may + # cause issues when running DistributionStrategy. Thus, DistributionStrategy + # is not yet supported with SavedModelEstimator. + return None + + def _model_fn_from_saved_model(self, features, labels, mode): + """Load a SavedModel graph and return an EstimatorSpec.""" + # TODO(kathywu): Model function loads placeholders from the graph. Calling + # export_all_saved_models creates another placeholder for the inputs, on top + # of the original placeholders. There should be a way to avoid this. + self._validate_mode(mode) + + g = ops.get_default_graph() + if training_util.get_global_step(g) is not None: + raise RuntimeError( + 'Graph must not contain a global step tensor before the SavedModel is' + ' loaded. Please make sure that the input function does not create a ' + 'global step.') + + # Extract SignatureDef for information about the input and output tensors. + signature_def = self._get_signature_def_for_mode(mode) + + # Generate input map for replacing the inputs in the SavedModel graph with + # the provided features and labels. + input_map = _generate_input_map(signature_def, features, labels) + + # Create a list of the names of output tensors. When the graph is loaded, + # names of the output tensors may be remapped. This ensures that the correct + # tensors are returned in the EstimatorSpec. + output_tensor_names = [ + value.name for value in six.itervalues(signature_def.outputs)] + + # Load the graph. `output_tensors` contains output `Tensors` in the same + # same order as the `output_tensor_names` list. + tags = model_fn_lib.EXPORT_TAG_MAP[mode] + _, output_tensors = self.saved_model_loader.load_graph( + g, tags, input_map=input_map, return_elements=output_tensor_names) + + # Create a scaffold from the MetaGraphDef that contains ops to initialize + # the graph. This should mirror the steps from _add_meta_graph_for_mode(), + # which creates a MetaGraphDef from the EstimatorSpec's scaffold. + scaffold = monitored_session.Scaffold( + local_init_op=loader_impl._get_main_op_tensor( # pylint: disable=protected-access + self._get_meta_graph_def_for_mode(mode))) + + # Ensure that a global step tensor has been created. + global_step_tensor = training_util.get_global_step(g) + training_util.assert_global_step(global_step_tensor) + + # Extract values to return in the EstimatorSpec. + output_map = dict(zip(output_tensor_names, output_tensors)) + outputs = {key: output_map[value.name] + for key, value in six.iteritems(signature_def.outputs)} + + loss, predictions, metrics = _validate_and_extract_outputs( + mode, outputs, signature_def.method_name) + + train_op = ops.get_collection(constants.TRAIN_OP_KEY) + if len(train_op) > 1: + raise RuntimeError('Multiple ops found in the train_op collection.') + train_op = None if not train_op else train_op[0] + + _clear_saved_model_collections() + return model_fn_lib.EstimatorSpec( + scaffold=scaffold, + mode=mode, + loss=loss, + train_op=train_op, + predictions=predictions, + eval_metric_ops=metrics) + + +def _clear_saved_model_collections(): + """Clear collections that are expected empty when exporting a SavedModel. + + The SavedModel builder uses these collections to track ops necessary to + restore the graph state. These collections are expected to be empty before + MetaGraphs are added to the builder. + """ + del ops.get_collection_ref(constants.ASSETS_KEY)[:] + del ops.get_collection_ref(constants.LEGACY_INIT_OP_KEY)[:] + del ops.get_collection_ref(constants.MAIN_OP_KEY)[:] + del ops.get_collection_ref(constants.TRAIN_OP_KEY)[:] + + +def _generate_input_map(signature_def, features, labels): + """Return dict mapping an input tensor name to a feature or label tensor. + + Args: + signature_def: SignatureDef loaded from SavedModel + features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or + `SparseTensor`, specifying the features to be passed to the model. + labels: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or + `SparseTensor`, specifying the labels to be passed to the model. May be + `None`. + + Returns: + dict mapping string names of inputs to features or labels tensors + + Raises: + ValueError: if SignatureDef inputs are not completely mapped by the input + features and labels. + """ + # pylint: disable=protected-access + if not isinstance(features, dict): + features = {export_lib._SINGLE_FEATURE_DEFAULT_NAME: features} + if labels is not None and not isinstance(labels, dict): + labels = {export_lib._SINGLE_LABEL_DEFAULT_NAME: labels} + # pylint: enable=protected-access + + inputs = signature_def.inputs + input_map = {} + for key, tensor_info in six.iteritems(inputs): + input_name = tensor_info.name + if ':' in input_name: + input_name = input_name[:input_name.find(':')] + + # When tensors are used as control inputs for operations, their names are + # prepended with a '^' character in the GraphDef. To handle possible control + # flow edge cases, control input names must be included in the input map. + control_dependency_name = '^' + input_name + + if key in features: + _check_same_dtype_and_shape(features[key], tensor_info, key) + input_map[input_name] = input_map[control_dependency_name] = features[key] + elif labels is not None and key in labels: + _check_same_dtype_and_shape(labels[key], tensor_info, key) + input_map[input_name] = input_map[control_dependency_name] = labels[key] + else: + raise ValueError( + 'Key \"%s\" not found in features or labels passed in to the model ' + 'function. All required keys: %s' % (key, inputs.keys())) + + return input_map + + +def _check_same_dtype_and_shape(tensor, tensor_info, name): + """Validate that tensor has the same properties as the TensorInfo proto. + + Args: + tensor: a `Tensor` object. + tensor_info: a `TensorInfo` proto. + name: Name of the input (to identify Tensor if an error is raised). + + Raises: + ValueError: If the tensor shape or dtype don't match the TensorInfo + """ + dtype_error = (tensor.dtype != dtypes.DType(tensor_info.dtype)) + shape_error = not tensor.shape.is_compatible_with(tensor_info.tensor_shape) + + if dtype_error or shape_error: + msg = 'Tensor shape and/or dtype validation failed for input %s:' % name + if dtype_error: + msg += ('\n\tExpected dtype: %s, Got: %s' + % (dtypes.DType(tensor_info.dtype), tensor.dtype)) + if shape_error: + msg += ('\n\tExpected shape: %s, Got: %s' + % (tensor_shape.TensorShape(tensor_info.tensor_shape), + tensor.shape)) + + raise ValueError(msg) + + +def _extract_eval_metrics(output_dict): + """Return a eval metric dict extracted from the output_dict. + + Eval metrics consist of a value tensor and an update op. Both must be in the + passed-in tensor dictionary for an eval metric to be added to the returned + dictionary. + + Args: + output_dict: a dict that maps strings to tensors. + + Returns: + dict mapping strings to (value, update_op) tuples. + """ + # pylint: disable=protected-access + metric_ops = {} + separator_char = export_output._SupervisedOutput._SEPARATOR_CHAR + + for key, tensor in six.iteritems(output_dict): + split_key = key.split(separator_char) + + # The metric name may contain the separator character, so recreate its name. + metric_name = separator_char.join(split_key[:-1]) + + if split_key[0] == export_output._SupervisedOutput.METRICS_NAME: + # If the key ends with the value suffix, and there is a corresponding + # key ending with the update_op suffix, then add tensors to metrics dict. + if split_key[-1] == export_output._SupervisedOutput.METRIC_VALUE_SUFFIX: + update_op = ''.join( + [metric_name, separator_char, + export_output._SupervisedOutput.METRIC_UPDATE_SUFFIX]) + if update_op in output_dict: + update_op_tensor = output_dict[update_op] + metric_ops[metric_name] = (tensor, update_op_tensor) + + # pylint: enable=protected-access + return metric_ops + + +def _validate_and_extract_outputs(mode, output_dict, method_name): + """Extract values from SignatureDef output dictionary. + + Args: + mode: One of the modes enumerated in `tf.estimator.ModeKeys`. + output_dict: dict of string SignatureDef keys to `Tensor`. + method_name: Method name of the SignatureDef as a string. + + Returns: + Tuple of ( + loss: `Tensor` object, + predictions: dictionary mapping string keys to `Tensor` objects, + metrics: dictionary mapping string keys to a tuple of two `Tensor` objects + ) + + Raises: + RuntimeError: raised if SignatureDef has an invalid method name for the mode + """ + # pylint: disable=protected-access + loss, predictions, metrics = None, None, None + + if mode == model_fn_lib.ModeKeys.PREDICT: + predictions = output_dict + else: + # Validate that the SignatureDef's method name matches the expected name for + # the given mode. + expected_method_name = signature_constants.SUPERVISED_TRAIN_METHOD_NAME + if mode == model_fn_lib.ModeKeys.EVAL: + expected_method_name = signature_constants.SUPERVISED_EVAL_METHOD_NAME + if method_name != expected_method_name: + raise RuntimeError( + 'Invalid SignatureDef method name for mode %s.\n\tExpected: %s\n\t' + 'Got: %s\nPlease ensure that the SavedModel was exported with ' + '`tf.contrib.estimator.export_all_saved_models()`.' % + (mode, expected_method_name, method_name)) + + # Extract loss, metrics and predictions from the output dict. + loss = output_dict[export_output._SupervisedOutput.LOSS_NAME] + metrics = _extract_eval_metrics(output_dict) + predictions = { + key: value for key, value in six.iteritems(output_dict) + if key.split(export_output._SupervisedOutput._SEPARATOR_CHAR)[0] == ( + export_output._SupervisedOutput.PREDICTIONS_NAME)} + + # pylint: enable=protected-access + return loss, predictions, metrics diff --git a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..718da1367ce69285f37269c5631fa0be2b050c97 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py @@ -0,0 +1,369 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SavedModelEstimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile + +from tensorflow.contrib.estimator.python.estimator import export as contrib_export +from tensorflow.contrib.estimator.python.estimator import saved_model_estimator +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import monitored_session +from tensorflow.python.training import training + + +def dummy_input_fn(): + return dataset_ops.Dataset.from_tensors(( + {'x': constant_op.constant([[1], [-2]], dtype=dtypes.int64)}, + constant_op.constant([[4], [-3]], dtype=dtypes.float32))).repeat() + + +def dummy_input_fn_features_only(): + return dataset_ops.Dataset.from_tensors( + {'x': constant_op.constant([[5], [6]], dtype=dtypes.int64)}).repeat() + + +def dummy_supervised_receiver_fn(): + feature_spec = { + 'x': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_x'), + } + label_spec = array_ops.placeholder( + dtype=dtypes.float32, shape=[2, 1], name='truth') + return export.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + +def dummy_serving_receiver_fn(): + feature_spec = {'x': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_x'),} + return export.build_raw_serving_input_receiver_fn(feature_spec) + + +def model_fn_diff_modes(features, labels, mode): + _, _ = features, labels + v = variables.Variable(21, name='some_var') + train_op = None + loss = constant_op.constant(104) + if mode == model_fn_lib.ModeKeys.TRAIN: + loss = constant_op.constant(105) + predictions = constant_op.constant([501]) + train_op = control_flow_ops.group( + state_ops.assign_add(training.get_global_step(), 1), + state_ops.assign_add(v, 3)) + elif mode == model_fn_lib.ModeKeys.EVAL: + loss = constant_op.constant(106) + predictions = constant_op.constant([502]) + else: + loss = constant_op.constant(107) + predictions = constant_op.constant([503]) + return model_fn_lib.EstimatorSpec( + mode, + loss=loss, + train_op=train_op, + eval_metric_ops={ + 'abs_err': metrics_lib.mean_absolute_error( + constant_op.constant(0), predictions)}, + predictions=predictions) + + +class SavedModelEstimatorTest(test.TestCase): + + def setUp(self): + self.tmpdirs = [] + + def tearDown(self): + for tmpdir in self.tmpdirs: + # gfile.DeleteRecursively fails in the windows cmake test, so use shutil. + shutil.rmtree(tmpdir, ignore_errors=True) + self.tmpdirs = [] + + def _get_tmp_dir(self): + tmpdir = tempfile.mkdtemp() + self.tmpdirs.append(tmpdir) + return tmpdir + + def _export_estimator(self, train=True, evaluate=True, predict=True, + model_fn=model_fn_diff_modes): + est = estimator.Estimator(model_fn, self._get_tmp_dir()) + est.train(input_fn=dummy_input_fn, steps=10) + + input_receiver_fn_map = {} + if train: + input_receiver_fn_map[model_fn_lib.ModeKeys.TRAIN] = ( + dummy_supervised_receiver_fn()) + if evaluate: + input_receiver_fn_map[model_fn_lib.ModeKeys.EVAL] = ( + dummy_supervised_receiver_fn()) + if predict: + input_receiver_fn_map[model_fn_lib.ModeKeys.PREDICT] = ( + dummy_serving_receiver_fn()) + + export_base_path = self._get_tmp_dir() + export_dir = contrib_export.export_all_saved_models( + est, export_base_path, input_receiver_fn_map) + return export_dir + + def test_load_all_modes(self): + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(), self._get_tmp_dir()) + sme.train(input_fn=dummy_input_fn, steps=1) + sme.train(input_fn=dummy_input_fn, steps=2) + self.assertEqual(13, sme.get_variable_value('global_step')) + self.assertEqual(60, sme.get_variable_value('some_var')) + + eval_results = sme.evaluate(dummy_input_fn, steps=5) + + self.assertEqual(13, eval_results['global_step']) + self.assertEqual(106, eval_results['loss']) + self.assertEqual(502, eval_results['metrics/abs_err']) + + predictions = next(sme.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'output': 503}, predictions) + + def test_load_all_modes_no_train(self): + """Ensure that all functions can be used without requiring a ckpt.""" + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(), self._get_tmp_dir()) + eval_results = sme.evaluate(dummy_input_fn, steps=5) + self.assertEqual(10, eval_results['global_step']) + self.assertEqual(106, eval_results['loss']) + self.assertEqual(502, eval_results['metrics/abs_err']) + + predictions = next(sme.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'output': 503}, predictions) + + def test_partial_exported_estimator(self): + sme1 = saved_model_estimator.SavedModelEstimator( + self._export_estimator(train=False, predict=False), self._get_tmp_dir()) + sme1.evaluate(dummy_input_fn, steps=5) + with self.assertRaisesRegexp(RuntimeError, 'train mode is not available'): + sme1.train(input_fn=dummy_input_fn, steps=1) + with self.assertRaisesRegexp(RuntimeError, 'infer mode is not available'): + next(sme1.predict(dummy_input_fn_features_only)) + + sme2 = saved_model_estimator.SavedModelEstimator( + self._export_estimator(evaluate=False), self._get_tmp_dir()) + sme2.train(input_fn=dummy_input_fn, steps=1) + next(sme2.predict(dummy_input_fn_features_only)) + with self.assertRaisesRegexp(RuntimeError, 'eval mode is not available'): + sme2.evaluate(dummy_input_fn, steps=5) + + def test_with_incorrect_input(self): + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(), self._get_tmp_dir()) + + def bad_shape_input_fn(): + return dataset_ops.Dataset.from_tensors(( + {'x': constant_op.constant([1, 2], dtype=dtypes.int64)}, + constant_op.constant([1, 2], dtype=dtypes.float32))) + + with self.assertRaisesRegexp(ValueError, 'Expected shape'): + sme.train(bad_shape_input_fn, steps=1) + + def bad_dtype_input_fn(): + return dataset_ops.Dataset.from_tensors(( + {'x': constant_op.constant([[1], [1]], dtype=dtypes.int32)}, + constant_op.constant([[1], [1]], dtype=dtypes.int64))) + + with self.assertRaisesRegexp(ValueError, 'Expected dtype'): + sme.train(bad_dtype_input_fn, steps=1) + + def test_input_fn_with_global_step(self): + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(), self._get_tmp_dir()) + + def bad_input_fn(): + training.get_or_create_global_step() + return dataset_ops.Dataset.from_tensors(( + {'x': constant_op.constant([[1], [1]], dtype=dtypes.int64)}, + constant_op.constant([[1], [1]], dtype=dtypes.float32))) + + with self.assertRaisesRegexp(RuntimeError, + 'Graph must not contain a global step tensor'): + sme.train(bad_input_fn, steps=1) + + def test_re_export_saved_model_serving_only(self): + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(), self._get_tmp_dir()) + sme.train(dummy_input_fn, steps=3) + self.assertEqual(13, sme.get_variable_value('global_step')) + self.assertEqual(60, sme.get_variable_value('some_var')) + + predictions = next(sme.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'output': 503}, predictions) + + # Export SavedModel, and test that the variable and prediction values are + # the same. + sme_export_dir = sme.export_savedmodel( + self._get_tmp_dir(), dummy_serving_receiver_fn()) + + sme2 = saved_model_estimator.SavedModelEstimator( + sme_export_dir, self._get_tmp_dir()) + self.assertEqual(60, sme.get_variable_value('some_var')) + self.assertEqual(13, sme.get_variable_value('global_step')) + + predictions = next(sme2.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'output': 503}, predictions) + + def test_re_export_saved_model(self): + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(), self._get_tmp_dir()) + self.assertDictEqual( + {'loss': 106, 'metrics/abs_err': 502, 'global_step': 10}, + sme.evaluate(dummy_input_fn, steps=1)) + + sme.train(dummy_input_fn, steps=3) + self.assertDictEqual( + {'loss': 106, 'metrics/abs_err': 502, 'global_step': 13}, + sme.evaluate(dummy_input_fn, steps=1)) + self.assertEqual(60, sme.get_variable_value('some_var')) + + predictions = next(sme.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'output': 503}, predictions) + + # Export SavedModel for all modes + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: dummy_supervised_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: dummy_supervised_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: dummy_serving_receiver_fn()} + sme_export_dir = contrib_export.export_all_saved_models( + sme, self._get_tmp_dir(), input_receiver_fn_map) + + sme2 = saved_model_estimator.SavedModelEstimator( + sme_export_dir, self._get_tmp_dir()) + self.assertDictEqual( + {'loss': 106, 'metrics/abs_err': 502, 'global_step': 13}, + sme.evaluate(dummy_input_fn, steps=1)) + self.assertEqual(60, sme.get_variable_value('some_var')) + + sme.train(dummy_input_fn, steps=7) + self.assertEqual(20, sme.get_variable_value('global_step')) + + predictions = next(sme2.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'output': 503}, predictions) + + def test_load_saved_model_from_serving_only(self): + def model_fn(features, labels, mode): + _, _ = features, labels + return model_fn_lib.EstimatorSpec( + mode, + loss=constant_op.constant([103]), + train_op=state_ops.assign_add(training.get_global_step(), 1), + predictions=constant_op.constant([502]), + export_outputs={'test': export_output.ClassificationOutput( + constant_op.constant([[32.]]))}) + + est = estimator.Estimator(model_fn, self._get_tmp_dir()) + est.train(input_fn=dummy_input_fn, steps=10) + + def serving_input_receiver_fn(): + return export.ServingInputReceiver( + {'test-features': constant_op.constant([[1], [1]])}, + array_ops.placeholder(dtype=dtypes.string)) + + export_dir = est.export_savedmodel( + self._get_tmp_dir(), serving_input_receiver_fn) + + sme = saved_model_estimator.SavedModelEstimator( + export_dir, self._get_tmp_dir()) + + def input_fn(): + return {'inputs': constant_op.constant('someinputstr')} + + prediction = next(sme.predict(input_fn)) + self.assertDictEqual({'scores': 32}, prediction) + + def test_with_local_init_op(self): + def model_fn(features, labels, mode): + _, _ = features, labels + v = variables.Variable(21, name='some_var') + scaffold = monitored_session.Scaffold( + local_init_op=state_ops.assign_add(v, -3).op + ) + return model_fn_lib.EstimatorSpec( + mode, + scaffold=scaffold, + train_op=state_ops.assign_add(training.get_global_step(), 1), + loss=array_ops.identity(v)) + export_dir = self._export_estimator(predict=False, model_fn=model_fn) + sme = saved_model_estimator.SavedModelEstimator( + export_dir, self._get_tmp_dir()) + + eval_results1 = sme.evaluate(dummy_input_fn, steps=2) + self.assertEqual(15, eval_results1['loss']) + + sme.train(dummy_input_fn, steps=1) + self.assertEqual(15, sme.get_variable_value('some_var')) + + eval_results2 = sme.evaluate(dummy_input_fn, steps=5) + self.assertEqual(12, eval_results2['loss']) + + def test_with_working_input_fn(self): + def model_fn(features, labels, mode): + loss = None + if labels is not None: + loss = labels[0][0] + labels[1][0] + return model_fn_lib.EstimatorSpec( + mode, + loss=loss, + train_op=state_ops.assign_add(training.get_global_step(), 1), + predictions={'features_0': array_ops.identity([features['x'][0][0]]), + 'features_1': array_ops.identity([features['x'][1][0]])}) + + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(model_fn=model_fn), self._get_tmp_dir()) + eval_results = sme.evaluate(dummy_input_fn, steps=1) + self.assertEqual(1, eval_results['loss']) + + predictions = next(sme.predict(dummy_input_fn_features_only)) + self.assertDictEqual({'features_0': 5, 'features_1': 6}, predictions) + + def test_control_dependency(self): + # Control dependencies are saved with "^" appended to the start of the input + # name. The input map must include control dependencies as well. + def model_fn(features, labels, mode): + _ = labels + with ops.control_dependencies([features['x']]): + loss = features['x'][1][0] + return model_fn_lib.EstimatorSpec( + mode, + loss=loss, + train_op=state_ops.assign_add(training.get_global_step(), 1)) + sme = saved_model_estimator.SavedModelEstimator( + self._export_estimator(train=False, predict=False, model_fn=model_fn), + self._get_tmp_dir()) + sme.evaluate(dummy_input_fn, steps=1) # Should run without error + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index dc49383c5c300e82839c478e097074b3e8776b3b..918a7e2bc772dee226e5ef23d0e3e34309f180f4 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -133,6 +133,7 @@ _nest_allowed_symbols = [ 'flatten_dict_items', 'pack_sequence_as', 'map_structure', + 'map_structure_with_paths', 'assert_shallow_structure', 'flatten_up_to', 'map_structure_up_to', diff --git a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py index 9e356dd96562c28adec7fc28fe144394e1c2ed38..e7184a01fbf57319399fc6dd287b7387138b4058 100644 --- a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py +++ b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py @@ -27,7 +27,7 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import training as train __all__ = [ @@ -40,7 +40,7 @@ __all__ = [ def _get_checkpoint_filename(filepattern): """Returns checkpoint filename given directory or specific filepattern.""" if gfile.IsDirectory(filepattern): - return saver.latest_checkpoint(filepattern) + return checkpoint_management.latest_checkpoint(filepattern) return filepattern diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index e8e318001972934c7d2154bc14744823a3ba09f9..322d5c335e6a77c46c7ce5dd795e21a2d5a1f8f9 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.platform import resource_loader from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver as tf_saver @@ -199,10 +200,20 @@ def global_variable(initial_value, @contrib_add_arg_scope -def variable(name, shape=None, dtype=None, initializer=None, - regularizer=None, trainable=True, collections=None, - caching_device=None, device=None, - partitioner=None, custom_getter=None, use_resource=None): +def variable(name, + shape=None, + dtype=None, + initializer=None, + regularizer=None, + trainable=True, + collections=None, + caching_device=None, + device=None, + partitioner=None, + custom_getter=None, + use_resource=None, + synchronization=variables.VariableSynchronization.AUTO, + aggregation=variables.VariableAggregation.NONE): """Gets an existing variable with these parameters or creates a new one. Args: @@ -228,6 +239,15 @@ def variable(name, shape=None, dtype=None, initializer=None, custom_getter: Callable that allows overwriting the internal get_variable method and has to have the same signature. use_resource: If `True` use a ResourceVariable instead of a Variable. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + @{tf.VariableSynchronization}. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + @{tf.VariableAggregation}. Returns: The created or existing variable. @@ -242,21 +262,36 @@ def variable(name, shape=None, dtype=None, initializer=None, getter = functools.partial(custom_getter, reuse=variable_scope.get_variable_scope().reuse) with ops.device(device or ''): - return getter(name, shape=shape, dtype=dtype, - initializer=initializer, - regularizer=regularizer, - trainable=trainable, - collections=collections, - caching_device=caching_device, - partitioner=partitioner, - use_resource=use_resource) + return getter( + name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + trainable=trainable, + collections=collections, + caching_device=caching_device, + partitioner=partitioner, + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) @contrib_add_arg_scope -def model_variable(name, shape=None, dtype=dtypes.float32, initializer=None, - regularizer=None, trainable=True, collections=None, - caching_device=None, device=None, partitioner=None, - custom_getter=None, use_resource=None): +def model_variable(name, + shape=None, + dtype=dtypes.float32, + initializer=None, + regularizer=None, + trainable=True, + collections=None, + caching_device=None, + device=None, + partitioner=None, + custom_getter=None, + use_resource=None, + synchronization=variables.VariableSynchronization.AUTO, + aggregation=variables.VariableAggregation.NONE): """Gets an existing model variable with these parameters or creates a new one. Args: @@ -283,18 +318,36 @@ def model_variable(name, shape=None, dtype=dtypes.float32, initializer=None, custom_getter: Callable that allows overwriting the internal get_variable method and has to have the same signature. use_resource: If `True` use a ResourceVariable instead of a Variable. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + @{tf.VariableSynchronization}. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + @{tf.VariableAggregation}. Returns: The created or existing variable. """ collections = list(collections or []) collections += [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES] - var = variable(name, shape=shape, dtype=dtype, - initializer=initializer, regularizer=regularizer, - trainable=trainable, collections=collections, - caching_device=caching_device, device=device, - partitioner=partitioner, custom_getter=custom_getter, - use_resource=use_resource) + var = variable( + name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + trainable=trainable, + collections=collections, + caching_device=caching_device, + device=device, + partitioner=partitioner, + custom_getter=custom_getter, + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) return var diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index 7e0c7dbec1d9266b53a169fe83b88d1e3af77d04..3c44630a51deb8a468165e8da458600665d0ada1 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -106,8 +106,9 @@ class LocalVariableTest(test.TestCase): def testResourceVariable(self): a = variables_lib2.local_variable(0) b = variables_lib2.local_variable(0, use_resource=True) - self.assertEqual(type(a), variables_lib.Variable) - self.assertEqual(type(b), resource_variable_ops.ResourceVariable) + self.assertTrue(isinstance(a, variables_lib.Variable)) + self.assertFalse(isinstance(a, resource_variable_ops.ResourceVariable)) + self.assertTrue(isinstance(b, resource_variable_ops.ResourceVariable)) class GlobalVariableTest(test.TestCase): @@ -176,8 +177,9 @@ class GlobalVariableTest(test.TestCase): def testResourceVariable(self): a = variables_lib2.global_variable(0) b = variables_lib2.global_variable(0, use_resource=True) - self.assertEqual(type(a), variables_lib.Variable) - self.assertEqual(type(b), resource_variable_ops.ResourceVariable) + self.assertTrue(isinstance(a, variables_lib.Variable)) + self.assertFalse(isinstance(a, resource_variable_ops.ResourceVariable)) + self.assertTrue(isinstance(b, resource_variable_ops.ResourceVariable)) class GlobalStepTest(test.TestCase): diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 2458f7554afdc12709571c551a8323cda7fa5c17..0ccb4583ab653bc2ef6c5c810c902a9332e82df9 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -135,9 +135,12 @@ class FusedConv2DBiasActivationOp : public OpKernel { context->GetAttr("activation_mode", &activation_mode_str)); OP_REQUIRES_OK(context, GetActivationModeFromString(activation_mode_str, &activation_mode_)); - OP_REQUIRES(context, activation_mode_ == ActivationMode::RELU, - errors::InvalidArgument("Current implementation only supports " - "RELU as the activation function.")); + OP_REQUIRES(context, + activation_mode_ == ActivationMode::RELU || + activation_mode_ == ActivationMode::NONE, + errors::InvalidArgument( + "Current implementation only supports RELU or NONE " + "as the activation function.")); cudnn_use_autotune_ = CudnnUseAutotune(); } @@ -440,6 +443,8 @@ void LaunchFusedConv2DBiasActivationOp:: : dnn::DataLayout::kBatchDepthYX; constexpr auto filter_layout = is_int8x4 ? dnn::FilterLayout::kOutputInputYX4 : dnn::FilterLayout::kOutputInputYX; + constexpr auto compute_data_format = + is_int8x4 ? FORMAT_NCHW_VECT_C : FORMAT_NCHW; dnn::BatchDescriptor conv_input_desc; conv_input_desc.set_count(batch_size) @@ -526,6 +531,7 @@ void LaunchFusedConv2DBiasActivationOp:: batch_size, conv_input_depth, {{conv_input_rows, conv_input_cols}}, + compute_data_format, output_depth, {{filter_rows, filter_cols}}, // TODO(yangzihao): Add support for arbitrary dilations for fused conv. @@ -538,6 +544,18 @@ void LaunchFusedConv2DBiasActivationOp:: activation_mode, }; + dnn::ActivationMode dnn_activation_mode; + switch (activation_mode) { + case ActivationMode::NONE: + dnn_activation_mode = dnn::ActivationMode::kNone; + break; + case ActivationMode::RELU: + dnn_activation_mode = dnn::ActivationMode::kRelu; + break; + default: + LOG(FATAL) << "Activation mode " << activation_mode << " not supported"; + } + dnn::AlgorithmConfig algorithm_config; if (cudnn_use_autotune && !AutoTuneConvBiasActivation::GetInstance()->Find( fused_conv_parameters, &algorithm_config)) { @@ -558,10 +576,9 @@ void LaunchFusedConv2DBiasActivationOp:: ->ThenFusedConvolveWithAlgorithm( conv_input_desc, conv_input_ptr, conv_input_scale, filter_desc, filter_ptr, conv_desc, side_input_ptr, - side_input_scale, bias_desc, bias_ptr, - dnn::ActivationMode::kRelu, output_desc, &output_ptr, - &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm), - &profile_result) + side_input_scale, bias_desc, bias_ptr, dnn_activation_mode, + output_desc, &output_ptr, &scratch_allocator, + dnn::AlgorithmConfig(profile_algorithm), &profile_result) .ok(); if (cudnn_launch_status) { if (profile_result.is_valid()) { @@ -597,7 +614,7 @@ void LaunchFusedConv2DBiasActivationOp:: ->ThenFusedConvolveWithAlgorithm( conv_input_desc, conv_input_ptr, conv_input_scale, filter_desc, filter_ptr, conv_desc, side_input_ptr, side_input_scale, - bias_desc, bias_ptr, dnn::ActivationMode::kRelu, output_desc, + bias_desc, bias_ptr, dnn_activation_mode, output_desc, &output_ptr, &scratch_allocator, algorithm_config, /*output_profile_result=*/nullptr) .ok(); diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h index ba52697679dafc239b1dac5562573b3589877a8c..b9c131a2e91469c52931080d8a5af90247bd16f0 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h @@ -29,13 +29,13 @@ namespace tensorflow { class FusedConvParameters : public ConvParameters { public: FusedConvParameters(int64 batch, int64 in_depths, const SpatialArray& in, - int64 out_depths, const SpatialArray& filter, - const SpatialArray& dilation, const SpatialArray& stride, - const SpatialArray& padding, DataType dtype, - int device_id, bool has_side_input, + TensorFormat data_format, int64 out_depths, + const SpatialArray& filter, const SpatialArray& dilation, + const SpatialArray& stride, const SpatialArray& padding, + DataType dtype, int device_id, bool has_side_input, ActivationMode activation_mode) - : ConvParameters(batch, in_depths, in, out_depths, filter, dilation, - stride, padding, dtype, device_id), + : ConvParameters(batch, in_depths, in, data_format, out_depths, filter, + dilation, stride, padding, dtype, device_id), activation_mode_(activation_mode), has_side_input_(has_side_input) { hash_code_ = Hash64Combine(hash_code_, has_side_input); diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc index bafd1d59418f0ba47ebbdaabbf06f8e5471fc1a1..410571f3783263152fda93980580182eb666886d 100644 --- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc @@ -44,7 +44,7 @@ REGISTER_OP("FusedConv2DBiasActivation") .Attr(GetPaddingAttrString()) .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") .Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'") - .Attr("activation_mode: {'Relu'} = 'Relu'") + .Attr("activation_mode: {'Relu', 'None'} = 'Relu'") .Attr("dilations: list(int) = [1, 1, 1, 1]") .SetShapeFn([](shape_inference::InferenceContext* c) { using shape_inference::ShapeHandle; @@ -144,7 +144,7 @@ REGISTER_OP("FusedConv2DBiasActivation") `qint8 [ output_channels, input_channels / 4, kernel_height, kernel_width, input_channels % 4 ]` activation_mode: The activation applied to the output. - Currently must be "Relu". + Must be "Relu" or "None". dilations: 1-D tensor of length 4. The dilation factor for each dimension of `input`. If set to k > 1, there will be k-1 skipped cells between each filter element on that dimension. The dimension order is determined diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py index 983b6dc8e5a1512ba81ecbc8d5ca5adaea09afe4..cdc07b935dcc42ce3c0cef6bb8f4a126fe82c883 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py @@ -66,8 +66,10 @@ def fused_conv2d_bias_activation(conv_input, This is optional and defaults to 0. side_input: A `Tensor` of the format specified by `data_format`. This is useful for implementing ResNet blocks. - activation_mode: (optional) currently must be the default "Relu". - Note that in qint8 mode, it also clips to 127, so acts like ReluX. + activation_mode: (optional) currently supports the default "Relu", or + "None" activation function. + Note: in qint8 mode, "None" actually clips to the range [-128, 127], + while "Relu" clips to the range [0, 127]. data_format: Specifies the data format. Possible values are: "NHWC" float [batch, height, width, channels] diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py index 4d62ac65ff619f98a18387058fdc8a0eade0d8f8..0185ef662c2ed05b1ceaf0e3e8071bad4c0d1a0a 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py @@ -622,7 +622,7 @@ def HwioToOihw(in_tensor): def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel, padding, strides, side_input_scale, - side_input, biases): + side_input, biases, apply_relu): """Simulates the int8 fused 2-D convolution op using separate float ops. The arguments and return values have the same format, meanings and @@ -636,6 +636,9 @@ def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel, side_input_scale: A scalar 'float'. side_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout. biases: A `Tensor` of type `float32` in NCHW layout. + apply_relu: A boolean to specify whether to apply "Relu" activation function + that clips outputs to the range [0, 127], or "None" activation that clips + to the range [-128, 127]. Returns: A `Tensor` of type `qint8` in NCHW_VECT_C layout. """ @@ -649,10 +652,12 @@ def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel, conv_and_side_inputs = conv_result + side_input_scale * NchwVectCToNchw( gen_array_ops.dequantize(side_input, -128, 127)) - logit = nn_ops.bias_add(conv_and_side_inputs, biases, data_format="NCHW") + output = nn_ops.bias_add(conv_and_side_inputs, biases, data_format="NCHW") + if apply_relu: + output = nn_ops.relu(output) result, _, _ = gen_array_ops.quantize_v2( - NchwToNchwVectC(nn_ops.relu(logit)), -128, 127, dtypes.qint8) + NchwToNchwVectC(output), -128, 127, dtypes.qint8) return result @@ -795,7 +800,7 @@ class FusedConvInt8Tests(test.TestCase): }, ] - def runTest(self, test_param): + def runTest(self, test_param, apply_relu): batch_size = test_param["batch_size"] input_channels = test_param["input_channels"] output_channels = test_param["output_channels"] @@ -831,8 +836,8 @@ class FusedConvInt8Tests(test.TestCase): vertical_stride, padding_type) output_width = CalculateConvolvedOutputDim(input_width, filter_width, horizontal_stride, padding_type) - tf_logging.info("output_height=", output_height, ", output_width=", - output_width) + tf_logging.info("output_height=", output_height, ", output_width=", + output_width) side_input, _, _ = gen_array_ops.quantize_v2( random_ops.random_uniform( @@ -858,12 +863,13 @@ class FusedConvInt8Tests(test.TestCase): conv_input_scale=conv_input_scale, side_input_scale=side_input_scale, side_input=side_input, + activation_mode="Relu" if apply_relu else "None", data_format="NCHW_VECT_C", filter_format="OIHW_VECT_I") expected = SimulateFusedConv2dBiasActivationInt8( conv_input_scale, conv_input, kernel, padding_type, strides, - side_input_scale, side_input, biases) + side_input_scale, side_input, biases, apply_relu) with self.test_session(use_gpu=True) as sess: actual_y, expected_y = sess.run([actual, expected]) @@ -877,8 +883,9 @@ class FusedConvInt8Tests(test.TestCase): tf_logging.info("int8 test skipped because not run with --config=cuda or " "no GPUs with compute capability >= 6.1 are available.") return - for test_param in self._test_params: - self.runTest(test_param) + for apply_relu in [True, False]: + for test_param in self._test_params: + self.runTest(test_param, apply_relu) if __name__ == "__main__": diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index b305f37791d71f5a6edeada2bb710a2e5f23087d..053d4e3e977ed1baed8ceeca1a983e999b1ad1ff 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -42,9 +42,12 @@ py_library( "//tensorflow/contrib/training:training_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", + "//tensorflow/python:random_ops", "//tensorflow/python:training", + "//tensorflow/python:training_util", "//tensorflow/python:variable_scope", "//tensorflow/python/ops/distributions", "//tensorflow/python/ops/losses", @@ -54,26 +57,31 @@ py_library( py_test( name = "train_test", srcs = ["python/train_test.py"], + shard_count = 50, srcs_version = "PY2AND3", tags = ["notsan"], deps = [ - ":features", ":namedtuples", + ":random_tensor_pool", ":train", "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/slim:learning", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:random_seed", "//tensorflow/python:training", + "//tensorflow/python:training_util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/ops/distributions", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -188,10 +196,16 @@ py_test( srcs = ["python/losses/python/tuple_losses_test.py"], srcs_version = "PY2AND3", deps = [ + ":losses_impl", + ":namedtuples", ":tuple_losses", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", + "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//third_party/py/numpy", ], @@ -248,12 +262,15 @@ py_library( py_test( name = "random_tensor_pool_test", srcs = ["python/features/python/random_tensor_pool_test.py"], + shard_count = 6, srcs_version = "PY2AND3", deps = [ ":random_tensor_pool", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//third_party/py/numpy", ], ) @@ -344,9 +361,11 @@ py_library( "//tensorflow/python:image_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:platform", "//tensorflow/python:util", + "@six_archive//:six", ], ) @@ -470,12 +489,12 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":head", ":namedtuples", ":summaries", ":train", "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:framework_ops", + "//tensorflow/python:metrics", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/estimator", @@ -498,16 +517,19 @@ py_test( "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", "//tensorflow/python:parsing_ops", "//tensorflow/python:summary", "//tensorflow/python:training", - "//tensorflow/python/estimator:head", + "//tensorflow/python:training_util", + "//tensorflow/python:variable_scope", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:numpy_io", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 4092b320042162e4eb4c5f4879c2c3ea5dc14fc9..ab9886580d1648852e08f64cb3e9b51f679c25de 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -24,11 +24,11 @@ import enum from tensorflow.contrib.framework.python.ops import variables as variable_lib from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples from tensorflow.contrib.gan.python import train as tfgan_train -from tensorflow.contrib.gan.python.estimator.python import head as head_lib from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import variable_scope from tensorflow.python.util import tf_inspect as inspect @@ -53,9 +53,6 @@ _summary_type_map = { } -# TODO(joelshor): For now, this only supports 1:1 generator:discriminator -# training sequentially. Find a nice way to expose options to the user without -# exposing internals. class GANEstimator(estimator.Estimator): """An estimator for Generative Adversarial Networks (GANs). @@ -154,94 +151,93 @@ class GANEstimator(estimator.Estimator): use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. config: `RunConfig` object to configure the runtime settings. + + Raises: + ValueError: If loss functions aren't callable. + ValueError: If `use_loss_summaries` isn't boolean or `None`. + ValueError: If `get_hooks_fn` isn't callable or `None`. """ - # TODO(joelshor): Explicitly validate inputs. + if not callable(generator_loss_fn): + raise ValueError('generator_loss_fn must be callable.') + if not callable(discriminator_loss_fn): + raise ValueError('discriminator_loss_fn must be callable.') + if use_loss_summaries not in [True, False, None]: + raise ValueError('use_loss_summaries must be True, False or None.') + if get_hooks_fn is not None and not callable(get_hooks_fn): + raise TypeError('get_hooks_fn must be callable.') def _model_fn(features, labels, mode): - gopt = (generator_optimizer() if callable(generator_optimizer) else - generator_optimizer) - dopt = (discriminator_optimizer() if callable(discriminator_optimizer) - else discriminator_optimizer) - gan_head = head_lib.gan_head( - generator_loss_fn, discriminator_loss_fn, gopt, dopt, - use_loss_summaries, get_hooks_fn=get_hooks_fn, - get_eval_metric_ops_fn=get_eval_metric_ops_fn) - return _gan_model_fn( - features, labels, mode, generator_fn, discriminator_fn, gan_head, + """GANEstimator model function.""" + if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, + model_fn_lib.ModeKeys.PREDICT]: + raise ValueError('Mode not recognized: %s' % mode) + real_data = labels # rename inputs for clarity + generator_inputs = features # rename inputs for clarity + + # Make GANModel, which encapsulates the GAN model architectures. + gan_model = _get_gan_model( + mode, generator_fn, discriminator_fn, real_data, generator_inputs, add_summaries) + # Make the EstimatorSpec, which incorporates the GANModel, losses, eval + # metrics, and optimizers (if required). + return _get_estimator_spec( + mode, gan_model, generator_loss_fn, discriminator_loss_fn, + get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, + get_hooks_fn) + super(GANEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) -def _gan_model_fn( - features, - labels, - mode, - generator_fn, - discriminator_fn, - head, - add_summaries=None, - generator_scope_name='Generator'): - """The `model_fn` for the GAN estimator. - - We make the following convention: - features -> TFGAN's `generator_inputs` - labels -> TFGAN's `real_data` - - Args: - features: A dictionary to feed to generator. In the unconditional case, - this might be just `noise`. In the conditional GAN case, this - might be the generator's conditioning. The `generator_fn` determines - what the required keys are. - labels: Real data. Can be any structure, as long as `discriminator_fn` - can accept it for the first argument. - mode: Defines whether this is training, evaluation or prediction. - See `ModeKeys`. - generator_fn: A python lambda that takes `generator_inputs` as inputs and - returns the outputs of the GAN generator. - discriminator_fn: A python lambda that takes `real_data`/`generated data` - and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. - head: A `Head` instance suitable for GANs. - add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. - generator_scope_name: The name of the generator scope. We need this to be - the same for GANModels produced by TFGAN's `train.gan_model` and the - manually constructed ones for predictions. - - Returns: - `ModelFnOps` - - Raises: - ValueError: If `labels` isn't `None` during prediction. - """ - real_data = labels - generator_inputs = features - - if mode == model_fn_lib.ModeKeys.TRAIN: - gan_model = _make_train_gan_model( - generator_fn, discriminator_fn, real_data, generator_inputs, - generator_scope_name, add_summaries) - elif mode == model_fn_lib.ModeKeys.EVAL: - gan_model = _make_eval_gan_model( - generator_fn, discriminator_fn, real_data, generator_inputs, - generator_scope_name, add_summaries) - else: +def _get_gan_model( + mode, generator_fn, discriminator_fn, real_data, generator_inputs, + add_summaries, generator_scope='Generator'): + """Makes the GANModel tuple, which encapsulates the GAN model architecture.""" + if mode == model_fn_lib.ModeKeys.PREDICT: if real_data is not None: raise ValueError('`labels` must be `None` when mode is `predict`. ' 'Instead, found %s' % real_data) gan_model = _make_prediction_gan_model( - generator_inputs, generator_fn, generator_scope_name) + generator_inputs, generator_fn, generator_scope) + else: # model_fn_lib.ModeKeys.TRAIN or model_fn_lib.ModeKeys.EVAL + gan_model = _make_gan_model( + generator_fn, discriminator_fn, real_data, generator_inputs, + generator_scope, add_summaries, mode) - return head.create_estimator_spec( - features=None, - mode=mode, - logits=gan_model, - labels=None) + return gan_model + + +def _get_estimator_spec( + mode, gan_model, generator_loss_fn, discriminator_loss_fn, + get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, + get_hooks_fn=None): + """Get the EstimatorSpec for the current mode.""" + if mode == model_fn_lib.ModeKeys.PREDICT: + estimator_spec = model_fn_lib.EstimatorSpec( + mode=mode, predictions=gan_model.generated_data) + else: + gan_loss = tfgan_tuples.GANLoss( + generator_loss=generator_loss_fn(gan_model), + discriminator_loss=discriminator_loss_fn(gan_model)) + if mode == model_fn_lib.ModeKeys.EVAL: + estimator_spec = _get_eval_estimator_spec( + gan_model, gan_loss, get_eval_metric_ops_fn) + else: # model_fn_lib.ModeKeys.TRAIN: + gopt = (generator_optimizer() if callable(generator_optimizer) else + generator_optimizer) + dopt = (discriminator_optimizer() if callable(discriminator_optimizer) + else discriminator_optimizer) + get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() + estimator_spec = _get_train_estimator_spec( + gan_model, gan_loss, gopt, dopt, get_hooks_fn) + + return estimator_spec def _make_gan_model(generator_fn, discriminator_fn, real_data, generator_inputs, generator_scope, add_summaries, mode): - """Make a `GANModel`, and optionally pass in `mode`.""" + """Construct a `GANModel`, and optionally pass in `mode`.""" # If network functions have an argument `mode`, pass mode to it. if 'mode' in inspect.getargspec(generator_fn).args: generator_fn = functools.partial(generator_fn, mode=mode) @@ -264,22 +260,6 @@ def _make_gan_model(generator_fn, discriminator_fn, real_data, return gan_model -def _make_train_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries): - """Make a `GANModel` for training.""" - return _make_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries, - model_fn_lib.ModeKeys.TRAIN) - - -def _make_eval_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries): - """Make a `GANModel` for evaluation.""" - return _make_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries, - model_fn_lib.ModeKeys.EVAL) - - def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): """Make a `GANModel` from just the generator.""" # If `generator_fn` has an argument `mode`, pass mode to it. @@ -303,3 +283,46 @@ def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): discriminator_variables=None, discriminator_scope=None, discriminator_fn=None) + + +def _get_eval_estimator_spec(gan_model, gan_loss, get_eval_metric_ops_fn=None, + name=None): + """Return an EstimatorSpec for the eval case.""" + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + with ops.name_scope(None, 'metrics', + [gan_loss.generator_loss, + gan_loss.discriminator_loss]): + def _summary_key(head_name, val): + return '%s/%s' % (val, head_name) if head_name else val + eval_metric_ops = { + _summary_key(name, 'generator_loss'): + metrics_lib.mean(gan_loss.generator_loss), + _summary_key(name, 'discriminator_loss'): + metrics_lib.mean(gan_loss.discriminator_loss) + } + if get_eval_metric_ops_fn is not None: + custom_eval_metric_ops = get_eval_metric_ops_fn(gan_model) + if not isinstance(custom_eval_metric_ops, dict): + raise TypeError('get_eval_metric_ops_fn must return a dict, ' + 'received: {}'.format(custom_eval_metric_ops)) + eval_metric_ops.update(custom_eval_metric_ops) + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.EVAL, + predictions=gan_model.generated_data, + loss=scalar_loss, + eval_metric_ops=eval_metric_ops) + + +def _get_train_estimator_spec( + gan_model, gan_loss, generator_optimizer, discriminator_optimizer, + get_hooks_fn, train_op_fn=tfgan_train.gan_train_ops): + """Return an EstimatorSpec for the train case.""" + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + train_ops = train_op_fn(gan_model, gan_loss, generator_optimizer, + discriminator_optimizer) + training_hooks = get_hooks_fn(train_ops) + return model_fn_lib.EstimatorSpec( + loss=scalar_loss, + mode=model_fn_lib.ModeKeys.TRAIN, + train_op=train_ops.global_step_inc_op, + training_hooks=training_hooks) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 955482599b372be3f0d0cbc81451c514958d0eb1..9ac9c6ca9ca86a8a9abe9c0f6ebc4cdf5dd2cfb1 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -21,30 +21,30 @@ from __future__ import print_function import shutil import tempfile +from absl.testing import parameterized import numpy as np import six from tensorflow.contrib import layers -from tensorflow.contrib.gan.python import namedtuples +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl as estimator from tensorflow.contrib.gan.python.losses.python import tuple_losses as losses from tensorflow.contrib.learn.python.learn.learn_io import graph_io from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import input as input_lib from tensorflow.python.training import learning_rate_decay -from tensorflow.python.training import monitored_session from tensorflow.python.training import training from tensorflow.python.training import training_util @@ -60,120 +60,109 @@ def discriminator_fn(data, unused_conditioning, mode): return layers.fully_connected(data, 1) -def mock_head(testcase, expected_generator_inputs, expected_real_data, - generator_scope_name): - """Returns a mock head that validates logits values and variable names.""" - discriminator_scope_name = 'Discriminator' # comes from TFGAN defaults - generator_var_names = set([ - '%s/fully_connected/weights:0' % generator_scope_name, - '%s/fully_connected/biases:0' % generator_scope_name]) - discriminator_var_names = set([ - '%s/fully_connected/weights:0' % discriminator_scope_name, - '%s/fully_connected/biases:0' % discriminator_scope_name]) - - def _create_estimator_spec(features, mode, logits, labels): - gan_model = logits # renaming for clarity - is_predict = mode == model_fn_lib.ModeKeys.PREDICT - testcase.assertIsNone(features) - testcase.assertIsNone(labels) - testcase.assertIsInstance(gan_model, namedtuples.GANModel) - - trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - expected_var_names = (generator_var_names if is_predict else - generator_var_names | discriminator_var_names) - testcase.assertItemsEqual(expected_var_names, - [var.name for var in trainable_vars]) - - assertions = [] - def _or_none(x): - return None if is_predict else x - testcase.assertEqual(expected_generator_inputs, gan_model.generator_inputs) - # TODO(joelshor): Add check on `generated_data`. - testcase.assertItemsEqual( - generator_var_names, - set([x.name for x in gan_model.generator_variables])) - testcase.assertEqual(generator_scope_name, gan_model.generator_scope.name) - testcase.assertEqual(_or_none(expected_real_data), gan_model.real_data) - # TODO(joelshor): Add check on `discriminator_real_outputs`. - # TODO(joelshor): Add check on `discriminator_gen_outputs`. - if is_predict: - testcase.assertIsNone(gan_model.discriminator_scope) - else: - testcase.assertEqual(discriminator_scope_name, - gan_model.discriminator_scope.name) - - with ops.control_dependencies(assertions): - if mode == model_fn_lib.ModeKeys.TRAIN: - return model_fn_lib.EstimatorSpec( - mode=mode, loss=array_ops.zeros([]), - train_op=control_flow_ops.no_op(), training_hooks=[]) - elif mode == model_fn_lib.ModeKeys.EVAL: - return model_fn_lib.EstimatorSpec( - mode=mode, predictions=gan_model.generated_data, - loss=array_ops.zeros([])) - elif mode == model_fn_lib.ModeKeys.PREDICT: - return model_fn_lib.EstimatorSpec( - mode=mode, predictions=gan_model.generated_data) - else: - testcase.fail('Invalid mode: {}'.format(mode)) - - head = test.mock.NonCallableMagicMock(spec=head_lib._Head) - head.create_estimator_spec = test.mock.MagicMock( - wraps=_create_estimator_spec) - - return head - - -class GANModelFnTest(test.TestCase): - """Tests that _gan_model_fn passes expected logits to mock head.""" - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - - def tearDown(self): - if self._model_dir: - writer_cache.FileWriterCache.clear() - shutil.rmtree(self._model_dir) +class GetGANModelTest(test.TestCase, parameterized.TestCase): + """Tests that `GetGANModel` produces the correct model.""" - def _test_logits_helper(self, mode): - """Tests that the expected logits are passed to mock head.""" + @parameterized.named_parameters( + ('train', model_fn_lib.ModeKeys.TRAIN), + ('eval', model_fn_lib.ModeKeys.EVAL), + ('predict', model_fn_lib.ModeKeys.PREDICT)) + def test_get_gan_model(self, mode): with ops.Graph().as_default(): - training_util.get_or_create_global_step() - generator_inputs = {'x': array_ops.zeros([5, 4])} - real_data = (None if mode == model_fn_lib.ModeKeys.PREDICT else - array_ops.zeros([5, 4])) - generator_scope_name = 'generator' - head = mock_head(self, - expected_generator_inputs=generator_inputs, - expected_real_data=real_data, - generator_scope_name=generator_scope_name) - estimator_spec = estimator._gan_model_fn( - features=generator_inputs, - labels=real_data, - mode=mode, - generator_fn=generator_fn, - discriminator_fn=discriminator_fn, - generator_scope_name=generator_scope_name, - head=head) - with monitored_session.MonitoredTrainingSession( - checkpoint_dir=self._model_dir) as sess: - if mode == model_fn_lib.ModeKeys.TRAIN: - sess.run(estimator_spec.train_op) - elif mode == model_fn_lib.ModeKeys.EVAL: - sess.run(estimator_spec.loss) - elif mode == model_fn_lib.ModeKeys.PREDICT: - sess.run(estimator_spec.predictions) - else: - self.fail('Invalid mode: {}'.format(mode)) - - def test_logits_predict(self): - self._test_logits_helper(model_fn_lib.ModeKeys.PREDICT) - - def test_logits_eval(self): - self._test_logits_helper(model_fn_lib.ModeKeys.EVAL) - - def test_logits_train(self): - self._test_logits_helper(model_fn_lib.ModeKeys.TRAIN) + generator_inputs = {'x': array_ops.ones([3, 4])} + real_data = (array_ops.zeros([3, 4]) if + mode != model_fn_lib.ModeKeys.PREDICT else None) + gan_model = estimator._get_gan_model( + mode, generator_fn, discriminator_fn, real_data, generator_inputs, + add_summaries=False) + + self.assertEqual(generator_inputs, gan_model.generator_inputs) + self.assertIsNotNone(gan_model.generated_data) + self.assertEqual(2, len(gan_model.generator_variables)) # 1 FC layer + self.assertIsNotNone(gan_model.generator_fn) + if mode == model_fn_lib.ModeKeys.PREDICT: + self.assertIsNone(gan_model.real_data) + self.assertIsNone(gan_model.discriminator_real_outputs) + self.assertIsNone(gan_model.discriminator_gen_outputs) + self.assertIsNone(gan_model.discriminator_variables) + self.assertIsNone(gan_model.discriminator_scope) + self.assertIsNone(gan_model.discriminator_fn) + else: + self.assertIsNotNone(gan_model.real_data) + self.assertIsNotNone(gan_model.discriminator_real_outputs) + self.assertIsNotNone(gan_model.discriminator_gen_outputs) + self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer + self.assertIsNotNone(gan_model.discriminator_scope) + self.assertIsNotNone(gan_model.discriminator_fn) + + +def get_dummy_gan_model(): + # TODO(joelshor): Find a better way of creating a variable scope. + with variable_scope.variable_scope('generator') as gen_scope: + gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) + with variable_scope.variable_scope('discriminator') as dis_scope: + dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) + return tfgan_tuples.GANModel( + generator_inputs=None, + generated_data=array_ops.ones([3, 4]), + generator_variables=[gen_var], + generator_scope=gen_scope, + generator_fn=None, + real_data=array_ops.zeros([3, 4]), + discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var, + discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var, + discriminator_variables=[dis_var], + discriminator_scope=dis_scope, + discriminator_fn=None) + + +def dummy_loss_fn(gan_model): + return math_ops.reduce_sum(gan_model.discriminator_real_outputs - + gan_model.discriminator_gen_outputs) + + +def get_metrics(gan_model): + return { + 'mse_custom_metric': metrics_lib.mean_squared_error( + gan_model.real_data, gan_model.generated_data) + } + + +class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): + """Tests that the EstimatorSpec is constructed appropriately.""" + + @classmethod + def setUpClass(cls): + cls._generator_optimizer = training.GradientDescentOptimizer(1.0) + cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0) + + @parameterized.named_parameters( + ('train', model_fn_lib.ModeKeys.TRAIN), + ('eval', model_fn_lib.ModeKeys.EVAL), + ('predict', model_fn_lib.ModeKeys.PREDICT)) + def test_get_estimator_spec(self, mode): + with ops.Graph().as_default(): + self._gan_model = get_dummy_gan_model() + spec = estimator._get_estimator_spec( + mode, + self._gan_model, + generator_loss_fn=dummy_loss_fn, + discriminator_loss_fn=dummy_loss_fn, + get_eval_metric_ops_fn=get_metrics, + generator_optimizer=self._generator_optimizer, + discriminator_optimizer=self._discriminator_optimizer) + + self.assertEqual(mode, spec.mode) + if mode == model_fn_lib.ModeKeys.PREDICT: + self.assertEqual(self._gan_model.generated_data, spec.predictions) + elif mode == model_fn_lib.ModeKeys.TRAIN: + self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar + self.assertIsNotNone(spec.train_op) + self.assertIsNotNone(spec.training_hooks) + elif mode == model_fn_lib.ModeKeys.EVAL: + self.assertEqual(self._gan_model.generated_data, spec.predictions) + self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar + self.assertIsNotNone(spec.eval_metric_ops) # TODO(joelshor): Add pandas test. @@ -195,12 +184,6 @@ class GANEstimatorIntegrationTest(test.TestCase): lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) return training.GradientDescentOptimizer(lr) - def get_metrics(gan_model): - return { - 'mse_custom_metric': metrics_lib.mean_squared_error( - gan_model.real_data, gan_model.generated_data) - } - gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) est = estimator.GANEstimator( diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index d1441e1eb2aae0fb7d1771110f969bf727ebbb14..1a0ee6dfc498eb6dc8c97411589d9e35bc352062 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -27,16 +27,21 @@ from tensorflow.python.estimator.canned import head from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import ops from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.util import deprecation __all__ = [ 'GANHead', 'gan_head', ] + def _summary_key(head_name, val): return '%s/%s' % (val, head_name) if head_name else val +@deprecation.deprecated( + None, 'Please use tf.contrib.gan.GANEstimator without explicitly making a ' + 'GANHead.') def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=tfgan_train.get_sequential_train_hooks(), @@ -77,6 +82,9 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, class GANHead(head._Head): # pylint: disable=protected-access """`Head` for a GAN.""" + @deprecation.deprecated( + None, 'Please use tf.contrib.gan.GANEstimator without explicitly making ' + 'a GANHead.') def __init__(self, generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, @@ -108,7 +116,7 @@ class GANHead(head._Head): # pylint: disable=protected-access raise TypeError('generator_loss_fn must be callable.') if not callable(discriminator_loss_fn): raise TypeError('discriminator_loss_fn must be callable.') - if not use_loss_summaries in [True, False, None]: + if use_loss_summaries not in [True, False, None]: raise ValueError('use_loss_summaries must be True, False or None.') if get_hooks_fn is not None and not callable(get_hooks_fn): raise TypeError('get_hooks_fn must be callable.') diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py index 5309d87765694fa476dae006105e842420a7c437..8205bc889dc01c8680e2139393d65723280cfbd0 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py @@ -67,7 +67,7 @@ class GANHeadTest(test.TestCase): generator_optimizer=training.GradientDescentOptimizer(1.0), discriminator_optimizer=training.GradientDescentOptimizer(1.0), get_eval_metric_ops_fn=self.get_metrics) - self.assertTrue(isinstance(self.gan_head, head.GANHead)) + self.assertIsInstance(self.gan_head, head.GANHead) def get_metrics(self, gan_model): self.assertTrue(isinstance(gan_model, tfgan_tuples.GANModel)) diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py index 9e4ec59e7098443efc53506a4ba159e84b5c1618..ca2d724b49db25191b5744e10b48c66b6bdeb120 100644 --- a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py @@ -36,16 +36,15 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import random_ops +from tensorflow.python.util import nest __all__ = [ 'tensor_pool', ] -def _to_tuple(x): - if isinstance(x, (list, tuple)): - return tuple(x) - return (x,) +def _to_list(x): + return [x] if isinstance(x, ops.Tensor) else list(x) def tensor_pool(input_values, @@ -63,8 +62,8 @@ def tensor_pool(input_values, `pool_size` = 0 or `pooling_probability` = 0. Args: - input_values: A `Tensor`, or a list or tuple of `Tensor`s from which to read - values to be pooled. + input_values: An arbitrarily nested structure of `tf.Tensors`, from which to + read values to be pooled. pool_size: An integer specifying the maximum size of the pool. Defaults to 50. pooling_probability: A float `Tensor` specifying the probability of getting @@ -72,9 +71,10 @@ def tensor_pool(input_values, name: A string prefix for the name scope for all tensorflow ops. Returns: - A `Tensor`, or a list or tuple of `Tensor`s (according to the type ofx - `input_values`) which is with given probability either the `input_values` or - a randomly chosen sample that was previously inserted in the pool. + A nested structure of `Tensor` objects with the same structure as + `input_values`. With the given probability, the Tensor values are either the + same as in `input_values` or a randomly chosen sample that was previously + inserted in the pool. Raises: ValueError: If `pool_size` is negative. @@ -86,11 +86,10 @@ def tensor_pool(input_values, return input_values original_input_values = input_values - input_values = _to_tuple(input_values) + input_values = nest.flatten(input_values) - with ops.name_scope( - '{}_pool_queue'.format(name), - values=input_values + (pooling_probability,)): + with ops.name_scope('{}_pool_queue'.format(name), + values=input_values + [pooling_probability]): pool_queue = data_flow_ops.RandomShuffleQueue( capacity=pool_size, min_after_dequeue=0, @@ -112,10 +111,10 @@ def tensor_pool(input_values, def _get_input_value_pooled(): enqueue_op = pool_queue.enqueue(input_values) with ops.control_dependencies([enqueue_op]): - return tuple(array_ops.identity(v) for v in input_values) + return [array_ops.identity(v) for v in input_values] def _get_random_pool_value_and_enqueue_input(): - dequeue_values = _to_tuple(pool_queue.dequeue()) + dequeue_values = _to_list(pool_queue.dequeue()) with ops.control_dependencies(dequeue_values): enqueue_op = pool_queue.enqueue(input_values) with ops.control_dependencies([enqueue_op]): @@ -124,7 +123,7 @@ def tensor_pool(input_values, return control_flow_ops.cond(prob, lambda: dequeue_values, lambda: input_values) - output_values = _to_tuple(control_flow_ops.cond( + output_values = _to_list(control_flow_ops.cond( pool_queue.size() < pool_size, _get_input_value_pooled, _get_random_pool_value_and_enqueue_input)) @@ -132,8 +131,4 @@ def tensor_pool(input_values, for input_value, output_value in zip(input_values, output_values): output_value.set_shape(input_value.shape) - if isinstance(original_input_values, list): - return list(output_values) - elif isinstance(original_input_values, tuple): - return output_values - return output_values[0] + return nest.pack_sequence_as(original_input_values, output_values) diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py index d8cf549cf71838178c9da01df462d41d81595fe5..08584dcd656e3e7a079a3fa36f44742b5eac1178 100644 --- a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py @@ -21,7 +21,9 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.gan.python.features.python.random_tensor_pool_impl import tensor_pool +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -111,6 +113,23 @@ class TensorPoolTest(test.TestCase): self.assertEqual(len(outs), len(input_values)) self.assertEqual(outs[1] - outs[0], 1) + def test_pool_preserves_shape(self): + t = constant_op.constant(1) + input_values = [[t, t, t], (t, t), t] + output_values = tensor_pool(input_values, pool_size=5) + print('stuff: ', output_values) + # Overall shape. + self.assertIsInstance(output_values, list) + self.assertEqual(3, len(output_values)) + # Shape of first element. + self.assertIsInstance(output_values[0], list) + self.assertEqual(3, len(output_values[0])) + # Shape of second element. + self.assertIsInstance(output_values[1], tuple) + self.assertEqual(2, len(output_values[1])) + # Shape of third element. + self.assertIsInstance(output_values[2], ops.Tensor) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index 1ba3a641671c7f2a411a0c5f99228ca16eee1080..d3897483740faafa62befbaf873886139f1482d2 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -949,6 +949,11 @@ def cycle_consistency_loss(data_x, * loss = (loss_x2x + loss_y2y) / 2 where `loss` is the final result. + For the L1-norm, we follow the original implementation: + https://github.com/junyanz/CycleGAN/blob/master/models/cycle_gan_model.lua + we use L1-norm of pixel-wise error normalized by data size such that + `cycle_loss_weight` can be specified independent of image size. + See https://arxiv.org/abs/1703.10593 for more details. Args: @@ -965,19 +970,12 @@ def cycle_consistency_loss(data_x, A scalar `Tensor` of cycle consistency loss. """ - def _partial_cycle_consistency_loss(data, reconstructed_data): - # Following the original implementation - # https://github.com/junyanz/CycleGAN/blob/master/models/cycle_gan_model.lua - # use L1-norm of pixel-wise error normalized by data size so that - # `cycle_loss_weight` can be specified independent of image size. - return math_ops.reduce_mean(math_ops.abs(data - reconstructed_data)) - with ops.name_scope( scope, 'cycle_consistency_loss', values=[data_x, reconstructed_data_x, data_y, reconstructed_data_y]): - loss_x2x = _partial_cycle_consistency_loss(data_x, reconstructed_data_x) - loss_y2y = _partial_cycle_consistency_loss(data_y, reconstructed_data_y) + loss_x2x = losses.absolute_difference(data_x, reconstructed_data_x) + loss_y2y = losses.absolute_difference(data_y, reconstructed_data_y) loss = (loss_x2x + loss_y2y) / 2.0 if add_summaries: summary.scalar('cycle_consistency_loss_x2x', loss_x2x) diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py index dcc3f94c2d6b9e5e44036e7cc1a9d1bb39104fb5..221c70c38bd432a6be7f6cda9c6700aa2255821f 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py @@ -80,6 +80,9 @@ __all__ = [ 'mutual_information_penalty', 'combine_adversarial_loss', 'cycle_consistency_loss', + 'stargan_generator_loss_wrapper', + 'stargan_discriminator_loss_wrapper', + 'stargan_gradient_penalty_wrapper' ] @@ -277,3 +280,86 @@ def cycle_consistency_loss(cyclegan_model, scope=None, add_summaries=False): cyclegan_model.model_x2y.generator_inputs, cyclegan_model.reconstructed_x, cyclegan_model.model_y2x.generator_inputs, cyclegan_model.reconstructed_y, scope, add_summaries) + + +def stargan_generator_loss_wrapper(loss_fn): + """Convert a generator loss function to take a StarGANModel. + + The new function has the same name as the original one. + + Args: + loss_fn: A python function taking Discriminator's real/fake prediction for + generated data. + + Returns: + A new function that takes a StarGANModel namedtuple and returns the same + loss. + """ + + def new_loss_fn(stargan_model, **kwargs): + return loss_fn( + stargan_model.discriminator_generated_data_source_predication, **kwargs) + + new_docstring = """The stargan_model version of %s.""" % loss_fn.__name__ + new_loss_fn.__docstring__ = new_docstring + new_loss_fn.__name__ = loss_fn.__name__ + new_loss_fn.__module__ = loss_fn.__module__ + return new_loss_fn + + +def stargan_discriminator_loss_wrapper(loss_fn): + """Convert a discriminator loss function to take a StarGANModel. + + The new function has the same name as the original one. + + Args: + loss_fn: A python function taking Discriminator's real/fake prediction for + real data and generated data. + + Returns: + A new function that takes a StarGANModel namedtuple and returns the same + loss. + """ + + def new_loss_fn(stargan_model, **kwargs): + return loss_fn( + stargan_model.discriminator_input_data_source_predication, + stargan_model.discriminator_generated_data_source_predication, **kwargs) + + new_docstring = """The stargan_model version of %s.""" % loss_fn.__name__ + new_loss_fn.__docstring__ = new_docstring + new_loss_fn.__name__ = loss_fn.__name__ + new_loss_fn.__module__ = loss_fn.__module__ + return new_loss_fn + + +def stargan_gradient_penalty_wrapper(loss_fn): + """Convert a gradient penalty function to take a StarGANModel. + + The new function has the same name as the original one. + + Args: + loss_fn: A python function taking real_data, generated_data, + generator_inputs for Discriminator's condition (i.e. number of domains), + discriminator_fn, and discriminator_scope. + + Returns: + A new function that takes a StarGANModel namedtuple and returns the same + loss. + """ + + def new_loss_fn(stargan_model, **kwargs): + num_domains = stargan_model.input_data_domain_label.shape.as_list()[-1] + return loss_fn( + real_data=stargan_model.input_data, + generated_data=stargan_model.generated_data, + generator_inputs=num_domains, + discriminator_fn=stargan_model.discriminator_fn, + discriminator_scope=stargan_model.discriminator_scope, + **kwargs) + + new_docstring = """The stargan_model version of %s.""" % loss_fn.__name__ + new_loss_fn.__docstring__ = new_docstring + new_loss_fn.__name__ = loss_fn.__name__ + new_loss_fn.__module__ = loss_fn.__module__ + return new_loss_fn diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py index aa1ef11172dee6799994b87f70a3883cd67fd15b..a559bbfa11367afd7dfe6a72d2ce2cc9d7ba1f16 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py @@ -22,10 +22,15 @@ import collections import numpy as np +from tensorflow.contrib import layers from tensorflow.contrib.gan.python import namedtuples +from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses_impl from tensorflow.contrib.gan.python.losses.python import tuple_losses_impl as tfgan_losses from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -129,6 +134,9 @@ manual_tests = [ 'mutual_information_penalty', 'wasserstein_gradient_penalty', 'cycle_consistency_loss', + 'stargan_generator_loss_wrapper', + 'stargan_discriminator_loss_wrapper', + 'stargan_gradient_penalty_wrapper' ] discriminator_keyword_args = { @@ -175,6 +183,112 @@ class CycleConsistencyLossTest(test.TestCase): self.assertNear(5.0, loss.eval(), 1e-5) +class StarGANLossWrapperTest(test.TestCase): + + def setUp(self): + + super(StarGANLossWrapperTest, self).setUp() + + self.input_data = array_ops.ones([1, 2, 2, 3]) + self.input_data_domain_label = constant_op.constant([[0, 1]]) + self.generated_data = array_ops.ones([1, 2, 2, 3]) + self.discriminator_input_data_source_predication = array_ops.ones([1]) + self.discriminator_generated_data_source_predication = array_ops.ones([1]) + + def _discriminator_fn(inputs, num_domains): + """Differentiable dummy discriminator for StarGAN.""" + hidden = layers.flatten(inputs) + output_src = math_ops.reduce_mean(hidden, axis=1) + output_cls = layers.fully_connected( + inputs=hidden, + num_outputs=num_domains, + activation_fn=None, + normalizer_fn=None, + biases_initializer=None) + return output_src, output_cls + + with variable_scope.variable_scope('discriminator') as dis_scope: + pass + + self.model = namedtuples.StarGANModel( + input_data=self.input_data, + input_data_domain_label=self.input_data_domain_label, + generated_data=self.generated_data, + generated_data_domain_target=None, + reconstructed_data=None, + discriminator_input_data_source_predication=self. + discriminator_input_data_source_predication, + discriminator_generated_data_source_predication=self. + discriminator_generated_data_source_predication, + discriminator_input_data_domain_predication=None, + discriminator_generated_data_domain_predication=None, + generator_variables=None, + generator_scope=None, + generator_fn=None, + discriminator_variables=None, + discriminator_scope=dis_scope, + discriminator_fn=_discriminator_fn) + + self.discriminator_fn = _discriminator_fn + self.discriminator_scope = dis_scope + + def test_stargan_generator_loss_wrapper(self): + """Test StarGAN generator loss wrapper.""" + loss_fn = tfgan_losses_impl.wasserstein_generator_loss + wrapped_loss_fn = tfgan_losses.stargan_generator_loss_wrapper(loss_fn) + + loss_result_tensor = loss_fn( + self.discriminator_generated_data_source_predication) + wrapped_loss_result_tensor = wrapped_loss_fn(self.model) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + loss_result, wrapped_loss_result = sess.run( + [loss_result_tensor, wrapped_loss_result_tensor]) + self.assertAlmostEqual(loss_result, wrapped_loss_result) + + def test_stargan_discriminator_loss_wrapper(self): + """Test StarGAN discriminator loss wrapper.""" + loss_fn = tfgan_losses_impl.wasserstein_discriminator_loss + wrapped_loss_fn = tfgan_losses.stargan_discriminator_loss_wrapper(loss_fn) + + loss_result_tensor = loss_fn( + self.discriminator_generated_data_source_predication, + self.discriminator_generated_data_source_predication) + wrapped_loss_result_tensor = wrapped_loss_fn(self.model) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + loss_result, wrapped_loss_result = sess.run( + [loss_result_tensor, wrapped_loss_result_tensor]) + self.assertAlmostEqual(loss_result, wrapped_loss_result) + + def test_stargan_gradient_penalty_wrapper(self): + """Test StaGAN gradient penalty wrapper. + + Notes: + The random interpolates are handled by given setting the reconstruction to + be the same as the input. + + """ + loss_fn = tfgan_losses_impl.wasserstein_gradient_penalty + wrapped_loss_fn = tfgan_losses.stargan_gradient_penalty_wrapper(loss_fn) + + loss_result_tensor = loss_fn( + real_data=self.input_data, + generated_data=self.generated_data, + generator_inputs=self.input_data_domain_label.shape.as_list()[-1], + discriminator_fn=self.discriminator_fn, + discriminator_scope=self.discriminator_scope) + wrapped_loss_result_tensor = wrapped_loss_fn(self.model) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + loss_result, wrapped_loss_result = sess.run( + [loss_result_tensor, wrapped_loss_result_tensor]) + self.assertAlmostEqual(loss_result, wrapped_loss_result) + + if __name__ == '__main__': for loss_name in tfgan_losses.__all__: if loss_name in manual_tests: continue diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index 25cfeafeec9000b0dc3849ebe646e59c1b4d1cc3..a462b68e28be989eee04fe4ec5ee902d75e5d909 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -25,12 +25,12 @@ from __future__ import print_function import collections - __all__ = [ 'GANModel', 'InfoGANModel', 'ACGANModel', 'CycleGANModel', + 'StarGANModel', 'GANLoss', 'CycleGANLoss', 'GANTrainOps', @@ -136,6 +136,54 @@ class CycleGANModel( """ +class StarGANModel( + collections.namedtuple('StarGANModel', ( + 'input_data', + 'input_data_domain_label', + 'generated_data', + 'generated_data_domain_target', + 'reconstructed_data', + 'discriminator_input_data_source_predication', + 'discriminator_generated_data_source_predication', + 'discriminator_input_data_domain_predication', + 'discriminator_generated_data_domain_predication', + 'generator_variables', + 'generator_scope', + 'generator_fn', + 'discriminator_variables', + 'discriminator_scope', + 'discriminator_fn', + ))): + """A StarGANModel contains all the pieces needed for StarGAN training. + + Args: + input_data: The real images that need to be transferred by the generator. + input_data_domain_label: The real domain labels associated with the real + images. + generated_data: The generated images produced by the generator. It has the + same shape as the input_data. + generated_data_domain_target: The target domain that the generated images + belong to. It has the same shape as the input_data_domain_label. + reconstructed_data: The reconstructed images produced by the G(enerator). + reconstructed_data = G(G(input_data, generated_data_domain_target), + input_data_domain_label). + discriminator_input_data_source: The discriminator's output for predicting + the source (real/generated) of input_data. + discriminator_generated_data_source: The discriminator's output for + predicting the source (real/generated) of generated_data. + discriminator_input_data_domain_predication: The discriminator's output for + predicting the domain_label for the input_data. + discriminator_generated_data_domain_predication: The discriminatorr's output + for predicting the domain_target for the generated_data. + generator_variables: A list of all generator variables. + generator_scope: Variable scope all generator variables live in. + generator_fn: The generator function. + discriminator_variables: A list of all discriminator variables. + discriminator_scope: Variable scope all discriminator variables live in. + discriminator_fn: The discriminator function. + """ + + class GANLoss( collections.namedtuple('GANLoss', ( 'generator_loss', diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 6fa43059f3125daea080f780210223363d0a89f9..03f52d214b5ac2fef075fb66018f88d2be5c1941 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -34,15 +34,20 @@ from __future__ import print_function from tensorflow.contrib.framework.python.ops import variables as variables_lib from tensorflow.contrib.gan.python import losses as tfgan_losses from tensorflow.contrib.gan.python import namedtuples +from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses_impl from tensorflow.contrib.slim.python.slim import learning as slim_learning from tensorflow.contrib.training.python.training import training +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.distributions import distribution as ds from tensorflow.python.ops.losses import losses +from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook from tensorflow.python.training import sync_replicas_optimizer from tensorflow.python.training import training_util @@ -53,6 +58,7 @@ __all__ = [ 'infogan_model', 'acgan_model', 'cyclegan_model', + 'stargan_model', 'gan_loss', 'cyclegan_loss', 'gan_train_ops', @@ -123,16 +129,9 @@ def gan_model( discriminator_variables = variables_lib.get_trainable_variables(dis_scope) return namedtuples.GANModel( - generator_inputs, - generated_data, - generator_variables, - gen_scope, - generator_fn, - real_data, - discriminator_real_outputs, - discriminator_gen_outputs, - discriminator_variables, - dis_scope, + generator_inputs, generated_data, generator_variables, gen_scope, + generator_fn, real_data, discriminator_real_outputs, + discriminator_gen_outputs, discriminator_variables, dis_scope, discriminator_fn) @@ -201,8 +200,7 @@ def infogan_model( # Get model-specific variables. generator_variables = variables_lib.get_trainable_variables(gen_scope) - discriminator_variables = variables_lib.get_trainable_variables( - disc_scope) + discriminator_variables = variables_lib.get_trainable_variables(disc_scope) return namedtuples.InfoGANModel( generator_inputs, @@ -279,12 +277,12 @@ def acgan_model( generator_inputs = _convert_tensor_or_l_or_d(generator_inputs) generated_data = generator_fn(generator_inputs) with variable_scope.variable_scope(discriminator_scope) as dis_scope: - with ops.name_scope(dis_scope.name+'/generated/'): + with ops.name_scope(dis_scope.name + '/generated/'): (discriminator_gen_outputs, discriminator_gen_classification_logits ) = _validate_acgan_discriminator_outputs( discriminator_fn(generated_data, generator_inputs)) with variable_scope.variable_scope(dis_scope, reuse=True): - with ops.name_scope(dis_scope.name+'/real/'): + with ops.name_scope(dis_scope.name + '/real/'): real_data = ops.convert_to_tensor(real_data) (discriminator_real_outputs, discriminator_real_classification_logits ) = _validate_acgan_discriminator_outputs( @@ -297,8 +295,7 @@ def acgan_model( # Get model-specific variables. generator_variables = variables_lib.get_trainable_variables(gen_scope) - discriminator_variables = variables_lib.get_trainable_variables( - dis_scope) + discriminator_variables = variables_lib.get_trainable_variables(dis_scope) return namedtuples.ACGANModel( generator_inputs, generated_data, generator_variables, gen_scope, @@ -379,6 +376,108 @@ def cyclegan_model( reconstructed_y) +def stargan_model(generator_fn, + discriminator_fn, + input_data, + input_data_domain_label, + generator_scope='Generator', + discriminator_scope='Discriminator'): + """Returns a StarGAN model outputs and variables. + + See https://arxiv.org/abs/1711.09020 for more details. + + Args: + generator_fn: A python lambda that takes `inputs` and `targets` as inputs + and returns 'generated_data' as the transformed version of `input` based + on the `target`. `input` has shape (n, h, w, c), `targets` has shape (n, + num_domains), and `generated_data` has the same shape as `input`. + discriminator_fn: A python lambda that takes `inputs` and `num_domains` as + inputs and returns a tuple (`source_prediction`, `domain_prediction`). + `source_prediction` represents the source(real/generated) prediction by + the discriminator, and `domain_prediction` represents the domain + prediction/classification by the discriminator. `source_prediction` has + shape (n) and `domain_prediction` has shape (n, num_domains). + input_data: Tensor or a list of tensor of shape (n, h, w, c) representing + the real input images. + input_data_domain_label: Tensor or a list of tensor of shape (batch_size, + num_domains) representing the domain label associated with the real + images. + generator_scope: Optional generator variable scope. Useful if you want to + reuse a subgraph that has already been created. + discriminator_scope: Optional discriminator variable scope. Useful if you + want to reuse a subgraph that has already been created. + + Returns: + StarGANModel nametuple return the tensor that are needed to compute the + loss. + + Raises: + ValueError: If the shape of `input_data_domain_label` is not rank 2 or fully + defined in every dimensions. + """ + + # Convert to tensor. + input_data = _convert_tensor_or_l_or_d(input_data) + input_data_domain_label = _convert_tensor_or_l_or_d(input_data_domain_label) + + # Convert list of tensor to a single tensor if applicable. + if isinstance(input_data, (list, tuple)): + input_data = array_ops.concat( + [ops.convert_to_tensor(x) for x in input_data], 0) + if isinstance(input_data_domain_label, (list, tuple)): + input_data_domain_label = array_ops.concat( + [ops.convert_to_tensor(x) for x in input_data_domain_label], 0) + + # Get batch_size, num_domains from the labels. + input_data_domain_label.shape.assert_has_rank(2) + input_data_domain_label.shape.assert_is_fully_defined() + batch_size, num_domains = input_data_domain_label.shape.as_list() + + # Transform input_data to random target domains. + with variable_scope.variable_scope(generator_scope) as generator_scope: + generated_data_domain_target = _generate_stargan_random_domain_target( + batch_size, num_domains) + generated_data = generator_fn(input_data, generated_data_domain_target) + + # Transform generated_data back to the original input_data domain. + with variable_scope.variable_scope(generator_scope, reuse=True): + reconstructed_data = generator_fn(generated_data, input_data_domain_label) + + # Predict source and domain for the generated_data using the discriminator. + with variable_scope.variable_scope( + discriminator_scope) as discriminator_scope: + disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn( + generated_data, num_domains) + + # Predict source and domain for the input_data using the discriminator. + with variable_scope.variable_scope(discriminator_scope, reuse=True): + disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn( + input_data, num_domains) + + # Collect trainable variables from the neural networks. + generator_variables = variables_lib.get_trainable_variables(generator_scope) + discriminator_variables = variables_lib.get_trainable_variables( + discriminator_scope) + + # Create the StarGANModel namedtuple. + return namedtuples.StarGANModel( + input_data=input_data, + input_data_domain_label=input_data_domain_label, + generated_data=generated_data, + generated_data_domain_target=generated_data_domain_target, + reconstructed_data=reconstructed_data, + discriminator_input_data_source_predication=disc_input_data_source_pred, + discriminator_generated_data_source_predication=disc_gen_data_source_pred, + discriminator_input_data_domain_predication=disc_input_data_domain_pred, + discriminator_generated_data_domain_predication=disc_gen_data_domain_pred, + generator_variables=generator_variables, + generator_scope=generator_scope, + generator_fn=generator_fn, + discriminator_variables=discriminator_variables, + discriminator_scope=discriminator_scope, + discriminator_fn=discriminator_fn) + + def _validate_aux_loss_weight(aux_loss_weight, name='aux_loss_weight'): if isinstance(aux_loss_weight, ops.Tensor): aux_loss_weight.shape.assert_is_compatible_with([]) @@ -419,33 +518,42 @@ def _tensor_pool_adjusted_model(model, tensor_pool_fn): Raises: ValueError: If tensor pool does not support the `model`. """ - if tensor_pool_fn is None: - return model - - pooled_generated_data, pooled_generator_inputs = tensor_pool_fn( - (model.generated_data, model.generator_inputs)) - if isinstance(model, namedtuples.GANModel): + pooled_generator_inputs, pooled_generated_data = tensor_pool_fn( + (model.generator_inputs, model.generated_data)) with variable_scope.variable_scope(model.discriminator_scope, reuse=True): dis_gen_outputs = model.discriminator_fn(pooled_generated_data, pooled_generator_inputs) - return model._replace(discriminator_gen_outputs=dis_gen_outputs) + return model._replace( + generator_inputs=pooled_generator_inputs, + generated_data=pooled_generated_data, + discriminator_gen_outputs=dis_gen_outputs) elif isinstance(model, namedtuples.ACGANModel): + pooled_generator_inputs, pooled_generated_data = tensor_pool_fn( + (model.generator_inputs, model.generated_data)) with variable_scope.variable_scope(model.discriminator_scope, reuse=True): - (dis_pooled_gen_outputs, - dis_pooled_gen_classification_logits) = model.discriminator_fn( + (pooled_discriminator_gen_outputs, + pooled_discriminator_gen_classification_logits) = model.discriminator_fn( pooled_generated_data, pooled_generator_inputs) return model._replace( - discriminator_gen_outputs=dis_pooled_gen_outputs, + generator_inputs=pooled_generator_inputs, + generated_data=pooled_generated_data, + discriminator_gen_outputs=pooled_discriminator_gen_outputs, discriminator_gen_classification_logits= - dis_pooled_gen_classification_logits) + pooled_discriminator_gen_classification_logits) elif isinstance(model, namedtuples.InfoGANModel): + pooled_generator_inputs, pooled_generated_data, pooled_structured_input = ( + tensor_pool_fn((model.generator_inputs, model.generated_data, + model.structured_generator_inputs))) with variable_scope.variable_scope(model.discriminator_scope, reuse=True): - (dis_pooled_gen_outputs, + (pooled_discriminator_gen_outputs, pooled_predicted_distributions) = model.discriminator_and_aux_fn( pooled_generated_data, pooled_generator_inputs) return model._replace( - discriminator_gen_outputs=dis_pooled_gen_outputs, + generator_inputs=pooled_generator_inputs, + generated_data=pooled_generated_data, + structured_generator_inputs=pooled_structured_input, + discriminator_gen_outputs=pooled_discriminator_gen_outputs, predicted_distributions=pooled_predicted_distributions) else: raise ValueError('Tensor pool does not support `model`: %s.' % type(model)) @@ -512,8 +620,8 @@ def gan_loss( `model` isn't an `InfoGANModel`. """ # Validate arguments. - gradient_penalty_weight = _validate_aux_loss_weight(gradient_penalty_weight, - 'gradient_penalty_weight') + gradient_penalty_weight = _validate_aux_loss_weight( + gradient_penalty_weight, 'gradient_penalty_weight') mutual_information_penalty_weight = _validate_aux_loss_weight( mutual_information_penalty_weight, 'infogan_weight') aux_cond_generator_weight = _validate_aux_loss_weight( @@ -537,33 +645,38 @@ def gan_loss( 'is provided, `model` must be an `ACGANModel`. Instead, was %s.' % type(model)) + # Optionally create pooled model. + pooled_model = (_tensor_pool_adjusted_model(model, tensor_pool_fn) if + tensor_pool_fn else model) + # Create standard losses. gen_loss = generator_loss_fn(model, add_summaries=add_summaries) - dis_loss = discriminator_loss_fn( - _tensor_pool_adjusted_model(model, tensor_pool_fn), - add_summaries=add_summaries) + dis_loss = discriminator_loss_fn(pooled_model, add_summaries=add_summaries) # Add optional extra losses. if _use_aux_loss(gradient_penalty_weight): gp_loss = tfgan_losses.wasserstein_gradient_penalty( - model, + pooled_model, epsilon=gradient_penalty_epsilon, target=gradient_penalty_target, one_sided=gradient_penalty_one_sided, add_summaries=add_summaries) dis_loss += gradient_penalty_weight * gp_loss if _use_aux_loss(mutual_information_penalty_weight): - info_loss = tfgan_losses.mutual_information_penalty( + gen_info_loss = tfgan_losses.mutual_information_penalty( model, add_summaries=add_summaries) - dis_loss += mutual_information_penalty_weight * info_loss - gen_loss += mutual_information_penalty_weight * info_loss + dis_info_loss = (gen_info_loss if tensor_pool_fn is None else + tfgan_losses.mutual_information_penalty( + pooled_model, add_summaries=add_summaries)) + gen_loss += mutual_information_penalty_weight * gen_info_loss + dis_loss += mutual_information_penalty_weight * dis_info_loss if _use_aux_loss(aux_cond_generator_weight): ac_gen_loss = tfgan_losses.acgan_generator_loss( model, add_summaries=add_summaries) gen_loss += aux_cond_generator_weight * ac_gen_loss if _use_aux_loss(aux_cond_discriminator_weight): ac_disc_loss = tfgan_losses.acgan_discriminator_loss( - model, add_summaries=add_summaries) + pooled_model, add_summaries=add_summaries) dis_loss += aux_cond_discriminator_weight * ac_disc_loss # Gathers auxiliary losses. if model.generator_scope: @@ -631,8 +744,8 @@ def cyclegan_loss( generator_loss_fn=generator_loss_fn, discriminator_loss_fn=discriminator_loss_fn, **kwargs) - return partial_loss._replace( - generator_loss=partial_loss.generator_loss + aux_loss) + return partial_loss._replace(generator_loss=partial_loss.generator_loss + + aux_loss) with ops.name_scope('cyclegan_loss_x2y'): loss_x2y = _partial_loss(model.model_x2y) @@ -642,6 +755,130 @@ def cyclegan_loss( return namedtuples.CycleGANLoss(loss_x2y, loss_y2x) +def stargan_loss( + model, + generator_loss_fn=tfgan_losses.stargan_generator_loss_wrapper( + tfgan_losses_impl.wasserstein_generator_loss), + discriminator_loss_fn=tfgan_losses.stargan_discriminator_loss_wrapper( + tfgan_losses_impl.wasserstein_discriminator_loss), + gradient_penalty_weight=10.0, + gradient_penalty_epsilon=1e-10, + gradient_penalty_target=1.0, + gradient_penalty_one_sided=False, + reconstruction_loss_fn=losses.absolute_difference, + reconstruction_loss_weight=10.0, + classification_loss_fn=losses.softmax_cross_entropy, + classification_loss_weight=1.0, + classification_one_hot=True, + add_summaries=True): + """StarGAN Loss. + + The four major part can be found here: http://screen/tMRMBAohDYG. + + Args: + model: (StarGAN) Model output of the stargan_model() function call. + generator_loss_fn: The loss function on the generator. Takes a + `StarGANModel` named tuple. + discriminator_loss_fn: The loss function on the discriminator. Takes a + `StarGANModel` namedtuple. + gradient_penalty_weight: (float) Gradient penalty weight. Default to 10 per + the original paper https://arxiv.org/abs/1711.09020. Set to 0 or None to + turn off gradient penalty. + gradient_penalty_epsilon: (float) A small positive number added for + numerical stability when computing the gradient norm. + gradient_penalty_target: (float, or tf.float `Tensor`) The target value of + gradient norm. Defaults to 1.0. + gradient_penalty_one_sided: (bool) If `True`, penalty proposed in + https://arxiv.org/abs/1709.08894 is used. Defaults to `False`. + reconstruction_loss_fn: The reconstruction loss function. Default to L1-norm + and the function must conform to the `tf.losses` API. + reconstruction_loss_weight: Reconstruction loss weight. Default to 10.0. + classification_loss_fn: The loss function on the discriminator's ability to + classify domain of the input. Default to one-hot softmax cross entropy + loss, and the function must conform to the `tf.losses` API. + classification_loss_weight: (float) Classification loss weight. Default to + 1.0. + classification_one_hot: (bool) If the label is one hot representation. + Default to True. If False, classification classification_loss_fn need to + be sigmoid cross entropy loss instead. + add_summaries: (bool) Add the loss to the summary + + Returns: + GANLoss namedtuple where we have generator loss and discriminator loss. + + Raises: + ValueError: If input StarGANModel.input_data_domain_label does not have rank + 2, or dimension 2 is not defined. + """ + + def _classification_loss_helper(true_labels, predict_logits, scope_name): + """Classification Loss Function Helper. + + Args: + true_labels: Tensor of shape [batch_size, num_domains] representing the + label where each row is an one-hot vector. + predict_logits: Tensor of shape [batch_size, num_domains] representing the + predicted label logit, which is UNSCALED output from the NN. + scope_name: (string) Name scope of the loss component. + + Returns: + Single scalar tensor representing the classification loss. + """ + + with ops.name_scope(scope_name, values=(true_labels, predict_logits)): + + loss = classification_loss_fn( + onehot_labels=true_labels, logits=predict_logits) + + if not classification_one_hot: + loss = math_ops.reduce_sum(loss, axis=1) + loss = math_ops.reduce_mean(loss) + + if add_summaries: + summary.scalar(scope_name, loss) + + return loss + + # Check input shape. + model.input_data_domain_label.shape.assert_has_rank(2) + model.input_data_domain_label.shape[1:].assert_is_fully_defined() + + # Adversarial Loss. + generator_loss = generator_loss_fn(model, add_summaries=add_summaries) + discriminator_loss = discriminator_loss_fn(model, add_summaries=add_summaries) + + # Gradient Penalty. + if _use_aux_loss(gradient_penalty_weight): + gradient_penalty_fn = tfgan_losses.stargan_gradient_penalty_wrapper( + tfgan_losses_impl.wasserstein_gradient_penalty) + discriminator_loss += gradient_penalty_fn( + model, + epsilon=gradient_penalty_epsilon, + target=gradient_penalty_target, + one_sided=gradient_penalty_one_sided, + add_summaries=add_summaries) * gradient_penalty_weight + + # Reconstruction Loss. + reconstruction_loss = reconstruction_loss_fn(model.input_data, + model.reconstructed_data) + generator_loss += reconstruction_loss * reconstruction_loss_weight + if add_summaries: + summary.scalar('reconstruction_loss', reconstruction_loss) + + # Classification Loss. + generator_loss += _classification_loss_helper( + true_labels=model.generated_data_domain_target, + predict_logits=model.discriminator_generated_data_domain_predication, + scope_name='generator_classification_loss') * classification_loss_weight + discriminator_loss += _classification_loss_helper( + true_labels=model.input_data_domain_label, + predict_logits=model.discriminator_input_data_domain_predication, + scope_name='discriminator_classification_loss' + ) * classification_loss_weight + + return namedtuples.GANLoss(generator_loss, discriminator_loss) + + def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True): """Gets generator and discriminator update ops. @@ -822,12 +1059,14 @@ def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): Returns: A function that takes a GANTrainOps tuple and returns a list of hooks. """ + def get_hooks(train_ops): generator_hook = RunTrainOpsHook(train_ops.generator_train_op, train_steps.generator_train_steps) discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op, train_steps.discriminator_train_steps) return [generator_hook, discriminator_hook] + return get_hooks @@ -881,23 +1120,23 @@ def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): d_hook = RunTrainOpsHook(d_op, num_d_steps) return [joint_hook, g_hook, d_hook] + return get_hooks # TODO(joelshor): This function currently returns the global step. Find a # good way for it to return the generator, discriminator, and final losses. -def gan_train( - train_ops, - logdir, - get_hooks_fn=get_sequential_train_hooks(), - master='', - is_chief=True, - scaffold=None, - hooks=None, - chief_only_hooks=None, - save_checkpoint_secs=600, - save_summaries_steps=100, - config=None): +def gan_train(train_ops, + logdir, + get_hooks_fn=get_sequential_train_hooks(), + master='', + is_chief=True, + scaffold=None, + hooks=None, + chief_only_hooks=None, + save_checkpoint_secs=600, + save_summaries_steps=100, + config=None): """A wrapper around `contrib.training.train` that uses GAN hooks. Args: @@ -943,8 +1182,7 @@ def gan_train( config=config) -def get_sequential_train_steps( - train_steps=namedtuples.GANTrainSteps(1, 1)): +def get_sequential_train_steps(train_steps=namedtuples.GANTrainSteps(1, 1)): """Returns a thin wrapper around slim.learning.train_step, for GANs. This function is to provide support for the Supervisor. For new code, please @@ -1042,3 +1280,19 @@ def _validate_acgan_discriminator_outputs(discriminator_output): 'A discriminator function for ACGAN must output a tuple ' 'consisting of (discrimination logits, classification logits).') return a, b + + +def _generate_stargan_random_domain_target(batch_size, num_domains): + """Generate random domain label. + + Args: + batch_size: (int) Number of random domain label. + num_domains: (int) Number of domains representing with the label. + + Returns: + Tensor of shape (batch_size, num_domains) representing random label. + """ + domain_idx = random_ops.random_uniform( + [batch_size], minval=0, maxval=num_domains, dtype=dtypes.int32) + + return array_ops.one_hot(domain_idx, num_domains) diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index 3ebbe55d059e5e72607bc4efdbf95a6c96d99f11..58f348034fdcaadd8d738517aef2a7e2f0172c13 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -18,8 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np +from tensorflow.contrib import layers from tensorflow.contrib.framework.python.ops import variables as variables_lib from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python import train @@ -30,6 +32,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -84,19 +87,59 @@ class InfoGANDiscriminator(object): def acgan_discriminator_model(inputs, _, num_classes=10): - return (discriminator_model(inputs, _), array_ops.one_hot( - # TODO(haeusser): infer batch size from input - random_ops.random_uniform([3], maxval=num_classes, dtype=dtypes.int32), - num_classes)) + return ( + discriminator_model(inputs, _), + array_ops.one_hot( + # TODO(haeusser): infer batch size from input + random_ops.random_uniform( + [3], maxval=num_classes, dtype=dtypes.int32), + num_classes)) class ACGANDiscriminator(object): def __call__(self, inputs, _, num_classes=10): - return (discriminator_model(inputs, _), array_ops.one_hot( - # TODO(haeusser): infer batch size from input - random_ops.random_uniform([3], maxval=num_classes, dtype=dtypes.int32), - num_classes)) + return ( + discriminator_model(inputs, _), + array_ops.one_hot( + # TODO(haeusser): infer batch size from input + random_ops.random_uniform( + [3], maxval=num_classes, dtype=dtypes.int32), + num_classes)) + + +def stargan_generator_model(inputs, _): + """Dummy generator for StarGAN.""" + + return variable_scope.get_variable('dummy_g', initializer=0.5) * inputs + + +class StarGANGenerator(object): + + def __call__(self, inputs, _): + return stargan_generator_model(inputs, _) + + +def stargan_discriminator_model(inputs, num_domains): + """Differentiable dummy discriminator for StarGAN.""" + + hidden = layers.flatten(inputs) + + output_src = math_ops.reduce_mean(hidden, axis=1) + + output_cls = layers.fully_connected( + inputs=hidden, + num_outputs=num_domains, + activation_fn=None, + normalizer_fn=None, + biases_initializer=None) + return output_src, output_cls + + +class StarGANDiscriminator(object): + + def __call__(self, inputs, num_domains): + return stargan_discriminator_model(inputs, num_domains) def get_gan_model(): @@ -122,8 +165,7 @@ def get_gan_model(): def get_callable_gan_model(): ganmodel = get_gan_model() return ganmodel._replace( - generator_fn=Generator(), - discriminator_fn=Discriminator()) + generator_fn=Generator(), discriminator_fn=Discriminator()) def create_gan_model(): @@ -242,69 +284,84 @@ def create_callable_cyclegan_model(): data_y=array_ops.ones([1, 2])) -def get_sync_optimizer(): - return sync_replicas_optimizer.SyncReplicasOptimizer( - gradient_descent.GradientDescentOptimizer(learning_rate=1.0), - replicas_to_aggregate=1) +def get_stargan_model(): + """Similar to get_gan_model().""" + # TODO(joelshor): Find a better way of creating a variable scope. + with variable_scope.variable_scope('generator') as gen_scope: + pass + with variable_scope.variable_scope('discriminator') as dis_scope: + pass + return namedtuples.StarGANModel( + input_data=array_ops.ones([1, 2, 2, 3]), + input_data_domain_label=array_ops.ones([1, 2]), + generated_data=array_ops.ones([1, 2, 2, 3]), + generated_data_domain_target=array_ops.ones([1, 2]), + reconstructed_data=array_ops.ones([1, 2, 2, 3]), + discriminator_input_data_source_predication=array_ops.ones([1]), + discriminator_generated_data_source_predication=array_ops.ones([1]), + discriminator_input_data_domain_predication=array_ops.ones([1, 2]), + discriminator_generated_data_domain_predication=array_ops.ones([1, 2]), + generator_variables=None, + generator_scope=gen_scope, + generator_fn=stargan_generator_model, + discriminator_variables=None, + discriminator_scope=dis_scope, + discriminator_fn=stargan_discriminator_model) -def get_tensor_pool_fn(pool_size): +def get_callable_stargan_model(): + model = get_stargan_model() + return model._replace( + generator_fn=StarGANGenerator(), discriminator_fn=StarGANDiscriminator()) - def tensor_pool_fn_impl(input_values): - return random_tensor_pool.tensor_pool(input_values, pool_size=pool_size) - return tensor_pool_fn_impl +def create_stargan_model(): + return train.stargan_model( + stargan_generator_model, stargan_discriminator_model, + array_ops.ones([1, 2, 2, 3]), array_ops.ones([1, 2])) -def get_tensor_pool_fn_for_infogan(pool_size): +def create_callable_stargan_model(): + return train.stargan_model(StarGANGenerator(), StarGANDiscriminator(), + array_ops.ones([1, 2, 2, 3]), + array_ops.ones([1, 2])) - def tensor_pool_fn_impl(input_values): - generated_data, generator_inputs = input_values - output_values = random_tensor_pool.tensor_pool( - [generated_data] + generator_inputs, pool_size=pool_size) - return output_values[0], output_values[1:] - return tensor_pool_fn_impl +def get_sync_optimizer(): + return sync_replicas_optimizer.SyncReplicasOptimizer( + gradient_descent.GradientDescentOptimizer(learning_rate=1.0), + replicas_to_aggregate=1) -class GANModelTest(test.TestCase): +class GANModelTest(test.TestCase, parameterized.TestCase): """Tests for `gan_model`.""" - def _test_output_type_helper(self, create_fn, tuple_type): - self.assertTrue(isinstance(create_fn(), tuple_type)) - - def test_output_type_gan(self): - self._test_output_type_helper(get_gan_model, namedtuples.GANModel) - - def test_output_type_callable_gan(self): - self._test_output_type_helper(get_callable_gan_model, namedtuples.GANModel) - - def test_output_type_infogan(self): - self._test_output_type_helper(get_infogan_model, namedtuples.InfoGANModel) - - def test_output_type_callable_infogan(self): - self._test_output_type_helper( - get_callable_infogan_model, namedtuples.InfoGANModel) - - def test_output_type_acgan(self): - self._test_output_type_helper(get_acgan_model, namedtuples.ACGANModel) - - def test_output_type_callable_acgan(self): - self._test_output_type_helper( - get_callable_acgan_model, namedtuples.ACGANModel) - - def test_output_type_cyclegan(self): - self._test_output_type_helper(get_cyclegan_model, namedtuples.CycleGANModel) - - def test_output_type_callable_cyclegan(self): - self._test_output_type_helper(get_callable_cyclegan_model, - namedtuples.CycleGANModel) + @parameterized.named_parameters( + ('gan', get_gan_model, namedtuples.GANModel), + ('callable_gan', get_callable_gan_model, namedtuples.GANModel), + ('infogan', get_infogan_model, namedtuples.InfoGANModel), + ('callable_infogan', get_callable_infogan_model, + namedtuples.InfoGANModel), + ('acgan', get_acgan_model, namedtuples.ACGANModel), + ('callable_acgan', get_callable_acgan_model, namedtuples.ACGANModel), + ('cyclegan', get_cyclegan_model, namedtuples.CycleGANModel), + ('callable_cyclegan', get_callable_cyclegan_model, + namedtuples.CycleGANModel), + ('stargan', get_stargan_model, namedtuples.StarGANModel), + ('callabel_stargan', get_callable_stargan_model, namedtuples.StarGANModel) + ) + def test_output_type(self, create_fn, expected_tuple_type): + """Test that output type is as expected.""" + self.assertIsInstance(create_fn(), expected_tuple_type) def test_no_shape_check(self): + def dummy_generator_model(_): return (None, None) + def dummy_discriminator_model(data, conditioning): # pylint: disable=unused-argument return 1 + with self.assertRaisesRegexp(AttributeError, 'object has no attribute'): train.gan_model( dummy_generator_model, @@ -320,52 +377,182 @@ class GANModelTest(test.TestCase): check_shapes=False) -class GANLossTest(test.TestCase): - """Tests for `gan_loss`.""" +class StarGANModelTest(test.TestCase): + """Tests for `stargan_model`.""" + + @staticmethod + def create_input_and_label_tensor(batch_size, img_size, c_size, num_domains): + input_tensor_list = [] + label_tensor_list = [] + for _ in range(num_domains): + input_tensor_list.append( + random_ops.random_uniform((batch_size, img_size, img_size, c_size))) + domain_idx = random_ops.random_uniform( + [batch_size], minval=0, maxval=num_domains, dtype=dtypes.int32) + label_tensor_list.append(array_ops.one_hot(domain_idx, num_domains)) + return input_tensor_list, label_tensor_list + + def test_generate_stargan_random_domain_target(self): + batch_size = 8 + domain_numbers = 3 + + target_tensor = train._generate_stargan_random_domain_target( + batch_size, domain_numbers) + + with self.test_session() as sess: + targets = sess.run(target_tensor) + self.assertTupleEqual((batch_size, domain_numbers), targets.shape) + for target in targets: + self.assertEqual(1, np.sum(target)) + self.assertEqual(1, np.max(target)) + + def test_stargan_model_output_type(self): + batch_size = 2 + img_size = 16 + c_size = 3 + num_domains = 5 + + input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor( + batch_size, img_size, c_size, num_domains) + model = train.stargan_model( + generator_fn=stargan_generator_model, + discriminator_fn=stargan_discriminator_model, + input_data=input_tensor, + input_data_domain_label=label_tensor) + + self.assertIsInstance(model, namedtuples.StarGANModel) + self.assertTrue(isinstance(model.discriminator_variables, list)) + self.assertTrue(isinstance(model.generator_variables, list)) + self.assertIsInstance(model.discriminator_scope, + variable_scope.VariableScope) + self.assertTrue(model.generator_scope, variable_scope.VariableScope) + self.assertTrue(callable(model.discriminator_fn)) + self.assertTrue(callable(model.generator_fn)) + + def test_stargan_model_generator_output(self): + batch_size = 2 + img_size = 16 + c_size = 3 + num_domains = 5 + + input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor( + batch_size, img_size, c_size, num_domains) + model = train.stargan_model( + generator_fn=stargan_generator_model, + discriminator_fn=stargan_discriminator_model, + input_data=input_tensor, + input_data_domain_label=label_tensor) - # Test output type. - def _test_output_type_helper(self, get_gan_model_fn): - loss = train.gan_loss(get_gan_model_fn(), add_summaries=True) - self.assertTrue(isinstance(loss, namedtuples.GANLoss)) - self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) - - def test_output_type_gan(self): - self._test_output_type_helper(get_gan_model) + with self.test_session(use_gpu=True) as sess: - def test_output_type_callable_gan(self): - self._test_output_type_helper(get_callable_gan_model) + sess.run(variables.global_variables_initializer()) - def test_output_type_infogan(self): - self._test_output_type_helper(get_infogan_model) + input_data, generated_data, reconstructed_data = sess.run( + [model.input_data, model.generated_data, model.reconstructed_data]) + self.assertTupleEqual( + (batch_size * num_domains, img_size, img_size, c_size), + input_data.shape) + self.assertTupleEqual( + (batch_size * num_domains, img_size, img_size, c_size), + generated_data.shape) + self.assertTupleEqual( + (batch_size * num_domains, img_size, img_size, c_size), + reconstructed_data.shape) + + def test_stargan_model_discriminator_output(self): + batch_size = 2 + img_size = 16 + c_size = 3 + num_domains = 5 + + input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor( + batch_size, img_size, c_size, num_domains) + model = train.stargan_model( + generator_fn=stargan_generator_model, + discriminator_fn=stargan_discriminator_model, + input_data=input_tensor, + input_data_domain_label=label_tensor) - def test_output_type_callable_infogan(self): - self._test_output_type_helper(get_callable_infogan_model) + with self.test_session(use_gpu=True) as sess: - def test_output_type_acgan(self): - self._test_output_type_helper(get_acgan_model) + sess.run(variables.global_variables_initializer()) - def test_output_type_callable_acgan(self): - self._test_output_type_helper(get_callable_acgan_model) + disc_input_data_source_pred, disc_gen_data_source_pred = sess.run([ + model.discriminator_input_data_source_predication, + model.discriminator_generated_data_source_predication + ]) + self.assertEqual(1, len(disc_input_data_source_pred.shape)) + self.assertEqual(batch_size * num_domains, + disc_input_data_source_pred.shape[0]) + self.assertEqual(1, len(disc_gen_data_source_pred.shape)) + self.assertEqual(batch_size * num_domains, + disc_gen_data_source_pred.shape[0]) + + input_label, disc_input_label, gen_label, disc_gen_label = sess.run([ + model.input_data_domain_label, + model.discriminator_input_data_domain_predication, + model.generated_data_domain_target, + model.discriminator_generated_data_domain_predication + ]) + self.assertTupleEqual((batch_size * num_domains, num_domains), + input_label.shape) + self.assertTupleEqual((batch_size * num_domains, num_domains), + disc_input_label.shape) + self.assertTupleEqual((batch_size * num_domains, num_domains), + gen_label.shape) + self.assertTupleEqual((batch_size * num_domains, num_domains), + disc_gen_label.shape) + + +class GANLossTest(test.TestCase, parameterized.TestCase): + """Tests for `gan_loss`.""" - def test_output_type_cyclegan(self): - loss = train.cyclegan_loss(create_cyclegan_model(), add_summaries=True) - self.assertIsInstance(loss, namedtuples.CycleGANLoss) + @parameterized.named_parameters( + ('gan', get_gan_model), + ('callable_gan', get_callable_gan_model), + ('infogan', get_infogan_model), + ('callable_infogan', get_callable_infogan_model), + ('acgan', get_acgan_model), + ('callable_acgan', get_callable_acgan_model), + ) + def test_output_type(self, get_gan_model_fn): + """Test output type.""" + loss = train.gan_loss(get_gan_model_fn(), add_summaries=True) + self.assertIsInstance(loss, namedtuples.GANLoss) self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) - def test_output_type_callable_cyclegan(self): - loss = train.cyclegan_loss( - create_callable_cyclegan_model(), add_summaries=True) + @parameterized.named_parameters( + ('cyclegan', create_cyclegan_model), + ('callable_cyclegan', create_callable_cyclegan_model), + ) + def test_cyclegan_output_type(self, get_gan_model_fn): + loss = train.cyclegan_loss(get_gan_model_fn(), add_summaries=True) self.assertIsInstance(loss, namedtuples.CycleGANLoss) self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) - # Test gradient penalty option. - def _test_grad_penalty_helper(self, create_gan_model_fn, one_sided=False): + @parameterized.named_parameters( + ('gan', create_gan_model, False), + ('gan_one_sided', create_gan_model, True), + ('callable_gan', create_callable_gan_model, False), + ('callable_gan_one_sided', create_callable_gan_model, True), + ('infogan', create_infogan_model, False), + ('infogan_one_sided', create_infogan_model, True), + ('callable_infogan', create_callable_infogan_model, False), + ('callable_infogan_one_sided', create_callable_infogan_model, True), + ('acgan', create_acgan_model, False), + ('acgan_one_sided', create_acgan_model, True), + ('callable_acgan', create_callable_acgan_model, False), + ('callable_acgan_one_sided', create_callable_acgan_model, True), + ) + def test_grad_penalty(self, create_gan_model_fn, one_sided): + """Test gradient penalty option.""" model = create_gan_model_fn() loss = train.gan_loss(model) - loss_gp = train.gan_loss(model, - gradient_penalty_weight=1.0, - gradient_penalty_one_sided=one_sided) - self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss)) + loss_gp = train.gan_loss( + model, + gradient_penalty_weight=1.0, + gradient_penalty_one_sided=one_sided) + self.assertIsInstance(loss_gp, namedtuples.GANLoss) # Check values. with self.test_session(use_gpu=True) as sess: @@ -376,58 +563,28 @@ class GANLossTest(test.TestCase): [loss.discriminator_loss, loss_gp.discriminator_loss]) self.assertEqual(loss_gen_np, loss_gen_gp_np) - self.assertTrue(loss_dis_np < loss_dis_gp_np) - - def test_grad_penalty_gan(self): - self._test_grad_penalty_helper(create_gan_model) - - def test_grad_penalty_callable_gan(self): - self._test_grad_penalty_helper(create_callable_gan_model) - - def test_grad_penalty_infogan(self): - self._test_grad_penalty_helper(create_infogan_model) - - def test_grad_penalty_callable_infogan(self): - self._test_grad_penalty_helper(create_callable_infogan_model) - - def test_grad_penalty_acgan(self): - self._test_grad_penalty_helper(create_acgan_model) - - def test_grad_penalty_callable_acgan(self): - self._test_grad_penalty_helper(create_callable_acgan_model) - - def test_grad_penalty_one_sided_gan(self): - self._test_grad_penalty_helper(create_gan_model, one_sided=True) - - def test_grad_penalty_one_sided_callable_gan(self): - self._test_grad_penalty_helper(create_callable_gan_model, one_sided=True) - - def test_grad_penalty_one_sided_infogan(self): - self._test_grad_penalty_helper(create_infogan_model, one_sided=True) - - def test_grad_penalty_one_sided_callable_infogan(self): - self._test_grad_penalty_helper( - create_callable_infogan_model, one_sided=True) - - def test_grad_penalty_one_sided_acgan(self): - self._test_grad_penalty_helper(create_acgan_model, one_sided=True) - - def test_grad_penalty_one_sided_callable_acgan(self): - self._test_grad_penalty_helper(create_callable_acgan_model, one_sided=True) - - # Test mutual information penalty option. - def _test_mutual_info_penalty_helper(self, create_gan_model_fn): - train.gan_loss(create_gan_model_fn(), - mutual_information_penalty_weight=constant_op.constant(1.0)) - - def test_mutual_info_penalty_infogan(self): - self._test_mutual_info_penalty_helper(get_infogan_model) - - def test_mutual_info_penalty_callable_infogan(self): - self._test_mutual_info_penalty_helper(get_callable_infogan_model) - - # Test regularization loss. - def _test_regularization_helper(self, get_gan_model_fn): + self.assertLess(loss_dis_np, loss_dis_gp_np) + + @parameterized.named_parameters( + ('infogan', get_infogan_model), + ('callable_infogan', get_callable_infogan_model), + ) + def test_mutual_info_penalty(self, create_gan_model_fn): + """Test mutual information penalty option.""" + train.gan_loss( + create_gan_model_fn(), + mutual_information_penalty_weight=constant_op.constant(1.0)) + + @parameterized.named_parameters( + ('gan', get_gan_model), + ('callable_gan', get_callable_gan_model), + ('infogan', get_infogan_model), + ('callable_infogan', get_callable_infogan_model), + ('acgan', get_acgan_model), + ('callable_acgan', get_callable_acgan_model), + ) + def test_regularization_helper(self, get_gan_model_fn): + """Test regularization loss.""" # Evaluate losses without regularization. no_reg_loss = train.gan_loss(get_gan_model_fn()) with self.test_session(use_gpu=True): @@ -435,11 +592,11 @@ class GANLossTest(test.TestCase): no_reg_loss_dis_np = no_reg_loss.discriminator_loss.eval() with ops.name_scope(get_gan_model_fn().generator_scope.name): - ops.add_to_collection( - ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0)) + ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, + constant_op.constant(3.0)) with ops.name_scope(get_gan_model_fn().discriminator_scope.name): - ops.add_to_collection( - ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0)) + ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, + constant_op.constant(2.0)) # Check that losses now include the correct regularization values. reg_loss = train.gan_loss(get_gan_model_fn()) @@ -447,63 +604,47 @@ class GANLossTest(test.TestCase): reg_loss_gen_np = reg_loss.generator_loss.eval() reg_loss_dis_np = reg_loss.discriminator_loss.eval() - self.assertTrue(3.0, reg_loss_gen_np - no_reg_loss_gen_np) - self.assertTrue(3.0, reg_loss_dis_np - no_reg_loss_dis_np) - - def test_regularization_gan(self): - self._test_regularization_helper(get_gan_model) + self.assertEqual(3.0, reg_loss_gen_np - no_reg_loss_gen_np) + self.assertEqual(2.0, reg_loss_dis_np - no_reg_loss_dis_np) - def test_regularization_callable_gan(self): - self._test_regularization_helper(get_callable_gan_model) - - def test_regularization_infogan(self): - self._test_regularization_helper(get_infogan_model) - - def test_regularization_callable_infogan(self): - self._test_regularization_helper(get_callable_infogan_model) - - def test_regularization_acgan(self): - self._test_regularization_helper(get_acgan_model) - - def test_regularization_callable_acgan(self): - self._test_regularization_helper(get_callable_acgan_model) - - # Test that ACGan models work. - def _test_acgan_helper(self, create_gan_model_fn): + @parameterized.named_parameters( + ('notcallable', create_acgan_model), + ('callable', create_callable_acgan_model), + ) + def test_acgan(self, create_gan_model_fn): + """Test that ACGAN models work.""" model = create_gan_model_fn() loss = train.gan_loss(model) loss_ac_gen = train.gan_loss(model, aux_cond_generator_weight=1.0) loss_ac_dis = train.gan_loss(model, aux_cond_discriminator_weight=1.0) - self.assertTrue(isinstance(loss, namedtuples.GANLoss)) - self.assertTrue(isinstance(loss_ac_gen, namedtuples.GANLoss)) - self.assertTrue(isinstance(loss_ac_dis, namedtuples.GANLoss)) + self.assertIsInstance(loss, namedtuples.GANLoss) + self.assertIsInstance(loss_ac_gen, namedtuples.GANLoss) + self.assertIsInstance(loss_ac_dis, namedtuples.GANLoss) # Check values. with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() - loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run( - [loss.generator_loss, - loss_ac_gen.generator_loss, - loss_ac_dis.generator_loss]) - loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run( - [loss.discriminator_loss, - loss_ac_gen.discriminator_loss, - loss_ac_dis.discriminator_loss]) - - self.assertTrue(loss_gen_np < loss_dis_np) + loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run([ + loss.generator_loss, loss_ac_gen.generator_loss, + loss_ac_dis.generator_loss + ]) + loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run([ + loss.discriminator_loss, loss_ac_gen.discriminator_loss, + loss_ac_dis.discriminator_loss + ]) + + self.assertLess(loss_gen_np, loss_dis_np) self.assertTrue(np.isscalar(loss_ac_gen_gen_np)) self.assertTrue(np.isscalar(loss_ac_dis_gen_np)) self.assertTrue(np.isscalar(loss_ac_gen_dis_np)) self.assertTrue(np.isscalar(loss_ac_dis_dis_np)) - def test_acgan(self): - self._test_acgan_helper(create_acgan_model) - - def test_callable_acgan(self): - self._test_acgan_helper(create_callable_acgan_model) - - # Test that CycleGan models work. - def _test_cyclegan_helper(self, create_gan_model_fn): + @parameterized.named_parameters( + ('notcallable', create_cyclegan_model), + ('callable', create_callable_cyclegan_model), + ) + def test_cyclegan(self, create_gan_model_fn): + """Test that CycleGan models work.""" model = create_gan_model_fn() loss = train.cyclegan_loss(model) self.assertIsInstance(loss, namedtuples.CycleGANLoss) @@ -524,14 +665,86 @@ class GANLossTest(test.TestCase): self.assertTrue(np.isscalar(loss_y2x_gen_np)) self.assertTrue(np.isscalar(loss_y2x_dis_np)) - def test_cyclegan(self): - self._test_cyclegan_helper(create_cyclegan_model) + @parameterized.named_parameters( + ('notcallable', create_stargan_model), + ('callable', create_callable_stargan_model), + ) + def test_stargan(self, create_gan_model_fn): + + model = create_gan_model_fn() + model_loss = train.stargan_loss(model) + + self.assertIsInstance(model_loss, namedtuples.GANLoss) + + with self.test_session() as sess: + + sess.run(variables.global_variables_initializer()) + + gen_loss, disc_loss = sess.run( + [model_loss.generator_loss, model_loss.discriminator_loss]) + + self.assertTrue(np.isscalar(gen_loss)) + self.assertTrue(np.isscalar(disc_loss)) + + @parameterized.named_parameters( + ('gan', create_gan_model), + ('callable_gan', create_callable_gan_model), + ('infogan', create_infogan_model), + ('callable_infogan', create_callable_infogan_model), + ('acgan', create_acgan_model), + ('callable_acgan', create_callable_acgan_model), + ) + def test_tensor_pool(self, create_gan_model_fn): + """Test tensor pool option.""" + model = create_gan_model_fn() + tensor_pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=5) + loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) + self.assertIsInstance(loss, namedtuples.GANLoss) + + # Check values. + with self.test_session(use_gpu=True) as sess: + variables.global_variables_initializer().run() + for _ in range(10): + sess.run([loss.generator_loss, loss.discriminator_loss]) + + def test_discriminator_only_sees_pool(self): + """Checks that discriminator only sees pooled values.""" + def checker_gen_fn(_): + return constant_op.constant(0.0) + model = train.gan_model( + checker_gen_fn, + discriminator_model, + real_data=array_ops.zeros([]), + generator_inputs=random_ops.random_normal([])) + def tensor_pool_fn(_): + return (random_ops.random_uniform([]), random_ops.random_uniform([])) + def checker_dis_fn(inputs, _): + """Discriminator that checks that it only sees pooled Tensors.""" + self.assertFalse(constant_op.is_constant(inputs)) + return inputs + model = model._replace( + discriminator_fn=checker_dis_fn) + train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) + + def test_doesnt_crash_when_in_nested_scope(self): + with variable_scope.variable_scope('outer_scope'): + gan_model = train.gan_model( + generator_model, + discriminator_model, + real_data=array_ops.zeros([1, 2]), + generator_inputs=random_ops.random_normal([1, 2])) + + # This should work inside a scope. + train.gan_loss(gan_model, gradient_penalty_weight=1.0) - def test_callable_cyclegan(self): - self._test_cyclegan_helper(create_callable_cyclegan_model) + # This should also work outside a scope. + train.gan_loss(gan_model, gradient_penalty_weight=1.0) - def _check_tensor_pool_adjusted_model_outputs(self, tensor1, tensor2, - pool_size): + +class TensorPoolAdjusteModelTest(test.TestCase): + + def _check_tensor_pool_adjusted_model_outputs( + self, tensor1, tensor2, pool_size): history_values = [] with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() @@ -548,115 +761,66 @@ class GANLossTest(test.TestCase): # pool). self.assertTrue(any([(v == t2).all() for v in history_values])) - # Test `_tensor_pool_adjusted_model` for gan model. - def test_tensor_pool_adjusted_model_gan(self): - model = create_gan_model() - - new_model = train._tensor_pool_adjusted_model(model, None) + def _make_new_model_and_check(self, model, pool_size): + pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=pool_size) + new_model = train._tensor_pool_adjusted_model(model, pool_fn) # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) - self.assertIs(new_model.discriminator_gen_outputs, - model.discriminator_gen_outputs) - - pool_size = 5 - new_model = train._tensor_pool_adjusted_model( - model, get_tensor_pool_fn(pool_size=pool_size)) self.assertIsNot(new_model.discriminator_gen_outputs, model.discriminator_gen_outputs) + + return new_model + + def test_tensor_pool_adjusted_model_gan(self): + """Test `_tensor_pool_adjusted_model` for gan model.""" + pool_size = 5 + model = create_gan_model() + new_model = self._make_new_model_and_check(model, pool_size) + # Check values. self._check_tensor_pool_adjusted_model_outputs( model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, pool_size) - # Test _tensor_pool_adjusted_model for infogan model. def test_tensor_pool_adjusted_model_infogan(self): + """Test _tensor_pool_adjusted_model for infogan model.""" + pool_size = 5 model = create_infogan_model() + new_model = self._make_new_model_and_check(model, pool_size) - pool_size = 5 - new_model = train._tensor_pool_adjusted_model( - model, get_tensor_pool_fn_for_infogan(pool_size=pool_size)) - # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' - self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) - self.assertIsNot(new_model.discriminator_gen_outputs, - model.discriminator_gen_outputs) + # Check values. self.assertIsNot(new_model.predicted_distributions, model.predicted_distributions) - # Check values. self._check_tensor_pool_adjusted_model_outputs( model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, pool_size) - # Test _tensor_pool_adjusted_model for acgan model. def test_tensor_pool_adjusted_model_acgan(self): + """Test _tensor_pool_adjusted_model for acgan model.""" + pool_size = 5 model = create_acgan_model() + new_model = self._make_new_model_and_check(model, pool_size) - pool_size = 5 - new_model = train._tensor_pool_adjusted_model( - model, get_tensor_pool_fn(pool_size=pool_size)) - # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' - self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) - self.assertIsNot(new_model.discriminator_gen_outputs, - model.discriminator_gen_outputs) + # Check values. self.assertIsNot(new_model.discriminator_gen_classification_logits, model.discriminator_gen_classification_logits) - # Check values. self._check_tensor_pool_adjusted_model_outputs( model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, pool_size) - # Test tensor pool. - def _test_tensor_pool_helper(self, create_gan_model_fn): - model = create_gan_model_fn() - if isinstance(model, namedtuples.InfoGANModel): - tensor_pool_fn = get_tensor_pool_fn_for_infogan(pool_size=5) - else: - tensor_pool_fn = get_tensor_pool_fn(pool_size=5) - loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) - self.assertTrue(isinstance(loss, namedtuples.GANLoss)) - - # Check values. - with self.test_session(use_gpu=True) as sess: - variables.global_variables_initializer().run() - for _ in range(10): - sess.run([loss.generator_loss, loss.discriminator_loss]) - - def test_tensor_pool_gan(self): - self._test_tensor_pool_helper(create_gan_model) - - def test_tensor_pool_callable_gan(self): - self._test_tensor_pool_helper(create_callable_gan_model) - - def test_tensor_pool_infogan(self): - self._test_tensor_pool_helper(create_infogan_model) - - def test_tensor_pool_callable_infogan(self): - self._test_tensor_pool_helper(create_callable_infogan_model) - - def test_tensor_pool_acgan(self): - self._test_tensor_pool_helper(create_acgan_model) - - def test_tensor_pool_callable_acgan(self): - self._test_tensor_pool_helper(create_callable_acgan_model) - - def test_doesnt_crash_when_in_nested_scope(self): - with variable_scope.variable_scope('outer_scope'): - gan_model = train.gan_model( - generator_model, - discriminator_model, - real_data=array_ops.zeros([1, 2]), - generator_inputs=random_ops.random_normal([1, 2])) - - # This should work inside a scope. - train.gan_loss(gan_model, gradient_penalty_weight=1.0) - # This should also work outside a scope. - train.gan_loss(gan_model, gradient_penalty_weight=1.0) - - -class GANTrainOpsTest(test.TestCase): +class GANTrainOpsTest(test.TestCase, parameterized.TestCase): """Tests for `gan_train_ops`.""" - def _test_output_type_helper(self, create_gan_model_fn): + @parameterized.named_parameters( + ('gan', create_gan_model), + ('callable_gan', create_callable_gan_model), + ('infogan', create_infogan_model), + ('callable_infogan', create_callable_infogan_model), + ('acgan', create_acgan_model), + ('callable_acgan', create_callable_acgan_model), + ) + def test_output_type(self, create_gan_model_fn): model = create_gan_model_fn() loss = train.gan_loss(model) @@ -670,28 +834,24 @@ class GANTrainOpsTest(test.TestCase): summarize_gradients=True, colocate_gradients_with_ops=True) - self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps)) - - def test_output_type_gan(self): - self._test_output_type_helper(create_gan_model) - - def test_output_type_callable_gan(self): - self._test_output_type_helper(create_callable_gan_model) - - def test_output_type_infogan(self): - self._test_output_type_helper(create_infogan_model) - - def test_output_type_callable_infogan(self): - self._test_output_type_helper(create_callable_infogan_model) - - def test_output_type_acgan(self): - self._test_output_type_helper(create_acgan_model) - - def test_output_type_callable_acgan(self): - self._test_output_type_helper(create_callable_acgan_model) + self.assertIsInstance(train_ops, namedtuples.GANTrainOps) # TODO(joelshor): Add a test to check that custom update op is run. - def _test_unused_update_ops(self, create_gan_model_fn, provide_update_ops): + @parameterized.named_parameters( + ('gan', create_gan_model, False), + ('gan_provideupdates', create_gan_model, True), + ('callable_gan', create_callable_gan_model, False), + ('callable_gan_provideupdates', create_callable_gan_model, True), + ('infogan', create_infogan_model, False), + ('infogan_provideupdates', create_infogan_model, True), + ('callable_infogan', create_callable_infogan_model, False), + ('callable_infogan_provideupdates', create_callable_infogan_model, True), + ('acgan', create_acgan_model, False), + ('acgan_provideupdates', create_acgan_model, True), + ('callable_acgan', create_callable_acgan_model, False), + ('callable_acgan_provideupdates', create_callable_acgan_model, True), + ) + def test_unused_update_ops(self, create_gan_model_fn, provide_update_ops): model = create_gan_model_fn() loss = train.gan_loss(model) @@ -707,8 +867,11 @@ class GANTrainOpsTest(test.TestCase): # Add an update op outside the generator and discriminator scopes. if provide_update_ops: - kwargs = {'update_ops': - [constant_op.constant(1.0), gen_update_op, dis_update_op]} + kwargs = { + 'update_ops': [ + constant_op.constant(1.0), gen_update_op, dis_update_op + ] + } else: ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, constant_op.constant(1.0)) kwargs = {} @@ -717,8 +880,8 @@ class GANTrainOpsTest(test.TestCase): d_opt = gradient_descent.GradientDescentOptimizer(1.0) with self.assertRaisesRegexp(ValueError, 'There are unused update ops:'): - train.gan_train_ops(model, loss, g_opt, d_opt, - check_for_unused_update_ops=True, **kwargs) + train.gan_train_ops( + model, loss, g_opt, d_opt, check_for_unused_update_ops=True, **kwargs) train_ops = train.gan_train_ops( model, loss, g_opt, d_opt, check_for_unused_update_ops=False, **kwargs) @@ -735,44 +898,16 @@ class GANTrainOpsTest(test.TestCase): self.assertEqual(1, gen_update_count.eval()) self.assertEqual(1, dis_update_count.eval()) - def test_unused_update_ops_gan(self): - self._test_unused_update_ops(create_gan_model, False) - - def test_unused_update_ops_gan_provideupdates(self): - self._test_unused_update_ops(create_gan_model, True) - - def test_unused_update_ops_callable_gan(self): - self._test_unused_update_ops(create_callable_gan_model, False) - - def test_unused_update_ops_callable_gan_provideupdates(self): - self._test_unused_update_ops(create_callable_gan_model, True) - - def test_unused_update_ops_infogan(self): - self._test_unused_update_ops(create_infogan_model, False) - - def test_unused_update_ops_infogan_provideupdates(self): - self._test_unused_update_ops(create_infogan_model, True) - - def test_unused_update_ops_callable_infogan(self): - self._test_unused_update_ops(create_callable_infogan_model, False) - - def test_unused_update_ops_callable_infogan_provideupdates(self): - self._test_unused_update_ops(create_callable_infogan_model, True) - - def test_unused_update_ops_acgan(self): - self._test_unused_update_ops(create_acgan_model, False) - - def test_unused_update_ops_acgan_provideupdates(self): - self._test_unused_update_ops(create_acgan_model, True) - - def test_unused_update_ops_callable_acgan(self): - self._test_unused_update_ops(create_callable_acgan_model, False) - - def test_unused_update_ops_callable_acgan_provideupdates(self): - self._test_unused_update_ops(create_callable_acgan_model, True) - - def _test_sync_replicas_helper( - self, create_gan_model_fn, create_global_step=False): + @parameterized.named_parameters( + ('gan', create_gan_model, False), + ('callable_gan', create_callable_gan_model, False), + ('infogan', create_infogan_model, False), + ('callable_infogan', create_callable_infogan_model, False), + ('acgan', create_acgan_model, False), + ('callable_acgan', create_callable_acgan_model, False), + ('gan_canbeint32', create_gan_model, True), + ) + def test_sync_replicas(self, create_gan_model_fn, create_global_step): model = create_gan_model_fn() loss = train.gan_loss(model) num_trainable_vars = len(variables_lib.get_trainable_variables()) @@ -785,11 +920,8 @@ class GANTrainOpsTest(test.TestCase): g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = train.gan_train_ops( - model, - loss, - generator_optimizer=g_opt, - discriminator_optimizer=d_opt) - self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps)) + model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt) + self.assertIsInstance(train_ops, namedtuples.GANTrainOps) # No new trainable variables should have been added. self.assertEqual(num_trainable_vars, len(variables_lib.get_trainable_variables())) @@ -827,29 +959,8 @@ class GANTrainOpsTest(test.TestCase): coord.request_stop() coord.join(g_threads + d_threads) - def test_sync_replicas_gan(self): - self._test_sync_replicas_helper(create_gan_model) - - def test_sync_replicas_callable_gan(self): - self._test_sync_replicas_helper(create_callable_gan_model) - - def test_sync_replicas_infogan(self): - self._test_sync_replicas_helper(create_infogan_model) - def test_sync_replicas_callable_infogan(self): - self._test_sync_replicas_helper(create_callable_infogan_model) - - def test_sync_replicas_acgan(self): - self._test_sync_replicas_helper(create_acgan_model) - - def test_sync_replicas_callable_acgan(self): - self._test_sync_replicas_helper(create_callable_acgan_model) - - def test_global_step_can_be_int32(self): - self._test_sync_replicas_helper(create_gan_model, create_global_step=True) - - -class GANTrainTest(test.TestCase): +class GANTrainTest(test.TestCase, parameterized.TestCase): """Tests for `gan_train`.""" def _gan_train_ops(self, generator_add, discriminator_add): @@ -860,12 +971,20 @@ class GANTrainTest(test.TestCase): # joint training. train_ops = namedtuples.GANTrainOps( generator_train_op=step.assign_add(generator_add, use_locking=True), - discriminator_train_op=step.assign_add(discriminator_add, - use_locking=True), + discriminator_train_op=step.assign_add( + discriminator_add, use_locking=True), global_step_inc_op=step.assign_add(1)) return train_ops - def _test_run_helper(self, create_gan_model_fn): + @parameterized.named_parameters( + ('gan', create_gan_model), + ('callable_gan', create_callable_gan_model), + ('infogan', create_infogan_model), + ('callable_infogan', create_callable_infogan_model), + ('acgan', create_acgan_model), + ('callable_acgan', create_callable_acgan_model), + ) + def test_run_helper(self, create_gan_model_fn): random_seed.set_random_seed(1234) model = create_gan_model_fn() loss = train.gan_loss(model) @@ -881,30 +1000,15 @@ class GANTrainTest(test.TestCase): self.assertTrue(np.isscalar(final_step)) self.assertEqual(2, final_step) - def test_run_gan(self): - self._test_run_helper(create_gan_model) - - def test_run_callable_gan(self): - self._test_run_helper(create_callable_gan_model) - - def test_run_infogan(self): - self._test_run_helper(create_infogan_model) - - def test_run_callable_infogan(self): - self._test_run_helper(create_callable_infogan_model) - - def test_run_acgan(self): - self._test_run_helper(create_acgan_model) - - def test_run_callable_acgan(self): - self._test_run_helper(create_callable_acgan_model) - - # Test multiple train steps. - def _test_multiple_steps_helper(self, get_hooks_fn_fn): + @parameterized.named_parameters( + ('seq_train_steps', train.get_sequential_train_hooks), + ('efficient_seq_train_steps', train.get_joint_train_hooks), + ) + def test_multiple_steps(self, get_hooks_fn_fn): + """Test multiple train steps.""" train_ops = self._gan_train_ops(generator_add=10, discriminator_add=100) train_steps = namedtuples.GANTrainSteps( - generator_train_steps=3, - discriminator_train_steps=4) + generator_train_steps=3, discriminator_train_steps=4) final_step = train.gan_train( train_ops, get_hooks_fn=get_hooks_fn_fn(train_steps), @@ -914,12 +1018,6 @@ class GANTrainTest(test.TestCase): self.assertTrue(np.isscalar(final_step)) self.assertEqual(1 + 3 * 10 + 4 * 100, final_step) - def test_multiple_steps_seq_train_steps(self): - self._test_multiple_steps_helper(train.get_sequential_train_hooks) - - def test_multiple_steps_efficient_seq_train_steps(self): - self._test_multiple_steps_helper(train.get_joint_train_hooks) - def test_supervisor_run_gan_model_train_ops_multiple_steps(self): step = training_util.create_global_step() train_ops = namedtuples.GANTrainOps( @@ -927,8 +1025,7 @@ class GANTrainTest(test.TestCase): discriminator_train_op=constant_op.constant(2.0), global_step_inc_op=step.assign_add(1)) train_steps = namedtuples.GANTrainSteps( - generator_train_steps=3, - discriminator_train_steps=4) + generator_train_steps=3, discriminator_train_steps=4) final_loss = slim_learning.train( train_op=train_ops, @@ -940,10 +1037,18 @@ class GANTrainTest(test.TestCase): self.assertEqual(17.0, final_loss) -class PatchGANTest(test.TestCase): +class PatchGANTest(test.TestCase, parameterized.TestCase): """Tests that functions work on PatchGAN style output.""" - def _test_patchgan_helper(self, create_gan_model_fn): + @parameterized.named_parameters( + ('gan', create_gan_model), + ('callable_gan', create_callable_gan_model), + ('infogan', create_infogan_model), + ('callable_infogan', create_callable_infogan_model), + ('acgan', create_acgan_model), + ('callable_acgan', create_callable_acgan_model), + ) + def test_patchgan(self, create_gan_model_fn): """Ensure that patch-based discriminators work end-to-end.""" random_seed.set_random_seed(1234) model = create_gan_model_fn() @@ -960,24 +1065,6 @@ class PatchGANTest(test.TestCase): self.assertTrue(np.isscalar(final_step)) self.assertEqual(2, final_step) - def test_patchgan_gan(self): - self._test_patchgan_helper(create_gan_model) - - def test_patchgan_callable_gan(self): - self._test_patchgan_helper(create_callable_gan_model) - - def test_patchgan_infogan(self): - self._test_patchgan_helper(create_infogan_model) - - def test_patchgan_callable_infogan(self): - self._test_patchgan_helper(create_callable_infogan_model) - - def test_patchgan_acgan(self): - self._test_patchgan_helper(create_acgan_model) - - def test_patchgan_callable_acgan(self): - self._test_patchgan_helper(create_callable_acgan_model) - if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index 81e70ae30a4c72dbcedd1aabfe758ecca4c8b366..7e6a0f14f6f5e467801fef39ebb597565b3d7e98 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -33,9 +33,11 @@ limitations under the License. #include "tensorflow/core/common_runtime/bfc_allocator.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/pool_allocator.h" +#include "tensorflow/core/common_runtime/process_state.h" #if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" #include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#include "tensorflow/core/common_runtime/gpu/process_state.h" #endif // GOOGLE_CUDA #include "tensorflow/core/framework/allocator_registry.h" #include "tensorflow/core/lib/core/status.h" @@ -172,7 +174,7 @@ class GdrMemoryManager : public RemoteMemoryManager { // Client side endpoints mutex client_mu_; std::map, RdmaEndpointPtr> clients_ - GUARDED_BY(cient_mu_); + GUARDED_BY(client_mu_); // Managed memory regions mutex alloc_mu_; @@ -181,28 +183,25 @@ class GdrMemoryManager : public RemoteMemoryManager { TF_DISALLOW_COPY_AND_ASSIGN(GdrMemoryManager); }; -// TODO(byronyi): remove this class duplicated from the one in -// common/runtime/gpu/pool_allocator.h when it is available in common_runtime -class BasicCPUAllocator : public SubAllocator { - public: - ~BasicCPUAllocator() override {} - - void* Alloc(size_t alignment, size_t num_bytes) override { - return port::AlignedMalloc(num_bytes, alignment); - } - void Free(void* ptr, size_t) override { port::AlignedFree(ptr); } -}; - // TODO(byronyi): remove this class and its registration when the default -// cpu_allocator() returns visitable allocator +// cpu_allocator() returns visitable allocator, or cpu_allocator() is no +// longer in use. class BFCRdmaAllocator : public BFCAllocator { public: BFCRdmaAllocator() - : BFCAllocator(new BasicCPUAllocator(), 1LL << 36, true, "cpu_rdma_bfc") { + : BFCAllocator(new BasicCPUAllocator(port::kNUMANoAffinity), 1LL << 36, + true, "cpu_rdma_bfc") {} +}; +class BFCRdmaAllocatorFactory : public AllocatorFactory { + public: + Allocator* CreateAllocator() override { return new BFCRdmaAllocator; } + + virtual SubAllocator* CreateSubAllocator(int numa_node) { + return new BasicCPUAllocator(numa_node); } }; -REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocator); +REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocatorFactory); GdrMemoryManager::GdrMemoryManager(const string& host, const string& port) : host_(host), @@ -274,9 +273,9 @@ Status GdrMemoryManager::Init() { Allocator* allocators[] = { #if GOOGLE_CUDA - ProcessState::singleton()->GetCUDAHostAllocator(0), - ProcessState::singleton()->GetCPUAllocator(0), + GPUProcessState::singleton()->GetCUDAHostAllocator(0), #endif // GOOGLE_CUDA + ProcessState::singleton()->GetCPUAllocator(0), cpu_allocator(), }; @@ -308,7 +307,8 @@ Status GdrMemoryManager::Init() { if (IsGDRAvailable()) { // Note we don't free allocated GPU memory so there is no free visitor int32_t bus_id = TryToReadNumaNode(listening_->verbs->device) + 1; - ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor); + GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id, + cuda_alloc_visitor); LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id; } #endif // GOOGLE_CUDA @@ -430,7 +430,7 @@ void GdrMemoryManager::TransportOptionsFromTensor( #if GOOGLE_CUDA if (!on_host) { - Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); + Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); Tensor* host_copy = new Tensor(alloc, tensor.dtype(), tensor.shape()); GPUUtil::CopyGPUTensorToCPU( device, device_context, &tensor, host_copy, @@ -532,7 +532,7 @@ void GdrMemoryManager::TensorFromTransportOptions( Tensor host_copy; #if GOOGLE_CUDA if (mr == nullptr && !on_host) { - Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); + Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); host_copy = Tensor(alloc, tensor->dtype(), tensor->shape()); buffer = DMAHelper::buffer(&host_copy); addr = buffer->data(); diff --git a/tensorflow/contrib/graph_editor/reroute.py b/tensorflow/contrib/graph_editor/reroute.py index 95c02a64d47c26e731ef2628fb551529e9bc3f4d..d42e0c01f455f861e9ccdbfb79aefab762e61abe 100644 --- a/tensorflow/contrib/graph_editor/reroute.py +++ b/tensorflow/contrib/graph_editor/reroute.py @@ -208,9 +208,9 @@ def _reroute_ts(ts0, ts1, mode, can_modify=None, cannot_modify=None): def swap_ts(ts0, ts1, can_modify=None, cannot_modify=None): """For each tensor's pair, swap the end of (t0,t1). - B0 B1 B0 B1 - | | => X - A0 A1 A0 A1 + B0 B1 B0 B1 + | | => X + A0 A1 A0 A1 Args: ts0: an object convertible to a list of `tf.Tensor`. @@ -233,9 +233,9 @@ def swap_ts(ts0, ts1, can_modify=None, cannot_modify=None): def reroute_ts(ts0, ts1, can_modify=None, cannot_modify=None): """For each tensor's pair, replace the end of t1 by the end of t0. - B0 B1 B0 B1 - | | => |/ - A0 A1 A0 A1 + B0 B1 B0 B1 + | | => |/ + A0 A1 A0 A1 The end of the tensors in ts1 are left dangling. diff --git a/tensorflow/contrib/kafka/ops/kafka_ops.cc b/tensorflow/contrib/kafka/ops/kafka_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..8cdf16103bab2b22d51c144d21a589e1e39f2f0b --- /dev/null +++ b/tensorflow/contrib/kafka/ops/kafka_ops.cc @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("KafkaDataset") + .Input("topics: string") + .Input("servers: string") + .Input("group: string") + .Input("eof: bool") + .Input("timeout: int64") + .Output("handle: variant") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that emits the messages of one or more Kafka topics. + +topics: A `tf.string` tensor containing one or more subscriptions, + in the format of [topic:partition:offset:length], + by default length is -1 for unlimited. +servers: A list of bootstrap servers. +group: The consumer group id. +eof: If True, the kafka reader will stop on EOF. +timeout: The timeout value for the Kafka Consumer to wait + (in millisecond). +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py index 1f9e82b41bf09b235e93fa512a50ea4c3047c01b..cb649a37510c301cb3df997f844617e9a4e6c7be 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py @@ -18,10 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras.preprocessing.image import apply_transform from tensorflow.python.keras.preprocessing.image import array_to_img from tensorflow.python.keras.preprocessing.image import DirectoryIterator -from tensorflow.python.keras.preprocessing.image import flip_axis from tensorflow.python.keras.preprocessing.image import ImageDataGenerator from tensorflow.python.keras.preprocessing.image import img_to_array from tensorflow.python.keras.preprocessing.image import Iterator diff --git a/tensorflow/contrib/kinesis/BUILD b/tensorflow/contrib/kinesis/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..25443d0ad47aa7d503f905eb34000488b62f22c6 --- /dev/null +++ b/tensorflow/contrib/kinesis/BUILD @@ -0,0 +1,113 @@ +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_custom_op_py_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_kernel_library", + "tf_py_test", +) + +py_library( + name = "kinesis", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_ops", + ], +) + +tf_custom_op_library( + name = "_dataset_ops.so", + srcs = ["ops/dataset_ops.cc"], + deps = [":dataset_kernels"], +) + +tf_gen_op_libs( + op_lib_names = ["dataset_ops"], +) + +cc_library( + name = "dataset_kernels", + srcs = [ + "kernels/kinesis_dataset_ops.cc", + ], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core/platform/s3:aws_crypto", + "//third_party/eigen3", + "@aws", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, +) + +py_library( + name = "dataset_ops", + srcs = [ + "python/ops/kinesis_dataset_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":kinesis_op_loader", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_dataset_ops", + out = "python/ops/gen_dataset_ops.py", + deps = ["//tensorflow/contrib/kinesis:dataset_ops_op_lib"], +) + +tf_kernel_library( + name = "dataset_ops_kernels", + deps = [ + ":dataset_kernels", + "//tensorflow/core:framework", + ], + alwayslink = 1, +) + +tf_custom_op_py_library( + name = "kinesis_op_loader", + srcs = ["python/ops/kinesis_op_loader.py"], + dso = ["//tensorflow/contrib/kinesis:_dataset_ops.so"], + kernels = [ + ":dataset_ops_kernels", + "//tensorflow/contrib/kinesis:dataset_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_dataset_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + ], +) + +tf_py_test( + name = "kinesis_test", + srcs = ["python/kernel_tests/kinesis_test.py"], + additional_deps = [ + ":kinesis", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + tags = [ + "manual", + "no_windows", + "notap", + ], +) diff --git a/tensorflow/contrib/kinesis/__init__.py b/tensorflow/contrib/kinesis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3824b8ae7532ab97a5ebf01ab66ece6476c87d42 --- /dev/null +++ b/tensorflow/contrib/kinesis/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Kinesis Dataset. + +@@KinesisDataset +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kinesis.python.ops.kinesis_dataset_ops import KinesisDataset + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "KinesisDataset", +] + +remove_undocumented(__name__) diff --git a/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..3212279c4c50efb92acc712b82cb3e1a22c76870 --- /dev/null +++ b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc @@ -0,0 +1,359 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/platform/s3/aws_crypto.h" + +namespace tensorflow { +namespace { + +Aws::Client::ClientConfiguration* InitializeDefaultClientConfig() { + static Aws::Client::ClientConfiguration config; + const char* endpoint = getenv("KINESIS_ENDPOINT"); + if (endpoint) { + config.endpointOverride = Aws::String(endpoint); + } + const char* region = getenv("AWS_REGION"); + if (region) { + config.region = Aws::String(region); + } else { + // Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG + // is set with a truthy value. + const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG"); + string load_config = + load_config_env ? str_util::Lowercase(load_config_env) : ""; + if (load_config == "true" || load_config == "1") { + Aws::String config_file; + // If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config. + const char* config_file_env = getenv("AWS_CONFIG_FILE"); + if (config_file_env) { + config_file = config_file_env; + } else { + const char* home_env = getenv("HOME"); + if (home_env) { + config_file = home_env; + config_file += "/.aws/config"; + } + } + Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file); + // Load the configuration. If successful, get the region. + // If the load is not successful, then generate a warning. + if (loader.Load()) { + auto profiles = loader.GetProfiles(); + if (!profiles["default"].GetRegion().empty()) { + config.region = profiles["default"].GetRegion(); + } + } else { + LOG(WARNING) << "Failed to load the profile in " << config_file << "."; + } + } + } + const char* use_https = getenv("KINESIS_USE_HTTPS"); + if (use_https) { + if (use_https[0] == '0') { + config.scheme = Aws::Http::Scheme::HTTP; + } else { + config.scheme = Aws::Http::Scheme::HTTPS; + } + } + const char* verify_ssl = getenv("KINESIS_VERIFY_SSL"); + if (verify_ssl) { + if (verify_ssl[0] == '0') { + config.verifySSL = false; + } else { + config.verifySSL = true; + } + } + const char* connect_timeout = getenv("KINESIS_CONNECT_TIMEOUT_MSEC"); + if (connect_timeout) { + int64 timeout; + + if (strings::safe_strto64(connect_timeout, &timeout)) { + config.connectTimeoutMs = timeout; + } + } + const char* request_timeout = getenv("KINESIS_REQUEST_TIMEOUT_MSEC"); + if (request_timeout) { + int64 timeout; + + if (strings::safe_strto64(request_timeout, &timeout)) { + config.requestTimeoutMs = timeout; + } + } + + return &config; +} + +Aws::Client::ClientConfiguration& GetDefaultClientConfig() { + static Aws::Client::ClientConfiguration* config = + InitializeDefaultClientConfig(); + return *config; +} + +static mutex mu(LINKER_INITIALIZED); +static unsigned count(0); +void AwsInitAPI() { + mutex_lock lock(mu); + count++; + if (count == 1) { + Aws::SDKOptions options; + options.cryptoOptions.sha256Factory_create_fn = []() { + return Aws::MakeShared(AWSCryptoAllocationTag); + }; + options.cryptoOptions.sha256HMACFactory_create_fn = []() { + return Aws::MakeShared(AWSCryptoAllocationTag); + }; + Aws::InitAPI(options); + } +} +void AwsShutdownAPI() { + mutex_lock lock(mu); + count--; + if (count == 0) { + Aws::SDKOptions options; + Aws::ShutdownAPI(options); + } +} +void ShutdownClient(Aws::Kinesis::KinesisClient* client) { + if (client != nullptr) { + delete client; + AwsShutdownAPI(); + } +} +} +class KinesisDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + std::string stream = ""; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "stream", &stream)); + std::string shard = ""; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "shard", &shard)); + bool read_indefinitely = true; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "read_indefinitely", + &read_indefinitely)); + int64 interval = -1; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "interval", &interval)); + OP_REQUIRES(ctx, (interval > 0), + errors::InvalidArgument( + "Interval value should be large than 0, got ", interval)); + *output = new Dataset(ctx, stream, shard, read_indefinitely, interval); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, const string& stream, const string& shard, + const bool read_indefinitely, const int64 interval) + : GraphDatasetBase(ctx), + stream_(stream), + shard_(shard), + read_indefinitely_(read_indefinitely), + interval_(interval) {} + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::Kinesis")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}}); + return *shapes; + } + + string DebugString() const override { return "KinesisDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Node** output) const override { + Node* stream = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(stream_, &stream)); + Node* shard = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(shard_, &shard)); + Node* read_indefinitely = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(read_indefinitely_, &read_indefinitely)); + Node* interval = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(interval_, &interval)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {stream, shard, read_indefinitely, interval}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params), + client_(nullptr, ShutdownClient) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (iterator_ == "") { + TF_RETURN_IF_ERROR(SetupStreamsLocked()); + } + do { + Aws::Kinesis::Model::GetRecordsRequest request; + auto outcome = client_->GetRecords( + request.WithShardIterator(iterator_).WithLimit(1)); + if (!outcome.IsSuccess()) { + return errors::Unknown(outcome.GetError().GetExceptionName(), ": ", + outcome.GetError().GetMessage()); + } + if (outcome.GetResult().GetRecords().size() == 0) { + // If no records were returned then nothing is available at the + // moment. + if (!dataset()->read_indefinitely_) { + *end_of_sequence = true; + return Status::OK(); + } + // Continue the loop after a period of time. + ctx->env()->SleepForMicroseconds(dataset()->interval_); + continue; + } + if (outcome.GetResult().GetRecords().size() != 1) { + return errors::Unknown("invalid number of records ", + outcome.GetResult().GetRecords().size(), + " returned"); + } + + iterator_ = outcome.GetResult().GetNextShardIterator(); + + const auto& data = outcome.GetResult().GetRecords()[0].GetData(); + StringPiece value( + reinterpret_cast(data.GetUnderlyingData()), + data.GetLength()); + Tensor value_tensor(ctx->allocator({}), DT_STRING, {}); + value_tensor.scalar()() = std::string(value); + out_tensors->emplace_back(std::move(value_tensor)); + + *end_of_sequence = false; + return Status::OK(); + } while (true); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + return errors::Unimplemented("SaveInternal is currently not supported"); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return errors::Unimplemented( + "RestoreInternal is currently not supported"); + } + + private: + // Sets up Kinesis streams to read from. + Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + AwsInitAPI(); + client_.reset( + new Aws::Kinesis::KinesisClient(GetDefaultClientConfig())); + + Aws::Kinesis::Model::DescribeStreamRequest request; + auto outcome = client_->DescribeStream( + request.WithStreamName(dataset()->stream_.c_str())); + if (!outcome.IsSuccess()) { + return errors::Unknown(outcome.GetError().GetExceptionName(), ": ", + outcome.GetError().GetMessage()); + } + Aws::String shard; + Aws::String sequence; + if (dataset()->shard_ == "") { + if (outcome.GetResult().GetStreamDescription().GetShards().size() != + 1) { + return errors::InvalidArgument( + "shard has to be provided unless the stream only have one " + "shard, there are ", + outcome.GetResult().GetStreamDescription().GetShards().size(), + " shards in stream ", dataset()->stream_); + } + shard = outcome.GetResult() + .GetStreamDescription() + .GetShards()[0] + .GetShardId(); + sequence = outcome.GetResult() + .GetStreamDescription() + .GetShards()[0] + .GetSequenceNumberRange() + .GetStartingSequenceNumber(); + } else { + for (const auto& entry : + outcome.GetResult().GetStreamDescription().GetShards()) { + if (entry.GetShardId() == dataset()->shard_.c_str()) { + shard = entry.GetShardId(); + sequence = + entry.GetSequenceNumberRange().GetStartingSequenceNumber(); + break; + } + } + if (shard == "") { + return errors::InvalidArgument("no shard ", dataset()->shard_, + " in stream ", dataset()->stream_); + } + } + + Aws::Kinesis::Model::GetShardIteratorRequest iterator_request; + auto iterator_outcome = client_->GetShardIterator( + iterator_request.WithStreamName(dataset()->stream_.c_str()) + .WithShardId(shard) + .WithShardIteratorType( + Aws::Kinesis::Model::ShardIteratorType::AT_SEQUENCE_NUMBER) + .WithStartingSequenceNumber(sequence)); + if (!iterator_outcome.IsSuccess()) { + return errors::Unknown(iterator_outcome.GetError().GetExceptionName(), + ": ", + iterator_outcome.GetError().GetMessage()); + } + iterator_ = iterator_outcome.GetResult().GetShardIterator(); + return Status::OK(); + } + + mutex mu_; + Aws::String iterator_ GUARDED_BY(mu_); + std::unique_ptr + client_ GUARDED_BY(mu_); + }; + + const std::string stream_; + const std::string shard_; + const bool read_indefinitely_; + const int64 interval_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("KinesisDataset").Device(DEVICE_CPU), + KinesisDatasetOp); + +} // namespace tensorflow diff --git a/tensorflow/contrib/kinesis/ops/dataset_ops.cc b/tensorflow/contrib/kinesis/ops/dataset_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..54204513cf22519ecfb5fa45748250ee0f4aac7a --- /dev/null +++ b/tensorflow/contrib/kinesis/ops/dataset_ops.cc @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("KinesisDataset") + .Input("stream: string") + .Input("shard: string") + .Input("read_indefinitely: bool") + .Input("interval: int64") + .Output("handle: variant") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that emits the messages of one or more Kinesis topics. + +stream: A `tf.string` tensor containing the name of the stream. +shard: A `tf.string` tensor containing the id of the shard. +read_indefinitely: If `True`, the Kinesis dataset will keep retry + again on `EOF` after the `interval` period. If `False`, then + the dataset will stop on `EOF`. The default value is `True`. +interval: The interval for the Kinesis Client to wait before + it tries to get records again (in millisecond). +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7289b45c50fa92455b4c317b8a039ca414fa585e --- /dev/null +++ b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py @@ -0,0 +1,139 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for KinesisDataset. +NOTE: boto3 is needed and the test has to be invoked manually: +``` +$ bazel test -s --verbose_failures --config=opt \ + --action_env=AWS_ACCESS_KEY_ID=XXXXXX \ + --action_env=AWS_SECRET_ACCESS_KEY=XXXXXX \ + //tensorflow/contrib/kinesis:kinesis_test +``` +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import boto3 + +from tensorflow.contrib.kinesis.python.ops import kinesis_dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class KinesisDatasetTest(test.TestCase): + + def testKinesisDatasetOneShard(self): + client = boto3.client('kinesis', region_name='us-east-1') + + # Setup the Kinesis with 1 shard. + stream_name = "tf_kinesis_test_1" + client.create_stream(StreamName=stream_name, ShardCount=1) + # Wait until stream exists, default is 10 * 18 seconds. + client.get_waiter('stream_exists').wait(StreamName=stream_name) + for i in range(10): + data = "D" + str(i) + client.put_record( + StreamName=stream_name, Data=data, PartitionKey="TensorFlow" + str(i)) + + stream = array_ops.placeholder(dtypes.string, shape=[]) + num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kinesis_dataset_ops.KinesisDataset( + stream, read_indefinitely=False).repeat(num_epochs) + batch_dataset = repeat_dataset.batch(batch_size) + + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + init_batch_op = iterator.make_initializer(batch_dataset) + get_next = iterator.get_next() + + with self.test_session() as sess: + # Basic test: read from shard 0 of stream 1. + sess.run(init_op, feed_dict={stream: stream_name, num_epochs: 1}) + for i in range(10): + self.assertEqual("D" + str(i), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + client.delete_stream(StreamName=stream_name) + # Wait until stream deleted, default is 10 * 18 seconds. + client.get_waiter('stream_not_exists').wait(StreamName=stream_name) + + def testKinesisDatasetTwoShards(self): + client = boto3.client('kinesis', region_name='us-east-1') + + # Setup the Kinesis with 2 shards. + stream_name = "tf_kinesis_test_2" + client.create_stream(StreamName=stream_name, ShardCount=2) + # Wait until stream exists, default is 10 * 18 seconds. + client.get_waiter('stream_exists').wait(StreamName=stream_name) + + for i in range(10): + data = "D" + str(i) + client.put_record( + StreamName=stream_name, Data=data, PartitionKey="TensorFlow" + str(i)) + response = client.describe_stream(StreamName=stream_name) + shard_id_0 = response["StreamDescription"]["Shards"][0]["ShardId"] + shard_id_1 = response["StreamDescription"]["Shards"][1]["ShardId"] + + stream = array_ops.placeholder(dtypes.string, shape=[]) + shard = array_ops.placeholder(dtypes.string, shape=[]) + num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kinesis_dataset_ops.KinesisDataset( + stream, shard, read_indefinitely=False).repeat(num_epochs) + batch_dataset = repeat_dataset.batch(batch_size) + + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + init_batch_op = iterator.make_initializer(batch_dataset) + get_next = iterator.get_next() + + data = list() + with self.test_session() as sess: + # Basic test: read from shard 0 of stream 2. + sess.run( + init_op, feed_dict={ + stream: stream_name, shard: shard_id_0, num_epochs: 1}) + with self.assertRaises(errors.OutOfRangeError): + # Use range(11) to guarantee the OutOfRangeError. + for i in range(11): + data.append(sess.run(get_next)) + + # Basic test: read from shard 1 of stream 2. + sess.run( + init_op, feed_dict={ + stream: stream_name, shard: shard_id_1, num_epochs: 1}) + with self.assertRaises(errors.OutOfRangeError): + # Use range(11) to guarantee the OutOfRangeError. + for i in range(11): + data.append(sess.run(get_next)) + + data.sort() + self.assertEqual(data, ["D" + str(i) for i in range(10)]) + + client.delete_stream(StreamName=stream_name) + # Wait until stream deleted, default is 10 * 18 seconds. + client.get_waiter('stream_not_exists').wait(StreamName=stream_name) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ca2df95ba4f20ec5fa58ff13530096e6e065f4fe --- /dev/null +++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py @@ -0,0 +1,96 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Kinesis Dataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import +from tensorflow.contrib.kinesis.python.ops import gen_dataset_ops +from tensorflow.python.data.ops.dataset_ops import Dataset +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape + + +class KinesisDataset(Dataset): + """A Kinesis Dataset that consumes the message. + + Kinesis is a managed service provided by AWS for data streaming. + This dataset reads messages from Kinesis with each message presented + as a `tf.string`. + + For example, we can construct and use the KinesisDataset as follows: + ```python + dataset = tf.contrib.kinesis.KinesisDataset( + "kinesis_stream_name", read_indefinitely=False) + next = dataset.make_one_shot_iterator().get_next() + with tf.Session() as sess: + while True: + try: + print(sess.run(nxt)) + except tf.errors.OutOfRangeError: + break + ``` + + Since Kinesis is a data streaming service, data may not be available + at the time it is being read. The argument `read_indefinitely` is + used to control the behavior in this situation. If `read_indefinitely` + is `True`, then `KinesisDataset` will keep retrying to retrieve data + from the stream. If `read_indefinitely` is `False`, an `OutOfRangeError` + is returned immediately instead. + """ + + def __init__(self, + stream, + shard="", + read_indefinitely=True, + interval=100000): + """Create a KinesisDataset. + + Args: + stream: A `tf.string` tensor containing the name of the stream. + shard: A `tf.string` tensor containing the id of the shard. + read_indefinitely: If `True`, the Kinesis dataset will keep retry + again on `EOF` after the `interval` period. If `False`, then + the dataset will stop on `EOF`. The default value is `True`. + interval: The interval for the Kinesis Client to wait before + it tries to get records again (in millisecond). + """ + super(KinesisDataset, self).__init__() + self._stream = ops.convert_to_tensor( + stream, dtype=dtypes.string, name="stream") + self._shard = ops.convert_to_tensor( + shard, dtype=dtypes.string, name="shard") + self._read_indefinitely = ops.convert_to_tensor( + read_indefinitely, dtype=dtypes.bool, name="read_indefinitely") + self._interval = ops.convert_to_tensor( + interval, dtype=dtypes.int64, name="interval") + + def _as_variant_tensor(self): + return gen_dataset_ops.kinesis_dataset( + self._stream, self._shard, self._read_indefinitely, self._interval) + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.scalar() + + @property + def output_types(self): + return dtypes.string diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py b/tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..c9ce9f3646200a777cdbdf34b37626154ca730bb --- /dev/null +++ b/tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py @@ -0,0 +1,24 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python helper for loading kinesis ops and kernels.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader + +_dataset_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("../../_dataset_ops.so")) diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index bc3359693562deb1229a78a2db5c256c76f7fd8d..a7b41b714ffaa062e2eba8caf9b4fa033c7633cd 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -121,6 +121,7 @@ from tensorflow.contrib.layers.python.layers import * from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = ['bias_add', + 'conv1d', 'conv2d', 'conv3d', 'elu', diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py index dd2395f8c9748dadbecfe47df5511874d5f848ea..7ede193029d2d95fa4953b4c417a1e86ebb4a42e 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py @@ -21,7 +21,6 @@ from __future__ import print_function import itertools import math -import sys import numpy as np diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index beeabd6b65631cad88efd10d5faee1917e162e41..6250f8852917c00c94162ce9711bb8f34051565b 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -55,9 +55,9 @@ from tensorflow.python.training import moving_averages # TODO(b/28426988): Replace legacy_* fns migrated from slim. # TODO(b/28426988): Remove legacy_* when all uses have migrated to new API. __all__ = [ - 'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv2d', 'conv3d', - 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution', - 'convolution1d', 'convolution2d', 'convolution2d_in_plane', + 'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv1d', 'conv2d', + 'conv3d', 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', + 'convolution', 'convolution1d', 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose', 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse', 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', 'gdn', 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d', @@ -1702,19 +1702,22 @@ def _inner_flatten(inputs, new_rank, output_collections=None, scope=None): return utils.collect_named_outputs(output_collections, sc, flattened) -def _model_variable_getter(getter, - name, - shape=None, - dtype=None, - initializer=None, - regularizer=None, - trainable=True, - collections=None, - caching_device=None, - partitioner=None, - rename=None, - use_resource=None, - **_): +def _model_variable_getter( + getter, + name, + shape=None, + dtype=None, + initializer=None, + regularizer=None, + trainable=True, + collections=None, + caching_device=None, + partitioner=None, + rename=None, + use_resource=None, + synchronization=tf_variables.VariableSynchronization.AUTO, + aggregation=tf_variables.VariableAggregation.NONE, + **_): """Getter that uses model_variable for compatibility with core layers.""" short_name = name.split('/')[-1] if rename and short_name in rename: @@ -1732,7 +1735,9 @@ def _model_variable_getter(getter, caching_device=caching_device, partitioner=partitioner, custom_getter=getter, - use_resource=use_resource) + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) def _build_variable_getter(rename=None): @@ -2655,7 +2660,7 @@ def separable_convolution2d( inputs, num_outputs, kernel_size, - depth_multiplier, + depth_multiplier=1, stride=1, padding='SAME', data_format=DATA_FORMAT_NHWC, @@ -3315,6 +3320,7 @@ relu6 = functools.partial(fully_connected, activation_fn=nn.relu6) linear = functools.partial(fully_connected, activation_fn=None) # Simple alias. +conv1d = convolution1d conv2d = convolution2d conv3d = convolution3d conv2d_transpose = convolution2d_transpose diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 0e35b1aa8bf682c1b4f7e8d974d3e8fad69e33cb..dad3da3748097c26e07b4abe0495f62a18aad369 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -514,15 +514,15 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): original_vars = set(tape.watched_variables()) # Backward pass - def grad_fn(*output_grads, **kwargs): + def _grad_fn(output_grads, variables=None): """Recompute outputs for gradient computation.""" - variables = [] + variables = variables or [] if original_vars: - variables = kwargs["variables"] - if set(variables) != original_vars: - raise ValueError(_WRONG_VARS_ERR) - del kwargs - inputs = list(args) + assert variables, ("Fn created variables but the variables were not " + "passed to the gradient fn.") + if set(variables) != original_vars: + raise ValueError(_WRONG_VARS_ERR) + inputs = [array_ops.identity(x) for x in list(args)] # Recompute outputs with framework_ops.control_dependencies(output_grads): if use_data_dep_: @@ -538,7 +538,7 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): if original_vars != recompute_vars: raise ValueError(_WRONG_VARS_ERR) - if not (isinstance(outputs, list) or isinstance(outputs, tuple)): + if not isinstance(outputs, (list, tuple)): outputs = [outputs] outputs = list(outputs) grads = gradients_impl.gradients(outputs, inputs + variables, @@ -554,6 +554,16 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): grad_vars = grads[len(inputs):] return grad_inputs, grad_vars + # custom_gradient inspects the signature of the function to determine + # whether the user expects variables passed in the grad_fn. If the function + # created variables, the grad_fn should accept the "variables" kwarg. + if original_vars: + def grad_fn(*output_grads, **kwargs): + return _grad_fn(output_grads, kwargs["variables"]) + else: + def grad_fn(*output_grads): + return _grad_fn(output_grads) + return outputs, grad_fn return fn_with_recompute(*args) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index bc09ba8d439808c1582f207a99504012afcf33a6..d5971fb9d8e2fbc1e14fd24fc79e7981a284a418 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -372,6 +372,26 @@ class RecomputeTest(test.TestCase): self.assertEqual(2, len(update_ops)) self.assertEqual([False, True], kwarg_values) + def testWithoutVariables(self): + + def concat_n(layer_list, num_inputs): + return math_ops.reduce_sum( + array_ops.concat([x for x in layer_list[-num_inputs:]], axis=-1), + axis=1, keepdims=True) + + @rev_block_lib.recompute_grad + def concat_n_wrap(*args): + return concat_n(args, 3) + + # DenseNet-style layers + layer_list = [random_ops.random_uniform((4, 8))] + for _ in range(5): + layer_list.append(math_ops.sqrt(concat_n_wrap(*layer_list))) + + grads = gradients_impl.gradients(layer_list[-1], layer_list[0]) + with self.test_session() as sess: + sess.run(grads) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 7a026a15e4aeea0dde4ed9f7de053a757a0abb58..c1de42782efb3497660affb3ef7162457977c150 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -72,6 +72,7 @@ from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary as core_summary from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import device_setter from tensorflow.python.training import monitored_session from tensorflow.python.training import saver @@ -891,7 +892,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, # Check that model has been trained (if nothing has been set explicitly). if not checkpoint_path: - latest_path = saver.latest_checkpoint(self._model_dir) + latest_path = checkpoint_management.latest_checkpoint(self._model_dir) if not latest_path: raise NotFittedError( "Couldn't find trained model at %s." % self._model_dir) @@ -956,7 +957,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, as_iterable=True, iterate_batches=False): # Check that model has been trained. - checkpoint_path = saver.latest_checkpoint(self._model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir) if not checkpoint_path: raise NotFittedError( "Couldn't find trained model at %s." % self._model_dir) @@ -1364,7 +1365,7 @@ class Estimator(BaseEstimator): if not checkpoint_path: # Locate the latest checkpoint - checkpoint_path = saver.latest_checkpoint(self._model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir) if not checkpoint_path: raise NotFittedError( "Couldn't find trained model at %s." % self._model_dir) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 339c4e0e360ed9ef9906f0e51b64a0dc13826259..ded93d4a7fb473c0c5df446ea89c5ab7784e9f3c 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -563,10 +563,10 @@ def _mean_squared_loss(labels, logits, weights=None): labels = ops.convert_to_tensor(labels) # To prevent broadcasting inside "-". if len(labels.get_shape()) == 1: - labels = array_ops.expand_dims(labels, dim=(1,)) + labels = array_ops.expand_dims(labels, axis=(1,)) # TODO(zakaria): make sure it does not recreate the broadcast bug. if len(logits.get_shape()) == 1: - logits = array_ops.expand_dims(logits, dim=(1,)) + logits = array_ops.expand_dims(logits, axis=(1,)) logits.get_shape().assert_is_compatible_with(labels.get_shape()) loss = math_ops.square(logits - math_ops.to_float(labels), name=name) return _compute_weighted_loss(loss, weights) @@ -579,10 +579,10 @@ def _poisson_loss(labels, logits, weights=None): labels = ops.convert_to_tensor(labels) # To prevent broadcasting inside "-". if len(labels.get_shape()) == 1: - labels = array_ops.expand_dims(labels, dim=(1,)) + labels = array_ops.expand_dims(labels, axis=(1,)) # TODO(zakaria): make sure it does not recreate the broadcast bug. if len(logits.get_shape()) == 1: - logits = array_ops.expand_dims(logits, dim=(1,)) + logits = array_ops.expand_dims(logits, axis=(1,)) logits.get_shape().assert_is_compatible_with(labels.get_shape()) loss = nn.log_poisson_loss(labels, logits, compute_full_loss=True, name=name) @@ -797,7 +797,7 @@ def _log_loss_with_two_classes(labels, logits, weights=None): # TODO(ptucker): This will break for dynamic shapes. # sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels. if len(labels.get_shape()) == 1: - labels = array_ops.expand_dims(labels, dim=(1,)) + labels = array_ops.expand_dims(labels, axis=(1,)) loss = nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits, name=name) return _compute_weighted_loss(loss, weights) diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py index 14ee2ba6094760d52180d6de7763ea88b8ee98c8..c36879e0483c92db0cc08dedbb483bcc288d4894 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py @@ -240,6 +240,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): keep_checkpoint_max=5, keep_checkpoint_every_n_hours=10000, log_step_count_steps=100, + protocol=None, evaluation_master='', model_dir=None, session_config=None): @@ -289,6 +290,8 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): session_config: a ConfigProto used to set session parameters, or None. Note - using this argument, it is easy to provide settings which break otherwise perfectly good models. Use with care. + protocol: An optional argument which specifies the protocol used when + starting server. None means default to grpc. """ # Neither parent class calls super().__init__(), so here we have to # manually call their __init__() methods. @@ -299,6 +302,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): # so instead of breaking compatibility with that assumption, we # just manually initialize this field: self._train_distribute = None + self._eval_distribute = None self._device_fn = None gpu_options = config_pb2.GPUOptions( @@ -313,6 +317,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): self._save_summary_steps = save_summary_steps self._save_checkpoints_secs = save_checkpoints_secs self._log_step_count_steps = log_step_count_steps + self._protocol = protocol self._session_config = session_config if save_checkpoints_secs == RunConfig._USE_DEFAULT: if save_checkpoints_steps is None: diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index f8a3709ee57a32734afa7ac8133271c75d152b2c..08e907a608b0c6df6e7ac9d9675f7f9e2b84ff5d 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -41,7 +41,7 @@ from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import server_lib from tensorflow.python.util import compat from tensorflow.python.util import function_utils @@ -95,7 +95,7 @@ class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener): # Load and cache the path of the most recent checkpoint to avoid duplicate # searches on GCS. logging.info("Checking for checkpoint in %s", self._model_dir) - latest_path = saver.latest_checkpoint(self._model_dir) + latest_path = checkpoint_management.latest_checkpoint(self._model_dir) if not latest_path: logging.warning("Skipping evaluation and export since model has not been " @@ -516,7 +516,8 @@ class Experiment(object): start = time.time() error_msg = None - latest_path = saver.latest_checkpoint(self._estimator.model_dir) + latest_path = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) if not latest_path: error_msg = ("Estimator is not fitted yet. " "Will start an evaluation when a checkpoint is ready.") @@ -778,7 +779,8 @@ class Experiment(object): saving_listeners=self._saving_listeners) logging.info("Evaluating model now.") - latest_checkpoint = saver.latest_checkpoint(self._estimator.model_dir) + latest_checkpoint = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) eval_result = self._call_evaluate( input_fn=self._eval_input_fn, steps=self._eval_steps, diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py index 0d039d593b7850ead34484f88426255dc659b7fc..df156da3f467538ed1c6b640d651fdfd33ce243d 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.summary import summary +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as saver_lib @@ -124,7 +125,7 @@ class GraphActionsTest(test.TestCase): # TODO(ptucker): Test number and contents of checkpoint files. def _assert_ckpt(self, output_dir, expected=True): - ckpt_state = saver_lib.get_checkpoint_state(output_dir) + ckpt_state = checkpoint_management.get_checkpoint_state(output_dir) if expected: pattern = '%s/model.ckpt-.*' % output_dir primary_ckpt_path = ckpt_state.model_checkpoint_path @@ -434,7 +435,7 @@ class GraphActionsTrainTest(test.TestCase): # TODO(ptucker): Test number and contents of checkpoint files. def _assert_ckpt(self, output_dir, expected=True): - ckpt_state = saver_lib.get_checkpoint_state(output_dir) + ckpt_state = checkpoint_management.get_checkpoint_state(output_dir) if expected: pattern = '%s/model.ckpt-.*' % output_dir primary_ckpt_path = ckpt_state.model_checkpoint_path diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index 77f7c73d5412d40b338eaff4cf04d99fd0892723..3d691d434044aab1e3e86457cee6aadb5bf798c7 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -51,7 +51,7 @@ from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary as core_summary -from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util from tensorflow.python.util import deprecation @@ -735,7 +735,8 @@ class ValidationMonitor(EveryN): return False self._last_checkpoint_check_time = current_time # Check that we are not running evaluation on the same checkpoint. - latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) + latest_path = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) if latest_path is None: logging.debug("Skipping evaluation since model has not been saved yet " "at step %d.", step) @@ -1059,7 +1060,8 @@ class ExportMonitor(EveryN): def end(self, session=None): super(ExportMonitor, self).end(session=session) - latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) + latest_path = checkpoint_management.latest_checkpoint( + self._estimator.model_dir) if latest_path is None: logging.info("Skipping export at the end since model has not been saved " "yet.") diff --git a/tensorflow/contrib/learn/python/learn/monitors_test.py b/tensorflow/contrib/learn/python/learn/monitors_test.py index 5c34d0ddb01f3bcdc407e6926e7c5b73be1863b4..ff1da32c218b4e105b5503426ac01410665f9c7e 100644 --- a/tensorflow/contrib/learn/python/learn/monitors_test.py +++ b/tensorflow/contrib/learn/python/learn/monitors_test.py @@ -39,9 +39,9 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import gradient_descent from tensorflow.python.training import monitored_session -from tensorflow.python.training import saver from tensorflow.python.training import training_util @@ -317,7 +317,7 @@ class MonitorsTest(test.TestCase): self._run_monitor(monitor) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_no_ckpt(self, mock_latest_checkpoint, mock_estimator_class): estimator = mock_estimator_class() @@ -336,7 +336,7 @@ class MonitorsTest(test.TestCase): mock_latest_checkpoint.assert_called_with(model_dir) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_no_early_stopping_rounds(self, mock_latest_checkpoint, mock_estimator_class): @@ -356,7 +356,7 @@ class MonitorsTest(test.TestCase): self._assert_validation_monitor(monitor) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_invalid_metric(self, mock_latest_checkpoint, mock_estimator_class): estimator = mock_estimator_class() @@ -375,7 +375,7 @@ class MonitorsTest(test.TestCase): self._run_monitor(monitor, num_epochs=1, num_steps_per_epoch=1) @test.mock.patch.object(estimators, 'Estimator', autospec=True) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor(self, mock_latest_checkpoint, mock_estimator_class): estimator = mock_estimator_class() @@ -464,7 +464,7 @@ class MonitorsTest(test.TestCase): monitor.epoch_end(epoch=0) monitor.end() - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_with_core_estimator(self, mock_latest_checkpoint): estimator = test.mock.Mock(spec=core_estimator.Estimator) model_dir = 'model/dir' @@ -495,7 +495,7 @@ class MonitorsTest(test.TestCase): expected_best_metrics={'loss': 42.0, 'auc': 0.5}) monitor.post_step(step=step, session=None) - @test.mock.patch.object(saver, 'latest_checkpoint') + @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') def test_validation_monitor_fail_with_core_estimator_and_metrics( self, mock_latest_checkpoint): estimator = test.mock.Mock(spec=core_estimator.Estimator) diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index 3eacac7a3d3dcff4d39025fdee88e16e385b1b84..0144b93814a174cfb8c3162f407a595ac637f4f5 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as tf_saver from tensorflow.python.training import training_util @@ -298,7 +299,8 @@ def _export_estimator(estimator, # If checkpoint_path is specified, use the specified checkpoint path. checkpoint_path = (checkpoint_path or - tf_saver.latest_checkpoint(estimator._model_dir)) + checkpoint_management.latest_checkpoint( + estimator._model_dir)) with ops.Graph().as_default() as g: training_util.create_global_step(g) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index f8106d1e4a7e79f1cd651c40995be480721a8129..66af6833da1644fa4f73a24987079c9ffc8cecce 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -55,7 +55,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.summary import summary_iterator -from tensorflow.python.training import saver +from tensorflow.python.training import checkpoint_management from tensorflow.python.util import compat from tensorflow.python.util.deprecation import deprecated @@ -714,7 +714,8 @@ def make_best_model_export_strategy( # as soon as contrib is cleaned up and we can thus be sure that # estimator is a tf.estimator.Estimator and not a # tf.contrib.learn.Estimator - checkpoint_path = saver.latest_checkpoint(estimator.model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint( + estimator.model_dir) export_checkpoint_path, export_eval_result = best_model_selector.update( checkpoint_path, eval_result) diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py index 5e7b422e3cc368a22eb94ed470297ae78293c4eb..e74244720896a835174f54bb97049c1d9b1c92f8 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py +++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py @@ -625,11 +625,13 @@ def attention_decoder(decoder_inputs, v = [] attention_vec_size = attn_size # Size of query vectors for attention. for a in xrange(num_heads): - k = variable_scope.get_variable("AttnW_%d" % a, - [1, 1, attn_size, attention_vec_size]) + k = variable_scope.get_variable( + "AttnW_%d" % a, [1, 1, attn_size, attention_vec_size], + dtype=dtype) hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")) v.append( - variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size])) + variable_scope.get_variable( + "AttnV_%d" % a, [attention_vec_size], dtype=dtype)) state = initial_state @@ -647,11 +649,13 @@ def attention_decoder(decoder_inputs, with variable_scope.variable_scope("Attention_%d" % a): y = Linear(query, attention_vec_size, True)(query) y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) + y = math_ops.cast(y, dtype) # Attention mask is a softmax of v^T * tanh(...). s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y), [2, 3]) - a = nn_ops.softmax(s) + a = nn_ops.softmax(math_ops.cast(s, dtype=dtypes.float32)) # Now calculate the attention-weighted vector d. + a = math_ops.cast(a, dtype) d = math_ops.reduce_sum( array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2]) ds.append(array_ops.reshape(d, [-1, attn_size])) @@ -681,6 +685,7 @@ def attention_decoder(decoder_inputs, raise ValueError("Could not infer input size from input: %s" % inp.name) inputs = [inp] + attns + inputs = [math_ops.cast(e, dtype) for e in inputs] x = Linear(inputs, input_size, True)(inputs) # Run the RNN. cell_output, state = cell(x, state) @@ -693,6 +698,7 @@ def attention_decoder(decoder_inputs, attns = attention(state) with variable_scope.variable_scope("AttnOutputProjection"): + cell_output = math_ops.cast(cell_output, dtype) inputs = [cell_output] + attns output = Linear(inputs, output_size, True)(inputs) if loop_function is not None: diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD index 5b89c6cef9fa9fdef7c26ddee1efa03f3056d881..7534b50a4ae0076fb27fb9cd0d1dd58b29192876 100644 --- a/tensorflow/contrib/linear_optimizer/BUILD +++ b/tensorflow/contrib/linear_optimizer/BUILD @@ -41,6 +41,10 @@ py_test( size = "medium", srcs = ["python/kernel_tests/sdca_ops_test.py"], srcs_version = "PY2AND3", + tags = [ + "no_gpu", + "no_pip_gpu", + ], deps = [ ":sdca_ops_py", ":sparse_feature_column_py", diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 8c17c65fcc0dbd58e2b3e9042a983e400cd6c2b9..1e6f1e7da212c3aeb1563dc2f4b6dff2cb550736 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -47,6 +47,10 @@ cc_test( name = "arena_planner_test", size = "small", srcs = ["arena_planner_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable", + ], deps = [ ":arena_planner", "//tensorflow/contrib/lite/testing:util", @@ -121,13 +125,26 @@ cc_library( "graph_info.cc", "interpreter.cc", "model.cc", - "nnapi_delegate.cc", "op_resolver.cc", "optional_debug_tools.cc", - ], + ] + select({ + "//tensorflow:android": [ + "nnapi_delegate.cc", + "mmap_allocation.cc", + ], + "//tensorflow:windows": [ + "nnapi_delegate_disabled.cc", + "mmap_allocation_disabled.cc", + ], + "//conditions:default": [ + "nnapi_delegate_disabled.cc", + "mmap_allocation.cc", + ], + }), hdrs = [ "allocation.h", "context.h", + "context_util.h", "error_reporter.h", "graph_info.h", "interpreter.h", @@ -145,6 +162,7 @@ cc_library( ":memory_planner", ":schema_fbs_version", ":simple_memory_arena", + ":string", ":util", "//tensorflow/contrib/lite/kernels:eigen_support", "//tensorflow/contrib/lite/kernels:gemm_support", @@ -198,6 +216,7 @@ cc_test( name = "graph_info_test", size = "small", srcs = ["graph_info_test.cc"], + tags = ["no_oss"], deps = [ ":framework", ":string_util", @@ -242,6 +261,7 @@ cc_test( name = "op_resolver_test", size = "small", srcs = ["op_resolver_test.cc"], + tags = ["no_oss"], deps = [ ":framework", "//tensorflow/contrib/lite/testing:util", @@ -274,6 +294,7 @@ cc_test( name = "util_test", size = "small", srcs = ["util_test.cc"], + tags = ["no_oss"], deps = [ ":context", ":util", diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile index 2b6997146e1e5a3873ed0f94a9221b34bed7621d..9cc8f10b4290030898cffa8a8cac6ba395a30e2e 100644 --- a/tensorflow/contrib/lite/Makefile +++ b/tensorflow/contrib/lite/Makefile @@ -17,7 +17,29 @@ else endif endif -ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi) +HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi) + +# Self-hosting +TARGET_ARCH := ${HOST_ARCH} + +# Cross compiling +ifeq ($(CROSS),rpi) + TARGET_ARCH := armv7l + TARGET_TOOLCHAIN_PREFIX := arm-linux-gnueabihf- +endif + +ifeq ($(CROSS),riscv) + TARGET_ARCH := riscv + TARGET_TOOLCHAIN_PREFIX := riscv32-unknown-elf- +endif +ifeq ($(CROSS),stm32f7) + TARGET_ARCH := armf7 + TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- +endif +ifeq ($(CROSS),stm32f1) + TARGET_ARCH := armm1 + TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- +endif # Where compiled objects are stored. OBJDIR := $(MAKEFILE_DIR)/gen/obj/ @@ -25,11 +47,47 @@ BINDIR := $(MAKEFILE_DIR)/gen/bin/ LIBDIR := $(MAKEFILE_DIR)/gen/lib/ GENDIR := $(MAKEFILE_DIR)/gen/obj/ +LIBS := +ifeq ($(TARGET_ARCH),x86_64) + CXXFLAGS += -fPIC -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -pthread # -msse4.2 +endif + +ifeq ($(TARGET_ARCH),armv7l) + CXXFLAGS += -mfpu=neon -pthread -fPIC + LIBS += -ldl +endif + +ifeq ($(TARGET_ARCH),riscv) +# CXXFLAGS += -march=gap8 + CXXFLAGS += -DTFLITE_MCU + LIBS += -ldl + BUILD_TYPE := micro +endif + +ifeq ($(TARGET_ARCH),armf7) + CXXFLAGS += -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -DTFLITE_MCU + CXXFLAGS += -fno-rtti -fmessage-length=0 -fno-exceptions -fno-builtin -ffunction-sections -fdata-sections + CXXFLAGS += -funsigned-char -MMD + CXXFLAGS += -mcpu=cortex-m7 -mthumb -mfpu=fpv5-sp-d16 -mfloat-abi=softfp + CXXFLAGS += '-std=gnu++11' '-fno-rtti' '-Wvla' '-c' '-Wall' '-Wextra' '-Wno-unused-parameter' '-Wno-missing-field-initializers' '-fmessage-length=0' '-fno-exceptions' '-fno-builtin' '-ffunction-sections' '-fdata-sections' '-funsigned-char' '-MMD' '-fno-delete-null-pointer-checks' '-fomit-frame-pointer' '-Os' + LIBS += -ldl + BUILD_TYPE := micro +endif +ifeq ($(TARGET_ARCH),armm1) + CXXFLAGS += -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -mcpu=cortex-m1 -mthumb -DTFLITE_MCU + CXXFLAGS += -fno-rtti -fmessage-length=0 -fno-exceptions -fno-builtin -ffunction-sections -fdata-sections + CXXFLAGS += -funsigned-char -MMD + LIBS += -ldl +endif + # Settings for the host compiler. -CXX := $(CC_PREFIX)gcc -CXXFLAGS := --std=c++11 -O3 -DNDEBUG -CC := $(CC_PREFIX)gcc -CCFLAGS := -O3 -DNDEBUG +CXX := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}g++ +CXXFLAGS += -O3 -DNDEBUG +CCFLAGS := ${CXXFLAGS} +CXXFLAGS += --std=c++11 +CC := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}gcc +AR := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}ar +CFLAGS := LDOPTS := LDOPTS += -L/usr/local/lib ARFLAGS := -r @@ -37,6 +95,7 @@ ARFLAGS := -r INCLUDES := \ -I. \ -I$(MAKEFILE_DIR)/../../../ \ +-I$(MAKEFILE_DIR)/../../../../ \ -I$(MAKEFILE_DIR)/downloads/ \ -I$(MAKEFILE_DIR)/downloads/eigen \ -I$(MAKEFILE_DIR)/downloads/gemmlowp \ @@ -48,7 +107,7 @@ INCLUDES := \ # override local versions in the source tree. INCLUDES += -I/usr/local/include -LIBS := \ +LIBS += \ -lstdc++ \ -lpthread \ -lm \ @@ -92,18 +151,21 @@ PROFILE_SUMMARIZER_SRCS := \ CORE_CC_ALL_SRCS := \ $(wildcard tensorflow/contrib/lite/*.cc) \ +$(wildcard tensorflow/contrib/lite/*.c) +ifneq ($(BUILD_TYPE),micro) +CORE_CC_ALL_SRCS += \ $(wildcard tensorflow/contrib/lite/kernels/*.cc) \ $(wildcard tensorflow/contrib/lite/kernels/internal/*.cc) \ $(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.cc) \ $(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.cc) \ $(PROFILER_SRCS) \ -$(wildcard tensorflow/contrib/lite/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.c) \ $(wildcard tensorflow/contrib/lite/downloads/farmhash/src/farmhash.cc) \ $(wildcard tensorflow/contrib/lite/downloads/fft2d/fftsg.c) +endif # Remove any duplicates. CORE_CC_ALL_SRCS := $(sort $(CORE_CC_ALL_SRCS)) CORE_CC_EXCLUDE_SRCS := \ @@ -113,6 +175,15 @@ $(wildcard tensorflow/contrib/lite/*/*/*test.cc) \ $(wildcard tensorflow/contrib/lite/*/*/*/*test.cc) \ $(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \ $(MINIMAL_SRCS) +ifeq ($(BUILD_TYPE),micro) +CORE_CC_EXCLUDE_SRCS += \ +tensorflow/contrib/lite/mmap_allocation.cc \ +tensorflow/contrib/lite/nnapi_delegate.cc +else +CORE_CC_EXCLUDE_SRCS += \ +tensorflow/contrib/lite/mmap_allocation_disabled.cc \ +tensorflow/contrib/lite/nnapi_delegate_disabled.cc +endif # Filter out all the excluded files. TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) # File names of the intermediate files target compilation generates. @@ -120,7 +191,6 @@ TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS)))) LIB_OBJS := $(TF_LITE_CC_OBJS) - # Benchmark sources BENCHMARK_SRCS_DIR := tensorflow/contrib/lite/tools/benchmark BENCHMARK_ALL_SRCS := $(TFLITE_CC_SRCS) \ @@ -146,8 +216,15 @@ $(OBJDIR)%.o: %.c # The target that's compiled if there's no command-line arguments. all: $(LIB_PATH) $(MINIMAL_PATH) $(BENCHMARK_BINARY) +# The target that's compiled for micro-controllers +micro: $(LIB_PATH) + +# Hack for generating schema file bypassing flatbuffer parsing +tensorflow/contrib/lite/schema/schema_generated.h: + @cp -u tensorflow/contrib/lite/schema/schema_generated.h.OPENSOURCE tensorflow/contrib/lite/schema/schema_generated.h + # Gathers together all the objects we've compiled into a single '.a' archive. -$(LIB_PATH): $(LIB_OBJS) +$(LIB_PATH): tensorflow/contrib/lite/schema/schema_generated.h $(LIB_OBJS) @mkdir -p $(dir $@) $(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS) diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc index a4772731ecda92431c412672610a39c188dabf27..89462618148a2afbcf2ef6b1dd2985bcd0178734 100644 --- a/tensorflow/contrib/lite/allocation.cc +++ b/tensorflow/contrib/lite/allocation.cc @@ -13,56 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include "tensorflow/contrib/lite/allocation.h" + #include #include -#include #include #include #include #include #include -#include "tensorflow/contrib/lite/allocation.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" -#include "tensorflow/contrib/lite/nnapi_delegate.h" namespace tflite { -MMAPAllocation::MMAPAllocation(const char* filename, - ErrorReporter* error_reporter) - : Allocation(error_reporter), mmapped_buffer_(MAP_FAILED) { - mmap_fd_ = open(filename, O_RDONLY); - if (mmap_fd_ == -1) { - error_reporter_->Report("Could not open '%s'.", filename); - return; - } - struct stat sb; - fstat(mmap_fd_, &sb); - buffer_size_bytes_ = sb.st_size; - mmapped_buffer_ = - mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0); - if (mmapped_buffer_ == MAP_FAILED) { - error_reporter_->Report("Mmap of '%s' failed.", filename); - return; - } -} - -MMAPAllocation::~MMAPAllocation() { - if (valid()) { - munmap(const_cast(mmapped_buffer_), buffer_size_bytes_); - } - if (mmap_fd_ != -1) close(mmap_fd_); -} - -const void* MMAPAllocation::base() const { return mmapped_buffer_; } - -size_t MMAPAllocation::bytes() const { return buffer_size_bytes_; } - -bool MMAPAllocation::valid() const { return mmapped_buffer_ != MAP_FAILED; } - +#ifndef TFLITE_MCU FileCopyAllocation::FileCopyAllocation(const char* filename, ErrorReporter* error_reporter) : Allocation(error_reporter) { @@ -94,7 +60,9 @@ FileCopyAllocation::FileCopyAllocation(const char* filename, filename); return; } - copied_buffer_ = std::move(buffer); + // Versions of GCC before 6.2.0 don't support std::move from non-const + // char[] to const char[] unique_ptrs. + copied_buffer_.reset(const_cast(buffer.release())); } FileCopyAllocation::~FileCopyAllocation() {} @@ -104,6 +72,7 @@ const void* FileCopyAllocation::base() const { return copied_buffer_.get(); } size_t FileCopyAllocation::bytes() const { return buffer_size_bytes_; } bool FileCopyAllocation::valid() const { return copied_buffer_ != nullptr; } +#endif MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes, ErrorReporter* error_reporter) diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h index 68aee2e64473320c461ec8b3f194904e7b8da43c..121f3d264687933f45f3a2c5d2a53ad80d594ca9 100644 --- a/tensorflow/contrib/lite/allocation.h +++ b/tensorflow/contrib/lite/allocation.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/simple_memory_arena.h" +#include "tensorflow/contrib/lite/string.h" namespace tflite { @@ -51,6 +52,8 @@ class MMAPAllocation : public Allocation { size_t bytes() const override; bool valid() const override; + static bool IsSupported(); + protected: // Data required for mmap. int mmap_fd_ = -1; // mmap file descriptor diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc index 22be64d6ff649b4bff45a5e5680984d688a8cf38..02442575b3aeed04ac6569440dd52a4d5ddd4d98 100644 --- a/tensorflow/contrib/lite/arena_planner.cc +++ b/tensorflow/contrib/lite/arena_planner.cc @@ -17,14 +17,6 @@ limitations under the License. namespace tflite { -namespace { - -// Memory allocation tuning -constexpr const int kDefaultArenaAlignment = 64; -constexpr const int kDefaultTensorAlignment = 4; - -} // namespace - struct AllocationInfo { // The node index requesting this allocation. int node; @@ -35,11 +27,16 @@ struct AllocationInfo { }; ArenaPlanner::ArenaPlanner(TfLiteContext* context, - std::unique_ptr graph_info) + std::unique_ptr graph_info, + bool preserve_inputs, bool preserve_intermediates, + int tensor_alignment) : context_(context), graph_info_(std::move(graph_info)), arena_(kDefaultArenaAlignment), - persistent_arena_(kDefaultArenaAlignment) {} + persistent_arena_(kDefaultArenaAlignment), + preserve_inputs_(preserve_inputs), + preserve_intermediates_(preserve_intermediates), + tensor_alignment_(tensor_alignment) {} ArenaPlanner::~ArenaPlanner() {} @@ -112,9 +109,13 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { refcounts[tensor_index]++; } - // Queue all graph inputs for allocation. + // Queue all graph inputs for allocation. If preserve_inputs_ is true, make + // sure they never be overwritten. for (int tensor_index : graph_info_->inputs()) { if (tensor_index != kOptionalTensor) { + if (preserve_inputs_) { + refcounts[tensor_index]++; + } TF_LITE_ENSURE_STATUS(allocate(0, tensor_index)); } } @@ -159,13 +160,15 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { // Then update the ref-counts of the node's inputs, and if necessary queue // them for deallocation. - TfLiteIntArray* node_inputs = node.inputs; - for (int j = 0; j < node_inputs->size; ++j) { - int tensor_index = node_inputs->data[j]; - if (tensor_index != kOptionalTensor) { - refcounts[tensor_index]--; - if (refcounts[tensor_index] == 0) { - TF_LITE_ENSURE_STATUS(deallocate(i, tensor_index)); + if (!preserve_intermediates_) { + TfLiteIntArray* node_inputs = node.inputs; + for (int j = 0; j < node_inputs->size; ++j) { + int tensor_index = node_inputs->data[j]; + if (tensor_index != kOptionalTensor) { + refcounts[tensor_index]--; + if (refcounts[tensor_index] == 0) { + TF_LITE_ENSURE_STATUS(deallocate(i, tensor_index)); + } } } } @@ -256,14 +259,12 @@ TfLiteStatus ArenaPlanner::ResolveTensorAllocation(int tensor_index) { TfLiteStatus ArenaPlanner::CalculateTensorAllocation(int tensor_index) { TfLiteTensor& tensor = *graph_info_->tensor(tensor_index); if (tensor.allocation_type == kTfLiteArenaRw) { - TF_LITE_ENSURE_STATUS(arena_.Allocate(context_, kDefaultTensorAlignment, - tensor.bytes, - &allocs_[tensor_index])); + TF_LITE_ENSURE_STATUS(arena_.Allocate( + context_, tensor_alignment_, tensor.bytes, &allocs_[tensor_index])); } if (tensor.allocation_type == kTfLiteArenaRwPersistent) { - TF_LITE_ENSURE_STATUS( - persistent_arena_.Allocate(context_, kDefaultTensorAlignment, - tensor.bytes, &allocs_[tensor_index])); + TF_LITE_ENSURE_STATUS(persistent_arena_.Allocate( + context_, tensor_alignment_, tensor.bytes, &allocs_[tensor_index])); } return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h index e9d0fbc5a9b5aec06e28da8757466b25f40da2f5..55003cf4e92d9ca79416c0f9f7a0c57e828af4ee 100644 --- a/tensorflow/contrib/lite/arena_planner.h +++ b/tensorflow/contrib/lite/arena_planner.h @@ -25,6 +25,10 @@ limitations under the License. namespace tflite { +// Memory allocation tuning +constexpr const int kDefaultArenaAlignment = 64; +constexpr const int kDefaultTensorAlignment = 64; + struct AllocationInfo; // A memory planner that makes all the allocations using arenas. @@ -43,8 +47,12 @@ struct AllocationInfo; class ArenaPlanner : public MemoryPlanner { public: // Ownership of 'context' is not taken and it must remain util the - // ArenaPlanner is destroyed. - ArenaPlanner(TfLiteContext* context, std::unique_ptr graph_info); + // ArenaPlanner is destroyed. If 'preserve_inputs' is true the inputs to the + // graph will not share memory with any other tensor, effectively preserving + // them until the end of inference. + ArenaPlanner(TfLiteContext* context, std::unique_ptr graph_info, + bool preserve_inputs, bool preserve_intermediates, + int tensor_alignment = kDefaultTensorAlignment); ~ArenaPlanner() override; ArenaPlanner(const ArenaPlanner&) = delete; ArenaPlanner& operator=(const ArenaPlanner&) = delete; @@ -100,6 +108,18 @@ class ArenaPlanner : public MemoryPlanner { // Raw memory buffer that is allocated for persistent tensors that are // declared as kTfLiteArenaRwPersistent. SimpleMemoryArena persistent_arena_; + + // Ensure that the memory self-allocated for inputs is never reused by the + // allocator. This allows for example, multiple runs without getting + // unpredictable results. + bool preserve_inputs_; + + // If true, then no overlapping of memory areas is done, meaning intermediates + // results can be queried after running (modulo running delegates). + bool preserve_intermediates_; + + // Number of bytes that tensor buffers should be aligned to. + int tensor_alignment_; }; } // namespace tflite diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc index f0fd35216f645df59b03340e00daca9322721b1b..7d7c41289cad95b73423a7218bf1e0516b2e87a2 100644 --- a/tensorflow/contrib/lite/arena_planner_test.cc +++ b/tensorflow/contrib/lite/arena_planner_test.cc @@ -24,6 +24,8 @@ limitations under the License. namespace tflite { namespace { +constexpr const int kTensorAlignment = 4; + // A simple op to be used in tests, as syntactic sugar. class TestOp { public: @@ -151,11 +153,12 @@ void ReportError(TfLiteContext* context, const char* format, ...) { class ArenaPlannerTest : public ::testing::Test { protected: - void SetGraph(TestGraph* graph) { + void SetGraph(TestGraph* graph, bool preserve_inputs = false) { graph_ = graph; context_.ReportError = ReportError; planner_.reset(new ArenaPlanner( - &context_, std::unique_ptr(new TestGraphInfo(graph)))); + &context_, std::unique_ptr(new TestGraphInfo(graph)), + preserve_inputs, /*preserve intermediates*/ false, kTensorAlignment)); CHECK(planner_->ResetAllocations() == kTfLiteOk); CHECK(planner_->PlanAllocations() == kTfLiteOk); } @@ -177,8 +180,8 @@ class ArenaPlannerTest : public ::testing::Test { const TfLiteTensor& tensor = (*graph_->tensors())[tensor_index]; int64_t offset = GetOffset(tensor_index) + tensor.bytes; // We must make sure the offset is aligned to kDefaultArenaAlignment. - if (offset % 4 != 0) { - offset += 4 - offset % 4; + if (offset % kTensorAlignment != 0) { + offset += kTensorAlignment - offset % kTensorAlignment; } return offset; }; @@ -243,6 +246,30 @@ TEST_F(ArenaPlannerTest, SimpleGraph) { EXPECT_EQ(GetOffset(3), 0); } +TEST_F(ArenaPlannerTest, SimpleGraphInputsPreserved) { + TestGraph graph({0, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4, 5}, {}}, // Second op + {{4, 5}, {3}, {}} // Third op + }, + {3}); + SetGraph(&graph, /*preserve_inputs=*/true); + Execute(0, 10); + + // Alloc(+) and dealloc(-) order: +0 +1 +2 +4 +5 -2 +3 -4 -5 + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(4)); + // Because we are keeping the inputs alive until the end (due to + // preserve_inputs=true), the output tensor will not be able to use that + // space. It will end up using the same are as tensor #2. + EXPECT_EQ(GetOffset(3), GetOffsetAfter(1)); +} + TEST_F(ArenaPlannerTest, SimpleGraphWithTemporary) { TestGraph graph({0, 1}, { diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 5543acc1f5dabaa8a54ec4d1f2027bc66a00f6db..81844756bc7239fa798ff96b8b093afdf9ea9557 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -1,193 +1,218 @@ """Generate Flatbuffer binary from json.""" + load( "//tensorflow:tensorflow.bzl", + "tf_cc_shared_object", "tf_cc_test", ) def tflite_copts(): - """Defines compile time flags.""" - copts = [ - "-DFARMHASH_NO_CXX_STRING", - ] + select({ - str(Label("//tensorflow:android_arm64")): [ - "-std=c++11", - "-O3", - ], - str(Label("//tensorflow:android_arm")): [ - "-mfpu=neon", - "-mfloat-abi=softfp", - "-std=c++11", - "-O3", - ], - str(Label("//tensorflow:android_x86")): [ - "-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK", - ], - str(Label("//tensorflow:ios_x86_64")): [ - "-msse4.1", - ], - "//conditions:default": [], - }) + select({ - str(Label("//tensorflow:with_default_optimizations")): [], - "//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"], - }) + """Defines compile time flags.""" + copts = [ + "-DFARMHASH_NO_CXX_STRING", + ] + select({ + str(Label("//tensorflow:android_arm64")): [ + "-std=c++11", + "-O3", + ], + str(Label("//tensorflow:android_arm")): [ + "-mfpu=neon", + "-mfloat-abi=softfp", + "-std=c++11", + "-O3", + ], + str(Label("//tensorflow:android_x86")): [ + "-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK", + ], + str(Label("//tensorflow:ios_x86_64")): [ + "-msse4.1", + ], + str(Label("//tensorflow:windows")): [ + "/DTF_COMPILE_LIBRARY", + ], + "//conditions:default": [], + }) + select({ + str(Label("//tensorflow:with_default_optimizations")): [], + "//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"], + }) - return copts + return copts LINKER_SCRIPT = "//tensorflow/contrib/lite/java/src/main/native:version_script.lds" def tflite_linkopts_unstripped(): - """Defines linker flags to reduce size of TFLite binary. + """Defines linker flags to reduce size of TFLite binary. - These are useful when trying to investigate the relative size of the - symbols in TFLite. + These are useful when trying to investigate the relative size of the + symbols in TFLite. - Returns: - a select object with proper linkopts - """ - return select({ - "//tensorflow:android": [ - "-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj. - "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export. - "-Wl,--gc-sections", # Eliminate unused code and data. - "-Wl,--as-needed", # Don't link unused libs. - ], - "//tensorflow/contrib/lite:mips": [], - "//tensorflow/contrib/lite:mips64": [], - "//conditions:default": [ - "-Wl,--icf=all", # Identical code folding. - ], - }) + Returns: + a select object with proper linkopts + """ + return select({ + "//tensorflow:android": [ + "-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj. + "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export. + "-Wl,--gc-sections", # Eliminate unused code and data. + "-Wl,--as-needed", # Don't link unused libs. + ], + "//tensorflow:darwin": [], + "//tensorflow/contrib/lite:mips": [], + "//tensorflow/contrib/lite:mips64": [], + "//conditions:default": [ + "-Wl,--icf=all", # Identical code folding. + ], + }) def tflite_jni_linkopts_unstripped(): - """Defines linker flags to reduce size of TFLite binary with JNI. + """Defines linker flags to reduce size of TFLite binary with JNI. - These are useful when trying to investigate the relative size of the - symbols in TFLite. + These are useful when trying to investigate the relative size of the + symbols in TFLite. - Returns: - a select object with proper linkopts - """ - return select({ - "//tensorflow:android": [ - "-Wl,--gc-sections", # Eliminate unused code and data. - "-Wl,--as-needed", # Don't link unused libs. - ], - "//tensorflow/contrib/lite:mips": [], - "//tensorflow/contrib/lite:mips64": [], - "//conditions:default": [ - "-Wl,--icf=all", # Identical code folding. - ], - }) + Returns: + a select object with proper linkopts + """ + return select({ + "//tensorflow:android": [ + "-Wl,--gc-sections", # Eliminate unused code and data. + "-Wl,--as-needed", # Don't link unused libs. + ], + "//tensorflow:darwin": [], + "//tensorflow/contrib/lite:mips": [], + "//tensorflow/contrib/lite:mips64": [], + "//conditions:default": [ + "-Wl,--icf=all", # Identical code folding. + ], + }) def tflite_linkopts(): - """Defines linker flags to reduce size of TFLite binary.""" - return tflite_linkopts_unstripped() + select({ - "//tensorflow:android": [ - "-s", # Omit symbol table. - ], - "//conditions:default": [], - }) + """Defines linker flags to reduce size of TFLite binary.""" + return tflite_linkopts_unstripped() + select({ + "//tensorflow:android": [ + "-s", # Omit symbol table. + ], + "//conditions:default": [], + }) def tflite_jni_linkopts(): - """Defines linker flags to reduce size of TFLite binary with JNI.""" - return tflite_jni_linkopts_unstripped() + select({ - "//tensorflow:android": [ - "-s", # Omit symbol table. - "-latomic", # Required for some uses of ISO C++11 in x86. - ], - "//conditions:default": [], - }) + """Defines linker flags to reduce size of TFLite binary with JNI.""" + return tflite_jni_linkopts_unstripped() + select({ + "//tensorflow:android": [ + "-s", # Omit symbol table. + "-latomic", # Required for some uses of ISO C++11 in x86. + ], + "//conditions:default": [], + }) + +def tflite_jni_binary( + name, + copts = tflite_copts(), + linkopts = tflite_jni_linkopts(), + linkscript = LINKER_SCRIPT, + linkshared = 1, + linkstatic = 1, + deps = []): + """Builds a jni binary for TFLite.""" + linkopts = linkopts + [ + "-Wl,--version-script", # Export only jni functions & classes. + "$(location {})".format(linkscript), + ] + native.cc_binary( + name = name, + copts = copts, + linkshared = linkshared, + linkstatic = linkstatic, + deps = deps + [linkscript], + linkopts = linkopts, + ) -def tflite_jni_binary(name, - copts=tflite_copts(), - linkopts=tflite_jni_linkopts(), - linkscript=LINKER_SCRIPT, - linkshared=1, - linkstatic=1, - deps=[]): - """Builds a jni binary for TFLite.""" - linkopts = linkopts + [ - "-Wl,--version-script", # Export only jni functions & classes. - "$(location {})".format(linkscript), - ] - native.cc_binary( - name=name, - copts=copts, - linkshared=linkshared, - linkstatic=linkstatic, - deps= deps + [linkscript], - linkopts=linkopts) +def tflite_cc_shared_object( + name, + copts = tflite_copts(), + linkopts = [], + linkstatic = 1, + deps = []): + """Builds a shared object for TFLite.""" + tf_cc_shared_object( + name = name, + copts = copts, + linkstatic = linkstatic, + linkopts = linkopts + tflite_jni_linkopts(), + framework_so = [], + deps = deps, + ) def tf_to_tflite(name, src, options, out): - """Convert a frozen tensorflow graphdef to TF Lite's flatbuffer. + """Convert a frozen tensorflow graphdef to TF Lite's flatbuffer. - Args: - name: Name of rule. - src: name of the input graphdef file. - options: options passed to TOCO. - out: name of the output flatbuffer file. - """ + Args: + name: Name of rule. + src: name of the input graphdef file. + options: options passed to TOCO. + out: name of the output flatbuffer file. + """ - toco_cmdline = " ".join([ - "//tensorflow/contrib/lite/toco:toco", - "--input_format=TENSORFLOW_GRAPHDEF", - "--output_format=TFLITE", - ("--input_file=$(location %s)" % src), - ("--output_file=$(location %s)" % out), - ] + options ) - native.genrule( - name = name, - srcs=[src], - outs=[out], - cmd = toco_cmdline, - tools= ["//tensorflow/contrib/lite/toco:toco"], - ) + toco_cmdline = " ".join([ + "//tensorflow/contrib/lite/toco:toco", + "--input_format=TENSORFLOW_GRAPHDEF", + "--output_format=TFLITE", + ("--input_file=$(location %s)" % src), + ("--output_file=$(location %s)" % out), + ] + options) + native.genrule( + name = name, + srcs = [src], + outs = [out], + cmd = toco_cmdline, + tools = ["//tensorflow/contrib/lite/toco:toco"], + ) def tflite_to_json(name, src, out): - """Convert a TF Lite flatbuffer to JSON. + """Convert a TF Lite flatbuffer to JSON. - Args: - name: Name of rule. - src: name of the input flatbuffer file. - out: name of the output JSON file. - """ + Args: + name: Name of rule. + src: name of the input flatbuffer file. + out: name of the output JSON file. + """ - flatc = "@flatbuffers//:flatc" - schema = "//tensorflow/contrib/lite/schema:schema.fbs" - native.genrule( - name = name, - srcs = [schema, src], - outs = [out], - cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&" + - "$(location %s) --raw-binary --strict-json -t" + - " -o /tmp $(location %s) -- $${TMP}.bin &&" + - "cp $${TMP}.json $(location %s)") - % (src, flatc, schema, out), - tools = [flatc], - ) + flatc = "@flatbuffers//:flatc" + schema = "//tensorflow/contrib/lite/schema:schema.fbs" + native.genrule( + name = name, + srcs = [schema, src], + outs = [out], + cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&" + + "$(location %s) --raw-binary --strict-json -t" + + " -o /tmp $(location %s) -- $${TMP}.bin &&" + + "cp $${TMP}.json $(location %s)") % + (src, flatc, schema, out), + tools = [flatc], + ) def json_to_tflite(name, src, out): - """Convert a JSON file to TF Lite's flatbuffer. + """Convert a JSON file to TF Lite's flatbuffer. - Args: - name: Name of rule. - src: name of the input JSON file. - out: name of the output flatbuffer file. - """ + Args: + name: Name of rule. + src: name of the input JSON file. + out: name of the output flatbuffer file. + """ - flatc = "@flatbuffers//:flatc" - schema = "//tensorflow/contrib/lite/schema:schema_fbs" - native.genrule( - name = name, - srcs = [schema, src], - outs = [out], - cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&" + - "$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" + - " -o /tmp $(location %s) $${TMP}.json &&" + - "cp $${TMP}.bin $(location %s)") - % (src, flatc, schema, out), - tools = [flatc], - ) + flatc = "@flatbuffers//:flatc" + schema = "//tensorflow/contrib/lite/schema:schema_fbs" + native.genrule( + name = name, + srcs = [schema, src], + outs = [out], + cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&" + + "$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" + + " -o /tmp $(location %s) $${TMP}.json &&" + + "cp $${TMP}.bin $(location %s)") % + (src, flatc, schema, out), + tools = [flatc], + ) # This is the master list of generated examples that will be made into tests. A # function called make_XXX_tests() must also appear in generate_examples.py. @@ -195,7 +220,7 @@ def json_to_tflite(name, src, out): def generated_test_models(): return [ "add", - "arg_max", + "arg_min_max", "avg_pool", "batch_to_space_nd", "concat", @@ -222,6 +247,9 @@ def generated_test_models(): "local_response_norm", "log_softmax", "log", + "logical_and", + "logical_or", + "logical_xor", "lstm", "max_pool", "maximum", @@ -230,10 +258,14 @@ def generated_test_models(): "mul", "neg", "not_equal", + "one_hot", + "pack", "pad", "padv2", - # "prelu", + "prelu", "pow", + "reduce_max", + "reduce_prod", "relu", "relu1", "relu6", @@ -257,63 +289,63 @@ def generated_test_models(): "tile", "topk", "transpose", - "transpose_conv", + #"transpose_conv", # disabled due to b/111213074 "where", ] def gen_zip_test(name, test_name, **kwargs): - """Generate a zipped-example test and its dependent zip files. + """Generate a zipped-example test and its dependent zip files. - Args: - name: Resulting cc_test target name - test_name: Test targets this model. Comes from the list above. - **kwargs: tf_cc_test kwargs. - """ - gen_zipped_test_file( - name = "zip_%s" % test_name, - file = "%s.zip" % test_name, - ) - tf_cc_test(name, **kwargs) + Args: + name: Resulting cc_test target name + test_name: Test targets this model. Comes from the list above. + **kwargs: tf_cc_test kwargs. + """ + gen_zipped_test_file( + name = "zip_%s" % test_name, + file = "%s.zip" % test_name, + ) + tf_cc_test(name, **kwargs) def gen_zipped_test_file(name, file): - """Generate a zip file of tests by using :generate_examples. + """Generate a zip file of tests by using :generate_examples. - Args: - name: Name of output. We will produce "`file`.files" as a target. - file: The name of one of the generated_examples targets, e.g. "transpose" - """ - toco = "//tensorflow/contrib/lite/toco:toco" - native.genrule( - name = file + ".files", - cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco - + " --zip_to_output " + file + " $(@D)"), - outs = [file], - tools = [ - ":generate_examples", - toco, - ], - ) + Args: + name: Name of output. We will produce "`file`.files" as a target. + file: The name of one of the generated_examples targets, e.g. "transpose" + """ + toco = "//tensorflow/contrib/lite/toco:toco" + native.genrule( + name = file + ".files", + cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco + + " --zip_to_output " + file + " $(@D)"), + outs = [file], + tools = [ + ":generate_examples", + toco, + ], + ) - native.filegroup( - name = name, - srcs = [file], - ) + native.filegroup( + name = name, + srcs = [file], + ) def gen_selected_ops(name, model): - """Generate the library that includes only used ops. + """Generate the library that includes only used ops. - Args: - name: Name of the generated library. - model: TFLite model to interpret. - """ - out = name + "_registration.cc" - tool = "//tensorflow/contrib/lite/tools:generate_op_registrations" - tflite_path = "//tensorflow/contrib/lite" - native.genrule( - name = name, - srcs = [model], - outs = [out], - cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") - % (tool, model, out, tflite_path[2:]), - tools = [tool], - ) + Args: + name: Name of the generated library. + model: TFLite model to interpret. + """ + out = name + "_registration.cc" + tool = "//tensorflow/contrib/lite/tools:generate_op_registrations" + tflite_path = "//tensorflow/contrib/lite" + native.genrule( + name = name, + srcs = [model], + outs = [out], + cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") % + (tool, model, out, tflite_path[2:]), + tools = [tool], + ) diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh index e9531aef19f04adf719156aa3e874dc5ce6e2b04..31df43a1754bd753a82a613dc15704aaa056a87e 100755 --- a/tensorflow/contrib/lite/build_ios_universal_lib.sh +++ b/tensorflow/contrib/lite/build_ios_universal_lib.sh @@ -21,7 +21,7 @@ cd "$SCRIPT_DIR/../../.." # Build library for supported architectures and packs them in a fat binary. make_library() { - for arch in x86_64 i386 armv7 armv7s arm64 + for arch in x86_64 armv7 armv7s arm64 do make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=${arch} \ -j 8 \ @@ -29,7 +29,6 @@ make_library() { done lipo \ tensorflow/contrib/lite/gen/lib/ios_x86_64/${1} \ - tensorflow/contrib/lite/gen/lib/ios_i386/${1} \ tensorflow/contrib/lite/gen/lib/ios_armv7/${1} \ tensorflow/contrib/lite/gen/lib/ios_armv7s/${1} \ tensorflow/contrib/lite/gen/lib/ios_arm64/${1} \ diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index cda889bf502a535eac4249bbae645359cdb2135d..70178b2faabe85f8a53a94c2b5d2e3ea40c8ba05 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -249,6 +249,10 @@ typedef struct { TfLiteType output_type; } TfLiteArgMaxParams; +typedef struct { + TfLiteType output_type; +} TfLiteArgMinParams; + typedef struct { TfLitePadding padding; int stride_width; @@ -263,6 +267,25 @@ typedef struct { TfLiteType out_type; } TfLiteShapeParams; +typedef struct { + // Parameters supported by version 1: + float min; + float max; + int num_bits; + + // Parameters supported by version 2: + bool narrow_range; +} TfLiteFakeQuantParams; + +typedef struct { + int values_count; + int axis; +} TfLitePackParams; + +typedef struct { + int axis; +} TfLiteOneHotParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index a44e9182302d19acd1e1c183ed388531eec11d93..8a8eb9856886538a1483141ab5f67f54613ea2a1 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -104,6 +104,15 @@ typedef enum { kTfLiteBuiltinRsqrt = 76, kTfLiteBuiltinShape = 77, kTfLiteBuiltinPow = 78, + kTfLiteBuiltinArgMin = 79, + kTfLiteBuiltinFakeQuant = 80, + kTfLiteBuiltinReduceProd = 81, + kTfLiteBuiltinReduceMax = 82, + kTfLiteBuiltinPack = 83, + kTfLiteBuiltinLogicalOr = 84, + kTfLiteBuiltinOneHot = 85, + kTfLiteBuiltinLogicalAnd = 86, + kTfLiteBuiltinLogicalNot = 87, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index 1265c4cba9064cc5aba9af81415f857ad00f6d99..5bc20106d31357e2da3f005baee0f8d134d37be2 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -29,6 +29,9 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ #define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ +#if defined(_MSC_VER) +#include +#endif #include #include #include @@ -39,6 +42,26 @@ extern "C" { typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; +// The list of external context types known to TF Lite. This list exists solely +// to avoid conflicts and to ensure ops can share the external contexts they +// need. Access to the external contexts is controled by one of the +// corresponding support files. +typedef enum { + kTfLiteEigenContext = 0, // include eigen_support.h to use. + kTfLiteGemmLowpContext = 1, // include gemm_support.h to use. + kTfLiteMaxExternalContexts = 2 +} TfLiteExternalContextType; + +// An external context is a collection of information unrelated to the TF Lite +// framework, but useful to a subset of the ops. TF Lite knows very little +// about about the actual contexts, but it keeps a list of them, and is able to +// refresh them if configurations like the number of recommended threads +// change. +typedef struct { + TfLiteExternalContextType type; + TfLiteStatus (*Refresh)(struct TfLiteContext* context); +} TfLiteExternalContext; + // Forward declare so GetNode can use this is in Context. typedef struct _TfLiteRegistration TfLiteRegistration; typedef struct _TfLiteDelegate TfLiteDelegate; @@ -160,7 +183,11 @@ typedef union { uint8_t* uint8; bool* b; int16_t* i16; +#if defined(_MSC_VER) + _Fcomplex* c64; +#else _Complex float* c64; +#endif } TfLitePtrUnion; // Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped @@ -245,7 +272,8 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, const void* allocation, bool is_variable, TfLiteTensor* tensor); -// Resize the allocated data of a (dynamic) tensor. +// Resize the allocated data of a (dynamic) tensor. Tensors with allocation +// types other than kTfLiteDynamic will be ignored. void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); // A structure representing an instance of a node. @@ -338,10 +366,15 @@ typedef struct TfLiteContext { // eigen. int recommended_num_threads; - // TODO(ahentz): we should create a more general mechanism for this sort of - // library-global objects. - void* gemm_context; - void* eigen_context; + // Access external contexts by type. + // WARNING: This is an experimental interface that is subject to change. + TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*, + TfLiteExternalContextType); + // Set the value of a external context. Does not take ownership of the + // pointer. + // WARNING: This is an experimental interface that is subject to change. + void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType, + TfLiteExternalContext*); } TfLiteContext; typedef struct _TfLiteRegistration { @@ -438,6 +471,12 @@ typedef struct _TfLiteDelegate { } TfLiteDelegate; // WARNING: This is an experimental interface that is subject to change. +// +// Currently, TfLiteDelegateParams has to be allocated in a way that it's +// trivially destructable. It will be stored as `builtin_data` field in +// `TfLiteNode` of the delegate node. +// +// See also the `CreateDelegateParams` function in `interpreter.cc` details. typedef struct { TfLiteDelegate* delegate; TfLiteIntArray* nodes_to_replace; diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..f21540d524e77b16a5a9fe1b66781eb6faeddd39 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/BUILD @@ -0,0 +1,186 @@ +# +# This is a TF Lite delegate that is powered by TensorFlow's Eager. +# +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "buffer_map", + srcs = ["buffer_map.cc"], + hdrs = ["buffer_map.h"], + deps = [ + ":util", + "//tensorflow/c:c_api_internal", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "buffer_map_test", + size = "small", + srcs = ["buffer_map_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable", + ], + deps = [ + ":buffer_map", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:util", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "delegate", + srcs = [ + "delegate.cc", + ], + hdrs = [ + "delegate.h", + ], + deps = [ + ":buffer_map", + ":delegate_data", + ":kernel", + ":util", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/contrib/lite:util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "delegate_test", + size = "small", + srcs = ["delegate_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable", + ], + deps = [ + ":delegate", + ":test_util", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "delegate_data", + srcs = ["delegate_data.cc"], + hdrs = ["delegate_data.h"], + deps = [ + ":buffer_map", + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + "//tensorflow/core/common_runtime/eager:context", + ], +) + +cc_test( + name = "delegate_data_test", + size = "small", + srcs = ["delegate_data_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable", + ], + deps = [ + ":delegate_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:util", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "kernel", + srcs = ["kernel.cc"], + hdrs = ["kernel.h"], + deps = [ + ":delegate_data", + ":util", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:execute", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "@flatbuffers", + ], +) + +cc_test( + name = "kernel_test", + size = "small", + srcs = ["kernel_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable", + ], + deps = [ + ":delegate_data", + ":kernel", + ":test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "test_util", + testonly = True, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + "//tensorflow/c:c_api_internal", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_absl//absl/memory", + "@flatbuffers", + ], +) + +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + deps = [ + ":constants", + "//tensorflow/c:c_api_internal", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "util_test", + size = "small", + srcs = ["util_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable", + ], + deps = [ + ":util", + "//tensorflow/contrib/lite/testing:util", + "//tensorflow/core:lib", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "constants", + hdrs = ["constants.h"], +) diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc b/tensorflow/contrib/lite/delegates/eager/buffer_map.cc new file mode 100644 index 0000000000000000000000000000000000000000..e5a19c39976969a0b05b28596c6d7d5ebe7c7782 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.cc @@ -0,0 +1,111 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h" + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/contrib/lite/delegates/eager/util.h" +#include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/log_memory.h" + +namespace tflite { +namespace eager { +namespace { +// A tensor buffer that is allocated, deallocated and populated by TF Lite. +class TfLiteTensorBuffer : public tensorflow::TensorBuffer { + public: + explicit TfLiteTensorBuffer(const TfLiteTensor* tensor) { + len_ = tensor->bytes; + // TODO(ahentz): if we can guarantee that TF Lite allocated tensors with + // the same alignment as TensorFlow (EIGEN_MAX_ALIGN_BYTES), then we can + // potentially eliminate the copy below. + data_ = + tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len_); + if (data_ != nullptr) { + if (tensorflow::LogMemory::IsEnabled()) { + tensorflow::LogMemory::RecordRawAllocation( + "TfLiteTensorBuffer_New", + tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, len_, + data_, tensorflow::cpu_allocator()); + } + std::memcpy(data_, tensor->data.raw, tensor->bytes); + } + } + + ~TfLiteTensorBuffer() override { + if (tensorflow::LogMemory::IsEnabled() && data_ != nullptr) { + tensorflow::LogMemory::RecordRawDeallocation( + "TfLiteTensorBuffer_Delete", + tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data_, + tensorflow::cpu_allocator(), false); + } + tensorflow::cpu_allocator()->DeallocateRaw(data_); + } + + void* data() const override { return data_; } + size_t size() const override { return len_; } + + TensorBuffer* root_buffer() override { return this; } + void FillAllocationDescription( + tensorflow::AllocationDescription* proto) const override { + tensorflow::int64 rb = size(); + proto->set_requested_bytes(rb); + proto->set_allocator_name(tensorflow::cpu_allocator()->Name()); + } + + // Prevents input forwarding from mutating this buffer. + bool OwnsMemory() const override { return false; } + + private: + void* data_; + size_t len_; +}; +} // namespace + +BufferMap::BufferMap() {} + +BufferMap::~BufferMap() {} + +bool BufferMap::HasTensor(int tensor_index) const { + return id_to_tensor_.count(tensor_index) != 0; +} + +tensorflow::Tensor BufferMap::GetTensor(int tensor_index) const { + return id_to_tensor_.at(tensor_index); +} + +void BufferMap::SetFromTfLite(int tensor_index, const TfLiteTensor* tensor) { + tensorflow::TensorShape shape; + int num_dims = tensor->dims->size; + for (int i = 0; i < num_dims; ++i) { + shape.AddDim(tensor->dims->data[i]); + } + // TODO(ahentz): we assume this is a new tensor and allocate a new buffer + // for it. This is not always the best approach. For example, this might + // be a reallocation after resizing tensors. In that case we would be + // preferable to somehow reuse the buffer. + auto* buf = new TfLiteTensorBuffer(tensor); + tensorflow::Tensor t = tensorflow::TensorCApi::MakeTensor( + GetTensorFlowDataType(tensor->type), shape, buf); + buf->Unref(); + + SetFromTensorFlow(tensor_index, std::move(t)); +} + +void BufferMap::SetFromTensorFlow(int tensor_index, tensorflow::Tensor tensor) { + id_to_tensor_[tensor_index] = std::move(tensor); +} + +} // namespace eager +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.h b/tensorflow/contrib/lite/delegates/eager/buffer_map.h new file mode 100644 index 0000000000000000000000000000000000000000..a28329ae7d14e3e0214c6602b28b09c43876bbf0 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.h @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_ + +#include + +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tflite { +namespace eager { + +// Maps a TF Lite tensor index into a TensorFlow tensor. +// +// The TF Lite interpreter assigns integer indices to each of its tensors, but +// the Eager delegate deals in terms of TensorFlow tensors. This class maps +// from indices to tensors and allows the creation of new tensors to be +// associated with a given index. +class BufferMap { + public: + BufferMap(); + ~BufferMap(); + + // Returns true if the given 'tensor_index' has a corresponding + // tensorflow::Tensor. + bool HasTensor(int tensor_index) const; + + // Returns the tensorflow::Tensor associated with the given 'tensor_index'. + // Precondition: HasTensor() is true. + tensorflow::Tensor GetTensor(int tensor_index) const; + + // Associates the given tensorflow::Tensor with the given 'tensor_index'. + // Note that tensorflow Tensors share data buffers, so this method is only a + // shallow copy. + void SetFromTensorFlow(int tensor_index, tensorflow::Tensor tensor); + + // Same as above but creates a new tensorflow::Tensor with a copy of the + // given TfLiteTensor's data. + void SetFromTfLite(int tensor_index, const TfLiteTensor* tensor); + + private: + std::map id_to_tensor_; +}; + +} // namespace eager +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc b/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..dcb3f6c94150892f565380ff0598a7a28f9399b1 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc @@ -0,0 +1,174 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h" + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/contrib/lite/util.h" + +namespace tflite { +namespace eager { +namespace { + +using ::testing::ElementsAre; + +// A bit of RAII to simplify handling of TfLiteTensors in the tests. +using UniqueTfLiteTensor = + std::unique_ptr>; + +template +UniqueTfLiteTensor MakeLiteTensor(const std::vector& shape, + const std::vector& data) { + auto tensor = UniqueTfLiteTensor(new TfLiteTensor, [](TfLiteTensor* t) { + TfLiteTensorDataFree(t); + TfLiteIntArrayFree(t->dims); + delete t; + }); + tensor->allocation_type = kTfLiteDynamic; + tensor->type = typeToTfLiteType(); + tensor->dims = ConvertVectorToTfLiteIntArray(shape); + tensor->data.raw = nullptr; + TfLiteTensorRealloc(data.size() * sizeof(T), tensor.get()); + memcpy(tensor->data.raw, data.data(), data.size() * sizeof(T)); + return tensor; +} + +template +tensorflow::Tensor MakeTensor(const std::vector& shape, + const std::vector& data) { + BufferMap buffer_map; // BufferMap is the easiest way to build the tensor. + UniqueTfLiteTensor t1 = MakeLiteTensor(shape, data); + buffer_map.SetFromTfLite(0, t1.get()); + return buffer_map.GetTensor(0); +} + +std::vector GetTensorShape(const tensorflow::Tensor& t) { + std::vector shape(t.dims()); + for (int i = 0; i < t.dims(); ++i) { + shape[i] = t.dim_size(i); + } + return shape; +} + +template +std::vector GetTensorData(const tensorflow::Tensor& t) { + const T* data = t.flat().data(); + return std::vector(data, data + t.NumElements()); +} + +TEST(BufferMapTest, EmptyBuffer) { + BufferMap buffer_map; + EXPECT_FALSE(buffer_map.HasTensor(0)); +} + +TEST(BufferMapTest, SetFromTfLite) { + BufferMap buffer_map; + + UniqueTfLiteTensor t = + MakeLiteTensor({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + buffer_map.SetFromTfLite(0, t.get()); + ASSERT_TRUE(buffer_map.HasTensor(0)); + + EXPECT_THAT(GetTensorData(buffer_map.GetTensor(0)), + ElementsAre(0, 0, 0, 0.123f, 0, 0)); + + // Also check details of the tensor. + tensorflow::Tensor out_tensor = buffer_map.GetTensor(0); + ASSERT_EQ(out_tensor.dtype(), tensorflow::DT_FLOAT); + ASSERT_EQ(out_tensor.NumElements(), 6); + ASSERT_THAT(GetTensorShape(out_tensor), ElementsAre(1, 2, 1, 3)); +} + +TEST(BufferMapTest, SetFromTfLiteTwice) { + UniqueTfLiteTensor t1 = + MakeLiteTensor({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + UniqueTfLiteTensor t2 = + MakeLiteTensor({1, 2, 4}, {0, 0, 0, 3, 0, 0, 1, 2}); + + BufferMap buffer_map; + buffer_map.SetFromTfLite(0, t1.get()); + buffer_map.SetFromTfLite(0, t2.get()); + + EXPECT_THAT(GetTensorData(buffer_map.GetTensor(0)), + ElementsAre(0, 0, 0, 3, 0, 0, 1, 2)); +} + +TEST(BufferMapTest, SetFromTensorFlow) { + tensorflow::Tensor t1 = + MakeTensor({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + + BufferMap buffer_map; + buffer_map.SetFromTensorFlow(0, t1); + + EXPECT_THAT(GetTensorData(buffer_map.GetTensor(0)), + ElementsAre(0, 0, 0, 0.123f, 0, 0)); + + // Also check details of the tensor. + tensorflow::Tensor out_tensor = buffer_map.GetTensor(0); + ASSERT_EQ(out_tensor.dtype(), tensorflow::DT_FLOAT); + ASSERT_EQ(out_tensor.NumElements(), 6); + ASSERT_THAT(GetTensorShape(out_tensor), ElementsAre(1, 2, 1, 3)); +} + +TEST(BufferMapTest, SetFromTensorFlowTwice) { + tensorflow::Tensor t1 = + MakeTensor({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + tensorflow::Tensor t2 = MakeTensor({1, 2, 4}, {0, 0, 0, 3, 0, 0, 1, 2}); + BufferMap buffer_map; + buffer_map.SetFromTensorFlow(0, t1); + buffer_map.SetFromTensorFlow(0, t2); + + EXPECT_THAT(GetTensorData(buffer_map.GetTensor(0)), + ElementsAre(0, 0, 0, 3, 0, 0, 1, 2)); +} + +TEST(BufferMapTest, TfLiteOverwritesTensorFlow) { + tensorflow::Tensor t1 = + MakeTensor({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + UniqueTfLiteTensor t2 = + MakeLiteTensor({1, 2, 4}, {0, 0, 0, 3, 0, 0, 1, 2}); + + BufferMap buffer_map; + buffer_map.SetFromTensorFlow(0, t1); + buffer_map.SetFromTfLite(0, t2.get()); + + EXPECT_THAT(GetTensorData(buffer_map.GetTensor(0)), + ElementsAre(0, 0, 0, 3, 0, 0, 1, 2)); +} + +TEST(BufferMapTest, TensorFlowOverwritesTfLite) { + tensorflow::Tensor t1 = + MakeTensor({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + UniqueTfLiteTensor t2 = + MakeLiteTensor({1, 2, 4}, {0, 0, 0, 3, 0, 0, 1, 2}); + BufferMap buffer_map; + buffer_map.SetFromTfLite(0, t2.get()); + buffer_map.SetFromTensorFlow(0, t1); + + EXPECT_THAT(GetTensorData(buffer_map.GetTensor(0)), + ElementsAre(0, 0, 0, 0.123f, 0, 0)); +} + +} // namespace +} // namespace eager +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc b/tensorflow/contrib/lite/delegates/eager/constants.h similarity index 51% rename from tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc rename to tensorflow/contrib/lite/delegates/eager/constants.h index 0e08a04370592f6e3c92b5811fa7e163f808e03c..7ed6ab7552792c68e6d90056c83c3c574c3f69f7 100644 --- a/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc +++ b/tensorflow/contrib/lite/delegates/eager/constants.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,27 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -#include -#include +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_CONSTANTS_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_CONSTANTS_H_ namespace tflite { +namespace eager { -// Gets the elapsed wall-clock timespec. -timespec getCurrentTime() { - timespec time; - clock_gettime(CLOCK_MONOTONIC, &time); - return time; -} - -// Computes the time diff from two timespecs. Returns '-1' if 'stop' is earlier -// than 'start'. -jlong timespec_diff_nanoseconds(struct timespec* start, struct timespec* stop) { - jlong result = stop->tv_sec - start->tv_sec; - if (result < 0) return -1; - result = 1000000000 * result + (stop->tv_nsec - start->tv_nsec); - if (result < 0) return -1; - return result; -} +// The prefix of Eager op custom code. +// This will be matched agains the `custom_code` field in `OperatorCode` +// Flatbuffer Table. +constexpr char kCustomCodePrefix[] = "Eager"; +} // namespace eager } // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_CONSTANTS_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.cc b/tensorflow/contrib/lite/delegates/eager/delegate.cc new file mode 100644 index 0000000000000000000000000000000000000000..7d22b454199e2c0d9b8fea05086a7c62d7cdbe81 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate.cc @@ -0,0 +1,110 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/delegate.h" + +#include + +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h" +#include "tensorflow/contrib/lite/delegates/eager/kernel.h" +#include "tensorflow/contrib/lite/delegates/eager/util.h" +#include "tensorflow/contrib/lite/util.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tflite { +namespace eager { +namespace delegate { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { + // Get the nodes in the current execution plan. Interpreter owns this array. + TfLiteIntArray* plan; + TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); + + // Add all custom ops starting with "Eager" to list of supported nodes. + std::vector supported_nodes; + for (int node_index : TfLiteIntArrayView(plan)) { + TfLiteNode* node; + TfLiteRegistration* registration; + TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( + context, node_index, &node, ®istration)); + + if (IsEagerOp(registration->custom_name)) { + supported_nodes.push_back(node_index); + } + } + + // Request TFLite to partition the graph and make kernels for each independent + // subgraph. + TfLiteIntArray* size_and_nodes = + ConvertVectorToTfLiteIntArray(supported_nodes); + context->ReplaceSubgraphsWithDelegateKernels(context, GetKernel(), + size_and_nodes, delegate); + TfLiteIntArrayFree(size_and_nodes); + return kTfLiteOk; +} + +TfLiteStatus CopyFromBufferHandle(TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, void* data, + size_t size) { + // TODO(nupurgarg): Make BufferMap unique to each interpreter in order to + // support multiple interpreters using a single delegate. + BufferMap* buffer_map = + reinterpret_cast(delegate->data_)->GetBufferMap(); + + // TODO(nupurgarg): Use TfLiteContext's ReportError instead of fprinf. + if (!buffer_map->HasTensor(buffer_handle)) { + fprintf(stderr, "Invalid tensor index %d.\n", buffer_handle); + return kTfLiteError; + } + + tensorflow::Tensor t = buffer_map->GetTensor(buffer_handle); + tensorflow::StringPiece t_data = t.tensor_data(); + + if (size != t_data.size()) { + fprintf(stderr, "Not enough space to store TensorFlow's aligned buffer.\n"); + return kTfLiteError; + } + + memcpy(data, t_data.data(), t_data.size()); + return kTfLiteOk; +} + +} // namespace delegate +} // namespace eager + +EagerDelegate::EagerDelegate() {} + +EagerDelegate::~EagerDelegate() {} + +TfLiteStatus EagerDelegate::Apply(Interpreter* interpreter) { + if (!delegate_) { + if (!eager::DelegateData::Create(&delegate_data_).ok()) { + fprintf(stderr, "Unable to initialize TensorFlow context.\n"); + return kTfLiteError; + } + + delegate_.reset(new TfLiteDelegate{ + /*data_=*/delegate_data_.get(), + /*nullptr,*/ &eager::delegate::Prepare, + /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle, + /*CopyToBufferHandle=*/nullptr, + /*FreeBufferHandle=*/nullptr}); + } + + return interpreter->ModifyGraphWithDelegate(delegate_.get(), + /*allow_dynamic_tensors=*/true); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/eager/delegate.h new file mode 100644 index 0000000000000000000000000000000000000000..0defca7c323e81bfb211ac56fd59c8656b320574 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_ + +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" +#include "tensorflow/contrib/lite/interpreter.h" + +namespace tflite { + +// WARNING: This is an experimental interface that is subject to change. +// Delegate that can be used to extract parts of a graph that are designed to be +// executed by TensorFlow's runtime via Eager. +// +// The interpreter must be constructed after the EagerDelegate and destructed +// before the EagerDelegate. This delegate can only be used with one +// interpreter. +// +// Usage: +// EagerDelegate delegate; +// ... build interpreter ... +// +// delegate.Apply(interpreter); +// ... run inference ... +// ... destroy interpreter ... +// ... destroy delegate ... +class EagerDelegate { + public: + EagerDelegate(); + ~EagerDelegate(); + + // Modifies the graph loaded in the interpreter. + TfLiteStatus Apply(Interpreter* interpreter); + + private: + std::unique_ptr delegate_data_; + std::unique_ptr delegate_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data.cc new file mode 100644 index 0000000000000000000000000000000000000000..0fd5c976f8ca9be16f7e3c5e610573755b40c506 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate_data.cc @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tflite { +namespace eager { +tensorflow::Status DelegateData::Create(std::unique_ptr* data) { + std::vector devices; + + TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices( + tensorflow::SessionOptions(), "/job:localhost/replica:0/task:0", + &devices)); + + std::unique_ptr device_mgr( + new tensorflow::DeviceMgr(devices)); + // Note that Rendezvous is ref-counted so it will be automatically deleted. + tensorflow::Rendezvous* rendezvous = + new tensorflow::IntraProcessRendezvous(device_mgr.get()); + data->reset(new DelegateData(new tensorflow::EagerContext( + tensorflow::SessionOptions(), + tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, + /*async=*/false, std::move(device_mgr), rendezvous))); + return tensorflow::Status(); +} + +DelegateData::DelegateData(tensorflow::EagerContext* eager_context) + : eager_context_(eager_context) {} + +DelegateData::~DelegateData() {} + +} // namespace eager +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.h b/tensorflow/contrib/lite/delegates/eager/delegate_data.h new file mode 100644 index 0000000000000000000000000000000000000000..8a0e8ba8bf213341d9da15613ea40e1f903f8bb6 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate_data.h @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_ + +#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h" +#include "tensorflow/core/common_runtime/eager/context.h" + +namespace tflite { +namespace eager { + +// Data kept by the Eager delegate for the lifetime of an Interpreter. +class DelegateData { + public: + // Create a new DelegateData, initialized with a newly-created EagerContext. + static tensorflow::Status Create(std::unique_ptr* data); + + ~DelegateData(); + + // The EagerContext that is required for execution of Eager Ops. + tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); } + + // Map from TF Lite tensor index to TensorFlow tensor. + BufferMap* GetBufferMap() { return &buffer_map_; } + + private: + explicit DelegateData(tensorflow::EagerContext* eager_context); + + std::unique_ptr eager_context_; + BufferMap buffer_map_; +}; + +} // namespace eager +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..30251b8f82cf623b4c45854f7f2f6e5e2c008af0 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" + +#include +#include +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +namespace eager { +namespace { + +TEST(DelegateDataTest, Basic) { + std::unique_ptr data; + // We only check for success because it is hard to make initialization fail. + // It only happens if we manage to not link the CPU device factory into the + // binary. + EXPECT_TRUE(DelegateData::Create(&data).ok()); + + EXPECT_NE(data->GetEagerContext(), nullptr); + EXPECT_NE(data->GetBufferMap(), nullptr); +} + +} // namespace +} // namespace eager +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..88fb34044ec5f8e5b4593638163cd4e6407bf8c8 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc @@ -0,0 +1,150 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/delegate.h" + +#include +#include +#include "tensorflow/contrib/lite/delegates/eager/test_util.h" + +namespace tflite { +namespace eager { +namespace { + +using ::testing::ContainsRegex; +using ::testing::ElementsAre; + +// TODO(nupurgarg): Add a test with multiple interpreters for one delegate. + +class DelegateTest : public testing::EagerModelTest { + public: + DelegateTest() { + // The delegate needs to be constructed before the interpreter because the + // interpreter references data contained in the delegate. + delegate_.reset(new EagerDelegate()); + interpreter_.reset(new Interpreter(&error_reporter_)); + } + + ~DelegateTest() override { + // The delegate needs to be destructed after the interpreter because the + // interpreter references data contained in the delegate. + delete interpreter_.release(); + delete delegate_.release(); + } + + void ConfigureDelegate() { + CHECK(delegate_->Apply(interpreter_.get()) == kTfLiteOk); + } + + private: + std::unique_ptr delegate_; +}; + +TEST_F(DelegateTest, FullGraph) { + // Define the graph. + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfOp(testing::kMul, {6, 7}, {8}); + + // Apply the delegate. + ConfigureDelegate(); + + // Define inputs. + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); +} + +TEST_F(DelegateTest, MixedGraph) { + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfLiteMulOp({6, 7}, {8}); + + ConfigureDelegate(); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); +} + +TEST_F(DelegateTest, SplitGraph) { + AddTensors(10, {0}, {9}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kAdd, {1, 2}, {3}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + + AddTfLiteMulOp({4, 5}, {6}); + + AddTfOp(testing::kUnpack, {6}, {7, 8}); + AddTfOp(testing::kAdd, {7, 8}, {9}); + + ConfigureDelegate(); + + SetShape(0, {2, 2, 2, 1}); + SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(9), ElementsAre(1)); + ASSERT_THAT(GetValues(9), ElementsAre(10.0f)); +} + +TEST_F(DelegateTest, OnlyTFLite) { + // Only TFLite single op model. + AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3}); + AddTfLiteMulOp({0, 1}, {2}); + + ConfigureDelegate(); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(1, {2, 2, 1}); + SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1)); + ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f)); +} + +} // namespace +} // namespace eager +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..172798180762f87e1c080be7788db661a63208b5 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc @@ -0,0 +1,289 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/kernel.h" + +#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/builtin_ops.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" +#include "tensorflow/contrib/lite/delegates/eager/util.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/execute.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/framework/node_def.pb.h" + +// Note: this is part of TF Lite's Eager delegation code which is to be +// completed soon. + +// This is the TF Lite op that is created by the eager delegate to handle +// execution of a supported subgraph. The usual flow is that the delegate +// informs the interpreter of supported nodes in a graph, and each supported +// subgraph is replaced with one instance of this kernel. +// +// The kernel is initialized with TfLiteDelegateParams from which we retrieve +// the global EagerContext and BufferMap, as well as a list of inputs and +// outputs to the subgraph. Those are used to build the OpData, with a list of +// TensorFlow Ops that should be executed in order (which we call an OpNode). +// +// For each node included in the subgraph, we query the interpreter and +// retrieve the associated NodeDef, which is then used to configure the +// corresponding TensorFlow/Eager Op. + +namespace tflite { +namespace eager { +namespace kernel { + +// Controls the lifetime of tensor handles in a vector. +class VectorOfHandles { + public: + explicit VectorOfHandles(int num_elements) : vector_(num_elements, nullptr) {} + + ~VectorOfHandles() { + for (auto* handle : vector_) { + if (handle) handle->Unref(); + } + } + + tensorflow::gtl::InlinedVector* GetVector() { + return &vector_; + } + + tensorflow::TensorHandle* GetHandle(int index) { return vector_[index]; } + + private: + tensorflow::gtl::InlinedVector vector_; +}; + +// Executes the TensorFlow op given by 'op_name', with the attributes specified +// in 'nodedef'. Inputs and outputs are given as indices into the 'buffer_map'. +tensorflow::Status ExecuteEagerOp(tensorflow::EagerContext* eager_context, + BufferMap* buffer_map, const string& op_name, + const tensorflow::NodeDef& nodedef, + const std::vector& inputs, + const std::vector& outputs) { + const tensorflow::AttrTypeMap* attr_types; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + tensorflow::AttrTypeMapForOp(op_name.c_str(), &attr_types), + " (while processing attributes of '", op_name, "')"); + + tensorflow::EagerOperation op(eager_context, op_name.c_str(), attr_types); + for (const auto& attr : nodedef.attr()) { + op.MutableAttrs()->Set(attr.first, attr.second); + } + + for (int input_index : inputs) { + if (!buffer_map->HasTensor(input_index)) { + return tensorflow::errors::Internal( + "Cannot read from invalid tensor index ", input_index); + } + auto* handle = new tensorflow::TensorHandle( + buffer_map->GetTensor(input_index), nullptr, nullptr, nullptr); + op.AddInput(handle); + handle->Unref(); + } + + int num_retvals = outputs.size(); + VectorOfHandles retvals(num_retvals); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + EagerExecute(&op, retvals.GetVector(), &num_retvals), + " (while executing '", op_name, "' via Eager)"); + + if (num_retvals != outputs.size()) { + return tensorflow::errors::Internal( + "Unexpected number of outputs from EagerExecute"); + } + + for (int i = 0; i < num_retvals; ++i) { + const tensorflow::Tensor* tensor = nullptr; + TF_RETURN_IF_ERROR(retvals.GetHandle(i)->Tensor(&tensor)); + buffer_map->SetFromTensorFlow(outputs[i], *tensor); + } + + return tensorflow::Status::OK(); +} + +// A single node within the larger 'op'. Note that this kernel executes many +// TensorFlow ops within a single TF Lite op. +struct OpNode { + // The name of the TensorFlow op to execute. + string name; + // The corresponding NodeDef, containing the attributes for the op. + tensorflow::NodeDef nodedef; + // List of inputs, as TF Lite tensor indices. + std::vector inputs; + // List of outputs, as TF Lite tensor indices. + std::vector outputs; +}; + +// The Larger 'op', which contains all the nodes in a supported subgraph. +struct OpData { + tensorflow::EagerContext* eager_context; + BufferMap* buffer_map; + std::vector nodes; + std::vector subgraph_inputs; + std::vector subgraph_outputs; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* op_data = new OpData; + + const TfLiteDelegateParams* params = + reinterpret_cast(buffer); + CHECK(params); + CHECK(params->delegate); + CHECK(params->delegate->data_); + op_data->eager_context = + reinterpret_cast(params->delegate->data_) + ->GetEagerContext(); + op_data->buffer_map = + reinterpret_cast(params->delegate->data_)->GetBufferMap(); + + CHECK(params->output_tensors); + for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) { + op_data->subgraph_outputs.push_back(tensor_index); + } + + CHECK(params->input_tensors); + for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) { + op_data->subgraph_inputs.push_back(tensor_index); + } + + CHECK(params->nodes_to_replace); + for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) { + TfLiteNode* node; + TfLiteRegistration* reg; + context->GetNodeAndRegistration(context, node_index, &node, ®); + + op_data->nodes.push_back(OpNode()); + OpNode& node_data = op_data->nodes.back(); + + node_data.name = ""; + if (node->custom_initial_data) { + // The flexbuffer contains a vector where the first elements is the + // op name and the second is a serialized NodeDef. + const flexbuffers::Vector& v = + flexbuffers::GetRoot( + reinterpret_cast(node->custom_initial_data), + node->custom_initial_data_size) + .AsVector(); + + node_data.name = v[0].AsString().str(); + if (!node_data.nodedef.ParseFromString(v[1].AsString().str())) { + // We will just leave the nodedef empty and error out in Eval(). + node_data.nodedef.Clear(); + } + } + + for (auto input_index : TfLiteIntArrayView(node->inputs)) { + node_data.inputs.push_back(input_index); + } + for (auto output_index : TfLiteIntArrayView(node->outputs)) { + node_data.outputs.push_back(output_index); + } + } + + return op_data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const auto* op_data = reinterpret_cast(node->user_data); + TF_LITE_ENSURE_MSG( + context, op_data->eager_context != nullptr, + "Failed to initialize eager context. This often happens when a CPU " + "device has not been registered, presumably because some symbols from " + "tensorflow/core:core_cpu_impl were not linked into the binary."); + + // Whenever we find a constant tensor, insert it in the buffer map. + BufferMap* buffer_map = op_data->buffer_map; + for (auto tensor_index : op_data->subgraph_inputs) { + TfLiteTensor* tensor = &context->tensors[tensor_index]; + if (IsConstantTensor(tensor)) { + if (!buffer_map->HasTensor(tensor_index)) { + buffer_map->SetFromTfLite(tensor_index, tensor); + } + } + } + + // All output tensors are allocated by TensorFlow/Eager, so we + // mark them as kTfLiteDynamic. + for (auto tensor_index : op_data->subgraph_outputs) { + SetTensorToDynamic(&context->tensors[tensor_index]); + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const auto* op_data = reinterpret_cast(node->user_data); + BufferMap* buffer_map = op_data->buffer_map; + tensorflow::EagerContext* eager_context = op_data->eager_context; + + // Insert a tensor in the buffer map for all inputs that are not constant. + // Constants were handled in Prepare() already. + for (auto tensor_index : op_data->subgraph_inputs) { + TfLiteTensor* tensor = &context->tensors[tensor_index]; + if (!IsConstantTensor(tensor)) { + buffer_map->SetFromTfLite(tensor_index, tensor); + } + } + + // Execute the TensorFlow Ops sequentially. + for (const auto& node_data : op_data->nodes) { + if (node_data.nodedef.op().empty()) { + context->ReportError(context, "Invalid NodeDef in Eager op '%s'", + node_data.name.c_str()); + return kTfLiteError; + } + auto status = + ExecuteEagerOp(eager_context, buffer_map, node_data.name, + node_data.nodedef, node_data.inputs, node_data.outputs); + TF_LITE_ENSURE_OK(context, ConvertStatus(context, status)); + } + + for (auto tensor_index : op_data->subgraph_outputs) { + if (!buffer_map->HasTensor(tensor_index)) { + context->ReportError(context, "Cannot write to invalid tensor index %d", + tensor_index); + return kTfLiteError; + } + + TfLiteTensor* tensor = &context->tensors[tensor_index]; + TF_LITE_ENSURE_OK( + context, + CopyShape(context, buffer_map->GetTensor(tensor_index), tensor)); + tensor->buffer_handle = tensor_index; + tensor->data_is_stale = true; + } + + return kTfLiteOk; +} + +} // namespace kernel + +TfLiteRegistration GetKernel() { + TfLiteRegistration registration{&kernel::Init, &kernel::Free, + &kernel::Prepare, &kernel::Eval, + nullptr, kTfLiteBuiltinDelegate}; + return registration; +} + +} // namespace eager +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.h b/tensorflow/contrib/lite/delegates/eager/kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..100672c82dcd3eaee17325f3b712140b081e8efe --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/kernel.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_ + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { +namespace eager { + +// Return the registration object used to initialize and execute ops that will +// be delegated to TensorFlow's Eager runtime. This TF Lite op is created by +// the eager delegate to handle execution of a supported subgraph. The usual +// flow is that the delegate informs the interpreter of supported nodes in a +// graph, and each supported subgraph is replaced with one instance of this +// kernel. +TfLiteRegistration GetKernel(); + +} // namespace eager +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/kernel_test.cc b/tensorflow/contrib/lite/delegates/eager/kernel_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b7bfbb34e49c71142e28f0bf1b2f84e0ff570734 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/kernel_test.cc @@ -0,0 +1,228 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/kernel.h" + +#include +#include +#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" +#include "tensorflow/contrib/lite/delegates/eager/test_util.h" + +namespace tflite { +namespace eager { +namespace { + +using ::testing::ContainsRegex; +using ::testing::ElementsAre; + +TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate, + const std::vector& supported_nodes) { + TfLiteIntArray* size_and_nodes = + ConvertVectorToTfLiteIntArray(supported_nodes); + TF_LITE_ENSURE_STATUS(context->ReplaceSubgraphsWithDelegateKernels( + context, eager::GetKernel(), size_and_nodes, delegate)); + TfLiteIntArrayFree(size_and_nodes); + return kTfLiteOk; +} + +class KernelTest : public testing::EagerModelTest { + public: + KernelTest() { + CHECK(DelegateData::Create(&delegate_data_).ok()); + interpreter_.reset(new Interpreter(&error_reporter_)); + } + + ~KernelTest() override { + // The data needs to be released before the interpreter because the + // interpreter references the data. + delegate_data_.reset(); + interpreter_.reset(); + } + + template + void ConfigureDelegate(T prepare_function) { + delegate_.data_ = delegate_data_.get(); + delegate_.FreeBufferHandle = nullptr; + delegate_.Prepare = prepare_function; + delegate_.CopyFromBufferHandle = [](TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + void* data, size_t size) { + auto* delegate_data = reinterpret_cast(delegate->data_); + tensorflow::StringPiece values = + delegate_data->GetBufferMap()->GetTensor(buffer_handle).tensor_data(); + memcpy(data, values.data(), values.size()); + return kTfLiteOk; + }; + CHECK(interpreter_->ModifyGraphWithDelegate( + &delegate_, /*allow_dynamic_tensors=*/true) == kTfLiteOk); + } + + private: + std::unique_ptr delegate_data_; + TfLiteDelegate delegate_; +}; + +TEST_F(KernelTest, FullGraph) { + // Define the graph. + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfOp(testing::kMul, {6, 7}, {8}); + + // Apply Delegate. + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0, 1, 2, 3, 4}); + }); + + // Define inputs. + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); +} + +TEST_F(KernelTest, BadTensorFlowOp) { + AddTensors(2, {0}, {1}, kTfLiteFloat32, {3}); + AddTfOp(testing::kNonExistent, {0}, {1}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_FALSE(Invoke()); + ASSERT_THAT(error_reporter().error_messages(), + ContainsRegex("while processing attributes of 'NonExistentOp'")); +} + +TEST_F(KernelTest, BadNumberOfOutputs) { + AddTensors(3, {0}, {1, 2}, kTfLiteFloat32, {3}); + AddTfOp(testing::kIdentity, {0}, {1, 2}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_FALSE(Invoke()); + ASSERT_THAT(error_reporter().error_messages(), + ContainsRegex("Unexpected number of outputs")); +} + +TEST_F(KernelTest, IncompatibleNodeDef) { + AddTensors(2, {0}, {1}, kTfLiteFloat32, {3}); + + // Cast is a TF op, but we don't add the proper nodedef to it in AddTfOp. + AddTfOp(testing::kIncompatibleNodeDef, {0}, {1}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_FALSE(Invoke()); + ASSERT_THAT(error_reporter().error_messages(), + ContainsRegex("while executing 'Cast' via Eager")); +} + +TEST_F(KernelTest, WrongSetOfNodes) { + AddTensors(4, {0}, {3}, kTfLiteFloat32, {3}); + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfLiteMulOp({1, 2}, {3}); + + // Specify that testing::kMul (#1) is supported when it actually isn't. + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0, 1}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_FALSE(Invoke()); + ASSERT_THAT(error_reporter().error_messages(), + ContainsRegex("Invalid NodeDef in Eager op")); +} + +TEST_F(KernelTest, MixedGraph) { + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfLiteMulOp({6, 7}, {8}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0, 1, 2, 3}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); +} + +TEST_F(KernelTest, SplitGraph) { + AddTensors(10, {0}, {9}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kAdd, {1, 2}, {3}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + + AddTfLiteMulOp({4, 5}, {6}); + + AddTfOp(testing::kUnpack, {6}, {7, 8}); + AddTfOp(testing::kAdd, {7, 8}, {9}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0, 1, 2, 4, 5}); + }); + + SetShape(0, {2, 2, 2, 1}); + SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(9), ElementsAre(1)); + ASSERT_THAT(GetValues(9), ElementsAre(10.0f)); +} + +} // namespace +} // namespace eager +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.cc b/tensorflow/contrib/lite/delegates/eager/test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..80acf5d9955f92ec06844b4bc3b980b3a924ab8f --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/test_util.cc @@ -0,0 +1,154 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/delegates/eager/test_util.h" + +#include "absl/memory/memory.h" +#include "third_party/flatbuffers/include/flatbuffers/flexbuffers.h" + +namespace tflite { +namespace eager { +namespace testing { + +bool EagerModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; } + +void EagerModelTest::SetValues(int tensor_index, + const std::vector& values) { + float* v = interpreter_->typed_tensor(tensor_index); + for (float f : values) { + *v++ = f; + } +} + +std::vector EagerModelTest::GetValues(int tensor_index) { + TfLiteTensor* o = interpreter_->tensor(tensor_index); + return std::vector(o->data.f, o->data.f + o->bytes / sizeof(float)); +} + +void EagerModelTest::SetShape(int tensor_index, + const std::vector& values) { + ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); +} + +std::vector EagerModelTest::GetShape(int tensor_index) { + std::vector result; + auto* dims = interpreter_->tensor(tensor_index)->dims; + result.reserve(dims->size); + for (int i = 0; i < dims->size; ++i) { + result.push_back(dims->data[i]); + } + return result; +} + +void EagerModelTest::AddTensors(int num_tensors, const std::vector& inputs, + const std::vector& outputs, + const TfLiteType& type, + const std::vector& dims) { + interpreter_->AddTensors(num_tensors); + for (int i = 0; i < num_tensors; ++i) { + TfLiteQuantizationParams quant; + CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type, + /*name=*/"", + /*dims=*/dims, quant), + kTfLiteOk); + } + + CHECK_EQ(interpreter_->SetInputs(inputs), kTfLiteOk); + CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk); +} + +void EagerModelTest::AddTfLiteMulOp(const std::vector& inputs, + const std::vector& outputs) { + static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + reg.builtin_code = BuiltinOperator_MUL; + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + auto* i0 = &context->tensors[node->inputs->data[0]]; + auto* o = &context->tensors[node->outputs->data[0]]; + return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims)); + }; + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + auto* i0 = &context->tensors[node->inputs->data[0]]; + auto* i1 = &context->tensors[node->inputs->data[1]]; + auto* o = &context->tensors[node->outputs->data[0]]; + for (int i = 0; i < o->bytes / sizeof(float); ++i) { + o->data.f[i] = i0->data.f[i] * i1->data.f[i]; + } + return kTfLiteOk; + }; + + CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0, + nullptr, ®), + kTfLiteOk); +} + +void EagerModelTest::AddTfOp(TfOpType op, const std::vector& inputs, + const std::vector& outputs) { + auto attr = [](const string& key, const string& value) { + return " attr{ key: '" + key + "' value {" + value + "}}"; + }; + + if (op == kUnpack) { + string attributes = attr("T", "type: DT_FLOAT") + attr("num", "i: 2") + + attr("axis", "i: 0"); + AddTfOp("EagerUnpack", "Unpack", attributes, inputs, outputs); + } else if (op == kIdentity) { + string attributes = attr("T", "type: DT_FLOAT"); + AddTfOp("EagerIdentity", "Identity", attributes, inputs, outputs); + } else if (op == kAdd) { + string attributes = attr("T", "type: DT_FLOAT"); + AddTfOp("EagerAdd", "Add", attributes, inputs, outputs); + } else if (op == kMul) { + string attributes = attr("T", "type: DT_FLOAT"); + AddTfOp("EagerMul", "Mul", attributes, inputs, outputs); + } else if (op == kNonExistent) { + AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs); + } else if (op == kIncompatibleNodeDef) { + // "Cast" op is created without attributes - making it incompatible. + AddTfOp("EagerCast", "Cast", "", inputs, outputs); + } +} + +void EagerModelTest::AddTfOp(const char* tflite_name, const string& tf_name, + const string& nodedef_str, + const std::vector& inputs, + const std::vector& outputs) { + static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + reg.builtin_code = BuiltinOperator_CUSTOM; + reg.custom_name = tflite_name; + + tensorflow::NodeDef nodedef; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString( + nodedef_str + " op: '" + tf_name + "'", &nodedef)); + string serialized_nodedef; + CHECK(nodedef.SerializeToString(&serialized_nodedef)); + flexbuffers::Builder fbb; + fbb.Vector([&]() { + fbb.String(nodedef.op()); + fbb.String(serialized_nodedef); + }); + fbb.Finish(); + + flexbuffers_.push_back(fbb.GetBuffer()); + auto& buffer = flexbuffers_.back(); + CHECK_EQ(interpreter_->AddNodeWithParameters( + inputs, outputs, reinterpret_cast(buffer.data()), + buffer.size(), nullptr, ®), + kTfLiteOk); +} + +} // namespace testing +} // namespace eager +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.h b/tensorflow/contrib/lite/delegates/eager/test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..0eab9e1135f02b4f22a4b36a85cf6771fbbb81d5 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/test_util.h @@ -0,0 +1,97 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_ + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" + +namespace tflite { +namespace eager { +namespace testing { + +enum TfOpType { + kUnpack, + kIdentity, + kAdd, + kMul, + // Represents an op that does not exist in TensorFlow. + kNonExistent, + // Represents an valid TensorFlow op where the NodeDef is incompatible. + kIncompatibleNodeDef, +}; + +// This class creates models with TF and TFLite ops. In order to use this class +// to test the Eager delegate, implement a function that calls +// interpreter->ModifyGraphWithDelegate. +class EagerModelTest : public ::testing::Test { + public: + EagerModelTest() {} + ~EagerModelTest() {} + + bool Invoke(); + + // Sets the tensor's values at the given index. + void SetValues(int tensor_index, const std::vector& values); + + // Returns the tensor's values at the given index. + std::vector GetValues(int tensor_index); + + // Sets the tensor's shape at the given index. + void SetShape(int tensor_index, const std::vector& values); + + // Returns the tensor's shape at the given index. + std::vector GetShape(int tensor_index); + + const TestErrorReporter& error_reporter() const { return error_reporter_; } + + // Adds `num_tensor` tensors to the model. `inputs` contains the indices of + // the input tensors and `outputs` contains the indices of the output + // tensors. All tensors are set to have `type` and `dims`. + void AddTensors(int num_tensors, const std::vector& inputs, + const std::vector& outputs, const TfLiteType& type, + const std::vector& dims); + + // Adds a TFLite Mul op. `inputs` contains the indices of the input tensors + // and `outputs` contains the indices of the output tensors. + void AddTfLiteMulOp(const std::vector& inputs, + const std::vector& outputs); + + // Adds a TensorFlow op. `inputs` contains the indices of the + // input tensors and `outputs` contains the indices of the output tensors. + // This function is limited to the set of ops defined in TfOpType. + void AddTfOp(TfOpType op, const std::vector& inputs, + const std::vector& outputs); + + protected: + std::unique_ptr interpreter_; + TestErrorReporter error_reporter_; + + private: + // Helper method to add a TensorFlow op. tflite_names needs to start with + // "Eager" in order to work with the Eager delegate. + void AddTfOp(const char* tflite_name, const string& tf_name, + const string& nodedef_str, const std::vector& inputs, + const std::vector& outputs); + + std::vector> flexbuffers_; +}; + +} // namespace testing +} // namespace eager +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/eager/util.cc new file mode 100644 index 0000000000000000000000000000000000000000..c8aa0b7f69f8f6bd3bff52b13f3cc7d689a514da --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/util.cc @@ -0,0 +1,78 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/util.h" +#include "tensorflow/contrib/lite/delegates/eager/constants.h" + +namespace tflite { +namespace eager { + +bool IsEagerOp(const char* custom_name) { + return custom_name && strncmp(custom_name, kCustomCodePrefix, + strlen(kCustomCodePrefix)) == 0; +} + +TfLiteStatus ConvertStatus(TfLiteContext* context, + const tensorflow::Status& status) { + if (!status.ok()) { + context->ReportError(context, "%s", status.error_message().c_str()); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src, + TfLiteTensor* tensor) { + int num_dims = src.dims(); + TfLiteIntArray* shape = TfLiteIntArrayCreate(num_dims); + for (int j = 0; j < num_dims; ++j) { + // We need to cast from TensorFlow's int64 to TF Lite's int32. Let's + // make sure there's no overflow. + if (src.dim_size(j) >= std::numeric_limits::max()) { + context->ReportError(context, + "Dimension value in TensorFlow shape is larger than " + "supported by TF Lite"); + TfLiteIntArrayFree(shape); + return kTfLiteError; + } + shape->data[j] = static_cast(src.dim_size(j)); + } + return context->ResizeTensor(context, tensor, shape); +} + +TF_DataType GetTensorFlowDataType(TfLiteType type) { + switch (type) { + case kTfLiteNoType: + return TF_FLOAT; + case kTfLiteFloat32: + return TF_FLOAT; + case kTfLiteInt16: + return TF_INT16; + case kTfLiteInt32: + return TF_INT32; + case kTfLiteUInt8: + return TF_UINT8; + case kTfLiteInt64: + return TF_INT64; + case kTfLiteComplex64: + return TF_COMPLEX64; + case kTfLiteString: + return TF_STRING; + case kTfLiteBool: + return TF_BOOL; + } +} + +} // namespace eager +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h new file mode 100644 index 0000000000000000000000000000000000000000..b7363361bec47f30e0741e3a76a5a375d7d9aeb1 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/util.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_ + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tflite { +namespace eager { + +// Checks whether the prefix of the custom name indicates the operation is an +// Eager operation. +bool IsEagerOp(const char* custom_name); + +// Converts a tensorflow:Status into a TfLiteStatus. If the original status +// represented an error, reports it using the given 'context'. +TfLiteStatus ConvertStatus(TfLiteContext* context, + const tensorflow::Status& status); + +// Copies the given shape of the given 'src' into a TF Lite 'tensor'. Logs an +// error and returns kTfLiteError if the shape can't be converted. +TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src, + TfLiteTensor* tensor); + +// Returns the TF C API Data type that corresponds to the given TfLiteType. +TF_DataType GetTensorFlowDataType(TfLiteType type); + +} // namespace eager +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_ diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/eager/util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e92da8d34f4f6c9b9c1ecd959cfaed25051f826 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/eager/util_test.cc @@ -0,0 +1,123 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/eager/util.h" + +#include + +#include +#include +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +namespace eager { +namespace { + +using tensorflow::DT_FLOAT; +using tensorflow::Tensor; +using ::testing::ElementsAre; + +struct TestContext : public TfLiteContext { + string error; + std::vector new_size; +}; + +void ReportError(TfLiteContext* context, const char* format, ...) { + TestContext* c = static_cast(context); + const size_t kBufferSize = 1024; + char temp_buffer[kBufferSize]; + + va_list args; + va_start(args, format); + vsnprintf(temp_buffer, kBufferSize, format, args); + va_end(args); + + c->error = temp_buffer; +} + +TfLiteStatus ResizeTensor(TfLiteContext* context, TfLiteTensor* tensor, + TfLiteIntArray* new_size) { + TestContext* c = static_cast(context); + c->new_size.clear(); + for (int i = 0; i < new_size->size; ++i) { + c->new_size.push_back(new_size->data[i]); + } + TfLiteIntArrayFree(new_size); + return kTfLiteOk; +} + +TEST(UtilTest, ConvertStatus) { + TestContext context; + context.ReportError = ReportError; + + EXPECT_EQ(ConvertStatus(&context, tensorflow::errors::Internal("Some Error")), + kTfLiteError); + EXPECT_EQ(context.error, "Some Error"); + + context.error.clear(); + EXPECT_EQ(ConvertStatus(&context, tensorflow::Status()), kTfLiteOk); + EXPECT_TRUE(context.error.empty()); +} + +TEST(UtilTest, CopyShape) { + TestContext context; + context.ReportError = ReportError; + context.ResizeTensor = ResizeTensor; + + TfLiteTensor dst; + + EXPECT_EQ(CopyShape(&context, Tensor(), &dst), kTfLiteOk); + EXPECT_THAT(context.new_size, ElementsAre(0)); + + EXPECT_EQ(CopyShape(&context, Tensor(DT_FLOAT, {1, 2}), &dst), kTfLiteOk); + EXPECT_THAT(context.new_size, ElementsAre(1, 2)); + + EXPECT_EQ(CopyShape(&context, Tensor(DT_FLOAT, {1LL << 44, 2}), &dst), + kTfLiteError); + EXPECT_EQ(context.error, + "Dimension value in TensorFlow shape is larger than supported by " + "TF Lite"); +} + +TEST(UtilTest, TypeConversions) { + EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteNoType)); + EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteFloat32)); + EXPECT_EQ(TF_INT16, GetTensorFlowDataType(kTfLiteInt16)); + EXPECT_EQ(TF_INT32, GetTensorFlowDataType(kTfLiteInt32)); + EXPECT_EQ(TF_UINT8, GetTensorFlowDataType(kTfLiteUInt8)); + EXPECT_EQ(TF_INT64, GetTensorFlowDataType(kTfLiteInt64)); + EXPECT_EQ(TF_COMPLEX64, GetTensorFlowDataType(kTfLiteComplex64)); + EXPECT_EQ(TF_STRING, GetTensorFlowDataType(kTfLiteString)); + EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool)); +} + +TEST(UtilTest, IsEagerOp) { + EXPECT_TRUE(IsEagerOp("Eager")); + EXPECT_TRUE(IsEagerOp("EagerOp")); + EXPECT_FALSE(IsEagerOp("eager")); + EXPECT_FALSE(IsEagerOp("Eage")); + EXPECT_FALSE(IsEagerOp("OpEager")); + EXPECT_FALSE(IsEagerOp(nullptr)); + EXPECT_FALSE(IsEagerOp("")); +} + +} // namespace +} // namespace eager +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/contrib/lite/delegates/nnapi/BUILD index 35a8f6ca4166e373ea1a0af5d4a013327b30d2b6..954955f24b87f79a8dbe2863f608d532e25902c6 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/BUILD +++ b/tensorflow/contrib/lite/delegates/nnapi/BUILD @@ -22,6 +22,10 @@ tf_cc_test( name = "nnapi_delegate_test", size = "small", srcs = ["nnapi_delegate_test.cc"], + tags = [ + "no_oss", + "noasan", # TODO(b/112326936): re-enable for asan once fixed. + ], deps = [ ":nnapi_delegate", "//tensorflow/contrib/lite:framework", diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc index fd798c209e5112235cf6e351e231d4096006a8b0..b1b8e9890c99b0e9fdbf735bf3f04795e8577203 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc @@ -142,6 +142,12 @@ class NNAPIOpBuilder { ANEURALNETWORKS_TENSOR_INT32); } + TfLiteStatus AddVectorFloat32Operand(const float* values, + uint32_t num_values) { + return AddVectorOperand(values, num_values, + ANEURALNETWORKS_TENSOR_FLOAT32); + } + TfLiteStatus AddPoolingParams(void* data) { auto builtin = reinterpret_cast(data); AddScalarInt32Operand(builtin->padding); @@ -167,6 +173,37 @@ class NNAPIOpBuilder { return kTfLiteOk; } + TfLiteStatus AddAdditionalFloat32OutputTensor(uint32_t dimension_count) { + std::vector dims(dimension_count, 0); + ANeuralNetworksOperandType operand_type{ + .type = ANEURALNETWORKS_TENSOR_FLOAT32, + .dimensionCount = dimension_count, + .dimensions = dims.data()}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + int ann_operand = operand_mapping_->add_new_non_tensor_operand(); + augmented_outputs_.push_back(ann_operand); + return kTfLiteOk; + } + + TfLiteStatus AddStateFloat32Tensor(int tensor_index, + int* ann_tensor_index_out) { + TfLiteTensor* tensor = &context_->tensors[tensor_index]; + int ann_index = operand_mapping_->add_new_non_tensor_operand(); + + ANeuralNetworksOperandType operand_type{ + ANEURALNETWORKS_TENSOR_FLOAT32, + static_cast(tensor->dims->size), + reinterpret_cast(tensor->dims->data), tensor->params.scale, + tensor->params.zero_point}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + augmented_inputs_.push_back(ann_index); + + *ann_tensor_index_out = ann_index; + return kTfLiteOk; + } + // Adds a new NN API tensor that shadows the TF Lite tensor `tensor_index`. // This returns the NN API tensor index corresponding to the created tensor. // If another caller previously created a NN API tensor for `tensor_index` @@ -198,6 +235,10 @@ class NNAPIOpBuilder { nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM; scale = tensor->params.scale; zeroPoint = tensor->params.zero_point; + if (scale == 0) { + // TENSOR_QUANT8_ASYMM with zero scale is not valid in NNAPI. + scale = 1; + } break; case kTfLiteInt32: nn_type = ANEURALNETWORKS_TENSOR_INT32; @@ -290,9 +331,10 @@ class NNAPIDelegateKernel { public: NNAPIDelegateKernel() = default; - typedef ANeuralNetworksOperationType (*MappingFn)(TfLiteContext*, - NNAPIOpBuilder* builder, - TfLiteNode* node); + typedef ANeuralNetworksOperationType (*MappingFn)( + TfLiteContext*, NNAPIOpBuilder* builder, TfLiteNode* node, + std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs); // Return a function that knows how to translate a node into its operands // when called. You can use this function to see if a node is supported @@ -303,7 +345,9 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinAdd: if (version == 1) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast(node->builtin_data); builder->AddScalarInt32Operand(builtin->activation); @@ -316,7 +360,9 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinMul: if (version == 1) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast(node->builtin_data); builder->AddScalarInt32Operand(builtin->activation); @@ -329,7 +375,9 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinAveragePool2d: if (version == 1) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { builder->AddPoolingParams(node->builtin_data); return ANEURALNETWORKS_AVERAGE_POOL_2D; }; @@ -340,7 +388,9 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinMaxPool2d: if (version == 1) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { builder->AddPoolingParams(node->builtin_data); return ANEURALNETWORKS_MAX_POOL_2D; }; @@ -351,7 +401,9 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinL2Pool2d: if (version == 1) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { builder->AddPoolingParams(node->builtin_data); return ANEURALNETWORKS_L2_POOL_2D; }; @@ -369,7 +421,9 @@ class NNAPIDelegateKernel { return nullptr; } return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast(node->builtin_data); builder->AddScalarInt32Operand(builtin->padding); @@ -385,7 +439,9 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinDepthwiseConv2d: if (version == 1) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast( node->builtin_data); builder->AddScalarInt32Operand(builtin->padding); @@ -402,7 +458,9 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinFullyConnected: if (version == 1) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast( node->builtin_data); builder->AddScalarInt32Operand(builtin->activation); @@ -415,7 +473,9 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinSoftmax: if (version == 1) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast(node->builtin_data); builder->AddScalarFloat32Operand(builtin->beta); @@ -428,7 +488,9 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinReshape: if (version == 1) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { return ANEURALNETWORKS_RESHAPE; }; } else { @@ -436,10 +498,11 @@ class NNAPIDelegateKernel { } break; case kTfLiteBuiltinSqueeze: - // Squeeze requires NNAPI1.1. if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) { return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { auto builtin = reinterpret_cast(node->builtin_data); // Note that we add the squeeze dimensions even if the dimensions @@ -452,6 +515,414 @@ class NNAPIDelegateKernel { } else { return nullptr; } + case kTfLiteBuiltinL2Normalization: { + auto builtin = + reinterpret_cast(node->builtin_data); + if (builtin->activation != kTfLiteActNone) { + // NNAPI does not support activations + return nullptr; + } + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_L2_NORMALIZATION; + }; + } + case kTfLiteBuiltinLocalResponseNormalization: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + node->builtin_data); + builder->AddScalarInt32Operand(builtin->radius); + builder->AddScalarFloat32Operand(builtin->bias); + builder->AddScalarFloat32Operand(builtin->alpha); + builder->AddScalarFloat32Operand(builtin->beta); + return ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION; + }; + } else { + // TODO(miaowang): clean-up code and return early in the unsupported + // case. + return nullptr; + } + break; + case kTfLiteBuiltinLshProjection: + if (version == 1) { + // NNAPI does not support sparse projection correctly (b/111751836). + if (reinterpret_cast(node->builtin_data) + ->type == kTfLiteLshProjectionSparse) { + return nullptr; + } + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + node->builtin_data); + builder->AddScalarInt32Operand(builtin->type); + return ANEURALNETWORKS_LSH_PROJECTION; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinConcatenation: + if (version == 1 && + reinterpret_cast(node->builtin_data) + ->activation == kTfLiteActNone) { + if (context->tensors[node->inputs->data[0]].type == kTfLiteUInt8) { + // NNAPI only support concatenating quantized tensor of the same + // scale and offset. + auto first_param = context->tensors[node->inputs->data[0]].params; + for (int i = 0; i < node->inputs->size; i++) { + auto curr_param = context->tensors[node->inputs->data[i]].params; + if (curr_param.scale != first_param.scale || + curr_param.zero_point != first_param.zero_point) { + return nullptr; + } + } + } + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + node->builtin_data); + builder->AddScalarInt32Operand(builtin->axis); + return ANEURALNETWORKS_CONCATENATION; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinDequantize: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_DEQUANTIZE; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinFloor: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_FLOOR; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinRelu: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_RELU; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinReluN1To1: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_RELU1; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinRelu6: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_RELU6; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinLogistic: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_LOGISTIC; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinTanh: + // TODO(miaowang): add additional checks for the parameters. + if (version == 1 && + context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) { + // NNAPI only support float tanh. + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_TANH; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinSub: + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11 && + context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) { + // NNAPI only support float sub. + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_SUB; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinDiv: + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11 && + context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) { + // NNAPI only support float div. + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_DIV; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinPad: + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11 && + node->inputs->size == 2 && + context->tensors[node->inputs->data[0]].type == kTfLiteFloat32) { + // NNAPI does not support specifying the padding value. + // NNAPI pads physical zero for quantized tensors, so only delegate + // float pad to NNAPI. + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_PAD; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinSpaceToBatchNd: + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_SPACE_TO_BATCH_ND; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinStridedSlice: + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->begin_mask); + builder->AddScalarInt32Operand(builtin->end_mask); + builder->AddScalarInt32Operand(builtin->shrink_axis_mask); + return ANEURALNETWORKS_STRIDED_SLICE; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinTranspose: + // Note that the permutation input tensor value dictates the output + // dimensions. + // TODO(b/110888333): Support dynamically-sized tensors in delegates. + if ((version == 1) && + (kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) && + (node->inputs->size > 1) && + (context->tensors[node->inputs->data[1]].allocation_type == + kTfLiteMmapRo)) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_TRANSPOSE; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinRnn: + // NNAPI only support float32 weights. + // TODO(miaowang): check the number of inputs before accessing it. + if (version == 1 && + context->tensors[node->inputs->data[/*kWeightsTensor*/ 1]].type == + kTfLiteFloat32) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + // NNAPI need both state_in and state_out. + int ann_index; + builder->AddStateFloat32Tensor( + node->outputs->data[/*kHiddenStateTensor*/ 0], &ann_index); + model_state_inputs->push_back(ann_index); + model_state_tfl_outputs->push_back( + node->outputs->data[/*kHiddenStateTensor*/ 0]); + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_RNN; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinSvdf: + // NNAPI only support float32 weights. + if (version == 1 && + context->tensors[node->inputs->data[/*kWeightsFeatureTensor*/ 1]] + .type == kTfLiteFloat32) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + // NNAPI need both state_in and state_out. + int ann_index; + builder->AddStateFloat32Tensor( + node->outputs->data[/*kStateTensor*/ 0], &ann_index); + model_state_inputs->push_back(ann_index); + model_state_tfl_outputs->push_back( + node->outputs->data[/*kStateTensor*/ 0]); + + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->rank); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_SVDF; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinLstm: + // NNAPI only support float32 weights. + // TODO(miaowang): add loggings to indicate why the op is rejected. + if (version == 1 && node->inputs->size == 18 && + context->tensors[node->inputs + ->data[/*kInputToOutputWeightsTensor*/ 4]] + .type == kTfLiteFloat32) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + // NNAPI need both state_in and state_out for cell_state and + // output_state. + int ann_index; + builder->AddStateFloat32Tensor( + node->outputs->data[/*kOutputStateTensor*/ 0], &ann_index); + model_state_inputs->push_back(ann_index); + model_state_tfl_outputs->push_back( + node->outputs->data[/*kOutputStateTensor*/ 0]); + builder->AddStateFloat32Tensor( + node->outputs->data[/*kCellStateTensor*/ 1], &ann_index); + model_state_inputs->push_back(ann_index); + model_state_tfl_outputs->push_back( + node->outputs->data[/*kCellStateTensor*/ 1]); + + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + builder->AddScalarFloat32Operand(builtin->cell_clip); + builder->AddScalarFloat32Operand(builtin->proj_clip); + + // Current NNAPI implementation requires the sratch_buffer as + // output. + builder->AddAdditionalFloat32OutputTensor(2); + return ANEURALNETWORKS_LSTM; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinMean: + // NNAPI does not support generating a scalar as output for MEAN. + if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11 && + context->tensors[node->inputs->data[0]].type == kTfLiteFloat32 && + context->tensors[node->outputs->data[0]].dims->size > 0) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + int32_t keep_dims = 0; + if (builtin->keep_dims) keep_dims = 1; + builder->AddScalarInt32Operand(keep_dims); + return ANEURALNETWORKS_MEAN; + }; + } else { + return nullptr; + } + case kTfLiteBuiltinEmbeddingLookup: + // NNAPI only support float32 values. + if (version == 1 && + context->tensors[node->inputs->data[1]].type == kTfLiteFloat32) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_EMBEDDING_LOOKUP; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinHashtableLookup: + // NNAPI only support float32 output. + if (version == 1 && + context->tensors[node->outputs->data[0]].type == kTfLiteFloat32) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node, std::vector* model_state_inputs, + std::vector* model_state_tfl_outputs) + -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_HASHTABLE_LOOKUP; + }; + } else { + return nullptr; + } break; default: return nullptr; @@ -492,7 +963,12 @@ class NNAPIDelegateKernel { // Set the input tensor buffers. Note: we access tflite tensors using // absolute indices but NN api indices inputs by relative indices. int relative_input_index = 0; + int num_optional_tensors = 0; for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) { + if (absolute_input_index == kOptionalTensor) { + num_optional_tensors++; + continue; + } TfLiteTensor* tensor = &context->tensors[absolute_input_index]; // TODO(miaowang): make sure the delegation works with dequantized weights // as intermediate tensors. @@ -513,6 +989,20 @@ class NNAPIDelegateKernel { tensor->data.raw, tensor->bytes)); relative_output_index++; } + + // The state_out of previous invocation need to be mapped to state_in of + // current invocation. + for (size_t i = 0; i < model_state_tfl_outputs_.size(); i++) { + int state_tensor_idx = model_state_tfl_outputs_[i]; + TfLiteTensor* tensor = &context->tensors[state_tensor_idx]; + // Here we are using a deep copy for state_in tensors so that we are not + // reading and writing into the same buffer during a invocation. + // TODO(110369471): using double shared buffer to minimize the copies. + CHECK_NN(context, + ANeuralNetworksExecution_setInput( + execution, i + node->inputs->size - num_optional_tensors, + nullptr, tensor->data.raw, tensor->bytes)); + } // Invoke ANN in blocking fashion. ANeuralNetworksEvent* event = nullptr; CHECK_NN(context, ANeuralNetworksExecution_startCompute(execution, &event)); @@ -534,6 +1024,9 @@ class NNAPIDelegateKernel { // Track indices we use OperandMapping operand_mapping_; + std::vector model_state_inputs_; + std::vector model_state_tfl_outputs_; + TfLiteStatus AddOpsAndTensors(TfLiteContext* context) { // The operand builder allows creating a single op. We create it at this // reduced power position rather than in the for loop to avoid reallocating @@ -548,11 +1041,22 @@ class NNAPIDelegateKernel { context->GetNodeAndRegistration(context, node_index, &node, ®); // Map inputs to NN API tensor indices. for (auto input_index : TfLiteIntArrayView(node->inputs)) { - TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index)); + if (input_index == kOptionalTensor && + (reg->builtin_code == kTfLiteBuiltinLstm || + reg->builtin_code == kTfLiteBuiltinSvdf)) { + // properly handle the optional tensor for LSTM and SVDF. + // currently only support float32. + // TODO(miaowang): make sure this is also able to handle quantized + // tensor when supported by NNAPI. + TF_LITE_ENSURE_STATUS(builder.AddVectorFloat32Operand(nullptr, 0)); + } else { + TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index)); + } } // Get op type and operands int nn_op_type = Map(context, reg->builtin_code, reg->version, node)( - context, &builder, node); + context, &builder, node, &model_state_inputs_, + &model_state_tfl_outputs_); // Map outputs to NN API tensor indices. for (auto output_index : TfLiteIntArrayView(node->outputs)) { TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index)); @@ -576,12 +1080,20 @@ class NNAPIDelegateKernel { // Make the TensorFlow lite inputs and outputs to ann_indices. for (int i : TfLiteIntArrayView(input_tensors)) { // Constant tensors are not NNAPI inputs. - if (context->tensors[i].allocation_type != kTfLiteMmapRo) { + if (i != kOptionalTensor && + context->tensors[i].allocation_type != kTfLiteMmapRo) { inputs.push_back(operand_mapping_.lite_index_to_ann(i)); } } - for (int i : TfLiteIntArrayView(output_tensors)) + // Add state input tensors as model inputs + for (int i : model_state_inputs_) { + inputs.push_back(i); + } + + for (int i : TfLiteIntArrayView(output_tensors)) { outputs.push_back(operand_mapping_.lite_index_to_ann(i)); + } + // Tell ANN to declare inputs/outputs CHECK_NN(context, ANeuralNetworksModel_identifyInputsAndOutputs( nn_model_.get(), inputs.size(), inputs.data(), diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc index aad10c9ce730a2e90481a123a1e3e323cfb2bd42..3224b23a0c3bc8456bd75f2923d16f0eed7d53ff 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -27,14 +27,20 @@ using ::testing::ElementsAreArray; // TODO(b/110368244): figure out how to share the existing tests in kernels/ but // with the delegation on. Also, add more unit tests to improve code coverage. -class FloatAddOpModel : public SingleOpModel { +class SingleOpModelWithNNAPI : public SingleOpModel { + public: + SingleOpModelWithNNAPI() { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate(), false); + }); + } +}; + +class FloatAddOpModel : public SingleOpModelWithNNAPI { public: FloatAddOpModel(const TensorData& input1, const TensorData& input2, const TensorData& output, ActivationFunctionType activation_type) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); input1_ = AddInput(input1); input2_ = AddInput(input2); output_ = AddOutput(output); @@ -76,14 +82,11 @@ TEST(NNAPIDelegate, AddWithRelu) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({0.0, 0.4, 1.0, 1.3})); } -class FloatMulOpModel : public SingleOpModel { +class FloatMulOpModel : public SingleOpModelWithNNAPI { public: FloatMulOpModel(const TensorData& input1, const TensorData& input2, const TensorData& output, ActivationFunctionType activation_type) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); input1_ = AddInput(input1); input2_ = AddInput(input2); output_ = AddOutput(output); @@ -114,15 +117,11 @@ TEST(NNAPIDelegate, MulWithNoActivation) { ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4}))); } -class FloatPoolingOpModel : public SingleOpModel { +class FloatPoolingOpModel : public SingleOpModelWithNNAPI { public: FloatPoolingOpModel(BuiltinOperator type, const TensorData& input, int filter_width, int filter_height, const TensorData& output) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); - input_ = AddInput(input); output_ = AddOutput(output); @@ -185,7 +184,7 @@ TEST(NNAPIDelegate, L2PoolWithNoActivation) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5})); } -class BaseConvolutionOpModel : public SingleOpModel { +class BaseConvolutionOpModel : public SingleOpModelWithNNAPI { public: BaseConvolutionOpModel( const TensorData& input, const TensorData& filter, @@ -193,10 +192,6 @@ class BaseConvolutionOpModel : public SingleOpModel { enum Padding padding = Padding_VALID, enum ActivationFunctionType activation = ActivationFunctionType_NONE, int dilation_width_factor = 1, int dilation_height_factor = 1) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); - input_ = AddInput(input); filter_ = AddInput(filter); @@ -344,14 +339,10 @@ TEST(NNAPIDelegate, Conv2DWithNoActivation) { })); } -class DepthwiseConvolutionOpModel : public SingleOpModel { +class DepthwiseConvolutionOpModel : public SingleOpModelWithNNAPI { public: DepthwiseConvolutionOpModel(const TensorData& input, const TensorData& filter, const TensorData& output) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); - input_ = AddInput(input); filter_ = AddInput(filter); @@ -426,15 +417,11 @@ TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) { })); } -class FloatFullyConnectedOpModel : public SingleOpModel { +class FloatFullyConnectedOpModel : public SingleOpModelWithNNAPI { public: FloatFullyConnectedOpModel(int units, int batches, const TensorData& input, const TensorData& output = {TensorType_FLOAT32}) : batches_(batches), units_(units) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); - int total_input_size = 1; for (int i = 0; i < input.shape.size(); ++i) { total_input_size *= input.shape[i]; @@ -515,14 +502,10 @@ TEST(NNAPIDelegate, FullyConnectedSimpleTest) { EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); } -class SoftmaxOpModel : public SingleOpModel { +class SoftmaxOpModel : public SingleOpModelWithNNAPI { public: SoftmaxOpModel(int batches, int size, float beta) : batches_(batches), input_size_(size), beta_(beta) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); - input_ = AddInput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions, @@ -566,14 +549,10 @@ TEST(NNAPIDelegate, SoftmaxSimpleTest) { 1e-6))); } -class ReshapeOpModel : public SingleOpModel { +class ReshapeOpModel : public SingleOpModelWithNNAPI { public: ReshapeOpModel(std::initializer_list input_shape, std::initializer_list new_shape) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); - input_ = AddInput(TensorType_FLOAT32); new_shape_ = AddInput(TensorType_INT32); output_ = AddOutput(TensorType_FLOAT32); @@ -605,14 +584,10 @@ TEST(NNAPIDelegate, ReshapeSimpleTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); } -class SqueezeOpModel : public SingleOpModel { +class SqueezeOpModel : public SingleOpModelWithNNAPI { public: SqueezeOpModel(const TensorData& input, const TensorData& output, std::initializer_list axis) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); - input_ = AddInput(input); output_ = AddOutput(output); SetBuiltinOp( @@ -666,6 +641,2880 @@ TEST(NNAPIDelegate, SqueezeWithAxisTest) { 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); } +class L2NormOpModel : public SingleOpModelWithNNAPI { + public: + L2NormOpModel(const TensorData& input, const TensorData& output, + ActivationFunctionType activation_type) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions, + CreateL2NormOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input_)}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int new_shape_; + int output_; +}; + +TEST(NNAPIDelegate, L2NormSimpleTest) { + std::initializer_list data = {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}; + L2NormOpModel m({TensorType_FLOAT32, {1, 1, 1, 6}}, + {TensorType_FLOAT32, {1, 1, 1, 6}}, + ActivationFunctionType_NONE); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 6})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05})); +} + +class TransposeSimpleModel : public SingleOpModelWithNNAPI { + public: + TransposeSimpleModel(std::initializer_list input_shape, + std::initializer_list perm_shape, + std::initializer_list perm) { + input_ = AddInput(TensorType_FLOAT32); + perm_ = AddConstInput(TensorType_INT32, perm, perm_shape); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions, + CreateTransposeOptions(builder_).Union()); + BuildInterpreter({input_shape, perm_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int perm_; + int output_; +}; + +TEST(NNAPIDelegate, TransposeSimpleTest) { + TransposeSimpleModel m({2, 3, 4}, {3}, {2, 0, 1}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, + 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23})); +} + +class FloatSubOpModel : public SingleOpModelWithNNAPI { + public: + FloatSubOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_SUB, BuiltinOptions_SubOptions, + CreateMulOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +TEST(NNAPIDelegate, SubWithNoActivation) { + FloatSubOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-2.1, 0.0, 0.4, 0.3}))); +} + +class FloatDivOpModel : public SingleOpModelWithNNAPI { + public: + FloatDivOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_DIV, BuiltinOptions_DivOptions, + CreateMulOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +TEST(NNAPIDelegate, DivWithNoActivation) { + FloatDivOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.8, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.4, 0.2}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({-20, 1, 2, 4}))); +} + +class BaseConcatenationOpModel : public SingleOpModelWithNNAPI { + public: + BaseConcatenationOpModel() {} + BaseConcatenationOpModel(const TensorData& input_template, int axis, + int num_inputs) { + std::vector> all_input_shapes; + for (int i = 0; i < num_inputs; ++i) { + all_input_shapes.push_back(input_template.shape); + AddInput(input_template); + } + output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min, + input_template.max}); + SetBuiltinOp( + BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions, + CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE) + .Union()); + BuildInterpreter(all_input_shapes); + } + + protected: + int output_; +}; + +class ConcatenationOpModel : public BaseConcatenationOpModel { + public: + using BaseConcatenationOpModel::BaseConcatenationOpModel; + void SetInput(int index, std::initializer_list data) { + PopulateTensor(index, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +TEST(NNAPIDelegate, ConcatenationThreeDimensionalOneInput) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/1, + /*num_inputs=*/1); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 3, 4, 7})); +} + +TEST(NNAPIDelegate, ConcatenationFourInputs) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/2, + /*num_inputs=*/4); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray({ + 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // + 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // + })); +} + +class QuantizedConcatenationOpModel : public BaseConcatenationOpModel { + public: + using BaseConcatenationOpModel::BaseConcatenationOpModel; + QuantizedConcatenationOpModel(const std::vector& input_template, + int axis, int num_inputs, + const TensorData& output_template) { + std::vector> all_input_shapes; + CHECK_EQ(input_template.size(), num_inputs); + for (int i = 0; i < num_inputs; ++i) { + all_input_shapes.push_back(input_template[i].shape); + AddInput(input_template[i]); + } + output_ = AddOutput({output_template.type, /*shape=*/{}, + output_template.min, output_template.max}); + SetBuiltinOp( + BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions, + CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE) + .Union()); + BuildInterpreter(all_input_shapes); + } + void SetInput(int index, std::initializer_list data) { + QuantizeAndPopulate(index, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(NNAPIDelegate, ConcatenationFourInputsQuantized) { + QuantizedConcatenationOpModel m0({TensorType_UINT8, {2, 1, 2}, -12.7, 12.8}, + /*axis=*/2, + /*num_inputs=*/4); + + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.Invoke(); + EXPECT_THAT(m0.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // + 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // + }))); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({ + 137, 157, 138, 158, 139, 159, 140, 160, // + 167, 197, 168, 198, 169, 199, 170, 200, // + })); +} + +TEST(NNAPIDelegate, ConcatenationFourInputsQuantizedMixedRange) { + QuantizedConcatenationOpModel m0({{TensorType_UINT8, {2, 1, 2}, -10.7, 10.8}, + {TensorType_UINT8, {2, 1, 2}, 0, 12.8}, + {TensorType_UINT8, {2, 1, 2}, -11, 11.8}, + {TensorType_UINT8, {2, 1, 2}, 0, 7.4}}, + /*axis=*/2, /*num_inputs=*/4, + {TensorType_UINT8, {2, 1, 2}, -12.7, 12.8}); + + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.Invoke(); + EXPECT_THAT(m0.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // + 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // + }))); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({ + 137, 157, 138, 158, 139, 159, 140, 160, // + 167, 197, 168, 198, 169, 199, 170, 200, // + })); +} + +class DequantizeOpModel : public SingleOpModelWithNNAPI { + public: + DequantizeOpModel(std::initializer_list shape, float min, float max) { + input_ = AddInput({TensorType_UINT8, shape, min, max}); + output_ = AddOutput({TensorType_FLOAT32, shape}); + SetBuiltinOp(BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions, + CreateDequantizeOptions(builder_).Union()); + + BuildInterpreter({GetShape(input_)}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(NNAPIDelegate, DequantizeFourDimensional) { + DequantizeOpModel m({2, 5}, -63.5, 64); + + m.SetInput({0, 1, 2, 3, 4, 251, 252, 253, 254, 255}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64}))); +} + +class FloorOpModel : public SingleOpModelWithNNAPI { + public: + FloorOpModel(std::initializer_list input_shape, TensorType input_type) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_FLOOR, BuiltinOptions_NONE, 0); + BuildInterpreter({ + input_shape, + }); + } + + int input() { return input_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(NNAPIDelegate, FloorSingleDim) { + FloorOpModel model({2}, TensorType_FLOAT32); + model.PopulateTensor(model.input(), {8.5, 0.0}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({8, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2})); +} + +TEST(NNAPIDelegate, FloorMultiDims) { + FloorOpModel model({2, 1, 1, 5}, TensorType_FLOAT32); + model.PopulateTensor(model.input(), { + 0.0001, + 8.0001, + 0.9999, + 9.9999, + 0.5, + -0.0001, + -8.0001, + -0.9999, + -9.9999, + -0.5, + }); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({0, 8, 0, 9, 0, -1, -9, -1, -10, -1})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 5})); +} + +class LocalResponseNormOpModel : public SingleOpModelWithNNAPI { + public: + LocalResponseNormOpModel(std::initializer_list input_shape, int radius, + float bias, float alpha, float beta) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + BuiltinOptions_LocalResponseNormalizationOptions, + CreateLocalResponseNormalizationOptions(builder_, radius, bias, + alpha, beta) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(NNAPIDelegate, LocalResponseNormSameAsL2Norm) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0, + /*alpha=*/1.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 2. + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}))); +} + +TEST(NNAPIDelegate, LocalResponseNormWithAlpha) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 3. + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + {-0.275, 0.15, 0.175, 0.3, -0.175, 0.025}))); +} + +TEST(NNAPIDelegate, LocalResponseNormWithBias) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/9.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 5. + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.22, 0.12, 0.14, 0.24, -0.14, 0.02}))); +} + +TEST(NNAPIDelegate, LocalResponseNormSmallRadius) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/2, /*bias=*/9.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {-0.264926, 0.125109, 0.140112, 0.267261, -0.161788, 0.0244266}))); +} + +class LSHProjectionOpModel : public SingleOpModelWithNNAPI { + public: + LSHProjectionOpModel(LSHProjectionType type, + std::initializer_list hash_shape, + std::initializer_list input_shape, + std::initializer_list weight_shape) { + hash_ = AddInput(TensorType_FLOAT32); + input_ = AddInput(TensorType_INT32); + if (weight_shape.size() > 0) { + weight_ = AddInput(TensorType_FLOAT32); + } + output_ = AddOutput(TensorType_INT32); + + SetBuiltinOp(BuiltinOperator_LSH_PROJECTION, + BuiltinOptions_LSHProjectionOptions, + CreateLSHProjectionOptions(builder_, type).Union()); + if (weight_shape.size() > 0) { + BuildInterpreter({hash_shape, input_shape, weight_shape}); + } else { + BuildInterpreter({hash_shape, input_shape}); + } + + output_size_ = 1; + for (int i : hash_shape) { + output_size_ *= i; + if (type == LSHProjectionType_SPARSE) { + break; + } + } + } + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetHash(std::initializer_list data) { + PopulateTensor(hash_, data); + } + + void SetWeight(std::initializer_list f) { PopulateTensor(weight_, f); } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int hash_; + int weight_; + int output_; + + int output_size_; +}; + +TEST(NNAPIDelegate, LSHProjectionDense1DInputs) { + LSHProjectionOpModel m(LSHProjectionType_DENSE, {3, 2}, {5}, {5}); + + m.SetInput({12345, 54321, 67890, 9876, -12345678}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + m.SetWeight({1.0, 1.0, 1.0, 1.0, 1.0}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 1, 0, 0)); +} + +TEST(NNAPIDelegate, LSHProjectionSparse1DInputs) { + LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5}, {}); + + m.SetInput({12345, 54321, 67890, 9876, -12345678}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 1, 8 + 0)); +} + +TEST(NNAPIDelegate, LSHProjectionSparse3DInputs) { + LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5, 2, 2}, {5}); + + m.SetInput({1234, 2345, 3456, 1234, 4567, 5678, 6789, 4567, 7891, 8912, + 9123, 7890, -987, -876, -765, -987, -543, -432, -321, -543}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + m.SetWeight({0.12, 0.34, 0.56, 0.67, 0.78}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 2, 4 + 1, 8 + 1)); +} + +class BaseActivationsOpModel : public SingleOpModelWithNNAPI { + public: + // Most activations don't take any options, so this constructor works for + // them. + BaseActivationsOpModel(BuiltinOperator type, TensorData input) { + input_ = AddInput(input); + if (input.type == TensorType_UINT8) { + output_ = AddOutput({input.type, {}, 0, 0, 1. / 256}); + } else { + output_ = AddOutput({input.type, {}}); + } + SetBuiltinOp(type, BuiltinOptions_NONE, 0); + BuildInterpreter({GetShape(input_)}); + } + + BaseActivationsOpModel(BuiltinOperator type, const TensorData& input, + const TensorData& output) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(type, BuiltinOptions_NONE, 0); + BuildInterpreter({GetShape(input_)}); + } + + protected: + int input_; + int output_; +}; + +class FloatActivationsOpModel : public BaseActivationsOpModel { + public: + using BaseActivationsOpModel::BaseActivationsOpModel; + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +const float kQuantizedTolerance = 2 * (1. / 256); + +class QuantizedActivationsOpModel : public BaseActivationsOpModel { + public: + using BaseActivationsOpModel::BaseActivationsOpModel; + + template + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + template + + std::vector GetOutput() { + return ExtractVector(output_); + } + template + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); + } +}; + +TEST(NNAPIDelegate, Relu) { + FloatActivationsOpModel m(BuiltinOperator_RELU, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 2, 4, // + 3, 0, 10, 1, // + })); +} + +TEST(NNAPIDelegate, Relu1) { + FloatActivationsOpModel m(BuiltinOperator_RELU_N1_TO_1, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0.0, -0.6, 0.2, -0.4, // + 0.3, -2.0, 1.1, -0.1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0.0, -0.6, 0.2, -0.4, // + 0.3, -1.0, 1.0, -0.1, // + })); +} + +TEST(NNAPIDelegate, Relu6) { + FloatActivationsOpModel m(BuiltinOperator_RELU6, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 2, 4, // + 3, 0, 6, 1, // + })); +} + +TEST(NNAPIDelegate, Tanh) { + FloatActivationsOpModel m(BuiltinOperator_TANH, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0, -0.9999877, 0.9640275, 0.999329, // + 0.99505475, -0.9640275, 1, 0.7615941, // + }))); +} + +TEST(NNAPIDelegate, LogisticFloat) { + FloatActivationsOpModel m(BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }))); +} + +TEST(NNAPIDelegate, LogisticQuantized) { + QuantizedActivationsOpModel m( + BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }, + kQuantizedTolerance))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({128, 1, 227, 251, 244, 32, 255, 188})); +} + +#if 0 +class ResizeBilinearOpModel : public SingleOpModelWithNNAPI { + public: + ResizeBilinearOpModel(const TensorData& input, + std::initializer_list size_data = {}) { + bool const_size = size_data.size() != 0; + input_ = AddInput(input); + if (const_size) { + size_ = AddConstInput(TensorType_INT32, size_data, {2}); + } else { + size_ = AddInput({TensorType_INT32, {2}}); + } + output_ = AddOutput(input.type); + SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR, + BuiltinOptions_ResizeBilinearOptions, + CreateResizeBilinearOptions(builder_).Union()); + if (const_size) { + BuildInterpreter({GetShape(input_)}); + } else { + BuildInterpreter({GetShape(input_), GetShape(size_)}); + } + } + + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetSize(std::initializer_list data) { PopulateTensor(size_, data); } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + private: + int input_; + int size_; + int output_; +}; + +TEST(NNAPIDelegate, ResizeBilinearHorizontal) { + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}); + m.SetInput({3, 6}); + m.SetSize({1, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3}); + const_m.SetInput({3, 6}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); +} + +TEST(NNAPIDelegate, ResizeBilinearVertical) { + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}); + m.SetInput({3, 9}); + m.SetSize({3, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1}); + const_m.SetInput({3, 9}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); +} + +TEST(NNAPIDelegate, ResizeBilinearTwoDimensional) { + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}); + m.SetInput({ + 3, 6, // + 9, 12 // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); +} +#endif + +template +class PadOpModel : public SingleOpModelWithNNAPI { + public: + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetQuantizedInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetQuantizedPadValue(float data) { + QuantizeAndPopulate(constant_values_, {data}); + } + + void SetPaddings(std::initializer_list paddings) { + PopulateTensor(paddings_, paddings); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } + + protected: + int input_; + int output_; + int paddings_; + int constant_values_; +}; + +class PadOpConstModel : public PadOpModel { + public: + PadOpConstModel(const TensorData& input, + std::initializer_list paddings_shape, + std::initializer_list paddings, + const TensorData& output) { + input_ = AddInput(input); + paddings_ = AddConstInput(TensorType_INT32, paddings, paddings_shape); + output_ = AddOutput(output); + + SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions, + CreatePadOptions(builder_).Union()); + BuildInterpreter({input.shape}); + } +}; + +TEST(NNAPIDelegate, PadAdvancedConstTest) { + PadOpConstModel m({TensorType_FLOAT32, {1, 2, 3, 1}}, {4, 2}, + {0, 0, 0, 2, 1, 3, 0, 0}, {TensorType_FLOAT32}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); +} + +class SpaceToBatchNDOpModel : public SingleOpModelWithNNAPI { + public: + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetBlockShape(std::initializer_list data) { + PopulateTensor(block_shape_, data); + } + + void SetPaddings(std::initializer_list data) { + PopulateTensor(paddings_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int block_shape_; + int paddings_; + int output_; +}; + +class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel { + public: + SpaceToBatchNDOpConstModel(std::initializer_list input_shape, + std::initializer_list block_shape, + std::initializer_list paddings) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2}); + paddings_ = AddConstInput(TensorType_INT32, paddings, {2, 2}); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND, + BuiltinOptions_SpaceToBatchNDOptions, + CreateSpaceToBatchNDOptions(builder_).Union()); + BuildInterpreter({input_shape}); + } +}; + +TEST(NNAPIDelegate, SpaceToBatchNDSimpleConstTest) { + SpaceToBatchNDOpConstModel m({1, 4, 4, 1}, {2, 2}, {0, 0, 0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7, + 13, 15, 6, 8, 14, 16})); +} + +TEST(NNAPIDelegate, SpaceToBatchNDMultipleInputBatchesConstTest) { + SpaceToBatchNDOpConstModel m({2, 2, 4, 1}, {2, 2}, {0, 0, 0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7, + 13, 15, 6, 8, 14, 16})); +} + +TEST(NNAPIDelegate, SpaceToBatchNDSimplePaddingConstTest) { + SpaceToBatchNDOpConstModel m({1, 5, 2, 1}, {3, 2}, {1, 0, 2, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 0, 5, 0, 0, 0, 6, 0, 1, 0, 7, + 0, 2, 0, 8, 0, 3, 0, 9, 0, 4, 0, 10, + })); +} + +TEST(NNAPIDelegate, SpaceToBatchNDComplexPaddingConstTest) { + SpaceToBatchNDOpConstModel m({1, 4, 2, 1}, {3, 2}, {1, 1, 2, 4}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, + 0, 1, 0, 0, 0, 7, 0, 0, 0, 2, 0, 0, 0, 8, 0, 0, + 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, + })); +} + +template +class StridedSliceOpModel : public SingleOpModelWithNNAPI { + public: + StridedSliceOpModel(std::initializer_list input_shape, + std::initializer_list begin_shape, + std::initializer_list end_shape, + std::initializer_list strides_shape, int begin_mask, + int end_mask, int ellipsis_mask, int new_axis_mask, + int shrink_axis_mask) { + input_ = AddInput(tensor_input_type); + begin_ = AddInput(TensorType_INT32); + end_ = AddInput(TensorType_INT32); + strides_ = AddInput(TensorType_INT32); + output_ = AddOutput(tensor_input_type); + SetBuiltinOp( + BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions, + CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask, + new_axis_mask, shrink_axis_mask) + .Union()); + BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetBegin(std::initializer_list data) { + PopulateTensor(begin_, data); + } + void SetEnd(std::initializer_list data) { + PopulateTensor(end_, data); + } + void SetStrides(std::initializer_list data) { + PopulateTensor(strides_, data); + } + + std::vector GetOutput() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int begin_; + int end_; + int strides_; + int output_; +}; + +TEST(NNAPIDelegate, StridedSliceIn2D) { + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, 0}); + m.SetEnd({2, 2}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5})); +} + +TEST(NNAPIDelegate, StridedSliceIn2D_ShrinkAxis_NegativeSlice) { + // This is equivalent to tf.range(4)[:, tf.newaxis][-2, -1]. + StridedSliceOpModel<> m({4, 1}, {2}, {2}, {2}, 0, 0, 0, 0, 3); + m.SetInput({0, 1, 2, 3}); + m.SetBegin({-2, -1}); + m.SetEnd({-1, 0}); + m.SetStrides({1, 1}); + + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); +} + +TEST(NNAPIDelegate, StridedSliceIn2D_ShrinkAxisMask) { + StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, 0}); + m.SetEnd({1, 1}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); +} + +static float rnn_input[] = { + 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, + 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471, + -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222, + 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933, + 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103, + 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043, + -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, + -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154, + 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584, + 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144, + 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351, + -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719, + 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, + -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881, + -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032, + -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374, + 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071, + -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219, + -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, + 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493, + -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265, + 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539, + 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446, + 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017, + -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, + 0.93455386, -0.6324693, -0.083922029}; + +static float rnn_golden_output[] = { + 0.496726, 0, 0.965996, 0, 0.0584254, 0, + 0, 0.12315, 0, 0, 0.612266, 0.456601, + 0, 0.52286, 1.16099, 0.0291232, + + 0, 0, 0.524901, 0, 0, 0, + 0, 1.02116, 0, 1.35762, 0, 0.356909, + 0.436415, 0.0355727, 0, 0, + + 0, 0, 0, 0.262335, 0, 0, + 0, 1.33992, 0, 2.9739, 0, 0, + 1.31914, 2.66147, 0, 0, + + 0.942568, 0, 0, 0, 0.025507, 0, + 0, 0, 0.321429, 0.569141, 1.25274, 1.57719, + 0.8158, 1.21805, 0.586239, 0.25427, + + 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, + 0.363026, 0, 0.533426, 0, 1.25926, 0.722707, + 0, 1.22031, 1.30117, 0.495867, + + 0.222187, 0, 0.72725, 0, 0.767003, 0, + 0, 0.147835, 0, 0, 0, 0.608758, + 0.469394, 0.00720298, 0.927537, 0, + + 0.856974, 0.424257, 0, 0, 0.937329, 0, + 0, 0, 0.476425, 0, 0.566017, 0.418462, + 0.141911, 0.996214, 1.13063, 0, + + 0.967899, 0, 0, 0, 0.0831304, 0, + 0, 1.00378, 0, 0, 0, 1.44818, + 1.01768, 0.943891, 0.502745, 0, + + 0.940135, 0, 0, 0, 0, 0, + 0, 2.13243, 0, 0.71208, 0.123918, 1.53907, + 1.30225, 1.59644, 0.70222, 0, + + 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, + 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311, + 0.0454298, 0.300267, 0.562784, 0.395095, + + 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, + 0, 0, 0, 0.735363, 0.0759267, 1.91017, + 0.941888, 0, 0, 0, + + 0, 0, 1.5909, 0, 0, 0, + 0, 0.5755, 0, 0.184687, 0, 1.56296, + 0.625285, 0, 0, 0, + + 0, 0, 0.0857888, 0, 0, 0, + 0, 0.488383, 0.252786, 0, 0, 0, + 1.02817, 1.85665, 0, 0, + + 0.00981836, 0, 1.06371, 0, 0, 0, + 0, 0, 0, 0.290445, 0.316406, 0, + 0.304161, 1.25079, 0.0707152, 0, + + 0.986264, 0.309201, 0, 0, 0, 0, + 0, 1.64896, 0.346248, 0, 0.918175, 0.78884, + 0.524981, 1.92076, 2.07013, 0.333244, + + 0.415153, 0.210318, 0, 0, 0, 0, + 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, + 0.628881, 3.58099, 1.49974, 0}; + +static std::initializer_list rnn_weights = { + 0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}; + +static std::initializer_list rnn_recurrent_weights = { + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}; + +static std::initializer_list rnn_bias = { + 0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568, + -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178, + 0.37197268, 0.61957061, 0.3956964, -0.37609905}; + +class RNNOpModel : public SingleOpModelWithNNAPI { + public: + RNNOpModel(int batches, int units, int size, + const TensorType& weights = TensorType_FLOAT32, + const TensorType& recurrent_weights = TensorType_FLOAT32) + : batches_(batches), units_(units), input_size_(size) { + input_ = AddInput(TensorType_FLOAT32); + weights_ = AddInput(weights); + recurrent_weights_ = AddInput(recurrent_weights); + bias_ = AddInput(TensorType_FLOAT32); + hidden_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RNN, BuiltinOptions_RNNOptions, + CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union()); + BuildInterpreter({{batches_, input_size_}, + {units_, input_size_}, + {units_, units_}, + {units_}}); + } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetRecurrentWeights(std::initializer_list f) { + PopulateTensor(recurrent_weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + void ResetHiddenState() { + const int zero_buffer_size = units_ * batches_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(hidden_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + protected: + int input_; + int weights_; + int recurrent_weights_; + int bias_; + int hidden_state_; + int output_; + + int batches_; + int units_; + int input_size_; +}; + +TEST(NNAPIDelegate, RnnBlackBoxTest) { + RNNOpModel rnn(2, 16, 8); + rnn.SetWeights(rnn_weights); + rnn.SetBias(rnn_bias); + rnn.SetRecurrentWeights(rnn_recurrent_weights); + + rnn.ResetHiddenState(); + const int input_sequence_size = sizeof(rnn_input) / sizeof(float) / + (rnn.input_size() * rnn.num_batches()); + + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + rnn.SetInput(0, batch_start, batch_end); + rnn.SetInput(rnn.input_size(), batch_start, batch_end); + + rnn.Invoke(); + + float* golden_start = rnn_golden_output + i * rnn.num_units(); + float* golden_end = golden_start + rnn.num_units(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +static float svdf_input[] = { + 0.12609188, -0.46347019, -0.89598465, + 0.35867718, 0.36897406, 0.73463392, + + 0.14278367, -1.64410412, -0.75222826, + -0.57290924, 0.12729003, 0.7567004, + + 0.49837467, 0.19278903, 0.26584083, + 0.17660543, 0.52949083, -0.77931279, + + -0.11186574, 0.13164264, -0.05349274, + -0.72674477, -0.5683046, 0.55900657, + + -0.68892461, 0.37783599, 0.18263303, + -0.63690937, 0.44483393, -0.71817774, + + -0.81299269, -0.86831826, 1.43940818, + -0.95760226, 1.82078898, 0.71135032, + + -1.45006323, -0.82251364, -1.69082689, + -1.65087092, -1.89238167, 1.54172635, + + 0.03966608, -0.24936394, -0.77526885, + 2.06740379, -1.51439476, 1.43768692, + + 0.11771342, -0.23761693, -0.65898693, + 0.31088525, -1.55601168, -0.87661445, + + -0.89477462, 1.67204106, -0.53235275, + -0.6230064, 0.29819036, 1.06939757, +}; + +static float svdf_golden_output_rank_1[] = { + 0.014899, -0.0517661, -0.143725, -0.00271883, + -0.03004015, 0.09565311, 0.1587342, 0.00784263, + + 0.068281, -0.162217, -0.152268, 0.00323521, + 0.01582633, 0.03858774, -0.03001583, -0.02671271, + + -0.0317821, -0.0333089, 0.0609602, 0.0333759, + -0.01432795, 0.05524484, 0.1101355, -0.02382665, + + -0.00623099, -0.077701, -0.391193, -0.0136691, + -0.02333033, 0.02293761, 0.12338032, 0.04326871, + + 0.201551, -0.164607, -0.179462, -0.0592739, + 0.01064911, -0.17503069, 0.07821996, -0.00224009, + + 0.0886511, -0.0875401, -0.269283, 0.0281379, + -0.02282338, 0.09741908, 0.32973239, 0.12281385, + + -0.201174, -0.586145, -0.628624, -0.0330412, + 0.24780814, -0.39304617, -0.22473189, 0.02589256, + + -0.0839096, -0.299329, 0.108746, 0.109808, + 0.10084175, -0.06416984, 0.28936723, 0.0026358, + + 0.419114, -0.237824, -0.422627, 0.175115, + -0.2314795, -0.18584411, -0.4228974, -0.12928449, + + 0.36726, -0.522303, -0.456502, -0.175475, + 0.17012937, -0.34447709, 0.38505614, -0.28158101, +}; + +static float svdf_golden_output_rank_2[] = { + -0.09623547, -0.10193135, 0.11083051, -0.0347917, + 0.1141196, 0.12965347, -0.12652366, 0.01007236, + + -0.16396809, -0.21247184, 0.11259045, -0.04156673, + 0.10132131, -0.06143532, -0.00924693, 0.10084561, + + 0.01257364, 0.0506071, -0.19287863, -0.07162561, + -0.02033747, 0.22673416, 0.15487903, 0.02525555, + + -0.1411963, -0.37054959, 0.01774767, 0.05867489, + 0.09607603, -0.0141301, -0.08995658, 0.12867066, + + -0.27142537, -0.16955489, 0.18521598, -0.12528358, + 0.00331409, 0.11167502, 0.02218599, -0.07309391, + + 0.09593632, -0.28361851, -0.0773851, 0.17199151, + -0.00075242, 0.33691186, -0.1536046, 0.16572715, + + -0.27916506, -0.27626723, 0.42615682, 0.3225764, + -0.37472126, -0.55655634, -0.05013514, 0.289112, + + -0.24418658, 0.07540751, -0.1940318, -0.08911639, + 0.00732617, 0.46737891, 0.26449674, 0.24888524, + + -0.17225097, -0.54660404, -0.38795233, 0.08389944, + 0.07736043, -0.28260678, 0.15666828, 1.14949894, + + -0.57454878, -0.64704704, 0.73235172, -0.34616736, + 0.21120001, -0.22927976, 0.02455296, -0.35906726, +}; + +class BaseSVDFOpModel : public SingleOpModelWithNNAPI { + public: + BaseSVDFOpModel(int batches, int units, int input_size, int memory_size, + int rank, + TensorType weights_feature_type = TensorType_FLOAT32, + TensorType weights_time_type = TensorType_FLOAT32) + : batches_(batches), + units_(units), + input_size_(input_size), + memory_size_(memory_size), + rank_(rank) { + input_ = AddInput(TensorType_FLOAT32); + weights_feature_ = AddInput(weights_feature_type); + weights_time_ = AddInput(weights_time_type); + bias_ = AddNullInput(); + state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions, + CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union()); + BuildInterpreter({ + {batches_, input_size_}, // Input tensor + {units_ * rank, input_size_}, // weights_feature tensor + {units_ * rank, memory_size_}, // weights_time tensor + {units_} // bias tensor + }); + } + + // Populates the weights_feature tensor. + void SetWeightsFeature(std::initializer_list f) { + PopulateTensor(weights_feature_, f); + } + + // Populates the weights_time tensor. + void SetWeightsTime(std::initializer_list f) { + PopulateTensor(weights_time_, f); + } + + // Populates the input tensor. + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + // Resets the state of SVDF op by filling it with 0's. + void ResetState() { + const int zero_buffer_size = rank_ * units_ * batches_ * memory_size_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + // Extracts the output tensor from the SVDF op. + std::vector GetOutput() { return ExtractVector(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + protected: + int input_; + int weights_feature_; + int weights_time_; + int bias_; + int state_; + int output_; + + int batches_; + int units_; + int input_size_; + int memory_size_; + int rank_; +}; + +class SVDFOpModel : public BaseSVDFOpModel { + public: + using BaseSVDFOpModel::BaseSVDFOpModel; + + void VerifyGoldens(float golden_input[], float golden_output[], + int golden_size, float tolerance = 1e-5) { + const int svdf_num_batches = num_batches(); + const int svdf_input_size = input_size(); + const int svdf_num_units = num_units(); + const int input_sequence_size = + golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches); + // Going over each input batch, setting the input tensor, invoking the SVDF + // op and checking the output with the expected golden values. + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = + golden_input + i * svdf_input_size * svdf_num_batches; + float* batch_end = batch_start + svdf_input_size * svdf_num_batches; + SetInput(0, batch_start, batch_end); + + Invoke(); + + const float* golden_start = + golden_output + i * svdf_num_units * svdf_num_batches; + const float* golden_end = + golden_start + svdf_num_units * svdf_num_batches; + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } + } +}; + +TEST(NNAPIDelegate, SVDFBlackBoxTestRank1) { + SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/1); + svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, + 0.22197971, 0.12416199, 0.27901134, 0.27557442, + 0.3905206, -0.36137494, -0.06634006, -0.10640851}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); + + svdf.ResetState(); + svdf.VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input)); +} + +TEST(NNAPIDelegate, SVDFBlackBoxTestRank2) { + SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/2); + svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, + 0.12416199, 0.15785322, 0.27901134, 0.3905206, + 0.21931258, -0.36137494, -0.10640851, 0.31053296, + -0.36118156, -0.0976817, -0.36916667, 0.22197971, + 0.15294972, 0.38031587, 0.27557442, 0.39635518, + -0.21580373, -0.06634006, -0.02702999, 0.27072677}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657, + + -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486, + 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187, + + -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589, + 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836, + + -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277, + -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214, + + 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326, + 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); + + svdf.ResetState(); + svdf.VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input)); +} + +class LSTMOpModel : public SingleOpModelWithNNAPI { + public: + LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, + bool use_peephole, bool use_projection_weights, + bool use_projection_bias, float cell_clip, float proj_clip, + const std::vector>& input_shapes, + const TensorType& weight_type = TensorType_FLOAT32) + : n_batch_(n_batch), + n_input_(n_input), + n_cell_(n_cell), + n_output_(n_output) { + input_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + input_to_input_weights_ = AddNullInput(); + } else { + input_to_input_weights_ = AddInput(weight_type); + } + + input_to_forget_weights_ = AddInput(weight_type); + input_to_cell_weights_ = AddInput(weight_type); + input_to_output_weights_ = AddInput(weight_type); + + if (use_cifg) { + recurrent_to_input_weights_ = AddNullInput(); + } else { + recurrent_to_input_weights_ = AddInput(weight_type); + } + + recurrent_to_forget_weights_ = AddInput(weight_type); + recurrent_to_cell_weights_ = AddInput(weight_type); + recurrent_to_output_weights_ = AddInput(weight_type); + + if (use_peephole) { + if (use_cifg) { + cell_to_input_weights_ = AddNullInput(); + } else { + cell_to_input_weights_ = AddInput(weight_type); + } + cell_to_forget_weights_ = AddInput(weight_type); + cell_to_output_weights_ = AddInput(weight_type); + } else { + cell_to_input_weights_ = AddNullInput(); + cell_to_forget_weights_ = AddNullInput(); + cell_to_output_weights_ = AddNullInput(); + } + + if (use_cifg) { + input_gate_bias_ = AddNullInput(); + } else { + input_gate_bias_ = AddInput(TensorType_FLOAT32); + } + forget_gate_bias_ = AddInput(TensorType_FLOAT32); + cell_bias_ = AddInput(TensorType_FLOAT32); + output_gate_bias_ = AddInput(TensorType_FLOAT32); + + if (use_projection_weights) { + projection_weights_ = AddInput(weight_type); + if (use_projection_bias) { + projection_bias_ = AddInput(TensorType_FLOAT32); + } else { + projection_bias_ = AddNullInput(); + } + } else { + projection_weights_ = AddNullInput(); + projection_bias_ = AddNullInput(); + } + + output_state_ = AddOutput(TensorType_FLOAT32); + cell_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, + cell_clip, proj_clip) + .Union()); + BuildInterpreter(input_shapes); + } + + void SetInputToInputWeights(std::initializer_list f) { + PopulateTensor(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + PopulateTensor(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + PopulateTensor(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + PopulateTensor(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + PopulateTensor(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + PopulateTensor(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + PopulateTensor(cell_to_output_weights_, f); + } + + void SetInputGateBias(std::initializer_list f) { + PopulateTensor(input_gate_bias_, f); + } + + void SetForgetGateBias(std::initializer_list f) { + PopulateTensor(forget_gate_bias_, f); + } + + void SetCellBias(std::initializer_list f) { + PopulateTensor(cell_bias_, f); + } + + void SetOutputGateBias(std::initializer_list f) { + PopulateTensor(output_gate_bias_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + PopulateTensor(projection_weights_, f); + } + + void SetProjectionBias(std::initializer_list f) { + PopulateTensor(projection_bias_, f); + } + + void ResetOutputState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(output_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void ResetCellState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(cell_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void SetInput(int offset, const float* begin, const float* end) { + PopulateTensor(input_, offset, const_cast(begin), + const_cast(end)); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + int num_cells() { return n_cell_; } + int num_batches() { return n_batch_; } + + protected: + int input_; + int input_to_input_weights_; + int input_to_forget_weights_; + int input_to_cell_weights_; + int input_to_output_weights_; + + int recurrent_to_input_weights_; + int recurrent_to_forget_weights_; + int recurrent_to_cell_weights_; + int recurrent_to_output_weights_; + + int cell_to_input_weights_; + int cell_to_forget_weights_; + int cell_to_output_weights_; + + int input_gate_bias_; + int forget_gate_bias_; + int cell_bias_; + int output_gate_bias_; + + int projection_weights_; + int projection_bias_; + int input_activation_state_; + int input_cell_state_; + + int output_; + int output_state_; + int cell_state_; + + int n_batch_; + int n_input_; + int n_cell_; + int n_output_; +}; + +class BaseLstmTest : public ::testing::Test { + protected: + // Weights of the LSTM model. Some are optional. + std::initializer_list input_to_input_weights_; + std::initializer_list input_to_cell_weights_; + std::initializer_list input_to_forget_weights_; + std::initializer_list input_to_output_weights_; + std::initializer_list input_gate_bias_; + std::initializer_list cell_gate_bias_; + std::initializer_list forget_gate_bias_; + std::initializer_list output_gate_bias_; + std::initializer_list recurrent_to_input_weights_; + std::initializer_list recurrent_to_cell_weights_; + std::initializer_list recurrent_to_forget_weights_; + std::initializer_list recurrent_to_output_weights_; + std::initializer_list cell_to_input_weights_; + std::initializer_list cell_to_forget_weights_; + std::initializer_list cell_to_output_weights_; + std::initializer_list projection_weights_; + + // LSTM input is stored as num_batch x num_inputs vector. + std::vector> lstm_input_; + // LSTM output is stored as num_batch x num_outputs vector. + std::vector> lstm_golden_output_; + + // Compares output up to tolerance to the result of the lstm given the input. + void VerifyGoldens(const std::vector>& input, + const std::vector>& output, + LSTMOpModel* lstm, float tolerance = 1e-5) { + const int num_batches = input.size(); + EXPECT_GT(num_batches, 0); + const int num_inputs = lstm->num_inputs(); + EXPECT_GT(num_inputs, 0); + const int input_sequence_size = input[0].size() / num_inputs; + EXPECT_GT(input_sequence_size, 0); + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* batch_start = input[b].data() + i * num_inputs; + const float* batch_end = batch_start + num_inputs; + + lstm->SetInput(b * lstm->num_inputs(), batch_start, batch_end); + } + + lstm->Invoke(); + + const int num_outputs = lstm->num_outputs(); + std::vector expected; + for (int b = 0; b < num_batches; ++b) { + const float* golden_start_batch = output[b].data() + i * num_outputs; + const float* golden_end_batch = golden_start_batch + num_outputs; + expected.insert(expected.end(), golden_start_batch, golden_end_batch); + } + EXPECT_THAT(lstm->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } + } +}; + +class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}; + input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, -0.29909778}; + input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}; + input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, + -0.1556896, 0.19487578}; + input_gate_bias_ = {0., 0., 0., 0.}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_input_weights_ = { + -0.0063535, -0.2042388, 0.31454784, -0.35746509, + 0.28902304, 0.08183324, -0.16555229, 0.02286911, + -0.13566875, 0.03034258, 0.48091322, -0.12528998, + 0.24077177, -0.51332325, -0.33502164, 0.10629296}; + + recurrent_to_cell_weights_ = { + -0.3407414, 0.24443203, -0.2078532, 0.26320225, + 0.05695659, -0.00123841, -0.4744786, -0.35869038, + -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}; + + recurrent_to_forget_weights_ = { + -0.48684245, -0.06655136, 0.42224967, 0.2112639, + 0.27654213, 0.20864892, -0.07646349, 0.45877004, + 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}; + + recurrent_to_output_weights_ = { + 0.43385774, -0.17194885, 0.2718237, 0.09215671, + 0.24107647, -0.39835793, 0.18212086, 0.01301402, + 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765, + -0.03716109, 0.12507336, 0.41193449, -0.20860538, + -0.15053082, 0.09120187, 0.24278517, -0.12222792}}; + } +}; + +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight_tensor + {n_cell, n_output}, // recurrent_to_forget_weight_tensor + {n_cell, n_output}, // recurrent_to_cell_weight_tensor + {n_cell, n_output}, // recurrent_to_output_weight_tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} + +class CifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726, + 0.05100781, 0.04717243, 0.48944736, + -0.38535351, -0.17212132}; + + input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, + 0.24407166, 0.33826375}; + + input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_cell_weights_ = { + 0.54066205, -0.32668582, -0.43562764, -0.56094903, + 0.42957711, 0.01841056, -0.32764608, -0.33027974, + -0.10826075, 0.20675004, 0.19069612, -0.03026325, + -0.54532051, 0.33003211, 0.44901288, 0.21193194}; + + recurrent_to_forget_weights_ = { + -0.13832897, -0.0515101, -0.2359007, -0.16661474, + -0.14340827, 0.36986142, 0.23414481, 0.55899, + 0.10798943, -0.41174671, 0.17751795, -0.34484994, + -0.35874045, -0.11352962, 0.27268326, 0.54058349}; + + recurrent_to_output_weights_ = { + 0.41613156, 0.42610586, -0.16495961, -0.5663873, + 0.30579174, -0.05115908, -0.33941799, 0.23364776, + 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}; + + cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408, + 0.31544167}; + cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703, + -0.77109635}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646, + -0.42312205, -0.01218222, 0.24201041, -0.08124574, + -0.358325, -0.04621704, 0.21641694, -0.06471302}}; + } +}; + +TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} + +class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = { + 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}; + + input_to_forget_weights_ = { + -0.0018401089, -0.004852237, 0.03698424, 0.014181704, + 0.028273236, -0.016726194, -0.05249759, -0.10204261, + 0.00861066, -0.040979505, -0.009899187, 0.01923892, + -0.028177269, -0.08535103, -0.14585495, 0.10662567, + -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, + 0.0814421, -0.12257899, -0.033945758, -0.031303465, + 0.045630626, 0.06843887, -0.13492945, -0.012480007, + -0.0811829, -0.07224499, -0.09628791, 0.045100946, + 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, + 0.052625068, 0.12784666, 0.07077897, 0.025725935, + 0.04165009, 0.07241905, 0.018668644, -0.037377294, + -0.06277783, -0.08833636, -0.040120605, -0.011405586, + -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, + 0.13396506, -0.08402166, -0.01901462, -0.044678304, + -0.07720565, 0.014350063, -0.11757958, -0.0652038, + -0.08185733, -0.076754324, -0.092614375, 0.10405491, + 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, + -0.054523353, 0.02582715, 0.02327355, -0.011857179, + -0.0011980024, -0.034641717, -0.026125094, -0.17582615, + -0.15923657, -0.27486774, -0.0006143371, 0.0001771948, + -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}; + + input_to_cell_weights_ = { + -0.04580283, -0.09549462, -0.032418985, -0.06454633, + -0.043528453, 0.043018587, -0.049152344, -0.12418144, + -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, + -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, + 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, + -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, + 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, + 0.053568836, 0.06408714, 0.12835667, -0.008714329, + -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, + -0.036999565, -0.028842626, -0.0033637602, -0.017012902, + -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, + -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, + 0.07463894, 0.0075130584, 0.012850982, 0.04555431, + 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}; + + input_to_output_weights_ = { + -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}; + + input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666, + 0.053110216, -0.06928846, -0.13942584, -0.11816189, + 0.19483899, 0.03652339, -0.10250295, 0.036714908, + -0.18426876, 0.036065217, 0.21810818, 0.02383196, + -0.043370757, 0.08690144, -0.04444982, 0.00030581196}; + + forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}; + + cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}; + + output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113, + 0.027195795, 0.35373217, -0.018957434, 0.008907322, + -0.0762701, 0.12018895, 0.04216877, 0.0022856654, + 0.040952638, 0.3147856, 0.08225149, -0.057416286, + -0.14995944, -0.008040261, 0.13208859, 0.029760877}; + + recurrent_to_input_weights_ = { + -0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}; + + recurrent_to_cell_weights_ = { + -0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}; + + recurrent_to_forget_weights_ = { + -0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}; + + recurrent_to_output_weights_ = { + 0.025825322, -0.05813119, 0.09495884, -0.045984812, + -0.01255415, -0.0026479573, -0.08196161, -0.054914974, + -0.0046604523, -0.029587349, -0.044576716, -0.07480124, + -0.082868785, 0.023254942, 0.027502948, -0.0039728214, + -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, + -0.08829125, -0.005139627, -0.08989442, -0.0555066, + 0.13596267, -0.025062224, -0.048351806, -0.03850004, + 0.07266485, -0.022414139, 0.05940088, 0.075114764, + 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, + 0.014416728, 0.043229222, 0.034178585, -0.07530371, + 0.035837382, -0.085607, -0.007721233, -0.03287832, + -0.043848954, -0.06404588, -0.06632928, -0.073643476, + 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, + -0.030063879, 0.008801774, -0.023021035, -0.019558564, + 0.05158114, -0.010947698, -0.011825728, 0.0075720972, + 0.0699727, -0.0039981045, 0.069350146, 0.08799282, + 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, + -0.00924166, 0.0046702605, -0.036598757, -0.08811812, + 0.10522024, -0.032441203, 0.008176899, -0.04454919, + 0.07058152, 0.0067963637, 0.039206743, 0.03259838, + 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, + 0.036879618, 0.043357447, 0.028362012, -0.05908629, + 0.0059240665, -0.04995891, -0.019187413, 0.0276265, + -0.01628143, 0.0025863599, 0.08800015, 0.035250366, + -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, + -0.009660886, 0.019076364, 0.018299393, -0.046004917, + 0.08891175, 0.0431396, -0.026327137, -0.051502608, + 0.08979574, -0.051670972, 0.04940282, -0.07491107, + -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, + -0.035936575, -0.011681591, 0.064818054, 0.0073146066, + -0.021745546, -0.043124277, -0.06471268, -0.07053354, + -0.029321948, -0.05330136, 0.016933719, -0.053782392, + 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, + -0.07924483, 0.06936997, 0.0034815092, -0.007305279, + -0.037325785, -0.07251102, -0.033633437, -0.08677009, + 0.091591336, -0.14165086, 0.021752775, 0.019683983, + 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, + 0.1183656, -0.0010731248, -0.023590032, -0.072285876, + -0.0724771, -0.026382286, -0.0014920527, 0.042667855, + 0.0018776858, 0.02986552, 0.009814309, 0.0733756, + 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, + -0.010036754, 0.02576849, -0.08307328, 0.010112348, + 0.042521734, -0.05869831, -0.071689695, 0.03876447, + -0.13275425, -0.0352966, -0.023077697, 0.10285965, + 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, + -0.08271222, -0.0030240538, -0.016368777, 0.1070414, + 0.042672627, 0.013456989, -0.0437609, -0.022309763, + 0.11576483, 0.04108048, 0.061026827, -0.0190714, + -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, + -0.023771819, -0.01965048, 0.007955471, -0.043740474, + 0.03346837, -0.10549954, 0.090567775, 0.042013682, + -0.03176985, 0.12569028, -0.02421228, -0.029526481, + 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, + -0.06861939, -0.021256343, -0.041093912, -0.06669611, + 0.035498552, 0.021757556, -0.09302526, -0.015403468, + -0.06614931, -0.051798206, -0.013874718, 0.03630673, + 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, + -0.020674974, -0.03944324, -0.008110165, -0.11113267, + 0.08484226, 0.043586485, 0.040582247, 0.0968012, + -0.065249965, -0.028036479, 0.0050708856, 0.0017462453, + 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, + -0.11768019, 0.085926116, -0.08251791, -0.045081906, + 0.0948852, 0.068401024, 0.024856757, 0.06978981, + -0.057309967, -0.012775832, -0.0032452994, 0.01977615, + -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }; + + cell_to_input_weights_ = { + 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}; + + cell_to_forget_weights_ = { + -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}; + + cell_to_output_weights_ = { + 0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}; + + projection_weights_ = { + -0.009802181, 0.09401916, 0.0717386, -0.13895074, + 0.09641832, 0.060420845, 0.08539281, 0.054285463, + 0.061395317, 0.034448683, -0.042991187, 0.019801661, + -0.16840284, -0.015726732, -0.23041931, -0.024478018, + -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, + 0.16169067, 0.22465782, -0.03993472, -0.004017731, + 0.08633481, -0.28869787, 0.08682067, 0.17240396, + 0.014975425, 0.056431185, 0.031037588, 0.16702051, + 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, + 0.014777949, -0.20203483, 0.094781205, 0.19100232, + 0.13987629, -0.036132768, -0.06426278, -0.05108664, + 0.13221376, 0.009441198, -0.16715929, 0.15859416, + -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, + 0.0046453946, 0.050794356, 0.10770313, -0.20790008, + -0.07149004, -0.11425117, 0.008225835, -0.035802525, + 0.14374903, 0.15262283, 0.048710253, 0.1847461, + -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, + 0.016261552, 0.022461696, 0.12689082, -0.043589946, + -0.12035478, -0.08361797, -0.050666027, -0.1248618, + -0.1275799, -0.071875185, 0.07377272, 0.09944291, + -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, + 0.041546922, -0.20424393, 0.06907816, 0.050412357, + 0.00724631, 0.039827548, 0.12449835, 0.10747581, + 0.13708383, 0.09134148, -0.12617786, -0.06428341, + 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, + 0.022913318, -0.042050496, 0.16842307, -0.060597885, + 0.10531834, -0.06411776, -0.07451711, -0.03410368, + -0.13393489, 0.06534304, 0.003620307, 0.04490757, + 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, + -0.024575593, -0.036445823, 0.07155557, 0.009672501, + -0.02328883, 0.009533515, -0.03606021, -0.07421458, + -0.028082801, -0.2678904, -0.13221288, 0.18419984, + -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, + 0.14494097, -0.12522776, -0.098633975, -0.10766018, + -0.08317623, 0.08594209, 0.07749552, 0.039474737, + 0.1776665, -0.07409566, -0.0477268, 0.29323658, + 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, + 0.10034707, 0.045594677, 0.0635285, -0.0715442, + -0.089667566, -0.10811871, 0.00026344223, 0.08298446, + -0.009525053, 0.006585689, -0.24567553, -0.09450807, + 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, + 0.067035615, 0.19271925, -0.0032889997, -0.043264326, + 0.09663576, -0.057112187, -0.10100678, 0.0628376, + 0.04447668, 0.017961001, -0.10094388, -0.10190601, + 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, + -0.1947263, 0.02251204, 0.11216432, -0.10307853, + 0.17351969, -0.039091777, 0.08066188, -0.00561982, + 0.12633002, 0.11335965, -0.0088127935, -0.019777594, + 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, + -0.07468996, -0.0855457, 0.099339016, -0.07580735, + -0.13775392, 0.08434318, 0.08330512, -0.12131499, + 0.031935584, 0.09180414, -0.08876437, -0.08049874, + 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, + 0.04331237, 0.04299654, -0.036394123, -0.12915532, + 0.09793732, 0.07512415, -0.11319543, -0.032502122, + 0.15661901, 0.07671967, -0.005491124, -0.19379048, + -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, + -0.09334311, 0.15026465, -0.15493552, -0.057762887, + -0.11604192, -0.262013, -0.01391798, 0.012185008, + 0.11156489, -0.07483202, 0.06693364, -0.26151478, + 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, + 0.030635102, 0.010969227, 0.11109743, 0.010919218, + 0.027526086, 0.13519906, 0.01891392, -0.046839405, + -0.040167913, 0.017953383, -0.09700955, 0.0061885654, + -0.07000971, 0.026893595, -0.038844477, 0.14543656}; + + lstm_input_ = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0 + 0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1 + 0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2 + 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3 + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0 + 0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1 + 0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2 + 0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3 + }; + + lstm_golden_output_ = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, + -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, + -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, + 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, + -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, + -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, + 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, + 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, + 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, + -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, + -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, + -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, + 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, + 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, + -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, + -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, + 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, + 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, + 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, + -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, + -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + } +}; + +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} + +class BaseReduceOpModel : public SingleOpModelWithNNAPI { + public: + void SetAxis(const std::vector& data) { PopulateTensor(axis_, data); } + + template + void SetInput(std::vector data) { + PopulateTensor(input_, data); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } + + int Input() { return input_; } + + protected: + int input_; + int axis_; + int output_; +}; + +// Model for the tests case where axis is a const tensor. +class MeanOpConstModel : public BaseReduceOpModel { + public: + MeanOpConstModel(const TensorData& input, const TensorData& output, + std::initializer_list axis_shape, + std::initializer_list axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// Tests for reduce_mean +TEST(NNAPIDelegate, MeanFloatNotKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, + {4}, {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); +} + +TEST(NNAPIDelegate, MeanFloatKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, + {2}, {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); +} + +class BaseEmbeddingLookupOpModel : public SingleOpModelWithNNAPI { + public: + BaseEmbeddingLookupOpModel(std::initializer_list index_shape, + std::initializer_list weight_shape, + TensorType weight_type = TensorType_FLOAT32) { + input_ = AddInput(TensorType_INT32); + weight_ = AddInput(weight_type); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0); + BuildInterpreter({index_shape, weight_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int weight_; + int output_; +}; + +class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { + public: + using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel; + + void Set3DWeightMatrix(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(weight_); + int rows = tensor->dims->data[0]; + int columns = tensor->dims->data[1]; + int features = tensor->dims->data[2]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < columns; j++) { + for (int k = 0; k < features; k++) { + tensor->data.f[(i * columns + j) * features + k] = function(i, j, k); + } + } + } + } +}; + +TEST(NNAPIDelegate, EmbeddingLookupSimpleTest) { + EmbeddingLookupOpModel m({3}, {3, 2, 4}); + m.SetInput({1, 0, 2}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }))); +} + +class HashtableLookupOpModel : public SingleOpModelWithNNAPI { + public: + HashtableLookupOpModel(std::initializer_list lookup_shape, + std::initializer_list key_shape, + std::initializer_list value_shape, + TensorType type) { + lookup_ = AddInput(TensorType_INT32); + key_ = AddInput(TensorType_INT32); + value_ = AddInput(type); + output_ = AddOutput(type); + hit_ = AddOutput(TensorType_UINT8); + SetBuiltinOp(BuiltinOperator_HASHTABLE_LOOKUP, BuiltinOptions_NONE, 0); + BuildInterpreter({lookup_shape, key_shape, value_shape}); + } + + void SetLookup(std::initializer_list data) { + PopulateTensor(lookup_, data); + } + + void SetHashtableKey(std::initializer_list data) { + PopulateTensor(key_, data); + } + + void SetHashtableValue(const std::vector& content) { + PopulateStringTensor(value_, content); + } + + void SetHashtableValue(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(value_); + int rows = tensor->dims->data[0]; + for (int i = 0; i < rows; i++) { + tensor->data.f[i] = function(i); + } + } + + void SetHashtableValue(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(value_); + int rows = tensor->dims->data[0]; + int features = tensor->dims->data[1]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < features; j++) { + tensor->data.f[i * features + j] = function(i, j); + } + } + } + + std::vector GetStringOutput() { + TfLiteTensor* output = interpreter_->tensor(output_); + int num = GetStringCount(output); + std::vector result(num); + for (int i = 0; i < num; i++) { + auto ref = GetString(output, i); + result[i] = string(ref.str, ref.len); + } + return result; + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetHit() { return ExtractVector(hit_); } + + private: + int lookup_; + int key_; + int value_; + int output_; + int hit_; +}; + +TEST(NNAPIDelegate, HashtableLookupTest2DInput) { + HashtableLookupOpModel m({4}, {3}, {3, 2}, TensorType_FLOAT32); + + m.SetLookup({1234, -292, -11, 0}); + m.SetHashtableKey({-11, 0, 1234}); + m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 2.0, 2.1, // 2-nd item + 0, 0, // Not found + 0.0, 0.1, // 0-th item + 1.0, 1.1, // 1-st item + }))); + EXPECT_THAT(m.GetHit(), ElementsAreArray({ + 1, + 0, + 1, + 1, + })); +} + +TEST(NNAPIDelegate, HashtableLookupTest1DInput) { + HashtableLookupOpModel m({4}, {3}, {3}, TensorType_FLOAT32); + + m.SetLookup({1234, -292, -11, 0}); + m.SetHashtableKey({-11, 0, 1234}); + m.SetHashtableValue([](int i) { return i * i / 10.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.4, // 2-nd item + 0, // Not found + 0.0, // 0-th item + 0.1, // 1-st item + }))); + EXPECT_THAT(m.GetHit(), ElementsAreArray({ + 1, + 0, + 1, + 1, + })); +} } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh index 840015a7fad173dbd2ea353786871dd4e89abb98..8c7df474d55a85d7a6659b436e33ebf7632ab960 100755 --- a/tensorflow/contrib/lite/download_dependencies.sh +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -35,7 +35,7 @@ GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.g ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip" FARMHASH_URL="https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz" -FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/master.zip" +FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/v1.8.0.zip" FFT2D_URL="https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz" # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD index dd2cd173246719976d7cd6e52d65f63125b5b2db..4d2437e7d3714e1b8b427b0c6197b295c0355b07 100644 --- a/tensorflow/contrib/lite/examples/android/BUILD +++ b/tensorflow/contrib/lite/examples/android/BUILD @@ -37,6 +37,7 @@ android_binary( "@tflite_conv_actions_frozen//:conv_actions_frozen.tflite", "//tensorflow/contrib/lite/examples/android/app/src/main/assets:conv_actions_labels.txt", "@tflite_mobilenet_ssd//:mobilenet_ssd.tflite", + "@tflite_mobilenet_ssd_quant//:detect.tflite", "//tensorflow/contrib/lite/examples/android/app/src/main/assets:box_priors.txt", "//tensorflow/contrib/lite/examples/android/app/src/main/assets:coco_labels_list.txt", ], diff --git a/tensorflow/contrib/lite/examples/android/app/README.md b/tensorflow/contrib/lite/examples/android/app/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cbdeeac8790d93210a6c637953605b4ca270d3f6 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/app/README.md @@ -0,0 +1,19 @@ +# TF Lite Android App Example + +## Building from Source with Bazel + +1. Install [Bazel](https://docs.bazel.build/versions/master/install.html), the Android NDK and SDK. The recommended versions are specified on this [webpage](https://www.tensorflow.org/mobile/tflite/demo_android#build_tensorflow_lite_and_the_demo_app_from_source). + +2. Build this demo app with Bazel. The demo needs C++11. We configure the fat_apk_cpu flag to package support for 4 hardware variants. You may replace it with --config=android_arm64 on a 64-bit device and --config=android_arm for 32-bit device: + + ```shell + bazel build -c opt --cxxopt='--std=c++11' --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ + //tensorflow/contrib/lite/examples/android:tflite_demo + ``` + +3. Install the demo on a + [debug-enabled device](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install): + + ```shell + adb install bazel-bin/tensorflow/contrib/lite/examples/android/tflite_demo.apk + ``` diff --git a/tensorflow/contrib/lite/examples/android/app/build.gradle b/tensorflow/contrib/lite/examples/android/app/build.gradle index 8e0a98ed63f99b7477cdb2f851a19cd31b45f314..eb7fd705e18f53eb026600207faefa3d2bb072af 100644 --- a/tensorflow/contrib/lite/examples/android/app/build.gradle +++ b/tensorflow/contrib/lite/examples/android/app/build.gradle @@ -9,7 +9,7 @@ android { targetSdkVersion 26 versionCode 1 versionName "1.0" - testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" // Remove this block. jackOptions { @@ -51,7 +51,7 @@ apply from: "download-models.gradle" dependencies { compile fileTree(dir: 'libs', include: ['*.jar']) - androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + androidTestCompile('androidx.test.espresso:espresso-core:3.1.0-alpha3', { exclude group: 'com.android.support', module: 'support-annotations' }) compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly' diff --git a/tensorflow/contrib/lite/examples/android/app/download-models.gradle b/tensorflow/contrib/lite/examples/android/app/download-models.gradle index 8e65dc076f2a8daaddf01ceab6796b8ed1127af3..c100e37c16f38a65f7b1f64a3f6e3eaa1477e8eb 100644 --- a/tensorflow/contrib/lite/examples/android/app/download-models.gradle +++ b/tensorflow/contrib/lite/examples/android/app/download-models.gradle @@ -12,8 +12,9 @@ def models = ['conv_actions_tflite.zip', 'mobilenet_ssd_tflite_v1.zip', - 'mobilenet_v1_224_android_quant_2017_11_08.zip'] -// LINT.ThenChange(//tensorflow/examples/android/BUILD) + 'mobilenet_v1_224_android_quant_2017_11_08.zip', + 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip'] +// LINT.ThenChange(//tensorflow/contrib/lite/examples/android/BUILD) // Root URL for model archives def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models/tflite' diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/assets/pets_labels_list.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/pets_labels_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..d581f733e48ff8c2ba88162ee56b5e9d12aec7de --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/app/src/main/assets/pets_labels_list.txt @@ -0,0 +1,38 @@ +??? +Abyssinian +american_bulldog +american_pit_bull_terrier +basset_hound +beagle +Bengal +Birman +Bombay +boxer +British_Shorthair +chihuahua +Egyptian_Mau +english_cocker_spaniel +english_setter +german_shorthaired +great_pyrenees +havanese +japanese_chin +keeshond +leonberger +Maine_Coon +miniature_pinscher +newfoundland +Persian +pomeranian +pug +Ragdoll +Russian_Blue +saint_bernard +samoyed +scottish_terrier +shiba_inu +Siamese +Sphynx +staffordshire_bull_terrier +wheaten_terrier +yorkshire_terrier diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java index de997e454a1e33254cb7c2c932ca79d0072539fa..87160f6b3fb8c0d24e5df131d9becbb3eb6e2980 100644 --- a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java +++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java @@ -1,5 +1,5 @@ /* - * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * Copyright 2018 The TensorFlow Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,9 +50,10 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable // Configuration values for the prepackaged SSD model. private static final int TF_OD_API_INPUT_SIZE = 300; - private static final String TF_OD_API_MODEL_FILE = "mobilenet_ssd.tflite"; + private static final boolean TF_OD_API_IS_QUANTIZED = true; + private static final String TF_OD_API_MODEL_FILE = "detect.tflite"; private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/coco_labels_list.txt"; - + // Which detection model to use: by default uses Tensorflow Object Detection API frozen // checkpoints. private enum DetectorMode { @@ -107,7 +108,11 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable try { detector = TFLiteObjectDetectionAPIModel.create( - getAssets(), TF_OD_API_MODEL_FILE, TF_OD_API_LABELS_FILE, TF_OD_API_INPUT_SIZE); + getAssets(), + TF_OD_API_MODEL_FILE, + TF_OD_API_LABELS_FILE, + TF_OD_API_INPUT_SIZE, + TF_OD_API_IS_QUANTIZED); cropSize = TF_OD_API_INPUT_SIZE; } catch (final IOException e) { LOGGER.e("Exception initializing classifier!", e); diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java index bfb4a0a04bc90566736864bf62340d1032961858..9eb21de9d03e387d3c25b38171e154a358dc81ce 100644 --- a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java +++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java @@ -25,15 +25,14 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.util.ArrayList; -import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.PriorityQueue; -import java.util.StringTokenizer; import java.util.Vector; import org.tensorflow.demo.env.Logger; import org.tensorflow.lite.Interpreter; @@ -46,32 +45,35 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { private static final Logger LOGGER = new Logger(); // Only return this many results. - private static final int NUM_RESULTS = 1917; - private static final int NUM_CLASSES = 91; - - private static final float Y_SCALE = 10.0f; - private static final float X_SCALE = 10.0f; - private static final float H_SCALE = 5.0f; - private static final float W_SCALE = 5.0f; - + private static final int NUM_DETECTIONS = 10; + private boolean isModelQuantized; + // Float model + private static final float IMAGE_MEAN = 128.0f; + private static final float IMAGE_STD = 128.0f; + // Number of threads in the java app + private static final int NUM_THREADS = 4; // Config values. private int inputSize; - - private final float[][] boxPriors = new float[4][NUM_RESULTS]; - // Pre-allocated buffers. private Vector labels = new Vector(); private int[] intValues; + // outputLocations: array of shape [Batchsize, NUM_DETECTIONS,4] + // contains the location of detected boxes private float[][][] outputLocations; - private float[][][] outputClasses; - - float[][][][] img; + // outputClasses: array of shape [Batchsize, NUM_DETECTIONS] + // contains the classes of detected boxes + private float[][] outputClasses; + // outputScores: array of shape [Batchsize, NUM_DETECTIONS] + // contains the scores of detected boxes + private float[][] outputScores; + // numDetections: array of shape [Batchsize] + // contains the number of detected boxes + private float[] numDetections; + + private ByteBuffer imgData; private Interpreter tfLite; - private float expit(final float x) { - return (float) (1. / (1. + Math.exp(-x))); - } /** Memory-map the model file in Assets. */ private static MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename) @@ -84,77 +86,24 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); } - private void loadCoderOptions( - final AssetManager assetManager, final String locationFilename, final float[][] boxPriors) - throws IOException { - // Try to be intelligent about opening from assets or sdcard depending on prefix. - final String assetPrefix = "file:///android_asset/"; - InputStream is; - if (locationFilename.startsWith(assetPrefix)) { - is = assetManager.open(locationFilename.split(assetPrefix, -1)[1]); - } else { - is = new FileInputStream(locationFilename); - } - - final BufferedReader reader = new BufferedReader(new InputStreamReader(is)); - - for (int lineNum = 0; lineNum < 4; ++lineNum) { - String line = reader.readLine(); - final StringTokenizer st = new StringTokenizer(line, ", "); - int priorIndex = 0; - while (st.hasMoreTokens()) { - final String token = st.nextToken(); - try { - final float number = Float.parseFloat(token); - boxPriors[lineNum][priorIndex++] = number; - } catch (final NumberFormatException e) { - // Silently ignore. - } - } - if (priorIndex != NUM_RESULTS) { - throw new RuntimeException( - "BoxPrior length mismatch: " + priorIndex + " vs " + NUM_RESULTS); - } - } - - LOGGER.i("Loaded box priors!"); - } - - void decodeCenterSizeBoxes(float[][][] predictions) { - for (int i = 0; i < NUM_RESULTS; ++i) { - float ycenter = predictions[0][i][0] / Y_SCALE * boxPriors[2][i] + boxPriors[0][i]; - float xcenter = predictions[0][i][1] / X_SCALE * boxPriors[3][i] + boxPriors[1][i]; - float h = (float) Math.exp(predictions[0][i][2] / H_SCALE) * boxPriors[2][i]; - float w = (float) Math.exp(predictions[0][i][3] / W_SCALE) * boxPriors[3][i]; - - float ymin = ycenter - h / 2.f; - float xmin = xcenter - w / 2.f; - float ymax = ycenter + h / 2.f; - float xmax = xcenter + w / 2.f; - - predictions[0][i][0] = ymin; - predictions[0][i][1] = xmin; - predictions[0][i][2] = ymax; - predictions[0][i][3] = xmax; - } - } - /** * Initializes a native TensorFlow session for classifying images. * * @param assetManager The asset manager to be used to load assets. * @param modelFilename The filepath of the model GraphDef protocol buffer. * @param labelFilename The filepath of label file for classes. + * @param inputSize The size of image input + * @param isQuantized Boolean representing model is quantized or not */ public static Classifier create( final AssetManager assetManager, final String modelFilename, final String labelFilename, - final int inputSize) throws IOException { + final int inputSize, + final boolean isQuantized) + throws IOException { final TFLiteObjectDetectionAPIModel d = new TFLiteObjectDetectionAPIModel(); - d.loadCoderOptions(assetManager, "file:///android_asset/box_priors.txt", d.boxPriors); - InputStream labelsInput = null; String actualFilename = labelFilename.split("file:///android_asset/")[1]; labelsInput = assetManager.open(actualFilename); @@ -175,12 +124,23 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { throw new RuntimeException(e); } + d.isModelQuantized = isQuantized; // Pre-allocate buffers. - d.img = new float[1][inputSize][inputSize][3]; - + int numBytesPerChannel; + if (isQuantized) { + numBytesPerChannel = 1; // Quantized + } else { + numBytesPerChannel = 4; // Floating point + } + d.imgData = ByteBuffer.allocateDirect(1 * d.inputSize * d.inputSize * 3 * numBytesPerChannel); + d.imgData.order(ByteOrder.nativeOrder()); d.intValues = new int[d.inputSize * d.inputSize]; - d.outputLocations = new float[1][NUM_RESULTS][4]; - d.outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES]; + + d.tfLite.setNumThreads(NUM_THREADS); + d.outputLocations = new float[1][NUM_DETECTIONS][4]; + d.outputClasses = new float[1][NUM_DETECTIONS]; + d.outputScores = new float[1][NUM_DETECTIONS]; + d.numDetections = new float[1]; return d; } @@ -196,25 +156,37 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { // on the provided parameters. bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); + imgData.rewind(); for (int i = 0; i < inputSize; ++i) { for (int j = 0; j < inputSize; ++j) { - int pixel = intValues[j * inputSize + i]; - img[0][j][i][2] = (float) (pixel & 0xFF) / 128.0f - 1.0f; - img[0][j][i][1] = (float) ((pixel >> 8) & 0xFF) / 128.0f - 1.0f; - img[0][j][i][0] = (float) ((pixel >> 16) & 0xFF) / 128.0f - 1.0f; + int pixelValue = intValues[i * inputSize + j]; + if (isModelQuantized) { + // Quantized model + imgData.put((byte) ((pixelValue >> 16) & 0xFF)); + imgData.put((byte) ((pixelValue >> 8) & 0xFF)); + imgData.put((byte) (pixelValue & 0xFF)); + } else { // Float model + imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + } } } Trace.endSection(); // preprocessBitmap // Copy the input data into TensorFlow. Trace.beginSection("feed"); - outputLocations = new float[1][NUM_RESULTS][4]; - outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES]; + outputLocations = new float[1][NUM_DETECTIONS][4]; + outputClasses = new float[1][NUM_DETECTIONS]; + outputScores = new float[1][NUM_DETECTIONS]; + numDetections = new float[1]; - Object[] inputArray = {img}; + Object[] inputArray = {imgData}; Map outputMap = new HashMap<>(); outputMap.put(0, outputLocations); outputMap.put(1, outputClasses); + outputMap.put(2, outputScores); + outputMap.put(3, numDetections); Trace.endSection(); // Run the inference call. @@ -222,56 +194,26 @@ public class TFLiteObjectDetectionAPIModel implements Classifier { tfLite.runForMultipleInputsOutputs(inputArray, outputMap); Trace.endSection(); - decodeCenterSizeBoxes(outputLocations); - - // Find the best detections. - final PriorityQueue pq = - new PriorityQueue( - 1, - new Comparator() { - @Override - public int compare(final Recognition lhs, final Recognition rhs) { - // Intentionally reversed to put high confidence at the head of the queue. - return Float.compare(rhs.getConfidence(), lhs.getConfidence()); - } - }); - - // Scale them back to the input size. - for (int i = 0; i < NUM_RESULTS; ++i) { - float topClassScore = -1000f; - int topClassScoreIndex = -1; - - // Skip the first catch-all class. - for (int j = 1; j < NUM_CLASSES; ++j) { - float score = expit(outputClasses[0][i][j]); - - if (score > topClassScore) { - topClassScoreIndex = j; - topClassScore = score; - } - } - - if (topClassScore > 0.001f) { - final RectF detection = - new RectF( - outputLocations[0][i][1] * inputSize, - outputLocations[0][i][0] * inputSize, - outputLocations[0][i][3] * inputSize, - outputLocations[0][i][2] * inputSize); - - pq.add( - new Recognition( - "" + i, - labels.get(topClassScoreIndex), - outputClasses[0][i][topClassScoreIndex], - detection)); - } - } - - final ArrayList recognitions = new ArrayList(); - for (int i = 0; i < Math.min(pq.size(), 10); ++i) { - Recognition recog = pq.poll(); - recognitions.add(recog); + // Show the best detections. + // after scaling them back to the input size. + final ArrayList recognitions = new ArrayList<>(NUM_DETECTIONS); + for (int i = 0; i < NUM_DETECTIONS; ++i) { + final RectF detection = + new RectF( + outputLocations[0][i][1] * inputSize, + outputLocations[0][i][0] * inputSize, + outputLocations[0][i][3] * inputSize, + outputLocations[0][i][2] * inputSize); + // SSD Mobilenet V1 Model assumes class 0 is background class + // in label file and class labels start from 1 to number_of_classes+1, + // while outputClasses correspond to class index from 0 to number_of_classes + int labelOffset = 1; + recognitions.add( + new Recognition( + "" + i, + labels.get((int) outputClasses[0][i] + labelOffset), + outputScores[0][i], + detection)); } Trace.endSection(); // "recognizeImage" return recognitions; diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm index d74e275f0439b1ce56b29e0eadff5f211f6a4faa..30fee64a6f621016446eff58c305e88fda01fa76 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm @@ -315,7 +315,7 @@ static void GetTopN(const uint8_t* prediction, const int prediction_size, const labelLayers = [[NSMutableArray alloc] init]; oldPredictionValues = [[NSMutableDictionary alloc] init]; - NSString* graph_path = FilePathForResourceName(model_file_name, @"tflite"); + NSString* graph_path = FilePathForResourceName(model_file_name, model_file_type); model = tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String]); if (!model) { LOG(FATAL) << "Failed to mmap model " << graph_path; diff --git a/tensorflow/contrib/lite/examples/ios/camera/Podfile b/tensorflow/contrib/lite/examples/ios/camera/Podfile index c7d3b1c966eaa0de71f5c37a6a77b3881e30ddd7..cd8c39043f6df61ed83e75e80a42156fdba68642 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/Podfile +++ b/tensorflow/contrib/lite/examples/ios/camera/Podfile @@ -2,4 +2,4 @@ platform :ios, '8.0' inhibit_all_warnings! target 'tflite_camera_example' - pod 'TensorFlowLite' + pod 'TensorFlowLite', '0.1.7' diff --git a/tensorflow/contrib/lite/examples/ios/simple/Podfile b/tensorflow/contrib/lite/examples/ios/simple/Podfile index e4aca2be82d437a0225d2c15d3e486b0344aa978..c885398f44456bc1b7429b4f6605237bbc64e654 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/Podfile +++ b/tensorflow/contrib/lite/examples/ios/simple/Podfile @@ -2,4 +2,4 @@ platform :ios, '8.0' inhibit_all_warnings! target 'tflite_simple_example' - pod 'TensorFlowLite' + pod 'TensorFlowLite', '0.1.7' diff --git a/tensorflow/contrib/lite/examples/label_image/BUILD b/tensorflow/contrib/lite/examples/label_image/BUILD index c61445114ecc6dfbe4f2b6ab666b28a8aa746be3..fc55a78019b4a12b24231034a7e4b912869389f2 100644 --- a/tensorflow/contrib/lite/examples/label_image/BUILD +++ b/tensorflow/contrib/lite/examples/label_image/BUILD @@ -63,6 +63,7 @@ cc_test( data = [ "testdata/grace_hopper.bmp", ], + tags = ["no_oss"], deps = [ ":bitmap_helpers", "@com_google_googletest//:gtest", diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h index e36218e4f12057a362af47c48454f7930fc495f2..6fdcf78b69c6799fc2e666af1150efb88b55ff5c 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h @@ -16,11 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ #define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/contrib/lite/examples/label_image/label_image.h" #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/interpreter.h" @@ -28,8 +24,6 @@ limitations under the License. #include "tensorflow/contrib/lite/string_util.h" #include "tensorflow/contrib/lite/version.h" -#include "tensorflow/contrib/lite/examples/label_image/label_image.h" - namespace tflite { namespace label_image { diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc index 86d7d1cc4a625243791d5e7d5b746526a58efb6d..7c6f523041ad5a516f348c1b4f66683128838228 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -213,22 +213,23 @@ void RunInference(Settings* s) { } } - const int output_size = 1000; - const size_t num_results = 5; const float threshold = 0.001f; std::vector> top_results; int output = interpreter->outputs()[0]; + TfLiteIntArray* output_dims = interpreter->tensor(output)->dims; + // assume output dims to be something like (1, 1, ... ,size) + auto output_size = output_dims->data[output_dims->size - 1]; switch (interpreter->tensor(output)->type) { case kTfLiteFloat32: get_top_n(interpreter->typed_output_tensor(0), output_size, - num_results, threshold, &top_results, true); + s->number_of_results, threshold, &top_results, true); break; case kTfLiteUInt8: get_top_n(interpreter->typed_output_tensor(0), - output_size, num_results, threshold, &top_results, - false); + output_size, s->number_of_results, threshold, + &top_results, false); break; default: LOG(FATAL) << "cannot handle output type " @@ -259,6 +260,7 @@ void display_usage() { << "--labels, -l: labels for the model\n" << "--tflite_model, -m: model_name.tflite\n" << "--profiling, -p: [0|1], profiling or not\n" + << "--num_results, -r: number of results to show\n" << "--threads, -t: number of threads\n" << "--verbose, -v: [0|1] print more information\n" << "\n"; @@ -280,12 +282,13 @@ int Main(int argc, char** argv) { {"threads", required_argument, nullptr, 't'}, {"input_mean", required_argument, nullptr, 'b'}, {"input_std", required_argument, nullptr, 's'}, + {"num_results", required_argument, nullptr, 'r'}, {nullptr, 0, nullptr, 0}}; /* getopt_long stores the option index here. */ int option_index = 0; - c = getopt_long(argc, argv, "a:b:c:f:i:l:m:p:s:t:v:", long_options, + c = getopt_long(argc, argv, "a:b:c:f:i:l:m:p:r:s:t:v:", long_options, &option_index); /* Detect the end of the options. */ @@ -315,6 +318,10 @@ int Main(int argc, char** argv) { s.profiling = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; + case 'r': + s.number_of_results = + strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) + break; case 's': s.input_std = strtod(optarg, nullptr); break; diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.h b/tensorflow/contrib/lite/examples/label_image/label_image.h index 4b48014e1c77eca1eca081f0fe906441a5dcce22..34c223f713b9fe7692440a6b7538f00be995ad11 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.h +++ b/tensorflow/contrib/lite/examples/label_image/label_image.h @@ -34,6 +34,7 @@ struct Settings { string labels_file_name = "./labels.txt"; string input_layer_type = "uint8_t"; int number_of_threads = 4; + int number_of_results = 5; }; } // namespace label_image diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..50f8da66d06abaf0637866e85c04e80fee042071 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/BUILD @@ -0,0 +1,59 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/contrib/lite:build_def.bzl", + "tflite_cc_shared_object", + "tflite_copts", + "tflite_jni_binary", +) + +tflite_cc_shared_object( + name = "libtensorflowlite_c.so", + linkopts = select({ + "//tensorflow:darwin": [ + "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file + "$(location //tensorflow/contrib/lite/experimental/c:exported_symbols.lds)", + "-Wl,-install_name,@rpath/libtensorflowlite_c.so", + ], + "//tensorflow:windows": [], + "//conditions:default": [ + "-z defs", + "-Wl,--version-script", # This line must be directly followed by the version_script.lds file + "$(location //tensorflow/contrib/lite/experimental/c:version_script.lds)", + ], + }), + deps = [ + ":c_api", + ":exported_symbols.lds", + ":version_script.lds", + ], +) + +cc_library( + name = "c_api", + srcs = ["c_api.cc"], + hdrs = ["c_api.h"], + copts = tflite_copts(), + deps = [ + "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], +) + +cc_test( + name = "c_api_test", + size = "small", + srcs = ["c_api_test.cc"], + data = ["//tensorflow/contrib/lite:testdata/add.bin"], + deps = [ + ":c_api", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..9d29e8b3e055e86a9e68285d81de742e36452215 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/c_api.cc @@ -0,0 +1,122 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/c/c_api.h" + +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +struct _TFL_Interpreter { + std::unique_ptr impl; +}; + +// LINT.IfChange + +TFL_Interpreter* TFL_NewInterpreter(const void* model_data, + int32_t model_size) { + auto model = tflite::FlatBufferModel::BuildFromBuffer( + static_cast(model_data), static_cast(model_size)); + if (!model) { + return nullptr; + } + + tflite::ops::builtin::BuiltinOpResolver resolver; + tflite::InterpreterBuilder builder(*model, resolver); + std::unique_ptr interpreter_impl; + if (builder(&interpreter_impl) != kTfLiteOk) { + return nullptr; + } + + return new TFL_Interpreter{std::move(interpreter_impl)}; +} + +void TFL_DeleteInterpreter(TFL_Interpreter* interpreter) { delete interpreter; } + +int32_t TFL_InterpreterGetInputTensorCount(const TFL_Interpreter* interpreter) { + return static_cast(interpreter->impl->inputs().size()); +} + +TFL_Tensor* TFL_InterpreterGetInputTensor(const TFL_Interpreter* interpreter, + int32_t input_index) { + return interpreter->impl->tensor(interpreter->impl->inputs()[input_index]); +} + +TFL_Status TFL_InterpreterResizeInputTensor(TFL_Interpreter* interpreter, + int32_t input_index, + const int* input_dims, + int32_t input_dims_size) { + std::vector dims{input_dims, input_dims + input_dims_size}; + return interpreter->impl->ResizeInputTensor( + interpreter->impl->inputs()[input_index], dims); +} + +TFL_Status TFL_InterpreterAllocateTensors(TFL_Interpreter* interpreter) { + return interpreter->impl->AllocateTensors(); +} + +TFL_Status TFL_InterpreterInvoke(TFL_Interpreter* interpreter) { + return interpreter->impl->Invoke(); +} + +int32_t TFL_InterpreterGetOutputTensorCount( + const TFL_Interpreter* interpreter) { + return static_cast(interpreter->impl->outputs().size()); +} + +const TFL_Tensor* TFL_InterpreterGetOutputTensor( + const TFL_Interpreter* interpreter, int32_t output_index) { + return interpreter->impl->tensor(interpreter->impl->outputs()[output_index]); +} + +TFL_Type TFL_TensorType(const TFL_Tensor* tensor) { return tensor->type; } + +int32_t TFL_TensorNumDims(const TFL_Tensor* tensor) { + return tensor->dims->size; +} + +int32_t TFL_TensorDim(const TFL_Tensor* tensor, int32_t dim_index) { + return tensor->dims->data[dim_index]; +} + +size_t TFL_TensorByteSize(const TFL_Tensor* tensor) { return tensor->bytes; } + +TFL_Status TFL_TensorCopyFromBuffer(TFL_Tensor* tensor, const void* input_data, + int32_t input_data_size) { + if (tensor->bytes != static_cast(input_data_size)) { + return kTfLiteError; + } + memcpy(tensor->data.raw, input_data, input_data_size); + return kTfLiteOk; +} + +TFL_Status TFL_TensorCopyToBuffer(const TFL_Tensor* tensor, void* output_data, + int32_t output_data_size) { + if (tensor->bytes != static_cast(output_data_size)) { + return kTfLiteError; + } + memcpy(output_data, tensor->data.raw, output_data_size); + return kTfLiteOk; +} + +// LINT.ThenChange(//tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs) + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/tensorflow/contrib/lite/experimental/c/c_api.h b/tensorflow/contrib/lite/experimental/c/c_api.h new file mode 100644 index 0000000000000000000000000000000000000000..070f1add13c9904e1a2b3736001ada0e274fdc55 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/c_api.h @@ -0,0 +1,149 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_H_ + +#include + +// Eventually the various C APIs defined in context.h will be migrated into +// the appropriate /c/c_api*.h header. For now, we pull in existing definitions +// for convenience. +#include "tensorflow/contrib/lite/context.h" + +// -------------------------------------------------------------------------- +// Experimental C API for TensorFlowLite. +// +// The API leans towards simplicity and uniformity instead of convenience, as +// most usage will be by language-specific wrappers. +// +// Conventions: +// * We use the prefix TFL_ for everything in the API. + +#ifdef SWIG +#define TFL_CAPI_EXPORT +#else +#if defined(_WIN32) +#ifdef TF_COMPILE_LIBRARY +#define TFL_CAPI_EXPORT __declspec(dllexport) +#else +#define TFL_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TFL_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 +#endif // SWIG + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef TfLiteTensor TFL_Tensor; +typedef TfLiteStatus TFL_Status; +typedef TfLiteType TFL_Type; + +// -------------------------------------------------------------------------- +// TFL_Interpreter provides inference from a provided model. +typedef struct _TFL_Interpreter TFL_Interpreter; + +// Returns an interpreter for the provided model, or null on failure. +// +// NOTE: The client *must* explicitly allocate tensors before attempting to +// access input tensor data or invoke the interpreter. +TFL_CAPI_EXPORT extern TFL_Interpreter* TFL_NewInterpreter( + const void* model_data, int32_t model_size); + +// Destroys the interpreter. +TFL_CAPI_EXPORT extern void TFL_DeleteInterpreter(TFL_Interpreter* interpreter); + +// Returns the number of input tensors associated with the model. +TFL_CAPI_EXPORT extern int TFL_InterpreterGetInputTensorCount( + const TFL_Interpreter* interpreter); + +// Returns the tensor associated with the input index. +// REQUIRES: 0 <= input_index < TFL_InterpreterGetInputTensorCount(tensor) +TFL_CAPI_EXPORT extern TFL_Tensor* TFL_InterpreterGetInputTensor( + const TFL_Interpreter* interpreter, int32_t input_index); + +// Attempts to resize the specified input tensor. +// NOTE: After a resize, the client *must* explicitly allocate tensors before +// attempting to access the resized tensor data or invoke the interpreter. +// REQUIRES: 0 <= input_index < TFL_InterpreterGetInputTensorCount(tensor) +TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResizeInputTensor( + TFL_Interpreter* interpreter, int32_t input_index, const int* input_dims, + int32_t input_dims_size); + +// Updates allocations for all tensors, resizing dependent tensors using the +// specified input tensor dimensionality. +// +// This is a relatively expensive operation, and need only be called after +// creating the graph and/or resizing any inputs. +TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterAllocateTensors( + TFL_Interpreter* interpreter); + +// Runs inference for the loaded graph. +// +// NOTE: It is possible that the interpreter is not in a ready state to +// evaluate (e.g., if a ResizeInputTensor() has been performed without a call to +// AllocateTensors()). +TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterInvoke( + TFL_Interpreter* interpreter); + +// Returns the number of output tensors associated with the model. +TFL_CAPI_EXPORT extern int32_t TFL_InterpreterGetOutputTensorCount( + const TFL_Interpreter* interpreter); + +// Returns the tensor associated with the output index. +// REQUIRES: 0 <= input_index < TFL_InterpreterGetOutputTensorCount(tensor) +TFL_CAPI_EXPORT extern const TFL_Tensor* TFL_InterpreterGetOutputTensor( + const TFL_Interpreter* interpreter, int32_t output_index); + +// -------------------------------------------------------------------------- +// TFL_Tensor wraps data associated with a graph tensor. +// +// Note that, while the TFL_Tensor struct is not currently opaque, and its +// fields can be accessed directly, these methods are still convenient for +// language bindings. In the future the tensor struct will likely be made opaque +// in the public API. + +// Returns the type of a tensor element. +TFL_CAPI_EXPORT extern TFL_Type TFL_TensorType(const TFL_Tensor* tensor); + +// Returns the number of dimensions that the tensor has. +TFL_CAPI_EXPORT extern int32_t TFL_TensorNumDims(const TFL_Tensor* tensor); + +// Returns the length of the tensor in the "dim_index" dimension. +// REQUIRES: 0 <= dim_index < TFLiteTensorNumDims(tensor) +TFL_CAPI_EXPORT extern int32_t TFL_TensorDim(const TFL_Tensor* tensor, + int32_t dim_index); + +// Returns the size of the underlying data in bytes. +TFL_CAPI_EXPORT extern size_t TFL_TensorByteSize(const TFL_Tensor* tensor); + +// Copies from the provided input buffer into the tensor's buffer. +// REQUIRES: input_data_size == TFL_TensorByteSize(tensor) +TFL_CAPI_EXPORT extern TFL_Status TFL_TensorCopyFromBuffer( + TFL_Tensor* tensor, const void* input_data, int32_t input_data_size); + +// Copies to the provided output buffer from the tensor's buffer. +// REQUIRES: output_data_size == TFL_TensorByteSize(tensor) +TFL_CAPI_EXPORT extern TFL_Status TFL_TensorCopyToBuffer( + const TFL_Tensor* output_tensor, void* output_data, + int32_t output_data_size); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_H_ diff --git a/tensorflow/contrib/lite/experimental/c/c_api_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc925e00a6096c5e8abcc0fa68b335c4db4401c3 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/c_api_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/contrib/lite/experimental/c/c_api.h" + +#include +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/testing/util.h" + +namespace { + +TEST(CApiSimple, Smoke) { + tflite::FileCopyAllocation model_file( + "tensorflow/contrib/lite/testdata/add.bin", + tflite::DefaultErrorReporter()); + + TFL_Interpreter* interpreter = + TFL_NewInterpreter(model_file.base(), model_file.bytes()); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(TFL_InterpreterAllocateTensors(interpreter), kTfLiteOk); + + ASSERT_EQ(TFL_InterpreterGetInputTensorCount(interpreter), 1); + ASSERT_EQ(TFL_InterpreterGetOutputTensorCount(interpreter), 1); + + std::array input_dims = {2}; + ASSERT_EQ(TFL_InterpreterResizeInputTensor(interpreter, 0, input_dims.data(), + input_dims.size()), + kTfLiteOk); + ASSERT_EQ(TFL_InterpreterAllocateTensors(interpreter), kTfLiteOk); + + TFL_Tensor* input_tensor = TFL_InterpreterGetInputTensor(interpreter, 0); + ASSERT_NE(input_tensor, nullptr); + EXPECT_EQ(TFL_TensorType(input_tensor), kTfLiteFloat32); + EXPECT_EQ(TFL_TensorNumDims(input_tensor), 1); + EXPECT_EQ(TFL_TensorDim(input_tensor, 0), 2); + EXPECT_EQ(TFL_TensorByteSize(input_tensor), sizeof(float) * 2); + + std::array input = {1.f, 3.f}; + ASSERT_EQ(TFL_TensorCopyFromBuffer(input_tensor, input.data(), + input.size() * sizeof(float)), + kTfLiteOk); + + ASSERT_EQ(TFL_InterpreterInvoke(interpreter), kTfLiteOk); + + const TFL_Tensor* output_tensor = + TFL_InterpreterGetOutputTensor(interpreter, 0); + ASSERT_NE(output_tensor, nullptr); + EXPECT_EQ(TFL_TensorType(output_tensor), kTfLiteFloat32); + EXPECT_EQ(TFL_TensorNumDims(output_tensor), 1); + EXPECT_EQ(TFL_TensorDim(output_tensor, 0), 2); + EXPECT_EQ(TFL_TensorByteSize(output_tensor), sizeof(float) * 2); + + std::array output; + ASSERT_EQ(TFL_TensorCopyToBuffer(output_tensor, output.data(), + output.size() * sizeof(float)), + kTfLiteOk); + EXPECT_EQ(output[0], 3.f); + EXPECT_EQ(output[1], 9.f); + + TFL_DeleteInterpreter(interpreter); +} + +} // namespace + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/experimental/c/exported_symbols.lds b/tensorflow/contrib/lite/experimental/c/exported_symbols.lds new file mode 100644 index 0000000000000000000000000000000000000000..a3ddc6bc8d370b1715fb1ebf2a66122296330249 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/exported_symbols.lds @@ -0,0 +1 @@ +_TFL_* diff --git a/tensorflow/contrib/lite/experimental/c/version_script.lds b/tensorflow/contrib/lite/experimental/c/version_script.lds new file mode 100644 index 0000000000000000000000000000000000000000..c0c8a2bca19afed186e6f8c72a58989a79c7b251 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/c/version_script.lds @@ -0,0 +1,9 @@ +VERS_1.0 { + # Export symbols in c_api.h. + global: + *TFL_*; + + # Hide everything else. + local: + *; +}; diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/.gitignore b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c72a5cae9ebfb15f60961fe25e622663cad89a41 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/.gitignore @@ -0,0 +1,13 @@ +# Unity generated +Builds/ +Temp/ +Library/ +obj/ +# Visual Studio / MonoDevelop generated +*.csproj +*.unityproj +*.sln +*.suo +*.userprefs +# OS generated +.DS_Store diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite.meta new file mode 100644 index 0000000000000000000000000000000000000000..ed9337b53e880b62f70953f197613dcb1409d208 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 71d1b4219b1da4aeaa1cebbec324fc81 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples.meta new file mode 100644 index 0000000000000000000000000000000000000000..edcce00939a298683b15ea45a5ec92709c6abc4f --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: d948aead14abd4c88947c9886d16f774 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite.meta new file mode 100644 index 0000000000000000000000000000000000000000..36b35516f0cee064c8d8e4814a2ae515e28590ce --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: b810b85b794fa48fd93100acf5525e1f +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes.meta new file mode 100644 index 0000000000000000000000000000000000000000..d4133da49a88d38a57d074d28b903f9f18102413 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 154f4201e2e454d4696fa5834eaa3ad3 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity new file mode 100644 index 0000000000000000000000000000000000000000..bcf24b89e335781877a7046001ac4deb6fc55041 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity @@ -0,0 +1,477 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!29 &1 +OcclusionCullingSettings: + m_ObjectHideFlags: 0 + serializedVersion: 2 + m_OcclusionBakeSettings: + smallestOccluder: 5 + smallestHole: 0.25 + backfaceThreshold: 100 + m_SceneGUID: 00000000000000000000000000000000 + m_OcclusionCullingData: {fileID: 0} +--- !u!104 &2 +RenderSettings: + m_ObjectHideFlags: 0 + serializedVersion: 8 + m_Fog: 0 + m_FogColor: {r: 0.5, g: 0.5, b: 0.5, a: 1} + m_FogMode: 3 + m_FogDensity: 0.01 + m_LinearFogStart: 0 + m_LinearFogEnd: 300 + m_AmbientSkyColor: {r: 0.212, g: 0.227, b: 0.259, a: 1} + m_AmbientEquatorColor: {r: 0.114, g: 0.125, b: 0.133, a: 1} + m_AmbientGroundColor: {r: 0.047, g: 0.043, b: 0.035, a: 1} + m_AmbientIntensity: 1 + m_AmbientMode: 3 + m_SubtractiveShadowColor: {r: 0.42, g: 0.478, b: 0.627, a: 1} + m_SkyboxMaterial: {fileID: 0} + m_HaloStrength: 0.5 + m_FlareStrength: 1 + m_FlareFadeSpeed: 3 + m_HaloTexture: {fileID: 0} + m_SpotCookie: {fileID: 10001, guid: 0000000000000000e000000000000000, type: 0} + m_DefaultReflectionMode: 0 + m_DefaultReflectionResolution: 128 + m_ReflectionBounces: 1 + m_ReflectionIntensity: 1 + m_CustomReflection: {fileID: 0} + m_Sun: {fileID: 0} + m_IndirectSpecularColor: {r: 0, g: 0, b: 0, a: 1} +--- !u!157 &3 +LightmapSettings: + m_ObjectHideFlags: 0 + serializedVersion: 11 + m_GIWorkflowMode: 1 + m_GISettings: + serializedVersion: 2 + m_BounceScale: 1 + m_IndirectOutputScale: 1 + m_AlbedoBoost: 1 + m_TemporalCoherenceThreshold: 1 + m_EnvironmentLightingMode: 0 + m_EnableBakedLightmaps: 0 + m_EnableRealtimeLightmaps: 0 + m_LightmapEditorSettings: + serializedVersion: 9 + m_Resolution: 2 + m_BakeResolution: 40 + m_TextureWidth: 1024 + m_TextureHeight: 1024 + m_AO: 0 + m_AOMaxDistance: 1 + m_CompAOExponent: 1 + m_CompAOExponentDirect: 0 + m_Padding: 2 + m_LightmapParameters: {fileID: 0} + m_LightmapsBakeMode: 1 + m_TextureCompression: 1 + m_FinalGather: 0 + m_FinalGatherFiltering: 1 + m_FinalGatherRayCount: 256 + m_ReflectionCompression: 2 + m_MixedBakeMode: 2 + m_BakeBackend: 0 + m_PVRSampling: 1 + m_PVRDirectSampleCount: 32 + m_PVRSampleCount: 500 + m_PVRBounces: 2 + m_PVRFilterTypeDirect: 0 + m_PVRFilterTypeIndirect: 0 + m_PVRFilterTypeAO: 0 + m_PVRFilteringMode: 1 + m_PVRCulling: 1 + m_PVRFilteringGaussRadiusDirect: 1 + m_PVRFilteringGaussRadiusIndirect: 5 + m_PVRFilteringGaussRadiusAO: 2 + m_PVRFilteringAtrousPositionSigmaDirect: 0.5 + m_PVRFilteringAtrousPositionSigmaIndirect: 2 + m_PVRFilteringAtrousPositionSigmaAO: 1 + m_ShowResolutionOverlay: 1 + m_LightingDataAsset: {fileID: 0} + m_UseShadowmask: 1 +--- !u!196 &4 +NavMeshSettings: + serializedVersion: 2 + m_ObjectHideFlags: 0 + m_BuildSettings: + serializedVersion: 2 + agentTypeID: 0 + agentRadius: 0.5 + agentHeight: 2 + agentSlope: 45 + agentClimb: 0.4 + ledgeDropHeight: 0 + maxJumpAcrossDistance: 0 + minRegionArea: 2 + manualCellSize: 0 + cellSize: 0.16666667 + manualTileSize: 0 + tileSize: 256 + accuratePlacement: 0 + debug: + m_Flags: 0 + m_NavMeshData: {fileID: 0} +--- !u!1 &492081941 +GameObject: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + serializedVersion: 5 + m_Component: + - component: {fileID: 492081945} + - component: {fileID: 492081944} + - component: {fileID: 492081943} + - component: {fileID: 492081942} + m_Layer: 0 + m_Name: Main Camera + m_TagString: MainCamera + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!81 &492081942 +AudioListener: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 492081941} + m_Enabled: 1 +--- !u!124 &492081943 +Behaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 492081941} + m_Enabled: 1 +--- !u!20 &492081944 +Camera: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 492081941} + m_Enabled: 1 + serializedVersion: 2 + m_ClearFlags: 1 + m_BackGroundColor: {r: 0.21933319, g: 0.21933319, b: 0.21933319, a: 0} + m_NormalizedViewPortRect: + serializedVersion: 2 + x: 0 + y: 0 + width: 1 + height: 1 + near clip plane: 0.3 + far clip plane: 1000 + field of view: 60 + orthographic: 1 + orthographic size: 5 + m_Depth: -1 + m_CullingMask: + serializedVersion: 2 + m_Bits: 4294967295 + m_RenderingPath: -1 + m_TargetTexture: {fileID: 0} + m_TargetDisplay: 0 + m_TargetEye: 3 + m_HDR: 1 + m_AllowMSAA: 1 + m_AllowDynamicResolution: 0 + m_ForceIntoRT: 0 + m_OcclusionCulling: 1 + m_StereoConvergence: 10 + m_StereoSeparation: 0.022 +--- !u!4 &492081945 +Transform: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 492081941} + m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: -10} + m_LocalScale: {x: 1, y: 1, z: 1} + m_Children: + - {fileID: 904015944} + m_Father: {fileID: 0} + m_RootOrder: 0 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} +--- !u!1 &871349752 +GameObject: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + serializedVersion: 5 + m_Component: + - component: {fileID: 871349756} + - component: {fileID: 871349755} + - component: {fileID: 871349754} + - component: {fileID: 871349753} + m_Layer: 5 + m_Name: Canvas + m_TagString: Untagged + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!114 &871349753 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 871349752} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 1301386320, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3} + m_Name: + m_EditorClassIdentifier: + m_IgnoreReversedGraphics: 1 + m_BlockingObjects: 0 + m_BlockingMask: + serializedVersion: 2 + m_Bits: 4294967295 +--- !u!114 &871349754 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 871349752} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 1980459831, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3} + m_Name: + m_EditorClassIdentifier: + m_UiScaleMode: 0 + m_ReferencePixelsPerUnit: 100 + m_ScaleFactor: 1 + m_ReferenceResolution: {x: 800, y: 600} + m_ScreenMatchMode: 0 + m_MatchWidthOrHeight: 0 + m_PhysicalUnit: 3 + m_FallbackScreenDPI: 96 + m_DefaultSpriteDPI: 96 + m_DynamicPixelsPerUnit: 1 +--- !u!223 &871349755 +Canvas: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 871349752} + m_Enabled: 1 + serializedVersion: 3 + m_RenderMode: 0 + m_Camera: {fileID: 0} + m_PlaneDistance: 100 + m_PixelPerfect: 0 + m_ReceivesEvents: 1 + m_OverrideSorting: 0 + m_OverridePixelPerfect: 0 + m_SortingBucketNormalizedSize: 0 + m_AdditionalShaderChannelsFlag: 0 + m_SortingLayerID: 0 + m_SortingOrder: 0 + m_TargetDisplay: 0 +--- !u!224 &871349756 +RectTransform: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 871349752} + m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: 0} + m_LocalScale: {x: 0, y: 0, z: 0} + m_Children: + - {fileID: 1726294324} + m_Father: {fileID: 0} + m_RootOrder: 1 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} + m_AnchorMin: {x: 0, y: 0} + m_AnchorMax: {x: 0, y: 0} + m_AnchoredPosition: {x: 0, y: 0} + m_SizeDelta: {x: 0, y: 0} + m_Pivot: {x: 0, y: 0} +--- !u!1 &904015943 +GameObject: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + serializedVersion: 5 + m_Component: + - component: {fileID: 904015944} + - component: {fileID: 904015945} + m_Layer: 0 + m_Name: HelloTFLite + m_TagString: Untagged + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!4 &904015944 +Transform: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 904015943} + m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: 0} + m_LocalScale: {x: 1, y: 1, z: 1} + m_Children: [] + m_Father: {fileID: 492081945} + m_RootOrder: 0 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} +--- !u!114 &904015945 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 904015943} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: 899510441e0ca4be0879d3055e467878, type: 3} + m_Name: + m_EditorClassIdentifier: + model: {fileID: 4900000, guid: adff4e1dbdba344c199ee4fe7e84457e, type: 3} + inputs: + - 1 + - 3 + - 7 + inferenceText: {fileID: 1726294325} +--- !u!1 &1726294323 +GameObject: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + serializedVersion: 5 + m_Component: + - component: {fileID: 1726294324} + - component: {fileID: 1726294326} + - component: {fileID: 1726294325} + m_Layer: 5 + m_Name: InferenceText + m_TagString: Untagged + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!224 &1726294324 +RectTransform: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 1726294323} + m_LocalRotation: {x: -0, y: -0, z: -0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: 0} + m_LocalScale: {x: 1, y: 1, z: 1} + m_Children: [] + m_Father: {fileID: 871349756} + m_RootOrder: 0 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} + m_AnchorMin: {x: 0.5, y: 0.5} + m_AnchorMax: {x: 0.5, y: 0.5} + m_AnchoredPosition: {x: 0, y: 25} + m_SizeDelta: {x: 450, y: 250} + m_Pivot: {x: 0.5, y: 0.5} +--- !u!114 &1726294325 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 1726294323} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 708705254, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3} + m_Name: + m_EditorClassIdentifier: + m_Material: {fileID: 0} + m_Color: {r: 0.9338235, g: 0.9338235, b: 0.9338235, a: 1} + m_RaycastTarget: 1 + m_OnCullStateChanged: + m_PersistentCalls: + m_Calls: [] + m_TypeName: UnityEngine.UI.MaskableGraphic+CullStateChangedEvent, UnityEngine.UI, + Version=1.0.0.0, Culture=neutral, PublicKeyToken=null + m_FontData: + m_Font: {fileID: 10102, guid: 0000000000000000e000000000000000, type: 0} + m_FontSize: 35 + m_FontStyle: 0 + m_BestFit: 0 + m_MinSize: 2 + m_MaxSize: 40 + m_Alignment: 4 + m_AlignByGeometry: 0 + m_RichText: 1 + m_HorizontalOverflow: 0 + m_VerticalOverflow: 0 + m_LineSpacing: 1 + m_Text: 'Inference took 0.0153 ms + + Input: 1,3,7 + + Output: 3,9,21' +--- !u!222 &1726294326 +CanvasRenderer: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 1726294323} +--- !u!1 &2026426602 +GameObject: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + serializedVersion: 5 + m_Component: + - component: {fileID: 2026426605} + - component: {fileID: 2026426604} + - component: {fileID: 2026426603} + m_Layer: 0 + m_Name: EventSystem + m_TagString: Untagged + m_Icon: {fileID: 0} + m_NavMeshLayer: 0 + m_StaticEditorFlags: 0 + m_IsActive: 1 +--- !u!114 &2026426603 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 2026426602} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 1077351063, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3} + m_Name: + m_EditorClassIdentifier: + m_HorizontalAxis: Horizontal + m_VerticalAxis: Vertical + m_SubmitButton: Submit + m_CancelButton: Cancel + m_InputActionsPerSecond: 10 + m_RepeatDelay: 0.5 + m_ForceModuleActive: 0 +--- !u!114 &2026426604 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 2026426602} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: -619905303, guid: f5f67c52d1564df4a8936ccd202a3bd8, type: 3} + m_Name: + m_EditorClassIdentifier: + m_FirstSelected: {fileID: 0} + m_sendNavigationEvents: 1 + m_DragThreshold: 5 +--- !u!4 &2026426605 +Transform: + m_ObjectHideFlags: 0 + m_PrefabParentObject: {fileID: 0} + m_PrefabInternal: {fileID: 0} + m_GameObject: {fileID: 2026426602} + m_LocalRotation: {x: 0, y: 0, z: 0, w: 1} + m_LocalPosition: {x: 0, y: 0, z: 0} + m_LocalScale: {x: 1, y: 1, z: 1} + m_Children: [] + m_Father: {fileID: 0} + m_RootOrder: 2 + m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity.meta new file mode 100644 index 0000000000000000000000000000000000000000..e1e13efb66027b555f1d45c76fe58fe2103774a2 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/HelloTFLite.unity.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: f8a8c37a396584bb7b21687f33d6d3f8 +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes new file mode 100644 index 0000000000000000000000000000000000000000..aef0fe3d82c9d92dc444076d3b46e05af1923f46 Binary files /dev/null and b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes differ diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes.meta new file mode 100644 index 0000000000000000000000000000000000000000..ba24871413e06154afd0c0d5e2db83b7619d34a9 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scenes/add.bytes.meta @@ -0,0 +1,7 @@ +fileFormatVersion: 2 +guid: adff4e1dbdba344c199ee4fe7e84457e +TextScriptImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts.meta new file mode 100644 index 0000000000000000000000000000000000000000..28fde68b8b1346e88375dc7a8613270f0e2f2762 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: f7d1e2dec09b64acdb7b8f5aef9fcb44 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs new file mode 100644 index 0000000000000000000000000000000000000000..83291e61794819e7c57f69ed2be6ea40294e01da --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs @@ -0,0 +1,85 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using TensorFlowLite; +using UnityEngine; +using UnityEngine.UI; + +/// +/// Simple example demonstrating use of the experimental C# bindings for TensorFlowLite. +/// +public class HelloTFLite : MonoBehaviour { + + [Tooltip("Configurable TFLite model.")] + public TextAsset model; + + [Tooltip("Configurable TFLite input tensor data.")] + public float[] inputs; + + [Tooltip("Target Text widget for display of inference execution.")] + public Text inferenceText; + + private Interpreter interpreter; + private float[] outputs; + + void Awake() { + // As the demo is extremely simple, there's no need to run at full frame-rate. + QualitySettings.vSyncCount = 0; + Application.targetFrameRate = 5; + } + + void Start () { + interpreter = new Interpreter(model.bytes); + Debug.LogFormat( + "InputCount: {0}, OutputCount: {1}", + interpreter.GetInputTensorCount(), + interpreter.GetOutputTensorCount()); + } + + void Update () { + if (inputs == null) { + return; + } + + if (outputs == null || outputs.Length != inputs.Length) { + interpreter.ResizeInputTensor(0, new int[]{inputs.Length}); + interpreter.AllocateTensors(); + outputs = new float[inputs.Length]; + } + + float startTimeSeconds = Time.realtimeSinceStartup; + interpreter.SetInputTensorData(0, inputs); + interpreter.Invoke(); + interpreter.GetOutputTensorData(0, outputs); + float inferenceTimeSeconds = Time.realtimeSinceStartup - startTimeSeconds; + + inferenceText.text = string.Format( + "Inference took {0:0.0000} ms\nInput(s): {1}\nOutput(s): {2}", + inferenceTimeSeconds * 1000.0, + ArrayToString(inputs), + ArrayToString(outputs)); + } + + void OnDestroy() { + interpreter.Dispose(); + } + + private static string ArrayToString(float[] values) { + return string.Join(",", values.Select(x => x.ToString()).ToArray()); + } +} diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs.meta new file mode 100644 index 0000000000000000000000000000000000000000..ba83f45084bb624e5e7777684b0fda98b4d46688 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/Examples/HelloTFLite/Scripts/HelloTFLite.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 899510441e0ca4be0879d3055e467878 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK.meta new file mode 100644 index 0000000000000000000000000000000000000000..bf5ce15c6a6932398d798d193b54f4ecfd8ba2d8 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 16dad1655bcdc48f7b325a2a634b9c69 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts.meta new file mode 100644 index 0000000000000000000000000000000000000000..22ed2c466bde1668595967f7a07f34a9193aaec8 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: d70863368f8904d509a9b73d3a555914 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs new file mode 100644 index 0000000000000000000000000000000000000000..ab966bae2efb9431e2f9f35dc818d130aabd71f6 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs @@ -0,0 +1,145 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +using System; +using System.Runtime.InteropServices; + +using TFL_Interpreter = System.IntPtr; +using TFL_Tensor = System.IntPtr; + +namespace TensorFlowLite +{ + /// + /// Simple C# bindings for the experimental TensorFlowLite C API. + /// + public class Interpreter : IDisposable + { + private const string TensorFlowLibrary = "tensorflowlite_c"; + + private TFL_Interpreter handle; + + public Interpreter(byte[] modelData) { + GCHandle modelDataHandle = GCHandle.Alloc(modelData, GCHandleType.Pinned); + IntPtr modelDataPtr = modelDataHandle.AddrOfPinnedObject(); + handle = TFL_NewInterpreter(modelDataPtr, modelData.Length); + if (handle == IntPtr.Zero) throw new Exception("Failed to create TensorFlowLite Interpreter"); + } + + ~Interpreter() { + Dispose(); + } + + public void Dispose() { + if (handle != IntPtr.Zero) TFL_DeleteInterpreter(handle); + handle = IntPtr.Zero; + } + + public void Invoke() { + ThrowIfError(TFL_InterpreterInvoke(handle)); + } + + public int GetInputTensorCount() { + return TFL_InterpreterGetInputTensorCount(handle); + } + + public void SetInputTensorData(int inputTensorIndex, Array inputTensorData) { + GCHandle tensorDataHandle = GCHandle.Alloc(inputTensorData, GCHandleType.Pinned); + IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject(); + TFL_Tensor tensor = TFL_InterpreterGetInputTensor(handle, inputTensorIndex); + ThrowIfError(TFL_TensorCopyFromBuffer( + tensor, tensorDataPtr, Buffer.ByteLength(inputTensorData))); + } + + public void ResizeInputTensor(int inputTensorIndex, int[] inputTensorShape) { + ThrowIfError(TFL_InterpreterResizeInputTensor( + handle, inputTensorIndex, inputTensorShape, inputTensorShape.Length)); + } + + public void AllocateTensors() { + ThrowIfError(TFL_InterpreterAllocateTensors(handle)); + } + + public int GetOutputTensorCount() { + return TFL_InterpreterGetOutputTensorCount(handle); + } + + public void GetOutputTensorData(int outputTensorIndex, Array outputTensorData) { + GCHandle tensorDataHandle = GCHandle.Alloc(outputTensorData, GCHandleType.Pinned); + IntPtr tensorDataPtr = tensorDataHandle.AddrOfPinnedObject(); + TFL_Tensor tensor = TFL_InterpreterGetOutputTensor(handle, outputTensorIndex); + ThrowIfError(TFL_TensorCopyToBuffer( + tensor, tensorDataPtr, Buffer.ByteLength(outputTensorData))); + } + + private static void ThrowIfError(int resultCode) { + if (resultCode != 0) throw new Exception("TensorFlowLite operation failed."); + } + + #region Externs + + [DllImport (TensorFlowLibrary)] + private static extern unsafe TFL_Interpreter TFL_NewInterpreter( + IntPtr model_data, + int model_size); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe void TFL_DeleteInterpreter(TFL_Interpreter interpreter); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_InterpreterGetInputTensorCount( + TFL_Interpreter interpreter); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe TFL_Tensor TFL_InterpreterGetInputTensor( + TFL_Interpreter interpreter, + int input_index); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_InterpreterResizeInputTensor( + TFL_Interpreter interpreter, + int input_index, + int[] input_dims, + int input_dims_size); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_InterpreterAllocateTensors( + TFL_Interpreter interpreter); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_InterpreterInvoke(TFL_Interpreter interpreter); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_InterpreterGetOutputTensorCount( + TFL_Interpreter interpreter); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe TFL_Tensor TFL_InterpreterGetOutputTensor( + TFL_Interpreter interpreter, + int output_index); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_TensorCopyFromBuffer( + TFL_Tensor tensor, + IntPtr input_data, + int input_data_size); + + [DllImport (TensorFlowLibrary)] + private static extern unsafe int TFL_TensorCopyToBuffer( + TFL_Tensor tensor, + IntPtr output_data, + int output_data_size); + + #endregion + } +} diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs.meta b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs.meta new file mode 100644 index 0000000000000000000000000000000000000000..5ec84ef7f70e9be45ff6292ed7a412fac35010de --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 0bbaf59e6ac914ed1b28174fb9008a09 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/AudioManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/AudioManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..da6112576a5ca4290108f6d4c731bd4c391e91d4 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/AudioManager.asset @@ -0,0 +1,17 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!11 &1 +AudioManager: + m_ObjectHideFlags: 0 + m_Volume: 1 + Rolloff Scale: 1 + Doppler Factor: 1 + Default Speaker Mode: 2 + m_SampleRate: 0 + m_DSPBufferSize: 0 + m_VirtualVoiceCount: 512 + m_RealVoiceCount: 32 + m_SpatializerPlugin: + m_AmbisonicDecoderPlugin: + m_DisableAudio: 0 + m_VirtualizeEffects: 1 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ClusterInputManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ClusterInputManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..e7886b266a005f4d9d80f2fef8d1649dcfd3ed2b --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ClusterInputManager.asset @@ -0,0 +1,6 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!236 &1 +ClusterInputManager: + m_ObjectHideFlags: 0 + m_Inputs: [] diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/DynamicsManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/DynamicsManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..78992f08c7ab7a4353c8a7d07cf1548174aaacbf --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/DynamicsManager.asset @@ -0,0 +1,29 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!55 &1 +PhysicsManager: + m_ObjectHideFlags: 0 + serializedVersion: 7 + m_Gravity: {x: 0, y: -9.81, z: 0} + m_DefaultMaterial: {fileID: 0} + m_BounceThreshold: 2 + m_SleepThreshold: 0.005 + m_DefaultContactOffset: 0.01 + m_DefaultSolverIterations: 6 + m_DefaultSolverVelocityIterations: 1 + m_QueriesHitBackfaces: 0 + m_QueriesHitTriggers: 1 + m_EnableAdaptiveForce: 0 + m_ClothInterCollisionDistance: 0 + m_ClothInterCollisionStiffness: 0 + m_ContactsGeneration: 1 + m_LayerCollisionMatrix: ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff + m_AutoSimulation: 1 + m_AutoSyncTransforms: 1 + m_ClothInterCollisionSettingsToggle: 0 + m_ContactPairsMode: 0 + m_BroadphaseType: 0 + m_WorldBounds: + m_Center: {x: 0, y: 0, z: 0} + m_Extent: {x: 250, y: 250, z: 250} + m_WorldSubdivisions: 8 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorBuildSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorBuildSettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..6dc24f7dfdb697ad6f5d0a4ec5599bcd3cbd2f43 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorBuildSettings.asset @@ -0,0 +1,7 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!1045 &1 +EditorBuildSettings: + m_ObjectHideFlags: 0 + serializedVersion: 2 + m_Scenes: [] diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorSettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..fcd016402f97e4c009a16640517a6930ed615ef9 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/EditorSettings.asset @@ -0,0 +1,21 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!159 &1 +EditorSettings: + m_ObjectHideFlags: 0 + serializedVersion: 7 + m_ExternalVersionControlSupport: Visible Meta Files + m_SerializationMode: 2 + m_LineEndingsForNewScripts: 1 + m_DefaultBehaviorMode: 1 + m_SpritePackerMode: 4 + m_SpritePackerPaddingPower: 1 + m_EtcTextureCompressorBehavior: 1 + m_EtcTextureFastCompressor: 1 + m_EtcTextureNormalCompressor: 2 + m_EtcTextureBestCompressor: 4 + m_ProjectGenerationIncludedExtensions: txt;xml;fnt;cd;asmdef;rsp + m_ProjectGenerationRootNamespace: + m_UserGeneratedProjectSuffix: + m_CollabEditorSettings: + inProgressEnabled: 1 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..a9bbfb02d1e7065b7d0e90609a3928d667933477 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/GraphicsSettings.asset @@ -0,0 +1,64 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!30 &1 +GraphicsSettings: + m_ObjectHideFlags: 0 + serializedVersion: 12 + m_Deferred: + m_Mode: 1 + m_Shader: {fileID: 69, guid: 0000000000000000f000000000000000, type: 0} + m_DeferredReflections: + m_Mode: 1 + m_Shader: {fileID: 74, guid: 0000000000000000f000000000000000, type: 0} + m_ScreenSpaceShadows: + m_Mode: 1 + m_Shader: {fileID: 64, guid: 0000000000000000f000000000000000, type: 0} + m_LegacyDeferred: + m_Mode: 1 + m_Shader: {fileID: 63, guid: 0000000000000000f000000000000000, type: 0} + m_DepthNormals: + m_Mode: 1 + m_Shader: {fileID: 62, guid: 0000000000000000f000000000000000, type: 0} + m_MotionVectors: + m_Mode: 1 + m_Shader: {fileID: 75, guid: 0000000000000000f000000000000000, type: 0} + m_LightHalo: + m_Mode: 1 + m_Shader: {fileID: 105, guid: 0000000000000000f000000000000000, type: 0} + m_LensFlare: + m_Mode: 1 + m_Shader: {fileID: 102, guid: 0000000000000000f000000000000000, type: 0} + m_AlwaysIncludedShaders: + - {fileID: 7, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 15104, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 15105, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 15106, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 10753, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 10770, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 17000, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 16000, guid: 0000000000000000f000000000000000, type: 0} + - {fileID: 16002, guid: 0000000000000000f000000000000000, type: 0} + m_PreloadedShaders: [] + m_SpritesDefaultMaterial: {fileID: 10754, guid: 0000000000000000f000000000000000, + type: 0} + m_CustomRenderPipeline: {fileID: 0} + m_TransparencySortMode: 0 + m_TransparencySortAxis: {x: 0, y: 0, z: 1} + m_DefaultRenderingPath: 1 + m_DefaultMobileRenderingPath: 1 + m_TierSettings: [] + m_LightmapStripping: 0 + m_FogStripping: 0 + m_InstancingStripping: 0 + m_LightmapKeepPlain: 1 + m_LightmapKeepDirCombined: 1 + m_LightmapKeepDynamicPlain: 1 + m_LightmapKeepDynamicDirCombined: 1 + m_LightmapKeepShadowMask: 1 + m_LightmapKeepSubtractive: 1 + m_FogKeepLinear: 1 + m_FogKeepExp: 1 + m_FogKeepExp2: 1 + m_AlbedoSwatchInfos: [] + m_LightsUseLinearIntensity: 0 + m_LightsUseColorTemperature: 0 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/InputManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/InputManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..17c8f538e2152c0a0310b4870979eeecece2153c --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/InputManager.asset @@ -0,0 +1,295 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!13 &1 +InputManager: + m_ObjectHideFlags: 0 + serializedVersion: 2 + m_Axes: + - serializedVersion: 3 + m_Name: Horizontal + descriptiveName: + descriptiveNegativeName: + negativeButton: left + positiveButton: right + altNegativeButton: a + altPositiveButton: d + gravity: 3 + dead: 0.001 + sensitivity: 3 + snap: 1 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Vertical + descriptiveName: + descriptiveNegativeName: + negativeButton: down + positiveButton: up + altNegativeButton: s + altPositiveButton: w + gravity: 3 + dead: 0.001 + sensitivity: 3 + snap: 1 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Fire1 + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: left ctrl + altNegativeButton: + altPositiveButton: mouse 0 + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Fire2 + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: left alt + altNegativeButton: + altPositiveButton: mouse 1 + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Fire3 + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: left shift + altNegativeButton: + altPositiveButton: mouse 2 + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Jump + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: space + altNegativeButton: + altPositiveButton: + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Mouse X + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: + altNegativeButton: + altPositiveButton: + gravity: 0 + dead: 0 + sensitivity: 0.1 + snap: 0 + invert: 0 + type: 1 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Mouse Y + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: + altNegativeButton: + altPositiveButton: + gravity: 0 + dead: 0 + sensitivity: 0.1 + snap: 0 + invert: 0 + type: 1 + axis: 1 + joyNum: 0 + - serializedVersion: 3 + m_Name: Mouse ScrollWheel + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: + altNegativeButton: + altPositiveButton: + gravity: 0 + dead: 0 + sensitivity: 0.1 + snap: 0 + invert: 0 + type: 1 + axis: 2 + joyNum: 0 + - serializedVersion: 3 + m_Name: Horizontal + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: + altNegativeButton: + altPositiveButton: + gravity: 0 + dead: 0.19 + sensitivity: 1 + snap: 0 + invert: 0 + type: 2 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Vertical + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: + altNegativeButton: + altPositiveButton: + gravity: 0 + dead: 0.19 + sensitivity: 1 + snap: 0 + invert: 1 + type: 2 + axis: 1 + joyNum: 0 + - serializedVersion: 3 + m_Name: Fire1 + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: joystick button 0 + altNegativeButton: + altPositiveButton: + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Fire2 + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: joystick button 1 + altNegativeButton: + altPositiveButton: + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Fire3 + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: joystick button 2 + altNegativeButton: + altPositiveButton: + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Jump + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: joystick button 3 + altNegativeButton: + altPositiveButton: + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Submit + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: return + altNegativeButton: + altPositiveButton: joystick button 0 + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Submit + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: enter + altNegativeButton: + altPositiveButton: space + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 + - serializedVersion: 3 + m_Name: Cancel + descriptiveName: + descriptiveNegativeName: + negativeButton: + positiveButton: escape + altNegativeButton: + altPositiveButton: joystick button 1 + gravity: 1000 + dead: 0.001 + sensitivity: 1000 + snap: 0 + invert: 0 + type: 0 + axis: 0 + joyNum: 0 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NavMeshAreas.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NavMeshAreas.asset new file mode 100644 index 0000000000000000000000000000000000000000..3b0b7c3d183abdd300112f56965916ef11667f54 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NavMeshAreas.asset @@ -0,0 +1,91 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!126 &1 +NavMeshProjectSettings: + m_ObjectHideFlags: 0 + serializedVersion: 2 + areas: + - name: Walkable + cost: 1 + - name: Not Walkable + cost: 1 + - name: Jump + cost: 2 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + - name: + cost: 1 + m_LastAgentTypeID: -887442657 + m_Settings: + - serializedVersion: 2 + agentTypeID: 0 + agentRadius: 0.5 + agentHeight: 2 + agentSlope: 45 + agentClimb: 0.75 + ledgeDropHeight: 0 + maxJumpAcrossDistance: 0 + minRegionArea: 2 + manualCellSize: 0 + cellSize: 0.16666667 + manualTileSize: 0 + tileSize: 256 + accuratePlacement: 0 + debug: + m_Flags: 0 + m_SettingNames: + - Humanoid diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NetworkManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NetworkManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..5dc6a831d9f2a11f08ed96571e0f602e3c3908b5 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/NetworkManager.asset @@ -0,0 +1,8 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!149 &1 +NetworkManager: + m_ObjectHideFlags: 0 + m_DebugLevel: 0 + m_Sendrate: 15 + m_AssetToPrefab: {} diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/Physics2DSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/Physics2DSettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..132ee6bc868f1aae138555dc139e054b0d1d8620 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/Physics2DSettings.asset @@ -0,0 +1,37 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!19 &1 +Physics2DSettings: + m_ObjectHideFlags: 0 + serializedVersion: 3 + m_Gravity: {x: 0, y: -9.81} + m_DefaultMaterial: {fileID: 0} + m_VelocityIterations: 8 + m_PositionIterations: 3 + m_VelocityThreshold: 1 + m_MaxLinearCorrection: 0.2 + m_MaxAngularCorrection: 8 + m_MaxTranslationSpeed: 100 + m_MaxRotationSpeed: 360 + m_BaumgarteScale: 0.2 + m_BaumgarteTimeOfImpactScale: 0.75 + m_TimeToSleep: 0.5 + m_LinearSleepTolerance: 0.01 + m_AngularSleepTolerance: 2 + m_DefaultContactOffset: 0.01 + m_AutoSimulation: 1 + m_QueriesHitTriggers: 1 + m_QueriesStartInColliders: 1 + m_ChangeStopsCallbacks: 0 + m_CallbacksOnDisable: 1 + m_AutoSyncTransforms: 1 + m_AlwaysShowColliders: 0 + m_ShowColliderSleep: 1 + m_ShowColliderContacts: 0 + m_ShowColliderAABB: 0 + m_ContactArrowScale: 0.2 + m_ColliderAwakeColor: {r: 0.5686275, g: 0.95686275, b: 0.54509807, a: 0.7529412} + m_ColliderAsleepColor: {r: 0.5686275, g: 0.95686275, b: 0.54509807, a: 0.36078432} + m_ColliderContactColor: {r: 1, g: 0, b: 1, a: 0.6862745} + m_ColliderAABBColor: {r: 1, g: 1, b: 0, a: 0.2509804} + m_LayerCollisionMatrix: ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectSettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..3fbfab76c13c84f66a166c5dfe1d4552503350ff --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectSettings.asset @@ -0,0 +1,641 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!129 &1 +PlayerSettings: + m_ObjectHideFlags: 0 + serializedVersion: 14 + productGUID: a084943b991dd4597b140f4ce2b41c65 + AndroidProfiler: 0 + AndroidFilterTouchesWhenObscured: 0 + defaultScreenOrientation: 4 + targetDevice: 2 + useOnDemandResources: 0 + accelerometerFrequency: 60 + companyName: DefaultCompany + productName: TensorFlowLitePlugin + defaultCursor: {fileID: 0} + cursorHotspot: {x: 0, y: 0} + m_SplashScreenBackgroundColor: {r: 0.13725491, g: 0.12156863, b: 0.1254902, a: 1} + m_ShowUnitySplashScreen: 1 + m_ShowUnitySplashLogo: 1 + m_SplashScreenOverlayOpacity: 1 + m_SplashScreenAnimation: 1 + m_SplashScreenLogoStyle: 1 + m_SplashScreenDrawMode: 0 + m_SplashScreenBackgroundAnimationZoom: 1 + m_SplashScreenLogoAnimationZoom: 1 + m_SplashScreenBackgroundLandscapeAspect: 1 + m_SplashScreenBackgroundPortraitAspect: 1 + m_SplashScreenBackgroundLandscapeUvs: + serializedVersion: 2 + x: 0 + y: 0 + width: 1 + height: 1 + m_SplashScreenBackgroundPortraitUvs: + serializedVersion: 2 + x: 0 + y: 0 + width: 1 + height: 1 + m_SplashScreenLogos: [] + m_VirtualRealitySplashScreen: {fileID: 0} + m_HolographicTrackingLossScreen: {fileID: 0} + defaultScreenWidth: 1024 + defaultScreenHeight: 768 + defaultScreenWidthWeb: 960 + defaultScreenHeightWeb: 600 + m_StereoRenderingPath: 0 + m_ActiveColorSpace: 0 + m_MTRendering: 1 + m_StackTraceTypes: 010000000100000001000000010000000100000001000000 + iosShowActivityIndicatorOnLoading: -1 + androidShowActivityIndicatorOnLoading: -1 + tizenShowActivityIndicatorOnLoading: -1 + iosAppInBackgroundBehavior: 0 + displayResolutionDialog: 1 + iosAllowHTTPDownload: 1 + allowedAutorotateToPortrait: 1 + allowedAutorotateToPortraitUpsideDown: 1 + allowedAutorotateToLandscapeRight: 1 + allowedAutorotateToLandscapeLeft: 1 + useOSAutorotation: 1 + use32BitDisplayBuffer: 1 + preserveFramebufferAlpha: 0 + disableDepthAndStencilBuffers: 0 + androidBlitType: 0 + defaultIsFullScreen: 1 + defaultIsNativeResolution: 1 + macRetinaSupport: 1 + runInBackground: 0 + captureSingleScreen: 0 + muteOtherAudioSources: 0 + Prepare IOS For Recording: 0 + Force IOS Speakers When Recording: 0 + deferSystemGesturesMode: 0 + hideHomeButton: 0 + submitAnalytics: 1 + usePlayerLog: 1 + bakeCollisionMeshes: 0 + forceSingleInstance: 0 + resizableWindow: 0 + useMacAppStoreValidation: 0 + macAppStoreCategory: public.app-category.games + gpuSkinning: 0 + graphicsJobs: 0 + xboxPIXTextureCapture: 0 + xboxEnableAvatar: 0 + xboxEnableKinect: 0 + xboxEnableKinectAutoTracking: 0 + xboxEnableFitness: 0 + visibleInBackground: 1 + allowFullscreenSwitch: 1 + graphicsJobMode: 0 + macFullscreenMode: 2 + d3d11FullscreenMode: 1 + xboxSpeechDB: 0 + xboxEnableHeadOrientation: 0 + xboxEnableGuest: 0 + xboxEnablePIXSampling: 0 + metalFramebufferOnly: 0 + n3dsDisableStereoscopicView: 0 + n3dsEnableSharedListOpt: 1 + n3dsEnableVSync: 0 + xboxOneResolution: 0 + xboxOneSResolution: 0 + xboxOneXResolution: 3 + xboxOneMonoLoggingLevel: 0 + xboxOneLoggingLevel: 1 + xboxOneDisableEsram: 0 + xboxOnePresentImmediateThreshold: 0 + videoMemoryForVertexBuffers: 0 + psp2PowerMode: 0 + psp2AcquireBGM: 1 + wiiUTVResolution: 0 + wiiUGamePadMSAA: 1 + wiiUSupportsNunchuk: 0 + wiiUSupportsClassicController: 0 + wiiUSupportsBalanceBoard: 0 + wiiUSupportsMotionPlus: 0 + wiiUSupportsProController: 0 + wiiUAllowScreenCapture: 1 + wiiUControllerCount: 0 + m_SupportedAspectRatios: + 4:3: 1 + 5:4: 1 + 16:10: 1 + 16:9: 1 + Others: 1 + bundleVersion: 1.0 + preloadedAssets: [] + metroInputSource: 0 + wsaTransparentSwapchain: 0 + m_HolographicPauseOnTrackingLoss: 1 + xboxOneDisableKinectGpuReservation: 0 + xboxOneEnable7thCore: 0 + vrSettings: + cardboard: + depthFormat: 0 + enableTransitionView: 0 + daydream: + depthFormat: 0 + useSustainedPerformanceMode: 0 + enableVideoLayer: 0 + useProtectedVideoMemory: 0 + minimumSupportedHeadTracking: 0 + maximumSupportedHeadTracking: 1 + hololens: + depthFormat: 1 + depthBufferSharingEnabled: 0 + oculus: + sharedDepthBuffer: 0 + dashSupport: 0 + protectGraphicsMemory: 0 + useHDRDisplay: 0 + m_ColorGamuts: 00000000 + targetPixelDensity: 30 + resolutionScalingMode: 0 + androidSupportedAspectRatio: 1 + androidMaxAspectRatio: 2.1 + applicationIdentifier: {} + buildNumber: {} + AndroidBundleVersionCode: 1 + AndroidMinSdkVersion: 16 + AndroidTargetSdkVersion: 0 + AndroidPreferredInstallLocation: 1 + aotOptions: + stripEngineCode: 1 + iPhoneStrippingLevel: 0 + iPhoneScriptCallOptimization: 0 + ForceInternetPermission: 0 + ForceSDCardPermission: 0 + CreateWallpaper: 0 + APKExpansionFiles: 0 + keepLoadedShadersAlive: 0 + StripUnusedMeshComponents: 0 + VertexChannelCompressionMask: + serializedVersion: 2 + m_Bits: 238 + iPhoneSdkVersion: 988 + iOSTargetOSVersionString: 7.0 + tvOSSdkVersion: 0 + tvOSRequireExtendedGameController: 0 + tvOSTargetOSVersionString: 9.0 + uIPrerenderedIcon: 0 + uIRequiresPersistentWiFi: 0 + uIRequiresFullScreen: 1 + uIStatusBarHidden: 1 + uIExitOnSuspend: 0 + uIStatusBarStyle: 0 + iPhoneSplashScreen: {fileID: 0} + iPhoneHighResSplashScreen: {fileID: 0} + iPhoneTallHighResSplashScreen: {fileID: 0} + iPhone47inSplashScreen: {fileID: 0} + iPhone55inPortraitSplashScreen: {fileID: 0} + iPhone55inLandscapeSplashScreen: {fileID: 0} + iPhone58inPortraitSplashScreen: {fileID: 0} + iPhone58inLandscapeSplashScreen: {fileID: 0} + iPadPortraitSplashScreen: {fileID: 0} + iPadHighResPortraitSplashScreen: {fileID: 0} + iPadLandscapeSplashScreen: {fileID: 0} + iPadHighResLandscapeSplashScreen: {fileID: 0} + appleTVSplashScreen: {fileID: 0} + appleTVSplashScreen2x: {fileID: 0} + tvOSSmallIconLayers: [] + tvOSSmallIconLayers2x: [] + tvOSLargeIconLayers: [] + tvOSTopShelfImageLayers: [] + tvOSTopShelfImageLayers2x: [] + tvOSTopShelfImageWideLayers: [] + tvOSTopShelfImageWideLayers2x: [] + iOSLaunchScreenType: 0 + iOSLaunchScreenPortrait: {fileID: 0} + iOSLaunchScreenLandscape: {fileID: 0} + iOSLaunchScreenBackgroundColor: + serializedVersion: 2 + rgba: 0 + iOSLaunchScreenFillPct: 100 + iOSLaunchScreenSize: 100 + iOSLaunchScreenCustomXibPath: + iOSLaunchScreeniPadType: 0 + iOSLaunchScreeniPadImage: {fileID: 0} + iOSLaunchScreeniPadBackgroundColor: + serializedVersion: 2 + rgba: 0 + iOSLaunchScreeniPadFillPct: 100 + iOSLaunchScreeniPadSize: 100 + iOSLaunchScreeniPadCustomXibPath: + iOSUseLaunchScreenStoryboard: 0 + iOSLaunchScreenCustomStoryboardPath: + iOSDeviceRequirements: [] + iOSURLSchemes: [] + iOSBackgroundModes: 0 + iOSMetalForceHardShadows: 0 + metalEditorSupport: 1 + metalAPIValidation: 1 + iOSRenderExtraFrameOnPause: 0 + appleDeveloperTeamID: + iOSManualSigningProvisioningProfileID: + tvOSManualSigningProvisioningProfileID: + appleEnableAutomaticSigning: 0 + clonedFromGUID: 00000000000000000000000000000000 + AndroidTargetDevice: 0 + AndroidSplashScreenScale: 0 + androidSplashScreen: {fileID: 0} + AndroidKeystoreName: + AndroidKeyaliasName: + AndroidTVCompatibility: 1 + AndroidIsGame: 1 + AndroidEnableTango: 0 + androidEnableBanner: 1 + androidUseLowAccuracyLocation: 0 + m_AndroidBanners: + - width: 320 + height: 180 + banner: {fileID: 0} + androidGamepadSupportLevel: 0 + resolutionDialogBanner: {fileID: 0} + m_BuildTargetIcons: [] + m_BuildTargetBatching: [] + m_BuildTargetGraphicsAPIs: [] + m_BuildTargetVRSettings: [] + m_BuildTargetEnableVuforiaSettings: [] + openGLRequireES31: 0 + openGLRequireES31AEP: 0 + m_TemplateCustomTags: {} + mobileMTRendering: + Android: 1 + iPhone: 1 + tvOS: 1 + m_BuildTargetGroupLightmapEncodingQuality: [] + wiiUTitleID: 0005000011000000 + wiiUGroupID: 00010000 + wiiUCommonSaveSize: 4096 + wiiUAccountSaveSize: 2048 + wiiUOlvAccessKey: 0 + wiiUTinCode: 0 + wiiUJoinGameId: 0 + wiiUJoinGameModeMask: 0000000000000000 + wiiUCommonBossSize: 0 + wiiUAccountBossSize: 0 + wiiUAddOnUniqueIDs: [] + wiiUMainThreadStackSize: 3072 + wiiULoaderThreadStackSize: 1024 + wiiUSystemHeapSize: 128 + wiiUTVStartupScreen: {fileID: 0} + wiiUGamePadStartupScreen: {fileID: 0} + wiiUDrcBufferDisabled: 0 + wiiUProfilerLibPath: + playModeTestRunnerEnabled: 0 + actionOnDotNetUnhandledException: 1 + enableInternalProfiler: 0 + logObjCUncaughtExceptions: 1 + enableCrashReportAPI: 0 + cameraUsageDescription: + locationUsageDescription: + microphoneUsageDescription: + switchNetLibKey: + switchSocketMemoryPoolSize: 6144 + switchSocketAllocatorPoolSize: 128 + switchSocketConcurrencyLimit: 14 + switchScreenResolutionBehavior: 2 + switchUseCPUProfiler: 0 + switchApplicationID: 0x01004b9000490000 + switchNSODependencies: + switchTitleNames_0: + switchTitleNames_1: + switchTitleNames_2: + switchTitleNames_3: + switchTitleNames_4: + switchTitleNames_5: + switchTitleNames_6: + switchTitleNames_7: + switchTitleNames_8: + switchTitleNames_9: + switchTitleNames_10: + switchTitleNames_11: + switchTitleNames_12: + switchTitleNames_13: + switchTitleNames_14: + switchPublisherNames_0: + switchPublisherNames_1: + switchPublisherNames_2: + switchPublisherNames_3: + switchPublisherNames_4: + switchPublisherNames_5: + switchPublisherNames_6: + switchPublisherNames_7: + switchPublisherNames_8: + switchPublisherNames_9: + switchPublisherNames_10: + switchPublisherNames_11: + switchPublisherNames_12: + switchPublisherNames_13: + switchPublisherNames_14: + switchIcons_0: {fileID: 0} + switchIcons_1: {fileID: 0} + switchIcons_2: {fileID: 0} + switchIcons_3: {fileID: 0} + switchIcons_4: {fileID: 0} + switchIcons_5: {fileID: 0} + switchIcons_6: {fileID: 0} + switchIcons_7: {fileID: 0} + switchIcons_8: {fileID: 0} + switchIcons_9: {fileID: 0} + switchIcons_10: {fileID: 0} + switchIcons_11: {fileID: 0} + switchIcons_12: {fileID: 0} + switchIcons_13: {fileID: 0} + switchIcons_14: {fileID: 0} + switchSmallIcons_0: {fileID: 0} + switchSmallIcons_1: {fileID: 0} + switchSmallIcons_2: {fileID: 0} + switchSmallIcons_3: {fileID: 0} + switchSmallIcons_4: {fileID: 0} + switchSmallIcons_5: {fileID: 0} + switchSmallIcons_6: {fileID: 0} + switchSmallIcons_7: {fileID: 0} + switchSmallIcons_8: {fileID: 0} + switchSmallIcons_9: {fileID: 0} + switchSmallIcons_10: {fileID: 0} + switchSmallIcons_11: {fileID: 0} + switchSmallIcons_12: {fileID: 0} + switchSmallIcons_13: {fileID: 0} + switchSmallIcons_14: {fileID: 0} + switchManualHTML: + switchAccessibleURLs: + switchLegalInformation: + switchMainThreadStackSize: 1048576 + switchPresenceGroupId: + switchLogoHandling: 0 + switchReleaseVersion: 0 + switchDisplayVersion: 1.0.0 + switchStartupUserAccount: 0 + switchTouchScreenUsage: 0 + switchSupportedLanguagesMask: 0 + switchLogoType: 0 + switchApplicationErrorCodeCategory: + switchUserAccountSaveDataSize: 0 + switchUserAccountSaveDataJournalSize: 0 + switchApplicationAttribute: 0 + switchCardSpecSize: -1 + switchCardSpecClock: -1 + switchRatingsMask: 0 + switchRatingsInt_0: 0 + switchRatingsInt_1: 0 + switchRatingsInt_2: 0 + switchRatingsInt_3: 0 + switchRatingsInt_4: 0 + switchRatingsInt_5: 0 + switchRatingsInt_6: 0 + switchRatingsInt_7: 0 + switchRatingsInt_8: 0 + switchRatingsInt_9: 0 + switchRatingsInt_10: 0 + switchRatingsInt_11: 0 + switchLocalCommunicationIds_0: + switchLocalCommunicationIds_1: + switchLocalCommunicationIds_2: + switchLocalCommunicationIds_3: + switchLocalCommunicationIds_4: + switchLocalCommunicationIds_5: + switchLocalCommunicationIds_6: + switchLocalCommunicationIds_7: + switchParentalControl: 0 + switchAllowsScreenshot: 1 + switchAllowsVideoCapturing: 1 + switchAllowsRuntimeAddOnContentInstall: 0 + switchDataLossConfirmation: 0 + switchSupportedNpadStyles: 3 + switchSocketConfigEnabled: 0 + switchTcpInitialSendBufferSize: 32 + switchTcpInitialReceiveBufferSize: 64 + switchTcpAutoSendBufferSizeMax: 256 + switchTcpAutoReceiveBufferSizeMax: 256 + switchUdpSendBufferSize: 9 + switchUdpReceiveBufferSize: 42 + switchSocketBufferEfficiency: 4 + switchSocketInitializeEnabled: 1 + switchNetworkInterfaceManagerInitializeEnabled: 1 + switchPlayerConnectionEnabled: 1 + ps4NPAgeRating: 12 + ps4NPTitleSecret: + ps4NPTrophyPackPath: + ps4ParentalLevel: 11 + ps4ContentID: ED1633-NPXX51362_00-0000000000000000 + ps4Category: 0 + ps4MasterVersion: 01.00 + ps4AppVersion: 01.00 + ps4AppType: 0 + ps4ParamSfxPath: + ps4VideoOutPixelFormat: 0 + ps4VideoOutInitialWidth: 1920 + ps4VideoOutBaseModeInitialWidth: 1920 + ps4VideoOutReprojectionRate: 60 + ps4PronunciationXMLPath: + ps4PronunciationSIGPath: + ps4BackgroundImagePath: + ps4StartupImagePath: + ps4StartupImagesFolder: + ps4IconImagesFolder: + ps4SaveDataImagePath: + ps4SdkOverride: + ps4BGMPath: + ps4ShareFilePath: + ps4ShareOverlayImagePath: + ps4PrivacyGuardImagePath: + ps4NPtitleDatPath: + ps4RemotePlayKeyAssignment: -1 + ps4RemotePlayKeyMappingDir: + ps4PlayTogetherPlayerCount: 0 + ps4EnterButtonAssignment: 1 + ps4ApplicationParam1: 0 + ps4ApplicationParam2: 0 + ps4ApplicationParam3: 0 + ps4ApplicationParam4: 0 + ps4DownloadDataSize: 0 + ps4GarlicHeapSize: 2048 + ps4ProGarlicHeapSize: 2560 + ps4Passcode: d3hjjul8UhK6ZnQCEBYYQPozR9sQV066 + ps4pnSessions: 1 + ps4pnPresence: 1 + ps4pnFriends: 1 + ps4pnGameCustomData: 1 + playerPrefsSupport: 0 + restrictedAudioUsageRights: 0 + ps4UseResolutionFallback: 0 + ps4ReprojectionSupport: 0 + ps4UseAudio3dBackend: 0 + ps4SocialScreenEnabled: 0 + ps4ScriptOptimizationLevel: 0 + ps4Audio3dVirtualSpeakerCount: 14 + ps4attribCpuUsage: 0 + ps4PatchPkgPath: + ps4PatchLatestPkgPath: + ps4PatchChangeinfoPath: + ps4PatchDayOne: 0 + ps4attribUserManagement: 0 + ps4attribMoveSupport: 0 + ps4attrib3DSupport: 0 + ps4attribShareSupport: 0 + ps4attribExclusiveVR: 0 + ps4disableAutoHideSplash: 0 + ps4videoRecordingFeaturesUsed: 0 + ps4contentSearchFeaturesUsed: 0 + ps4attribEyeToEyeDistanceSettingVR: 0 + ps4IncludedModules: [] + monoEnv: + psp2Splashimage: {fileID: 0} + psp2NPTrophyPackPath: + psp2NPSupportGBMorGJP: 0 + psp2NPAgeRating: 12 + psp2NPTitleDatPath: + psp2NPCommsID: + psp2NPCommunicationsID: + psp2NPCommsPassphrase: + psp2NPCommsSig: + psp2ParamSfxPath: + psp2ManualPath: + psp2LiveAreaGatePath: + psp2LiveAreaBackroundPath: + psp2LiveAreaPath: + psp2LiveAreaTrialPath: + psp2PatchChangeInfoPath: + psp2PatchOriginalPackage: + psp2PackagePassword: 3onkgZsAECEn0fzCoWiCtWCKe4l74pE5 + psp2KeystoneFile: + psp2MemoryExpansionMode: 0 + psp2DRMType: 0 + psp2StorageType: 0 + psp2MediaCapacity: 0 + psp2DLCConfigPath: + psp2ThumbnailPath: + psp2BackgroundPath: + psp2SoundPath: + psp2TrophyCommId: + psp2TrophyPackagePath: + psp2PackagedResourcesPath: + psp2SaveDataQuota: 10240 + psp2ParentalLevel: 1 + psp2ShortTitle: Not Set + psp2ContentID: IV0000-ABCD12345_00-0123456789ABCDEF + psp2Category: 0 + psp2MasterVersion: 01.00 + psp2AppVersion: 01.00 + psp2TVBootMode: 0 + psp2EnterButtonAssignment: 2 + psp2TVDisableEmu: 0 + psp2AllowTwitterDialog: 1 + psp2Upgradable: 0 + psp2HealthWarning: 0 + psp2UseLibLocation: 0 + psp2InfoBarOnStartup: 0 + psp2InfoBarColor: 0 + psp2ScriptOptimizationLevel: 0 + psmSplashimage: {fileID: 0} + splashScreenBackgroundSourceLandscape: {fileID: 0} + splashScreenBackgroundSourcePortrait: {fileID: 0} + spritePackerPolicy: + webGLMemorySize: 256 + webGLExceptionSupport: 1 + webGLNameFilesAsHashes: 0 + webGLDataCaching: 0 + webGLDebugSymbols: 0 + webGLEmscriptenArgs: + webGLModulesDirectory: + webGLTemplate: APPLICATION:Default + webGLAnalyzeBuildSize: 0 + webGLUseEmbeddedResources: 0 + webGLUseWasm: 0 + webGLCompressionFormat: 1 + scriptingDefineSymbols: {} + platformArchitecture: {} + scriptingBackend: {} + incrementalIl2cppBuild: {} + additionalIl2CppArgs: + scriptingRuntimeVersion: 0 + apiCompatibilityLevelPerPlatform: {} + m_RenderingPath: 1 + m_MobileRenderingPath: 1 + metroPackageName: TensorFlowLitePlugin + metroPackageVersion: + metroCertificatePath: + metroCertificatePassword: + metroCertificateSubject: + metroCertificateIssuer: + metroCertificateNotAfter: 0000000000000000 + metroApplicationDescription: TensorFlowLitePlugin + wsaImages: {} + metroTileShortName: + metroCommandLineArgsFile: + metroTileShowName: 0 + metroMediumTileShowName: 0 + metroLargeTileShowName: 0 + metroWideTileShowName: 0 + metroDefaultTileSize: 1 + metroTileForegroundText: 2 + metroTileBackgroundColor: {r: 0.13333334, g: 0.17254902, b: 0.21568628, a: 0} + metroSplashScreenBackgroundColor: {r: 0.12941177, g: 0.17254902, b: 0.21568628, + a: 1} + metroSplashScreenUseBackgroundColor: 0 + platformCapabilities: {} + metroFTAName: + metroFTAFileTypes: [] + metroProtocolName: + metroCompilationOverrides: 1 + tizenProductDescription: + tizenProductURL: + tizenSigningProfileName: + tizenGPSPermissions: 0 + tizenMicrophonePermissions: 0 + tizenDeploymentTarget: + tizenDeploymentTargetType: -1 + tizenMinOSVersion: 1 + n3dsUseExtSaveData: 0 + n3dsCompressStaticMem: 1 + n3dsExtSaveDataNumber: 0x12345 + n3dsStackSize: 131072 + n3dsTargetPlatform: 2 + n3dsRegion: 7 + n3dsMediaSize: 0 + n3dsLogoStyle: 3 + n3dsTitle: GameName + n3dsProductCode: + n3dsApplicationId: 0xFF3FF + XboxOneProductId: + XboxOneUpdateKey: + XboxOneSandboxId: + XboxOneContentId: + XboxOneTitleId: + XboxOneSCId: + XboxOneGameOsOverridePath: + XboxOnePackagingOverridePath: + XboxOneAppManifestOverridePath: + XboxOnePackageEncryption: 0 + XboxOnePackageUpdateGranularity: 2 + XboxOneDescription: + XboxOneLanguage: + - enus + XboxOneCapability: [] + XboxOneGameRating: {} + XboxOneIsContentPackage: 0 + XboxOneEnableGPUVariability: 0 + XboxOneSockets: {} + XboxOneSplashScreen: {fileID: 0} + XboxOneAllowedProductIds: [] + XboxOnePersistentLocalStorageSize: 0 + XboxOneXTitleMemory: 8 + xboxOneScriptCompiler: 0 + vrEditorSettings: + daydream: + daydreamIconForeground: {fileID: 0} + daydreamIconBackground: {fileID: 0} + cloudServicesEnabled: {} + facebookSdkVersion: 7.9.4 + apiCompatibilityLevel: 2 + cloudProjectId: + projectName: + organizationId: + cloudEnabled: 0 + enableNativePlatformBackendsForNewInputSystem: 0 + disableOldInputManagerSupport: 0 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectVersion.txt b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectVersion.txt new file mode 100644 index 0000000000000000000000000000000000000000..4a9cfb61ab55abc2f0d09b0225a802ef8122eaaf --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/ProjectVersion.txt @@ -0,0 +1 @@ +m_EditorVersion: 2017.4.6f1 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/QualitySettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/QualitySettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..05daac3c4922feef068af19efa921fcbb476afde --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/QualitySettings.asset @@ -0,0 +1,191 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!47 &1 +QualitySettings: + m_ObjectHideFlags: 0 + serializedVersion: 5 + m_CurrentQuality: 5 + m_QualitySettings: + - serializedVersion: 2 + name: Very Low + pixelLightCount: 0 + shadows: 0 + shadowResolution: 0 + shadowProjection: 1 + shadowCascades: 1 + shadowDistance: 15 + shadowNearPlaneOffset: 3 + shadowCascade2Split: 0.33333334 + shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667} + shadowmaskMode: 0 + blendWeights: 1 + textureQuality: 1 + anisotropicTextures: 0 + antiAliasing: 0 + softParticles: 0 + softVegetation: 0 + realtimeReflectionProbes: 0 + billboardsFaceCameraPosition: 0 + vSyncCount: 0 + lodBias: 0.3 + maximumLODLevel: 0 + particleRaycastBudget: 4 + asyncUploadTimeSlice: 2 + asyncUploadBufferSize: 4 + resolutionScalingFixedDPIFactor: 1 + excludedTargetPlatforms: [] + - serializedVersion: 2 + name: Low + pixelLightCount: 0 + shadows: 0 + shadowResolution: 0 + shadowProjection: 1 + shadowCascades: 1 + shadowDistance: 20 + shadowNearPlaneOffset: 3 + shadowCascade2Split: 0.33333334 + shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667} + shadowmaskMode: 0 + blendWeights: 2 + textureQuality: 0 + anisotropicTextures: 0 + antiAliasing: 0 + softParticles: 0 + softVegetation: 0 + realtimeReflectionProbes: 0 + billboardsFaceCameraPosition: 0 + vSyncCount: 0 + lodBias: 0.4 + maximumLODLevel: 0 + particleRaycastBudget: 16 + asyncUploadTimeSlice: 2 + asyncUploadBufferSize: 4 + resolutionScalingFixedDPIFactor: 1 + excludedTargetPlatforms: [] + - serializedVersion: 2 + name: Medium + pixelLightCount: 1 + shadows: 1 + shadowResolution: 0 + shadowProjection: 1 + shadowCascades: 1 + shadowDistance: 20 + shadowNearPlaneOffset: 3 + shadowCascade2Split: 0.33333334 + shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667} + shadowmaskMode: 0 + blendWeights: 2 + textureQuality: 0 + anisotropicTextures: 1 + antiAliasing: 0 + softParticles: 0 + softVegetation: 0 + realtimeReflectionProbes: 0 + billboardsFaceCameraPosition: 0 + vSyncCount: 1 + lodBias: 0.7 + maximumLODLevel: 0 + particleRaycastBudget: 64 + asyncUploadTimeSlice: 2 + asyncUploadBufferSize: 4 + resolutionScalingFixedDPIFactor: 1 + excludedTargetPlatforms: [] + - serializedVersion: 2 + name: High + pixelLightCount: 2 + shadows: 2 + shadowResolution: 1 + shadowProjection: 1 + shadowCascades: 2 + shadowDistance: 40 + shadowNearPlaneOffset: 3 + shadowCascade2Split: 0.33333334 + shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667} + shadowmaskMode: 1 + blendWeights: 2 + textureQuality: 0 + anisotropicTextures: 1 + antiAliasing: 0 + softParticles: 0 + softVegetation: 1 + realtimeReflectionProbes: 1 + billboardsFaceCameraPosition: 1 + vSyncCount: 1 + lodBias: 1 + maximumLODLevel: 0 + particleRaycastBudget: 256 + asyncUploadTimeSlice: 2 + asyncUploadBufferSize: 4 + resolutionScalingFixedDPIFactor: 1 + excludedTargetPlatforms: [] + - serializedVersion: 2 + name: Very High + pixelLightCount: 3 + shadows: 2 + shadowResolution: 2 + shadowProjection: 1 + shadowCascades: 2 + shadowDistance: 70 + shadowNearPlaneOffset: 3 + shadowCascade2Split: 0.33333334 + shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667} + shadowmaskMode: 1 + blendWeights: 4 + textureQuality: 0 + anisotropicTextures: 2 + antiAliasing: 2 + softParticles: 1 + softVegetation: 1 + realtimeReflectionProbes: 1 + billboardsFaceCameraPosition: 1 + vSyncCount: 1 + lodBias: 1.5 + maximumLODLevel: 0 + particleRaycastBudget: 1024 + asyncUploadTimeSlice: 2 + asyncUploadBufferSize: 4 + resolutionScalingFixedDPIFactor: 1 + excludedTargetPlatforms: [] + - serializedVersion: 2 + name: Ultra + pixelLightCount: 4 + shadows: 2 + shadowResolution: 2 + shadowProjection: 1 + shadowCascades: 4 + shadowDistance: 150 + shadowNearPlaneOffset: 3 + shadowCascade2Split: 0.33333334 + shadowCascade4Split: {x: 0.06666667, y: 0.2, z: 0.46666667} + shadowmaskMode: 1 + blendWeights: 4 + textureQuality: 0 + anisotropicTextures: 2 + antiAliasing: 2 + softParticles: 1 + softVegetation: 1 + realtimeReflectionProbes: 1 + billboardsFaceCameraPosition: 1 + vSyncCount: 1 + lodBias: 2 + maximumLODLevel: 0 + particleRaycastBudget: 4096 + asyncUploadTimeSlice: 2 + asyncUploadBufferSize: 4 + resolutionScalingFixedDPIFactor: 1 + excludedTargetPlatforms: [] + m_PerPlatformDefaultQuality: + Android: 2 + Nintendo 3DS: 5 + Nintendo Switch: 5 + PS4: 5 + PSM: 5 + PSP2: 2 + Standalone: 5 + Tizen: 2 + WebGL: 3 + WiiU: 5 + Windows Store Apps: 5 + XboxOne: 5 + iPhone: 2 + tvOS: 2 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TagManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TagManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..1c92a7840ec11895c76785f65d949a3d20d53355 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TagManager.asset @@ -0,0 +1,43 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!78 &1 +TagManager: + serializedVersion: 2 + tags: [] + layers: + - Default + - TransparentFX + - Ignore Raycast + - + - Water + - UI + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + m_SortingLayers: + - name: Default + uniqueID: 0 + locked: 0 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TimeManager.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TimeManager.asset new file mode 100644 index 0000000000000000000000000000000000000000..558a017e1f50b2db73414a1abad3c033922774f8 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/TimeManager.asset @@ -0,0 +1,9 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!5 &1 +TimeManager: + m_ObjectHideFlags: 0 + Fixed Timestep: 0.02 + Maximum Allowed Timestep: 0.33333334 + m_TimeScale: 1 + Maximum Particle Timestep: 0.03 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/UnityConnectSettings.asset b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/UnityConnectSettings.asset new file mode 100644 index 0000000000000000000000000000000000000000..3da14d5baf1fa24df1746c3ce9d969eda3a9c59d --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/ProjectSettings/UnityConnectSettings.asset @@ -0,0 +1,34 @@ +%YAML 1.1 +%TAG !u! tag:unity3d.com,2011: +--- !u!310 &1 +UnityConnectSettings: + m_ObjectHideFlags: 0 + m_Enabled: 0 + m_TestMode: 0 + m_TestEventUrl: + m_TestConfigUrl: + m_TestInitMode: 0 + CrashReportingSettings: + m_EventUrl: https://perf-events.cloud.unity3d.com/api/events/crashes + m_NativeEventUrl: https://perf-events.cloud.unity3d.com/symbolicate + m_Enabled: 0 + m_CaptureEditorExceptions: 1 + UnityPurchasingSettings: + m_Enabled: 0 + m_TestMode: 0 + UnityAnalyticsSettings: + m_Enabled: 0 + m_InitializeOnStartup: 1 + m_TestMode: 0 + m_TestEventUrl: + m_TestConfigUrl: + UnityAdsSettings: + m_Enabled: 0 + m_InitializeOnStartup: 1 + m_TestMode: 0 + m_IosGameId: + m_AndroidGameId: + m_GameIds: {} + m_GameId: + PerformanceReportingSettings: + m_Enabled: 0 diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f480c49cd050de2192e9673f72c9e4d5c3c6ceff --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/README.md @@ -0,0 +1,29 @@ +# TF Lite Experimental Unity Plugin + +This directory contains an experimental sample Unity (2017) Plugin, based on +the experimental TF Lite C API. The sample demonstrates running inference within +Unity by way of a C# `Interpreter` wrapper. + +Note that the native TF Lite plugin(s) *must* be built before using the Unity +Plugin, and placed in Assets/TensorFlowLite/SDK/Plugins/. For the editor (note +that this has only been tested on Linux; the syntax may differ on Mac/Windows): + +```sh +bazel build -c opt --cxxopt=--std=c++11 \ + //tensorflow/contrib/lite/experimental/c:libtensorflowlite_c.so +``` + +and for Android: + +```sh +bazel build -c opt --cxxopt=--std=c++11 \ + --crosstool_top=//external:android/crosstool \ + --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ + --cpu=armeabi-v7a \ + //tensorflow/contrib/lite/experimental/c:libtensorflowlite_c.so +``` + +If you encounter issues with native plugin discovery on Mac ("Darwin") +platforms, try renaming `libtensorflowlite_c.so` to `tensorflowlite_c.bundle`. +Similarly, on Windows you'll likely need to rename `libtensorflowlite_c.so` to +`tensorflowlite_c.dll`. diff --git a/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/UnityPackageManager/manifest.json b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/UnityPackageManager/manifest.json new file mode 100644 index 0000000000000000000000000000000000000000..526aca60573f334a6b6bd536fa5be9c26d678e0f --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/unity/TensorFlowLitePlugin/UnityPackageManager/manifest.json @@ -0,0 +1,4 @@ +{ + "dependencies": { + } +} diff --git a/tensorflow/contrib/lite/experimental/kernels/BUILD b/tensorflow/contrib/lite/experimental/kernels/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..9c06c4ebd958294586dbb1fde5040a0d328954ac --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/BUILD @@ -0,0 +1,84 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +# ctc support classes imported directly from TensorFlow. +cc_library( + name = "ctc_utils", + hdrs = [ + "ctc_beam_entry.h", + "ctc_beam_scorer.h", + "ctc_beam_search.h", + "ctc_decoder.h", + "ctc_loss_util.h", + ], + deps = [ + ":top_n", + "//tensorflow/contrib/lite/kernels/internal:types", + "//third_party/eigen3", + ], +) + +# top_n support classes imported directly from TensorFlow. +cc_library( + name = "top_n", + hdrs = [ + "top_n.h", + ], + deps = [ + "//tensorflow/contrib/lite/kernels/internal:types", + ], +) + +cc_library( + name = "experimental_ops", + srcs = [ + "ctc_beam_search_decoder.cc", + ], + # Suppress warnings that are introduced by Eigen Tensor. + copts = tflite_copts() + [ + "-Wno-error=reorder", + ] + select({ + "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"], + "//conditions:default": [ + ], + }), + deps = [ + ":ctc_utils", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:gemm_support", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/contrib/lite/kernels:op_macros", + "//tensorflow/contrib/lite/kernels/internal:kernel_utils", + "//tensorflow/contrib/lite/kernels/internal:optimized", + "//tensorflow/contrib/lite/kernels/internal:optimized_base", + "//tensorflow/contrib/lite/kernels/internal:quantization_util", + "//tensorflow/contrib/lite/kernels/internal:reference", + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/contrib/lite/kernels/internal:tensor_utils", + "@flatbuffers", + ], +) + +tf_cc_test( + name = "ctc_beam_search_decoder_test", + size = "small", + srcs = ["ctc_beam_search_decoder_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":experimental_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h new file mode 100644 index 0000000000000000000000000000000000000000..a60ff2a1c53f1b3f9f490ab5cf2bc429ba09dff0 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h @@ -0,0 +1,150 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/util/ctc/ctc_beam_entry.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ + +#include +#include +#include +#include + +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h" + +namespace tflite { +namespace experimental { +namespace ctc { + +// The ctc_beam_search namespace holds several classes meant to be accessed only +// in case of extending the CTCBeamSearch decoder to allow custom scoring +// functions. +// +// BeamEntry is exposed through template arguments BeamScorer and BeamComparer +// of CTCBeamSearch (ctc_beam_search.h). +namespace ctc_beam_search { + +struct EmptyBeamState {}; + +struct BeamProbability { + BeamProbability() : total(kLogZero), blank(kLogZero), label(kLogZero) {} + void Reset() { + total = kLogZero; + blank = kLogZero; + label = kLogZero; + } + float total; + float blank; + float label; +}; + +template +class BeamRoot; + +template +struct BeamEntry { + // BeamRoot::AddEntry() serves as the factory method. + friend BeamEntry* BeamRoot::AddEntry( + BeamEntry* p, int l); + inline bool Active() const { return newp.total != kLogZero; } + // Return the child at the given index, or construct a new one in-place if + // none was found. + BeamEntry& GetChild(int ind) { + auto entry = children.emplace(ind, nullptr); + auto& child_entry = entry.first->second; + // If this is a new child, populate the BeamEntry*. + if (entry.second) { + child_entry = beam_root->AddEntry(this, ind); + } + return *child_entry; + } + std::vector LabelSeq(bool merge_repeated) const { + std::vector labels; + int prev_label = -1; + const BeamEntry* c = this; + while (c->parent != nullptr) { // Checking c->parent to skip root leaf. + if (!merge_repeated || c->label != prev_label) { + labels.push_back(c->label); + } + prev_label = c->label; + c = c->parent; + } + std::reverse(labels.begin(), labels.end()); + return labels; + } + + BeamEntry* parent; + int label; + // All instances of child BeamEntry are owned by *beam_root. + std::unordered_map*> children; + BeamProbability oldp; + BeamProbability newp; + CTCBeamState state; + + private: + // Constructor giving parent, label, and the beam_root. + // The object pointed to by p cannot be copied and should not be moved, + // otherwise parent will become invalid. + // This private constructor is only called through the factory method + // BeamRoot::AddEntry(). + BeamEntry(BeamEntry* p, int l, BeamRoot* beam_root) + : parent(p), label(l), beam_root(beam_root) {} + BeamRoot* beam_root; + + BeamEntry(const BeamEntry&) = delete; + void operator=(const BeamEntry&) = delete; +}; + +// This class owns all instances of BeamEntry. This is used to avoid recursive +// destructor call during destruction. +template +class BeamRoot { + public: + BeamRoot(BeamEntry* p, int l) { root_entry_ = AddEntry(p, l); } + BeamRoot(const BeamRoot&) = delete; + BeamRoot& operator=(const BeamRoot&) = delete; + + BeamEntry* AddEntry(BeamEntry* p, int l) { + auto* new_entry = new BeamEntry(p, l, this); + beam_entries_.emplace_back(new_entry); + return new_entry; + } + BeamEntry* RootEntry() const { return root_entry_; } + + private: + BeamEntry* root_entry_ = nullptr; + std::vector>> beam_entries_; +}; + +// BeamComparer is the default beam comparer provided in CTCBeamSearch. +template +class BeamComparer { + public: + virtual ~BeamComparer() {} + virtual bool inline operator()(const BeamEntry* a, + const BeamEntry* b) const { + return a->newp.total > b->newp.total; + } +}; + +} // namespace ctc_beam_search + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_ENTRY_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h new file mode 100644 index 0000000000000000000000000000000000000000..ec60e26257b0f4126e7a7abed6a663abe277ef12 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Collection of scoring classes that can be extended and provided to the +// CTCBeamSearchDecoder to incorporate additional scoring logic (such as a +// language model). +// +// To build a custom scorer extend and implement the pure virtual methods from +// BeamScorerInterface. The default CTC decoding behavior is implemented +// through BaseBeamScorer. + +// Copied from tensorflow/core/util/ctc/ctc_beam_scorer.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_ + +#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h" + +namespace tflite { +namespace experimental { +namespace ctc { + +// Base implementation of a beam scorer used by default by the decoder that can +// be subclassed and provided as an argument to CTCBeamSearchDecoder, if complex +// scoring is required. Its main purpose is to provide a thin layer for +// integrating language model scoring easily. +template +class BaseBeamScorer { + public: + virtual ~BaseBeamScorer() {} + // State initialization. + virtual void InitializeState(CTCBeamState* root) const {} + // ExpandState is called when expanding a beam to one of its children. + // Called at most once per child beam. In the simplest case, no state + // expansion is done. + virtual void ExpandState(const CTCBeamState& from_state, int from_label, + CTCBeamState* to_state, int to_label) const {} + // ExpandStateEnd is called after decoding has finished. Its purpose is to + // allow a final scoring of the beam in its current state, before resorting + // and retrieving the TopN requested candidates. Called at most once per beam. + virtual void ExpandStateEnd(CTCBeamState* state) const {} + // GetStateExpansionScore should be an inexpensive method to retrieve the + // (cached) expansion score computed within ExpandState. The score is + // multiplied (log-addition) with the input score at the current step from + // the network. + // + // The score returned should be a log-probability. In the simplest case, as + // there's no state expansion logic, the expansion score is zero. + virtual float GetStateExpansionScore(const CTCBeamState& state, + float previous_score) const { + return previous_score; + } + // GetStateEndExpansionScore should be an inexpensive method to retrieve the + // (cached) expansion score computed within ExpandStateEnd. The score is + // multiplied (log-addition) with the final probability of the beam. + // + // The score returned should be a log-probability. + virtual float GetStateEndExpansionScore(const CTCBeamState& state) const { + return 0; + } +}; + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SCORER_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h new file mode 100644 index 0000000000000000000000000000000000000000..c658e43092519ba29d880a670a890af148230091 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h @@ -0,0 +1,420 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/util/ctc/ctc_beam_search.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_ + +#include +#include +#include +#include +#include + +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_entry.h" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_scorer.h" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h" +#include "tensorflow/contrib/lite/experimental/kernels/top_n.h" +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" + +namespace tflite { +namespace experimental { +namespace ctc { + +template > +class CTCBeamSearchDecoder : public CTCDecoder { + // Beam Search + // + // Example (GravesTh Fig. 7.5): + // a - + // P = [ 0.3 0.7 ] t = 0 + // [ 0.4 0.6 ] t = 1 + // + // Then P(l = -) = P(--) = 0.7 * 0.6 = 0.42 + // P(l = a) = P(a-) + P(aa) + P(-a) = 0.3*0.4 + ... = 0.58 + // + // In this case, Best Path decoding is suboptimal. + // + // For Beam Search, we use the following main recurrence relations: + // + // Relation 1: + // ---------------------------------------------------------- Eq. 1 + // P(l=abcd @ t=7) = P(l=abc @ t=6) * P(d @ 7) + // + P(l=abcd @ t=6) * (P(d @ 7) + P(- @ 7)) + // where P(l=? @ t=7), ? = a, ab, abc, abcd are all stored and + // updated recursively in the beam entry. + // + // Relation 2: + // ---------------------------------------------------------- Eq. 2 + // P(l=abc? @ t=3) = P(l=abc @ t=2) * P(? @ 3) + // for ? in a, b, d, ..., (not including c or the blank index), + // and the recurrence starts from the beam entry for P(l=abc @ t=2). + // + // For this case, the length of the new sequence equals t+1 (t + // starts at 0). This special case can be calculated as: + // P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3) + // but we calculate it recursively for speed purposes. + typedef ctc_beam_search::BeamEntry BeamEntry; + typedef ctc_beam_search::BeamRoot BeamRoot; + typedef ctc_beam_search::BeamProbability BeamProbability; + + public: + typedef BaseBeamScorer DefaultBeamScorer; + + // The beam search decoder is constructed specifying the beam_width (number of + // candidates to keep at each decoding timestep) and a beam scorer (used for + // custom scoring, for example enabling the use of a language model). + // The ownership of the scorer remains with the caller. The default + // implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the + // standard beam search. + CTCBeamSearchDecoder(int num_classes, int beam_width, + BaseBeamScorer* scorer, int batch_size = 1, + bool merge_repeated = false) + : CTCDecoder(num_classes, batch_size, merge_repeated), + beam_width_(beam_width), + leaves_(beam_width), + beam_scorer_(scorer) { + Reset(); + } + + ~CTCBeamSearchDecoder() override {} + + // Run the hibernating beam search algorithm on the given input. + bool Decode(const CTCDecoder::SequenceLength& seq_len, + const std::vector& input, + std::vector* output, + CTCDecoder::ScoreOutput* scores) override; + + // Calculate the next step of the beam search and update the internal state. + template + void Step(const Vector& log_input_t); + + template + float GetTopK(const int K, const Vector& input, + std::vector* top_k_logits, + std::vector* top_k_indices); + + // Retrieve the beam scorer instance used during decoding. + BaseBeamScorer* GetBeamScorer() const { return beam_scorer_; } + + // Set label selection parameters for faster decoding. + // See comments for label_selection_size_ and label_selection_margin_. + void SetLabelSelectionParameters(int label_selection_size, + float label_selection_margin) { + label_selection_size_ = label_selection_size; + label_selection_margin_ = label_selection_margin; + } + + // Reset the beam search + void Reset(); + + // Extract the top n paths at current time step + bool TopPaths(int n, std::vector>* paths, + std::vector* log_probs, bool merge_repeated) const; + + private: + int beam_width_; + + // Label selection is designed to avoid possibly very expensive scorer calls, + // by pruning the hypotheses based on the input alone. + // Label selection size controls how many items in each beam are passed + // through to the beam scorer. Only items with top N input scores are + // considered. + // Label selection margin controls the difference between minimal input score + // (versus the best scoring label) for an item to be passed to the beam + // scorer. This margin is expressed in terms of log-probability. + // Default is to do no label selection. + // For more detail: https://research.google.com/pubs/pub44823.html + int label_selection_size_ = 0; // zero means unlimited + float label_selection_margin_ = -1; // -1 means unlimited. + + gtl::TopN leaves_; + std::unique_ptr beam_root_; + BaseBeamScorer* beam_scorer_; + + CTCBeamSearchDecoder(const CTCBeamSearchDecoder&) = delete; + void operator=(const CTCBeamSearchDecoder&) = delete; +}; + +template +bool CTCBeamSearchDecoder::Decode( + const CTCDecoder::SequenceLength& seq_len, + const std::vector& input, + std::vector* output, ScoreOutput* scores) { + // Storage for top paths. + std::vector> beams; + std::vector beam_log_probabilities; + int top_n = output->size(); + if (std::any_of(output->begin(), output->end(), + [this](const CTCDecoder::Output& output) -> bool { + return output.size() < this->batch_size_; + })) { + return false; + } + if (scores->rows() < batch_size_ || scores->cols() < top_n) { + return false; + } + + for (int b = 0; b < batch_size_; ++b) { + int seq_len_b = seq_len[b]; + Reset(); + + for (int t = 0; t < seq_len_b; ++t) { + // Pass log-probabilities for this example + time. + Step(input[t].row(b)); + } // for (int t... + + // O(n * log(n)) + std::unique_ptr> branches(leaves_.Extract()); + leaves_.Reset(); + for (int i = 0; i < branches->size(); ++i) { + BeamEntry* entry = (*branches)[i]; + beam_scorer_->ExpandStateEnd(&entry->state); + entry->newp.total += + beam_scorer_->GetStateEndExpansionScore(entry->state); + leaves_.push(entry); + } + + bool status = + TopPaths(top_n, &beams, &beam_log_probabilities, merge_repeated_); + if (!status) { + return status; + } + + TFLITE_DCHECK_EQ(top_n, beam_log_probabilities.size()); + TFLITE_DCHECK_EQ(beams.size(), beam_log_probabilities.size()); + + for (int i = 0; i < top_n; ++i) { + // Copy output to the correct beam + batch + (*output)[i][b].swap(beams[i]); + (*scores)(b, i) = -beam_log_probabilities[i]; + } + } // for (int b... + return true; +} + +template +template +float CTCBeamSearchDecoder::GetTopK( + const int K, const Vector& input, std::vector* top_k_logits, + std::vector* top_k_indices) { + // Find Top K choices, complexity nk in worst case. The array input is read + // just once. + TFLITE_DCHECK_EQ(num_classes_, input.size()); + top_k_logits->clear(); + top_k_indices->clear(); + top_k_logits->resize(K, -INFINITY); + top_k_indices->resize(K, -1); + for (int j = 0; j < num_classes_ - 1; ++j) { + const float logit = input(j); + if (logit > (*top_k_logits)[K - 1]) { + int k = K - 1; + while (k > 0 && logit > (*top_k_logits)[k - 1]) { + (*top_k_logits)[k] = (*top_k_logits)[k - 1]; + (*top_k_indices)[k] = (*top_k_indices)[k - 1]; + k--; + } + (*top_k_logits)[k] = logit; + (*top_k_indices)[k] = j; + } + } + // Return max value which is in 0th index or blank character logit + return std::max((*top_k_logits)[0], input(num_classes_ - 1)); +} + +template +template +void CTCBeamSearchDecoder::Step( + const Vector& raw_input) { + std::vector top_k_logits; + std::vector top_k_indices; + const bool top_k = + (label_selection_size_ > 0 && label_selection_size_ < raw_input.size()); + // Number of character classes to consider in each step. + const int max_classes = top_k ? label_selection_size_ : (num_classes_ - 1); + // Get max coefficient and remove it from raw_input later. + float max_coeff; + if (top_k) { + max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits, + &top_k_indices); + } else { + max_coeff = raw_input.maxCoeff(); + } + const float label_selection_input_min = + (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_) + : -std::numeric_limits::infinity(); + + // Extract the beams sorted in decreasing new probability + TFLITE_DCHECK_EQ(num_classes_, raw_input.size()); + + std::unique_ptr> branches(leaves_.Extract()); + leaves_.Reset(); + + for (BeamEntry* b : *branches) { + // P(.. @ t) becomes the new P(.. @ t-1) + b->oldp = b->newp; + } + + for (BeamEntry* b : *branches) { + if (b->parent != nullptr) { // if not the root + if (b->parent->Active()) { + // If last two sequence characters are identical: + // Plabel(l=acc @ t=6) = (Plabel(l=acc @ t=5) + // + Pblank(l=ac @ t=5)) + // else: + // Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5) + // + P(l=ab @ t=5)) + float previous = (b->label == b->parent->label) ? b->parent->oldp.blank + : b->parent->oldp.total; + b->newp.label = + LogSumExp(b->newp.label, + beam_scorer_->GetStateExpansionScore(b->state, previous)); + } + // Plabel(l=abc @ t=6) *= P(c @ 6) + b->newp.label += raw_input(b->label) - max_coeff; + } + // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6) + b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff; + // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6) + b->newp.total = LogSumExp(b->newp.blank, b->newp.label); + + // Push the entry back to the top paths list. + // Note, this will always fill leaves back up in sorted order. + leaves_.push(b); + } + + // we need to resort branches in descending oldp order. + + // branches is in descending oldp order because it was + // originally in descending newp order and we copied newp to oldp. + + // Grow new leaves + for (BeamEntry* b : *branches) { + // A new leaf (represented by its BeamProbability) is a candidate + // iff its total probability is nonzero and either the beam list + // isn't full, or the lowest probability entry in the beam has a + // lower probability than the leaf. + auto is_candidate = [this](const BeamProbability& prob) { + return (prob.total > kLogZero && + (leaves_.size() < beam_width_ || + prob.total > leaves_.peek_bottom()->newp.total)); + }; + + if (!is_candidate(b->oldp)) { + continue; + } + + for (int ind = 0; ind < max_classes; ind++) { + const int label = top_k ? top_k_indices[ind] : ind; + const float logit = top_k ? top_k_logits[ind] : raw_input(ind); + // Perform label selection: if input for this label looks very + // unpromising, never evaluate it with a scorer. + if (logit < label_selection_input_min) { + continue; + } + BeamEntry& c = b->GetChild(label); + if (!c.Active()) { + // Pblank(l=abcd @ t=6) = 0 + c.newp.blank = kLogZero; + // If new child label is identical to beam label: + // Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6) + // Otherwise: + // Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6) + beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label); + float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total; + c.newp.label = logit - max_coeff + + beam_scorer_->GetStateExpansionScore(c.state, previous); + // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6) + c.newp.total = c.newp.label; + + if (is_candidate(c.newp)) { + // Before adding the new node to the beam, check if the beam + // is already at maximum width. + if (leaves_.size() == beam_width_) { + // Bottom is no longer in the beam search. Reset + // its probability; signal it's no longer in the beam search. + BeamEntry* bottom = leaves_.peek_bottom(); + bottom->newp.Reset(); + } + leaves_.push(&c); + } else { + // Deactivate child. + c.oldp.Reset(); + c.newp.Reset(); + } + } + } + } // for (BeamEntry* b... +} + +template +void CTCBeamSearchDecoder::Reset() { + leaves_.Reset(); + + // This beam root, and all of its children, will be in memory until + // the next reset. + beam_root_.reset(new BeamRoot(nullptr, -1)); + beam_root_->RootEntry()->newp.total = 0.0; // ln(1) + beam_root_->RootEntry()->newp.blank = 0.0; // ln(1) + + // Add the root as the initial leaf. + leaves_.push(beam_root_->RootEntry()); + + // Call initialize state on the root object. + beam_scorer_->InitializeState(&beam_root_->RootEntry()->state); +} + +template +bool CTCBeamSearchDecoder::TopPaths( + int n, std::vector>* paths, std::vector* log_probs, + bool merge_repeated) const { + TFLITE_DCHECK(paths); + TFLITE_DCHECK(log_probs); + paths->clear(); + log_probs->clear(); + if (n > beam_width_) { + return false; + } + if (n > leaves_.size()) { + return false; + } + + gtl::TopN top_branches(n); + + // O(beam_width_ * log(n)), space complexity is O(n) + for (auto it = leaves_.unsorted_begin(); it != leaves_.unsorted_end(); ++it) { + top_branches.push(*it); + } + // O(n * log(n)) + std::unique_ptr> branches(top_branches.Extract()); + + for (int i = 0; i < n; ++i) { + BeamEntry* e((*branches)[i]); + paths->push_back(e->LabelSeq(merge_repeated)); + log_probs->push_back(e->newp.total); + } + return true; +} + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_BEAM_SEARCH_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc new file mode 100644 index 0000000000000000000000000000000000000000..834d1ebd666db2be46394166edadf2a166d958aa --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc @@ -0,0 +1,247 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace experimental { +namespace ctc_beam_search_decoder { + +constexpr int kInputsTensor = 0; +constexpr int kSequenceLengthTensor = 1; + +typedef struct { + int beam_width; + int top_paths; + bool merge_repeated; +} CTCBeamSearchDecoderParams; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_CHECK(buffer != nullptr); + const uint8_t* buffer_t = reinterpret_cast(buffer); + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); + + CTCBeamSearchDecoderParams* option = new CTCBeamSearchDecoderParams; + option->beam_width = m["beam_width"].AsInt32(); + option->top_paths = m["top_paths"].AsInt32(); + option->merge_repeated = m["merge_repeated"].AsBool(); + + return option; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const CTCBeamSearchDecoderParams* option = + reinterpret_cast(node->user_data); + const int top_paths = option->top_paths; + TF_LITE_ENSURE(context, option->beam_width >= top_paths); + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + // The outputs should be top_paths * 3 + 1. + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3 * top_paths + 1); + + const TfLiteTensor* inputs = GetInput(context, node, kInputsTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(inputs), 3); + // TensorFlow only supports float. + TF_LITE_ENSURE_EQ(context, inputs->type, kTfLiteFloat32); + const int batch_size = SizeOfDimension(inputs, 1); + + const TfLiteTensor* sequence_length = + GetInput(context, node, kSequenceLengthTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(sequence_length), 1); + TF_LITE_ENSURE_EQ(context, NumElements(sequence_length), batch_size); + // TensorFlow only supports int32. + TF_LITE_ENSURE_EQ(context, sequence_length->type, kTfLiteInt32); + + // Resize decoded outputs. + // Do not resize indices & values cause we don't know the values yet. + for (int i = 0; i < top_paths; ++i) { + TfLiteTensor* indices = GetOutput(context, node, i); + SetTensorToDynamic(indices); + TfLiteTensor* values = GetOutput(context, node, i + top_paths); + SetTensorToDynamic(values); + TfLiteTensor* output_shape = GetOutput(context, node, i + 2 * top_paths); + SetTensorToDynamic(output_shape); + } + + // Resize log probability outputs. + TfLiteTensor* log_probability_output = + GetOutput(context, node, top_paths * 3); + TfLiteIntArray* log_probability_output_shape_array = TfLiteIntArrayCreate(2); + log_probability_output_shape_array->data[0] = batch_size; + log_probability_output_shape_array->data[1] = top_paths; + return context->ResizeTensor(context, log_probability_output, + log_probability_output_shape_array); +} + +TfLiteStatus Resize(TfLiteContext* context, + std::initializer_list output_shape, + TfLiteTensor* output) { + const int dimensions = output_shape.size(); + TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(dimensions); + int i = 0; + for (const int v : output_shape) { + output_shape_array->data[i++] = v; + } + return context->ResizeTensor(context, output, output_shape_array); +} + +TfLiteStatus StoreAllDecodedSequences( + TfLiteContext* context, + const std::vector>>& sequences, + TfLiteNode* node, int top_paths) { + const int32_t batch_size = sequences.size(); + std::vector num_entries(top_paths, 0); + + // Calculate num_entries per path + for (const auto& batch_s : sequences) { + TF_LITE_ENSURE_EQ(context, batch_s.size(), top_paths); + for (int p = 0; p < top_paths; ++p) { + num_entries[p] += batch_s[p].size(); + } + } + + for (int p = 0; p < top_paths; ++p) { + const int32_t p_num = num_entries[p]; + + // Resize the decoded outputs. + TfLiteTensor* indices = GetOutput(context, node, p); + TF_LITE_ENSURE_OK(context, Resize(context, {p_num, 2}, indices)); + + TfLiteTensor* values = GetOutput(context, node, p + top_paths); + TF_LITE_ENSURE_OK(context, Resize(context, {p_num}, values)); + + TfLiteTensor* decoded_shape = GetOutput(context, node, p + 2 * top_paths); + TF_LITE_ENSURE_OK(context, Resize(context, {2}, decoded_shape)); + + int32_t max_decoded = 0; + int32_t offset = 0; + + int32_t* indices_data = GetTensorData(indices); + int32_t* values_data = GetTensorData(values); + int32_t* decoded_shape_data = GetTensorData(decoded_shape); + for (int b = 0; b < batch_size; ++b) { + auto& p_batch = sequences[b][p]; + int32_t num_decoded = p_batch.size(); + max_decoded = std::max(max_decoded, num_decoded); + + std::copy_n(p_batch.begin(), num_decoded, values_data + offset); + for (int32_t t = 0; t < num_decoded; ++t, ++offset) { + indices_data[offset * 2] = b; + indices_data[offset * 2 + 1] = t; + } + } + + decoded_shape_data[0] = batch_size; + decoded_shape_data[1] = max_decoded; + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* inputs = GetInput(context, node, kInputsTensor); + const TfLiteTensor* sequence_length = + GetInput(context, node, kSequenceLengthTensor); + const CTCBeamSearchDecoderParams* option = + reinterpret_cast(node->user_data); + + const int max_time = SizeOfDimension(inputs, 0); + const int batch_size = SizeOfDimension(inputs, 1); + const int num_classes = SizeOfDimension(inputs, 2); + + const int beam_width = option->beam_width; + const int top_paths = option->top_paths; + const bool merge_repeated = option->merge_repeated; + + // Validate sequence length is less or equal than max time. + for (int i = 0; i < batch_size; ++i) { + TF_LITE_ENSURE(context, + max_time >= GetTensorData(sequence_length)[i]); + } + + // The following logic is implemented like + // tensorflow/core/kernels/ctc_decoder_ops.cc + std::vector::UnalignedConstMatrix> input_list_t; + + for (std::size_t t = 0; t < max_time; ++t) { + input_list_t.emplace_back( + GetTensorData(inputs) + t * batch_size * num_classes, batch_size, + num_classes); + } + + ::tflite::experimental::ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer + beam_scorer; + ::tflite::experimental::ctc::CTCBeamSearchDecoder<> beam_search( + num_classes, beam_width, &beam_scorer, 1 /* batch_size */, + merge_repeated); + + // Allocate temporary memory for holding chip operation data. + float* input_chip_t_data = + static_cast(malloc(num_classes * sizeof(float))); + Eigen::array dims; + dims[0] = num_classes; + optimized_ops::TTypes::Flat input_chip_t(input_chip_t_data, dims); + + std::vector>> best_paths(batch_size); + std::vector log_probs; + + TfLiteTensor* log_probabilities = GetOutput(context, node, 3 * top_paths); + float* log_probabilities_output = GetTensorData(log_probabilities); + + // Assumption: the blank index is num_classes - 1 + for (int b = 0; b < batch_size; ++b) { + auto& best_paths_b = best_paths[b]; + best_paths_b.resize(top_paths); + for (int t = 0; t < GetTensorData(sequence_length)[b]; ++t) { + input_chip_t = input_list_t[t].chip(b, 0); + auto input_bi = + Eigen::Map(input_chip_t.data(), num_classes); + beam_search.Step(input_bi); + } + TF_LITE_ENSURE(context, beam_search.TopPaths(top_paths, &best_paths_b, + &log_probs, merge_repeated)); + beam_search.Reset(); + + // Fill in log_probabilities output. + for (int bp = 0; bp < top_paths; ++bp) { + log_probabilities_output[b * top_paths + bp] = log_probs[bp]; + } + } + + free(input_chip_t_data); + return StoreAllDecodedSequences(context, best_paths, node, top_paths); +} + +} // namespace ctc_beam_search_decoder + +TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER() { + static TfLiteRegistration r = { + ctc_beam_search_decoder::Init, ctc_beam_search_decoder::Free, + ctc_beam_search_decoder::Prepare, ctc_beam_search_decoder::Eval}; + return &r; +} + +} // namespace experimental +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9d1e6a562f00905d1db7f7e055ac1c6b1cc34f9e --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc @@ -0,0 +1,238 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace ops { +namespace experimental { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +TfLiteRegistration* Register_CTC_BEAM_SEARCH_DECODER(); + +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +class CTCBeamSearchDecoderOpModel : public SingleOpModel { + public: + CTCBeamSearchDecoderOpModel(std::initializer_list input_shape, + std::initializer_list sequence_length_shape, + int beam_width, int top_paths, + bool merge_repeated) { + inputs_ = AddInput(TensorType_FLOAT32); + sequence_length_ = AddInput(TensorType_INT32); + + for (int i = 0; i < top_paths * 3; ++i) { + outputs_.push_back(AddOutput(TensorType_INT32)); + } + outputs_.push_back(AddOutput(TensorType_FLOAT32)); + + flexbuffers::Builder fbb; + fbb.Map([&]() { + fbb.Int("beam_width", beam_width); + fbb.Int("top_paths", top_paths); + fbb.Bool("merge_repeated", merge_repeated); + }); + fbb.Finish(); + SetCustomOp("CTCBeamSearchDecoder", fbb.GetBuffer(), + Register_CTC_BEAM_SEARCH_DECODER); + BuildInterpreter({input_shape, sequence_length_shape}); + } + + int inputs() { return inputs_; } + + int sequence_length() { return sequence_length_; } + + std::vector> GetDecodedOutpus() { + std::vector> outputs; + for (int i = 0; i < outputs_.size() - 1; ++i) { + outputs.push_back(ExtractVector(outputs_[i])); + } + return outputs; + } + + std::vector GetLogProbabilitiesOutput() { + return ExtractVector(outputs_[outputs_.size() - 1]); + } + + std::vector> GetOutputShapes() { + std::vector> output_shapes; + for (const int output : outputs_) { + output_shapes.push_back(GetTensorShape(output)); + } + return output_shapes; + } + + private: + int inputs_; + int sequence_length_; + std::vector outputs_; +}; + +TEST(CTCBeamSearchTest, SimpleTest) { + CTCBeamSearchDecoderOpModel m({2, 1, 2}, {1}, 1, 1, true); + m.PopulateTensor(m.inputs(), + {-0.50922557, -1.35512652, -2.55445064, -1.58419356}); + m.PopulateTensor(m.sequence_length(), {2}); + m.Invoke(); + + // Make sure the output shapes are right. + const std::vector>& output_shapes = m.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 4); + EXPECT_THAT(output_shapes[0], ElementsAre(1, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(1)); + EXPECT_THAT(output_shapes[2], ElementsAre(2)); + EXPECT_THAT(output_shapes[3], ElementsAre(1, 1)); + + // Check decoded outputs. + const std::vector>& decoded_outputs = m.GetDecodedOutpus(); + EXPECT_EQ(decoded_outputs.size(), 3); + EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0)); + EXPECT_THAT(decoded_outputs[1], ElementsAre(0)); + EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 1)); + // Check log probabilities output. + EXPECT_THAT(m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear({0.32134813}))); +} + +TEST(CTCBeamSearchTest, MultiBatchTest) { + CTCBeamSearchDecoderOpModel m({3, 3, 3}, {3}, 1, 1, true); + m.PopulateTensor( + m.inputs(), + {-0.63649208, -0.00487571, -0.04249819, -0.67754697, -1.0341399, + -2.14717721, -0.77686821, -3.41973774, -0.05151402, -0.21482619, + -0.57411168, -1.45039917, -0.73769373, -2.10941739, -0.44818325, + -0.25287673, -2.80057302, -0.54748312, -0.73334867, -0.86537719, + -0.2065197, -0.18725838, -1.42770405, -0.86051965, -1.61642301, + -2.07275114, -0.9201845}); + m.PopulateTensor(m.sequence_length(), {3, 3, 3}); + m.Invoke(); + + // Make sure the output shapes are right. + const std::vector>& output_shapes = m.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 4); + EXPECT_THAT(output_shapes[0], ElementsAre(4, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(4)); + EXPECT_THAT(output_shapes[2], ElementsAre(2)); + EXPECT_THAT(output_shapes[3], ElementsAre(3, 1)); + + // Check decoded outputs. + const std::vector>& decoded_outputs = m.GetDecodedOutpus(); + EXPECT_EQ(decoded_outputs.size(), 3); + EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 0, 1, 1, 0, 2, 0)); + EXPECT_THAT(decoded_outputs[1], ElementsAre(1, 0, 0, 0)); + EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 2)); + // Check log probabilities output. + EXPECT_THAT( + m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear({0.46403232, 0.49500442, 0.40443572}))); +} + +TEST(CTCBeamSearchTest, MultiPathsTest) { + CTCBeamSearchDecoderOpModel m({3, 2, 5}, {2}, 3, 2, true); + m.PopulateTensor( + m.inputs(), + {-2.206851, -0.09542714, -0.2393415, -3.81866197, -0.27241158, + -0.20371124, -0.68236623, -1.1397166, -0.17422639, -1.85224048, + -0.9406037, -0.32544678, -0.21846784, -0.38377237, -0.33498676, + -0.10139782, -0.51886883, -0.21678554, -0.15267063, -1.91164412, + -0.31328673, -0.27462716, -0.65975336, -1.53671973, -2.76554225, + -0.23920634, -1.2370502, -4.98751576, -3.12995717, -0.43129368}); + m.PopulateTensor(m.sequence_length(), {3, 3}); + m.Invoke(); + + // Make sure the output shapes are right. + const std::vector>& output_shapes = m.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 7); + EXPECT_THAT(output_shapes[0], ElementsAre(4, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(3, 2)); + EXPECT_THAT(output_shapes[2], ElementsAre(4)); + EXPECT_THAT(output_shapes[3], ElementsAre(3)); + EXPECT_THAT(output_shapes[4], ElementsAre(2)); + EXPECT_THAT(output_shapes[5], ElementsAre(2)); + EXPECT_THAT(output_shapes[6], ElementsAre(2, 2)); + + // Check decoded outputs. + const std::vector>& decoded_outputs = m.GetDecodedOutpus(); + EXPECT_EQ(decoded_outputs.size(), 6); + EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 0, 1, 1, 0, 1, 1)); + EXPECT_THAT(decoded_outputs[1], ElementsAre(0, 0, 0, 1, 1, 0)); + EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 2, 3, 0)); + EXPECT_THAT(decoded_outputs[3], ElementsAre(2, 1, 0)); + EXPECT_THAT(decoded_outputs[4], ElementsAre(2, 2)); + EXPECT_THAT(decoded_outputs[5], ElementsAre(2, 2)); + // Check log probabilities output. + EXPECT_THAT(m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear( + {0.91318405, 0.9060272, 1.0780245, 0.64358956}))); +} + +TEST(CTCBeamSearchTest, NonEqualSequencesTest) { + CTCBeamSearchDecoderOpModel m({3, 3, 4}, {3}, 3, 1, true); + m.PopulateTensor( + m.inputs(), + {-1.26658163, -0.25760023, -0.03917975, -0.63772235, -0.03794756, + -0.45063099, -0.27706473, -0.01569179, -0.59940385, -0.35700127, + -0.48920721, -1.42635476, -1.3462478, -0.02565498, -0.30179568, + -0.6491698, -0.55017719, -2.92291466, -0.92522973, -0.47592022, + -0.07099135, -0.31575624, -0.86345281, -0.36017021, -0.79208612, + -1.75306124, -0.65089224, -0.00912786, -0.42915003, -1.72606203, + -1.66337589, -0.70800793, -2.52272352, -0.67329562, -2.49145522, + -0.49786342}); + m.PopulateTensor(m.sequence_length(), {1, 2, 3}); + m.Invoke(); + + // Make sure the output shapes are right. + const std::vector>& output_shapes = m.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 4); + EXPECT_THAT(output_shapes[0], ElementsAre(3, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(3)); + EXPECT_THAT(output_shapes[2], ElementsAre(2)); + EXPECT_THAT(output_shapes[3], ElementsAre(3, 1)); + + // Check decoded outputs. + const std::vector>& decoded_outputs = m.GetDecodedOutpus(); + EXPECT_EQ(decoded_outputs.size(), 3); + EXPECT_THAT(decoded_outputs[0], ElementsAre(0, 0, 1, 0, 2, 0)); + EXPECT_THAT(decoded_outputs[1], ElementsAre(2, 0, 1)); + EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 1)); + // Check log probabilities output. + EXPECT_THAT(m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear({0., 1.0347567, 0.7833005}))); +} + +} // namespace +} // namespace experimental +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h b/tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h new file mode 100644 index 0000000000000000000000000000000000000000..596ad4a5f7264ae24caa5592d10c09c256629b06 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_decoder.h @@ -0,0 +1,114 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/util/ctc/ctc_decoder.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_ + +#include +#include + +#include "third_party/eigen3/Eigen/Core" + +namespace tflite { +namespace experimental { +namespace ctc { + +// The CTCDecoder is an abstract interface to be implemented when providing a +// decoding method on the timestep output of a RNN trained with CTC loss. +// +// The two types of decoding available are: +// - greedy path, through the CTCGreedyDecoder +// - beam search, through the CTCBeamSearchDecoder +class CTCDecoder { + public: + typedef Eigen::Map SequenceLength; + typedef Eigen::Map Input; + typedef std::vector> Output; + typedef Eigen::Map ScoreOutput; + + CTCDecoder(int num_classes, int batch_size, bool merge_repeated) + : num_classes_(num_classes), + blank_index_(num_classes - 1), + batch_size_(batch_size), + merge_repeated_(merge_repeated) {} + + virtual ~CTCDecoder() {} + + // Dimensionality of the input/output is expected to be: + // - seq_len[b] - b = 0 to batch_size_ + // - input[t].rows(b) - t = 0 to timesteps; b = 0 t batch_size_ + // - output.size() specifies the number of beams to be returned. + // - scores(b, i) - b = 0 to batch_size; i = 0 to output.size() + virtual bool Decode(const SequenceLength& seq_len, + const std::vector& input, + std::vector* output, ScoreOutput* scores) = 0; + + int batch_size() { return batch_size_; } + int num_classes() { return num_classes_; } + + protected: + int num_classes_; + int blank_index_; + int batch_size_; + bool merge_repeated_; +}; + +// CTCGreedyDecoder is an implementation of the simple best path decoding +// algorithm, selecting at each timestep the most likely class at each timestep. +class CTCGreedyDecoder : public CTCDecoder { + public: + CTCGreedyDecoder(int num_classes, int batch_size, bool merge_repeated) + : CTCDecoder(num_classes, batch_size, merge_repeated) {} + + bool Decode(const CTCDecoder::SequenceLength& seq_len, + const std::vector& input, + std::vector* output, + CTCDecoder::ScoreOutput* scores) override { + if (output->empty() || (*output)[0].size() < batch_size_) { + return false; + } + if (scores->rows() < batch_size_ || scores->cols() == 0) { + return false; + } + // For each batch entry, identify the transitions + for (int b = 0; b < batch_size_; ++b) { + int seq_len_b = seq_len[b]; + // Only writing to beam 0 + std::vector& output_b = (*output)[0][b]; + + int prev_class_ix = -1; + (*scores)(b, 0) = 0; + for (int t = 0; t < seq_len_b; ++t) { + auto row = input[t].row(b); + int max_class_ix; + (*scores)(b, 0) += -row.maxCoeff(&max_class_ix); + if (max_class_ix != blank_index_ && + !(merge_repeated_ && max_class_ix == prev_class_ix)) { + output_b.push_back(max_class_ix); + } + prev_class_ix = max_class_ix; + } + } + return true; + } +}; + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_DECODER_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h b/tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h new file mode 100644 index 0000000000000000000000000000000000000000..0bae732533716ac047a55ea31633c8ed51253fe0 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_loss_util.h @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/util/ctc/ctc_loss_util.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_ + +#include +#include + +namespace tflite { +namespace experimental { +namespace ctc { + +const float kLogZero = -std::numeric_limits::infinity(); + +// Add logarithmic probabilities using: +// ln(a + b) = ln(a) + ln(1 + exp(ln(b) - ln(a))) +// The two inputs are assumed to be log probabilities. +// (GravesTh) Eq. 7.18 +inline float LogSumExp(float log_prob_1, float log_prob_2) { + // Always have 'b' be the smaller number to avoid the exponential from + // blowing up. + if (log_prob_1 == kLogZero && log_prob_2 == kLogZero) { + return kLogZero; + } else { + return (log_prob_1 > log_prob_2) + ? log_prob_1 + log1pf(expf(log_prob_2 - log_prob_1)) + : log_prob_2 + log1pf(expf(log_prob_1 - log_prob_2)); + } +} + +} // namespace ctc +} // namespace experimental +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_CTC_LOSS_UTIL_H_ diff --git a/tensorflow/contrib/lite/experimental/kernels/top_n.h b/tensorflow/contrib/lite/experimental/kernels/top_n.h new file mode 100644 index 0000000000000000000000000000000000000000..cd2a2f1c80276d4659ccd2f8f05af3af030acb90 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/kernels/top_n.h @@ -0,0 +1,341 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This simple class finds the top n elements of an incrementally provided set +// of elements which you push one at a time. If the number of elements exceeds +// n, the lowest elements are incrementally dropped. At the end you get +// a vector of the top elements sorted in descending order (through Extract() or +// ExtractNondestructive()), or a vector of the top elements but not sorted +// (through ExtractUnsorted() or ExtractUnsortedNondestructive()). +// +// The value n is specified in the constructor. If there are p elements pushed +// altogether: +// The total storage requirements are O(min(n, p)) elements +// The running time is O(p * log(min(n, p))) comparisons +// If n is a constant, the total storage required is a constant and the running +// time is linear in p. +// +// NOTE(zhifengc): There is a way to do this in O(min(n, p)) storage and O(p) +// runtime. The basic idea is to repeatedly fill up a buffer of 2 * n elements, +// discarding the lowest n elements whenever the buffer is full using a linear- +// time median algorithm. This may have better performance when the input +// sequence is partially sorted. +// +// NOTE(zhifengc): This class should be redesigned to avoid reallocating a +// vector for each Extract. + +// Copied from tensorflow/core/lib/gtl/top_n.h +// TODO(b/111524997): Remove this file. +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" + +namespace tflite { +namespace gtl { + +// Cmp is an stl binary predicate. Note that Cmp is the "greater" predicate, +// not the more commonly used "less" predicate. +// +// If you use a "less" predicate here, the TopN will pick out the bottom N +// elements out of the ones passed to it, and it will return them sorted in +// ascending order. +// +// TopN is rule-of-zero copyable and movable if its members are. +template > +class TopN { + public: + // The TopN is in one of the three states: + // + // o UNORDERED: this is the state an instance is originally in, + // where the elements are completely orderless. + // + // o BOTTOM_KNOWN: in this state, we keep the invariant that there + // is at least one element in it, and the lowest element is at + // position 0. The elements in other positions remain + // unsorted. This state is reached if the state was originally + // UNORDERED and a peek_bottom() function call is invoked. + // + // o HEAP_SORTED: in this state, the array is kept as a heap and + // there are exactly (limit_+1) elements in the array. This + // state is reached when at least (limit_+1) elements are + // pushed in. + // + // The state transition graph is at follows: + // + // peek_bottom() (limit_+1) elements + // UNORDERED --------------> BOTTOM_KNOWN --------------------> HEAP_SORTED + // | ^ + // | (limit_+1) elements | + // +-----------------------------------------------------------+ + + enum State { UNORDERED, BOTTOM_KNOWN, HEAP_SORTED }; + using UnsortedIterator = typename std::vector::const_iterator; + + // 'limit' is the maximum number of top results to return. + explicit TopN(size_t limit) : TopN(limit, Cmp()) {} + TopN(size_t limit, const Cmp &cmp) : limit_(limit), cmp_(cmp) {} + + size_t limit() const { return limit_; } + + // Number of elements currently held by this TopN object. This + // will be no greater than 'limit' passed to the constructor. + size_t size() const { return std::min(elements_.size(), limit_); } + + bool empty() const { return size() == 0; } + + // If you know how many elements you will push at the time you create the + // TopN object, you can call reserve to preallocate the memory that TopN + // will need to process all 'n' pushes. Calling this method is optional. + void reserve(size_t n) { elements_.reserve(std::min(n, limit_ + 1)); } + + // Push 'v'. If the maximum number of elements was exceeded, drop the + // lowest element and return it in 'dropped' (if given). If the maximum is not + // exceeded, 'dropped' will remain unchanged. 'dropped' may be omitted or + // nullptr, in which case it is not filled in. + // Requires: T is CopyAssignable, Swappable + void push(const T &v) { push(v, nullptr); } + void push(const T &v, T *dropped) { PushInternal(v, dropped); } + + // Move overloads of push. + // Requires: T is MoveAssignable, Swappable + void push(T &&v) { // NOLINT(build/c++11) + push(std::move(v), nullptr); + } + void push(T &&v, T *dropped) { // NOLINT(build/c++11) + PushInternal(std::move(v), dropped); + } + + // Peeks the bottom result without calling Extract() + const T &peek_bottom(); + + // Extract the elements as a vector sorted in descending order. The caller + // assumes ownership of the vector and must delete it when done. This is a + // destructive operation. The only method that can be called immediately + // after Extract() is Reset(). + std::vector *Extract(); + + // Similar to Extract(), but makes no guarantees the elements are in sorted + // order. As with Extract(), the caller assumes ownership of the vector and + // must delete it when done. This is a destructive operation. The only + // method that can be called immediately after ExtractUnsorted() is Reset(). + std::vector *ExtractUnsorted(); + + // A non-destructive version of Extract(). Copy the elements in a new vector + // sorted in descending order and return it. The caller assumes ownership of + // the new vector and must delete it when done. After calling + // ExtractNondestructive(), the caller can continue to push() new elements. + std::vector *ExtractNondestructive() const; + + // A non-destructive version of Extract(). Copy the elements to a given + // vector sorted in descending order. After calling + // ExtractNondestructive(), the caller can continue to push() new elements. + // Note: + // 1. The given argument must to be allocated. + // 2. Any data contained in the vector prior to the call will be deleted + // from it. After the call the vector will contain only the elements + // from the data structure. + void ExtractNondestructive(std::vector *output) const; + + // A non-destructive version of ExtractUnsorted(). Copy the elements in a new + // vector and return it, with no guarantees the elements are in sorted order. + // The caller assumes ownership of the new vector and must delete it when + // done. After calling ExtractUnsortedNondestructive(), the caller can + // continue to push() new elements. + std::vector *ExtractUnsortedNondestructive() const; + + // A non-destructive version of ExtractUnsorted(). Copy the elements into + // a given vector, with no guarantees the elements are in sorted order. + // After calling ExtractUnsortedNondestructive(), the caller can continue + // to push() new elements. + // Note: + // 1. The given argument must to be allocated. + // 2. Any data contained in the vector prior to the call will be deleted + // from it. After the call the vector will contain only the elements + // from the data structure. + void ExtractUnsortedNondestructive(std::vector *output) const; + + // Return an iterator to the beginning (end) of the container, + // with no guarantees about the order of iteration. These iterators are + // invalidated by mutation of the data structure. + UnsortedIterator unsorted_begin() const { return elements_.begin(); } + UnsortedIterator unsorted_end() const { return elements_.begin() + size(); } + + // Accessor for comparator template argument. + Cmp *comparator() { return &cmp_; } + + // This removes all elements. If Extract() or ExtractUnsorted() have been + // called, this will put it back in an empty but useable state. + void Reset(); + + private: + template + void PushInternal(U &&v, T *dropped); // NOLINT(build/c++11) + + // elements_ can be in one of two states: + // elements_.size() <= limit_: elements_ is an unsorted vector of elements + // pushed so far. + // elements_.size() > limit_: The last element of elements_ is unused; + // the other elements of elements_ are an stl heap whose size is exactly + // limit_. In this case elements_.size() is exactly one greater than + // limit_, but don't use "elements_.size() == limit_ + 1" to check for + // that because you'll get a false positive if limit_ == size_t(-1). + std::vector elements_; + size_t limit_; // Maximum number of elements to find + Cmp cmp_; // Greater-than comparison function + State state_ = UNORDERED; +}; + +// ---------------------------------------------------------------------- +// Implementations of non-inline functions + +template +template +void TopN::PushInternal(U &&v, T *dropped) { // NOLINT(build/c++11) + if (limit_ == 0) { + if (dropped) *dropped = std::forward(v); // NOLINT(build/c++11) + return; + } + if (state_ != HEAP_SORTED) { + elements_.push_back(std::forward(v)); // NOLINT(build/c++11) + if (state_ == UNORDERED || cmp_(elements_.back(), elements_.front())) { + // Easy case: we just pushed the new element back + } else { + // To maintain the BOTTOM_KNOWN state, we need to make sure that + // the element at position 0 is always the smallest. So we put + // the new element at position 0 and push the original bottom + // element in the back. + // Warning: this code is subtle. + using std::swap; + swap(elements_.front(), elements_.back()); + } + if (elements_.size() == limit_ + 1) { + // Transition from unsorted vector to a heap. + std::make_heap(elements_.begin(), elements_.end(), cmp_); + if (dropped) *dropped = std::move(elements_.front()); + std::pop_heap(elements_.begin(), elements_.end(), cmp_); + state_ = HEAP_SORTED; + } + } else { + // Only insert the new element if it is greater than the least element. + if (cmp_(v, elements_.front())) { + elements_.back() = std::forward(v); // NOLINT(build/c++11) + std::push_heap(elements_.begin(), elements_.end(), cmp_); + if (dropped) *dropped = std::move(elements_.front()); + std::pop_heap(elements_.begin(), elements_.end(), cmp_); + } else { + if (dropped) *dropped = std::forward(v); // NOLINT(build/c++11) + } + } +} + +template +const T &TopN::peek_bottom() { + TFLITE_DCHECK(!empty()); + if (state_ == UNORDERED) { + // We need to do a linear scan to find out the bottom element + int min_candidate = 0; + for (size_t i = 1; i < elements_.size(); ++i) { + if (cmp_(elements_[min_candidate], elements_[i])) { + min_candidate = i; + } + } + // By swapping the element at position 0 and the minimal + // element, we transition to the BOTTOM_KNOWN state + if (min_candidate != 0) { + using std::swap; + swap(elements_[0], elements_[min_candidate]); + } + state_ = BOTTOM_KNOWN; + } + return elements_.front(); +} + +template +std::vector *TopN::Extract() { + auto out = new std::vector; + out->swap(elements_); + if (state_ != HEAP_SORTED) { + std::sort(out->begin(), out->end(), cmp_); + } else { + out->pop_back(); + std::sort_heap(out->begin(), out->end(), cmp_); + } + return out; +} + +template +std::vector *TopN::ExtractUnsorted() { + auto out = new std::vector; + out->swap(elements_); + if (state_ == HEAP_SORTED) { + // Remove the limit_+1'th element. + out->pop_back(); + } + return out; +} + +template +std::vector *TopN::ExtractNondestructive() const { + auto out = new std::vector; + ExtractNondestructive(out); + return out; +} + +template +void TopN::ExtractNondestructive(std::vector *output) const { + TFLITE_DCHECK(output); + *output = elements_; + if (state_ != HEAP_SORTED) { + std::sort(output->begin(), output->end(), cmp_); + } else { + output->pop_back(); + std::sort_heap(output->begin(), output->end(), cmp_); + } +} + +template +std::vector *TopN::ExtractUnsortedNondestructive() const { + auto elements = new std::vector; + ExtractUnsortedNondestructive(elements); + return elements; +} + +template +void TopN::ExtractUnsortedNondestructive(std::vector *output) const { + TFLITE_DCHECK(output); + *output = elements_; + if (state_ == HEAP_SORTED) { + // Remove the limit_+1'th element. + output->pop_back(); + } +} + +template +void TopN::Reset() { + elements_.clear(); + state_ = UNORDERED; +} + +} // namespace gtl +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_KERNELS_TOP_N_H_ diff --git a/tensorflow/contrib/lite/g3doc/README.md b/tensorflow/contrib/lite/g3doc/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e3db4784815b7562588d3afbd34f837b101f0977 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/README.md @@ -0,0 +1,4 @@ +This is a *work-in-progress* TF Lite subsite for: +https://www.tensorflow.org/mobile + +DO NOT PUBLISH diff --git a/tensorflow/contrib/lite/g3doc/_book.yaml b/tensorflow/contrib/lite/g3doc/_book.yaml new file mode 100644 index 0000000000000000000000000000000000000000..98abd5743b2412399496f2fb3a70cd25d8597bca --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/_book.yaml @@ -0,0 +1,58 @@ +upper_tabs: +# Tabs left of dropdown menu +- include: /_upper_tabs_left.yaml +# Dropdown menu +- name: Ecosystem + path: /ecosystem + is_default: True + menu: + - include: /ecosystem/_menu_toc.yaml + lower_tabs: + # Subsite tabs + other: + - name: Guide + contents: + - title: Overview + path: /mobile/overview + - title: Developer Guide + path: /mobile/devguide + - title: Android Demo App + path: /mobile/demo_android + - title: iOS Demo App + path: /mobile/demo_ios + - title: Performance + path: /mobile/performance + - break: True + - title: TensorFlow Lite APIs + path: /mobile/apis + - title: Custom operators + path: /mobile/custom_operators + - title: TensorFlow Lite Ops Versioning + path: /mobile/ops_versioning + - title: TensorFlow Lite Compatibility Guide + path: /mobile/tf_ops_compatibility + - title: List of Hosted Models + path: /mobile/models + - title: TensorFlow Lite for iOS + path: /mobile/ios + - title: TensorFlow Lite for Raspberry Pi + path: /mobile/rpi + + - heading: TF Mobile + status: deprecated + - title: Overview + path: /mobile/tfmobile/ + - title: Building TensorFlow on Android + path: /mobile/tfmobile/android_build + - title: Building TensorFlow on IOS + path: /mobile/tfmobile/ios_build + - title: Integrating TensorFlow libraries + path: /mobile/tfmobile/linking_libs + - title: Preparing models for mobile deployment + path: /mobile/tfmobile/prepare_models + - title: Optimizing for mobile + path: /mobile/tfmobile/optimizing + + - name: API + contents: + - include: /mobile/api_docs/python/_toc.yaml diff --git a/tensorflow/contrib/lite/g3doc/_index.yaml b/tensorflow/contrib/lite/g3doc/_index.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9119e49117ffbda268f36324072d30ffd83c9e6c --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/_index.yaml @@ -0,0 +1,67 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml +description: +landing_page: + rows: + - heading: TensorFlow Lite is a lightweight solution for mobile and embedded devices. + items: + - description: > + TensorFlow Lite is TensorFlow’s lightweight solution for mobile and + embedded devices. It enables on-device machine learning inference with + low latency and a small binary size. TensorFlow Lite also supports + hardware acceleration with the + Android Neural Networks API. + list: + - heading: Key point 1 + description: > + [high-level overview] + icon: + icon_name: chevron_right + foreground: theme + background: grey + - heading: Key point 2 + description: > + [high-level overview] + icon: + icon_name: chevron_right + foreground: theme + background: grey + - heading: Key point 3 + description: > + [high-level overview] + icon: + icon_name: chevron_right + foreground: theme + background: grey + - code_block: | +
+        $ toco --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
+               --input_format=TENSORFLOW_GRAPHDEF \
+               --output_format=TFLITE \
+               --output_file=/tmp/mobilenet_v1_1.0_224.tflite \
+               --inference_type=FLOAT \
+               --input_type=FLOAT \
+               --input_arrays=input \
+               --output_arrays=MobilenetV1/Predictions/Reshape_1 \
+               --input_shapes=1,224,224,3
+        
+ + - classname: devsite-landing-row-cards + items: + - heading: Using TensorFlow Lite on Android + image_path: /ecosystem/images/tf-logo-card-16x9.png + path: https://medium.com/tensorflow/using-tensorflow-lite-on-android-9bbc9cb7d69d + buttons: + - label: Read on TensorFlow blog + path: https://medium.com/tensorflow/using-tensorflow-lite-on-android-9bbc9cb7d69d + - heading: TensorFlow Lite at the Dev Summit + youtube_id: FAMfy7izB6A + buttons: + - label: Watch the video + path: https://www.youtube.com/watch?v=FAMfy7izB6A + - heading: TensorFlow Lite on GitHub + image_path: /ecosystem/images/github-card-16x9.png + path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite + buttons: + - label: View on GitHub + path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite diff --git a/tensorflow/contrib/lite/g3doc/_project.yaml b/tensorflow/contrib/lite/g3doc/_project.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b39666516baab42d289e4d40077c2877ed65d396 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/_project.yaml @@ -0,0 +1,10 @@ +name: TensorFlow Lite +breadcrumb_name: Mobile +home_url: /mobile/ +parent_project_metadata_path: /_project.yaml +description: > + TensorFlow Lite is a lightweight solution for mobile and embedded devices. +use_site_branding: True +hide_from_products_list: True +content_license: cc3-apache2 +buganizer_id: 316308 diff --git a/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml b/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e1c44c6929571144d8cf0b54463c48e37466022 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml @@ -0,0 +1,6 @@ +# Automatically generated file; please do not edit +toc: + - title: TensorFlow Lite + section: + - title: Overview + path: /mobile/api_docs/python/ diff --git a/tensorflow/contrib/lite/g3doc/api_docs/python/index.md b/tensorflow/contrib/lite/g3doc/api_docs/python/index.md new file mode 100644 index 0000000000000000000000000000000000000000..70031a3c3d26eb6557014879cc92288cd22331eb --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/api_docs/python/index.md @@ -0,0 +1,10 @@ +Project: /mobile/_project.yaml +Book: /mobile/_book.yaml +page_type: reference + + + + +# All symbols in TensorFlow Lite + +TEMP PAGE diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md index a591a353dd8f0ac94ecaa3f12e1aa1c57566ef69..776803da8c7126c6198e3740448888119df030b9 100644 --- a/tensorflow/contrib/lite/g3doc/apis.md +++ b/tensorflow/contrib/lite/g3doc/apis.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # TensorFlow Lite APIs TensorFlow Lite provides programming APIs in C++ and Java, and in both cases @@ -53,6 +56,7 @@ typedef enum { ``` Failures can be easily verified with: + ```c++ if (status != kTfLiteOk) { // ... error handling here ... diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md index 972e57f73e82961ebc5e341dd7a41bc00acc5d21..d979353bb3550fe53d86b2e6c76702a3970b01fe 100644 --- a/tensorflow/contrib/lite/g3doc/custom_operators.md +++ b/tensorflow/contrib/lite/g3doc/custom_operators.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # How to use custom operators TensorFlow Lite currently supports a subset of TensorFlow operators. However, it @@ -89,3 +92,83 @@ builtins.AddCustom("Sin", Register_SIN()); Note that a similar process as above can be followed for supporting for a set of operations instead of a single operator. + +## Best Practices for writing custom operators + +1. Optimize memory allocations and de-allocations cautiously. It is more + efficient to allocate memory in Prepare() instead of Invoke(), and allocate + memory before a loop instead of in every iteration. Use temporary tensors + data rather than mallocing yourself (see item 2). Use pointers/references + instead of copying as much as possible. + +2. If a data structure will persist during the entire operation, we advise + pre-allocating the memory using temporary tensors. You may need to use + OpData struct to reference the tensor indices in other functions. See + example in the + [kernel for convolution](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/kernels/conv.cc). + A sample code snippet is below + + ``` + auto* op_data = reinterpret_cast(node->user_data); + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(1); + node->temporaries->data[0] = op_data->temp_tensor_index; + TfLiteTensor* temp_tensor = &context->tensors[op_data->temp_tensor_index]; + temp_tensor->type = kTfLiteFloat32; + temp_tensor->allocation_type = kTfLiteArenaRw; + ``` + +3. If it doesn't cost too much wasted memory, prefer using a static fixed size + array (or in Resize() pre-allocated std::vector) rather than using a + dynamically allocating std::vector every iteration of execution. + +4. Avoid instantiating standard library container templates that don't already + exist, because they affect binary size. For example, if you need a std::map + in your operation that doesn't exist in other kernels, using a std::vector + with direct indexing mapping could work while keeping the binary size small. + See what other kernels use to gain insight (or ask). + +5. Check the pointer to the memory returned by malloc. If this pointer is + nullptr, no operations should be performed using that pointer. If you + malloc() in a function and have an error exit, deallocate memory before you + exit. + +6. Use TF_LITE_ENSURE(context, condition) to check for a specific condition. + Your code must not leave memory hanging when TF_LITE_ENSURE is done, i.e., + these should be done before any resources are allocated that will leak. + +## Special TF Graph Attributes + +When Toco convertes a TF graph into TFLite format, it makes some assumption +about custom operations that might be not correct. In this case, the generated +graph can be not executable. + +It is possible to add aditional information about your custom op output to TF +graph before it is converted. The following attributes are supported: + +- **_output_quantized** a boolean attribute, true if the operation outputs are + quantized +- **_output_types** a list of types for output tensors +- **_output_shapes** a list of shapes for output tensors + +### Setting the Attributes + +This is an example how the attributes can be set: + +```python +frozen_graph_def = tf.graph_util.convert_variables_to_constants(...) +for node in frozen_graph_def.node: + if node.op == 'sin': + node.attr['_output_types'].list.type.extend([ + types_pb2.DT_FLOAT, + ]) + node.attr['_output_shapes'].list.shape.extend([ + tf.TensorShape([10]), + ]) + node.attr['_output_quantized'].b = False +tflite_model = tf.contrib.lite.toco_convert( + frozen_graph_def,...) +``` + +**Note:** After the attributes are set, the graph can not be executed by +Tensorflow, therefore it should be done just before the conversion. diff --git a/tensorflow/docs_src/mobile/tflite/demo_android.md b/tensorflow/contrib/lite/g3doc/demo_android.md similarity index 98% rename from tensorflow/docs_src/mobile/tflite/demo_android.md rename to tensorflow/contrib/lite/g3doc/demo_android.md index fdf0bcf3c1135f0e702c7dda4d1d608a26169470..d79a2696b4e9cc10480aa67c7eaec5a356eff596 100644 --- a/tensorflow/docs_src/mobile/tflite/demo_android.md +++ b/tensorflow/contrib/lite/g3doc/demo_android.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # Android Demo App An example Android application using TensorFLow Lite is available diff --git a/tensorflow/docs_src/mobile/tflite/demo_ios.md b/tensorflow/contrib/lite/g3doc/demo_ios.md similarity index 97% rename from tensorflow/docs_src/mobile/tflite/demo_ios.md rename to tensorflow/contrib/lite/g3doc/demo_ios.md index 3be21da89f9e53d324c2ade0cb937f4b5b30fad4..a554898899e67a6bc2bc52733f5301767bc1c06a 100644 --- a/tensorflow/docs_src/mobile/tflite/demo_ios.md +++ b/tensorflow/contrib/lite/g3doc/demo_ios.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # iOS Demo App The TensorFlow Lite demo is a camera app that continuously classifies whatever diff --git a/tensorflow/docs_src/mobile/tflite/devguide.md b/tensorflow/contrib/lite/g3doc/devguide.md similarity index 89% rename from tensorflow/docs_src/mobile/tflite/devguide.md rename to tensorflow/contrib/lite/g3doc/devguide.md index 4133bc172a1924f0ce8bb515d66fc03d716923c8..dc9cc98c0821edff57cb9428a50637a15211cfda 100644 --- a/tensorflow/docs_src/mobile/tflite/devguide.md +++ b/tensorflow/contrib/lite/g3doc/devguide.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # Developer Guide Using a TensorFlow Lite model in your mobile app requires multiple @@ -54,10 +57,11 @@ both floating point and quantized inference. ### Train a custom model A developer may choose to train a custom model using Tensorflow (see the -@{$tutorials} for examples of building and training models). If you have already -written a model, the first step is to export this to a @{tf.GraphDef} file. This -is required because some formats do not store the model structure outside the -code, and we must communicate with other parts of the framework. See +[TensorFlow tutorials](../../tutorials/) for examples of building and training +models). If you have already written a model, the first step is to export this +to a `tf.GraphDef` file. This is required because some formats do not store the +model structure outside the code, and we must communicate with other parts of the +framework. See [Exporting the Inference Graph](https://github.com/tensorflow/models/blob/master/research/slim/README.md) to create .pb file for the custom model. @@ -70,12 +74,12 @@ grow in future Tensorflow Lite releases. ## 2. Convert the model format The model generated (or downloaded) in the previous step is a *standard* -Tensorflow model and you should now have a .pb or .pbtxt @{tf.GraphDef} file. +Tensorflow model and you should now have a .pb or .pbtxt `tf.GraphDef` file. Models generated with transfer learning (re-training) or custom models must be converted—but, we must first freeze the graph to convert the model to the Tensorflow Lite format. This process uses several model formats: -* @{tf.GraphDef} (.pb) —A protobuf that represents the TensorFlow training or +* `tf.GraphDef` (.pb) —A protobuf that represents the TensorFlow training or computation graph. It contains operators, tensors, and variables definitions. * *CheckPoint* (.ckpt) —Serialized variables from a TensorFlow graph. Since this does not contain a graph structure, it cannot be interpreted by itself. @@ -142,11 +146,11 @@ containing the model architecture. The [frozen_graph.pb](https://storage.googlea file used here is available for download. `output_file` is where the TensorFlow Lite model will get generated. The `input_type` and `inference_type` arguments should be set to `FLOAT`, unless converting a -@{$performance/quantization$quantized model}. Setting the `input_array`, -`output_array`, and `input_shape` arguments are not as straightforward. The -easiest way to find these values is to explore the graph using Tensorboard. Reuse -the arguments for specifying the output nodes for inference in the -`freeze_graph` step. +quantized model. +Setting the `input_array`, `output_array`, and `input_shape` arguments are not as +straightforward. The easiest way to find these values is to explore the graph +using Tensorboard. Reuse the arguments for specifying the output nodes for +inference in the `freeze_graph` step. It is also possible to use the Tensorflow Optimizing Converter with protobufs from either Python or from the command line (see the @@ -203,16 +207,16 @@ The open source Android demo app uses the JNI interface and is available [on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app). You can also download a [prebuilt APK](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk). -See the @{$tflite/demo_android} guide for details. +See the Android demo guide for details. -The @{$mobile/android_build} guide has instructions for installing TensorFlow on -Android and setting up `bazel` and Android Studio. +The Android mobile guide has instructions for +installing TensorFlow on Android and setting up `bazel` and Android Studio. ### iOS To integrate a TensorFlow model in an iOS app, see the [TensorFlow Lite for iOS](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md) -guide and @{$tflite/demo_ios} guide. +guide and iOS demo guide. #### Core ML support diff --git a/tensorflow/contrib/lite/g3doc/ios.md b/tensorflow/contrib/lite/g3doc/ios.md index e0358a444d6dffc377bf13ee72ba5477359d6e07..d78d373ccfea074872773693c562253b202a646b 100644 --- a/tensorflow/contrib/lite/g3doc/ios.md +++ b/tensorflow/contrib/lite/g3doc/ios.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # TensorFlow Lite for iOS ## Building diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md index c1c8ef049f693dae038e5e0ca242b9219329cc50..4ceb9a53dc0967ab6320a1bfdb1ddb859482c5dd 100644 --- a/tensorflow/contrib/lite/g3doc/models.md +++ b/tensorflow/contrib/lite/g3doc/models.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # List of Hosted Models ## Image classification (Float Models) @@ -39,22 +42,22 @@ single thread large core. Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance ------------------------ | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------: -Mobilenet_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.9% | 65.8% | 3.7 ms -Mobilenet_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 43.5% | 69.1% | 5.5 ms -Mobilenet_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 45.8% | 71.9% | 7.9 ms -Mobilenet_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.2% | 73.8% | 10.4 ms -Mobilenet_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.9% | 78.9% | 8.8 ms -Mobilenet_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.7% | 81.3% | 13.0 ms -Mobilenet_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 60.4% | 83.2% | 18.3 ms -Mobilenet_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 62.2% | 84.5% | 24.7 ms -Mobilenet_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 59.8% | 82.8% | 16.2 ms -Mobilenet_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 63.9% | 85.5% | 24.3 ms -Mobilenet_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.2% | 87.1% | 33.8 ms -Mobilenet_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 67.9% | 88.1% | 45.4 ms -Mobilenet_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 64.0% | 85.5% | 24.9 ms -Mobilenet_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 67.3% | 87.7% | 37.4 ms -Mobilenet_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.0% | 88.9% | 51.9 ms -Mobilenet_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 69.7% | 89.5% | 70.2 ms +Mobilenet_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.5% | 64.4% | 3.7 ms +Mobilenet_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 43.4% | 68.5% | 5.5 ms +Mobilenet_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 46.0% | 71.2% | 7.9 ms +Mobilenet_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.0% | 72.8% | 10.4 ms +Mobilenet_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.5% | 77.7% | 8.8 ms +Mobilenet_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.7% | 80.4% | 13.0 ms +Mobilenet_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 60.0% | 82.2% | 18.3 ms +Mobilenet_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 60.7% | 83.2% | 24.7 ms +Mobilenet_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 55.8% | 78.8% | 16.2 ms +Mobilenet_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 62.3% | 83.8% | 24.3 ms +Mobilenet_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.1% | 86.4% | 33.8 ms +Mobilenet_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 66.8% | 87.0% | 45.4 ms +Mobilenet_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 63.4% | 84.2% | 24.9 ms +Mobilenet_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 67.2% | 86.7% | 37.4 ms +Mobilenet_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.2% | 88.3% | 51.9 ms +Mobilenet_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 70.1% | 88.9% | 70.2 ms ## Other models diff --git a/tensorflow/contrib/lite/g3doc/ops_versioning.md b/tensorflow/contrib/lite/g3doc/ops_versioning.md index bd2f797e6c5b05f52bec9fc34f1b8011aca70330..b06f4fd3b893e5e5977f92de26109a6dd264531f 100644 --- a/tensorflow/contrib/lite/g3doc/ops_versioning.md +++ b/tensorflow/contrib/lite/g3doc/ops_versioning.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # TensorFlow Lite Ops Versioning This document describes TensorFlow Lite's op versioning schema. Op diff --git a/tensorflow/docs_src/mobile/tflite/index.md b/tensorflow/contrib/lite/g3doc/overview.md similarity index 93% rename from tensorflow/docs_src/mobile/tflite/index.md rename to tensorflow/contrib/lite/g3doc/overview.md index 3d1733024e493042a2cc85aa9f2fec4b75eefa94..be60d7941ade824ee201bfd05400fb3e4e9fae7e 100644 --- a/tensorflow/docs_src/mobile/tflite/index.md +++ b/tensorflow/contrib/lite/g3doc/overview.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # Introduction to TensorFlow Lite TensorFlow Lite is TensorFlow’s lightweight solution for mobile and embedded @@ -70,10 +73,9 @@ There are several factors which are fueling interest in this domain: We believe the next wave of machine learning applications will have significant processing on mobile and embedded devices. -## TensorFlow Lite developer preview highlights +## TensorFlow Lite highlights -TensorFlow Lite is available as a developer preview and includes the -following: +TensorFlow Lite provides: - A set of core operators, both quantized and float, many of which have been tuned for mobile platforms. These can be used to create and run custom @@ -129,9 +131,6 @@ following: - Java and C++ API support -Note: This is a developer release, and it’s likely that there will be changes in -the API in upcoming versions. We do not guarantee backward or forward -compatibility with this release. ## Getting Started @@ -201,9 +200,5 @@ possible performance for a particular model on a particular device. ## Next Steps -For the developer preview, most of our documentation is on GitHub. Please take a -look at the [TensorFlow Lite -repository](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite) -on GitHub for more information and for code samples, demo applications, and -more. - +The TensorFlow Lite [GitHub repository](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite). +contains additional docs, code samples, and demo applications. diff --git a/tensorflow/contrib/lite/g3doc/benchmarks.md b/tensorflow/contrib/lite/g3doc/performance.md similarity index 72% rename from tensorflow/contrib/lite/g3doc/benchmarks.md rename to tensorflow/contrib/lite/g3doc/performance.md index 29b087bea7aab1fcbc87ef764795f01e87b0bf9e..5cd0aab44f10de1b76e1acb302fc1ee2711c8d74 100644 --- a/tensorflow/contrib/lite/g3doc/benchmarks.md +++ b/tensorflow/contrib/lite/g3doc/performance.md @@ -1,27 +1,26 @@ -# Performance Benchmark numbers +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml -This document contains the performance benchmark numbers for running a few well +# Performance + +This document lists TensorFlow Lite performance benchmarks when running well known models on some Android and iOS devices. -The benchmark numbers were generated by running the [TFLite benchmark -binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark) -on Android and running the [iOS benchmark -app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios) -on iOS. +These performance benchmark numbers were generated with the +[Android TFLite benchmark binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark) +and the [iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios). -# Android benchmarks +# Android performance benchmarks -When running Android benchmarks, the CPU affinity is set to use big cores on the -device to reduce variance (see -[details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#reducing-variance-between-runs-on-android)). +For Android benchmarks, the CPU affinity is set to use big cores on the device to +reduce variance (see [details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#reducing-variance-between-runs-on-android)). -Models are assumed to have been downloaded from the link, unzipped and pushed to -`/data/local/tmp/tflite_models` folder. The benchmark binary is built according -to instructions listed -[here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#on-android). -and is assumed to have been pushed to `/data/local/tmp`. +It assumes that models were download and unzipped to the +`/data/local/tmp/tflite_models` directory. The benchmark binary is built +using [these instructions](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#on-android) +and assumed in the `/data/local/tmp` directory. -The following command was used to run the benchmark: +To run the benchmark: ``` adb shell taskset ${CPU_MASK} /data/local/tmp/benchmark_model \ @@ -32,7 +31,7 @@ adb shell taskset ${CPU_MASK} /data/local/tmp/benchmark_model \ --use_nnapi=false ``` -where `${GRAPH}` is the name of model and `${CPU_MASK}` is the CPU affinity +Here, `${GRAPH}` is the name of model and `${CPU_MASK}` is the CPU affinity chosen according to the following table: Device | CPU_MASK | @@ -40,7 +39,6 @@ Device | CPU_MASK | Pixel 2 | f0 | Pixel xl | 0c | - @@ -51,7 +49,7 @@ Pixel xl | 0c | @@ -62,7 +60,7 @@ Pixel xl | 0c | @@ -120,7 +118,7 @@ Pixel xl | 0c | # iOS benchmarks -For running iOS benchmarks, the [benchmark +To run iOS benchmarks, the [benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios) was modified to include the appropriate model and `benchmark_params.json` was modified to set `num_threads` to 1. @@ -135,14 +133,14 @@ modified to set `num_threads` to 1. diff --git a/tensorflow/contrib/lite/g3doc/rpi.md b/tensorflow/contrib/lite/g3doc/rpi.md index ab50789307414255bccd84d4cfcb6ddecc25ba08..cdc9172d873bfd32811ca69901ed2e4eedf902a3 100644 --- a/tensorflow/contrib/lite/g3doc/rpi.md +++ b/tensorflow/contrib/lite/g3doc/rpi.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # TensorFlow Lite for Raspberry Pi ## Cross compiling diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index dcd17bbeabda08eaf86f8d5ac7f26cea0d3719a3..aa65ec99887a61df658dd7add7b5cc3b91d81846 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # TensorFlow Lite & TensorFlow Compatibility Guide TensorFlow Lite supports a number of TensorFlow operations used in common @@ -42,6 +45,7 @@ counterparts: *as long as the input tensor is 4D (1 batch + 2 spatial + 1 other) and the crops attribute is not used* * [tf.exp](https://www.tensorflow.org/api_docs/python/tf/exp) +* [tf.fake_quant*](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_args) * [tf.matmul](https://www.tensorflow.org/api_docs/python/tf/matmul) - *as long as the second argument is constant and transposition is not used* * [tf.nn.avg_pool](https://www.tensorflow.org/api_docs/python/tf/nn/avg_pool) @@ -58,6 +62,7 @@ counterparts: * [tf.nn.softmax](https://www.tensorflow.org/api_docs/python/tf/nn/softmax) - *as long as tensors are 2D and axis is the last dimension* * [tf.nn.top_k](https://www.tensorflow.org/api_docs/python/tf/nn/top_k) +* [tf.one_hot](https://www.tensorflow.org/api_docs/python/tf/one_hot) * [tf.pad](https://www.tensorflow.org/api_docs/python/tf/pad) - *as long as mode and constant_values are not used* * [tf.reduce_mean](https://www.tensorflow.org/api_docs/python/tf/reduce_mean) - @@ -790,6 +795,54 @@ Outputs { } ``` +**ARG_MAX** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: A tensor of indices of maximum values. +} +``` + +**ARG_MIN** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: A tensor of indices of minium values. +} +``` + +**PACK** + +``` +Inputs { + 0: a list of tensors. + 1: an integer. +} +Outputs { + 0: A tensor of stacked tensors. +} +``` + +**LOGICAL_OR** + +``` +Inputs { + 0: a list of tensors. + 1: a list of tensors. +} +Outputs { + 0: A tensor of logical_or output tensors. +} +``` + And these are TensorFlow Lite operations that are present but not ready for custom models yet: diff --git a/tensorflow/docs_src/mobile/android_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md similarity index 97% rename from tensorflow/docs_src/mobile/android_build.md rename to tensorflow/contrib/lite/g3doc/tfmobile/android_build.md index f4b07db4591dddcfbf3633f471072f4a0eea9843..76e16fc9db27782fe0f9454ba463722f4bf6eb4b 100644 --- a/tensorflow/docs_src/mobile/android_build.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # Building TensorFlow on Android To get you started working with TensorFlow on Android, we'll walk through two @@ -91,7 +94,8 @@ using [ADB](https://developer.android.com/studio/command-line/adb.html). This requires some knowledge of build systems and Android developer tools, but we'll guide you through the basics here. -- First, follow our instructions for @{$install/install_sources$installing from sources}. +- First, follow our instructions for + installing from sources. This will also guide you through installing Bazel and cloning the TensorFlow code. diff --git a/tensorflow/docs_src/mobile/mobile_intro.md b/tensorflow/contrib/lite/g3doc/tfmobile/index.md similarity index 85% rename from tensorflow/docs_src/mobile/mobile_intro.md rename to tensorflow/contrib/lite/g3doc/tfmobile/index.md index 241f01d460ae35e818a61be4c4914b3bd8dae00a..bd047bfceceddfd0b5a9fd0c83cb47a339299abf 100644 --- a/tensorflow/docs_src/mobile/mobile_intro.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/index.md @@ -1,4 +1,45 @@ -# Introduction to TensorFlow Mobile +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + +# Overview + +TensorFlow was designed to be a good deep learning solution for mobile +platforms. Currently we have two solutions for deploying machine learning +applications on mobile and embedded devices: TensorFlow for Mobile and +TensorFlow Lite. + +## TensorFlow Lite versus TensorFlow Mobile + +Here are a few of the differences between the two: + +- TensorFlow Lite is an evolution of TensorFlow Mobile. In most cases, apps + developed with TensorFlow Lite will have a smaller binary size, fewer + dependencies, and better performance. + +- TensorFlow Lite is in developer preview, so not all use cases are covered yet. + We expect you to use TensorFlow Mobile to cover production cases. + +- TensorFlow Lite supports only a limited set of operators, so not all models + will work on it by default. TensorFlow for Mobile has a fuller set of + supported functionality. + +TensorFlow Lite provides better performance and a small binary size on mobile +platforms as well as the ability to leverage hardware acceleration if available +on their platforms. In addition, it has many fewer dependencies so it can be +built and hosted on simpler, more constrained device scenarios. TensorFlow Lite +also allows targeting accelerators through the [Neural Networks +API](https://developer.android.com/ndk/guides/neuralnetworks/index.html). + +TensorFlow Lite currently has coverage for a limited set of operators. While +TensorFlow for Mobile supports only a constrained set of ops by default, in +principle if you use an arbitrary operator in TensorFlow, it can be customized +to build that kernel. Thus use cases which are not currently supported by +TensorFlow Lite should continue to use TensorFlow for Mobile. As TensorFlow Lite +evolves, it will gain additional operators, and the decision will be easier to +make. + + +## Introduction to TensorFlow Mobile TensorFlow was designed from the ground up to be a good deep learning solution for mobile platforms like Android and iOS. This mobile guide should help you @@ -38,7 +79,8 @@ speech-driven interface, and many of these require on-device processing. Most of the time a user isn’t giving commands, and so streaming audio continuously to a remote server would be a waste of bandwidth, since it would mostly be silence or background noises. To solve this problem it’s common to have a small neural -network running on-device @{$tutorials/audio_recognition$listening out for a particular keyword}. +network running on-device +[listening out for a particular keyword](../tutorials/sequences/audio_recognition). Once that keyword has been spotted, the rest of the conversation can be transmitted over to the server for further processing if more computing power is needed. @@ -166,7 +208,7 @@ interesting products possible. TensorFlow runs on Ubuntu Linux, Windows 10, and OS X. For a list of all supported operating systems and instructions to install TensorFlow, see -@{$install$Installing Tensorflow}. +Installing Tensorflow. Note that some of the sample code we provide for mobile TensorFlow requires you to compile TensorFlow from source, so you’ll need more than just `pip install` @@ -240,8 +282,3 @@ results you’ll see. It’s common for an algorithm to get great training accur numbers but then fail to be useful within a real application because there’s a mismatch between the dataset and real usage. Prototype end-to-end usage as soon as possible to create a consistent user experience. - -## Next Steps - -We suggest you get started by building one of our demos for -@{$mobile/android_build$Android} or @{$mobile/ios_build$iOS}. diff --git a/tensorflow/docs_src/mobile/ios_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md similarity index 98% rename from tensorflow/docs_src/mobile/ios_build.md rename to tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md index 4c84a1214a26eeb90c1b6a186a369212377b06cd..6223707892ce7b288ecabf932b33cd39860446a6 100644 --- a/tensorflow/docs_src/mobile/ios_build.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # Building TensorFlow on iOS ## Using CocoaPods diff --git a/tensorflow/docs_src/mobile/linking_libs.md b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md similarity index 83% rename from tensorflow/docs_src/mobile/linking_libs.md rename to tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md index efef5dd0daa0b267d8384d32d62d9ce0226dc102..4c2071ed053125cfa643ed785fe302198f734ead 100644 --- a/tensorflow/docs_src/mobile/linking_libs.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # Integrating TensorFlow libraries Once you have made some progress on a model that addresses the problem you’re @@ -14,11 +17,11 @@ TensorFlow mobile demo apps. After you've managed to build the examples, you'll probably want to call TensorFlow from one of your existing applications. The very easiest way to do -this is to use the Pod installation steps described -@{$mobile/ios_build#using_cocoapods$here}, but if you want to build TensorFlow -from source (for example to customize which operators are included) you'll need -to break out TensorFlow as a framework, include the right header files, and link -against the built libraries and dependencies. +this is to use the Pod installation steps described in +Building TensorFlow on iOS, but if you want to build +TensorFlow from source (for example to customize which operators are included) +you'll need to break out TensorFlow as a framework, include the right header +files, and link against the built libraries and dependencies. ### Android @@ -82,10 +85,12 @@ recompile of the core. To achieve this capability, TensorFlow uses a registration pattern in a lot of places. In the code, it looks like this: - class MulKernel : OpKernel { - Status Compute(OpKernelContext* context) { … } - }; - REGISTER_KERNEL(MulKernel, “Mul”); +``` +class MulKernel : OpKernel { + Status Compute(OpKernelContext* context) { … } +}; +REGISTER_KERNEL(MulKernel, “Mul”); +``` This would be in a standalone `.cc` file linked into your application, either as part of the main set of kernels or as a separate custom library. The magic @@ -101,15 +106,17 @@ doesn’t offer a good mechanism for doing this sort of registration, so we have to resort to some tricky code. Under the hood, the macro is implemented so that it produces something like this: - class RegisterMul { - public: - RegisterMul() { - global_kernel_registry()->Register(“Mul”, [](){ - return new MulKernel() - }); - } - }; - RegisterMul g_register_mul; +``` +class RegisterMul { + public: + RegisterMul() { + global_kernel_registry()->Register(“Mul”, [](){ + return new MulKernel() + }); + } +}; +RegisterMul g_register_mul; +``` This sets up a class `RegisterMul` with a constructor that tells the global kernel registry what function to call when somebody asks it how to create a @@ -176,8 +183,10 @@ have an experimental script at [rename_protobuf.sh](https://github.com/tensorflo You need to run this as part of the makefile build, after you’ve downloaded all the dependencies: - tensorflow/contrib/makefile/download_dependencies.sh - tensorflow/contrib/makefile/rename_protobuf.sh +``` +tensorflow/contrib/makefile/download_dependencies.sh +tensorflow/contrib/makefile/rename_protobuf.sh +``` ## Calling the TensorFlow API @@ -193,18 +202,20 @@ use case, while on iOS and Raspberry Pi you call directly into the C++ API. Here’s what a typical Inference Library sequence looks like on Android: - // Load the model from disk. - TensorFlowInferenceInterface inferenceInterface = - new TensorFlowInferenceInterface(assetManager, modelFilename); +``` +// Load the model from disk. +TensorFlowInferenceInterface inferenceInterface = +new TensorFlowInferenceInterface(assetManager, modelFilename); - // Copy the input data into TensorFlow. - inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3); +// Copy the input data into TensorFlow. +inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3); - // Run the inference call. - inferenceInterface.run(outputNames, logStats); +// Run the inference call. +inferenceInterface.run(outputNames, logStats); - // Copy the output Tensor back into the output array. - inferenceInterface.fetch(outputName, outputs); +// Copy the output Tensor back into the output array. +inferenceInterface.fetch(outputName, outputs); +``` You can find the source of this code in the [Android examples](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java#L107). @@ -212,27 +223,29 @@ You can find the source of this code in the [Android examples](https://github.co Here’s the equivalent code for iOS and Raspberry Pi: - // Load the model. - PortableReadFileToProto(file_path, &tensorflow_graph); - - // Create a session from the model. - tensorflow::Status s = session->Create(tensorflow_graph); - if (!s.ok()) { - LOG(FATAL) << "Could not create TensorFlow Graph: " << s; - } - - // Run the model. - std::string input_layer = "input"; - std::string output_layer = "output"; - std::vector outputs; - tensorflow::Status run_status = session->Run({{input_layer, image_tensor}}, +``` +// Load the model. +PortableReadFileToProto(file_path, &tensorflow_graph); + +// Create a session from the model. +tensorflow::Status s = session->Create(tensorflow_graph); +if (!s.ok()) { + LOG(FATAL) << "Could not create TensorFlow Graph: " << s; +} + +// Run the model. +std::string input_layer = "input"; +std::string output_layer = "output"; +std::vector outputs; +tensorflow::Status run_status = session->Run({\{input_layer, image_tensor}}, {output_layer}, {}, &outputs); - if (!run_status.ok()) { - LOG(FATAL) << "Running model failed: " << run_status; - } +if (!run_status.ok()) { + LOG(FATAL) << "Running model failed: " << run_status; +} - // Access the output data. - tensorflow::Tensor* output = &outputs[0]; +// Access the output data. +tensorflow::Tensor* output = &outputs[0]; +``` This is all based on the [iOS sample code](https://www.tensorflow.org/code/tensorflow/examples/ios/simple/RunModelViewController.mm), diff --git a/tensorflow/docs_src/mobile/optimizing.md b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md similarity index 98% rename from tensorflow/docs_src/mobile/optimizing.md rename to tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md index 778e4d3a6233c3bec70b830bc998013745a1f0ba..a0192c3541483437b817e22eb92193bd7bcb4c28 100644 --- a/tensorflow/docs_src/mobile/optimizing.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # Optimizing for mobile There are some special issues that you have to deal with when you’re trying to @@ -77,7 +80,7 @@ out of a mobile device's memory faster. To understand how large your network will be on disk, start by looking at the size on disk of your `GraphDef` file after you’ve run `freeze_graph` and -`strip_unused_nodes` on it (see @{$mobile/prepare_models$Preparing models} for +`strip_unused_nodes` on it (see Preparing models for more details on these tools), since then it should only contain inference-related nodes. To double-check that your results are as expected, run the `summarize_graph` tool to see how many parameters are in constants: @@ -103,7 +106,8 @@ you multiply the number of const parameters by four, you should get something that’s close to the size of the file on disk. You can often get away with only eight-bits per parameter with very little loss of accuracy in the final result, so if your file size is too large you can try using -@{$performance/quantization$quantize_weights} to transform the parameters down. +quantize_weights +to transform the parameters down. bazel build tensorflow/tools/graph_transforms:transform_graph && \ bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ @@ -292,7 +296,8 @@ run it on a 64-bit ARM device: You can interpret the results in exactly the same way as the desktop version above. If you have any trouble figuring out what the right input and output -names and types are, take a look at the @{$mobile/prepare_models$Preparing models} +names and types are, take a look at the +Preparing models page for details about detecting these for your model, and look at the `summarize_graph` tool which may give you helpful information. diff --git a/tensorflow/docs_src/mobile/prepare_models.md b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md similarity index 98% rename from tensorflow/docs_src/mobile/prepare_models.md rename to tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md index 2b84dbb97388b16c6a4ae1d3472e0b1a993285f0..6b4e4a92bd9262139be3cf650b7d16714ee3a277 100644 --- a/tensorflow/docs_src/mobile/prepare_models.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md @@ -1,3 +1,6 @@ +book_path: /mobile/_book.yaml +project_path: /mobile/_project.yaml + # Preparing models for mobile deployment The requirements for storing model information during training are very @@ -255,8 +258,8 @@ The criteria for including ops and types fall into several categories: These ops are trimmed by default to optimize for inference on mobile, but it is possible to alter some build files to change the default. After alternating the build files, you will need to recompile TensorFlow. See below for more details -on how to do this, and also see @{$mobile/optimizing#binary_size$Optimizing} for -more on reducing your binary size. +on how to do this, and also see optimizing binary size +for more on reducing your binary size. ### Locate the implementation diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 62a0b1ff0817d25bc8d4caaedf96d27c141b85ef..7a680f5c6400a94a2746d09891e0e39a410404a2 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -22,10 +22,9 @@ limitations under the License. #include "tensorflow/contrib/lite/arena_planner.h" #include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/context_util.h" #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/graph_info.h" -#include "tensorflow/contrib/lite/kernels/eigen_support.h" -#include "tensorflow/contrib/lite/kernels/gemm_support.h" #include "tensorflow/contrib/lite/memory_planner.h" #include "tensorflow/contrib/lite/nnapi_delegate.h" #include "tensorflow/contrib/lite/profiling/profiler.h" @@ -33,9 +32,21 @@ limitations under the License. #include "tensorflow/contrib/lite/util.h" namespace tflite { - namespace { +TfLiteStatus ReportOpError(TfLiteContext* context, const TfLiteNode& node, + const TfLiteRegistration& registration, + int node_index, const char* message) { + context->ReportError( + context, "Node number %d (%s) %s.\n", node_index, + registration.custom_name + ? registration.custom_name + : EnumNameBuiltinOperator( + static_cast(registration.builtin_code)), + message); + return kTfLiteError; +} + // Stub method which returns kTfLiteError when the function is forbidden. // We're registrating this function to several different function to save // compiled binary size. Please note the restrictions: @@ -53,6 +64,19 @@ void SetForbiddenContextFunction(FunctionType* func) { *func = reinterpret_cast(ForbiddenContextFunction); } +// Returns true if at least one tensor in the given list is kTfLiteDynamic. +template +bool HasDynamicTensorImpl(const TfLiteContext& context, + const TensorIntArray& int_array) { + for (int i : int_array) { + const TfLiteTensor& tensor = context.tensors[i]; + if (tensor.allocation_type == kTfLiteDynamic) { + return true; + } + } + return false; +} + } // namespace // A trivial implementation of GraphInfo around the Interpreter. @@ -99,19 +123,22 @@ Interpreter::Interpreter(ErrorReporter* error_reporter) context_.AddTensors = AddTensors; context_.tensors = nullptr; context_.tensors_size = 0; - context_.eigen_context = nullptr; - context_.gemm_context = nullptr; context_.recommended_num_threads = -1; + context_.GetExternalContext = GetExternalContext; + context_.SetExternalContext = SetExternalContext; // Invalid to call these these except from TfLiteDelegate - SetForbiddenContextFunction(&context_.GetNodeAndRegistration); - SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels); - SetForbiddenContextFunction(&context_.GetExecutionPlan); + SwitchToKernelContext(); // Reserve some space for the tensors to avoid excessive resizing. tensors_.reserve(kTensorsReservedCapacity); nodes_and_registration_.reserve(kTensorsReservedCapacity); next_execution_plan_index_to_prepare_ = 0; + + for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) { + external_contexts_[i] = nullptr; + } + UseNNAPI(false); } @@ -246,8 +273,9 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( int node_index; TfLiteDelegateParams* params = CreateDelegateParams(delegate, subgraph); - AddNodeWithParameters(subgraph.input_tensors, subgraph.output_tensors, - nullptr, 0, params, ®istration, &node_index); + TF_LITE_ENSURE_STATUS(AddNodeWithParameters( + subgraph.input_tensors, subgraph.output_tensors, nullptr, 0, params, + ®istration, &node_index)); // Initialize the output tensors's delegate-related fields. for (int tensor_index : subgraph.output_tensors) { @@ -269,6 +297,33 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( return kTfLiteOk; } +TfLiteExternalContext* Interpreter::GetExternalContext( + TfLiteExternalContextType type) { + if (type >= 0 && type < kTfLiteMaxExternalContexts) { + return external_contexts_[type]; + } + return nullptr; +} + +TfLiteExternalContext* Interpreter::GetExternalContext( + struct TfLiteContext* context, TfLiteExternalContextType type) { + return static_cast(context->impl_)->GetExternalContext(type); +} + +void Interpreter::SetExternalContext(TfLiteExternalContextType type, + TfLiteExternalContext* ctx) { + if (type >= 0 && type < kTfLiteMaxExternalContexts) { + external_contexts_[type] = ctx; + } +} + +void Interpreter::SetExternalContext(struct TfLiteContext* context, + TfLiteExternalContextType type, + TfLiteExternalContext* ctx) { + return static_cast(context->impl_) + ->SetExternalContext(type, ctx); +} + // Gets an TfLiteIntArray* representing the execution plan. The interpreter owns // this memory and it is only guaranteed to exist during the invocation of the // delegate prepare. @@ -372,23 +427,33 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, } TfLiteStatus Interpreter::AllocateTensors() { - next_execution_plan_index_to_prepare_ = 0; - if (memory_planner_) { - TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations()); - } - if (!consistent_) { ReportError(&context_, "AllocateTensors() called on inconsistent model."); return kTfLiteError; } - TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); + // Explicit (re)allocation is necessary if nodes have been changed or tensors + // have been resized. For inputs marked as dynamic, we can't short-circuit the + // allocation as the client may have done the resize manually. + if (state_ != kStateUninvokable && !HasDynamicTensorImpl(context_, inputs_)) { + return kTfLiteOk; + } - if (state_ == kStateUninvokable) { - state_ = kStateInvokable; + next_execution_plan_index_to_prepare_ = 0; + if (memory_planner_) { + TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations()); } - TF_LITE_ENSURE(&context_, state_ == kStateInvokable || - state_ == kStateInvokableAndImmutable); + + TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); + + state_ = kStateInvokable; + + // Reset the variable tensors to zero after (re)allocating the tensors. + // Developers shouldn't rely on the side effect of this function to reset + // variable tesnsors. They should call `ResetVariableTensorsToZero` directly + // instead. + ResetVariableTensorsToZero(); + return kTfLiteOk; } @@ -481,26 +546,26 @@ TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index, "ResizeInputTensor is disallowed when graph is immutable."); return kTfLiteError; } - state_ = kStateUninvokable; // TODO(aselle): All bounds checks can be implemented as one-sided bounds // checks by casting to unsigned for efficiency. Profile before doing this. TF_LITE_ENSURE(&context_, tensor_index < context_.tensors_size && tensor_index >= 0); - TfLiteIntArray* dims_lite = ConvertVectorToTfLiteIntArray(dims); - return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite); + TfLiteTensor* tensor = &context_.tensors[tensor_index]; + + // Short-circuit the state change if the dimensions don't change, avoiding + // unnecessary (re)allocations. + if (EqualArrayAndTfLiteIntArray(tensor->dims, dims.size(), dims.data())) { + return kTfLiteOk; + } + + state_ = kStateUninvokable; + return ResizeTensorImpl(tensor, ConvertVectorToTfLiteIntArray(dims)); } -// Returns true if at least one tensor in the given list is kTfLiteDynamic. bool HasDynamicTensor(const TfLiteContext& context, - const TfLiteIntArray* tensors) { - for (int i = 0; i < tensors->size; ++i) { - const TfLiteTensor& tensor = context.tensors[tensors->data[i]]; - if (tensor.allocation_type == kTfLiteDynamic) { - return true; - } - } - return false; + const TfLiteIntArray* int_array) { + return HasDynamicTensorImpl(context, TfLiteIntArrayView{int_array}); } TfLiteStatus Interpreter::PrepareOpsStartingAt( @@ -513,7 +578,8 @@ TfLiteStatus Interpreter::PrepareOpsStartingAt( nodes_and_registration_[node_index].second; EnsureTensorsVectorCapacity(); if (OpPrepare(registration, &node) == kTfLiteError) { - return kTfLiteError; + return ReportOpError(&context_, node, registration, node_index, + "failed to prepare"); } *last_execution_plan_index_prepared = execution_plan_index; @@ -531,7 +597,8 @@ TfLiteStatus Interpreter::PrepareOpsStartingAt( TfLiteStatus Interpreter::PrepareOpsAndTensors() { if (!memory_planner_) { memory_planner_.reset(new ArenaPlanner( - &context_, std::unique_ptr(new InterpreterInfo(this)))); + &context_, std::unique_ptr(new InterpreterInfo(this)), + /*preserve_inputs=*/true, /*preserve_intermediates*/ false)); memory_planner_->PlanAllocations(); } @@ -610,7 +677,8 @@ TfLiteStatus Interpreter::Invoke() { EnsureTensorsVectorCapacity(); tensor_resized_since_op_invoke_ = false; if (OpInvoke(registration, &node) == kTfLiteError) { - status = kTfLiteError; + status = ReportOpError(&context_, node, registration, node_index, + "failed to invoke"); } // Force execution prep for downstream ops if the latest op triggered the @@ -829,7 +897,7 @@ void Interpreter::UseNNAPI(bool enable) { // TODO(aselle): This is a workaround for finding if NNAPI exists. // We also need to make sure getLibraryHandle() is renamed to be NNAPI // prefixed. - if (!NNAPIExists()) enable = false; + if (!NNAPIDelegate::IsSupported()) enable = false; if (!enable) { nnapi_delegate_.reset(); } else if (!nnapi_delegate_) { @@ -840,10 +908,25 @@ void Interpreter::UseNNAPI(bool enable) { void Interpreter::SetNumThreads(int num_threads) { context_.recommended_num_threads = num_threads; - // TODO(ahentz): find a way to avoid this. It causes gemmlowp and eigen to - // be required in order to compile the framework. - gemm_support::SetNumThreads(&context_, num_threads); - eigen_support::SetNumThreads(&context_, num_threads); + for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) { + auto* c = external_contexts_[i]; + if (c && c->Refresh) { + c->Refresh(&context_); + } + } +} + +void Interpreter::SwitchToDelegateContext() { + context_.GetNodeAndRegistration = GetNodeAndRegistration; + context_.ReplaceSubgraphsWithDelegateKernels = + ReplaceSubgraphsWithDelegateKernels; + context_.GetExecutionPlan = GetExecutionPlan; +} + +void Interpreter::SwitchToKernelContext() { + SetForbiddenContextFunction(&context_.GetNodeAndRegistration); + SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels); + SetForbiddenContextFunction(&context_.GetExecutionPlan); } TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate, @@ -872,24 +955,20 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate, // TODO(aselle): Consider if it is worth storing pointers to delegates. // Setup additional context interface. - context_.GetNodeAndRegistration = GetNodeAndRegistration; - context_.ReplaceSubgraphsWithDelegateKernels = - ReplaceSubgraphsWithDelegateKernels; - context_.GetExecutionPlan = GetExecutionPlan; + SwitchToDelegateContext(); TfLiteStatus status = delegate->Prepare(&context_, delegate); // Remove additional context info. - SetForbiddenContextFunction(&context_.GetNodeAndRegistration); - SetForbiddenContextFunction(&context_.ReplaceSubgraphsWithDelegateKernels); - SetForbiddenContextFunction(&context_.GetExecutionPlan); + SwitchToKernelContext(); TF_LITE_ENSURE_OK(&context_, status); if (!allow_dynamic_tensors) { + // Reset the state to force tensor/op reallocation. + state_ = kStateUninvokable; TF_LITE_ENSURE_OK(&context_, AllocateTensors()); - TF_LITE_ENSURE(&context_, state_ == kStateInvokable || - state_ == kStateInvokableAndImmutable); + TF_LITE_ENSURE_EQ(&context_, state_, kStateInvokable); // After using a delegate which doesn't support dynamic tensors, make the // entire graph immutable. state_ = kStateInvokableAndImmutable; diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 033b8ee5fabc416fd5936b7ff69697235cd9e7e7..e8301ff5076ec104d09351d081a28f5eb0964bc6 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -63,6 +63,10 @@ template <> constexpr TfLiteType typeToTfLiteType>() { return kTfLiteComplex64; } +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteString; +} // Forward declare since NNAPIDelegate uses Interpreter. class NNAPIDelegate; @@ -107,7 +111,7 @@ class Interpreter { // processing this model will be forwarded to the error_reporter object. // // Note, if error_reporter is nullptr, then a default StderrReporter is - // used. + // used. Ownership of 'error_reporter' remains with the caller. explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter()); ~Interpreter(); @@ -410,6 +414,15 @@ class Interpreter { } private: + friend class InterpreterTest; + + // Prevent 'context_' from accessing functions that are only available to + // delegated kernels. + void SwitchToKernelContext(); + + // Add delegate-only functions to 'context_'. + void SwitchToDelegateContext(); + // Give 'op_reg' a chance to initialize itself using the contents of // 'buffer'. void* OpInit(const TfLiteRegistration& op_reg, const char* buffer, @@ -496,6 +509,7 @@ class Interpreter { // Update the execution graph to replace some of the nodes with stub // nodes. Specifically any node index that has `nodes[index]==1` will be // slated for replacement with a delegate kernel specified by registration. + // Ownership of 'nodes_to_replace' and 'delegate' remains with the caller. // WARNING: This is an experimental interface that is subject to change. TfLiteStatus ReplaceSubgraphsWithDelegateKernels( TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace, @@ -513,15 +527,28 @@ class Interpreter { TfLiteRegistration** registration); // WARNING: This is an experimental interface that is subject to change. - // Gets an TfLiteIntArray* representing the execution plan. The caller owns - // this memory and must free it with TfLiteIntArrayFree(). + // Gets an TfLiteIntArray* representing the execution plan. The interpreter + // owns this memory and it is only guaranteed to exist during the invocation + // of the delegate prepare. TfLiteStatus GetExecutionPlan(TfLiteIntArray** execution_plan); // WARNING: This is an experimental interface that is subject to change. - // Entry point for C node plugin API to get the execution plan + // Entry point for C node plugin API to get the execution plan. static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context, TfLiteIntArray** execution_plan); + // Retrieve an existing external context by type. + TfLiteExternalContext* GetExternalContext(TfLiteExternalContextType type); + static TfLiteExternalContext* GetExternalContext( + struct TfLiteContext* context, TfLiteExternalContextType type); + + // Set the value of an external context. + void SetExternalContext(TfLiteExternalContextType type, + TfLiteExternalContext* ctx); + static void SetExternalContext(struct TfLiteContext* context, + TfLiteExternalContextType type, + TfLiteExternalContext* ctx); + // Ensures that `tensors_` has at least `kTensorsCapacityHeadroom` extra // capacity. Calling this function may invalidate existing pointers to // tensors. After calling this function, adding `kTensorsCapacityHeadroom` @@ -611,7 +638,10 @@ class Interpreter { bool tensor_resized_since_op_invoke_ = false; // Profiler for this interpreter instance. - profiling::Profiler* profiler_; + profiling::Profiler* profiler_ = nullptr; + + // List of active external contexts. + TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts]; }; } // namespace tflite diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 21cdf87d1e421868d1b62c5e23c2481cfbb4c989..2bf598bad71b87afaa22c1eb95474c49386c122f 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -23,6 +23,15 @@ limitations under the License. #include "tensorflow/contrib/lite/testing/util.h" namespace tflite { + +// InterpreterTest is a friend of Interpreter, so it can access context_. +class InterpreterTest : public ::testing::Test { + protected: + TfLiteContext* GetInterpreterContext() { return &interpreter_.context_; } + + Interpreter interpreter_; +}; + namespace ops { namespace builtin { TfLiteRegistration* Register_PADV2(); @@ -48,6 +57,22 @@ TEST(BasicInterpreter, InvokeInvalidModel) { ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); } +TEST(BasicInterpreter, TestAllocateTensorsResetVariableTensors) { + Interpreter interpreter; + int tensor_index; + ASSERT_EQ(interpreter.AddTensors(1, &tensor_index), kTfLiteOk); + constexpr int kTensorSize = 16; + interpreter.SetTensorParametersReadWrite(tensor_index, kTfLiteFloat32, "", + {kTensorSize}, {}, true); + interpreter.SetVariables({tensor_index}); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + TfLiteTensor* tensor = interpreter.tensor(tensor_index); + // Ensure that variable tensors are reset to zero. + for (int i = 0; i < kTensorSize; ++i) { + ASSERT_EQ(tensor->data.f[i], 0.0f); + } +} + // Test size accessor functions. TEST(BasicInterpreter, TestSizeFunctions) { Interpreter interpreter; @@ -231,32 +256,16 @@ TEST(BasicInterpreter, CheckArenaAllocation) { ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); - ASSERT_EQ(interpreter.tensor(0)->data.raw, interpreter.tensor(4)->data.raw); - ASSERT_EQ(interpreter.tensor(1)->data.raw, interpreter.tensor(7)->data.raw); - ASSERT_EQ(interpreter.tensor(8)->data.raw, nullptr); - - ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(1)->data.raw); - ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(1)->data.raw); ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(1)->data.raw); - - ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(2)->data.raw); ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(3)->data.raw); - ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(3)->data.raw); - - ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(4)->data.raw); ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(5)->data.raw); - ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(5)->data.raw, interpreter.tensor(7)->data.raw); + ASSERT_EQ(interpreter.tensor(6)->data.raw, interpreter.tensor(2)->data.raw); + // #7 is the one with the largest pointer. + ASSERT_EQ(interpreter.tensor(8)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw); } TEST(BasicInterpreter, BufferAccess) { @@ -292,6 +301,57 @@ TEST(BasicInterpreter, NoOpInterpreter) { ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); } +TEST(BasicInterpreter, RedundantAllocateTensors) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + const auto data_raw = interpreter.tensor(0)->data.raw; + ASSERT_NE(data_raw, nullptr); + + // A redundant allocation request should have no impact. + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.tensor(0)->data.raw, data_raw); +} + +TEST(BasicInterpreter, RedundantAllocateTensorsWithDynamicInputs) { + Interpreter interpreter; + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + interpreter.SetInputs({0}); + interpreter.SetOutputs({1}); + interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 1, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + + // Configure the input tensor as dynamic. + interpreter.tensor(0)->data.raw = nullptr; + interpreter.tensor(0)->allocation_type = kTfLiteDynamic; + + ASSERT_EQ(interpreter.ResizeInputTensor(interpreter.inputs()[0], {1, 2, 3}), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr); + + // Reset the output tensor's buffer. + interpreter.tensor(1)->data.raw = nullptr; + + // A redundant allocation request should be honored, as the input tensor + // was marked dynamic. + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr); +} + TEST(BasicInterpreter, ResizingTensors) { Interpreter interpreter; ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); @@ -349,6 +409,37 @@ TEST(BasicInterpreter, ResizingTensors) { tensor->data.f[15] = 0.123f; } +TEST(BasicInterpreter, NoopResizingTensors) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + + int t = interpreter.inputs()[0]; + TfLiteTensor* tensor = interpreter.tensor(t); + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 6 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + tensor->data.f[5] = 0.123f; + + // Resizing to the same size should not trigger re-allocation. + ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 6 * sizeof(float)); + ASSERT_NE(tensor->data.raw, nullptr); + ASSERT_EQ(tensor->data.f[5], 0.123f); + + // Explicitly allocating should be a no-op, as no resize was performed. + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 6 * sizeof(float)); + ASSERT_NE(tensor->data.raw, nullptr); + ASSERT_EQ(tensor->data.f[5], 0.123f); +} + TEST(BasicInterpreter, OneOpInterpreter) { Interpreter interpreter; ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); @@ -556,18 +647,6 @@ TEST(BasicInterpreter, AllocateTwice) { ASSERT_EQ(old_tensor1_ptr, interpreter.tensor(1)->data.raw); } -struct TestErrorReporter : public ErrorReporter { - int Report(const char* format, va_list args) override { - char buffer[1024]; - int size = vsnprintf(buffer, sizeof(buffer), format, args); - all_reports += buffer; - calls++; - return size; - } - int calls = 0; - std::string all_reports; -}; - TEST(BasicInterpreter, TestNullErrorReporter) { TestErrorReporter reporter; Interpreter interpreter; @@ -577,8 +656,9 @@ TEST(BasicInterpreter, TestCustomErrorReporter) { TestErrorReporter reporter; Interpreter interpreter(&reporter); ASSERT_NE(interpreter.Invoke(), kTfLiteOk); - ASSERT_EQ(reporter.all_reports, "Invoke called on model that is not ready."); - ASSERT_EQ(reporter.calls, 1); + ASSERT_EQ(reporter.error_messages(), + "Invoke called on model that is not ready."); + ASSERT_EQ(reporter.num_calls(), 1); } TEST(BasicInterpreter, TestUnsupportedDelegateFunctions) { @@ -714,6 +794,47 @@ TEST(InterpreterTensorsCapacityTest, TestExceedHeadroom) { ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); } +struct TestExternalContext : public TfLiteExternalContext { + static const TfLiteExternalContextType kType = kTfLiteGemmLowpContext; + + static TestExternalContext* Get(TfLiteContext* context) { + return reinterpret_cast( + context->GetExternalContext(context, kType)); + } + + static void Set(TfLiteContext* context, TestExternalContext* value) { + context->SetExternalContext(context, kType, value); + } + + int num_refreshes = 0; +}; + +TEST_F(InterpreterTest, GetSetResetExternalContexts) { + auto* context = GetInterpreterContext(); + + TestExternalContext external_context; + external_context.Refresh = [](TfLiteContext* context) { + auto* ptr = TestExternalContext::Get(context); + if (ptr != nullptr) { + ++ptr->num_refreshes; + } + return kTfLiteOk; + }; + + EXPECT_EQ(TestExternalContext::Get(context), nullptr); + interpreter_.SetNumThreads(4); + + TestExternalContext::Set(context, &external_context); + EXPECT_EQ(TestExternalContext::Get(context), &external_context); + interpreter_.SetNumThreads(4); + interpreter_.SetNumThreads(5); + EXPECT_EQ(external_context.num_refreshes, 2); + + TestExternalContext::Set(context, nullptr); + EXPECT_EQ(TestExternalContext::Get(context), nullptr); + interpreter_.SetNumThreads(4); +} + // Test fixture that allows playing with execution plans. It creates a two // node graph that can be executed in either [0,1] order or [1,0] order. // The CopyOp records when it is invoked in the class member run_order_ diff --git a/tensorflow/contrib/lite/java/AndroidManifest.xml b/tensorflow/contrib/lite/java/AndroidManifest.xml index f705feacbec38ab5152ce52b701320d8f1cd8d3d..b91c6d149a213926be90b9b131bd632d4f79a0fc 100644 --- a/tensorflow/contrib/lite/java/AndroidManifest.xml +++ b/tensorflow/contrib/lite/java/AndroidManifest.xml @@ -1,7 +1,12 @@ - - + package="org.tensorflow.lite"> + + + + + diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD index 593af81a18a1e20a41dcc8d9bb3a1d815876e294..098ba7e7731d833678fbd5eab9cce3f022570f23 100644 --- a/tensorflow/contrib/lite/java/BUILD +++ b/tensorflow/contrib/lite/java/BUILD @@ -69,6 +69,7 @@ java_test( size = "small", srcs = ["src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java"], javacopts = JAVACOPTS, + tags = ["no_oss"], test_class = "org.tensorflow.lite.TensorFlowLiteTest", deps = [ ":libtensorflowlite_jni.so", @@ -83,6 +84,7 @@ java_test( size = "small", srcs = ["src/test/java/org/tensorflow/lite/DataTypeTest.java"], javacopts = JAVACOPTS, + tags = ["no_oss"], test_class = "org.tensorflow.lite.DataTypeTest", deps = [ ":libtensorflowlite_jni.so", @@ -105,6 +107,7 @@ java_test( "src/testdata/with_custom_op.lite", ], javacopts = JAVACOPTS, + tags = ["no_oss"], test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest", deps = [ ":libtensorflowlite_jni.so", @@ -124,6 +127,7 @@ java_test( "src/testdata/mobilenet.tflite.bin", ], javacopts = JAVACOPTS, + tags = ["no_oss"], test_class = "org.tensorflow.lite.InterpreterTest", visibility = ["//visibility:private"], deps = [ @@ -142,6 +146,7 @@ java_test( "src/testdata/add.bin", ], javacopts = JAVACOPTS, + tags = ["no_oss"], test_class = "org.tensorflow.lite.TensorTest", deps = [ ":tensorflowlitelib", diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle index 192162cfce787ffbf13e2b0db2da972116407888..92f04c651c0488a5202def593774890630c8631f 100644 --- a/tensorflow/contrib/lite/java/demo/app/build.gradle +++ b/tensorflow/contrib/lite/java/demo/app/build.gradle @@ -10,7 +10,7 @@ android { targetSdkVersion 26 versionCode 1 versionName "1.0" - testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" // Remove this block. jackOptions { @@ -44,7 +44,7 @@ repositories { dependencies { compile fileTree(dir: 'libs', include: ['*.jar']) - androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + androidTestCompile('androidx.test.espresso:espresso-core:3.1.0-alpha3', { exclude group: 'com.android.support', module: 'support-annotations' }) compile 'com.android.support:appcompat-v7:25.2.0' @@ -92,4 +92,4 @@ class DownloadUrlTask extends DefaultTask { void download() { ant.get(src: sourceUrl, dest: target) } -} \ No newline at end of file +} diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD index f232b00045cf1df6a31ada80af4cc5885a4c0099..06f46fb92394b19415ddb95dcf8c798753b630e3 100644 --- a/tensorflow/contrib/lite/java/ovic/BUILD +++ b/tensorflow/contrib/lite/java/ovic/BUILD @@ -18,6 +18,7 @@ java_test( "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", ], javacopts = JAVACOPTS, + tags = ["no_oss"], test_class = "org.tensorflow.ovic.OvicClassifierTest", visibility = ["//visibility:public"], deps = [ diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle index c5d19bad89a93988a6830a17fe2fb4a60e2fb00f..2a08608bbb121a2e279bbd16de6a014e68848796 100644 --- a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle +++ b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle @@ -9,7 +9,7 @@ android { targetSdkVersion 26 versionCode 1 versionName "1.0" - testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" // Remove this block. jackOptions { @@ -43,7 +43,7 @@ repositories { dependencies { compile fileTree(dir: 'libs', include: ['*.jar']) - androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + androidTestCompile('androidx.test.espresso:espresso-core:3.1.0-alpha3', { exclude group: 'com.android.support', module: 'support-annotations' }) compile 'com.android.support:appcompat-v7:25.2.0' diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java index 56f3e7604a5b172e907edbe862b017957594397f..1587c3c56f45c0baddfa75286c979fe0c0edffcc 100644 --- a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java +++ b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java @@ -127,12 +127,8 @@ public final class OvicClassifierTest { try { testResult = classifier.classifyByteBuffer(testImage); fail(); - } catch (RuntimeException e) { - assertThat(e) - .hasMessageThat() - .contains( - "Failed to get input dimensions. 0-th input should have 49152 bytes, " - + "but found 150528 bytes."); + } catch (IllegalArgumentException e) { + // Success. } } diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java index 75334cd96e8daadc356dadea063eee30ef6d5245..94a1ec65d64b6493cdb309fc0c19155eb9cb26cb 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java @@ -27,10 +27,7 @@ enum DataType { UINT8(3), /** 64-bit signed integer. */ - INT64(4), - - /** A {@link ByteBuffer}. */ - BYTEBUFFER(999); + INT64(4); private final int value; @@ -69,8 +66,6 @@ enum DataType { return 1; case INT64: return 8; - case BYTEBUFFER: - return 1; } throw new IllegalArgumentException( "DataType error: DataType " + this + " is not supported yet"); @@ -87,8 +82,6 @@ enum DataType { return "byte"; case INT64: return "long"; - case BYTEBUFFER: - return "ByteBuffer"; } throw new IllegalArgumentException( "DataType error: DataType " + this + " is not supported yet"); diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index fd1f0ffa68eeca7b5866b146ecaa1f9216ef377d..7002f826775b216e0a27ebe00f30680c9ce362bb 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -135,7 +135,8 @@ public final class Interpreter implements AutoCloseable { * including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large * input data. When {@link ByteBuffer} is used, its content should remain unchanged until * model inference is done. - * @param output a multidimensional array of output data. + * @param output a multidimensional array of output data, or a {@link ByteBuffer} of primitive + * types including int, float, long, and byte. */ public void run(@NonNull Object input, @NonNull Object output) { Object[] inputs = {input}; @@ -155,28 +156,16 @@ public final class Interpreter implements AutoCloseable { * primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred * way to pass large input data. When {@link ByteBuffer} is used, its content should remain * unchanged until model inference is done. - * @param outputs a map mapping output indices to multidimensional arrays of output data. It only - * needs to keep entries for the outputs to be used. + * @param outputs a map mapping output indices to multidimensional arrays of output data or {@link + * ByteBuffer}s of primitive types including int, float, long, and byte. It only needs to keep + * entries for the outputs to be used. */ public void runForMultipleInputsOutputs( @NonNull Object[] inputs, @NonNull Map outputs) { if (wrapper == null) { throw new IllegalStateException("Internal error: The Interpreter has already been closed."); } - Tensor[] tensors = wrapper.run(inputs); - if (outputs == null || tensors == null || outputs.size() > tensors.length) { - throw new IllegalArgumentException("Output error: Outputs do not match with model outputs."); - } - final int size = tensors.length; - for (Integer idx : outputs.keySet()) { - if (idx == null || idx < 0 || idx >= size) { - throw new IllegalArgumentException( - String.format( - "Output error: Invalid index of output %d (should be in range [0, %d))", - idx, size)); - } - tensors[idx].copyTo(outputs.get(idx)); - } + wrapper.run(inputs, outputs); } /** @@ -249,8 +238,10 @@ public final class Interpreter implements AutoCloseable { /** Release resources associated with the {@code Interpreter}. */ @Override public void close() { - wrapper.close(); - wrapper = null; + if (wrapper != null) { + wrapper.close(); + wrapper = null; + } } @Override diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index 80de88b6a1cd75b033e116f76f5612ee66e48f03..767a220f8cd5381ce10e044553317b1cb05ba17b 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -15,10 +15,10 @@ limitations under the License. package org.tensorflow.lite; -import java.lang.reflect.Array; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.MappedByteBuffer; +import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -40,6 +40,8 @@ final class NativeInterpreterWrapper implements AutoCloseable { modelHandle = createModel(modelPath, errorHandle); interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads); isMemoryAllocated = true; + inputTensors = new Tensor[getInputCount(interpreterHandle)]; + outputTensors = new Tensor[getOutputCount(interpreterHandle)]; } /** @@ -72,6 +74,8 @@ final class NativeInterpreterWrapper implements AutoCloseable { modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads); isMemoryAllocated = true; + inputTensors = new Tensor[getInputCount(interpreterHandle)]; + outputTensors = new Tensor[getOutputCount(interpreterHandle)]; } /** Releases resources associated with this {@code NativeInterpreterWrapper}. */ @@ -85,75 +89,63 @@ final class NativeInterpreterWrapper implements AutoCloseable { inputsIndexes = null; outputsIndexes = null; isMemoryAllocated = false; + Arrays.fill(inputTensors, null); + Arrays.fill(outputTensors, null); } /** Sets inputs, runs model inference and returns outputs. */ - Tensor[] run(Object[] inputs) { + void run(Object[] inputs, Map outputs) { + inferenceDurationNanoseconds = -1; if (inputs == null || inputs.length == 0) { throw new IllegalArgumentException("Input error: Inputs should not be null or empty."); } - int[] dataTypes = new int[inputs.length]; - Object[] sizes = new Object[inputs.length]; - int[] numsOfBytes = new int[inputs.length]; + if (outputs == null || outputs.isEmpty()) { + throw new IllegalArgumentException("Input error: Outputs should not be null or empty."); + } + + // TODO(b/80431971): Remove implicit resize after deprecating multi-dimensional array inputs. + // Rather than forcing an immediate resize + allocation if an input's shape differs, we first + // flush all resizes, avoiding redundant allocations. for (int i = 0; i < inputs.length; ++i) { - DataType dataType = dataTypeOf(inputs[i]); - dataTypes[i] = dataType.getNumber(); - if (dataType == DataType.BYTEBUFFER) { - ByteBuffer buffer = (ByteBuffer) inputs[i]; - if (buffer == null || !buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()) { - throw new IllegalArgumentException( - "Input error: ByteBuffer should be a direct ByteBuffer that uses " - + "ByteOrder.nativeOrder()."); - } - numsOfBytes[i] = buffer.limit(); - sizes[i] = getInputDims(interpreterHandle, i, numsOfBytes[i]); - } else if (isNonEmptyArray(inputs[i])) { - int[] dims = shapeOf(inputs[i]); - sizes[i] = dims; - numsOfBytes[i] = dataType.elemByteSize() * numElements(dims); - } else { - throw new IllegalArgumentException( - String.format( - "Input error: %d-th element of the %d inputs is not an array or a ByteBuffer.", - i, inputs.length)); + Tensor tensor = getInputTensor(i); + int[] newShape = tensor.getInputShapeIfDifferent(inputs[i]); + if (newShape != null) { + resizeInput(i, newShape); } } - inferenceDurationNanoseconds = -1; - long[] outputsHandles = - run( - interpreterHandle, - errorHandle, - sizes, - dataTypes, - numsOfBytes, - inputs, - this, - isMemoryAllocated); - if (outputsHandles == null || outputsHandles.length == 0) { - throw new IllegalStateException("Internal error: Interpreter has no outputs."); + + if (!isMemoryAllocated) { + allocateTensors(interpreterHandle, errorHandle); + isMemoryAllocated = true; + // Allocation can trigger dynamic resizing of output tensors, so clear the + // output tensor cache. + Arrays.fill(outputTensors, null); } - isMemoryAllocated = true; - Tensor[] outputs = new Tensor[outputsHandles.length]; - for (int i = 0; i < outputsHandles.length; ++i) { - outputs[i] = Tensor.fromHandle(outputsHandles[i]); + + for (int i = 0; i < inputs.length; ++i) { + getInputTensor(i).setTo(inputs[i]); + } + + long inferenceStartNanos = System.nanoTime(); + run(interpreterHandle, errorHandle); + long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos; + + for (Map.Entry output : outputs.entrySet()) { + getOutputTensor(output.getKey()).copyTo(output.getValue()); } - return outputs; + + // Only set if the entire operation succeeds. + this.inferenceDurationNanoseconds = inferenceDurationNanoseconds; } - private static native long[] run( - long interpreterHandle, - long errorHandle, - Object[] sizes, - int[] dtypes, - int[] numsOfBytes, - Object[] values, - NativeInterpreterWrapper wrapper, - boolean memoryAllocated); + private static native boolean run(long interpreterHandle, long errorHandle); /** Resizes dimensions of a specific input. */ void resizeInput(int idx, int[] dims) { if (resizeInput(interpreterHandle, errorHandle, idx, dims)) { isMemoryAllocated = false; + // Resizing will invalidate the Tensor's shape, so invalidate the Tensor handle. + inputTensors[idx] = null; } } @@ -212,78 +204,6 @@ final class NativeInterpreterWrapper implements AutoCloseable { } } - static int numElements(int[] shape) { - if (shape == null) { - return 0; - } - int n = 1; - for (int i = 0; i < shape.length; i++) { - n *= shape[i]; - } - return n; - } - - static boolean isNonEmptyArray(Object o) { - return (o != null && o.getClass().isArray() && Array.getLength(o) != 0); - } - - /** Returns the type of the data. */ - static DataType dataTypeOf(Object o) { - if (o != null) { - Class c = o.getClass(); - while (c.isArray()) { - c = c.getComponentType(); - } - if (float.class.equals(c)) { - return DataType.FLOAT32; - } else if (int.class.equals(c)) { - return DataType.INT32; - } else if (byte.class.equals(c)) { - return DataType.UINT8; - } else if (long.class.equals(c)) { - return DataType.INT64; - } else if (ByteBuffer.class.isInstance(o)) { - return DataType.BYTEBUFFER; - } - } - throw new IllegalArgumentException( - "DataType error: cannot resolve DataType of " + o.getClass().getName()); - } - - /** Returns the shape of an object as an int array. */ - static int[] shapeOf(Object o) { - int size = numDimensions(o); - int[] dimensions = new int[size]; - fillShape(o, 0, dimensions); - return dimensions; - } - - static int numDimensions(Object o) { - if (o == null || !o.getClass().isArray()) { - return 0; - } - if (Array.getLength(o) == 0) { - throw new IllegalArgumentException("Array lengths cannot be 0."); - } - return 1 + numDimensions(Array.get(o, 0)); - } - - static void fillShape(Object o, int dim, int[] shape) { - if (shape == null || dim == shape.length) { - return; - } - final int len = Array.getLength(o); - if (shape[dim] == 0) { - shape[dim] = len; - } else if (shape[dim] != len) { - throw new IllegalArgumentException( - String.format("Mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim)); - } - for (int i = 0; i < len; ++i) { - fillShape(Array.get(o, i), dim + 1, shape); - } - } - /** * Gets the last inference duration in nanoseconds. It returns null if there is no previous * inference run or the last inference run failed. @@ -293,40 +213,55 @@ final class NativeInterpreterWrapper implements AutoCloseable { } /** - * Gets the dimensions of an input. It throws IllegalArgumentException if input index is invalid. + * Gets the quantization zero point of an output. + * + * @throws IllegalArgumentException if the output index is invalid. */ - int[] getInputDims(int index) { - return getInputDims(interpreterHandle, index, -1); + int getOutputQuantizationZeroPoint(int index) { + return getOutputQuantizationZeroPoint(interpreterHandle, index); } /** - * Gets the dimensions of an input. If numBytes >= 0, it will check whether num of bytes match the - * input. + * Gets the quantization scale of an output. + * + * @throws IllegalArgumentException if the output index is invalid. */ - private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes); - - /** Gets the type of an output. It throws IllegalArgumentException if output index is invalid. */ - String getOutputDataType(int index) { - int type = getOutputDataType(interpreterHandle, index); - return DataType.fromNumber(type).toStringName(); + float getOutputQuantizationScale(int index) { + return getOutputQuantizationScale(interpreterHandle, index); } /** - * Gets the quantization zero point of an output. + * Gets the input {@link Tensor} for the provided input index. * - * @throws IllegalArgumentExeption if the output index is invalid. + * @throws IllegalArgumentException if the input index is invalid. */ - int getOutputQuantizationZeroPoint(int index) { - return getOutputQuantizationZeroPoint(interpreterHandle, index); + Tensor getInputTensor(int index) { + if (index < 0 || index >= inputTensors.length) { + throw new IllegalArgumentException("Invalid input Tensor index: " + index); + } + Tensor inputTensor = inputTensors[index]; + if (inputTensor == null) { + inputTensor = + inputTensors[index] = Tensor.fromHandle(getInputTensor(interpreterHandle, index)); + } + return inputTensor; } /** - * Gets the quantization scale of an output. + * Gets the output {@link Tensor} for the provided output index. * - * @throws IllegalArgumentExeption if the output index is invalid. + * @throws IllegalArgumentException if the output index is invalid. */ - float getOutputQuantizationScale(int index) { - return getOutputQuantizationScale(interpreterHandle, index); + Tensor getOutputTensor(int index) { + if (index < 0 || index >= outputTensors.length) { + throw new IllegalArgumentException("Invalid output Tensor index: " + index); + } + Tensor outputTensor = outputTensors[index]; + if (outputTensor == null) { + outputTensor = + outputTensors[index] = Tensor.fromHandle(getOutputTensor(interpreterHandle, index)); + } + return outputTensor; } private static native int getOutputDataType(long interpreterHandle, int outputIdx); @@ -343,18 +278,30 @@ final class NativeInterpreterWrapper implements AutoCloseable { private long modelHandle; - private int inputSize; - private long inferenceDurationNanoseconds = -1; private ByteBuffer modelByteBuffer; + // Lazily constructed maps of input and output names to input and output Tensor indexes. private Map inputsIndexes; - private Map outputsIndexes; + // Lazily constructed and populated arrays of input and output Tensor wrappers. + private final Tensor[] inputTensors; + private final Tensor[] outputTensors; + private boolean isMemoryAllocated = false; + private static native long allocateTensors(long interpreterHandle, long errorHandle); + + private static native long getInputTensor(long interpreterHandle, int inputIdx); + + private static native long getOutputTensor(long interpreterHandle, int outputIdx); + + private static native int getInputCount(long interpreterHandle); + + private static native int getOutputCount(long interpreterHandle); + private static native String[] getInputNames(long interpreterHandle); private static native String[] getOutputNames(long interpreterHandle); diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java index 09e887aae3339e9f114c07d689c0d7b5e2fc384b..2403570c527e762f6782e313731e383feeeef46d 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -15,6 +15,9 @@ limitations under the License. package org.tensorflow.lite; +import java.lang.reflect.Array; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.Arrays; /** @@ -29,30 +32,179 @@ final class Tensor { return new Tensor(nativeHandle); } - /** Reads Tensor content into an array. */ - T copyTo(T dst) { - if (NativeInterpreterWrapper.dataTypeOf(dst) != dtype) { + /** Returns the {@link DataType} of elements stored in the Tensor. */ + public DataType dataType() { + return dtype; + } + + /** Returns the size, in bytes, of the tensor data. */ + public int numBytes() { + return numBytes(nativeHandle); + } + + /** + * Returns the shape of + * the Tensor, i.e., the sizes of each dimension. + * + * @return an array where the i-th element is the size of the i-th dimension of the tensor. + */ + public int[] shape() { + return shapeCopy; + } + + /** + * Copies the contents of the provided {@code src} object to the Tensor. + * + *

The {@code src} should either be a (multi-dimensional) array with a shape matching that of + * this tensor, or a {@link ByteByffer} of compatible primitive type with a matching flat size. + * + * @throws IllegalArgumentException if the tensor is a scalar or if {@code src} is not compatible + * with the tensor (for example, mismatched data types or shapes). + */ + void setTo(Object src) { + throwExceptionIfTypeIsIncompatible(src); + if (isByteBuffer(src)) { + ByteBuffer srcBuffer = (ByteBuffer) src; + // For direct ByteBuffer instances we support zero-copy. Note that this assumes the caller + // retains ownership of the source buffer until inference has completed. + if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) { + writeDirectBuffer(nativeHandle, srcBuffer); + } else { + buffer().put(srcBuffer); + } + return; + } + writeMultiDimensionalArray(nativeHandle, src); + } + + /** + * Copies the contents of the tensor to {@code dst} and returns {@code dst}. + * + * @param dst the destination buffer, either an explicitly-typed array or a {@link ByteBuffer}. + * @throws IllegalArgumentException if {@code dst} is not compatible with the tensor (for example, + * mismatched data types or shapes). + */ + Object copyTo(Object dst) { + throwExceptionIfTypeIsIncompatible(dst); + if (dst instanceof ByteBuffer) { + ByteBuffer dstByteBuffer = (ByteBuffer) dst; + dstByteBuffer.put(buffer()); + return dst; + } + readMultiDimensionalArray(nativeHandle, dst); + return dst; + } + + /** Returns the provided buffer's shape if specified and different from this Tensor's shape. */ + // TODO(b/80431971): Remove this method after deprecating multi-dimensional array inputs. + int[] getInputShapeIfDifferent(Object input) { + // Implicit resizes based on ByteBuffer capacity isn't supported, so short-circuit that path. + // The ByteBuffer's size will be validated against this Tensor's size in {@link #setTo(Object)}. + if (isByteBuffer(input)) { + return null; + } + int[] inputShape = shapeOf(input); + if (Arrays.equals(shapeCopy, inputShape)) { + return null; + } + return inputShape; + } + + /** Returns the type of the data. */ + static DataType dataTypeOf(Object o) { + if (o != null) { + Class c = o.getClass(); + while (c.isArray()) { + c = c.getComponentType(); + } + if (float.class.equals(c)) { + return DataType.FLOAT32; + } else if (int.class.equals(c)) { + return DataType.INT32; + } else if (byte.class.equals(c)) { + return DataType.UINT8; + } else if (long.class.equals(c)) { + return DataType.INT64; + } + } + throw new IllegalArgumentException( + "DataType error: cannot resolve DataType of " + o.getClass().getName()); + } + + /** Returns the shape of an object as an int array. */ + static int[] shapeOf(Object o) { + int size = numDimensions(o); + int[] dimensions = new int[size]; + fillShape(o, 0, dimensions); + return dimensions; + } + + /** Returns the number of dimensions of a multi-dimensional array, otherwise 0. */ + static int numDimensions(Object o) { + if (o == null || !o.getClass().isArray()) { + return 0; + } + if (Array.getLength(o) == 0) { + throw new IllegalArgumentException("Array lengths cannot be 0."); + } + return 1 + numDimensions(Array.get(o, 0)); + } + + /** Recursively populates the shape dimensions for a given (multi-dimensional) array. */ + static void fillShape(Object o, int dim, int[] shape) { + if (shape == null || dim == shape.length) { + return; + } + final int len = Array.getLength(o); + if (shape[dim] == 0) { + shape[dim] = len; + } else if (shape[dim] != len) { + throw new IllegalArgumentException( + String.format("Mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim)); + } + for (int i = 0; i < len; ++i) { + fillShape(Array.get(o, i), dim + 1, shape); + } + } + + private void throwExceptionIfTypeIsIncompatible(Object o) { + if (isByteBuffer(o)) { + ByteBuffer oBuffer = (ByteBuffer) o; + if (oBuffer.capacity() != numBytes()) { + throw new IllegalArgumentException( + String.format( + "Cannot convert between a TensorFlowLite buffer with %d bytes and a " + + "ByteBuffer with %d bytes.", + numBytes(), oBuffer.capacity())); + } + return; + } + DataType oType = dataTypeOf(o); + if (oType != dtype) { throw new IllegalArgumentException( String.format( - "Output error: Cannot convert an TensorFlowLite tensor with type %s to a Java " - + "object of type %s (which is compatible with the TensorFlowLite type %s)", - dtype, dst.getClass().getName(), NativeInterpreterWrapper.dataTypeOf(dst))); + "Cannot convert between a TensorFlowLite tensor with type %s and a Java " + + "object of type %s (which is compatible with the TensorFlowLite type %s).", + dtype, o.getClass().getName(), oType)); } - int[] dstShape = NativeInterpreterWrapper.shapeOf(dst); - if (!Arrays.equals(dstShape, shapeCopy)) { + + int[] oShape = shapeOf(o); + if (!Arrays.equals(oShape, shapeCopy)) { throw new IllegalArgumentException( String.format( - "Output error: Shape of output target %s does not match with the shape of the " - + "Tensor %s.", - Arrays.toString(dstShape), Arrays.toString(shapeCopy))); + "Cannot copy between a TensorFlowLite tensor with shape %s and a Java object " + + "with shape %s.", + Arrays.toString(shapeCopy), Arrays.toString(oShape))); } - readMultiDimensionalArray(nativeHandle, dst); - return dst; } - final long nativeHandle; - final DataType dtype; - final int[] shapeCopy; + private static boolean isByteBuffer(Object o) { + return o instanceof ByteBuffer; + } + + private final long nativeHandle; + private final DataType dtype; + private final int[] shapeCopy; private Tensor(long nativeHandle) { this.nativeHandle = nativeHandle; @@ -60,11 +212,23 @@ final class Tensor { this.shapeCopy = shape(nativeHandle); } + private ByteBuffer buffer() { + return buffer(nativeHandle).order(ByteOrder.nativeOrder()); + } + + private static native ByteBuffer buffer(long handle); + + private static native void writeDirectBuffer(long handle, ByteBuffer src); + private static native int dtype(long handle); private static native int[] shape(long handle); - private static native void readMultiDimensionalArray(long handle, Object value); + private static native int numBytes(long handle); + + private static native void readMultiDimensionalArray(long handle, Object dst); + + private static native void writeMultiDimensionalArray(long handle, Object src); static { TensorFlowLite.init(); diff --git a/tensorflow/contrib/lite/java/src/main/native/BUILD b/tensorflow/contrib/lite/java/src/main/native/BUILD index 4399ed202597082fba36c04a744bf6378e4539a2..4b4e1c21d818dc56803ff31d83d19dea2ac08707 100644 --- a/tensorflow/contrib/lite/java/src/main/native/BUILD +++ b/tensorflow/contrib/lite/java/src/main/native/BUILD @@ -11,7 +11,6 @@ licenses(["notice"]) # Apache 2.0 cc_library( name = "native_framework_only", srcs = [ - "duration_utils_jni.cc", "exception_jni.cc", "nativeinterpreterwrapper_jni.cc", "tensor_jni.cc", diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index 31f7b58fbc30cab9e6cb813094ea4b2627ba5cba..fdcf00a0a08459d8d669f1def3ae2eb21dbd31c3 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -16,9 +16,6 @@ limitations under the License. #include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h" namespace { -const int kByteBufferValue = 999; -const int kBufferSize = 256; - tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) { if (handle == 0) { throwException(env, kIllegalArgumentException, @@ -62,22 +59,6 @@ std::vector convertJIntArrayToVector(JNIEnv* env, jintArray inputs) { return outputs; } -bool isByteBuffer(jint data_type) { return data_type == kByteBufferValue; } - -TfLiteType resolveDataType(jint data_type) { - switch (data_type) { - case 1: - return kTfLiteFloat32; - case 2: - return kTfLiteInt32; - case 3: - return kTfLiteUInt8; - case 4: - return kTfLiteInt64; - default: - return kTfLiteNoType; - } -} int getDataType(TfLiteType data_type) { switch (data_type) { @@ -108,64 +89,6 @@ void printDims(char* buffer, int max_size, int* dims, int num_dims) { } } -TfLiteStatus checkInputs(JNIEnv* env, tflite::Interpreter* interpreter, - const int input_size, jintArray data_types, - jintArray nums_of_bytes, jobjectArray values, - jobjectArray sizes) { - if (input_size != interpreter->inputs().size()) { - throwException(env, kIllegalArgumentException, - "Input error: Expected num of inputs is %d but got %d", - interpreter->inputs().size(), input_size); - return kTfLiteError; - } - if (input_size != env->GetArrayLength(data_types) || - input_size != env->GetArrayLength(nums_of_bytes) || - input_size != env->GetArrayLength(values)) { - throwException(env, kIllegalArgumentException, - "Internal error: Arrays in arguments should be of the same " - "length, but got %d sizes, %d data_types, %d nums_of_bytes, " - "and %d values", - input_size, env->GetArrayLength(data_types), - env->GetArrayLength(nums_of_bytes), - env->GetArrayLength(values)); - return kTfLiteError; - } - for (int i = 0; i < input_size; ++i) { - int input_idx = interpreter->inputs()[i]; - TfLiteTensor* target = interpreter->tensor(input_idx); - jintArray dims = - static_cast(env->GetObjectArrayElement(sizes, i)); - int num_dims = static_cast(env->GetArrayLength(dims)); - if (target->dims->size != num_dims) { - throwException(env, kIllegalArgumentException, - "Input error: %d-th input should have %d dimensions, but " - "found %d dimensions", - i, target->dims->size, num_dims); - return kTfLiteError; - } - jint* ptr = env->GetIntArrayElements(dims, nullptr); - for (int j = 1; j < num_dims; ++j) { - if (target->dims->data[j] != ptr[j]) { - std::unique_ptr expected_dims(new char[kBufferSize]); - std::unique_ptr obtained_dims(new char[kBufferSize]); - printDims(expected_dims.get(), kBufferSize, target->dims->data, - num_dims); - printDims(obtained_dims.get(), kBufferSize, ptr, num_dims); - throwException(env, kIllegalArgumentException, - "Input error: %d-th input dimension should be [%s], but " - "found [%s]", - i, expected_dims.get(), obtained_dims.get()); - env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT); - return kTfLiteError; - } - } - env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT); - env->DeleteLocalRef(dims); - if (env->ExceptionCheck()) return kTfLiteError; - } - return kTfLiteOk; -} - // Checks whether there is any difference between dimensions of a tensor and a // given dimensions. Returns true if there is difference, else false. bool areDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) { @@ -188,74 +111,6 @@ bool areDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) { return false; } -bool areInputDimensionsTheSame(JNIEnv* env, tflite::Interpreter* interpreter, - int input_size, jobjectArray sizes) { - if (interpreter->inputs().size() != input_size) { - return false; - } - for (int i = 0; i < input_size; ++i) { - int input_idx = interpreter->inputs()[i]; - jintArray dims = - static_cast(env->GetObjectArrayElement(sizes, i)); - TfLiteTensor* target = interpreter->tensor(input_idx); - if (areDimsDifferent(env, target, dims)) return false; - env->DeleteLocalRef(dims); - if (env->ExceptionCheck()) return false; - } - return true; -} - -TfLiteStatus resizeInputs(JNIEnv* env, tflite::Interpreter* interpreter, - int input_size, jobjectArray sizes) { - for (int i = 0; i < input_size; ++i) { - int input_idx = interpreter->inputs()[i]; - jintArray dims = - static_cast(env->GetObjectArrayElement(sizes, i)); - TfLiteStatus status = interpreter->ResizeInputTensor( - input_idx, convertJIntArrayToVector(env, dims)); - if (status != kTfLiteOk) { - return status; - } - env->DeleteLocalRef(dims); - if (env->ExceptionCheck()) return kTfLiteError; - } - return kTfLiteOk; -} - -TfLiteStatus setInputs(JNIEnv* env, tflite::Interpreter* interpreter, - int input_size, jintArray data_types, - jintArray nums_of_bytes, jobjectArray values) { - jint* data_type = env->GetIntArrayElements(data_types, nullptr); - jint* num_bytes = env->GetIntArrayElements(nums_of_bytes, nullptr); - for (int i = 0; i < input_size; ++i) { - int input_idx = interpreter->inputs()[i]; - TfLiteTensor* target = interpreter->tensor(input_idx); - jobject value = env->GetObjectArrayElement(values, i); - bool is_byte_buffer = isByteBuffer(data_type[i]); - if (is_byte_buffer) { - writeByteBuffer(env, value, &(target->data.raw), - static_cast(num_bytes[i])); - } else { - TfLiteType type = resolveDataType(data_type[i]); - if (type != target->type) { - throwException(env, kIllegalArgumentException, - "Input error: DataType (%d) of input data does not " - "match with the DataType (%d) of model inputs.", - type, target->type); - return kTfLiteError; - } - writeMultiDimensionalArray(env, value, target->type, target->dims->size, - &(target->data.raw), - static_cast(num_bytes[i])); - } - env->DeleteLocalRef(value); - if (env->ExceptionCheck()) return kTfLiteError; - } - env->ReleaseIntArrayElements(data_types, data_type, JNI_ABORT); - env->ReleaseIntArrayElements(nums_of_bytes, num_bytes, JNI_ABORT); - return kTfLiteOk; -} - // TODO(yichengfan): evaluate the benefit to use tflite verifier. bool VerifyModel(const void* buf, size_t len) { flatbuffers::Verifier verifier(static_cast(buf), len); @@ -287,6 +142,64 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env, return names; } +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors( + JNIEnv* env, jclass clazz, jlong handle, jlong error_handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return; + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return; + + if (interpreter->AllocateTensors() != kTfLiteOk) { + throwException( + env, kIllegalStateException, + "Internal error: Unexpected failure when preparing tensor allocations:" + " %s", + error_reporter->CachedErrorMessage()); + } +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env, + jclass clazz, + jlong handle, + jint index) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 0; + return reinterpret_cast( + interpreter->tensor(interpreter->inputs()[index])); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env, + jclass clazz, + jlong handle, + jint index) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 0; + return reinterpret_cast( + interpreter->tensor(interpreter->outputs()[index])); +} + +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env, + jclass clazz, + jlong handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 0; + return static_cast(interpreter->inputs().size()); +} + +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env, + jclass clazz, + jlong handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 0; + return static_cast(interpreter->outputs().size()); +} + JNIEXPORT jobjectArray JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, jclass clazz, @@ -424,124 +337,32 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( // allocates memory status = interpreter->AllocateTensors(); if (status != kTfLiteOk) { - throwException(env, kNullPointerException, - "Internal error: Cannot allocate memory for the interpreter:" - " %s", - error_reporter->CachedErrorMessage()); + throwException( + env, kIllegalStateException, + "Internal error: Unexpected failure when preparing tensor allocations:" + " %s", + error_reporter->CachedErrorMessage()); return 0; } return reinterpret_cast(interpreter.release()); } // Sets inputs, runs inference, and returns outputs as long handles. -JNIEXPORT jlongArray JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_run( - JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, - jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, - jobjectArray values, jobject wrapper, jboolean memory_allocated) { +JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) { tflite::Interpreter* interpreter = convertLongToInterpreter(env, interpreter_handle); - if (interpreter == nullptr) return nullptr; + if (interpreter == nullptr) return; BufferErrorReporter* error_reporter = convertLongToErrorReporter(env, error_handle); - if (error_reporter == nullptr) return nullptr; - const int input_size = env->GetArrayLength(sizes); - // validates inputs - TfLiteStatus status = checkInputs(env, interpreter, input_size, data_types, - nums_of_bytes, values, sizes); - if (status != kTfLiteOk) return nullptr; - if (!memory_allocated || - !areInputDimensionsTheSame(env, interpreter, input_size, sizes)) { - // resizes inputs - status = resizeInputs(env, interpreter, input_size, sizes); - if (status != kTfLiteOk) { - throwException(env, kNullPointerException, - "Internal error: Can not resize the input: %s", - error_reporter->CachedErrorMessage()); - return nullptr; - } - // allocates memory - status = interpreter->AllocateTensors(); - if (status != kTfLiteOk) { - throwException(env, kNullPointerException, - "Internal error: Can not allocate memory for the given " - "inputs: %s", - error_reporter->CachedErrorMessage()); - return nullptr; - } - } - // sets inputs - status = setInputs(env, interpreter, input_size, data_types, nums_of_bytes, - values); - if (status != kTfLiteOk) return nullptr; - timespec beforeInference = ::tflite::getCurrentTime(); - // runs inference + if (error_reporter == nullptr) return; + if (interpreter->Invoke() != kTfLiteOk) { throwException(env, kIllegalArgumentException, "Internal error: Failed to run on the given Interpreter: %s", error_reporter->CachedErrorMessage()); - return nullptr; - } - timespec afterInference = ::tflite::getCurrentTime(); - jclass wrapper_clazz = env->GetObjectClass(wrapper); - jfieldID fid = - env->GetFieldID(wrapper_clazz, "inferenceDurationNanoseconds", "J"); - if (env->ExceptionCheck()) { - env->ExceptionClear(); - } else if (fid != nullptr) { - env->SetLongField( - wrapper, fid, - ::tflite::timespec_diff_nanoseconds(&beforeInference, &afterInference)); - } - // returns outputs - const std::vector& results = interpreter->outputs(); - if (results.empty()) { - throwException( - env, kIllegalArgumentException, - "Internal error: The Interpreter does not have any outputs."); - return nullptr; - } - jlongArray outputs = env->NewLongArray(results.size()); - size_t size = results.size(); - for (int i = 0; i < size; ++i) { - TfLiteTensor* source = interpreter->tensor(results[i]); - jlong output = reinterpret_cast(source); - env->SetLongArrayRegion(outputs, i, 1, &output); - } - return outputs; -} - -JNIEXPORT jintArray JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( - JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes) { - tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); - if (interpreter == nullptr) return nullptr; - const int idx = static_cast(input_idx); - if (input_idx < 0 || input_idx >= interpreter->inputs().size()) { - throwException(env, kIllegalArgumentException, - "Input error: Out of range: Failed to get %d-th input out of" - " %d inputs", - input_idx, interpreter->inputs().size()); - return nullptr; - } - TfLiteTensor* target = interpreter->tensor(interpreter->inputs()[idx]); - int size = target->dims->size; - if (num_bytes >= 0) { // verifies num of bytes matches if num_bytes if valid. - int expected_num_bytes = elementByteSize(target->type); - for (int i = 0; i < size; ++i) { - expected_num_bytes *= target->dims->data[i]; - } - if (num_bytes != expected_num_bytes) { - throwException(env, kIllegalArgumentException, - "Input error: Failed to get input dimensions. %d-th input " - "should have %d bytes, but found %d bytes.", - idx, expected_num_bytes, num_bytes); - return nullptr; - } + return; } - jintArray outputs = env->NewIntArray(size); - env->SetIntArrayRegion(outputs, 0, size, &(target->dims->data[0])); - return outputs; } JNIEXPORT jint JNICALL diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h index 128ece49811a112684dac7b36810e920eeeb7351..618fba480e4a1c4a1ff8531cb3fbc29fcb8191d8 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -29,15 +29,63 @@ limitations under the License. namespace tflite { // This is to be provided at link-time by a library. extern std::unique_ptr CreateOpResolver(); -extern timespec getCurrentTime(); -extern jlong timespec_diff_nanoseconds(struct timespec* start, - struct timespec* stop); } // namespace tflite #ifdef __cplusplus extern "C" { #endif // __cplusplus +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: allocateTensors + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors( + JNIEnv* env, jclass clazz, jlong handle, jlong error_handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: getInputTensor + * Signature: (JI)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env, + jclass clazz, + jlong handle, + jint index); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: getOutputTensor + * Signature: (JI)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env, + jclass clazz, + jlong handle, + jint index); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: getInputCount + * Signature: (J)I + */ +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: getOutputCount + * Signature: (J)I + */ +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env, + jclass clazz, + jlong handle); + /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: @@ -118,28 +166,11 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( /* * Class: org_tensorflow_lite_NativeInterpreterWrapper - * Method: - * Signature: - * (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;Ljava/lang/Object;Z)[J - */ -JNIEXPORT jlongArray JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_run( - JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, - jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, - jobjectArray values, jobject wrapper, jboolean memory_allocated); - -/* - * Class: org_tensorflow_lite_NativeInterpreterWrapper - * Method: - * Signature: (JII)[I - * - * Gets input dimensions. If num_bytes is non-negative, it will check whether - * num_bytes matches num of bytes required by the input, and return null and - * throw IllegalArgumentException if not. + * Method: run + * Signature: (JJ)V */ -JNIEXPORT jintArray JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( - JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes); +JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle); /* * Class: org_tensorflow_lite_NativeInterpreterWrapper diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc index 9e9387da86ebde7d711a7ce967461e370c95bc3e..7ff96a3172dcf020b34fcbe7491c9022fc7f51de 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc @@ -29,6 +29,35 @@ TfLiteTensor* convertLongToTensor(JNIEnv* env, jlong handle) { return reinterpret_cast(handle); } +size_t elementByteSize(TfLiteType data_type) { + // The code in this file makes the assumption that the + // TensorFlow TF_DataTypes and the Java primitive types + // have the same byte sizes. Validate that: + switch (data_type) { + case kTfLiteFloat32: + static_assert(sizeof(jfloat) == 4, + "Interal error: Java float not compatible with " + "kTfLiteFloat"); + return 4; + case kTfLiteInt32: + static_assert(sizeof(jint) == 4, + "Interal error: Java int not compatible with kTfLiteInt"); + return 4; + case kTfLiteUInt8: + static_assert(sizeof(jbyte) == 1, + "Interal error: Java byte not compatible with " + "kTfLiteUInt8"); + return 1; + case kTfLiteInt64: + static_assert(sizeof(jlong) == 8, + "Interal error: Java long not compatible with " + "kTfLiteInt64"); + return 8; + default: + return 0; + } +} + size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type, void* dst, size_t dst_size) { jarray array = static_cast(object); @@ -141,48 +170,6 @@ size_t readMultiDimensionalArray(JNIEnv* env, TfLiteType data_type, char* src, } } -} // namespace - -size_t elementByteSize(TfLiteType data_type) { - // The code in this file makes the assumption that the - // TensorFlow TF_DataTypes and the Java primitive types - // have the same byte sizes. Validate that: - switch (data_type) { - case kTfLiteFloat32: - static_assert(sizeof(jfloat) == 4, - "Interal error: Java float not compatible with " - "kTfLiteFloat"); - return 4; - case kTfLiteInt32: - static_assert(sizeof(jint) == 4, - "Interal error: Java int not compatible with kTfLiteInt"); - return 4; - case kTfLiteUInt8: - static_assert(sizeof(jbyte) == 1, - "Interal error: Java byte not compatible with " - "kTfLiteUInt8"); - return 1; - case kTfLiteInt64: - static_assert(sizeof(jlong) == 8, - "Interal error: Java long not compatible with " - "kTfLiteInt64"); - return 8; - default: - return 0; - } -} - -size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size) { - char* buf = static_cast(env->GetDirectBufferAddress(object)); - if (!buf) { - throwException(env, kIllegalArgumentException, - "Input ByteBuffer is not a direct buffer"); - return 0; - } - *dst = buf; - return dst_size; -} - size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, int dims_left, char** dst, int dst_size) { if (dims_left <= 1) { @@ -203,6 +190,37 @@ size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, } } +} // namespace + +JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env, + jclass clazz, + jlong handle) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return nullptr; + if (tensor->data.raw == nullptr) { + throwException(env, kIllegalArgumentException, + "Internal error: Tensor hasn't been allocated."); + return nullptr; + } + return env->NewDirectByteBuffer(static_cast(tensor->data.raw), + static_cast(tensor->bytes)); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer( + JNIEnv* env, jclass clazz, jlong handle, jobject src) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return; + + char* src_data_raw = static_cast(env->GetDirectBufferAddress(src)); + if (!src_data_raw) { + throwException(env, kIllegalArgumentException, + "Input ByteBuffer is not a direct buffer"); + return; + } + + tensor->data.raw = src_data_raw; +} + JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, jclass clazz, @@ -220,6 +238,27 @@ Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, num_dims, static_cast(value)); } +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env, + jclass clazz, + jlong handle, + jobject src) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return; + if (tensor->data.raw == nullptr) { + throwException(env, kIllegalArgumentException, + "Internal error: Target Tensor hasn't been allocated."); + return; + } + if (tensor->dims->size == 0) { + throwException(env, kIllegalArgumentException, + "Internal error: Cannot copy empty/scalar Tensors."); + return; + } + writeMultiDimensionalArray(env, src, tensor->type, tensor->dims->size, + &tensor->data.raw, tensor->bytes); +} + JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, jclass clazz, jlong handle) { @@ -237,3 +276,11 @@ Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) { env->SetIntArrayRegion(result, 0, num_dims, tensor->dims->data); return result; } + +JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env, + jclass clazz, + jlong handle) { + const TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return 0; + return static_cast(tensor->bytes); +} diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h index 3a4910dcc3a719fbb9f365dae693423de768349c..06e2546af8400de117ed6923a1d1bd67bcb998e2 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h @@ -24,8 +24,25 @@ extern "C" { #endif // __cplusplus /* - * Class: org_tensorflow_lite_TfLiteTensor - * Method: + * Class: org_tensorflow_lite_Tensor + * Method: buffer + * Signature: (J)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_Tensor + * Method: writeDirectBuffer + * Signature: (JLjava/nio/ByteBuffer;) + */ +JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer( + JNIEnv* env, jclass clazz, jlong handle, jobject src); + +/* + * Class: org_tensorflow_lite_Tensor + * Method: dtype * Signature: (J)I */ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, @@ -33,8 +50,8 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, jlong handle); /* - * Class: org_tensorflow_lite_TfLiteTensor - * Method: + * Class: org_tensorflow_lite_Tensor + * Method: shape * Signature: (J)[I */ JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, @@ -42,31 +59,35 @@ JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jlong handle); /* - * Class: org_tensorflow_lite_TfLiteTensor - * Method: + * Class: org_tensorflow_lite_Tensor + * Method: numBytes + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_Tensor + * Method: readMultiDimensionalArray * Signature: (JLjava/lang/Object;) */ JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, jclass clazz, jlong handle, - jobject value); + jobject dst); /* - * Finds the size of each data type. - */ -size_t elementByteSize(TfLiteType data_type); - -/* - * Writes data of a ByteBuffer into dest. - */ -size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size); - -/* - * Writes a multi-dimensional array into dest. + * Class: org_tensorflow_lite_Tensor + * Method: writeMultidimensionalArray + * Signature: (JLjava/lang/Object;) */ -size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, - int dims_left, char** dst, int dst_size); +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env, + jclass clazz, + jlong handle, + jobject src); #ifdef __cplusplus } // extern "C" diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index 82007a6ab5be3492495125b1c20ed155907ae5a0..d66a73db94f06776fe2a7310ed0837941aba87c4 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -164,6 +164,24 @@ public final class InterpreterTest { interpreter.close(); } + @Test + public void testRunWithByteBufferOutput() { + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + ByteBuffer parsedOutput = + ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); + try (Interpreter interpreter = new Interpreter(MODEL_FILE)) { + interpreter.run(fourD, parsedOutput); + } + float[] outputOneD = { + parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8) + }; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + } + @Test public void testMobilenetRun() { // Create a gray image. @@ -203,7 +221,9 @@ public final class InterpreterTest { assertThat(e) .hasMessageThat() .contains( - "DataType (2) of input data does not match with the DataType (1) of model inputs."); + "Cannot convert between a TensorFlowLite tensor with type " + + "FLOAT32 and a Java object of type [[[[I (which is compatible with the" + + " TensorFlowLite type INT32)"); } interpreter.close(); } @@ -223,8 +243,8 @@ public final class InterpreterTest { assertThat(e) .hasMessageThat() .contains( - "Cannot convert an TensorFlowLite tensor with type " - + "FLOAT32 to a Java object of type [[[[I (which is compatible with the" + "Cannot convert between a TensorFlowLite tensor with type " + + "FLOAT32 and a Java object of type [[[[I (which is compatible with the" + " TensorFlowLite type INT32)"); } interpreter.close(); @@ -311,4 +331,11 @@ public final class InterpreterTest { interpreter.close(); fileChannel.close(); } + + @Test + public void testRedundantClose() throws Exception { + Interpreter interpreter = new Interpreter(MODEL_FILE); + interpreter.close(); + interpreter.close(); + } } diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 9e41cb132d8386748e24c46d846e04f158d8b4c6..9c4a5acd797ec3476f44fb203901c9ba0429ab26 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -20,6 +20,8 @@ import static org.junit.Assert.fail; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.HashMap; +import java.util.Map; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -101,16 +103,37 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); float[][][][] parsedOutputs = new float[2][8][8][3]; - outputs[0].copyTo(parsedOutputs); + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); float[] outputOneD = parsedOutputs[0][0][0]; float[] expected = {3.69f, -19.62f, 23.43f}; assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); wrapper.close(); } + @Test + public void testRunWithBufferOutput() { + try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH)) { + float[] oneD = {1.23f, -6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + ByteBuffer parsedOutput = + ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutput); + wrapper.run(inputs, outputs); + float[] outputOneD = { + parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8) + }; + float[] expected = {3.69f, -19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + } + } + @Test public void testRunWithInputsOfSameDims() { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); @@ -119,17 +142,16 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); float[][][][] parsedOutputs = new float[2][8][8][3]; - outputs[0].copyTo(parsedOutputs); + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); float[] outputOneD = parsedOutputs[0][0][0]; float[] expected = {3.69f, -19.62f, 23.43f}; assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); - outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); parsedOutputs = new float[2][8][8][3]; - outputs[0].copyTo(parsedOutputs); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); outputOneD = parsedOutputs[0][0][0]; assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); wrapper.close(); @@ -143,10 +165,10 @@ public final class NativeInterpreterWrapperTest { int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; int[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); int[][][][] parsedOutputs = new int[2][4][4][12]; - outputs[0].copyTo(parsedOutputs); + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); int[] outputOneD = parsedOutputs[0][0][0]; int[] expected = {3, 7, -4, 3, 7, -4, 3, 7, -4, 3, 7, -4}; assertThat(outputOneD).isEqualTo(expected); @@ -161,10 +183,10 @@ public final class NativeInterpreterWrapperTest { long[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; long[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); long[][][][] parsedOutputs = new long[2][4][4][12]; - outputs[0].copyTo(parsedOutputs); + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); long[] outputOneD = parsedOutputs[0][0][0]; long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L}; @@ -182,10 +204,10 @@ public final class NativeInterpreterWrapperTest { Object[] inputs = {fourD}; int[] inputDims = {2, 8, 8, 3}; wrapper.resizeInput(0, inputDims); - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); byte[][][][] parsedOutputs = new byte[2][4][4][12]; - outputs[0].copyTo(parsedOutputs); + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); byte[] outputOneD = parsedOutputs[0][0][0]; byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0}; @@ -208,13 +230,14 @@ public final class NativeInterpreterWrapperTest { } } } + bbuf.rewind(); Object[] inputs = {bbuf}; int[] inputDims = {2, 8, 8, 3}; wrapper.resizeInput(0, inputDims); - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); byte[][][][] parsedOutputs = new byte[2][4][4][12]; - outputs[0].copyTo(parsedOutputs); + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); byte[] outputOneD = parsedOutputs[0][0][0]; byte[] expected = { (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, @@ -240,21 +263,22 @@ public final class NativeInterpreterWrapperTest { } } Object[] inputs = {bbuf}; + float[][][][] parsedOutputs = new float[4][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { assertThat(e) .hasMessageThat() .contains( - "Failed to get input dimensions. 0-th input should have 768 bytes, but found 3072 bytes"); + "Cannot convert between a TensorFlowLite buffer with 768 bytes and a " + + "ByteBuffer with 3072 bytes."); } int[] inputDims = {4, 8, 8, 3}; wrapper.resizeInput(0, inputDims); - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); - float[][][][] parsedOutputs = new float[4][8][8][3]; - outputs[0].copyTo(parsedOutputs); + wrapper.run(inputs, outputs); float[] outputOneD = parsedOutputs[0][0][0]; float[] expected = {3.69f, -19.62f, 23.43f}; assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); @@ -267,14 +291,18 @@ public final class NativeInterpreterWrapperTest { ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 7 * 8 * 3); bbuf.order(ByteOrder.nativeOrder()); Object[] inputs = {bbuf}; + Map outputs = new HashMap<>(); + ByteBuffer parsedOutput = ByteBuffer.allocateDirect(2 * 7 * 8 * 3); + outputs.put(0, parsedOutput); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { assertThat(e) .hasMessageThat() .contains( - "Failed to get input dimensions. 0-th input should have 192 bytes, but found 336 bytes."); + "Cannot convert between a TensorFlowLite buffer with 192 bytes and a " + + "ByteBuffer with 336 bytes."); } wrapper.close(); } @@ -287,14 +315,18 @@ public final class NativeInterpreterWrapperTest { int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; int[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; + int[][][][] parsedOutputs = new int[2][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { assertThat(e) .hasMessageThat() .contains( - "DataType (2) of input data does not match with the DataType (1) of model inputs."); + "Cannot convert between a TensorFlowLite tensor with type FLOAT32 and a Java object " + + "of type [[[[I (which is compatible with the TensorFlowLite type INT32)"); } wrapper.close(); } @@ -308,8 +340,11 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { assertThat(e).hasMessageThat().contains("Invalid handle to Interpreter."); @@ -321,7 +356,7 @@ public final class NativeInterpreterWrapperTest { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); try { Object[] inputs = {}; - wrapper.run(inputs); + wrapper.run(inputs, null); fail(); } catch (IllegalArgumentException e) { assertThat(e).hasMessageThat().contains("Inputs should not be null or empty."); @@ -337,11 +372,14 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD, fourD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("Expected num of inputs is 1 but got 2"); + assertThat(e).hasMessageThat().contains("Invalid input Tensor index: 1"); } wrapper.close(); } @@ -353,13 +391,18 @@ public final class NativeInterpreterWrapperTest { float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD}; float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; Object[] inputs = {threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { assertThat(e) .hasMessageThat() - .contains("0-th input should have 4 dimensions, but found 3 dimensions"); + .contains( + "Cannot copy between a TensorFlowLite tensor with shape [8, 7, 3] and a " + + "Java object with shape [2, 8, 8, 3]."); } wrapper.close(); } @@ -372,91 +415,22 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { assertThat(e) .hasMessageThat() - .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]"); + .contains( + "Cannot copy between a TensorFlowLite tensor with shape [2, 8, 7, 3] and a " + + "Java object with shape [2, 8, 8, 3]."); } wrapper.close(); } - @Test - public void testNumElements() { - int[] shape = {2, 3, 4}; - int num = NativeInterpreterWrapper.numElements(shape); - assertThat(num).isEqualTo(24); - shape = null; - num = NativeInterpreterWrapper.numElements(shape); - assertThat(num).isEqualTo(0); - } - - @Test - public void testIsNonEmtpyArray() { - assertThat(NativeInterpreterWrapper.isNonEmptyArray(null)).isFalse(); - assertThat(NativeInterpreterWrapper.isNonEmptyArray(3.2)).isFalse(); - int[] emptyArray = {}; - assertThat(NativeInterpreterWrapper.isNonEmptyArray(emptyArray)).isFalse(); - int[] validArray = {9, 5, 2, 1}; - assertThat(NativeInterpreterWrapper.isNonEmptyArray(validArray)).isTrue(); - } - - @Test - public void testDataTypeOf() { - float[] testEmtpyArray = {}; - DataType dataType = NativeInterpreterWrapper.dataTypeOf(testEmtpyArray); - assertThat(dataType).isEqualTo(DataType.FLOAT32); - float[] testFloatArray = {0.783f, 0.251f}; - dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray); - assertThat(dataType).isEqualTo(DataType.FLOAT32); - float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray}; - dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray); - assertThat(dataType).isEqualTo(DataType.FLOAT32); - try { - double[] testDoubleArray = {0.783, 0.251}; - NativeInterpreterWrapper.dataTypeOf(testDoubleArray); - fail(); - } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("cannot resolve DataType of"); - } - try { - Float[] testBoxedArray = {0.783f, 0.251f}; - NativeInterpreterWrapper.dataTypeOf(testBoxedArray); - fail(); - } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;"); - } - } - - @Test - public void testNumDimensions() { - int scalar = 1; - assertThat(NativeInterpreterWrapper.numDimensions(scalar)).isEqualTo(0); - int[][] array = {{2, 4}, {1, 9}}; - assertThat(NativeInterpreterWrapper.numDimensions(array)).isEqualTo(2); - try { - int[] emptyArray = {}; - NativeInterpreterWrapper.numDimensions(emptyArray); - fail(); - } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("Array lengths cannot be 0."); - } - } - - @Test - public void testFillShape() { - int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}}; - int num = NativeInterpreterWrapper.numDimensions(array); - int[] shape = new int[num]; - NativeInterpreterWrapper.fillShape(array, 0, shape); - assertThat(num).isEqualTo(3); - assertThat(shape[0]).isEqualTo(2); - assertThat(shape[1]).isEqualTo(3); - assertThat(shape[2]).isEqualTo(1); - } - @Test public void testGetInferenceLatency() { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); @@ -465,8 +439,10 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - assertThat(outputs.length).isEqualTo(1); + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + wrapper.run(inputs, outputs); assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isGreaterThan(0L); wrapper.close(); } @@ -486,13 +462,14 @@ public final class NativeInterpreterWrapperTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); try { - wrapper.run(inputs); + wrapper.run(inputs, outputs); fail(); } catch (IllegalArgumentException e) { - assertThat(e) - .hasMessageThat() - .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]"); + // Expected. } assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isNull(); wrapper.close(); @@ -502,41 +479,7 @@ public final class NativeInterpreterWrapperTest { public void testGetInputDims() { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); int[] expectedDims = {1, 8, 8, 3}; - assertThat(wrapper.getInputDims(0)).isEqualTo(expectedDims); - wrapper.close(); - } - - @Test - public void testGetInputDimsOutOfRange() { - NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); - try { - wrapper.getInputDims(-1); - fail(); - } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("Out of range"); - } - try { - wrapper.getInputDims(1); - fail(); - } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("Out of range"); - } - wrapper.close(); - } - - @Test - public void testGetOutputDataType() { - NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); - assertThat(wrapper.getOutputDataType(0)).contains("float"); - wrapper.close(); - wrapper = new NativeInterpreterWrapper(LONG_MODEL_PATH); - assertThat(wrapper.getOutputDataType(0)).contains("long"); - wrapper.close(); - wrapper = new NativeInterpreterWrapper(INT_MODEL_PATH); - assertThat(wrapper.getOutputDataType(0)).contains("int"); - wrapper.close(); - wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH); - assertThat(wrapper.getOutputDataType(0)).contains("byte"); + assertThat(wrapper.getInputTensor(0).shape()).isEqualTo(expectedDims); wrapper.close(); } diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java index 94b6632bb8dd7117bf4074da1939bd23ce732efd..71ef04494357e8b951cbbbd2c68385b17c472736 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java @@ -18,6 +18,10 @@ package org.tensorflow.lite; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.HashMap; +import java.util.Map; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -32,7 +36,7 @@ public final class TensorTest { "tensorflow/contrib/lite/java/src/testdata/add.bin"; private NativeInterpreterWrapper wrapper; - private long nativeHandle; + private Tensor tensor; @Before public void setUp() { @@ -42,8 +46,10 @@ public final class TensorTest { float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; float[][][][] fourD = {threeD, threeD}; Object[] inputs = {fourD}; - Tensor[] outputs = wrapper.run(inputs); - nativeHandle = outputs[0].nativeHandle; + Map outputs = new HashMap<>(); + outputs.put(0, new float[2][8][8][3]); + wrapper.run(inputs, outputs); + tensor = wrapper.getOutputTensor(0); } @After @@ -52,17 +58,16 @@ public final class TensorTest { } @Test - public void testFromHandle() throws Exception { - Tensor tensor = Tensor.fromHandle(nativeHandle); + public void testBasic() throws Exception { assertThat(tensor).isNotNull(); int[] expectedShape = {2, 8, 8, 3}; - assertThat(tensor.shapeCopy).isEqualTo(expectedShape); - assertThat(tensor.dtype).isEqualTo(DataType.FLOAT32); + assertThat(tensor.shape()).isEqualTo(expectedShape); + assertThat(tensor.dataType()).isEqualTo(DataType.FLOAT32); + assertThat(tensor.numBytes()).isEqualTo(2 * 8 * 8 * 3 * 4); } @Test public void testCopyTo() { - Tensor tensor = Tensor.fromHandle(nativeHandle); float[][][][] parsedOutputs = new float[2][8][8][3]; tensor.copyTo(parsedOutputs); float[] outputOneD = parsedOutputs[0][0][0]; @@ -70,9 +75,32 @@ public final class TensorTest { assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); } + @Test + public void testCopyToByteBuffer() { + ByteBuffer parsedOutput = + ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); + tensor.copyTo(parsedOutput); + assertThat(parsedOutput.position()).isEqualTo(2 * 8 * 8 * 3 * 4); + float[] outputOneD = { + parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8) + }; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + } + + @Test + public void testCopyToInvalidByteBuffer() { + ByteBuffer parsedOutput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder()); + try { + tensor.copyTo(parsedOutput); + fail(); + } catch (IllegalArgumentException e) { + // Expected. + } + } + @Test public void testCopyToWrongType() { - Tensor tensor = Tensor.fromHandle(nativeHandle); int[][][][] parsedOutputs = new int[2][8][8][3]; try { tensor.copyTo(parsedOutputs); @@ -81,15 +109,13 @@ public final class TensorTest { assertThat(e) .hasMessageThat() .contains( - "Cannot convert an TensorFlowLite tensor with type " - + "FLOAT32 to a Java object of type [[[[I (which is compatible with the TensorFlowLite " - + "type INT32)"); + "Cannot convert between a TensorFlowLite tensor with type FLOAT32 and a Java object " + + "of type [[[[I (which is compatible with the TensorFlowLite type INT32)"); } } @Test public void testCopyToWrongShape() { - Tensor tensor = Tensor.fromHandle(nativeHandle); float[][][][] parsedOutputs = new float[1][8][8][3]; try { tensor.copyTo(parsedOutputs); @@ -98,8 +124,104 @@ public final class TensorTest { assertThat(e) .hasMessageThat() .contains( - "Shape of output target [1, 8, 8, 3] does not match " - + "with the shape of the Tensor [2, 8, 8, 3]."); + "Cannot copy between a TensorFlowLite tensor with shape [2, 8, 8, 3] " + + "and a Java object with shape [1, 8, 8, 3]."); + } + } + + @Test + public void testSetTo() { + float[][][][] input = new float[2][8][8][3]; + float[][][][] output = new float[2][8][8][3]; + ByteBuffer inputByteBuffer = + ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder()); + + input[0][0][0][0] = 2.0f; + tensor.setTo(input); + tensor.copyTo(output); + assertThat(output[0][0][0][0]).isEqualTo(2.0f); + + inputByteBuffer.putFloat(0, 3.0f); + tensor.setTo(inputByteBuffer); + tensor.copyTo(output); + assertThat(output[0][0][0][0]).isEqualTo(3.0f); + } + + @Test + public void testSetToInvalidByteBuffer() { + ByteBuffer input = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder()); + try { + tensor.setTo(input); + fail(); + } catch (IllegalArgumentException e) { + // Success. + } + } + + @Test + public void testGetInputShapeIfDifferent() { + ByteBuffer bytBufferInput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder()); + assertThat(tensor.getInputShapeIfDifferent(bytBufferInput)).isNull(); + + float[][][][] sameShapeInput = new float[2][8][8][3]; + assertThat(tensor.getInputShapeIfDifferent(sameShapeInput)).isNull(); + + float[][][][] differentShapeInput = new float[1][8][8][3]; + assertThat(tensor.getInputShapeIfDifferent(differentShapeInput)) + .isEqualTo(new int[] {1, 8, 8, 3}); + } + + @Test + public void testDataTypeOf() { + float[] testEmptyArray = {}; + DataType dataType = Tensor.dataTypeOf(testEmptyArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + float[] testFloatArray = {0.783f, 0.251f}; + dataType = Tensor.dataTypeOf(testFloatArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray}; + dataType = Tensor.dataTypeOf(testFloatArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + try { + double[] testDoubleArray = {0.783, 0.251}; + Tensor.dataTypeOf(testDoubleArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of"); } + try { + Float[] testBoxedArray = {0.783f, 0.251f}; + Tensor.dataTypeOf(testBoxedArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;"); + } + } + + @Test + public void testNumDimensions() { + int scalar = 1; + assertThat(Tensor.numDimensions(scalar)).isEqualTo(0); + int[][] array = {{2, 4}, {1, 9}}; + assertThat(Tensor.numDimensions(array)).isEqualTo(2); + try { + int[] emptyArray = {}; + Tensor.numDimensions(emptyArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("Array lengths cannot be 0."); + } + } + + @Test + public void testFillShape() { + int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}}; + int num = Tensor.numDimensions(array); + int[] shape = new int[num]; + Tensor.fillShape(array, 0, shape); + assertThat(num).isEqualTo(3); + assertThat(shape[0]).isEqualTo(2); + assertThat(shape[1]).isEqualTo(3); + assertThat(shape[2]).isEqualTo(1); } } diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java index 3aef0c3bb6cc4748de0e55d31f0215a77320ae69..c23521c0774ebab01f38db8b416020ae5755cee9 100644 --- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java @@ -58,7 +58,7 @@ public class TestHelper { */ public static int[] getInputDims(Interpreter interpreter, int index) { if (interpreter != null && interpreter.wrapper != null) { - return interpreter.wrapper.getInputDims(index); + return interpreter.wrapper.getInputTensor(index).shape(); } else { throw new IllegalArgumentException( "Interpreter has not initialized;" + " Failed to get input dimensions."); @@ -77,7 +77,7 @@ public class TestHelper { */ public static String getOutputDataType(Interpreter interpreter, int index) { if (interpreter != null && interpreter.wrapper != null) { - return interpreter.wrapper.getOutputDataType(index); + return interpreter.wrapper.getOutputTensor(index).dataType().toStringName(); } else { throw new IllegalArgumentException( "Interpreter has not initialized;" + " Failed to get output data type."); diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 61d5af3478474f006fe50cbbc9d2749127086c51..c5586475ec258849948ff6b960abc846e2ea1b3c 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -8,11 +8,27 @@ load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +# Suppress warnings that are introduced by Eigen Tensor. +EXTRA_EIGEN_COPTS = select({ + "//tensorflow:ios": [ + "-Wno-error=invalid-partial-specialization", + "-Wno-error=reorder", + ], + "//tensorflow:windows": [ + "/DEIGEN_HAS_C99_MATH", + "/DEIGEN_AVOID_STL_ARRAY", + ], + "//conditions:default": ["-Wno-error=reorder"], +}) + tf_cc_test( name = "optional_tensor_test", size = "small", srcs = ["optional_tensor_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -46,11 +62,12 @@ cc_library( hdrs = [ "eigen_support.h", ], - copts = tflite_copts(), + copts = tflite_copts() + EXTRA_EIGEN_COPTS, deps = [ ":op_macros", + "//tensorflow/contrib/lite:arena_planner", "//tensorflow/contrib/lite:context", - "//third_party/eigen3", + "//tensorflow/contrib/lite/kernels/internal:optimized", ], ) @@ -106,7 +123,10 @@ tf_cc_test( name = "kernel_util_test", size = "small", srcs = ["kernel_util_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":kernel_util", "//tensorflow/contrib/lite/testing:util", @@ -118,6 +138,7 @@ tf_cc_test( name = "test_util_test", size = "small", srcs = ["test_util_test.cc"], + tags = ["no_oss"], deps = [ ":test_util", "//tensorflow/contrib/lite/testing:util", @@ -130,7 +151,7 @@ cc_library( srcs = [ "activations.cc", "add.cc", - "arg_max.cc", + "arg_min_max.cc", "audio_spectrogram.cc", "basic_rnn.cc", "batch_to_space_nd.cc", @@ -149,18 +170,22 @@ cc_library( "embedding_lookup_sparse.cc", "exp.cc", "expand_dims.cc", + "fake_quant.cc", "floor.cc", "fully_connected.cc", "gather.cc", "hashtable_lookup.cc", "l2norm.cc", "local_response_norm.cc", + "logical.cc", "lsh_projection.cc", "lstm.cc", "maximum_minimum.cc", "mfcc.cc", "mul.cc", "neg.cc", + "one_hot.cc", + "pack.cc", "pad.cc", "pooling.cc", "pow.cc", @@ -191,14 +216,7 @@ cc_library( "padding.h", "register.h", ], - # Suppress warnings that are introduced by Eigen Tensor. - copts = tflite_copts() + [ - "-Wno-error=reorder", - ] + select({ - "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"], - "//conditions:default": [ - ], - }), + copts = tflite_copts() + EXTRA_EIGEN_COPTS, deps = [ ":activation_functor", ":eigen_support", @@ -225,7 +243,10 @@ tf_cc_test( name = "audio_spectrogram_test", size = "small", srcs = ["audio_spectrogram_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -239,7 +260,10 @@ tf_cc_test( name = "mfcc_test", size = "small", srcs = ["mfcc_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -253,7 +277,10 @@ tf_cc_test( name = "detection_postprocess_test", size = "small", srcs = ["detection_postprocess_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -290,10 +317,11 @@ tf_cc_test( ) tf_cc_test( - name = "arg_max_test", + name = "arg_min_max_test", size = "small", - srcs = ["arg_max_test.cc"], + srcs = ["arg_min_max_test.cc"], tags = [ + "no_oss", "tflite_not_portable_ios", ], deps = [ @@ -308,7 +336,10 @@ tf_cc_test( name = "div_test", size = "small", srcs = ["div_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -321,7 +352,10 @@ tf_cc_test( name = "sub_test", size = "small", srcs = ["sub_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -334,7 +368,10 @@ tf_cc_test( name = "transpose_test", size = "small", srcs = ["transpose_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -349,7 +386,10 @@ tf_cc_test( name = "space_to_batch_nd_test", size = "small", srcs = ["space_to_batch_nd_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -362,7 +402,10 @@ tf_cc_test( name = "batch_to_space_nd_test", size = "small", srcs = ["batch_to_space_nd_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -375,7 +418,10 @@ tf_cc_test( name = "cast_test", size = "small", srcs = ["cast_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -428,7 +474,10 @@ tf_cc_test( name = "dequantize_test", size = "small", srcs = ["dequantize_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -455,7 +504,10 @@ tf_cc_test( name = "bidirectional_sequence_lstm_test", size = "small", srcs = ["bidirectional_sequence_lstm_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -468,7 +520,10 @@ tf_cc_test( name = "floor_test", size = "small", srcs = ["floor_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -481,7 +536,10 @@ tf_cc_test( name = "elementwise_test", size = "small", srcs = ["elementwise_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -494,7 +552,10 @@ tf_cc_test( name = "unidirectional_sequence_lstm_test", size = "small", srcs = ["unidirectional_sequence_lstm_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -508,6 +569,7 @@ tf_cc_test( size = "small", srcs = ["bidirectional_sequence_rnn_test.cc"], tags = [ + "no_oss", "tflite_not_portable", ], deps = [ @@ -522,7 +584,10 @@ tf_cc_test( name = "unidirectional_sequence_rnn_test", size = "small", srcs = ["unidirectional_sequence_rnn_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -548,7 +613,26 @@ tf_cc_test( name = "exp_test", size = "small", srcs = ["exp_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "fake_quant_test", + size = "small", + srcs = ["fake_quant_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -561,7 +645,10 @@ tf_cc_test( name = "maximum_minimum_test", size = "small", srcs = ["maximum_minimum_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -574,7 +661,10 @@ tf_cc_test( name = "reduce_test", size = "small", srcs = ["reduce_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -600,7 +690,10 @@ tf_cc_test( name = "pad_test", size = "small", srcs = ["pad_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -626,7 +719,10 @@ tf_cc_test( name = "gather_test", size = "small", srcs = ["gather_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:builtin_op_data", @@ -640,7 +736,10 @@ tf_cc_test( name = "topk_v2_test", size = "small", srcs = ["topk_v2_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:builtin_op_data", @@ -761,7 +860,10 @@ tf_cc_test( name = "log_softmax_test", size = "small", srcs = ["log_softmax_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -842,7 +944,10 @@ tf_cc_test( name = "split_test", size = "small", srcs = ["split_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -855,7 +960,10 @@ tf_cc_test( name = "squeeze_test", size = "small", srcs = ["squeeze_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -868,7 +976,10 @@ tf_cc_test( name = "strided_slice_test", size = "small", srcs = ["strided_slice_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -881,7 +992,10 @@ tf_cc_test( name = "tile_test", size = "small", srcs = ["tile_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:builtin_op_data", @@ -898,6 +1012,7 @@ tf_cc_test( "comparisons_test.cc", ], tags = [ + "no_oss", "tflite_not_portable_ios", ], deps = [ @@ -912,7 +1027,10 @@ tf_cc_test( name = "neg_test", size = "small", srcs = ["neg_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -928,6 +1046,7 @@ tf_cc_test( "select_test.cc", ], tags = [ + "no_oss", "tflite_not_portable_ios", ], deps = [ @@ -945,6 +1064,7 @@ tf_cc_test( "slice_test.cc", ], tags = [ + "no_oss", "tflite_not_portable_ios", ], deps = [ @@ -959,7 +1079,10 @@ tf_cc_test( name = "transpose_conv_test", size = "small", srcs = ["transpose_conv_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:framework", @@ -972,7 +1095,10 @@ tf_cc_test( name = "expand_dims_test", size = "small", srcs = ["expand_dims_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:builtin_op_data", @@ -986,7 +1112,10 @@ tf_cc_test( name = "sparse_to_dense_test", size = "small", srcs = ["sparse_to_dense_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:builtin_op_data", @@ -1000,7 +1129,10 @@ tf_cc_test( name = "shape_test", size = "small", srcs = ["shape_test.cc"], - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":builtin_ops", "//tensorflow/contrib/lite:builtin_op_data", @@ -1014,6 +1146,50 @@ tf_cc_test( name = "pow_test", size = "small", srcs = ["pow_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "pack_test", + size = "small", + srcs = ["pack_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "one_hot_test", + size = "small", + srcs = ["one_hot_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "logical_test", + size = "small", + srcs = ["logical_test.cc"], tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index 99f81c4a8a78ab0b2a24955d77f25ed09da13b84..817266a47147980699a348a5c26ed637828e80c6 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -186,8 +185,8 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_EQ(context, input->type, output->type); - TF_LITE_ENSURE(context, - NumDimensions(input) == 2 || NumDimensions(input) == 4); + const int num_dims = NumDimensions(input); + TF_LITE_ENSURE(context, num_dims == 1 || num_dims == 2 || num_dims == 4); if (input->type == kTfLiteUInt8) { TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); @@ -213,25 +212,25 @@ TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, 0); const TfLiteTensor* alpha = GetInput(context, node, 1); - output->type = input->type; - // Currently only Float32 is supported // TODO(ycling): Support other data types. TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, alpha->type, kTfLiteFloat32); + output->type = input->type; - // Currently, only support 4D `input` and 3D `alpha` with shape - // (1, 1, channels). - // TODO(impjdi): Support other cases where `alpha` is broadcastable - // to `input`. - TF_LITE_ENSURE_EQ(context, input->dims->size, 4); - TF_LITE_ENSURE_EQ(context, alpha->dims->size, 3); - TF_LITE_ENSURE_EQ(context, alpha->dims->data[0], 1); - TF_LITE_ENSURE_EQ(context, alpha->dims->data[1], 1); - TF_LITE_ENSURE_EQ(context, alpha->dims->data[2], input->dims->data[3]); + // PRelu (parameteric Relu) shares the same alpha value on "shared axis". + // This means it's always required to "broadcast" alpha values in PRelu. + TfLiteIntArray* output_size = nullptr; + TF_LITE_ENSURE_OK( + context, CalculateShapeForBroadcast(context, input, alpha, &output_size)); - return context->ResizeTensor(context, output, - TfLiteIntArrayCopy(input->dims)); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size)); + // After broadcasting, the output shape should always be the same as the + // input shape. + TF_LITE_ENSURE(context, HaveSameShapes(input, output)); + + return kTfLiteOk; } TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { @@ -365,13 +364,9 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -// Takes a 2D tensor and perform softmax along the second dimension. -void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output, - TfLiteSoftmaxParams* params) { - const int batch_size = input->dims->data[0]; - const int input_size = input->dims->data[1]; - float* in = input->data.f; - float* out = output->data.f; +// Performs softmax along the input of size (input_size * batch_size). +void Softmax(const float* in, const int input_size, const int batch_size, + const float beta, float* out) { TF_LITE_ASSERT(input_size > 0); // For each batch @@ -385,7 +380,7 @@ void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output, // Compute the normalized sum of exps. float exp_sum = 0.0; for (int i = 0; i < input_size; i++) { - out[i] = std::exp((in[i] - max_coeff) * params->beta); + out[i] = std::exp((in[i] - max_coeff) * beta); exp_sum += out[i]; } @@ -401,6 +396,33 @@ void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output, } } +// Takes a 1D tensor and performs softmax along it. +void Softmax1DFloat(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + const int input_size = input->dims->data[0]; + Softmax(input->data.f, input_size, 1, params->beta, output->data.f); +} + +// Takes a 2D tensor and perform softmax along the last dimension. +void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + Softmax(input->data.f, input_size, batch_size, params->beta, output->data.f); +} + +void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + // TODO(ahentz): this is arguably a dirty trick. Since the implementation + // always traverses the last dimension of a 4D tensor, we will pretend our 1D + // tensor is 4D in a special way. We will convert a (Y) shape into a (1, + // 1, 1, Y) shape. + const int input_size = input->dims->data[0]; + optimized_ops::Softmax( + GetTensorData(input), GetTensorShape({1, 1, 1, input_size}), + data->input_multiplier, data->input_left_shift, data->diff_min, + GetTensorData(output), GetTensorShape({1, 1, 1, input_size})); +} void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output, TfLiteSoftmaxParams* params, OpData* data) { // TODO(ahentz): this is arguably a dirty trick. Since the implementation @@ -443,6 +465,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { // dimensions. switch (input->type) { case kTfLiteFloat32: { + if (NumDimensions(input) == 1) { + Softmax1DFloat(input, output, params); + return kTfLiteOk; + } if (NumDimensions(input) == 2) { Softmax2DFloat(input, output, params); return kTfLiteOk; @@ -452,11 +478,15 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } context->ReportError( - context, "Only 2D and 4D tensors supported currently, got %dD.", + context, "Only 1D, 2D and 4D tensors supported currently, got %dD.", NumDimensions(input)); return kTfLiteError; } case kTfLiteUInt8: { + if (NumDimensions(input) == 1) { + Softmax1DQuantized(input, output, params, data); + return kTfLiteOk; + } if (NumDimensions(input) == 2) { Softmax2DQuantized(input, output, params, data); return kTfLiteOk; @@ -494,33 +524,24 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } } +template +T ApplyPrelu(T input, T alpha) { + return input >= 0.0 ? input : input * alpha; +} + TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, 0); const TfLiteTensor* alpha = GetInput(context, node, 1); - const TfLiteTensor* output = GetOutput(context, node, 0); - + TfLiteTensor* output = GetOutput(context, node, 0); if (input->type != kTfLiteFloat32) { context->ReportError(context, "Only float32 supported currently, got %d.", input->type); return kTfLiteError; } - TF_LITE_ENSURE_EQ(context, input->dims->size, 4); - const int batches = input->dims->data[0]; - const int height = input->dims->data[1]; - const int width = input->dims->data[2]; - const int channels = input->dims->data[3]; - - TF_LITE_ENSURE_EQ(context, alpha->dims->size, 3); - TF_LITE_ENSURE_EQ(context, alpha->dims->data[0], 1); - TF_LITE_ENSURE_EQ(context, alpha->dims->data[1], 1); - TF_LITE_ENSURE_EQ(context, alpha->dims->data[2], channels); - - const int n = batches * height * width * channels; - for (int i = 0; i < n; ++i) { - const float x = input->data.f[i]; - output->data.f[i] = x >= 0.0f ? x : alpha->data.f[i % channels] * x; - } - + reference_ops::BroadcastBinaryFunction( + GetTensorData(input), GetTensorDims(input), + GetTensorData(alpha), GetTensorDims(alpha), + GetTensorData(output), GetTensorDims(output), ApplyPrelu); return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc index 587e1303da6afed1fc711100f457f1bf62b0b7e1..083cdf78d76991b89c4c2caf03dcb6db404a2578 100644 --- a/tensorflow/contrib/lite/kernels/activations_test.cc +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -339,6 +339,29 @@ TEST(QuantizedActivationsOpTest, Softmax4D) { kQuantizedTolerance))); } +TEST(FloatActivationsOpTest, Softmax1D) { + FloatActivationsOpModel m(0.1, + /*input=*/{TensorType_FLOAT32, {8}}); + m.SetInput({0, -6, 2, 4, 3, -2, 10, 1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {.09752, .05352, .11911, .14548, .13164, .07984, .26509, .10778}))); +} + +TEST(QuantizedActivationsOpTest, Softmax1D) { + QuantizedActivationsOpModel m(0.1, + /*input=*/{TensorType_UINT8, {8}, -10, 10}); + m.SetInput({0, -6, 2, 4, 3, -2, 10, 1}); + m.Invoke(); + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({0.09766, 0.05469, 0.12109, 0.14453, + 0.13281, 0.07813, 0.26563, 0.10938}, + kQuantizedTolerance))); +} + TEST(FloatActivationsOpTest, Softmax2D) { FloatActivationsOpModel m(0.1, /*input=*/{TensorType_FLOAT32, {2, 4}}); diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc index f44d531cbfa9ed41f881380752558555aab97b4d..af9b5c7013afc5d32d01cba07492a282727b3e12 100644 --- a/tensorflow/contrib/lite/kernels/add.cc +++ b/tensorflow/contrib/lite/kernels/add.cc @@ -110,15 +110,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { QuantizeMultiplierSmallerThanOneExp( real_input1_multiplier, &data->input1_multiplier, &data->input1_shift); - data->input1_shift *= -1; QuantizeMultiplierSmallerThanOneExp( real_input2_multiplier, &data->input2_multiplier, &data->input2_shift); - data->input2_shift *= -1; QuantizeMultiplierSmallerThanOneExp( real_output_multiplier, &data->output_multiplier, &data->output_shift); - data->output_shift *= -1; CalculateActivationRangeUint8(params->activation, output, &data->output_activation_min, @@ -152,14 +149,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { CheckedLog2(output->params.scale, &output_scale_log2_rounded); TF_LITE_ENSURE(context, output_scale_is_pot); - data->input1_shift = output_scale_log2_rounded - input1_scale_log2_rounded; - data->input2_shift = output_scale_log2_rounded - input2_scale_log2_rounded; + data->input1_shift = input1_scale_log2_rounded - output_scale_log2_rounded; + data->input2_shift = input2_scale_log2_rounded - output_scale_log2_rounded; // Shifting of one input is supported. The graph quantization should ensure // that the other input matches the output. TF_LITE_ENSURE(context, data->input1_shift == 0 || data->input2_shift == 0); - TF_LITE_ENSURE(context, data->input1_shift >= 0); - TF_LITE_ENSURE(context, data->input2_shift >= 0); + TF_LITE_ENSURE(context, data->input1_shift <= 0); + TF_LITE_ENSURE(context, data->input2_shift <= 0); CalculateActivationRangeQuantized(context, params->activation, output, &data->output_activation_min, @@ -173,24 +170,27 @@ template void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params, const OpData* data, const TfLiteTensor* input1, const TfLiteTensor* input2, TfLiteTensor* output) { -#define TF_LITE_ADD(type, opname, data_type) \ - data_type output_activation_min, output_activation_max; \ - CalculateActivationRange(params->activation, &output_activation_min, \ - &output_activation_max); \ - type::opname(GetTensorData(input1), GetTensorDims(input1), \ - GetTensorData(input2), GetTensorDims(input2), \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)) +#define TF_LITE_ADD(type, opname, data_type) \ + data_type output_activation_min, output_activation_max; \ + CalculateActivationRange(params->activation, &output_activation_min, \ + &output_activation_max); \ + tflite::ArithmeticParams op_params; \ + SetActivationParams(output_activation_min, output_activation_max, \ + &op_params); \ + type::opname(op_params, GetTensorShape(input1), \ + GetTensorData(input1), GetTensorShape(input2), \ + GetTensorData(input2), GetTensorShape(output), \ + GetTensorData(output)) if (output->type == kTfLiteInt32) { if (kernel_type == kReference) { if (data->requires_broadcast) { - TF_LITE_ADD(reference_ops, BroadcastAdd, int32_t); + TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, int32_t); } else { TF_LITE_ADD(reference_ops, Add, int32_t); } } else { if (data->requires_broadcast) { - TF_LITE_ADD(optimized_ops, BroadcastAdd, int32_t); + TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, int32_t); } else { TF_LITE_ADD(optimized_ops, Add, int32_t); } @@ -198,13 +198,13 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params, } else if (output->type == kTfLiteFloat32) { if (kernel_type == kReference) { if (data->requires_broadcast) { - TF_LITE_ADD(reference_ops, BroadcastAdd, float); + TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow, float); } else { TF_LITE_ADD(reference_ops, Add, float); } } else { if (data->requires_broadcast) { - TF_LITE_ADD(optimized_ops, BroadcastAdd, float); + TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, float); } else { TF_LITE_ADD(optimized_ops, Add, float); } @@ -220,30 +220,43 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input2, TfLiteTensor* output) { if (output->type == kTfLiteUInt8) { -#define TF_LITE_ADD(type, opname) \ - type::opname( \ - data->left_shift, GetTensorData(input1), GetTensorDims(input1), \ - data->input1_offset, data->input1_multiplier, data->input1_shift, \ - GetTensorData(input2), GetTensorDims(input2), \ - data->input2_offset, data->input2_multiplier, data->input2_shift, \ - data->output_offset, data->output_multiplier, data->output_shift, \ - data->output_activation_min, data->output_activation_max, \ - GetTensorData(output), GetTensorDims(output)); +#define TF_LITE_ADD(type, opname) \ + tflite::ArithmeticParams op_params; \ + op_params.left_shift = data->left_shift; \ + op_params.input1_offset = data->input1_offset; \ + op_params.input1_multiplier = data->input1_multiplier; \ + op_params.input1_shift = data->input1_shift; \ + op_params.input2_offset = data->input2_offset; \ + op_params.input2_multiplier = data->input2_multiplier; \ + op_params.input2_shift = data->input2_shift; \ + op_params.output_offset = data->output_offset; \ + op_params.output_multiplier = data->output_multiplier; \ + op_params.output_shift = data->output_shift; \ + SetActivationParams(data->output_activation_min, \ + data->output_activation_max, &op_params); \ + type::opname(op_params, GetTensorShape(input1), \ + GetTensorData(input1), GetTensorShape(input2), \ + GetTensorData(input2), GetTensorShape(output), \ + GetTensorData(output)) // The quantized version of Add doesn't support activations, so we // always use BroadcastAdd. if (kernel_type == kReference) { - TF_LITE_ADD(reference_ops, BroadcastAdd); + TF_LITE_ADD(reference_ops, BroadcastAdd4DSlow); } else { - TF_LITE_ADD(optimized_ops, BroadcastAdd); + TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow); } #undef TF_LITE_ADD } else if (output->type == kTfLiteInt16) { -#define TF_LITE_ADD(type, opname) \ - type::opname(GetTensorData(input1), GetTensorDims(input1), \ - data->input1_shift, GetTensorData(input2), \ - GetTensorDims(input2), data->input2_shift, \ - data->output_activation_min, data->output_activation_max, \ - GetTensorData(output), GetTensorDims(output)); +#define TF_LITE_ADD(type, opname) \ + tflite::ArithmeticParams op_params; \ + op_params.input1_shift = data->input1_shift; \ + op_params.input2_shift = data->input2_shift; \ + SetActivationParams(data->output_activation_min, \ + data->output_activation_max, &op_params); \ + type::opname(op_params, GetTensorShape(input1), \ + GetTensorData(input1), GetTensorShape(input2), \ + GetTensorData(input2), GetTensorShape(output), \ + GetTensorData(output)) // The quantized version of Add doesn't support activations, so we // always use BroadcastAdd. if (kernel_type == kReference) { diff --git a/tensorflow/contrib/lite/kernels/arg_max.cc b/tensorflow/contrib/lite/kernels/arg_min_max.cc similarity index 70% rename from tensorflow/contrib/lite/kernels/arg_max.cc rename to tensorflow/contrib/lite/kernels/arg_min_max.cc index 26f57e88962116f446e72fbc164d2747e8b633b4..4f30d09030fb8d26c08090b180fdd352a967807f 100644 --- a/tensorflow/contrib/lite/kernels/arg_max.cc +++ b/tensorflow/contrib/lite/kernels/arg_min_max.cc @@ -23,7 +23,7 @@ limitations under the License. namespace tflite { namespace ops { namespace builtin { -namespace arg_max { +namespace arg_min_max { constexpr int kInputTensor = 0; constexpr int kAxis = 1; @@ -80,30 +80,39 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, output_size); } +template +std::function GetComparefunction(bool is_arg_max) { + if (is_arg_max) { + return std::greater(); + } else { + return std::less(); + } +} + // The current impl actually ignores the axis argument. // Only determine the index of the maximum value in the last dimension. -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) { const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* axis = GetInput(context, node, kAxis); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); -#define TF_LITE_ARG_MAX(data_type, axis_type, output_type) \ - optimized_ops::ArgMax(GetTensorData(axis), \ - GetTensorData(input), GetTensorDims(input), \ - GetTensorData(output), \ - GetTensorDims(output)) +#define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \ + optimized_ops::ArgMinMax( \ + GetTensorData(axis), GetTensorData(input), \ + GetTensorDims(input), GetTensorData(output), \ + GetTensorDims(output), GetComparefunction(is_arg_max)) if (axis->type == kTfLiteInt32) { switch (output->type) { case kTfLiteInt32: { switch (input->type) { case kTfLiteFloat32: - TF_LITE_ARG_MAX(float, int32_t, int32_t); + TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t); break; case kTfLiteUInt8: - TF_LITE_ARG_MAX(uint8_t, int32_t, int32_t); + TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t); break; case kTfLiteInt32: - TF_LITE_ARG_MAX(int32_t, int32_t, int32_t); + TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t); break; default: return kTfLiteError; @@ -112,13 +121,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: { switch (input->type) { case kTfLiteFloat32: - TF_LITE_ARG_MAX(float, int32_t, int64_t); + TF_LITE_ARG_MIN_MAX(float, int32_t, int64_t); break; case kTfLiteUInt8: - TF_LITE_ARG_MAX(uint8_t, int32_t, int64_t); + TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t); break; case kTfLiteInt32: - TF_LITE_ARG_MAX(int32_t, int32_t, int64_t); + TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t); break; default: return kTfLiteError; @@ -132,13 +141,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt32: { switch (input->type) { case kTfLiteFloat32: - TF_LITE_ARG_MAX(float, int64_t, int32_t); + TF_LITE_ARG_MIN_MAX(float, int64_t, int32_t); break; case kTfLiteUInt8: - TF_LITE_ARG_MAX(uint8_t, int64_t, int32_t); + TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int32_t); break; case kTfLiteInt32: - TF_LITE_ARG_MAX(int32_t, int64_t, int32_t); + TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int32_t); break; default: return kTfLiteError; @@ -147,13 +156,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: { switch (input->type) { case kTfLiteFloat32: - TF_LITE_ARG_MAX(float, int64_t, int64_t); + TF_LITE_ARG_MIN_MAX(float, int64_t, int64_t); break; case kTfLiteUInt8: - TF_LITE_ARG_MAX(uint8_t, int64_t, int64_t); + TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int64_t); break; case kTfLiteInt32: - TF_LITE_ARG_MAX(int32_t, int64_t, int64_t); + TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int64_t); break; default: return kTfLiteError; @@ -163,16 +172,30 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } } -#undef TF_LITE_ARG_MAX +#undef TF_LITE_ARG_MIN_MAX return kTfLiteOk; } -} // namespace arg_max +TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, false); +} + +TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, true); +} + +} // namespace arg_min_max TfLiteRegistration* Register_ARG_MAX() { - static TfLiteRegistration r = {nullptr, nullptr, arg_max::Prepare, - arg_max::Eval}; + static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare, + arg_min_max::ArgMaxEval}; + return &r; +} + +TfLiteRegistration* Register_ARG_MIN() { + static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare, + arg_min_max::ArgMinEval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/arg_max_test.cc b/tensorflow/contrib/lite/kernels/arg_min_max_test.cc similarity index 52% rename from tensorflow/contrib/lite/kernels/arg_max_test.cc rename to tensorflow/contrib/lite/kernels/arg_min_max_test.cc index 31b15fe19ab87027c28bde9eaff7d88d03b2c213..90e5fdc532c821691aaeca6e6faa4c24919ca2c8 100644 --- a/tensorflow/contrib/lite/kernels/arg_max_test.cc +++ b/tensorflow/contrib/lite/kernels/arg_min_max_test.cc @@ -24,16 +24,13 @@ namespace { using ::testing::ElementsAreArray; template -class ArgMaxOpModel : public SingleOpModel { +class ArgBaseOpModel : public SingleOpModel { public: - ArgMaxOpModel(std::initializer_list input_shape, TensorType input_type, - TensorType output_type, TensorType index_output_type) { + ArgBaseOpModel(std::initializer_list input_shape, TensorType input_type, + TensorType output_type, TensorType index_output_type) { input_ = AddInput(input_type); axis_ = AddInput(TensorType_INT32); output_ = AddOutput(output_type); - SetBuiltinOp(BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions, - CreateArgMaxOptions(builder_, index_output_type).Union()); - BuildInterpreter({input_shape, {1, 1, 1, 1}}); } int input() { return input_; } @@ -42,12 +39,42 @@ class ArgMaxOpModel : public SingleOpModel { std::vector GetOutput() { return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } - private: + protected: int input_; int axis_; int output_; }; +template +class ArgMaxOpModel : public ArgBaseOpModel { + public: + ArgMaxOpModel(std::initializer_list input_shape, TensorType input_type, + TensorType output_type, TensorType index_output_type) + : ArgBaseOpModel(input_shape, input_type, output_type, + index_output_type) { + ArgBaseOpModel::SetBuiltinOp( + BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions, + CreateArgMaxOptions(ArgBaseOpModel::builder_, index_output_type) + .Union()); + ArgBaseOpModel::BuildInterpreter({input_shape, {1, 1, 1, 1}}); + } +}; + +template +class ArgMinOpModel : public ArgBaseOpModel { + public: + ArgMinOpModel(std::initializer_list input_shape, TensorType input_type, + TensorType output_type, TensorType index_output_type) + : ArgBaseOpModel(input_shape, input_type, output_type, + index_output_type) { + ArgBaseOpModel::SetBuiltinOp( + BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions, + CreateArgMinOptions(ArgBaseOpModel::builder_, index_output_type) + .Union()); + ArgBaseOpModel::BuildInterpreter({input_shape, {1, 1, 1, 1}}); + } +}; + TEST(ArgMaxOpTest, GetMaxArgFloat) { ArgMaxOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, TensorType_INT32, TensorType_INT32); @@ -96,6 +123,54 @@ TEST(ArgMaxOpTest, GetMaxArgOutput64) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1})); } +TEST(ArgMinOpTest, GetMinArgFloat) { + ArgMinOpModel model({1, 1, 1, 4}, TensorType_FLOAT32, + TensorType_INT32, TensorType_INT32); + model.PopulateTensor(model.input(), {0.1, 0.9, 0.7, 0.3}); + // Currently only support the last dimension. + model.PopulateTensor(model.axis(), {3}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1})); +} + +TEST(ArgMinOpTest, GetMinArgInt) { + ArgMinOpModel model({1, 1, 1, 4}, TensorType_INT32, TensorType_INT32, + TensorType_INT32); + model.PopulateTensor(model.input(), {1, 9, 7, 3}); + // Currently only support the last dimension. + model.PopulateTensor(model.axis(), {3}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1})); +} + +TEST(ArgMinOpTest, GetMinArgMulDimensions) { + ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32, + TensorType_INT32); + model.PopulateTensor(model.input(), {1, 2, 7, 8, 1, 9, 7, 3}); + // Currently only support the last dimension. + model.PopulateTensor(model.axis(), {3}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1})); +} + +TEST(ArgMinOpTest, GetMinArgOutput64) { + ArgMinOpModel model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT64, + TensorType_INT64); + model.PopulateTensor(model.input(), {10, 2, 7, 8, 1, 9, 7, 3}); + // Currently only support the last dimension. + model.PopulateTensor(model.axis(), {3}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc index 3425288f027a6fd9eb65f730bc7d039c832ace1c..a11a59aa050675314ac8b1316cdd0f15c81b8b15 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -276,27 +275,33 @@ TfLiteStatus CheckLstmTensorDimensions( TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteNode* node, int n_input, int n_output, int n_cell) { - CheckLstmTensorDimensions( - context, node, n_input, n_output, n_cell, kFwInputToInputWeightsTensor, - kFwInputToForgetWeightsTensor, kFwInputToCellWeightsTensor, - kFwInputToOutputWeightsTensor, kFwRecurrentToInputWeightsTensor, - kFwRecurrentToForgetWeightsTensor, kFwRecurrentToCellWeightsTensor, - kFwRecurrentToOutputWeightsTensor, kFwCellToInputWeightsTensor, - kFwCellToForgetWeightsTensor, kFwCellToOutputWeightsTensor, - kFwInputGateBiasTensor, kFwForgetGateBiasTensor, kFwCellGateBiasTensor, - kFwOutputGateBiasTensor, kFwProjectionWeightsTensor, - kFwProjectionBiasTensor); - - CheckLstmTensorDimensions( - context, node, n_input, n_output, n_cell, kBwInputToInputWeightsTensor, - kBwInputToForgetWeightsTensor, kBwInputToCellWeightsTensor, - kBwInputToOutputWeightsTensor, kBwRecurrentToInputWeightsTensor, - kBwRecurrentToForgetWeightsTensor, kBwRecurrentToCellWeightsTensor, - kBwRecurrentToOutputWeightsTensor, kBwCellToInputWeightsTensor, - kBwCellToForgetWeightsTensor, kBwCellToOutputWeightsTensor, - kBwInputGateBiasTensor, kBwForgetGateBiasTensor, kBwCellGateBiasTensor, - kBwOutputGateBiasTensor, kBwProjectionWeightsTensor, - kBwProjectionBiasTensor); + TF_LITE_ENSURE_OK( + context, + CheckLstmTensorDimensions( + context, node, n_input, n_output, n_cell, + kFwInputToInputWeightsTensor, kFwInputToForgetWeightsTensor, + kFwInputToCellWeightsTensor, kFwInputToOutputWeightsTensor, + kFwRecurrentToInputWeightsTensor, kFwRecurrentToForgetWeightsTensor, + kFwRecurrentToCellWeightsTensor, kFwRecurrentToOutputWeightsTensor, + kFwCellToInputWeightsTensor, kFwCellToForgetWeightsTensor, + kFwCellToOutputWeightsTensor, kFwInputGateBiasTensor, + kFwForgetGateBiasTensor, kFwCellGateBiasTensor, + kFwOutputGateBiasTensor, kFwProjectionWeightsTensor, + kFwProjectionBiasTensor)); + + TF_LITE_ENSURE_OK( + context, + CheckLstmTensorDimensions( + context, node, n_input, n_output, n_cell, + kBwInputToInputWeightsTensor, kBwInputToForgetWeightsTensor, + kBwInputToCellWeightsTensor, kBwInputToOutputWeightsTensor, + kBwRecurrentToInputWeightsTensor, kBwRecurrentToForgetWeightsTensor, + kBwRecurrentToCellWeightsTensor, kBwRecurrentToOutputWeightsTensor, + kBwCellToInputWeightsTensor, kBwCellToForgetWeightsTensor, + kBwCellToOutputWeightsTensor, kBwInputGateBiasTensor, + kBwForgetGateBiasTensor, kBwCellGateBiasTensor, + kBwOutputGateBiasTensor, kBwProjectionWeightsTensor, + kBwProjectionBiasTensor)); // Check if Forward and Backward tensors match along required dimensions. return kTfLiteOk; @@ -334,7 +339,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1]; // Check that input tensor dimensions matches with each other. - CheckInputTensorDimensions(context, node, n_input, n_fw_output, n_fw_cell); + TF_LITE_ENSURE_OK( + context, CheckInputTensorDimensions(context, node, n_input, n_fw_output, + n_fw_cell)); // Get the pointer to output, state and scratch buffer tensors. TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); @@ -404,7 +411,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1]; // Check that input tensor dimensions matches with each other. - CheckInputTensorDimensions(context, node, n_input, n_bw_output, n_bw_cell); + TF_LITE_ENSURE_OK( + context, CheckInputTensorDimensions(context, node, n_input, n_bw_output, + n_bw_cell)); // Get the pointer to output, output_state and cell_state buffer tensors. TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc index aa24c1f34cd1e8c02a6a75b62fbe5f3c629498ca..517309a226bcfb717186be8c1d02d68e3b337f8e 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc index f678f48fa5bbbcece6c5b87030d951783378d78f..8b4d778332afd5f4b53509bd669a674c63d9f6f9 100644 --- a/tensorflow/contrib/lite/kernels/comparisons.cc +++ b/tensorflow/contrib/lite/kernels/comparisons.cc @@ -57,6 +57,57 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, output_size); } +// TODO(ruic): optimize macros below to using template functions. +#define TF_LITE_QUANTIZE_COMPARISON(opname) \ + void EvalQuantized##opname(TfLiteContext* context, TfLiteNode* node, \ + const TfLiteTensor* input1, \ + const TfLiteTensor* input2, TfLiteTensor* output, \ + bool requires_broadcast) { \ + if (input1->type == kTfLiteUInt8) { \ + auto input1_offset = -input1->params.zero_point; \ + auto input2_offset = -input2->params.zero_point; \ + const int left_shift = 20; \ + const double twice_max_input_scale = \ + 2 * std::max(input1->params.scale, input2->params.scale); \ + const double real_input1_multiplier = \ + input1->params.scale / twice_max_input_scale; \ + const double real_input2_multiplier = \ + input2->params.scale / twice_max_input_scale; \ + \ + int32 input1_multiplier; \ + int input1_shift; \ + QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, \ + &input1_multiplier, &input1_shift); \ + int32 input2_multiplier; \ + int input2_shift; \ + QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, \ + &input2_multiplier, &input2_shift); \ + \ + if (requires_broadcast) { \ + reference_ops::Broadcast##opname( \ + left_shift, GetTensorData(input1), GetTensorDims(input1), \ + input1_offset, input1_multiplier, input1_shift, \ + GetTensorData(input2), GetTensorDims(input2), \ + input2_offset, input2_multiplier, input2_shift, \ + GetTensorData(output), GetTensorDims(output)); \ + } else { \ + reference_ops::opname( \ + left_shift, GetTensorData(input1), GetTensorDims(input1), \ + input1_offset, input1_multiplier, input1_shift, \ + GetTensorData(input2), GetTensorDims(input2), \ + input2_offset, input2_multiplier, input2_shift, \ + GetTensorData(output), GetTensorDims(output)); \ + } \ + } \ + } +TF_LITE_QUANTIZE_COMPARISON(Equal); +TF_LITE_QUANTIZE_COMPARISON(NotEqual); +TF_LITE_QUANTIZE_COMPARISON(Greater); +TF_LITE_QUANTIZE_COMPARISON(GreaterEqual); +TF_LITE_QUANTIZE_COMPARISON(Less); +TF_LITE_QUANTIZE_COMPARISON(LessEqual); +#undef TF_LITE_QUANTIZE_COMPARISON + #define TF_LITE_COMPARISON(type, opname, requires_broadcast) \ requires_broadcast \ ? reference_ops::Broadcast##opname( \ @@ -73,7 +124,6 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); - // TODO(renjieliu): Support quantized data. switch (input1->type) { case kTfLiteFloat32: TF_LITE_COMPARISON(float, Equal, requires_broadcast); @@ -84,9 +134,13 @@ TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_COMPARISON(int64_t, Equal, requires_broadcast); break; + case kTfLiteUInt8: + EvalQuantizedEqual(context, node, input1, input2, output, + requires_broadcast); + break; default: context->ReportError(context, - "Does not support type %d, requires float|int", + "Does not support type %d, requires float|int|uint8", input1->type); return kTfLiteError; } @@ -99,7 +153,6 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); - // TODO(renjieliu): Support quantized data. switch (input1->type) { case kTfLiteFloat32: TF_LITE_COMPARISON(float, NotEqual, requires_broadcast); @@ -110,9 +163,13 @@ TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_COMPARISON(int64_t, NotEqual, requires_broadcast); break; + case kTfLiteUInt8: + EvalQuantizedNotEqual(context, node, input1, input2, output, + requires_broadcast); + break; default: context->ReportError(context, - "Does not support type %d, requires float|int", + "Does not support type %d, requires float|int|uint8", input1->type); return kTfLiteError; } @@ -124,7 +181,6 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); - // TODO(renjieliu): Support quantized data. switch (input1->type) { case kTfLiteFloat32: TF_LITE_COMPARISON(float, Greater, requires_broadcast); @@ -135,9 +191,13 @@ TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_COMPARISON(int64_t, Greater, requires_broadcast); break; + case kTfLiteUInt8: + EvalQuantizedGreater(context, node, input1, input2, output, + requires_broadcast); + break; default: context->ReportError(context, - "Does not support type %d, requires float|int", + "Does not support type %d, requires float|int|uint8", input1->type); return kTfLiteError; } @@ -149,7 +209,6 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); - // TODO(renjieliu): Support quantized data. switch (input1->type) { case kTfLiteFloat32: TF_LITE_COMPARISON(float, GreaterEqual, requires_broadcast); @@ -160,9 +219,13 @@ TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_COMPARISON(int64_t, GreaterEqual, requires_broadcast); break; + case kTfLiteUInt8: + EvalQuantizedGreaterEqual(context, node, input1, input2, output, + requires_broadcast); + break; default: context->ReportError(context, - "Does not support type %d, requires float|int", + "Does not support type %d, requires float|int|uint8", input1->type); return kTfLiteError; } @@ -174,7 +237,6 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); - // TODO(renjieliu): Support quantized data. switch (input1->type) { case kTfLiteFloat32: TF_LITE_COMPARISON(float, Less, requires_broadcast); @@ -185,9 +247,13 @@ TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_COMPARISON(int64_t, Less, requires_broadcast); break; + case kTfLiteUInt8: + EvalQuantizedLess(context, node, input1, input2, output, + requires_broadcast); + break; default: context->ReportError(context, - "Does not support type %d, requires float|int", + "Does not support type %d, requires float|int|uint8", input1->type); return kTfLiteError; } @@ -199,7 +265,6 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); bool requires_broadcast = !HaveSameShapes(input1, input2); - // TODO(renjieliu): Support quantized data. switch (input1->type) { case kTfLiteFloat32: TF_LITE_COMPARISON(float, LessEqual, requires_broadcast); @@ -210,9 +275,13 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteInt64: TF_LITE_COMPARISON(int64_t, LessEqual, requires_broadcast); break; + case kTfLiteUInt8: + EvalQuantizedLessEqual(context, node, input1, input2, output, + requires_broadcast); + break; default: context->ReportError(context, - "Does not support type %d, requires float|int", + "Does not support type %d, requires float|int|uint8", input1->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc index bb02e1c812fdc40bf515f1f978e9e39b5a16a4ea..67a91c17fd4f25e4a9ea22de5e2a10dc1c17656d 100644 --- a/tensorflow/contrib/lite/kernels/comparisons_test.cc +++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc @@ -35,6 +35,15 @@ class ComparisonOpModel : public SingleOpModel { BuildInterpreter({input1_shape, input2_shape}); } + ComparisonOpModel(const TensorData& input1, const TensorData& input2, + TensorType input_type, BuiltinOperator op) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(TensorType_BOOL); + ConfigureBuiltinOp(op); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + int input1() { return input1_; } int input2() { return input2_; } @@ -354,6 +363,192 @@ TEST(ComparisonsTest, LessEqualBroadcastTwoD) { EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } +TEST(QuantizedComparisonsTest, EqualQuantized) { + const float kMin = -1.f; + const float kMax = 128.f; + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_EQUAL); + model.QuantizeAndPopulate(model.input1(), {1, 9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 7, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, false)); +} + +TEST(QuantizedComparisonsTest, NotEqualQuantized) { + const float kMin = -1.f; + const float kMax = 128.f; + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_NOT_EQUAL); + model.QuantizeAndPopulate(model.input1(), {1, 9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 7, 0}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, true)); +} + +TEST(ComparisonsTest, GreaterQuantized) { + const float kMin = -1.f; + const float kMax = 128.f; + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_GREATER); + model.QuantizeAndPopulate(model.input1(), {1, 9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 6, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); +} + +TEST(ComparisonsTest, GreaterEqualQuantized) { + const float kMin = -1.f; + const float kMax = 128.f; + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_GREATER_EQUAL); + model.QuantizeAndPopulate(model.input1(), {1, 9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 6, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, true, false)); +} + +TEST(ComparisonsTest, LessQuantized) { + const float kMin = -1.f; + const float kMax = 128.f; + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_LESS); + model.QuantizeAndPopulate(model.input1(), {1, 9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 6, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, true)); +} + +TEST(ComparisonsTest, LessEqualQuantized) { + const float kMin = -1.f; + const float kMax = 128.f; + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_UINT8, {1, 2, 2, 1}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_LESS_EQUAL); + model.QuantizeAndPopulate(model.input1(), {1, 9, 7, 3}); + model.QuantizeAndPopulate(model.input2(), {1, 2, 6, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); +} + +TEST(ComparisonsTest, QuantizedEqualWithBroadcast) { + const float kMin = -1.f; + const float kMax = 128.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax}, + {TensorType_UINT8, {}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_EQUAL); + model.QuantizeAndPopulate(model.input1(), {20, 2, 7, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {2}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, false, false, false, false)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedNotEqualWithBroadcast) { + const float kMin = -1.f; + const float kMax = 128.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax}, + {TensorType_UINT8, {}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_NOT_EQUAL); + model.QuantizeAndPopulate(model.input1(), {20, 2, 7, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {2}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, true, true, true, true)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedGreaterWithBroadcast) { + const float kMin = -1.f; + const float kMax = 128.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax}, + {TensorType_UINT8, {}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_GREATER); + model.QuantizeAndPopulate(model.input1(), {20, 2, 7, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {8}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, false, false, true, true)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedGreaterEqualWithBroadcast) { + const float kMin = -1.f; + const float kMax = 128.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax}, + {TensorType_UINT8, {}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_GREATER_EQUAL); + model.QuantizeAndPopulate(model.input1(), {20, 2, 7, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {8}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, false, true, true, true)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedLessWithBroadcast) { + const float kMin = -1.f; + const float kMax = 128.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax}, + {TensorType_UINT8, {}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_LESS); + model.QuantizeAndPopulate(model.input1(), {20, 2, 7, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {8}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, true, false, false, false)) + << "With shape number " << i; + } +} + +TEST(ComparisonsTest, QuantizedLessEqualWithBroadcast) { + const float kMin = -1.f; + const float kMax = 128.f; + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + ComparisonOpModel model({TensorType_UINT8, test_shapes[i], kMin, kMax}, + {TensorType_UINT8, {}, kMin, kMax}, + TensorType_UINT8, BuiltinOperator_LESS_EQUAL); + model.QuantizeAndPopulate(model.input1(), {20, 2, 7, 8, 11, 20}); + model.QuantizeAndPopulate(model.input2(), {8}); + model.Invoke(); + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, true, true, false, false)) + << "With shape number " << i; + } +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc index 45ea8d00498455be98467f2f1addc8ad7dcf35fa..605a20ac3e7c8346db2bcf64e9422132b433b3da 100644 --- a/tensorflow/contrib/lite/kernels/concatenation.cc +++ b/tensorflow/contrib/lite/kernels/concatenation.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -58,7 +57,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, t0->dims->size <= 4); TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); TF_LITE_ENSURE(context, - input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8); + input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || + input_type == kTfLiteInt16 || input_type == kTfLiteInt32 || + input_type == kTfLiteInt64); // Output dimensions will match input dimensions, except 'axis', which // will be the sum of inputs @@ -122,6 +123,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_CONCATENATION(optimized_ops, float); } break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_CONCATENATION(reference_ops, int32); + } else { + TF_LITE_CONCATENATION(optimized_ops, int32); + } + break; case kTfLiteUInt8: if (kernel_type == kReference) { TF_LITE_CONCATENATION_QUANTIZED(reference_ops); @@ -129,6 +137,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_CONCATENATION_QUANTIZED(optimized_ops); } break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_CONCATENATION(reference_ops, int64_t); + } else { + TF_LITE_CONCATENATION(optimized_ops, int64_t); + } + break; + default: context->ReportError(context, "Only float32 and uint8 are currently supported."); diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 0321b2e2a0088bdb09b2c3c61827be8064fe939b..04c0263b789e75727ed3bd4d6b3292063a4530e0 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -257,10 +256,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); - TF_LITE_ENSURE(context, real_multiplier < 1.0); - QuantizeMultiplierSmallerThanOneExp( - real_multiplier, &data->output_multiplier, &data->output_shift); - data->output_shift *= -1; + + int exponent; + QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); + data->output_shift = -exponent; CalculateActivationRangeUint8(params->activation, output, &data->output_activation_min, &data->output_activation_max); @@ -418,6 +417,7 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, filter_data = GetTensorData(filter); } multithreaded_ops::Conv( + *eigen_support::GetThreadPoolDevice(context), GetTensorData(input), GetTensorDims(input), filter_data, GetTensorDims(filter), GetTensorData(bias), GetTensorDims(bias), params->stride_width, params->stride_height, diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc index 0dcfc826fd218d2d2dfbf89201d2c13fbfe6f0e1..24633c2fd7cb3725977ae6c6459daa829165ccfd 100644 --- a/tensorflow/contrib/lite/kernels/conv_test.cc +++ b/tensorflow/contrib/lite/kernels/conv_test.cc @@ -64,12 +64,6 @@ class BaseConvolutionOpModel : public SingleOpModel { } output_ = AddOutput(output); - if (input.type != TensorType_FLOAT32) { - // The following is required by quantized inference. It is the unittest's - // responsibility to make sure the output scale falls into the correct - // range. - CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); - } SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions, CreateConv2DOptions( @@ -441,6 +435,44 @@ TEST_P(ConvolutionOpTest, SimpleTestQuantized) { })); } +TEST_P(ConvolutionOpTest, SimpleTestQuantizedOutputMultiplierGreaterThan1) { + // output_multiplier = 1.0118 + QuantizedConvolutionOpModel quant_op( + GetRegistration(), {TensorType_UINT8, {2, 2, 4, 1}, -128.5, 128}, + {TensorType_UINT8, {3, 2, 2, 1}, -128.5, 128}, + {TensorType_UINT8, {}, -127, 128}); + ConvolutionOpModel float_op( + GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_FLOAT32, {3, 2, 2, 1}}, {TensorType_FLOAT32, {}}); + std::initializer_list input = { + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }; + std::initializer_list filter = { + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }; + std::initializer_list bias = {1, 2, 3}; + + quant_op.SetInput(input); + quant_op.SetFilter(filter); + quant_op.SetBias(bias); + quant_op.Invoke(); + + float_op.SetInput(input); + float_op.SetFilter(filter); + float_op.SetBias(bias); + float_op.Invoke(); + + EXPECT_THAT(quant_op.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1))); +} + TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { QuantizedConvolutionOpModel m(GetRegistration(), {TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64}, diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc index 16e5f1d065d8ea6d187c5e368d6c9385fe62514b..21518156b851892f50c62df7901d71c41fd733f7 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc +++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc index 0c532cac5a9f59c8b09ff9aefc294e243561f027..d7bde0ff79bd23fa4c277dd04ec4343663e0ad00 100644 --- a/tensorflow/contrib/lite/kernels/detection_postprocess.cc +++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc @@ -40,8 +40,8 @@ constexpr int kOutputTensorDetectionClasses = 1; constexpr int kOutputTensorDetectionScores = 2; constexpr int kOutputTensorNumDetections = 3; -constexpr size_t kNumCoordBox = 4; -constexpr size_t kBatchSize = 1; +constexpr int kNumCoordBox = 4; +constexpr int kBatchSize = 1; // Object Detection model produces axis-aligned boxes in two formats: // BoxCorner represents the upper right (xmin, ymin) and diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc index bc5c3783fd63451fd6d600df2d8e93f740c68e95..d7420ddd8e41a57c901527884e942d444e543aa6 100644 --- a/tensorflow/contrib/lite/kernels/div.cc +++ b/tensorflow/contrib/lite/kernels/div.cc @@ -78,29 +78,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } template -void EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteDivParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRange(params->activation, &output_activation_min, - &output_activation_max); -#define TF_LITE_DIV(type, opname) \ - type::opname(GetTensorData(input1), GetTensorDims(input1), \ - GetTensorData(input2), GetTensorDims(input2), \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)) - if (kernel_type == kReference) { - if (data->requires_broadcast) { - TF_LITE_DIV(reference_ops, BroadcastDiv); +void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params, + const OpData* data, const TfLiteTensor* input1, + const TfLiteTensor* input2, TfLiteTensor* output) { +#define TF_LITE_DIV(type, opname, data_type) \ + data_type output_activation_min, output_activation_max; \ + CalculateActivationRange(params->activation, &output_activation_min, \ + &output_activation_max); \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (output->type == kTfLiteInt32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_DIV(reference_ops, BroadcastDiv, int32_t); + } else { + TF_LITE_DIV(reference_ops, Div, int32_t); + } } else { - TF_LITE_DIV(reference_ops, Div); + if (data->requires_broadcast) { + TF_LITE_DIV(optimized_ops, BroadcastDiv, int32_t); + } else { + TF_LITE_DIV(optimized_ops, Div, int32_t); + } } - } else { - if (data->requires_broadcast) { - TF_LITE_DIV(optimized_ops, BroadcastDiv); + } else if (output->type == kTfLiteFloat32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_DIV(reference_ops, BroadcastDiv, float); + } else { + TF_LITE_DIV(reference_ops, Div, float); + } } else { - TF_LITE_DIV(optimized_ops, Div); + if (data->requires_broadcast) { + TF_LITE_DIV(optimized_ops, BroadcastDiv, float); + } else { + TF_LITE_DIV(optimized_ops, Div, float); + } } } #undef TF_LITE_DIV @@ -115,11 +130,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - if (output->type == kTfLiteFloat32) { - EvalFloat(context, node, params, data, input1, input2, output); + if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { + EvalDiv(context, node, params, data, input1, input2, output); } else { context->ReportError( - context, "Div only supports FLOAT32 and quantized UINT8 now, got %d.", + context, + "Div only supports FLOAT32, INT32 and quantized UINT8 now, got %d.", output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/div_test.cc b/tensorflow/contrib/lite/kernels/div_test.cc index 276b8289fbc1b4dcbf4624b76b854300d0fd4912..97aa2fe04e27416b99f48ab61ece54b745597ae3 100644 --- a/tensorflow/contrib/lite/kernels/div_test.cc +++ b/tensorflow/contrib/lite/kernels/div_test.cc @@ -52,6 +52,13 @@ class FloatDivOpModel : public BaseDivOpModel { std::vector GetOutput() { return ExtractVector(output_); } }; +class IntegerDivOpModel : public BaseDivOpModel { + public: + using BaseDivOpModel::BaseDivOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + TEST(FloatDivOpTest, NoActivation) { FloatDivOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, @@ -75,7 +82,7 @@ TEST(FloatDivOpTest, ActivationRELU_N1_TO_1) { } TEST(FloatDivOpTest, VariousInputShapes) { - std::vector> test_shapes = { + std::vector> test_shapes = { {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { FloatDivOpModel m({TensorType_FLOAT32, test_shapes[i]}, @@ -92,7 +99,7 @@ TEST(FloatDivOpTest, VariousInputShapes) { } TEST(FloatDivOpTest, WithBroadcast) { - std::vector> test_shapes = { + std::vector> test_shapes = { {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { FloatDivOpModel m({TensorType_FLOAT32, test_shapes[i]}, @@ -108,6 +115,56 @@ TEST(FloatDivOpTest, WithBroadcast) { } } +TEST(IntegerDivOpTest, NoActivation) { + IntegerDivOpModel m({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, + ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2, 2, -15, 8}); + m.PopulateTensor(m.input2(), {5, -2, -3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, -1, 5, 1})); +} + +TEST(IntegerDivOpTest, ActivationRELU_N1_TO_1) { + IntegerDivOpModel m({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, + ActivationFunctionType_RELU_N1_TO_1); + m.PopulateTensor(m.input1(), {-2, 2, -12, 8}); + m.PopulateTensor(m.input2(), {1, 2, -15, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 1, 0, 1})); +} + +TEST(IntegerDivOpTest, VariousInputShapes) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + IntegerDivOpModel m({TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 3, 8, 11, -20}); + m.PopulateTensor(m.input2(), {1, 2, 6, 5, -11, -1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-20, 1, 0, 1, -1, 20})) + << "With shape number " << i; + } +} + +TEST(IntegerDivOpTest, WithBroadcast) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + IntegerDivOpModel m({TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, {}}, // always a scalar + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 21, 7, 8, 11, -123}); + m.PopulateTensor(m.input2(), {3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-6, 7, 2, 2, 3, -41})) + << "With shape number " << i; + } +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc index f1fdb42624073717fb70423ff70dfad08e578ca6..e542ad076528fa30152abba074a5c7dcd6ca1f48 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.cc +++ b/tensorflow/contrib/lite/kernels/eigen_support.cc @@ -14,31 +14,100 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/eigen_support.h" -#include "third_party/eigen3/Eigen/Core" +#include + +#include "tensorflow/contrib/lite/arena_planner.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { namespace eigen_support { +namespace { + +#ifndef EIGEN_DONT_ALIGN +// Eigen may require buffers to be algiend to 16, 32 or 64 bytes depending on +// hardware architecture and build configurations. +// If the static assertion fails, try to increase `kDefaultTensorAlignment` to +// in `arena_planner.h` to 32 or 64. +static_assert( + kDefaultTensorAlignment % EIGEN_MAX_ALIGN_BYTES == 0, + "kDefaultArenaAlignment doesn't comply with Eigen alignment requirement."); +#endif // EIGEN_DONT_ALIGN + +// We have a single global threadpool for all convolution operations. This means +// that inferences started from different threads may block each other, but +// since the underlying resource of CPU cores should be consumed by the +// operations anyway, it shouldn't affect overall performance. +class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { + public: + // Takes ownership of 'pool' + explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {} + ~EigenThreadPoolWrapper() override {} -struct RefCountedEigenContext { + void Schedule(std::function fn) override { + pool_->Schedule(std::move(fn)); + } + int NumThreads() const override { return pool_->NumThreads(); } + int CurrentThreadId() const override { return pool_->CurrentThreadId(); } + + private: + std::unique_ptr pool_; +}; + +struct RefCountedEigenContext : public TfLiteExternalContext { + std::unique_ptr thread_pool_wrapper; + std::unique_ptr device; int num_references = 0; }; +RefCountedEigenContext* GetEigenContext(TfLiteContext* context) { + return reinterpret_cast( + context->GetExternalContext(context, kTfLiteEigenContext)); +} + +void InitDevice(TfLiteContext* context, RefCountedEigenContext* ptr) { + int num_threads = 4; + if (context->recommended_num_threads != -1) { + num_threads = context->recommended_num_threads; + } + ptr->device.reset(); // destroy before we invalidate the thread pool + ptr->thread_pool_wrapper.reset( + new EigenThreadPoolWrapper(new Eigen::ThreadPool(num_threads))); + ptr->device.reset( + new Eigen::ThreadPoolDevice(ptr->thread_pool_wrapper.get(), num_threads)); +} + +TfLiteStatus Refresh(TfLiteContext* context) { + Eigen::setNbThreads(context->recommended_num_threads); + + auto* ptr = GetEigenContext(context); + if (ptr != nullptr) { + InitDevice(context, ptr); + } + + return kTfLiteOk; +} + +} // namespace + void IncrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast(context->eigen_context); + auto* ptr = GetEigenContext(context); if (ptr == nullptr) { if (context->recommended_num_threads != -1) { Eigen::setNbThreads(context->recommended_num_threads); } ptr = new RefCountedEigenContext; + ptr->type = kTfLiteEigenContext; + ptr->Refresh = Refresh; ptr->num_references = 0; - context->eigen_context = ptr; + InitDevice(context, ptr); + context->SetExternalContext(context, kTfLiteEigenContext, ptr); } ptr->num_references++; } void DecrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast(context->eigen_context); + auto* ptr = GetEigenContext(context); if (ptr == nullptr) { TF_LITE_FATAL( "Call to DecrementUsageCounter() not preceded by " @@ -46,14 +115,17 @@ void DecrementUsageCounter(TfLiteContext* context) { } if (--ptr->num_references == 0) { delete ptr; - context->eigen_context = nullptr; + context->SetExternalContext(context, kTfLiteEigenContext, nullptr); } } -void SetNumThreads(TfLiteContext* context, int num_threads) { - IncrementUsageCounter(context); - Eigen::setNbThreads(num_threads); - DecrementUsageCounter(context); +const Eigen::ThreadPoolDevice* GetThreadPoolDevice(TfLiteContext* context) { + auto* ptr = GetEigenContext(context); + if (ptr == nullptr) { + TF_LITE_FATAL( + "Call to GetFromContext() not preceded by IncrementUsageCounter()"); + } + return ptr->device.get(); } } // namespace eigen_support diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h index aa8c351fd8e8dae45f7d4807ce24d80bb393c41c..ec77856b1054e85c405193c6f44dc6e74b58a645 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.h +++ b/tensorflow/contrib/lite/kernels/eigen_support.h @@ -17,6 +17,10 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" +namespace EigenForTFLite { +class ThreadPoolDevice; +} + namespace tflite { namespace eigen_support { @@ -28,8 +32,8 @@ void IncrementUsageCounter(TfLiteContext* context); // usages all temporary Eigen objects will be deleted. void DecrementUsageCounter(TfLiteContext* context); -// Set the number of threads that can be used by Eigen. -void SetNumThreads(TfLiteContext* context, int num_threads); +const EigenForTFLite::ThreadPoolDevice* GetThreadPoolDevice( + TfLiteContext* context); } // namespace eigen_support } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc index 59bab3c4ecd20bf938919ca606a5933f3112f233..e19779ea59d441984d3562508e4237e10ce17515 100644 --- a/tensorflow/contrib/lite/kernels/elementwise.cc +++ b/tensorflow/contrib/lite/kernels/elementwise.cc @@ -22,79 +22,118 @@ namespace tflite { namespace ops { namespace builtin { namespace elementwise { +namespace { +bool IsNumericSupportedType(const TfLiteType type) { + return type == kTfLiteFloat32; +} + +bool IsLogicalSupportedType(const TfLiteType type) { + return type == kTfLiteBool; +} + +typedef bool (*IsSupportedType)(TfLiteType); +template TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_EQ(context, input->type, output->type); - // Quantized float is not supported yet. - TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + if (!IsSupportedType(input->type)) { + context->ReportError(context, "Current data type %d is not supported.", + input->type); + return kTfLiteError; + } return context->ResizeTensor(context, output, TfLiteIntArrayCopy(input->dims)); } -inline TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, - float float_func(float)) { +template +inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, + T func(T), TfLiteType expected_type) { const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); - switch (input->type) { - case kTfLiteFloat32: { - size_t elements = NumElements(input); - const float* in = GetTensorData(input); - const float* in_end = in + elements; - float* out = output->data.f; - for (; in < in_end; in++, out++) *out = float_func(*in); - return kTfLiteOk; - } - default: { - context->ReportError(context, "Input type is %d, requires float32", - input->type); - return kTfLiteError; - } + TF_LITE_ENSURE_EQ(context, input->type, expected_type); + const int64_t num_elements = NumElements(input); + const T* in_data = GetTensorData(input); + T* out_data = GetTensorData(output); + for (int64_t i = 0; i < num_elements; ++i) { + out_data[i] = func(in_data[i]); } + return kTfLiteOk; +} + +inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node, + float float_func(float)) { + return EvalImpl(context, node, float_func, kTfLiteFloat32); +} + +inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node, + bool bool_func(bool)) { + return EvalImpl(context, node, bool_func, kTfLiteBool); } TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { - return Eval(context, node, std::sin); + return EvalNumeric(context, node, std::sin); } TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) { - return Eval(context, node, std::log); + return EvalNumeric(context, node, std::log); } TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) { - return Eval(context, node, std::sqrt); + return EvalNumeric(context, node, std::sqrt); } TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) { - return Eval(context, node, [](float f) { return 1.f / std::sqrt(f); }); + return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); }); +} + +TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) { + return EvalLogical(context, node, [](bool v) { return !v; }); } +} // namespace } // namespace elementwise TfLiteRegistration* Register_SIN() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, - elementwise::SinEval}; + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::SinEval}; return &r; } TfLiteRegistration* Register_LOG() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, - elementwise::LogEval}; + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::LogEval}; return &r; } TfLiteRegistration* Register_SQRT() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, - elementwise::SqrtEval}; + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::SqrtEval}; return &r; } TfLiteRegistration* Register_RSQRT() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, - elementwise::RsqrtEval}; + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::RsqrtEval}; + return &r; +} + +TfLiteRegistration* Register_LOGICAL_NOT() { + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::LogicalNotEval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc index ce4c602ee5c788d67701af3ecd3e023f2b25aae7..b9d7d73c52862da9166f6881b1e27a6ff6b76bbc 100644 --- a/tensorflow/contrib/lite/kernels/elementwise_test.cc +++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc @@ -24,26 +24,40 @@ namespace { using ::testing::ElementsAreArray; -class ElementWiseOpModel : public SingleOpModel { +class ElementWiseOpBaseModel : public SingleOpModel { public: - ElementWiseOpModel(BuiltinOperator op, - std::initializer_list input_shape) { + int input() const { return input_; } + int output() const { return output_; } + + protected: + int input_; + int output_; +}; + +class ElementWiseOpFloatModel : public ElementWiseOpBaseModel { + public: + ElementWiseOpFloatModel(BuiltinOperator op, + std::initializer_list input_shape) { input_ = AddInput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(op, BuiltinOptions_NONE, 0); BuildInterpreter({input_shape}); } +}; - int input() const { return input_; } - int output() const { return output_; } - - private: - int input_; - int output_; +class ElementWiseOpBoolModel : public ElementWiseOpBaseModel { + public: + ElementWiseOpBoolModel(BuiltinOperator op, + std::initializer_list input_shape) { + input_ = AddInput(TensorType_BOOL); + output_ = AddOutput(TensorType_BOOL); + SetBuiltinOp(op, BuiltinOptions_NONE, 0); + BuildInterpreter({input_shape}); + } }; TEST(ElementWise, Sin) { - ElementWiseOpModel m(BuiltinOperator_SIN, {1, 1, 4, 1}); + ElementWiseOpFloatModel m(BuiltinOperator_SIN, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {0, 3.1415926, -3.1415926, 1}); m.Invoke(); EXPECT_THAT(m.ExtractVector(m.output()), @@ -52,7 +66,7 @@ TEST(ElementWise, Sin) { } TEST(ElementWise, Log) { - ElementWiseOpModel m(BuiltinOperator_LOG, {1, 1, 4, 1}); + ElementWiseOpFloatModel m(BuiltinOperator_LOG, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {1, 3.1415926, 1, 1}); m.Invoke(); EXPECT_THAT(m.ExtractVector(m.output()), @@ -61,7 +75,7 @@ TEST(ElementWise, Log) { } TEST(ElementWise, Sqrt) { - ElementWiseOpModel m(BuiltinOperator_SQRT, {1, 1, 4, 1}); + ElementWiseOpFloatModel m(BuiltinOperator_SQRT, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {0, 1, 2, 4}); m.Invoke(); EXPECT_THAT(m.ExtractVector(m.output()), @@ -70,7 +84,7 @@ TEST(ElementWise, Sqrt) { } TEST(ElementWise, Rsqrt) { - ElementWiseOpModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1}); + ElementWiseOpFloatModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {1, 2, 4, 9}); m.Invoke(); EXPECT_THAT(m.ExtractVector(m.output()), @@ -78,6 +92,15 @@ TEST(ElementWise, Rsqrt) { EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); } +TEST(ElementWise, LogicalNot) { + ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {true, false, true, false}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({false, true, false, true})); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc index 9410bead5e7a68363d034c22fb2c0eff9f060ef1..b2dff87e6296c6038241c704d9158e174501f026 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc @@ -29,7 +29,6 @@ limitations under the License. // When indices are out of bound, the ops will not succeed. // -#include #include #include #include @@ -94,7 +93,7 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* lookup, const TfLiteTensor* value, TfLiteTensor* output) { const int row_size = SizeOfDimension(value, 0); - const double scaling_factor = 1.0 / value->params.scale; + const double scaling_factor = value->params.scale; // col_size after we flatten tensor into 2D. int col_size = 1; @@ -112,8 +111,9 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, // TODO(alanchiao): refactor scalar multiply into separate function // for ease of adding a neon equivalent if ever necessary. for (int j = 0; j < col_size; j++) { + const int8_t* value_ptr = reinterpret_cast(value->data.uint8); output->data.f[j + i * col_size] = - value->data.uint8[j + idx * col_size] * scaling_factor; + value_ptr[j + idx * col_size] * scaling_factor; } } } diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc index 04657fd86323ef1c58d069c06097c7665f55cc87..4a88d168c60203f10802e634def9b1d1316c9c6d 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc @@ -107,9 +107,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) { HybridEmbeddingLookupOpModel m({3}, {3, 8}); m.SetInput({1, 0, 2}); m.SetWeight({ - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }); m.Invoke(); @@ -117,9 +117,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }, 7.41e-03))); } @@ -128,9 +128,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) { HybridEmbeddingLookupOpModel m({3}, {3, 2, 4}); m.SetInput({1, 0, 2}); m.SetWeight({ - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }); m.Invoke(); @@ -138,9 +138,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }, 7.41e-03))); } @@ -149,9 +149,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) { HybridEmbeddingLookupOpModel m({3}, {3, 2, 2, 2}); m.SetInput({1, 0, 2}); m.SetWeight({ - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }); m.Invoke(); @@ -159,9 +159,9 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { - 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 - 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 - 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 }, 7.41e-03))); } diff --git a/tensorflow/contrib/lite/kernels/fake_quant.cc b/tensorflow/contrib/lite/kernels/fake_quant.cc new file mode 100644 index 0000000000000000000000000000000000000000..0ef1a50b308b2e8a781bc9ed7195c22e627ea2de --- /dev/null +++ b/tensorflow/contrib/lite/kernels/fake_quant.cc @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace fake_quant { + +// This file has reference implementation of FakeQuant. +enum KernelType { + kReference, +}; + +struct OpContext { + OpContext(TfLiteContext* context, TfLiteNode* node) { + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + const TfLiteTensor* input; + TfLiteTensor* output; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const auto* params = + reinterpret_cast(node->builtin_data); + + if (params->narrow_range) { + context->ReportError( + context, + "narrow_range FakeQuant is not currently supported at runtime. " + "narrow_range is only meant to be applied to weights, not activations"); + return kTfLiteError; + } + + OpContext op_context(context, node); + TfLiteIntArray* output_dims = TfLiteIntArrayCopy(op_context.input->dims); + op_context.output->type = op_context.input->type; + return context->ResizeTensor(context, op_context.output, output_dims); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + + const auto* params = + reinterpret_cast(node->builtin_data); + + reference_ops::FakeQuant(GetTensorData(op_context.input), + GetTensorDims(op_context.input), params->min, + params->max, params->num_bits, + GetTensorData(op_context.output), + GetTensorDims(op_context.output)); + + return kTfLiteOk; +} + +} // namespace fake_quant + +TfLiteRegistration* Register_FAKE_QUANT_REF() { + static TfLiteRegistration r = {nullptr, nullptr, fake_quant::Prepare, + fake_quant::Eval}; + return &r; +} + +TfLiteRegistration* Register_FAKE_QUANT() { return Register_FAKE_QUANT_REF(); } + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/fake_quant_test.cc b/tensorflow/contrib/lite/kernels/fake_quant_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..11a02f7ed7474e05b887955c111179d2d403f0e6 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/fake_quant_test.cc @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class FakeQuantOpModel : public SingleOpModel { + public: + FakeQuantOpModel(const TensorData& input, const TensorType& output, float min, + float max, int num_bits) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_FAKE_QUANT, BuiltinOptions_FakeQuantOptions, + CreateFakeQuantOptions(builder_, min, max, num_bits).Union()); + BuildInterpreter({GetShape(input_)}); + } + + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int output_; +}; + +TEST(FakeQuantOpTest, FloatPositiveRange8Test) { + std::initializer_list data = {0.0, 1.0, 0.25, + 0.50, 0.4444444, 0.00001}; + FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, 0.0f, + 1.0f, 8); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({0, 1, 0.25098, 0.498039, 0.443137, 0}))); +} + +TEST(FakeQuantOpTest, FloatNegativeRange8Test) { + std::initializer_list data = {0.0, -0.9, 0.25, + 0.50, 0.4444444, -0.00001}; + FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, -0.9f, + 0.9f, 8); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0, -0.896471, 0.247059, 0.501176, 0.444706, 0}))); +} + +TEST(FakeQuantOpTest, FloatPositiveRange16Test) { + std::initializer_list data = {0.0, 1.0, 0.25, + 0.50, 0.4444444, 0.00001}; + FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, 0.0f, + 1.0f, 16); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0, 1, 0.250004, 0.500008, 0.44445, 1.5259e-05}))); +} + +TEST(FakeQuantOpTest, FloatNegativeRange16Test) { + std::initializer_list data = {0.0, -0.9, 0.25, + 0.50, 0.4444444, -0.00001}; + FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, -0.9f, + 0.9f, 16); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0, -0.900014, 0.249998, 0.499995, 0.444431, 0}))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc index 3b203dd480f95c5dc70a69aafce0bac6ab2cbc06..bc370608c092eeb5312dc40b56f47740f473c8ae 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -71,7 +70,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { // Instead, we allocate a new object to carry information from Prepare() to // Eval(). gemm_support::IncrementUsageCounter(context); - auto* op_data = new OpData; + auto* op_data = new OpData(); context->AddTensors(context, 1, &op_data->input_quantized_index); return op_data; } diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc index 6a2341461f2c627c78bd4783ee27579b59b5fde3..2b2a9e662051287fd1e3dbe8978f4689dc731064 100644 --- a/tensorflow/contrib/lite/kernels/gather.cc +++ b/tensorflow/contrib/lite/kernels/gather.cc @@ -40,10 +40,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // Only INT32 positions are supported. TF_LITE_ENSURE_EQ(context, positions->type, kTfLiteInt32); - // Check that input and output types match. - TF_LITE_ENSURE_EQ(context, input->type, output->type); - // TODO(mgubin): only 0D or 1D positions are currently supported. - TF_LITE_ENSURE(context, NumDimensions(positions) <= 1); + // Assign to output the input type. + output->type = input->type; // TODO(mgubin): Only default axis == 0 is supported. TF_LITE_ENSURE_EQ(context, params->axis, 0); // Check conditions for different types. @@ -102,6 +100,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_GATHER(int32_t, int32_t); break; case kTfLiteString: { + // TODO(mgubin): Currently support only for 1D output tensors. DynamicBuffer buffer; const int32* indexes = positions->data.i32; const int num_strings = GetStringCount(input); diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc index cdadbeda1884ba0186846826dd16be6ff69878d9..1d4292955cced59a47e0500833a86113cb9d3eb8 100644 --- a/tensorflow/contrib/lite/kernels/gather_test.cc +++ b/tensorflow/contrib/lite/kernels/gather_test.cc @@ -96,6 +96,15 @@ TEST(GatherOpTest, Test0DIndexWith0DResult) { EXPECT_TRUE(m.GetOutputShape().empty()); } +TEST(GatherOpTest, Test2DIndexWith2DResult) { + GatherOpModel m({3}, TensorType_FLOAT32, {1, 2}); + m.SetInputFloat({1.0, 2.0, 3.0}); + m.SetPositions({1, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray(ArrayFloatNear({2.0, 1.0}))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); +} + TEST(FloatGatherOpTest, Duplicate) { GatherOpModel m({1, 2, 2}, TensorType_FLOAT32, {2}); m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/contrib/lite/kernels/gemm_support.cc index 95f45ea768be7f9bae9570563f161792afbff436..ed334af2da877edf9f591612478e22f04cf15931 100644 --- a/tensorflow/contrib/lite/kernels/gemm_support.cc +++ b/tensorflow/contrib/lite/kernels/gemm_support.cc @@ -14,57 +14,70 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include + #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { namespace gemm_support { +namespace { -struct RefCountedGemmContext { - gemmlowp::GemmContext* gemm_context_ = nullptr; - int num_references_ = 0; +struct RefCountedGemmContext : public TfLiteExternalContext { + std::unique_ptr gemm_context; + int num_references = 0; }; +RefCountedGemmContext* GetGemmLowpContext(TfLiteContext* context) { + return reinterpret_cast( + context->GetExternalContext(context, kTfLiteGemmLowpContext)); +} + +TfLiteStatus Refresh(TfLiteContext* context) { + auto* ptr = GetGemmLowpContext(context); + if (ptr != nullptr) { + ptr->gemm_context->set_max_num_threads(context->recommended_num_threads); + } + return kTfLiteOk; +} + +} // namespace + void IncrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast(context->gemm_context); + auto* ptr = GetGemmLowpContext(context); if (ptr == nullptr) { ptr = new RefCountedGemmContext; - ptr->gemm_context_ = new gemmlowp::GemmContext(); + ptr->type = kTfLiteGemmLowpContext; + ptr->Refresh = Refresh; + ptr->gemm_context.reset(new gemmlowp::GemmContext()); if (context->recommended_num_threads != -1) { - ptr->gemm_context_->set_max_num_threads(context->recommended_num_threads); + ptr->gemm_context->set_max_num_threads(context->recommended_num_threads); } - ptr->num_references_ = 0; - context->gemm_context = ptr; + ptr->num_references = 0; + context->SetExternalContext(context, kTfLiteGemmLowpContext, ptr); } - ptr->num_references_++; + ptr->num_references++; } void DecrementUsageCounter(TfLiteContext* context) { - auto* ptr = reinterpret_cast(context->gemm_context); + auto* ptr = GetGemmLowpContext(context); if (ptr == nullptr) { TF_LITE_FATAL( "Call to DecrementUsageCounter() not preceded by " "IncrementUsageCounter()"); } - if (--ptr->num_references_ == 0) { - delete ptr->gemm_context_; + if (--ptr->num_references == 0) { delete ptr; - context->gemm_context = nullptr; + context->SetExternalContext(context, kTfLiteGemmLowpContext, nullptr); } } gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) { - auto* ptr = reinterpret_cast(context->gemm_context); + auto* ptr = GetGemmLowpContext(context); if (ptr == nullptr) { TF_LITE_FATAL( "Call to GetFromContext() not preceded by IncrementUsageCounter()"); } - return ptr->gemm_context_; -} - -void SetNumThreads(TfLiteContext* context, int num_threads) { - IncrementUsageCounter(context); - GetFromContext(context)->set_max_num_threads(num_threads); - DecrementUsageCounter(context); + return ptr->gemm_context.get(); } } // namespace gemm_support diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h index f033501cb6e341aa014fa4d956b531bd79aa555b..37af772c6846f2f8124faabf1a0f0987e2e9393d 100644 --- a/tensorflow/contrib/lite/kernels/gemm_support.h +++ b/tensorflow/contrib/lite/kernels/gemm_support.h @@ -45,9 +45,6 @@ void IncrementUsageCounter(TfLiteContext* context); // 'context'. If there are no more usages the GemmContext will be deleted. void DecrementUsageCounter(TfLiteContext* context); -// Set the number of threads that can be used by gemmlowp. -void SetNumThreads(TfLiteContext* context, int num_threads); - } // namespace gemm_support } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc index 41211d41aa85a5a2da6ae96dc6f0337c54fb1a45..f37c66acb33eb9995772e595b84df6616e8d9e6a 100644 --- a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc @@ -31,7 +31,6 @@ limitations under the License. // Each item indicates whether the corresponding lookup has a returned value. // 0 for missing key, 1 for found key. -#include #include #include #include diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index 7962fcbc9d6c839ea11d7355e955239194442e03..0d424071da23010afe5f15a61e0ea6e45b4e6742 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -232,6 +232,7 @@ cc_library( cc_test( name = "tensor_test", srcs = ["tensor_test.cc"], + tags = ["no_oss"], deps = [ ":reference", "@com_google_googletest//:gtest", @@ -260,6 +261,7 @@ cc_library( cc_test( name = "quantization_util_test", srcs = ["quantization_util_test.cc"], + tags = ["no_oss"], deps = [ ":quantization_util", "@com_google_googletest//:gtest", @@ -479,6 +481,9 @@ cc_library( ":darwin": [ ":neon_tensor_utils", ], + ":darwin_x86_64": [ + ":neon_tensor_utils", + ], "//conditions:default": [ ":portable_tensor_utils", ], @@ -505,7 +510,10 @@ cc_test( "//conditions:default": [], }), linkstatic = 1, - tags = ["tflite_not_portable_ios"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], deps = [ ":tensor_utils", "//tensorflow/contrib/lite:builtin_op_data", @@ -517,6 +525,7 @@ cc_test( cc_test( name = "depthwiseconv_float_test", srcs = ["depthwiseconv_float_test.cc"], + tags = ["no_oss"], deps = [ ":optimized_base", ":reference_base", @@ -529,6 +538,7 @@ cc_test( cc_test( name = "depthwiseconv_quantized_test", srcs = ["depthwiseconv_quantized_test.cc"], + tags = ["no_oss"], deps = [ ":optimized_base", ":reference_base", @@ -541,7 +551,10 @@ cc_test( cc_test( name = "resize_bilinear_test", srcs = ["resize_bilinear_test.cc"], - tags = ["tflite_not_portable"], + tags = [ + "no_oss", + "tflite_not_portable", + ], deps = [ ":optimized_base", ":reference_base", @@ -557,6 +570,7 @@ cc_test( srcs = [ "softmax_quantized_test.cc", ], + tags = ["no_oss"], deps = [ ":optimized_base", ":quantization_util", @@ -572,7 +586,10 @@ cc_test( srcs = [ "logsoftmax_quantized_test.cc", ], - tags = ["tflite_not_portable"], + tags = [ + "no_oss", + "tflite_not_portable", + ], deps = [ ":optimized_base", ":quantization_util", @@ -585,6 +602,7 @@ cc_test( cc_test( name = "log_quantized_test", srcs = ["log_quantized_test.cc"], + tags = ["no_oss"], deps = [ ":optimized_base", ":reference_base", @@ -611,6 +629,7 @@ cc_library( cc_test( name = "batch_to_space_nd_test", srcs = ["batch_to_space_nd_test.cc"], + tags = ["no_oss"], deps = [ ":optimized_base", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h index b86ca49c116875672c4516a2a47f7dae511a7116..eb4d0108bd0438dd27744a864d071cfc166a7a94 100644 --- a/tensorflow/contrib/lite/kernels/internal/common.h +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -117,6 +117,9 @@ template int CountLeadingZeros(T integer_input) { static_assert(std::is_unsigned::value, "Only unsigned integer types handled."); +#if defined(__GNUC__) + return integer_input ? __builtin_clz(integer_input) : 0; +#else const T one_in_leading_positive = static_cast(1) << (std::numeric_limits::digits - 1); int leading_zeros = 0; @@ -125,6 +128,140 @@ int CountLeadingZeros(T integer_input) { ++leading_zeros; } return leading_zeros; +#endif +} + +// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING +// BROADCASTING. +// +// NdArrayDesc describes the shape and memory layout of an N-dimensional +// rectangular array of numbers. +// +// NdArrayDesc is basically identical to Dims defined in types.h. +// However, as Dims is to be deprecated, this class exists as an adaptor +// to enable simple unoptimized implementations of element-wise broadcasting +// operations. +template +struct NdArrayDesc { + // The "extent" of each dimension. Indices along dimension d must be in the + // half-open interval [0, extents[d]). + int extents[N]; + + // The number of *elements* (not bytes) between consecutive indices of each + // dimension. + int strides[N]; +}; + +// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING +// BROADCASTING. +// +// Same as Offset(), except takes as NdArrayDesc instead of Dims. +inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2, + int i3) { + TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]); + TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]); + TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]); + TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]); + return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] + + i3 * desc.strides[3]; +} + +// Given the dimensions of the operands for an element-wise binary broadcast, +// adjusts them so that they can be directly iterated over with simple loops. +// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and +// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr. +// +// This function assumes that the two input shapes are compatible up to +// broadcasting and the shorter one has already been prepended with 1s to be the +// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64), +// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that +// Dims refer to shapes in reverse order. In this case, input0_dims will be +// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1). +// +// When two shapes are compatible up to broadcasting, for each dimension d, +// the input extents are either equal, or one of them is 1. +// +// This function performs the following for each dimension d: +// - If the extents are equal, then do nothing since the loop that walks over +// both of the input arrays is correct. +// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1 +// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows +// array0 to be referenced *at any index* in dimension d and still access the +// same slice. +template +inline void NdArrayDescsForElementwiseBroadcast(const Dims& input0_dims, + const Dims& input1_dims, + NdArrayDesc* desc0_out, + NdArrayDesc* desc1_out) { + TFLITE_DCHECK(desc0_out != nullptr); + TFLITE_DCHECK(desc1_out != nullptr); + + // Copy dims to desc. + for (int i = 0; i < N; ++i) { + desc0_out->extents[i] = input0_dims.sizes[i]; + desc0_out->strides[i] = input0_dims.strides[i]; + desc1_out->extents[i] = input1_dims.sizes[i]; + desc1_out->strides[i] = input1_dims.strides[i]; + } + + // Walk over each dimension. If the extents are equal do nothing. + // Otherwise, set the desc with extent 1 to have extent equal to the other and + // stride 0. + for (int i = 0; i < N; ++i) { + const int extent0 = ArraySize(input0_dims, i); + const int extent1 = ArraySize(input1_dims, i); + if (extent0 != extent1) { + if (extent0 == 1) { + desc0_out->strides[i] = 0; + desc0_out->extents[i] = extent1; + } else { + TFLITE_DCHECK_EQ(extent1, 1); + desc1_out->strides[i] = 0; + desc1_out->extents[i] = extent0; + } + } + } +} + +template +inline void NdArrayDescsForElementwiseBroadcast( + const RuntimeShape& input0_shape, const RuntimeShape& input1_shape, + NdArrayDesc* desc0_out, NdArrayDesc* desc1_out) { + TFLITE_DCHECK(desc0_out != nullptr); + TFLITE_DCHECK(desc1_out != nullptr); + + auto extended_input0_shape = RuntimeShape::ExtendedShape(N, input0_shape); + auto extended_input1_shape = RuntimeShape::ExtendedShape(N, input1_shape); + + // Copy dims to desc, calculating strides. + int desc0_stride = 1; + int desc1_stride = 1; + for (int i = N - 1; i >= 0; --i) { + desc0_out->extents[i] = extended_input0_shape.Dims(i); + desc0_out->strides[i] = desc0_stride; + desc0_stride *= extended_input0_shape.Dims(i); + desc1_out->extents[i] = extended_input1_shape.Dims(i); + desc1_out->strides[i] = desc1_stride; + desc1_stride *= extended_input1_shape.Dims(i); + } + + // Walk over each dimension. If the extents are equal do nothing. + // Otherwise, set the desc with extent 1 to have extent equal to the other and + // stride 0. + for (int i = 0; i < N; ++i) { + const int extent0 = extended_input0_shape.Dims(i); + const int extent1 = extended_input1_shape.Dims(i); + if (extent0 != extent1) { + if (extent0 == 1) { + desc0_out->strides[i] = 0; + desc0_out->extents[i] = extent1; + } else { + TFLITE_DCHECK_EQ(extent1, 1); + desc1_out->strides[i] = 0; + desc1_out->extents[i] = extent0; + } + } + } } } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc index 36c25388e8bde721d7644dc83d5b7c490d37b4d3..200f2f151582c2361dd2403164d0bbe119cbed72 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -255,14 +255,6 @@ void LstmStep( output_state_ptr); } -// TODO(alanchiao): move this to tensor_utils. -void VectorMultiply(const int8_t* vector, const int v_size, const float scale, - float* result) { - for (int i = 0; i < v_size; ++i) { - *result++ = scale * *vector++; - } -} - void LstmStep( const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, float input_to_input_weights_scale, @@ -415,8 +407,9 @@ void LstmStep( // For each batch and cell: update input gate. if (!use_cifg) { if (use_peephole && !is_cell_state_all_zeros) { - VectorMultiply(cell_to_input_weights_ptr, n_cell, - 1. / cell_to_input_weights_scale, recovered_cell_weights); + tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell, + cell_to_input_weights_scale, + recovered_cell_weights); tensor_utils::VectorBatchVectorCwiseProductAccumulate( recovered_cell_weights, n_cell, cell_state_ptr, n_batch, input_gate_scratch); @@ -427,8 +420,9 @@ void LstmStep( // For each batch and cell: update forget gate. if (use_peephole && !is_cell_state_all_zeros) { - VectorMultiply(cell_to_forget_weights_ptr, n_cell, - 1. / cell_to_forget_weights_scale, recovered_cell_weights); + tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell, + cell_to_forget_weights_scale, + recovered_cell_weights); tensor_utils::VectorBatchVectorCwiseProductAccumulate( recovered_cell_weights, n_cell, cell_state_ptr, n_batch, forget_gate_scratch); @@ -459,8 +453,9 @@ void LstmStep( tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); // For each batch and cell: update the output gate. if (use_peephole && !is_cell_state_all_zeros) { - VectorMultiply(cell_to_output_weights_ptr, n_cell, - 1. / cell_to_output_weights_scale, recovered_cell_weights); + tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell, + cell_to_output_weights_scale, + recovered_cell_weights); tensor_utils::VectorBatchVectorCwiseProductAccumulate( recovered_cell_weights, n_cell, cell_state_ptr, n_batch, output_gate_scratch); diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h index 4cfaa0f36defa9c1f7d4a51af243c416bf09e331..0ce64f8c70d76f970df610f47947580a1efde720 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -3242,6 +3242,7 @@ inline void DepthwiseConv3x3Filter( int32 output_shift, int32 output_activation_min, int32 output_activation_max, uint8* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__); DepthwiseConvParams params; params.input_depth = ArraySize(input_dims, 0); params.input_width = ArraySize(input_dims, 1); diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h index d85e06a5d5af8d23235a08592d49754e4f493d34..250872c422a3ff9b3353d0055513ff1f7f03d68e 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h @@ -33,7 +33,7 @@ limitations under the License. #include #ifdef _WIN32 -#include +#include #elif defined(__APPLE__) #include #else diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h index 7816752132761d9523ffc1f45b3740c0817ed402..d5503073a7cfc0be137fde104815ca1a2a6bb438 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h @@ -55,15 +55,262 @@ inline void Relu(const float* input_data, const Dims<4>& input_dims, DimsToShape(output_dims)); } +// legacy, for compatibility with old checked-in code +template +void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + tflite::ArithmeticParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + Add(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +template +inline void Add(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, int input2_shift, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + constexpr int kReverseShift = -1; + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + + tflite::ArithmeticParams op_params; + op_params.left_shift = left_shift; + op_params.input1_offset = input1_offset; + op_params.input1_multiplier = input1_multiplier; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_offset = input2_offset; + op_params.input2_multiplier = input2_multiplier; + op_params.input2_shift = kReverseShift * input2_shift; + op_params.output_offset = output_offset; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = kReverseShift * output_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + Add(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +template +void Add(const int32* input1_data, const Dims<4>& input1_dims, + const int32* input2_data, const Dims<4>& input2_dims, + int32* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Add/int32"); + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + + tflite::ArithmeticParams op_params; + op_params.quantized_activation_min = std::numeric_limits::min(); + op_params.quantized_activation_max = std::numeric_limits::max(); + Add(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +template +void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, + T* output_data, const Dims<4>& output_dims) { + tflite::ArithmeticParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); +} + +template +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + constexpr int kReverseShift = -1; + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + + tflite::ArithmeticParams op_params; + op_params.left_shift = left_shift; + op_params.input1_offset = input1_offset; + op_params.input1_multiplier = input1_multiplier; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_offset = input2_offset; + op_params.input2_multiplier = input2_multiplier; + op_params.input2_shift = kReverseShift * input2_shift; + op_params.output_offset = output_offset; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = kReverseShift * output_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); +} + +template +inline void BroadcastAddFivefold( + int y0, int y1, int y2, int y3, int y4, int left_shift, + const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + constexpr int kReverseShift = -1; + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + tflite::ArithmeticParams op_params; + op_params.broadcast_category = + tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast; + op_params.left_shift = left_shift; + op_params.input1_offset = input1_offset; + op_params.input1_multiplier = input1_multiplier; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_offset = input2_offset; + op_params.input2_multiplier = input2_multiplier; + op_params.input2_shift = kReverseShift * input2_shift; + op_params.output_offset = output_offset; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = kReverseShift * output_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + op_params.broadcast_shape[4] = y0; + op_params.broadcast_shape[3] = y1; + op_params.broadcast_shape[2] = y2; + op_params.broadcast_shape[1] = y3; + op_params.broadcast_shape[0] = y4; + BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); +} + +// legacy, for compatibility with old checked-in code +template +void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + T output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +template +inline void Add(const int16* input1_data, const Dims<4>& input1_dims, + int input1_shift, const int16* input2_data, + const Dims<4>& input2_dims, int input2_shift, + int16 output_activation_min, int16 output_activation_max, + int16* output_data, const Dims<4>& output_dims) { + constexpr int kReverseShift = -1; + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, -32768); + TFLITE_DCHECK_EQ(output_activation_max, 32767); + } + + tflite::ArithmeticParams op_params; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_shift = kReverseShift * input2_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + Add(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +inline void Sub(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(FusedActivationFunctionType::kNone, + &output_activation_min, &output_activation_max); + tflite::ArithmeticParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + Sub(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +template +void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, + const Dims<4>& input2_dims, T* output_data, + const Dims<4>& output_dims) { + T output_activation_min, output_activation_max; + GetActivationMinMax(FusedActivationFunctionType::kNone, + &output_activation_min, &output_activation_max); + tflite::ArithmeticParams op_params; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + Sub(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + inline void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride_width, int stride_height, int pad_width, int pad_height, int kwidth, int kheight, float output_activation_min, float output_activation_max, float* output_data, const Dims<4>& output_dims) { - AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height, - pad_width, pad_height, kwidth, kheight, output_activation_min, - output_activation_max, output_data, DimsToShape(output_dims)); + tflite::PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = kheight; + params.filter_width = kwidth; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.float_activation_min = output_activation_min; + params.float_activation_max = output_activation_max; + AveragePool(params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); } // legacy, for compatibility with old checked-in code @@ -96,10 +343,17 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int32 output_activation_min, int32 output_activation_max, uint8* output_data, const Dims<4>& output_dims) { - AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height, - pad_width, pad_height, filter_width, filter_height, - output_activation_min, output_activation_max, output_data, - DimsToShape(output_dims)); + tflite::PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = filter_height; + params.filter_width = filter_width; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.quantized_activation_min = output_activation_min; + params.quantized_activation_max = output_activation_max; + AveragePool(params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); } // legacy, for compatibility with old checked-in code @@ -140,9 +394,17 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims, int pad_height, int kwidth, int kheight, float output_activation_min, float output_activation_max, float* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height, - pad_width, pad_height, kwidth, kheight, output_activation_min, - output_activation_max, output_data, DimsToShape(output_dims)); + tflite::PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = kheight; + params.filter_width = kwidth; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.float_activation_min = output_activation_min; + params.float_activation_max = output_activation_max; + MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims), + output_data); } // legacy, for compatibility with old checked-in code @@ -172,10 +434,17 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int pad_height, int filter_width, int filter_height, int32 output_activation_min, int32 output_activation_max, uint8* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height, - pad_width, pad_height, filter_width, filter_height, - output_activation_min, output_activation_max, output_data, - DimsToShape(output_dims)); + PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = filter_height; + params.filter_width = filter_width; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.quantized_activation_min = output_activation_min; + params.quantized_activation_max = output_activation_max; + MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims), + output_data); } // legacy, for compatibility with old checked-in code @@ -215,10 +484,17 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims, int pad_height, int filter_width, int filter_height, float output_activation_min, float output_activation_max, float* output_data, const Dims<4>& output_dims) { - L2Pool(input_data, DimsToShape(input_dims), stride_width, stride_height, - pad_width, pad_height, filter_width, filter_height, - output_activation_min, output_activation_max, output_data, - DimsToShape(output_dims)); + PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = filter_height; + params.filter_width = filter_width; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.float_activation_min = output_activation_min; + params.float_activation_max = output_activation_max; + L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims), + output_data); } // legacy, for compatibility with old checked-in code diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h index 27d9224512a835ea58911031f1b4d6dcf5482ba9..4a3545d47aca7d649061d39cbc23fa7ddf208156 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h @@ -35,35 +35,6 @@ limitations under the License. namespace tflite { namespace multithreaded_ops { -class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { - public: - explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {} - ~EigenThreadPoolWrapper() override {} - - void Schedule(std::function fn) override { - pool_->Schedule(std::move(fn)); - } - int NumThreads() const override { return pool_->NumThreads(); } - int CurrentThreadId() const override { return pool_->CurrentThreadId(); } - - private: - Eigen::ThreadPool* pool_ = nullptr; -}; - -// We have a single global threadpool for all convolution operations. This means -// that inferences started from different threads may block each other, but -// since the underlying resource of CPU cores should be consumed by the -// operations anyway, it shouldn't affect overall performance. -const Eigen::ThreadPoolDevice& GetThreadPoolDevice() { - const int thread_count = 4; - static Eigen::ThreadPool* tp = new Eigen::ThreadPool(thread_count); - static EigenThreadPoolWrapper* thread_pool_wrapper = - new EigenThreadPoolWrapper(tp); - static Eigen::ThreadPoolDevice* device = - new Eigen::ThreadPoolDevice(thread_pool_wrapper, thread_count); - return *device; -} - // Shorthands for the types we need when interfacing with the EigenTensor // library. typedef Eigen::TensorMap< @@ -113,14 +84,13 @@ class EigenTensorConvFunctor { } public: - void operator()(const T* input_data, T* im2col_buffer, int input_batches, - int input_height, int input_width, int input_depth, - const T* filter_data, int filter_height, int filter_width, - int filter_count, int stride_rows, int stride_cols, - int pad_width, int pad_height, TfLitePadding padding, - T* output_data, int output_height, int output_width) { - const Eigen::ThreadPoolDevice& device = GetThreadPoolDevice(); - + void operator()(const Eigen::ThreadPoolDevice& device, const T* input_data, + T* im2col_buffer, int input_batches, int input_height, + int input_width, int input_depth, const T* filter_data, + int filter_height, int filter_width, int filter_count, + int stride_rows, int stride_cols, int pad_width, + int pad_height, TfLitePadding padding, T* output_data, + int output_height, int output_width) { const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 && stride_rows == 1 && stride_cols == 1); if (is_1x1_kernel) { @@ -162,11 +132,11 @@ class EigenTensorConvFunctor { } }; -inline void Conv(const float* input_data, const Dims<4>& input_dims, - const float* filter_data, const Dims<4>& filter_dims, - const float* bias_data, const Dims<4>& bias_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, TfLitePadding padding, +inline void Conv(const Eigen::ThreadPoolDevice& device, const float* input_data, + const Dims<4>& input_dims, const float* filter_data, + const Dims<4>& filter_dims, const float* bias_data, + const Dims<4>& bias_dims, int stride_width, int stride_height, + int pad_width, int pad_height, TfLitePadding padding, float output_activation_min, float output_activation_max, float* output_data, const Dims<4>& output_dims, float* im2col_data, const Dims<4>& im2col_dims) { @@ -180,10 +150,11 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, const int output_height = ArraySize(output_dims, 2); const int output_width = ArraySize(output_dims, 1); EigenTensorConvFunctor conv_functor; - conv_functor(input_data, im2col_data, batches, input_height, input_width, - input_depth, filter_data, filter_height, filter_width, - output_depth, stride_height, stride_width, pad_height, pad_width, - padding, output_data, output_height, output_width); + conv_functor(device, input_data, im2col_data, batches, input_height, + input_width, input_depth, filter_data, filter_height, + filter_width, output_depth, stride_height, stride_width, + pad_height, pad_width, padding, output_data, output_height, + output_width); optimized_ops::AddBiasAndEvalActivationFunction( bias_data, bias_dims, output_data, output_dims, output_activation_min, diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc index 38ad32c734a2286c7d23162810625169a4d8df43..420bc68b43dd8c135e95badcc7d18935449cfc73 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -55,83 +55,33 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, const int postamble_start = m_cols - (m_cols & (kFloatWeightsPerNeonLane - 1)); - // The arrays used to cache the vector. - void* aligned_vector_cache_free = nullptr; - float32x4_t* vector_cache_float32x4 = - reinterpret_cast(aligned_alloc( - sizeof(float32x4_t), (postamble_start >> 2) * sizeof(float32x4_t), - &aligned_vector_cache_free)); - - const int kUnrollSize = 2; for (int b = 0; b < n_batch; b++) { float* result_in_batch = result + b * m_rows * result_stride; const float* vector_in_batch = vector + b * m_cols; + const float* matrix_row = matrix; - const float* matrix_ptr0 = matrix; - // If there is only 1 row, we don't want to assign an illegal pointer. - const float* matrix_ptr1 = nullptr; - if (m_rows > 1) { - matrix_ptr1 = matrix + m_cols; - } - - // Cache the vector. - for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { - vector_cache_float32x4[c >> 2] = vld1q_f32(vector_in_batch + c); - } - - // Main matrix by vector multiplication loop, which handles two rows of - // matrix by vector multiplication. - for (int r = 0; r < (m_rows & ~(kUnrollSize - 1)); r += kUnrollSize) { - float32x4_t acc0_32x4 = vmovq_n_f32(0.0); - float32x4_t acc1_32x4 = vmovq_n_f32(0.0); + // Main matrix by vector multiplication loop + for (int r = 0; r < m_rows; r++) { + float32x4_t acc_32x4 = vmovq_n_f32(0.0); for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { - float32x4_t temp = vector_cache_float32x4[c >> 2]; - // Load 4 float values from vector1 and vector2 and accumulator. - float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c); - float32x4_t v1_f32x4 = vld1q_f32(matrix_ptr1 + c); - // Vector multiply-accumulate 4 float - acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp); - acc1_32x4 = vmlaq_f32(acc1_32x4, v1_f32x4, temp); + // Load 4 float values from vector and matrix row. + float32x4_t vector_f32x4 = vld1q_f32(vector_in_batch + c); + float32x4_t matrix_f32x4 = vld1q_f32(matrix_row + c); + // Multiply the vector and matrix row and add to accumulator. + acc_32x4 = vmlaq_f32(acc_32x4, matrix_f32x4, vector_f32x4); } // Add the 4 intermediate sum values to get the final dot-prod value for // this column. *result_in_batch += - (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) + - vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3)); - *(result_in_batch + result_stride) += - (vgetq_lane_f32(acc1_32x4, 0) + vgetq_lane_f32(acc1_32x4, 1) + - vgetq_lane_f32(acc1_32x4, 2) + vgetq_lane_f32(acc1_32x4, 3)); + (vgetq_lane_f32(acc_32x4, 0) + vgetq_lane_f32(acc_32x4, 1) + + vgetq_lane_f32(acc_32x4, 2) + vgetq_lane_f32(acc_32x4, 3)); for (int c = postamble_start; c < m_cols; c++) { - *result_in_batch += matrix_ptr0[c] * vector_in_batch[c]; - *(result_in_batch + result_stride) += - matrix_ptr1[c] * vector_in_batch[c]; + *result_in_batch += matrix_row[c] * vector_in_batch[c]; } - matrix_ptr0 += kUnrollSize * m_cols; - matrix_ptr1 += kUnrollSize * m_cols; - result_in_batch += kUnrollSize * result_stride; - } - for (int r = (m_rows & ~(kUnrollSize - 1)); r < m_rows; r++) { - float32x4_t acc0_32x4 = vmovq_n_f32(0.0); - for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { - float32x4_t temp = vector_cache_float32x4[c >> 2]; - // Load 4 float values from vector1 and vector2 and accumulator. - float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c); - // Vector multiply-accumulate 4 float - acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp); - } - // Add the 4 intermediate sum values to get the final dot-prod value for - // this column. - *result_in_batch += - (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) + - vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3)); - for (int c = postamble_start; c < m_cols; c++) { - *result_in_batch += matrix_ptr0[c] * vector_in_batch[c]; - } - matrix_ptr0 += m_cols; + matrix_row += m_cols; result_in_batch += result_stride; } } - free(aligned_vector_cache_free); } void NeonMatrixBatchVectorMultiplyAccumulate( @@ -162,7 +112,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate( int batch, row, col; for (batch = 0; batch < n_batch; ++batch) { - const float batch_scaling_factor_inv = 1.0 / scaling_factors[batch]; + const float batch_scaling_factor = scaling_factors[batch]; // Copy the vector data to an aligned vector. memcpy(aligned_vec, vectors + batch * m_cols, sizeof(int8) * m_cols); // Compute dot-product for every column. @@ -232,7 +182,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate( int32 neon_sum = vgetq_lane_s64(pairwiseAdded, 0) + vgetq_lane_s64(pairwiseAdded, 1); - *result += ((neon_sum + postable_sum) * batch_scaling_factor_inv); + *result += ((neon_sum + postable_sum) * batch_scaling_factor); } // for row } // for batch @@ -296,17 +246,6 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector, const int postamble_start = v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); - // The arrays used to cache the vector. - void* aligned_vector_cache_free = nullptr; - float32x4_t* vector_cache_float32x4 = - reinterpret_cast(aligned_alloc( - sizeof(float32x4_t), (postamble_start >> 2) * sizeof(float32x4_t), - &aligned_vector_cache_free)); - - for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { - vector_cache_float32x4[v >> 2] = vld1q_f32(vector + v); - } - float* result_ptr = result; const float* batch_vector_ptr = batch_vector; for (int b = 0; b < n_batch; b++) { @@ -314,9 +253,9 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector, // Load from memory to vectors. float32x4_t result_f32x4 = vld1q_f32(result_ptr + v); float32x4_t batch_vector_f32x4 = vld1q_f32(batch_vector_ptr + v); + float32x4_t vector_f32x4 = vld1q_f32(vector + v); // Multiply-accumulate. - result_f32x4 = vmlaq_f32(result_f32x4, batch_vector_f32x4, - vector_cache_float32x4[v >> 2]); + result_f32x4 = vmlaq_f32(result_f32x4, batch_vector_f32x4, vector_f32x4); // Store. vst1q_f32(result_ptr + v, result_f32x4); } @@ -328,7 +267,6 @@ void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector, result_ptr += v_size; batch_vector_ptr += v_size; } - free(aligned_vector_cache_free); } void NeonSub1Vector(const float* vector, int v_size, float* result) { @@ -404,6 +342,77 @@ void NeonClipVector(const float* vector, int v_size, float abs_limit, } } +void NeonVectorScalarMultiply(const int8_t* vector, const int v_size, + const float scale, float* result) { + // Here the assumption is that each buffer is 4-byte aligned. + const int kWeightsPerUint32 = 4; + TFLITE_CHECK_EQ((intptr_t)(&vector[0]) & (kWeightsPerUint32 - 1), 0); + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int kWeightsPerNeonLane = 16; + const int postamble_start = v_size - (v_size & (kWeightsPerNeonLane - 1)); + + // Create a vector of 4 floats with the scale value. + const float32x4_t scale_f32x4 = vdupq_n_f32(scale); + int v = 0; + for (; v < postamble_start; v += kWeightsPerNeonLane) { + // Load int8 values, sixteen at a time. + const int8x16_t v_i8x16 = vld1q_s8(vector + v); + // Split it into two components of size eight. + const int8x8_t v0_i8x8 = vget_low_s8(v_i8x16); + const int8x8_t v1_i8x8 = vget_high_s8(v_i8x16); + // Convert both components to int16 first. + const int16x8_t v0_i16x8 = vmovl_s8(v0_i8x8); + const int16x8_t v1_i16x8 = vmovl_s8(v1_i8x8); + // Split each of them into two components each. + const int16x4_t v0_i16x4 = vget_low_s16(v0_i16x8); + const int16x4_t v1_i16x4 = vget_high_s16(v0_i16x8); + const int16x4_t v2_i16x4 = vget_low_s16(v1_i16x8); + const int16x4_t v3_i16x4 = vget_high_s16(v1_i16x8); + // Convert these to int32 and then to float. + float32x4_t v0_f32x4 = vcvtq_f32_s32(vmovl_s16(v0_i16x4)); + float32x4_t v1_f32x4 = vcvtq_f32_s32(vmovl_s16(v1_i16x4)); + float32x4_t v2_f32x4 = vcvtq_f32_s32(vmovl_s16(v2_i16x4)); + float32x4_t v3_f32x4 = vcvtq_f32_s32(vmovl_s16(v3_i16x4)); + // Vector multiply four floats at a time. + v0_f32x4 = vmulq_f32(v0_f32x4, scale_f32x4); + v1_f32x4 = vmulq_f32(v1_f32x4, scale_f32x4); + v2_f32x4 = vmulq_f32(v2_f32x4, scale_f32x4); + v3_f32x4 = vmulq_f32(v3_f32x4, scale_f32x4); + // Store the results. + vst1q_f32(result + v, v0_f32x4); + vst1q_f32(result + v + 4, v1_f32x4); + vst1q_f32(result + v + 8, v2_f32x4); + vst1q_f32(result + v + 12, v3_f32x4); + } + + if (v_size - postamble_start >= (kWeightsPerNeonLane >> 1)) { + // Load eight int8 values, if there is at least eight remaining. + const int8x8_t v_i8x8 = vld1_s8(vector + v); + // Convert them to int16 first. + const int16x8_t v_i16x8 = vmovl_s8(v_i8x8); + // Split it into two components. + const int16x4_t v0_i16x4 = vget_low_s16(v_i16x8); + const int16x4_t v1_i16x4 = vget_high_s16(v_i16x8); + // Convert the components two floats. + float32x4_t v0_f32x4 = vcvtq_f32_s32(vmovl_s16(v0_i16x4)); + float32x4_t v1_f32x4 = vcvtq_f32_s32(vmovl_s16(v1_i16x4)); + // Vector multiply four floats at a time. + v0_f32x4 = vmulq_f32(v0_f32x4, scale_f32x4); + v1_f32x4 = vmulq_f32(v1_f32x4, scale_f32x4); + // Store the results. + vst1q_f32(result + v, v0_f32x4); + vst1q_f32(result + v + 4, v1_f32x4); + v += (kWeightsPerNeonLane >> 1); + } + + // Postamble loop. + for (; v < v_size; v++) { + result[v] = scale * vector[v]; + } +} + void NeonSymmetricQuantizeFloats(const float* values, const int size, int8_t* quantized_values, float* min, float* max, float* scaling_factor) { @@ -418,13 +427,14 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size, *scaling_factor = 1; return; } - *scaling_factor = kScale / range; + *scaling_factor = range / kScale; + const float scaling_factor_inv = 1.0f / *scaling_factor; const int postamble_start = size - (size & (2 * kFloatWeightsPerNeonLane - 1)); // Vectorized constants. - const float32x4_t q_factor_f32x4 = vmovq_n_f32(*scaling_factor); + const float32x4_t q_factor_f32x4 = vmovq_n_f32(scaling_factor_inv); const float32x4_t point5_f32x4 = vmovq_n_f32(0.5); const float32x4_t zero_f32x4 = vmovq_n_f32(0.0); const int32x4_t scale_i32x4 = vmovq_n_s32(kScale); @@ -476,7 +486,7 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size, for (int i = postamble_start; i < size; ++i) { const int32 quantized_value = - static_cast(TfLiteRound(*scaling_factor * values[i])); + static_cast(TfLiteRound(scaling_factor_inv * values[i])); quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value)); } } diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h index 7a5a8fc54123946229963abd1720030d0bb358bf..63c89d1eeef47b206fc871929f1fb1295b2f70ff 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -105,16 +105,20 @@ bool IsZeroVector(const float* vector, int v_size) { return NEON_OR_PORTABLE(IsZeroVector, vector, v_size); } +void VectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result) { + NEON_OR_PORTABLE(VectorScalarMultiply, vector, v_size, scale, result); +} void ClipVector(const float* vector, int v_size, float abs_limit, float* result) { NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result); } void SymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, float* min, float* max, - float* scaling_factor) { - NEON_OR_PORTABLE(SymmetricQuantizeFloats, values, size, quantized_values, min, - max, scaling_factor); + int8_t* quantized_values, float* min_value, + float* max_value, float* scaling_factor) { + NEON_OR_PORTABLE(SymmetricQuantizeFloats, values, size, quantized_values, + min_value, max_value, scaling_factor); } void VectorShiftLeft(float* vector, int v_size, float shift_value) { diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 1b8a7205e6a6eac9f8f0f6580d41f7d6722fa9c2..6adb879c71e6a02007dacd5ed9f91b04b2094fe7 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -41,10 +41,13 @@ namespace optimized_ops { // Unoptimized reference ops: using reference_ops::ArgMax; +using reference_ops::ArgMinMax; +using reference_ops::BroadcastAdd4DSlow; using reference_ops::BroadcastGreater; using reference_ops::BroadcastGreaterEqual; using reference_ops::BroadcastLess; using reference_ops::BroadcastLessEqual; +using reference_ops::BroadcastSub4DSlow; using reference_ops::Concatenation; using reference_ops::DepthConcatenation; using reference_ops::Dequantize; @@ -59,6 +62,7 @@ using reference_ops::Mean; using reference_ops::RankOneSelect; using reference_ops::Relu1; using reference_ops::Relu6; +using reference_ops::ReluX; using reference_ops::Select; using reference_ops::SpaceToBatchND; using reference_ops::StridedSlice; @@ -164,6 +168,18 @@ ArrayMap MapAsArrayWithFirstDimAsRows(Scalar* data, return ArrayMap(data, rows, cols); } +// Copied from tensorflow/core/framework/tensor_types.h +template +struct TTypes { + // Rank-1 tensor (vector) of scalar type T. + typedef Eigen::TensorMap, + Eigen::Aligned> + Flat; + typedef Eigen::TensorMap< + Eigen::Tensor> + UnalignedConstMatrix; +}; + // TODO(b/62193649): this function is only needed as long // as we have the --variable_batch hack. template @@ -215,98 +231,6 @@ SaturatingRoundingMultiplyByPOTParam( SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent)); } -// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE -// BROADCASTING. -// -// NdArrayDesc describes the shape and memory layout of an N-dimensional -// rectangular array of numbers. -// -// NdArrayDesc is basically identical to Dims defined in types.h. -// However, as Dims is to be deprecated, this class exists as an adaptor -// to enable simple unoptimized implementations of element-wise broadcasting -// operations. -template -struct NdArrayDesc { - // The "extent" of each dimension. Indices along dimension d must be in the - // half-open interval [0, extents[d]). - int extents[N]; - - // The number of *elements* (not bytes) between consecutive indices of each - // dimension. - int strides[N]; -}; - -// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING -// ELEMENT-WISE BROADCASTING. -// -// Same as Offset(), except takes as NdArrayDesc instead of Dims. -inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2, - int i3) { - TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]); - TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]); - TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]); - TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]); - return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] + - i3 * desc.strides[3]; -} - -// Given the dimensions of the operands for an element-wise binary broadcast, -// adjusts them so that they can be directly iterated over with simple loops. -// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and -// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr. -// -// This function assumes that the two input shapes are compatible up to -// broadcasting and the shorter one has already been prepended with 1s to be the -// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64), -// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that -// Dims refer to shapes in reverse order. In this case, input0_dims will be -// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1). -// -// When two shapes are compatible up to broadcasting, for each dimension d, -// the input extents are either equal, or one of them is 1. -// -// This function performs the following for each dimension d: -// - If the extents are equal, then do nothing since the loop that walks over -// both of the input arrays is correct. -// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1 -// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows -// array0 to be referenced *at any index* in dimension d and still access the -// same slice. -template -inline void NdArrayDescsForElementwiseBroadcast(const Dims& input0_dims, - const Dims& input1_dims, - NdArrayDesc* desc0_out, - NdArrayDesc* desc1_out) { - TFLITE_DCHECK(desc0_out != nullptr); - TFLITE_DCHECK(desc1_out != nullptr); - - // Copy dims to desc. - for (int i = 0; i < N; ++i) { - desc0_out->extents[i] = input0_dims.sizes[i]; - desc0_out->strides[i] = input0_dims.strides[i]; - desc1_out->extents[i] = input1_dims.sizes[i]; - desc1_out->strides[i] = input1_dims.strides[i]; - } - - // Walk over each dimension. If the extents are equal do nothing. - // Otherwise, set the desc with extent 1 to have extent equal to the other and - // stride 0. - for (int i = 0; i < N; ++i) { - const int extent0 = ArraySize(input0_dims, i); - const int extent1 = ArraySize(input1_dims, i); - if (extent0 != extent1) { - if (extent0 == 1) { - desc0_out->strides[i] = 0; - desc0_out->extents[i] = extent1; - } else { - TFLITE_DCHECK_EQ(extent1, 1); - desc1_out->strides[i] = 0; - desc1_out->extents[i] = extent0; - } - } - } -} - inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) { for (int i = 0; i < 4; i++) { if (dims1.sizes[i] != dims2.sizes[i]) { @@ -1106,10 +1030,10 @@ inline void FullyConnectedAsGEMV( struct GemmlowpOutputPipeline { typedef gemmlowp::VectorMap ColVectorMap; - typedef std::tuple< - gemmlowp::OutputStageBiasAddition, - gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint, - gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToUint8> + typedef std::tuple, + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent, + gemmlowp::OutputStageClamp, + gemmlowp::OutputStageSaturatingCastToUint8> Pipeline; static Pipeline MakeExp(const int32* bias_data, int output_rows, int32 output_offset, int32 output_multiplier, @@ -1118,11 +1042,10 @@ struct GemmlowpOutputPipeline { ColVectorMap bias_vector(bias_data, output_rows); gemmlowp::OutputStageBiasAddition bias_addition_stage; bias_addition_stage.bias_vector = bias_vector; - gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint - quantize_down_stage; + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage; quantize_down_stage.result_offset_after_shift = output_offset; quantize_down_stage.result_fixedpoint_multiplier = output_multiplier; - quantize_down_stage.result_shift = -output_left_shift; + quantize_down_stage.result_exponent = output_left_shift; gemmlowp::OutputStageClamp clamp_stage; clamp_stage.min = output_activation_min; clamp_stage.max = output_activation_max; @@ -2403,7 +2326,8 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input, ++*output_shift; } TFLITE_DCHECK_GT(input, 0); - const unsigned max_left_shift_bits = __builtin_clz(input) - 1; + const unsigned max_left_shift_bits = + CountLeadingZeros(static_cast(input)) - 1; const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; *output_shift -= left_shift_bit_pairs; @@ -2476,20 +2400,17 @@ inline void L2Normalization(const uint8* input_data, } } -inline void Add(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { +inline void Add(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const float* input1_data, + const RuntimeShape& input2_shape, const float* input2_data, + const RuntimeShape& output_shape, float* output_data) { gemmlowp::ScopedProfilingLabel label("Add"); - TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); int i = 0; - const int size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape); #ifdef USE_NEON - const auto activation_min = vdupq_n_f32(output_activation_min); - const auto activation_max = vdupq_n_f32(output_activation_max); + const auto activation_min = vdupq_n_f32(params.float_activation_min); + const auto activation_max = vdupq_n_f32(params.float_activation_max); for (; i <= size - 16; i += 16) { auto a10 = vld1q_f32(input1_data + i); auto a11 = vld1q_f32(input1_data + i + 4); @@ -2528,29 +2449,26 @@ inline void Add(const float* input1_data, const Dims<4>& input1_dims, for (; i < size; i++) { auto x = input1_data[i] + input2_data[i]; - output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min, - output_activation_max); + output_data[i] = ActivationFunctionWithMinMax( + x, params.float_activation_min, params.float_activation_max); } } // Element-wise add that can often be used for inner loop of broadcast add as // well as the non-broadcast add. -inline void AddElementwise(int size, int left_shift, const uint8* input1_data, - int32 input1_offset, int32 input1_multiplier, - int input1_shift, const uint8* input2_data, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data) { +inline void AddElementwise(int size, const ArithmeticParams& params, + const uint8* input1_data, const uint8* input2_data, + uint8* output_data) { int i = 0; - TFLITE_DCHECK_GT(input1_offset, -256); - TFLITE_DCHECK_GT(input2_offset, -256); - TFLITE_DCHECK_LT(input1_offset, 256); - TFLITE_DCHECK_LT(input2_offset, 256); + TFLITE_DCHECK_GT(params.input1_offset, -256); + TFLITE_DCHECK_GT(params.input2_offset, -256); + TFLITE_DCHECK_LT(params.input1_offset, 256); + TFLITE_DCHECK_LT(params.input2_offset, 256); #ifdef USE_NEON - const auto output_activation_min_vector = vdup_n_u8(output_activation_min); - const auto output_activation_max_vector = vdup_n_u8(output_activation_max); + const auto output_activation_min_vector = + vdup_n_u8(params.quantized_activation_min); + const auto output_activation_max_vector = + vdup_n_u8(params.quantized_activation_max); for (; i <= size - 8; i += 8) { const auto input1_val_original = vld1_u8(input1_data + i); const auto input2_val_original = vld1_u8(input2_data + i); @@ -2559,9 +2477,9 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data, const auto input2_val_s16 = vreinterpretq_s16_u16(vmovl_u8(input2_val_original)); const auto input1_val = - vaddq_s16(input1_val_s16, vdupq_n_s16(input1_offset)); + vaddq_s16(input1_val_s16, vdupq_n_s16(params.input1_offset)); const auto input2_val = - vaddq_s16(input2_val_s16, vdupq_n_s16(input2_offset)); + vaddq_s16(input2_val_s16, vdupq_n_s16(params.input2_offset)); const auto input1_val_high = vget_high_s16(input1_val); const auto input1_val_low = vget_low_s16(input1_val); const auto input2_val_high = vget_high_s16(input2_val); @@ -2570,32 +2488,32 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data, auto x12 = vmovl_s16(input1_val_high); auto x21 = vmovl_s16(input2_val_low); auto x22 = vmovl_s16(input2_val_high); - const auto left_shift_dup = vdupq_n_s32(left_shift); + const auto left_shift_dup = vdupq_n_s32(params.left_shift); x11 = vshlq_s32(x11, left_shift_dup); x12 = vshlq_s32(x12, left_shift_dup); x21 = vshlq_s32(x21, left_shift_dup); x22 = vshlq_s32(x22, left_shift_dup); - x11 = vqrdmulhq_n_s32(x11, input1_multiplier); - x12 = vqrdmulhq_n_s32(x12, input1_multiplier); - x21 = vqrdmulhq_n_s32(x21, input2_multiplier); - x22 = vqrdmulhq_n_s32(x22, input2_multiplier); - const auto input1_shift_dup = vdupq_n_s32(-input1_shift); - const auto input2_shift_dup = vdupq_n_s32(-input2_shift); + x11 = vqrdmulhq_n_s32(x11, params.input1_multiplier); + x12 = vqrdmulhq_n_s32(x12, params.input1_multiplier); + x21 = vqrdmulhq_n_s32(x21, params.input2_multiplier); + x22 = vqrdmulhq_n_s32(x22, params.input2_multiplier); + const auto input1_shift_dup = vdupq_n_s32(params.input1_shift); + const auto input2_shift_dup = vdupq_n_s32(params.input2_shift); x11 = vshlq_s32(x11, input1_shift_dup); x12 = vshlq_s32(x12, input1_shift_dup); x21 = vshlq_s32(x21, input2_shift_dup); x22 = vshlq_s32(x22, input2_shift_dup); auto s1 = vaddq_s32(x11, x21); auto s2 = vaddq_s32(x12, x22); - s1 = vqrdmulhq_n_s32(s1, output_multiplier); - s2 = vqrdmulhq_n_s32(s2, output_multiplier); + s1 = vqrdmulhq_n_s32(s1, params.output_multiplier); + s2 = vqrdmulhq_n_s32(s2, params.output_multiplier); using gemmlowp::RoundingDivideByPOT; - s1 = RoundingDivideByPOT(s1, output_shift); - s2 = RoundingDivideByPOT(s2, output_shift); + s1 = RoundingDivideByPOT(s1, -params.output_shift); + s2 = RoundingDivideByPOT(s2, -params.output_shift); const auto s1_narrowed = vmovn_s32(s1); const auto s2_narrowed = vmovn_s32(s2); const auto s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed), - vdupq_n_s16(output_offset)); + vdupq_n_s16(params.output_offset)); const auto clamped = vmax_u8(output_activation_min_vector, vmin_u8(output_activation_max_vector, vqmovun_s16(s))); @@ -2604,101 +2522,74 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data, #endif // NEON for (; i < size; ++i) { - const int32 input1_val = input1_offset + input1_data[i]; - const int32 input2_val = input2_offset + input2_data[i]; - const int32 shifted_input1_val = input1_val * (1 << left_shift); - const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 input1_val = params.input1_offset + input1_data[i]; + const int32 input2_val = params.input2_offset + input2_data[i]; + const int32 shifted_input1_val = input1_val * (1 << params.left_shift); + const int32 shifted_input2_val = input2_val * (1 << params.left_shift); const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, - kReverseShift * input1_shift); + shifted_input1_val, params.input1_multiplier, params.input1_shift); const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, - kReverseShift * input2_shift); + shifted_input2_val, params.input2_multiplier, params.input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp( - raw_sum, output_multiplier, kReverseShift * output_shift) + - output_offset; - const int32 clamped_output = std::min( - output_activation_max, std::max(output_activation_min, raw_output)); + raw_sum, params.output_multiplier, params.output_shift) + + params.output_offset; + const int32 clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, raw_output)); output_data[i] = static_cast(clamped_output); } } -// legacy, for compatibility with old checked-in code -template -void Add(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float* output_data, const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - - Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, - output_activation_max, output_data, output_dims); -} - -template -inline void Add(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, int input2_shift, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } +inline void Add(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const uint8* input1_data, + const RuntimeShape& input2_shape, const uint8* input2_data, + const RuntimeShape& output_shape, uint8* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); gemmlowp::ScopedProfilingLabel label("Add/8bit"); - const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); - TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); - - TFLITE_DCHECK_GT(input1_offset, -256); - TFLITE_DCHECK_GT(input2_offset, -256); - TFLITE_DCHECK_LT(input1_offset, 256); - TFLITE_DCHECK_LT(input2_offset, 256); - AddElementwise(flat_size, left_shift, input1_data, input1_offset, - input1_multiplier, input1_shift, input2_data, input2_offset, - input2_multiplier, input2_shift, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_data); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + + TFLITE_DCHECK_GT(params.input1_offset, -256); + TFLITE_DCHECK_GT(params.input2_offset, -256); + TFLITE_DCHECK_LT(params.input1_offset, 256); + TFLITE_DCHECK_LT(params.input2_offset, 256); + AddElementwise(flat_size, params, input1_data, input2_data, output_data); } -inline void Add(const int16* input1_data, const Dims<4>& input1_dims, - int input1_shift, const int16* input2_data, - const Dims<4>& input2_dims, int input2_shift, - int16 output_activation_min, int16 output_activation_max, - int16* output_data, const Dims<4>& output_dims) { +inline void Add(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const int16* input1_data, + const RuntimeShape& input2_shape, const int16* input2_data, + const RuntimeShape& output_shape, int16* output_data) { gemmlowp::ScopedProfilingLabel label("Add/Int16"); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - - const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); - - TFLITE_DCHECK(input1_shift == 0 || input2_shift == 0); - TFLITE_DCHECK_GE(input1_shift, 0); - TFLITE_DCHECK_GE(input2_shift, 0); + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + + const int input1_shift = params.input1_shift; + const int flat_size = + MatchingFlatSize(output_shape, input1_shape, input2_shape); + const int16 output_activation_min = params.quantized_activation_min; + const int16 output_activation_max = params.quantized_activation_max; + + TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0); + TFLITE_DCHECK_LE(input1_shift, 0); + TFLITE_DCHECK_LE(params.input2_shift, 0); const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data; const int16* shift_input = input1_shift == 0 ? input2_data : input1_data; - const int input_shift = input1_shift == 0 ? input2_shift : input1_shift; + const int input_right_shift = + input1_shift == 0 ? -params.input2_shift : -input1_shift; for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. using F0 = gemmlowp::FixedPoint; F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]); - F0 scaled_input = - F0::FromRaw(gemmlowp::RoundingDivideByPOT(shift_input[i], input_shift)); + F0 scaled_input = F0::FromRaw( + gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift)); F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled); const int16 raw_output = result.raw(); const int16 clamped_output = std::min( @@ -2707,195 +2598,59 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims, } } -inline void Add(const int32* input1_data, const Dims<4>& input1_dims, - const int32* input2_data, const Dims<4>& input2_dims, - int32 output_activation_min, int32 output_activation_max, - int32* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Add/int32"); - - const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); - for (int i = 0; i < flat_size; ++i) { - output_data[i] = ActivationFunctionWithMinMax( - input1_data[i] + input2_data[i], output_activation_min, - output_activation_max); - } -} - -template -inline void Add(const int16* input1_data, const Dims<4>& input1_dims, - int input1_shift, const int16* input2_data, - const Dims<4>& input2_dims, int input2_shift, - int16 output_activation_min, int16 output_activation_max, - int16* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, -32768); - TFLITE_DCHECK_EQ(output_activation_max, 32767); - } - - Add(input1_data, input1_dims, input1_shift, input2_data, input2_dims, - input2_shift, output_activation_min, output_activation_max, output_data, - output_dims); -} - -template -void Add(const int32* input1_data, const Dims<4>& input1_dims, - const int32* input2_data, const Dims<4>& input2_dims, - int32* output_data, const Dims<4>& output_dims) { +inline void Add(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const int32* input1_data, + const RuntimeShape& input2_shape, const int32* input2_data, + const RuntimeShape& output_shape, int32* output_data) { gemmlowp::ScopedProfilingLabel label("Add/int32"); - TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); - auto input1_map = MapAsVector(input1_data, input1_dims); - auto input2_map = MapAsVector(input2_data, input2_dims); - auto output_map = MapAsVector(output_data, output_dims); - if (AreSameDims(input1_dims, input2_dims)) { + auto input1_map = MapAsVector(input1_data, input1_shape); + auto input2_map = MapAsVector(input2_data, input2_shape); + auto output_map = MapAsVector(output_data, output_shape); + if (input1_shape == input2_shape) { output_map.array() = input1_map.array() + input2_map.array(); - } else if (FlatSize(input2_dims) == 1) { + } else if (input2_shape.FlatSize() == 1) { auto scalar = input2_data[0]; output_map.array() = input1_map.array() + scalar; - } else if (FlatSize(input1_dims) == 1) { + } else if (input1_shape.FlatSize() == 1) { auto scalar = input1_data[0]; output_map.array() = scalar + input2_map.array(); } else { // Should not come here. TFLITE_DCHECK(false); } + output_map = output_map.cwiseMax(params.quantized_activation_min); + output_map = output_map.cwiseMin(params.quantized_activation_max); } -// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary -// dimensionality if the runtime code does a single loop over one dimension -// that handles broadcasting as the base case. The code generator would then -// generate max(D1, D2) nested for loops. -// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from -// reference_ops.h. Once an optimized version is implemented and NdArrayDesc -// is no longer referenced in this file, move NdArrayDesc from types.h to -// reference_ops.h. -template -void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T output_activation_min, T output_activation_max, - T* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastAdd"); - - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - - // In Tensorflow, the dimensions are canonically named (batch_number, row, - // col, channel), with extents (batches, height, width, depth), with the - // trailing dimension changing most rapidly (channels has the smallest stride, - // typically 1 element). - // - // In generated C code, we store arrays with the dimensions reversed. The - // first dimension has smallest stride. - // - // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for the - // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = - ActivationFunctionWithMinMax( - input1_data[SubscriptToIndex(desc1, c, x, y, b)] + - input2_data[SubscriptToIndex(desc2, c, x, y, b)], - output_activation_min, output_activation_max); - } - } - } - } -} - -// legacy, for compatibility with old checked-in code -template -void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T* output_data, const Dims<4>& output_dims) { - T output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - - BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims, - output_activation_min, output_activation_max, output_data, - output_dims); -} - -inline void BroadcastAdd(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastAddGeneric/8bit"); - - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - - // In Tensorflow, the dimensions are canonically named (batch_number, row, - // col, channel), with extents (batches, height, width, depth), with the - // trailing dimension changing most rapidly (channels has the smallest stride, - // typically 1 element). - // - // In generated C code, we store arrays with the dimensions reversed. The - // first dimension has smallest stride. - // - // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for the - // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - const int32 input1_val = - input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; - const int32 input2_val = - input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; - const int32 shifted_input1_val = input1_val * (1 << left_shift); - const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, - kReverseShift * input1_shift); - const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, - kReverseShift * input2_shift); - const int32 raw_sum = scaled_input1_val + scaled_input2_val; - const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - raw_sum, output_multiplier, kReverseShift * output_shift) + - output_offset; - const int32 clamped_output = - std::min(output_activation_max, - std::max(output_activation_min, raw_output)); - output_data[Offset(output_dims, c, x, y, b)] = - static_cast(clamped_output); - } - } - } - } -} - -inline void BroadcastAddFivefold( - int y0, int y1, int y2, int y3, int y4, int left_shift, - const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { +inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params, + const RuntimeShape& unswitched_input1_shape, + const uint8* unswitched_input1_data, + const RuntimeShape& unswitched_input2_shape, + const uint8* unswitched_input2_data, + const RuntimeShape& output_shape, + uint8* output_data) { gemmlowp::ScopedProfilingLabel label("BroadcastAddFivefold/8bit"); + ArithmeticParams switched_params = unswitched_params; + switched_params.input1_offset = unswitched_params.input2_offset; + switched_params.input1_multiplier = unswitched_params.input2_multiplier; + switched_params.input1_shift = unswitched_params.input2_shift; + switched_params.input2_offset = unswitched_params.input1_offset; + switched_params.input2_multiplier = unswitched_params.input1_multiplier; + switched_params.input2_shift = unswitched_params.input1_shift; + + const bool use_unswitched = + unswitched_params.broadcast_category == + tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast; + + const ArithmeticParams& params = + use_unswitched ? unswitched_params : switched_params; + const uint8* input1_data = + use_unswitched ? unswitched_input1_data : unswitched_input2_data; + const uint8* input2_data = + use_unswitched ? unswitched_input2_data : unswitched_input1_data; + // Fivefold nested loops. The second input resets its position for each // iteration of the second loop. The first input resets its position at the // beginning of the fourth loop. The innermost loop is an elementwise add of @@ -2903,82 +2658,29 @@ inline void BroadcastAddFivefold( uint8* output_data_ptr = output_data; const uint8* input1_data_ptr = input1_data; const uint8* input2_data_reset = input2_data; - for (int i4 = 0; i4 < y4; ++i4) { + int y0 = params.broadcast_shape[0]; + int y1 = params.broadcast_shape[1]; + int y2 = params.broadcast_shape[2]; + int y3 = params.broadcast_shape[3]; + int y4 = params.broadcast_shape[4]; + for (int i0 = 0; i0 < y0; ++i0) { const uint8* input2_data_ptr; - for (int i3 = 0; i3 < y3; ++i3) { + for (int i1 = 0; i1 < y1; ++i1) { input2_data_ptr = input2_data_reset; for (int i2 = 0; i2 < y2; ++i2) { - for (int i1 = 0; i1 < y1; ++i1) { - AddElementwise( - y0, left_shift, input1_data_ptr, input1_offset, input1_multiplier, - input1_shift, input2_data_ptr, input2_offset, input2_multiplier, - input2_shift, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_data_ptr); - input2_data_ptr += y0; - output_data_ptr += y0; + for (int i3 = 0; i3 < y3; ++i3) { + AddElementwise(y4, params, input1_data_ptr, input2_data_ptr, + output_data_ptr); + input2_data_ptr += y4; + output_data_ptr += y4; } - input1_data_ptr += y0; + input1_data_ptr += y4; } } input2_data_reset = input2_data_ptr; } } -template -inline void BroadcastAdd(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset, - input1_multiplier, input1_shift, input2_data, input2_dims, - input2_offset, input2_multiplier, input2_shift, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_data, output_dims); -} - -template -inline void BroadcastAddFivefold( - int y0, int y1, int y2, int y3, int y4, int left_shift, - const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - BroadcastAddFivefold(y0, y1, y2, y3, y4, left_shift, input1_data, input1_dims, - input1_offset, input1_multiplier, input1_shift, - input2_data, input2_dims, input2_offset, - input2_multiplier, input2_shift, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_data, output_dims); -} - inline void Mul(const float* input1_data, const Dims<4>& input1_dims, const float* input2_data, const Dims<4>& input2_dims, float output_activation_min, float output_activation_max, @@ -3052,6 +2754,20 @@ void Mul(const float* input1_data, const Dims<4>& input1_dims, output_activation_max, output_data, output_dims); } +inline void Mul(const int32* input1_data, const Dims<4>& input1_dims, + const int32* input2_data, const Dims<4>& input2_dims, + int32 output_activation_min, int32 output_activation_max, + int32* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Mul/int32"); + + const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] * input2_data[i], output_activation_min, + output_activation_max); + } +} + template void Mul(const int32* input1_data, const Dims<4>& input1_dims, const int32* input2_data, const Dims<4>& input2_dims, @@ -3289,122 +3005,78 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims, } // TODO(aselle): This is not actually optimized yet. -inline void Sub(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Sub"); - const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); +inline void SubNonBroadcast(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, + float* output_data) { + gemmlowp::ScopedProfilingLabel label("SubNonBroadcast"); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( - input1_data[i] - input2_data[i], output_activation_min, - output_activation_max); + input1_data[i] - input2_data[i], params.float_activation_min, + params.float_activation_max); } } -// TODO(jiawen): We can implement BroadcastSub on buffers of arbitrary -// dimensionality if the runtime code does a single loop over one dimension -// that handles broadcasting as the base case. The code generator would then -// generate max(D1, D2) nested for loops. -// TODO(benoitjacob): BroadcastSub is intentionally duplicated from -// reference_ops.h. Once an optimized version is implemented and NdArrayDesc -// is no longer referenced in this file, move NdArrayDesc from types.h to -// reference_ops.h. -template -void BroadcastSub(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T output_activation_min, T output_activation_max, - T* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastSub"); - - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - - // In Tensorflow, the dimensions are canonically named (batch_number, row, - // col, channel), with extents (batches, height, width, depth), with the - // trailing dimension changing most rapidly (channels has the smallest stride, - // typically 1 element). - // - // In generated C code, we store arrays with the dimensions reversed. The - // first dimension has smallest stride. - // - // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for the - // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = - ActivationFunctionWithMinMax( - input1_data[SubscriptToIndex(desc1, c, x, y, b)] - - input2_data[SubscriptToIndex(desc2, c, x, y, b)], - output_activation_min, output_activation_max); - } - } - } +inline void SubWithActivation(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const int32* input1_data, + const RuntimeShape& input2_shape, + const int32* input2_data, + const RuntimeShape& output_shape, + int32* output_data) { + gemmlowp::ScopedProfilingLabel label("SubWithActivation/int32"); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, input2_shape); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] - input2_data[i], params.quantized_activation_min, + params.quantized_activation_max); } } -inline void BroadcastSub(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastSub/8bit"); +inline void SubWithActivation(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, + float* output_data) { + gemmlowp::ScopedProfilingLabel label("SubWithActivation/float"); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, input2_shape); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] - input2_data[i], params.float_activation_min, + params.float_activation_max); + } +} - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); +template +void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape, + const T* input1_data, const RuntimeShape& input2_shape, + const T* input2_data, const RuntimeShape& output_shape, + T* output_data) { + gemmlowp::ScopedProfilingLabel label("Sub"); - // In Tensorflow, the dimensions are canonically named (batch_number, row, - // col, channel), with extents (batches, height, width, depth), with the - // trailing dimension changing most rapidly (channels has the smallest stride, - // typically 1 element). - // - // In generated C code, we store arrays with the dimensions reversed. The - // first dimension has smallest stride. - // - // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for the - // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - const int32 input1_val = - input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; - const int32 input2_val = - input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; - const int32 shifted_input1_val = input1_val * (1 << left_shift); - const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, - kReverseShift * input1_shift); - const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, - kReverseShift * input2_shift); - const int32 raw_sub = scaled_input1_val - scaled_input2_val; - const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - raw_sub, output_multiplier, kReverseShift * output_shift) + - output_offset; - const int32 clamped_output = - std::min(output_activation_max, - std::max(output_activation_min, raw_output)); - output_data[Offset(output_dims, c, x, y, b)] = - static_cast(clamped_output); - } - } - } + auto input1_map = MapAsVector(input1_data, input1_shape); + auto input2_map = MapAsVector(input2_data, input2_shape); + auto output_map = MapAsVector(output_data, output_shape); + if (input1_shape == input2_shape) { + output_map.array() = input1_map.array() - input2_map.array(); + } else if (input1_shape.FlatSize() == 1) { + auto scalar = input1_data[0]; + output_map.array() = scalar - input2_map.array(); + } else if (input2_shape.FlatSize() == 1) { + auto scalar = input2_data[0]; + output_map.array() = input1_map.array() - scalar; + } else { + BroadcastSub4DSlow(params, input1_shape, input1_data, input2_shape, + input2_data, output_shape, output_data); } } @@ -3770,21 +3442,20 @@ inline int NodeOffset(int b, int h, int w, int height, int width) { return (b * height + h) * width + w; } -inline void AveragePool(const float* input_data, - const RuntimeShape& input_shape, int stride_width, - int stride_height, int pad_width, int pad_height, - int kwidth, int kheight, float output_activation_min, - float output_activation_max, float* output_data, - const RuntimeShape& output_shape) { +inline void AveragePool(const PoolParams& params, + const RuntimeShape& input_shape, + const float* input_data, + const RuntimeShape& output_shape, float* output_data) { gemmlowp::ScopedProfilingLabel label("AveragePool"); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); const int batches = MatchingDim(input_shape, 0, output_shape, 0); - const int depth = MatchingDim(input_shape, 3, output_shape, 3); const int input_height = input_shape.Dims(1); const int input_width = input_shape.Dims(2); const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; // TODO(benoitjacob) make this a proper reference impl without Eigen! const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); @@ -3799,12 +3470,15 @@ inline void AveragePool(const float* input_data, for (int w = 0; w < input_width; ++w) { // (h_start, h_end) * (w_start, w_end) is the range that the input // vector projects to. - int hpad = h + pad_height; - int wpad = w + pad_width; - int h_start = - (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1; + int hpad = h + params.padding_values.height; + int wpad = w + params.padding_values.width; + int h_start = (hpad < params.filter_height) + ? 0 + : (hpad - params.filter_height) / stride_height + 1; int h_end = std::min(hpad / stride_height + 1, output_height); - int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1; + int w_start = (wpad < params.filter_width) + ? 0 + : (wpad - params.filter_width) / stride_width + 1; int w_end = std::min(wpad / stride_width + 1, output_width); // compute elementwise sum for (int ph = h_start; ph < h_end; ++ph) { @@ -3822,29 +3496,21 @@ inline void AveragePool(const float* input_data, TFLITE_DCHECK_GT(out_count.minCoeff(), 0); out_mat.array().rowwise() /= out_count.transpose().array(); - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < output_height; ++y) { - for (int x = 0; x < output_width; ++x) { - for (int c = 0; c < depth; ++c) { - output_data[Offset(output_shape, b, y, x, c)] = - ActivationFunctionWithMinMax( - output_data[Offset(output_shape, b, y, x, c)], - output_activation_min, output_activation_max); - } - } - } + const int flat_size = output_shape.FlatSize(); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax(output_data[i], + params.float_activation_min, + params.float_activation_max); } } -inline void AveragePool(const uint8* input_data, - const RuntimeShape& input_shape, int stride_width, - int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const RuntimeShape& output_shape) { +inline void AveragePool(const PoolParams& params, + const RuntimeShape& input_shape, + const uint8* input_data, + const RuntimeShape& output_shape, uint8* output_data) { gemmlowp::ScopedProfilingLabel label("AveragePool/8bit"); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); const int batches = MatchingDim(input_shape, 0, output_shape, 0); @@ -3853,17 +3519,21 @@ inline void AveragePool(const uint8* input_data, const int input_width = input_shape.Dims(2); const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; const int filter_x_start = std::max(0, -in_x_origin); const int filter_x_end = - std::min(filter_width, input_width - in_x_origin); + std::min(params.filter_width, input_width - in_x_origin); const int filter_y_start = std::max(0, -in_y_origin); const int filter_y_end = - std::min(filter_height, input_height - in_y_origin); + std::min(params.filter_height, input_height - in_y_origin); const int filter_count = (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start); // 1280 required by Inception v3 @@ -3911,18 +3581,18 @@ inline void AveragePool(const uint8* input_data, output_data + Offset(output_shape, batch, out_y, out_x, 0); int channel = 0; #ifdef USE_NEON -#define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \ - if (filter_count == FILTER_COUNT) { \ - for (; channel <= depth - 8; channel += 8) { \ - uint16 buf[8]; \ - for (int i = 0; i < 8; i++) { \ - buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \ - } \ - uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \ - buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max)); \ - buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min)); \ - vst1_u8(output_ptr + channel, buf8); \ - } \ +#define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \ + if (filter_count == FILTER_COUNT) { \ + for (; channel <= depth - 8; channel += 8) { \ + uint16 buf[8]; \ + for (int i = 0; i < 8; i++) { \ + buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \ + } \ + uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \ + buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max)); \ + buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min)); \ + vst1_u8(output_ptr + channel, buf8); \ + } \ } AVGPOOL_DIVIDING_BY(9) AVGPOOL_DIVIDING_BY(15) @@ -3933,15 +3603,15 @@ inline void AveragePool(const uint8* input_data, buf[i] = (acc[channel + i] + filter_count / 2) / filter_count; } uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); - buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max)); - buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min)); + buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max)); + buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min)); vst1_u8(output_ptr + channel, buf8); } #endif for (; channel < depth; ++channel) { uint16 a = (acc[channel] + filter_count / 2) / filter_count; - a = std::max(a, output_activation_min); - a = std::min(a, output_activation_max); + a = std::max(a, params.quantized_activation_min); + a = std::min(a, params.quantized_activation_max); output_ptr[channel] = static_cast(a); } } @@ -3949,20 +3619,19 @@ inline void AveragePool(const uint8* input_data, } } -inline void MaxPool(const float* input_data, const RuntimeShape& input_shape, - int stride_width, int stride_height, int pad_width, - int pad_height, int kwidth, int kheight, - float output_activation_min, float output_activation_max, - float* output_data, const RuntimeShape& output_shape) { +inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& output_shape, + float* output_data) { gemmlowp::ScopedProfilingLabel label("MaxPool"); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); const int batches = MatchingDim(input_shape, 0, output_shape, 0); - const int depth = MatchingDim(input_shape, 3, output_shape, 3); const int input_height = input_shape.Dims(1); const int input_width = input_shape.Dims(2); const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape); @@ -3973,12 +3642,15 @@ inline void MaxPool(const float* input_data, const RuntimeShape& input_shape, for (int w = 0; w < input_width; ++w) { // (h_start, h_end) * (w_start, w_end) is the range that the input // vector projects to. - int hpad = h + pad_height; - int wpad = w + pad_width; - int h_start = - (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1; + int hpad = h + params.padding_values.height; + int wpad = w + params.padding_values.width; + int h_start = (hpad < params.filter_height) + ? 0 + : (hpad - params.filter_height) / stride_height + 1; int h_end = std::min(hpad / stride_height + 1, output_height); - int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1; + int w_start = (wpad < params.filter_width) + ? 0 + : (wpad - params.filter_width) / stride_width + 1; int w_end = std::min(wpad / stride_width + 1, output_width); // compute elementwise sum for (int ph = h_start; ph < h_end; ++ph) { @@ -3993,28 +3665,20 @@ inline void MaxPool(const float* input_data, const RuntimeShape& input_shape, } } } - - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < output_height; ++y) { - for (int x = 0; x < output_width; ++x) { - for (int c = 0; c < depth; ++c) { - output_data[Offset(output_shape, b, y, x, c)] = - ActivationFunctionWithMinMax( - output_data[Offset(output_shape, b, y, x, c)], - output_activation_min, output_activation_max); - } - } - } + const int flat_size = output_shape.FlatSize(); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax(output_data[i], + params.float_activation_min, + params.float_activation_max); } } -inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const RuntimeShape& output_shape) { +inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& output_shape, + uint8* output_data) { gemmlowp::ScopedProfilingLabel label("MaxPool/8bit"); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); const int batches = MatchingDim(input_shape, 0, output_shape, 0); @@ -4023,17 +3687,21 @@ inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape, const int input_width = input_shape.Dims(2); const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; const int filter_x_start = std::max(0, -in_x_origin); const int filter_x_end = - std::min(filter_width, input_width - in_x_origin); + std::min(params.filter_width, input_width - in_x_origin); const int filter_y_start = std::max(0, -in_y_origin); const int filter_y_end = - std::min(filter_height, input_height - in_y_origin); + std::min(params.filter_height, input_height - in_y_origin); // 2048 required by Inception v3 static constexpr int kAccBufferMaxSize = 2048; TFLITE_DCHECK_LE(depth, kAccBufferMaxSize); @@ -4076,21 +3744,21 @@ inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape, #ifdef USE_NEON for (; channel <= depth - 16; channel += 16) { uint8x16_t a = vld1q_u8(acc + channel); - a = vminq_u8(a, vdupq_n_u8(output_activation_max)); - a = vmaxq_u8(a, vdupq_n_u8(output_activation_min)); + a = vminq_u8(a, vdupq_n_u8(params.quantized_activation_max)); + a = vmaxq_u8(a, vdupq_n_u8(params.quantized_activation_min)); vst1q_u8(output_ptr + channel, a); } for (; channel <= depth - 8; channel += 8) { uint8x8_t a = vld1_u8(acc + channel); - a = vmin_u8(a, vdup_n_u8(output_activation_max)); - a = vmax_u8(a, vdup_n_u8(output_activation_min)); + a = vmin_u8(a, vdup_n_u8(params.quantized_activation_max)); + a = vmax_u8(a, vdup_n_u8(params.quantized_activation_min)); vst1_u8(output_ptr + channel, a); } #endif for (; channel < depth; ++channel) { uint8 a = acc[channel]; - a = std::max(a, output_activation_min); - a = std::min(a, output_activation_max); + a = std::max(a, params.quantized_activation_min); + a = std::min(a, params.quantized_activation_max); output_ptr[channel] = static_cast(a); } } @@ -4098,11 +3766,9 @@ inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape, } } -inline void L2Pool(const float* input_data, const RuntimeShape& input_shape, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - float output_activation_min, float output_activation_max, - float* output_data, const RuntimeShape& output_shape) { +inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& output_shape, + float* output_data) { gemmlowp::ScopedProfilingLabel label("L2Pool"); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); @@ -4111,6 +3777,8 @@ inline void L2Pool(const float* input_data, const RuntimeShape& input_shape, const int input_width = input_shape.Dims(2); const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; // Actually carry out L2 Pool. Code is written in forward mode: we go through // the input values once, and write to all the pooled regions that it maps to. const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); @@ -4125,15 +3793,17 @@ inline void L2Pool(const float* input_data, const RuntimeShape& input_shape, for (int w = 0; w < input_width; ++w) { // (h_start, h_end) * (w_start, w_end) is the range that the input // vector projects to. - const int hpad = h + pad_height; - const int wpad = w + pad_width; - const int h_start = (hpad < filter_height) - ? 0 - : (hpad - filter_height) / stride_height + 1; + const int hpad = h + params.padding_values.height; + const int wpad = w + params.padding_values.width; + const int h_start = + (hpad < params.filter_height) + ? 0 + : (hpad - params.filter_height) / stride_height + 1; const int h_end = std::min(hpad / stride_height + 1, output_height); - const int w_start = (wpad < filter_width) - ? 0 - : (wpad - filter_width) / stride_width + 1; + const int w_start = + (wpad < params.filter_width) + ? 0 + : (wpad - params.filter_width) / stride_width + 1; const int w_end = std::min(wpad / stride_width + 1, output_width); // pre-compute square const int in_offset = w + input_width * (h + input_height * b); @@ -4154,6 +3824,13 @@ inline void L2Pool(const float* input_data, const RuntimeShape& input_shape, out_count = out_count.array().inverse(); out_mat = (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt(); + + const int flat_size = output_shape.FlatSize(); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax(output_data[i], + params.float_activation_min, + params.float_activation_max); + } } inline void LocalResponseNormalization(const float* input_data, @@ -4358,7 +4035,7 @@ inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape, // perform a division by the above-computed sum-of-exponentials. int32 fixed_sum_of_exps = sum_of_exps.raw(); int headroom_plus_one = - __builtin_clz(static_cast(fixed_sum_of_exps)); + CountLeadingZeros(static_cast(fixed_sum_of_exps)); // This is the number of bits to the left of the binary point above 1.0. // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and // no later adjustment will be needed. @@ -4504,7 +4181,7 @@ log_x_for_x_greater_than_or_equal_to_1_impl( // required shift "ourselves" instead of using, say, Rescale. FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw()); // z_a_pow_2 = input_integer_bits - z_a_headroom; - int z_a_headroom_plus_1 = __builtin_clz(static_cast(z_a.raw())); + int z_a_headroom_plus_1 = CountLeadingZeros(static_cast(z_a.raw())); FixedPoint0 r_a_tmp = SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1)); const int32 r_a_raw = @@ -4519,7 +4196,7 @@ log_x_for_x_greater_than_or_equal_to_1_impl( // z_b is treated like z_a, but premultiplying by sqrt(0.5). FixedPoint0 z_b = z_a * sqrt_half; - int z_b_headroom = __builtin_clz(static_cast(z_b.raw())) - 1; + int z_b_headroom = CountLeadingZeros(static_cast(z_b.raw())) - 1; const int32 r_b_raw = SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom); const FixedPointAccum z_b_pow_2_adj = SaturatingSub( @@ -5841,63 +5518,6 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims, } } -template -void GenericBroadcastSub(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("GenericBroadcastSub"); - - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - - // In Tensorflow, the dimensions are canonically named (batch_number, row, - // col, channel), with extents (batches, height, width, depth), with the - // trailing dimension changing most rapidly (channels has the smallest stride, - // typically 1 element). - // - // In generated C code, we store arrays with the dimensions reversed. The - // first dimension has smallest stride. - // - // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for the - // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = - input1_data[SubscriptToIndex(desc1, c, x, y, b)] - - input2_data[SubscriptToIndex(desc2, c, x, y, b)]; - } - } - } - } -} - -template -void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, - const Dims<4>& input2_dims, T* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Sub"); - - auto input1_map = MapAsVector(input1_data, input1_dims); - auto input2_map = MapAsVector(input2_data, input2_dims); - auto output_map = MapAsVector(output_data, output_dims); - if (AreSameDims(input1_dims, input2_dims)) { - output_map.array() = input1_map.array() - input2_map.array(); - } else if (FlatSize(input1_dims) == 1) { - auto scalar = input1_data[0]; - output_map.array() = scalar - input2_map.array(); - } else if (FlatSize(input2_dims) == 1) { - auto scalar = input2_data[0]; - output_map.array() = input1_map.array() - scalar; - } else { - GenericBroadcastSub(input1_data, input1_dims, input2_data, input2_dims, - output_data, output_dims); - } -} - template void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, T* output_data, diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h index f14667090f5c3867c7992211272063239f3b92aa..010b40b901e2821c36367da7e2c987fac113de11 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -19,6 +19,10 @@ limitations under the License. // structure. #include "tensorflow/contrib/lite/builtin_op_data.h" +#if defined(_MSC_VER) +#define __restrict__ __restrict +#endif + #ifndef USE_NEON #if defined(__ARM_NEON__) || defined(__ARM_NEON) #define USE_NEON @@ -124,6 +128,12 @@ void PortableCopyVector(const float* vector, int v_size, float* result); // Fill vector with 0.f. void PortableZeroVector(float* vector, int v_size); +// Multiply all elements of vector with a scalar. +void PortableVectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result); +void NeonVectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result); + // Limit a float input f between +abs_limit and -abs_limit. float PortableClip(float f, float abs_limit); diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc index e224980493aa11f642da103ee7d7377b6c4b1da0..f882f9910e0c65d69eb5a86886bae4d3c881e6ab 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc @@ -109,12 +109,12 @@ int CalculateInputRadius(int input_integer_bits, int input_left_shift) { void NudgeQuantizationRange(const float min, const float max, const int quant_min, const int quant_max, float* nudged_min, float* nudged_max, - float* scale) { + float* nudged_scale) { // This code originates from tensorflow/core/kernels/fake_quant_ops_functor.h. const float quant_min_float = static_cast(quant_min); const float quant_max_float = static_cast(quant_max); - *scale = (max - min) / (quant_max_float - quant_min_float); - const float zero_point_from_min = quant_min_float - min / *scale; + *nudged_scale = (max - min) / (quant_max_float - quant_min_float); + const float zero_point_from_min = quant_min_float - min / *nudged_scale; uint16 nudged_zero_point; if (zero_point_from_min < quant_min_float) { nudged_zero_point = static_cast(quant_min); @@ -123,8 +123,25 @@ void NudgeQuantizationRange(const float min, const float max, } else { nudged_zero_point = static_cast(TfLiteRound(zero_point_from_min)); } - *nudged_min = (quant_min_float - nudged_zero_point) * (*scale); - *nudged_max = (quant_max_float - nudged_zero_point) * (*scale); + *nudged_min = (quant_min_float - nudged_zero_point) * (*nudged_scale); + *nudged_max = (quant_max_float - nudged_zero_point) * (*nudged_scale); +} + +void FakeQuantizeArray(const float nudged_scale, const float nudged_min, + const float nudged_max, const float* input_data, + float* output_data, const float size) { + // This code originates from tensorflow/core/kernels/fake_quant_ops_functor.h. + const float inv_nudged_scale = 1.0f / nudged_scale; + + for (int i = 0; i < size; i++) { + const float src_val = input_data[i]; + const float clamped = std::min(nudged_max, std::max(nudged_min, src_val)); + const float clamped_shifted = clamped - nudged_min; + const float dst_val = + TfLiteRound(clamped_shifted * inv_nudged_scale) * nudged_scale + + nudged_min; + output_data[i] = dst_val; + } } bool CheckedLog2(const float x, int* log2_result) { diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h index 525857a2e6f73276d0a6e64770947169033c7667..9ee4a47fbb5bba1a409830f99c7b9ba967325a0a 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -28,8 +28,9 @@ namespace tflite { // Given the min and max values of a float array, return // reasonable quantization parameters to use for this array. template -QuantizationParams ChooseQuantizationParams(double rmin, double rmax) { - const T qmin = std::numeric_limits::min(); +QuantizationParams ChooseQuantizationParams(double rmin, double rmax, + bool narrow_range) { + const T qmin = std::numeric_limits::min() + (narrow_range ? 1 : 0); const T qmax = std::numeric_limits::max(); const double qmin_double = qmin; const double qmax_double = qmax; @@ -97,6 +98,11 @@ QuantizationParams ChooseQuantizationParams(double rmin, double rmax) { return quantization_params; } +template +QuantizationParams ChooseQuantizationParams(double rmin, double rmax) { + return ChooseQuantizationParams(rmin, rmax, false); +} + // Converts a floating-point number to an integer. For all inputs x where // static_cast(x) is legal according to the C++ standard, the result // is identical to that cast (i.e. the result is x with its fractional part @@ -216,7 +222,15 @@ int CalculateInputRadius(int input_integer_bits, int input_left_shift); // Outputs nudged_min, nudged_max, nudged_scale. void NudgeQuantizationRange(const float min, const float max, const int quant_min, const int quant_max, - float* nudged_min, float* nudged_max, float* scale); + float* nudged_min, float* nudged_max, + float* nudged_scale); + +// Fake quantizes (quantizes and dequantizes) input_data using the scale, +// nudged_min, and nudged_max from NudgeQuantizationRange. This matches the code +// in TensorFlow's FakeQuantizeWithMinMaxVarsFunctor. +void FakeQuantizeArray(const float nudged_scale, const float nudged_min, + const float nudged_max, const float* input_data, + float* output_data, const float size); // If x is approximately a power of two (with any positive or negative // exponent), stores that exponent (i.e. log2(x)) in *log2_result, otherwise diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h index 878b2441b4f2828a014673f5bd80fb8aa29514db..bcf5e4e4f6593ec9bce7acd1fb7082955276ca32 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h @@ -63,15 +63,257 @@ inline void Relu6(const float* input_data, const Dims<4>& input_dims, DimsToShape(output_dims)); } +template +inline void Add(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, int input2_shift, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + constexpr int kReverseShift = -1; + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + + tflite::ArithmeticParams op_params; + op_params.left_shift = left_shift; + op_params.input1_offset = input1_offset; + op_params.input1_multiplier = input1_multiplier; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_offset = input2_offset; + op_params.input2_multiplier = input2_multiplier; + op_params.input2_shift = kReverseShift * input2_shift; + op_params.output_offset = output_offset; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = kReverseShift * output_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + Add(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +template +void Add(const int32* input1_data, const Dims<4>& input1_dims, + const int32* input2_data, const Dims<4>& input2_dims, + int32* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Add/int32"); + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + + tflite::ArithmeticParams op_params; + op_params.quantized_activation_min = std::numeric_limits::min(); + op_params.quantized_activation_max = std::numeric_limits::max(); + Add(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +template +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + constexpr int kReverseShift = -1; + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + + tflite::ArithmeticParams op_params; + op_params.left_shift = left_shift; + op_params.input1_offset = input1_offset; + op_params.input1_multiplier = input1_multiplier; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_offset = input2_offset; + op_params.input2_multiplier = input2_multiplier; + op_params.input2_shift = kReverseShift * input2_shift; + op_params.output_offset = output_offset; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = kReverseShift * output_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); +} + +template +void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + tflite::ArithmeticParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + Add(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +template +void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, + T* output_data, const Dims<4>& output_dims) { + tflite::ArithmeticParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); +} + +template +inline void BroadcastAddFivefold( + int y0, int y1, int y2, int y3, int y4, int left_shift, + const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + constexpr int kReverseShift = -1; + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + tflite::ArithmeticParams op_params; + op_params.broadcast_category = + tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast; + op_params.left_shift = left_shift; + op_params.input1_offset = input1_offset; + op_params.input1_multiplier = input1_multiplier; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_offset = input2_offset; + op_params.input2_multiplier = input2_multiplier; + op_params.input2_shift = kReverseShift * input2_shift; + op_params.output_offset = output_offset; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = kReverseShift * output_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + op_params.broadcast_shape[4] = y0; + op_params.broadcast_shape[3] = y1; + op_params.broadcast_shape[2] = y2; + op_params.broadcast_shape[1] = y3; + op_params.broadcast_shape[0] = y4; + BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); +} + +// legacy, for compatibility with old checked-in code +template +void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + T output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +template +inline void Add(const int16* input1_data, const Dims<4>& input1_dims, + int input1_shift, const int16* input2_data, + const Dims<4>& input2_dims, int input2_shift, + int16 output_activation_min, int16 output_activation_max, + int16* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, -32768); + TFLITE_DCHECK_EQ(output_activation_max, 32767); + } + + tflite::ArithmeticParams op_params; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_shift = kReverseShift * input2_shift; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + Add(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +inline void Sub(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(FusedActivationFunctionType::kNone, + &output_activation_min, &output_activation_max); + tflite::ArithmeticParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + Sub(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +template +void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, + const Dims<4>& input2_dims, T* output_data, + const Dims<4>& output_dims) { + tflite::ArithmeticParams op_params; + op_params.quantized_activation_min = std::numeric_limits::min(); + op_params.quantized_activation_max = std::numeric_limits::max(); + Sub(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + inline void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride_width, int stride_height, int pad_width, int pad_height, int kwidth, int kheight, float output_activation_min, float output_activation_max, float* output_data, const Dims<4>& output_dims) { - AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height, - pad_width, pad_height, kwidth, kheight, output_activation_min, - output_activation_max, output_data, DimsToShape(output_dims)); + tflite::PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = kheight; + params.filter_width = kwidth; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.float_activation_min = output_activation_min; + params.float_activation_max = output_activation_max; + AveragePool(params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); } // legacy, for compatibility with old checked-in code @@ -104,10 +346,17 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int32 output_activation_min, int32 output_activation_max, uint8* output_data, const Dims<4>& output_dims) { - AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height, - pad_width, pad_height, filter_width, filter_height, - output_activation_min, output_activation_max, output_data, - DimsToShape(output_dims)); + tflite::PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = filter_height; + params.filter_width = filter_width; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.quantized_activation_min = output_activation_min; + params.quantized_activation_max = output_activation_max; + AveragePool(params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); } // legacy, for compatibility with old checked-in code @@ -148,9 +397,17 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims, int pad_height, int kwidth, int kheight, float output_activation_min, float output_activation_max, float* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height, - pad_width, pad_height, kwidth, kheight, output_activation_min, - output_activation_max, output_data, DimsToShape(output_dims)); + tflite::PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = kheight; + params.filter_width = kwidth; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.float_activation_min = output_activation_min; + params.float_activation_max = output_activation_max; + MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims), + output_data); } // legacy, for compatibility with old checked-in code @@ -180,10 +437,17 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int pad_height, int filter_width, int filter_height, int32 output_activation_min, int32 output_activation_max, uint8* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height, - pad_width, pad_height, filter_width, filter_height, - output_activation_min, output_activation_max, output_data, - DimsToShape(output_dims)); + PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = filter_height; + params.filter_width = filter_width; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.quantized_activation_min = output_activation_min; + params.quantized_activation_max = output_activation_max; + MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims), + output_data); } // legacy, for compatibility with old checked-in code @@ -223,10 +487,17 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims, int pad_height, int filter_width, int filter_height, float output_activation_min, float output_activation_max, float* output_data, const Dims<4>& output_dims) { - L2Pool(input_data, DimsToShape(input_dims), stride_width, stride_height, - pad_width, pad_height, filter_width, filter_height, - output_activation_min, output_activation_max, output_data, - DimsToShape(output_dims)); + PoolParams params; + params.stride_height = stride_height; + params.stride_width = stride_width; + params.filter_height = filter_height; + params.filter_width = filter_width; + params.padding_values.height = pad_height; + params.padding_values.width = pad_width; + params.float_activation_min = output_activation_min; + params.float_activation_max = output_activation_max; + L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims), + output_data); } // legacy, for compatibility with old checked-in code diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc index f8c6f341f7e61529bbbac592f9caf115f6121e0c..a5f4addd5e9297124ff4f6d6011093fac101f9f0 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -14,12 +14,17 @@ limitations under the License. ==============================================================================*/ #include #include +#include #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/round.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" +#if defined(_MSC_VER) +#define __restrict__ __restrict +#endif + namespace tflite { namespace tensor_utils { @@ -37,24 +42,23 @@ bool PortableIsZeroVector(const float* vector, int v_size) { } void PortableSymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, - float* __restrict__ min, - float* __restrict__ max, - float* __restrict__ scaling_factor) { + int8_t* quantized_values, float* min_value, + float* max_value, float* scaling_factor) { auto minmax = std::minmax_element(values, values + size); - *min = *minmax.first; - *max = *minmax.second; + *min_value = *minmax.first; + *max_value = *minmax.second; const int kScale = 127; - const float range = std::max(std::abs(*min), std::abs(*max)); + const float range = std::max(std::abs(*min_value), std::abs(*max_value)); if (range == 0) { memset(quantized_values, 0, size * sizeof(int8_t)); *scaling_factor = 1; return; } - *scaling_factor = kScale / range; + *scaling_factor = range / kScale; + const float scaling_factor_inv = 1.0f / *scaling_factor; for (int i = 0; i < size; ++i) { const int32_t quantized_value = - static_cast(TfLiteRound(*scaling_factor * values[i])); + static_cast(TfLiteRound(values[i] * scaling_factor_inv)); // Clamp: just in case some odd numeric offset. quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value)); } @@ -80,25 +84,26 @@ void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, void PortableMatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, - const int8_t* __restrict__ vectors, - const float* __restrict__ scaling_factors, int n_batch, - float* __restrict__ result, int result_stride) { + const int8_t* __restrict__ vectors, const float* scaling_factors, + int n_batch, float* __restrict__ result, int result_stride) { int batch, row, col; for (batch = 0; batch < n_batch; ++batch, vectors += m_cols) { - const float batch_scaling_factor_inv = 1.0 / scaling_factors[batch]; + const float batch_scaling_factor = scaling_factors[batch]; // Get the address of the first row. const int8_t* row_ptr = matrix; for (row = 0; row < m_rows; ++row, result += result_stride) { // Initialize the dot product sum for the row to 0. int32_t dotprod = 0; +#if defined(__GNUC__) // Prefetch the row to cache. __builtin_prefetch(row_ptr, 0 /* prefetch for read */, 3 /* temporal locality */); +#endif // For every block of 16 8-bit elements (128-bit register) from each row. for (col = 0; col < m_cols; ++col, ++row_ptr) { dotprod += (*row_ptr) * (vectors[col]); } // for col - *result += (dotprod * batch_scaling_factor_inv); + *result += (dotprod * batch_scaling_factor); } // for row } // for batch } @@ -194,6 +199,13 @@ void PortableZeroVector(float* vector, int v_size) { memset(vector, 0, v_size * sizeof(float)); } +void PortableVectorScalarMultiply(const int8_t* vector, const int v_size, + const float scale, float* result) { + for (int v = 0; v < v_size; ++v) { + *result++ = scale * *vector++; + } +} + void PortableClipVector(const float* vector, int v_size, float abs_limit, float* result) { for (int v = 0; v < v_size; v++) { diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h index d2e1fecd25cf3d11d3daffcc566dc1d5df97128c..a375aaffa67ac19975cc8e371be11547d689dc72 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -19,6 +19,10 @@ limitations under the License. // structure. #include "tensorflow/contrib/lite/builtin_op_data.h" +#if defined(_MSC_VER) +#define __restrict__ __restrict +#endif + namespace tflite { namespace tensor_utils { @@ -28,8 +32,8 @@ float PortableClip(float f, float abs_limit); bool PortableIsZeroVector(const float* vector, int v_size); void PortableSymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, float* min, - float* max, float* scaling_factor); + int8_t* quantized_values, float* min_value, + float* max_value, float* scaling_factor); // Multiply a matrix by a batch vector, and store results in a batch-size // vector. @@ -96,6 +100,10 @@ void PortableSub1Vector(const float* vector, int v_size, float* result); // Fill vector with 0.f. void PortableZeroVector(float* vector, int v_size); +// Multiply all elements of vector with a scalar. +void PortableVectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result); + // Clip elements of a vector using a abs_limit value. void PortableClipVector(const float* vector, int v_size, float abs_limit, float* result); @@ -199,6 +207,12 @@ void ZeroVector(float* vector, int v_size) { PortableZeroVector(vector, v_size); } +// Multiply all elements of vector with a scalar. +void VectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result) { + PortableVectorScalarMultiply(vector, v_size, scale, result); +} + void ClipVector(const float* vector, int v_size, float abs_limit, float* result) { PortableClipVector(vector, v_size, abs_limit, result); diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 16901a3e5336a5668ee1df700baa87b51fd2fe2e..ace3af2da06b31cea7f3d7e60d086f9ff6d7c0ce 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -158,98 +158,6 @@ SaturatingRoundingMultiplyByPOTParam( SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent)); } -// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE -// BROADCASTING. -// -// NdArrayDesc describes the shape and memory layout of an N-dimensional -// rectangular array of numbers. -// -// NdArrayDesc is basically identical to Dims defined in types.h. -// However, as Dims is to be deprecated, this class exists as an adaptor -// to enable simple unoptimized implementations of element-wise broadcasting -// operations. -template -struct NdArrayDesc { - // The "extent" of each dimension. Indices along dimension d must be in the - // half-open interval [0, extents[d]). - int extents[N]; - - // The number of *elements* (not bytes) between consecutive indices of each - // dimension. - int strides[N]; -}; - -// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING -// ELEMENT-WISE BROADCASTING. -// -// Same as Offset(), except takes as NdArrayDesc instead of Dims. -inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2, - int i3) { - TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]); - TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]); - TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]); - TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]); - return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] + - i3 * desc.strides[3]; -} - -// Given the dimensions of the operands for an element-wise binary broadcast, -// adjusts them so that they can be directly iterated over with simple loops. -// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and -// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr. -// -// This function assumes that the two input shapes are compatible up to -// broadcasting and the shorter one has already been prepended with 1s to be the -// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64), -// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that -// Dims refer to shapes in reverse order. In this case, input0_dims will be -// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1). -// -// When two shapes are compatible up to broadcasting, for each dimension d, -// the input extents are either equal, or one of them is 1. -// -// This function performs the following for each dimension d: -// - If the extents are equal, then do nothing since the loop that walks over -// both of the input arrays is correct. -// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1 -// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows -// array0 to be referenced *at any index* in dimension d and still access the -// same slice. -template -inline void NdArrayDescsForElementwiseBroadcast(const Dims& input0_dims, - const Dims& input1_dims, - NdArrayDesc* desc0_out, - NdArrayDesc* desc1_out) { - TFLITE_DCHECK(desc0_out != nullptr); - TFLITE_DCHECK(desc1_out != nullptr); - - // Copy dims to desc. - for (int i = 0; i < N; ++i) { - desc0_out->extents[i] = input0_dims.sizes[i]; - desc0_out->strides[i] = input0_dims.strides[i]; - desc1_out->extents[i] = input1_dims.sizes[i]; - desc1_out->strides[i] = input1_dims.strides[i]; - } - - // Walk over each dimension. If the extents are equal do nothing. - // Otherwise, set the desc with extent 1 to have extent equal to the other and - // stride 0. - for (int i = 0; i < N; ++i) { - const int extent0 = ArraySize(input0_dims, i); - const int extent1 = ArraySize(input1_dims, i); - if (extent0 != extent1) { - if (extent0 == 1) { - desc0_out->strides[i] = 0; - desc0_out->extents[i] = extent1; - } else { - TFLITE_DCHECK_EQ(extent1, 1); - desc1_out->strides[i] = 0; - desc1_out->extents[i] = extent0; - } - } - } -} - inline void Conv(const float* input_data, const Dims<4>& input_dims, const float* filter_data, const Dims<4>& filter_dims, const float* bias_data, const Dims<4>& bias_dims, @@ -414,8 +322,8 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, if (bias_data) { acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; } - acc = MultiplyByQuantizedMultiplierSmallerThanOneExp( - acc, output_multiplier, kReverseShift * output_shift); + acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, + kReverseShift * output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); @@ -951,6 +859,19 @@ inline void Relu6(const float* input_data, const RuntimeShape& input_shape, } } +inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data, + const RuntimeShape& input_shape, uint8* output_data, + const RuntimeShape& output_shape) { + gemmlowp::ScopedProfilingLabel label("Quantized ReluX (not fused)"); + const int flat_size = MatchingFlatSize(input_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + const uint8 val = input_data[i]; + const uint8 clamped = + val > max_value ? max_value : val < min_value ? min_value : val; + output_data[i] = clamped; + } +} + template void L2Normalization(const float* input_data, const RuntimeShape& input_shape, float* output_data, const RuntimeShape& output_shape) { @@ -982,7 +903,8 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input, ++*output_shift; } TFLITE_DCHECK_GT(input, 0); - const unsigned max_left_shift_bits = __builtin_clz(input) - 1; + const unsigned max_left_shift_bits = + CountLeadingZeros(static_cast(input)) - 1; const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; *output_shift -= left_shift_bit_pairs; @@ -1052,114 +974,108 @@ inline void L2Normalization(const uint8* input_data, } template -inline void Add(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T output_activation_min, T output_activation_max, - T* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); +inline void Add(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const T* input1_data, + const RuntimeShape& input2_shape, const T* input2_data, + const RuntimeShape& output_shape, T* output_data) { + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( - input1_data[i] + input2_data[i], output_activation_min, - output_activation_max); + input1_data[i] + input2_data[i], params.quantized_activation_min, + params.quantized_activation_max); } } -// legacy, for compatibility with old checked-in code -template -void Add(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float* output_data, const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - - Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, - output_activation_max, output_data, output_dims); +inline void Add(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const float* input1_data, + const RuntimeShape& input2_shape, const float* input2_data, + const RuntimeShape& output_shape, float* output_data) { + const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape); + for (int i = 0; i < size; i++) { + auto x = input1_data[i] + input2_data[i]; + output_data[i] = ActivationFunctionWithMinMax( + x, params.float_activation_min, params.float_activation_max); + } } -template -inline void Add(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, int input2_shift, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - const int batches = - MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); - const int height = - MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); - const int width = - MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); - const int depth = - MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - for (int c = 0; c < depth; ++c) { - const int32 input1_val = - input1_offset + input1_data[Offset(input1_dims, c, x, y, b)]; - const int32 input2_val = - input2_offset + input2_data[Offset(input2_dims, c, x, y, b)]; - const int32 shifted_input1_val = input1_val * (1 << left_shift); - const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, - kReverseShift * input1_shift); - const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, - kReverseShift * input2_shift); - const int32 raw_sum = scaled_input1_val + scaled_input2_val; - const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - raw_sum, output_multiplier, kReverseShift * output_shift) + - output_offset; - const int32 clamped_output = - std::min(output_activation_max, - std::max(output_activation_min, raw_output)); - output_data[Offset(output_dims, c, x, y, b)] = - static_cast(clamped_output); - } - } - } +// Element-wise add that can often be used for inner loop of broadcast add as +// well as the non-broadcast add. +inline void AddElementwise(int size, const ArithmeticParams& params, + const uint8* input1_data, const uint8* input2_data, + uint8* output_data) { + TFLITE_DCHECK_GT(params.input1_offset, -256); + TFLITE_DCHECK_GT(params.input2_offset, -256); + TFLITE_DCHECK_LT(params.input1_offset, 256); + TFLITE_DCHECK_LT(params.input2_offset, 256); + + for (int i = 0; i < size; ++i) { + const int32 input1_val = params.input1_offset + input1_data[i]; + const int32 input2_val = params.input2_offset + input2_data[i]; + const int32 shifted_input1_val = input1_val * (1 << params.left_shift); + const int32 shifted_input2_val = input2_val * (1 << params.left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, params.input1_multiplier, params.input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, params.input2_multiplier, params.input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, params.output_multiplier, params.output_shift) + + params.output_offset; + const int32 clamped_output = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, raw_output)); + output_data[i] = static_cast(clamped_output); } } -inline void Add(const int16* input1_data, const Dims<4>& input1_dims, - int input1_shift, const int16* input2_data, - const Dims<4>& input2_dims, int input2_shift, - int16 output_activation_min, int16 output_activation_max, - int16* output_data, const Dims<4>& output_dims) { - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - - const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); - - TFLITE_DCHECK(input1_shift == 0 || input2_shift == 0); - TFLITE_DCHECK_GE(input1_shift, 0); - TFLITE_DCHECK_GE(input2_shift, 0); +inline void Add(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const uint8* input1_data, + const RuntimeShape& input2_shape, const uint8* input2_data, + const RuntimeShape& output_shape, uint8* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + + TFLITE_DCHECK_GT(params.input1_offset, -256); + TFLITE_DCHECK_GT(params.input2_offset, -256); + TFLITE_DCHECK_LT(params.input1_offset, 256); + TFLITE_DCHECK_LT(params.input2_offset, 256); + AddElementwise(flat_size, params, input1_data, input2_data, output_data); +} + +inline void Add(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const int16* input1_data, + const RuntimeShape& input2_shape, const int16* input2_data, + const RuntimeShape& output_shape, int16* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + + const int input1_shift = params.input1_shift; + const int flat_size = + MatchingFlatSize(output_shape, input1_shape, input2_shape); + const int16 output_activation_min = params.quantized_activation_min; + const int16 output_activation_max = params.quantized_activation_max; + + TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0); + TFLITE_DCHECK_LE(input1_shift, 0); + TFLITE_DCHECK_LE(params.input2_shift, 0); const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data; const int16* shift_input = input1_shift == 0 ? input2_data : input1_data; - const int input_shift = input1_shift == 0 ? input2_shift : input1_shift; + const int input_right_shift = + input1_shift == 0 ? -params.input2_shift : -input1_shift; for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. using F0 = gemmlowp::FixedPoint; F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]); - F0 scaled_input = - F0::FromRaw(gemmlowp::RoundingDivideByPOT(shift_input[i], input_shift)); + F0 scaled_input = F0::FromRaw( + gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift)); F0 result = gemmlowp::SaturatingAdd(scaled_input, input_ready_scaled); const int16 raw_output = result.raw(); const int16 clamped_output = std::min( @@ -1168,42 +1084,28 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims, } } -template -inline void Add(const int16* input1_data, const Dims<4>& input1_dims, - int input1_shift, const int16* input2_data, - const Dims<4>& input2_dims, int input2_shift, - int16 output_activation_min, int16 output_activation_max, - int16* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, -32768); - TFLITE_DCHECK_EQ(output_activation_max, 32767); - } - - Add(input1_data, input1_dims, input1_shift, input2_data, input2_dims, - input2_shift, output_activation_min, output_activation_max, output_data, - output_dims); -} - // TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then // generate max(D1, D2) nested for loops. -template -void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T output_activation_min, T output_activation_max, - T* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastAdd"); - +// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from +// reference_ops.h. Once an optimized version is implemented and NdArrayDesc +// is no longer referenced in this file, move NdArrayDesc from types.h to +// reference_ops.h. +inline void BroadcastAdd4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, + float* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/float"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); // In Tensorflow, the dimensions are canonically named (batch_number, row, // col, channel), with extents (batches, height, width, depth), with the @@ -1216,49 +1118,77 @@ void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, // We name our variables by their Tensorflow convention, but generate C code // nesting loops such that the innermost loop has the smallest stride for the // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { + output_data[Offset(extended_output_shape, b, y, x, c)] = ActivationFunctionWithMinMax( - input1_data[SubscriptToIndex(desc1, c, x, y, b)] + - input2_data[SubscriptToIndex(desc2, c, x, y, b)], - output_activation_min, output_activation_max); + input1_data[SubscriptToIndex(desc1, b, y, x, c)] + + input2_data[SubscriptToIndex(desc2, b, y, x, c)], + params.float_activation_min, params.float_activation_max); } } } } } -// legacy, for compatibility with old checked-in code -template -void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T* output_data, const Dims<4>& output_dims) { - T output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); +inline void BroadcastAdd4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const int32* input1_data, + const RuntimeShape& input2_shape, + const int32* input2_data, + const RuntimeShape& output_shape, + int32* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/int32"); + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); - BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims, - output_activation_min, output_activation_max, output_data, - output_dims); + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { + output_data[Offset(extended_output_shape, b, y, x, c)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, b, y, x, c)] + + input2_data[SubscriptToIndex(desc2, b, y, x, c)], + params.quantized_activation_min, + params.quantized_activation_max); + } + } + } + } } -inline void BroadcastAdd(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit"); - +inline void BroadcastAdd4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const uint8* input1_data, + const RuntimeShape& input2_shape, + const uint8* input2_data, + const RuntimeShape& output_shape, + uint8* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/uint8"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); // In Tensorflow, the dimensions are canonically named (batch_number, row, // col, channel), with extents (batches, height, width, depth), with the @@ -1271,33 +1201,37 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data, // We name our variables by their Tensorflow convention, but generate C code // nesting loops such that the innermost loop has the smallest stride for the // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { const int32 input1_val = - input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + params.input1_offset + + input1_data[SubscriptToIndex(desc1, b, y, x, c)]; const int32 input2_val = - input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; - const int32 shifted_input1_val = input1_val * (1 << left_shift); - const int32 shifted_input2_val = input2_val * (1 << left_shift); + params.input2_offset + + input2_data[SubscriptToIndex(desc2, b, y, x, c)]; + const int32 shifted_input1_val = + input1_val * (1 << params.left_shift); + const int32 shifted_input2_val = + input2_val * (1 << params.left_shift); const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, - kReverseShift * input1_shift); + shifted_input1_val, params.input1_multiplier, + params.input1_shift); const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, - kReverseShift * input2_shift); + shifted_input2_val, params.input2_multiplier, + params.input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp( - raw_sum, output_multiplier, kReverseShift * output_shift) + - output_offset; + raw_sum, params.output_multiplier, params.output_shift) + + params.output_offset; const int32 clamped_output = - std::min(output_activation_max, - std::max(output_activation_min, raw_output)); - output_data[Offset(output_dims, c, x, y, b)] = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, raw_output)); + output_data[Offset(extended_output_shape, b, y, x, c)] = static_cast(clamped_output); } } @@ -1305,121 +1239,67 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data, } } -inline void BroadcastAddFivefold( - int y0, int y1, int y2, int y3, int y4, int left_shift, - const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastAddFivefold/8bit"); - - int sb1 = y0; - int sa2 = y0; - int sb2 = y0 * y1; - int sa3 = y0 * y2; - int sa4 = y0 * y2 * y3; - int sb4 = y0 * y1 * y2; - +inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params, + const RuntimeShape& unswitched_input1_shape, + const uint8* unswitched_input1_data, + const RuntimeShape& unswitched_input2_shape, + const uint8* unswitched_input2_data, + const RuntimeShape& output_shape, + uint8* output_data) { + ArithmeticParams switched_params = unswitched_params; + switched_params.input1_offset = unswitched_params.input2_offset; + switched_params.input1_multiplier = unswitched_params.input2_multiplier; + switched_params.input1_shift = unswitched_params.input2_shift; + switched_params.input2_offset = unswitched_params.input1_offset; + switched_params.input2_multiplier = unswitched_params.input1_multiplier; + switched_params.input2_shift = unswitched_params.input1_shift; + + const bool use_unswitched = + unswitched_params.broadcast_category == + tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast; + + const ArithmeticParams& params = + use_unswitched ? unswitched_params : switched_params; + const uint8* input1_data = + use_unswitched ? unswitched_input1_data : unswitched_input2_data; + const uint8* input2_data = + use_unswitched ? unswitched_input2_data : unswitched_input1_data; + + // Fivefold nested loops. The second input resets its position for each + // iteration of the second loop. The first input resets its position at the + // beginning of the fourth loop. The innermost loop is an elementwise add of + // sections of the arrays. uint8* output_data_ptr = output_data; - for (int i4 = 0; i4 < y4; ++i4) { - for (int i3 = 0; i3 < y3; ++i3) { + const uint8* input1_data_ptr = input1_data; + const uint8* input2_data_reset = input2_data; + int y0 = params.broadcast_shape[0]; + int y1 = params.broadcast_shape[1]; + int y2 = params.broadcast_shape[2]; + int y3 = params.broadcast_shape[3]; + int y4 = params.broadcast_shape[4]; + for (int i0 = 0; i0 < y0; ++i0) { + const uint8* input2_data_ptr; + for (int i1 = 0; i1 < y1; ++i1) { + input2_data_ptr = input2_data_reset; for (int i2 = 0; i2 < y2; ++i2) { - for (int i1 = 0; i1 < y1; ++i1) { - for (int i0 = 0; i0 < y0; ++i0) { - const int32 input1_val = - input1_offset + - input1_data[i4 * sa4 + i3 * sa3 + i2 * sa2 + i0]; - const int32 input2_val = - input2_offset + - input2_data[i4 * sb4 + i2 * sb2 + i1 * sb1 + i0]; - const int32 shifted_input1_val = input1_val * (1 << left_shift); - const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, - kReverseShift * input1_shift); - const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, - kReverseShift * input2_shift); - const int32 raw_sum = scaled_input1_val + scaled_input2_val; - const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOneExp( - raw_sum, output_multiplier, kReverseShift * output_shift) + - output_offset; - const int32 clamped_output = - std::min(output_activation_max, - std::max(output_activation_min, raw_output)); - *output_data_ptr = static_cast(clamped_output); - ++output_data_ptr; - } + for (int i3 = 0; i3 < y3; ++i3) { + AddElementwise(y4, params, input1_data_ptr, input2_data_ptr, + output_data_ptr); + input2_data_ptr += y4; + output_data_ptr += y4; } + input1_data_ptr += y4; } } + input2_data_reset = input2_data_ptr; } } -template -inline void BroadcastAdd(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset, - input1_multiplier, input1_shift, input2_data, input2_dims, - input2_offset, input2_multiplier, input2_shift, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_data, output_dims); -} - -template -inline void BroadcastAddFivefold( - int y0, int y1, int y2, int y3, int y4, int left_shift, - const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - BroadcastAddFivefold(y0, y1, y2, y3, y4, left_shift, input1_data, input1_dims, - input1_offset, input1_multiplier, input1_shift, - input2_data, input2_dims, input2_offset, - input2_multiplier, input2_shift, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_data, output_dims); -} - -inline void Mul(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { +template +inline void Mul(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, + T* output_data, const Dims<4>& output_dims) { const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( @@ -1640,10 +1520,11 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims, } } -inline void Div(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { +template +inline void Div(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, + T* output_data, const Dims<4>& output_dims) { const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( @@ -1652,15 +1533,35 @@ inline void Div(const float* input1_data, const Dims<4>& input1_dims, } } -inline void Sub(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); +inline void SubNonBroadcast(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, + float* output_data) { + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( - input1_data[i] - input2_data[i], output_activation_min, - output_activation_max); + input1_data[i] - input2_data[i], params.float_activation_min, + params.float_activation_max); + } +} + +inline void SubNonBroadcast(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const int32* input1_data, + const RuntimeShape& input2_shape, + const int32* input2_data, + const RuntimeShape& output_shape, + int32* output_data) { + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] - input2_data[i], params.quantized_activation_min, + params.quantized_activation_max); } } @@ -1668,16 +1569,24 @@ inline void Sub(const float* input1_data, const Dims<4>& input1_dims, // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then // generate max(D1, D2) nested for loops. -template -void BroadcastSub(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T output_activation_min, T output_activation_max, - T* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastSub"); - +// TODO(benoitjacob): BroadcastSub is intentionally duplicated from +// reference_ops.h. Once an optimized version is implemented and NdArrayDesc +// is no longer referenced in this file, move NdArrayDesc from types.h to +// reference_ops.h. +inline void BroadcastSub4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, + float* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/float"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); // In Tensorflow, the dimensions are canonically named (batch_number, row, // col, channel), with extents (batches, height, width, depth), with the @@ -1690,36 +1599,35 @@ void BroadcastSub(const T* input1_data, const Dims<4>& input1_dims, // We name our variables by their Tensorflow convention, but generate C code // nesting loops such that the innermost loop has the smallest stride for the // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { + output_data[Offset(extended_output_shape, b, y, x, c)] = ActivationFunctionWithMinMax( - input1_data[SubscriptToIndex(desc1, c, x, y, b)] - - input2_data[SubscriptToIndex(desc2, c, x, y, b)], - output_activation_min, output_activation_max); + input1_data[SubscriptToIndex(desc1, b, y, x, c)] - + input2_data[SubscriptToIndex(desc2, b, y, x, c)], + params.float_activation_min, params.float_activation_max); } } } } } -inline void BroadcastSub(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastSub/8bit"); - +inline void BroadcastSub4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const uint8* input1_data, + const RuntimeShape& input2_shape, + const uint8* input2_data, + const RuntimeShape& output_shape, + uint8* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/uint8"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); // In Tensorflow, the dimensions are canonically named (batch_number, row, // col, channel), with extents (batches, height, width, depth), with the @@ -1732,33 +1640,37 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, // We name our variables by their Tensorflow convention, but generate C code // nesting loops such that the innermost loop has the smallest stride for the // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { const int32 input1_val = - input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + params.input1_offset + + input1_data[SubscriptToIndex(desc1, b, y, x, c)]; const int32 input2_val = - input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; - const int32 shifted_input1_val = input1_val * (1 << left_shift); - const int32 shifted_input2_val = input2_val * (1 << left_shift); + params.input2_offset + + input2_data[SubscriptToIndex(desc2, b, y, x, c)]; + const int32 shifted_input1_val = + input1_val * (1 << params.left_shift); + const int32 shifted_input2_val = + input2_val * (1 << params.left_shift); const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, - kReverseShift * input1_shift); + shifted_input1_val, params.input1_multiplier, + params.input1_shift); const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, - kReverseShift * input2_shift); + shifted_input2_val, params.input2_multiplier, + params.input2_shift); const int32 raw_sub = scaled_input1_val - scaled_input2_val; const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOneExp( - raw_sub, output_multiplier, kReverseShift * output_shift) + - output_offset; + raw_sub, params.output_multiplier, params.output_shift) + + params.output_offset; const int32 clamped_output = - std::min(output_activation_max, - std::max(output_activation_min, raw_output)); - output_data[Offset(output_dims, c, x, y, b)] = + std::min(params.quantized_activation_max, + std::max(params.quantized_activation_min, raw_output)); + output_data[Offset(extended_output_shape, b, y, x, c)] = static_cast(clamped_output); } } @@ -1766,6 +1678,156 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, } } +inline void BroadcastSub4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const int32* input1_data, + const RuntimeShape& input2_shape, + const int32* input2_data, + const RuntimeShape& output_shape, + int32* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/int32"); + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { + output_data[Offset(extended_output_shape, b, y, x, c)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, b, y, x, c)] - + input2_data[SubscriptToIndex(desc2, b, y, x, c)], + params.quantized_activation_min, + params.quantized_activation_max); + } + } + } + } +} + +template +void BroadcastSub4DSlow(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const T* input1_data, + const RuntimeShape& input2_shape, const T* input2_data, + const RuntimeShape& output_shape, T* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/templated"); + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { + output_data[Offset(extended_output_shape, b, y, x, c)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, b, y, x, c)] - + input2_data[SubscriptToIndex(desc2, b, y, x, c)], + params.quantized_activation_min, + params.quantized_activation_max); + } + } + } + } +} + +template +void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape, + const T* input1_data, const RuntimeShape& input2_shape, + const T* input2_data, const RuntimeShape& output_shape, + T* output_data) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, + &desc2); + RuntimeShape extended_output_shape = + RuntimeShape::ExtendedShape(4, output_shape); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < extended_output_shape.Dims(0); ++b) { + for (int y = 0; y < extended_output_shape.Dims(1); ++y) { + for (int x = 0; x < extended_output_shape.Dims(2); ++x) { + for (int c = 0; c < extended_output_shape.Dims(3); ++c) { + output_data[Offset(extended_output_shape, b, y, x, c)] = + input1_data[SubscriptToIndex(desc1, b, y, x, c)] - + input2_data[SubscriptToIndex(desc2, b, y, x, c)]; + } + } + } + } +} + +inline void SubWithActivation(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const int32* input1_data, + const RuntimeShape& input2_shape, + const int32* input2_data, + const RuntimeShape& output_shape, + int32* output_data) { + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, input2_shape); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] - input2_data[i], params.quantized_activation_min, + params.quantized_activation_max); + } +} + +inline void SubWithActivation(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, + float* output_data) { + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, input2_shape); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = ActivationFunctionWithMinMax( + input1_data[i] - input2_data[i], params.float_activation_min, + params.float_activation_max); + } +} + template void Concatenation(int concat_dim, const Scalar* const* input_data, const Dims<4>* const* input_dims, int inputs_count, @@ -1799,6 +1861,26 @@ void Concatenation(int concat_dim, const Scalar* const* input_data, } } +template +void Pack(int dim, const Scalar* const* input_data, + const Dims<4>* const* input_dims, int inputs_count, + Scalar* output_data, const Dims<4>& output_dims) { + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + int outer_size = 1; + for (int i = dim + 1; i < 4; i++) { + outer_size *= output_dims.sizes[i]; + } + Scalar* output_ptr = output_data; + const int copy_size = FlatSize(**input_dims) / outer_size; + for (int k = 0; k < outer_size; k++) { + for (int i = 0; i < inputs_count; ++i) { + memcpy(output_ptr, input_data[i] + k * copy_size, + copy_size * sizeof(Scalar)); + output_ptr += copy_size; + } + } +} + // TODO(prabhumk): This is the same as the optimized implementation. // TODO(prabhumk): The quantized implementation of concatentation isn't fully // quantized as it takes scale as a floating point value. This should be fixed @@ -2260,13 +2342,10 @@ inline int NodeOffset(int b, int h, int w, int height, int width) { return (b * height + h) * width + w; } -inline void AveragePool(const float* input_data, - const RuntimeShape& input_shape, int stride_width, - int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, - float output_activation_min, - float output_activation_max, float* output_data, - const RuntimeShape& output_shape) { +inline void AveragePool(const PoolParams& params, + const RuntimeShape& input_shape, + const float* input_data, + const RuntimeShape& output_shape, float* output_data) { TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); const int batches = MatchingDim(input_shape, 0, output_shape, 0); @@ -2275,20 +2354,24 @@ inline void AveragePool(const float* input_data, const int input_width = input_shape.Dims(2); const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { for (int channel = 0; channel < depth; ++channel) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; // Compute the boundaries of the filter region clamped so as to // ensure that the filter window fits in the input array. const int filter_x_start = std::max(0, -in_x_origin); const int filter_x_end = - std::min(filter_width, input_width - in_x_origin); + std::min(params.filter_width, input_width - in_x_origin); const int filter_y_start = std::max(0, -in_y_origin); const int filter_y_end = - std::min(filter_height, input_height - in_y_origin); + std::min(params.filter_height, input_height - in_y_origin); float total = 0.f; float filter_count = 0; for (int filter_y = filter_y_start; filter_y < filter_y_end; @@ -2304,22 +2387,20 @@ inline void AveragePool(const float* input_data, } const float average = total / filter_count; output_data[Offset(output_shape, batch, out_y, out_x, channel)] = - ActivationFunctionWithMinMax(average, output_activation_min, - output_activation_max); + ActivationFunctionWithMinMax(average, params.float_activation_min, + params.float_activation_max); } } } } } -inline void AveragePool(const uint8* input_data, - const RuntimeShape& input_shape, int stride_width, - int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const RuntimeShape& output_shape) { - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); +inline void AveragePool(const PoolParams& params, + const RuntimeShape& input_shape, + const uint8* input_data, + const RuntimeShape& output_shape, uint8* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); const int batches = MatchingDim(input_shape, 0, output_shape, 0); @@ -2328,20 +2409,24 @@ inline void AveragePool(const uint8* input_data, const int input_width = input_shape.Dims(2); const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { for (int channel = 0; channel < depth; ++channel) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; // Compute the boundaries of the filter region clamped so as to // ensure that the filter window fits in the input array. const int filter_x_start = std::max(0, -in_x_origin); const int filter_x_end = - std::min(filter_width, input_width - in_x_origin); + std::min(params.filter_width, input_width - in_x_origin); const int filter_y_start = std::max(0, -in_y_origin); const int filter_y_end = - std::min(filter_height, input_height - in_y_origin); + std::min(params.filter_height, input_height - in_y_origin); int32 acc = 0; int filter_count = 0; for (int filter_y = filter_y_start; filter_y < filter_y_end; @@ -2356,8 +2441,8 @@ inline void AveragePool(const uint8* input_data, } } acc = (acc + filter_count / 2) / filter_count; - acc = std::max(acc, output_activation_min); - acc = std::min(acc, output_activation_max); + acc = std::max(acc, params.quantized_activation_min); + acc = std::min(acc, params.quantized_activation_max); output_data[Offset(output_shape, batch, out_y, out_x, channel)] = static_cast(acc); } @@ -2366,11 +2451,9 @@ inline void AveragePool(const uint8* input_data, } } -inline void L2Pool(const float* input_data, const RuntimeShape& input_shape, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - float output_activation_min, float output_activation_max, - float* output_data, const RuntimeShape& output_shape) { +inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& output_shape, + float* output_data) { TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); const int batches = MatchingDim(input_shape, 0, output_shape, 0); @@ -2379,20 +2462,24 @@ inline void L2Pool(const float* input_data, const RuntimeShape& input_shape, const int input_width = input_shape.Dims(2); const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { for (int channel = 0; channel < depth; ++channel) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; // Compute the boundaries of the filter region clamped so as to // ensure that the filter window fits in the input array. const int filter_x_start = std::max(0, -in_x_origin); const int filter_x_end = - std::min(filter_width, input_width - in_x_origin); + std::min(params.filter_width, input_width - in_x_origin); const int filter_y_start = std::max(0, -in_y_origin); const int filter_y_end = - std::min(filter_height, input_height - in_y_origin); + std::min(params.filter_height, input_height - in_y_origin); float sum_squares = 0.f; int filter_count = 0; for (int filter_y = filter_y_start; filter_y < filter_y_end; @@ -2409,19 +2496,18 @@ inline void L2Pool(const float* input_data, const RuntimeShape& input_shape, } const float l2pool_result = std::sqrt(sum_squares / filter_count); output_data[Offset(output_shape, batch, out_y, out_x, channel)] = - ActivationFunctionWithMinMax(l2pool_result, output_activation_min, - output_activation_max); + ActivationFunctionWithMinMax(l2pool_result, + params.float_activation_min, + params.float_activation_max); } } } } } -inline void MaxPool(const float* input_data, const RuntimeShape& input_shape, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - float output_activation_min, float output_activation_max, - float* output_data, const RuntimeShape& output_shape) { +inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& output_shape, + float* output_data) { TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); const int batches = MatchingDim(input_shape, 0, output_shape, 0); @@ -2430,20 +2516,24 @@ inline void MaxPool(const float* input_data, const RuntimeShape& input_shape, const int input_width = input_shape.Dims(2); const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { for (int channel = 0; channel < depth; ++channel) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; // Compute the boundaries of the filter region clamped so as to // ensure that the filter window fits in the input array. const int filter_x_start = std::max(0, -in_x_origin); const int filter_x_end = - std::min(filter_width, input_width - in_x_origin); + std::min(params.filter_width, input_width - in_x_origin); const int filter_y_start = std::max(0, -in_y_origin); const int filter_y_end = - std::min(filter_height, input_height - in_y_origin); + std::min(params.filter_height, input_height - in_y_origin); float max = std::numeric_limits::lowest(); for (int filter_y = filter_y_start; filter_y < filter_y_end; ++filter_y) { @@ -2457,22 +2547,21 @@ inline void MaxPool(const float* input_data, const RuntimeShape& input_shape, } } output_data[Offset(output_shape, batch, out_y, out_x, channel)] = - ActivationFunctionWithMinMax(max, output_activation_min, - output_activation_max); + ActivationFunctionWithMinMax(max, params.float_activation_min, + params.float_activation_max); } } } } } -inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const RuntimeShape& output_shape) { - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - TFLITE_DCHECK_GE(output_activation_min, 0); - TFLITE_DCHECK_LE(output_activation_max, 255); +inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& output_shape, + uint8* output_data) { + TFLITE_DCHECK_LE(params.quantized_activation_min, + params.quantized_activation_max); + TFLITE_DCHECK_GE(params.quantized_activation_min, 0); + TFLITE_DCHECK_LE(params.quantized_activation_max, 255); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); const int batches = MatchingDim(input_shape, 0, output_shape, 0); @@ -2481,20 +2570,24 @@ inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape, const int input_width = input_shape.Dims(2); const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); + const int stride_height = params.stride_height; + const int stride_width = params.stride_width; for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { for (int channel = 0; channel < depth; ++channel) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; + const int in_x_origin = + (out_x * stride_width) - params.padding_values.width; + const int in_y_origin = + (out_y * stride_height) - params.padding_values.height; // Compute the boundaries of the filter region clamped so as to // ensure that the filter window fits in the input array. const int filter_x_start = std::max(0, -in_x_origin); const int filter_x_end = - std::min(filter_width, input_width - in_x_origin); + std::min(params.filter_width, input_width - in_x_origin); const int filter_y_start = std::max(0, -in_y_origin); const int filter_y_end = - std::min(filter_height, input_height - in_y_origin); + std::min(params.filter_height, input_height - in_y_origin); uint8 max = 0; for (int filter_y = filter_y_start; filter_y < filter_y_end; ++filter_y) { @@ -2507,8 +2600,8 @@ inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape, input_data[Offset(input_shape, batch, in_y, in_x, channel)]); } } - max = std::max(max, output_activation_min); - max = std::min(max, output_activation_max); + max = std::max(max, params.quantized_activation_min); + max = std::min(max, params.quantized_activation_max); output_data[Offset(output_shape, batch, out_y, out_x, channel)] = static_cast(max); } @@ -3063,18 +3156,9 @@ inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, float nudged_min, nudged_max, nudged_scale; NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min, &nudged_max, &nudged_scale); - const float inv_nudged_scale = 1.0f / nudged_scale; - const int flat_size = MatchingFlatSize(output_dims, input_dims); - for (int i = 0; i < flat_size; i++) { - const float src_val = input_data[i]; - const float clamped = std::min(nudged_max, std::max(nudged_min, src_val)); - const float clamped_shifted = clamped - nudged_min; - const float dst_val = - TfLiteRound(clamped_shifted * inv_nudged_scale) * nudged_scale + - nudged_min; - output_data[i] = dst_val; - } + FakeQuantizeArray(nudged_scale, nudged_min, nudged_max, input_data, + output_data, flat_size); } template @@ -3192,7 +3276,8 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, const Dims<4>& block_shape_dims, const int32* paddings_data, const Dims<4>& paddings_dims, T* output_data, - const Dims<4>& output_dims) { + const Dims<4>& output_dims, + const int32_t pad_value) { const int output_batch_size = ArraySize(output_dims, 3); const int output_height = ArraySize(output_dims, 2); const int output_width = ArraySize(output_dims, 1); @@ -3217,7 +3302,7 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, padding_top + input_height || out_w * block_shape_width + shift_w < padding_left || out_w * block_shape_width + shift_w >= padding_left + input_width) { - memset(out, 0, depth * sizeof(T)); + memset(out, pad_value, depth * sizeof(T)); } else { const T* in = input_data + @@ -3232,6 +3317,17 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, } } +template +inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, + const int32* paddings_data, + const Dims<4>& paddings_dims, T* output_data, + const Dims<4>& output_dims) { + SpaceToBatchND(input_data, input_dims, block_shape_data, block_shape_dims, + paddings_data, paddings_dims, output_data, output_dims, 0); +} + template inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, const int32* block_shape_data, @@ -3442,9 +3538,9 @@ inline bool Reduce(const In* input_data, const int* input_dims, const int* output_dims, const int input_num_dims, const int output_num_dims, const int* axis, const int num_axis, int* input_iter, - Out reducer(Out current, const In in), Out* output_data) { + Out reducer(const Out current, const In in), + Out* output_data) { // Reset input iterator. - TFLITE_DCHECK(input_num_dims > 0); for (int idx = 0; idx < input_num_dims; ++idx) { input_iter[idx] = 0; } @@ -3460,11 +3556,16 @@ inline bool Reduce(const In* input_data, const int* input_dims, return true; } -inline bool ResolveAxis(const int num_dims, const int* axis, const int num_axis, - int* out_axis, int* out_num_axis) { +inline bool ResolveAxis(const int num_dims, const int* axis, + const int64_t num_axis, int* out_axis, + int* out_num_axis) { *out_num_axis = 0; // Just in case. + // Short-circuit axis resolution for scalars; the axis will go unused. + if (num_dims == 0) { + return true; + } // o(n^2) is fine since out_num_axis should be really small, mostly <= 4 - for (int idx = 0; idx < num_axis; ++idx) { + for (int64_t idx = 0; idx < num_axis; ++idx) { // Handle negative index. int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx]; TFLITE_DCHECK(current >= 0 && current < num_dims); @@ -3490,7 +3591,7 @@ inline bool ReduceSumImpl(const In* input_data, const int* input_dims, const int output_num_dims, const int* axis, const int num_axis, int* input_iter, Out* output_data) { - auto reducer = [](Out current, const In in) -> Out { + auto reducer = [](const Out current, const In in) -> Out { const Out actual_in = static_cast(in); return current + actual_in; }; @@ -3499,6 +3600,24 @@ inline bool ReduceSumImpl(const In* input_data, const int* input_dims, output_data); } +template +inline bool InitTensorDataForReduce(const int* dims, const int num_dims, + const T init_value, T* data) { + size_t num_elements = 1; + for (int idx = 0; idx < num_dims; ++idx) { + size_t current = static_cast(dims[idx]); + // Overflow prevention. + if (num_elements > std::numeric_limits::max() / current) { + return false; + } + num_elements *= current; + } + for (size_t idx = 0; idx < num_elements; ++idx) { + data[idx] = init_value; + } + return true; +} + // Computes the sum of elements across dimensions given in axis. template inline bool Sum(const T* input_data, const int* input_dims, @@ -3507,17 +3626,9 @@ inline bool Sum(const T* input_data, const int* input_dims, const int* axis, const int num_axis_dimensions, bool keep_dims, int* temp_index, int* resolved_axis) { // Reset output data. - size_t num_outputs = 1; - for (int idx = 0; idx < output_num_dims; ++idx) { - size_t current = static_cast(output_dims[idx]); - // Overflow prevention. - if (num_outputs > std::numeric_limits::max() / current) { - return false; - } - num_outputs *= current; - } - for (size_t idx = 0; idx < num_outputs; ++idx) { - output_data[idx] = T(); + if (!InitTensorDataForReduce(output_dims, output_num_dims, static_cast(0), + output_data)) { + return false; } // Resolve axis. @@ -3532,6 +3643,61 @@ inline bool Sum(const T* input_data, const int* input_dims, num_resolved_axis, temp_index, output_data); } +// Computes the max of elements across dimensions given in axis. +template +inline bool ReduceMax(const T* input_data, const int* input_dims, + const int input_num_dims, T* output_data, + const int* output_dims, const int output_num_dims, + const int* axis, const int64_t num_axis_dimensions, + bool keep_dims, int* temp_index, int* resolved_axis) { + T init_value = std::numeric_limits::lowest(); + // Reset output data. + if (!InitTensorDataForReduce(output_dims, output_num_dims, init_value, + output_data)) { + return false; + } + + // Resolve axis. + int num_resolved_axis = 0; + if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, + &num_resolved_axis)) { + return false; + } + + auto reducer = [](const T current, const T in) -> T { + return (in > current) ? in : current; + }; + return Reduce(input_data, input_dims, output_dims, input_num_dims, + output_num_dims, resolved_axis, num_resolved_axis, + temp_index, reducer, output_data); +} + +// Computes the prod of elements across dimensions given in axis. +template +inline bool ReduceProd(const T* input_data, const int* input_dims, + const int input_num_dims, T* output_data, + const int* output_dims, const int output_num_dims, + const int* axis, const int64_t num_axis_dimensions, + bool keep_dims, int* temp_index, int* resolved_axis) { + // Reset output data. + if (!InitTensorDataForReduce(output_dims, output_num_dims, static_cast(1), + output_data)) { + return false; + } + + // Resolve axis. + int num_resolved_axis = 0; + if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, + &num_resolved_axis)) { + return false; + } + + auto reducer = [](const T current, const T in) -> T { return in * current; }; + return Reduce(input_data, input_dims, output_dims, input_num_dims, + output_num_dims, resolved_axis, num_resolved_axis, + temp_index, reducer, output_data); +} + // Computes the mean of elements across dimensions given in axis. // It does so in two stages, first calculates the sum of elements along the axis // then divides it by the number of element in axis. @@ -3623,38 +3789,6 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims, } } -template -void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, - const Dims<4>& input2_dims, T* output_data, - const Dims<4>& output_dims) { - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - - // In Tensorflow, the dimensions are canonically named (batch_number, row, - // col, channel), with extents (batches, height, width, depth), with the - // trailing dimension changing most rapidly (channels has the smallest stride, - // typically 1 element). - // - // In generated C code, we store arrays with the dimensions reversed. The - // first dimension has smallest stride. - // - // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for the - // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = - input1_data[SubscriptToIndex(desc1, c, x, y, b)] - - input2_data[SubscriptToIndex(desc2, c, x, y, b)]; - } - } - } - } -} - template void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, T* output_data, @@ -3704,9 +3838,9 @@ void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims, } } -template -void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, - T2* output_data, const Dims<4>& output_dims) { +template +void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, + T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) { // The current ArgMax implemention can only determine the index of the maximum // value in the last dimension. So the axis argument is ignored. @@ -3719,19 +3853,28 @@ void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, const int depth = ArraySize(input_dims, 0); for (int i = 0; i < outer_size; ++i) { - auto max_value = input_data[i * depth]; - int max_index = 0; + auto min_max_value = input_data[i * depth]; + int min_max_index = 0; for (int d = 1; d < depth; ++d) { const auto& curr_value = input_data[i * depth + d]; - if (curr_value > max_value) { - max_value = curr_value; - max_index = d; + if (cmp(curr_value, min_max_value)) { + min_max_value = curr_value; + min_max_index = d; } } - output_data[i] = max_index; + output_data[i] = min_max_index; } } +// TODO(renjieliu): Remove this one. +template +void ArgMax(const T3* axis, const T1* input_data, + const tflite::Dims<4>& input_dims, T2* output_data, + const tflite::Dims<4>& output_dims) { + ArgMinMax(axis, input_data, input_dims, output_data, output_dims, + std::greater()); +} + template void Transpose(const T* input, const Dims<4>& input_dims, T* output, const Dims<4>& output_dims, const int* permuted_axes) { @@ -4039,8 +4182,8 @@ inline void RankOneSelect(const D* input_condition_data, } // For easy implementation, the indices is always a vector of size-4 vectors. -template -inline void SparseToDense(const std::vector>& indices, +template +inline void SparseToDense(const std::vector>& indices, const T* values, T default_value, T* output_data, const Dims<4>& output_dims, bool value_is_scalar) { const int value_count = indices.size(); @@ -4055,7 +4198,7 @@ inline void SparseToDense(const std::vector>& indices, // condition within the loop every time. if (value_is_scalar) { for (int i = 0; i < value_count; ++i) { - const std::vector& index = indices[i]; + const std::vector& index = indices[i]; TFLITE_DCHECK_EQ(index.size(), 4); const T value = *values; // just use the first value. output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] = @@ -4066,7 +4209,7 @@ inline void SparseToDense(const std::vector>& indices, // Go through the values and indices to fill the sparse values. for (int i = 0; i < value_count; ++i) { - const std::vector& index = indices[i]; + const std::vector& index = indices[i]; TFLITE_DCHECK_EQ(index.size(), 4); const T value = values[i]; output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] = @@ -4104,6 +4247,65 @@ inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims, } } +inline void Logical(const bool* input1_data, const Dims<4>& input1_dims, + const bool* input2_data, const Dims<4>& input2_dims, + bool* output_data, const Dims<4>& output_dims, + const std::function& func) { + const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = func(input1_data[i], input2_data[i]); + } +} + +inline void BroadcastLogical(const bool* input1_data, + const Dims<4>& input1_dims, + const bool* input2_data, + const Dims<4>& input2_dims, bool* output_data, + const Dims<4>& output_dims, + const std::function& func) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + func(input1_data[SubscriptToIndex(desc1, c, x, y, b)], + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more +// generalized and efficient BroadcastBinaryFunction. +// +// R: Result type. T1: Input 1 type. T2: Input 2 type. +template +inline void BroadcastBinaryFunction(const T1* input1_data, + const Dims<4>& input1_dims, + const T2* input2_data, + const Dims<4>& input2_dims, R* output_data, + const Dims<4>& output_dims, + R (*func)(T1, T2)) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + func(input1_data[SubscriptToIndex(desc1, c, x, y, b)], + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/spectrogram.cc b/tensorflow/contrib/lite/kernels/internal/spectrogram.cc index 4eddf7bf0a2cbca695dae20ba8ba56a9cd72e4ba..20abcb725859d03f83c969369bddf1429895e0ba 100644 --- a/tensorflow/contrib/lite/kernels/internal/spectrogram.cc +++ b/tensorflow/contrib/lite/kernels/internal/spectrogram.cc @@ -43,13 +43,13 @@ bool Spectrogram::Initialize(int window_length, int step_length) { return Initialize(window, step_length); } -inline int Log2Floor(uint n) { +inline int Log2Floor(uint32_t n) { if (n == 0) return -1; int log = 0; - uint value = n; + uint32_t value = n; for (int i = 4; i >= 0; --i) { int shift = (1 << i); - uint x = value >> shift; + uint32_t x = value >> shift; if (x != 0) { value = x; log += shift; @@ -58,7 +58,7 @@ inline int Log2Floor(uint n) { return log; } -inline int Log2Ceiling(uint n) { +inline int Log2Ceiling(uint32_t n) { int floor = Log2Floor(n); if (n == (n & ~(n - 1))) // zero or a power of two return floor; @@ -66,7 +66,7 @@ inline int Log2Ceiling(uint n) { return floor + 1; } -inline uint NextPowerOfTwo(uint value) { +inline uint32_t NextPowerOfTwo(uint32_t value) { int exponent = Log2Ceiling(value); // DCHECK_LT(exponent, std::numeric_limits::digits); return 1 << exponent; diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h index 5160e22307ae0894fabd0e9c4f7b9cd38b00840e..1ff8cfe39c9aed7e9241815dc8eff7ab4d9fd585 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -17,6 +17,10 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" +#if defined(_MSC_VER) +#define __restrict__ __restrict +#endif + namespace tflite { namespace tensor_utils { @@ -31,8 +35,8 @@ bool IsZeroVector(const float* vector, int v_size); // It also outputs the range (min, max) of the floating point buffer, and the // scaling factor used to quantize the values. void SymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, float* min, float* max, - float* scaling_factor); + int8_t* quantized_values, float* min_value, + float* max_value, float* scaling_factor); // Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch // dimension composed by input vectors independent from each other). The result @@ -124,6 +128,10 @@ void Sub1Vector(const float* vector, int v_size, float* result); // Fill vector with 0.f. void ZeroVector(float* vector, int v_size); +// Multiply all elements of vector with a scalar. +void VectorScalarMultiply(const int8_t* vector, int v_size, float scale, + float* result); + // Clip elements of a vector using a abs_limit value. void ClipVector(const float* vector, int v_size, float abs_limit, float* result); diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc index 14ee528394b6872d9e79969db0e431658277f56b..372a6efec5c69e53d558edf8c822f638a4d33d81 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc @@ -32,6 +32,22 @@ TEST(uKernels, ClipTest) { {0.0, -0.5, 1.0, -1.5, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0}))); } +TEST(uKernels, VectorScalarMultiply) { + constexpr int kVectorSize = 29; + static int8_t input[kVectorSize]; + for (int i = 0; i < 29; ++i) { + input[i] = static_cast(i - 14); + } + const float scale = 0.1f; + std::vector output(kVectorSize, 0.0f); + VectorScalarMultiply(input, kVectorSize, scale, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear( + {-1.4, -1.3, -1.2, -1.1, -1.0, -0.9, -0.8, -0.7, -0.6, -0.5, + -0.4, -0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3, 0.4, 0.5, + 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4}))); +} + TEST(uKernels, IsZeroTest) { constexpr int kVectorSize = 21; static float zeros[kVectorSize] = {0.0}; @@ -63,7 +79,8 @@ TEST(uKernels, SymmetricQuantizeFloatsTest) { EXPECT_EQ(min, -640); EXPECT_EQ(max, 1000); - EXPECT_NEAR(scaling_factor, 0.127, 1e-6); // EQ won't work due to fpoint. + // EQ won't work due to fpoint. + EXPECT_NEAR(scaling_factor, 1000 / 127.0, 1e-6); EXPECT_THAT(output, testing::ElementsAreArray({-81, -81, -80, 1, 0, -1, -1, 0, 127})); } @@ -95,7 +112,7 @@ TEST(uKernels, SymmetricQuantizeFloatsAllAlmostZeroTest) { EXPECT_NEAR(min, -9e-05, 1e-6); EXPECT_NEAR(max, 0.0002, 1e-6); - EXPECT_EQ(scaling_factor, 635000); + EXPECT_NEAR(scaling_factor, 1.57e-6, 1e-6); EXPECT_THAT(output, testing::ElementsAreArray({-6, 19, -4, -57, 1, 25, 6, 127, 0})); } diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index fa2420713fea4faa3596251a95c2ed9606878b98..c44698b677a862bc41c947ea46fe204710b79668 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -23,7 +23,12 @@ limitations under the License. namespace tflite { enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu }; -enum class PaddingType { kNone, kSame, kValid }; +enum class PaddingType : uint8 { kNone, kSame, kValid }; + +struct PaddingValues { + int8 width; + int8 height; +}; // This enumeration allows for non-default formats for the weights array // of a fully-connected operator, allowing the use of special optimized @@ -114,6 +119,8 @@ class RuntimeShape { // larger shapes are separately allocated. static constexpr int kMaxSmallSize = 4; + RuntimeShape& operator=(RuntimeShape const&) = delete; + RuntimeShape() : size_(0) {} explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) { @@ -130,6 +137,20 @@ class RuntimeShape { BuildFrom(init_list); } + // Avoid using this constructor. We should be able to delete it when C++17 + // rolls out. + RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) { + if (size_ > kMaxSmallSize) { + dims_pointer_ = new int32[size_]; + } + std::memcpy(DimsData(), other.DimsData(), sizeof(int32) * size_); + } + + bool operator==(const RuntimeShape& comp) const { + return this->size_ == comp.size_ && + std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32)) == 0; + } + ~RuntimeShape() { if (size_ > kMaxSmallSize) { delete[] dims_pointer_; @@ -186,6 +207,16 @@ class RuntimeShape { } } + // This will probably be factored out. Old code made substantial use of 4-D + // shapes, and so this function is used to extend smaller shapes. Note that + // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be + // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their + // inputs should already be 4-D, so this function should not be needed. + inline static RuntimeShape ExtendedShape(int new_shape_size, + const RuntimeShape& shape) { + return RuntimeShape(new_shape_size, shape, 1); + } + inline void BuildFrom(const std::initializer_list init_list) { BuildFrom>(init_list); } @@ -203,7 +234,25 @@ class RuntimeShape { return buffer_size; } + bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); } + private: + // For use only by ExtendFrom(), written to guarantee (return-value) copy + // elision in C++17. + // This creates a shape padded to the desired size with the specified value. + RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value) + : size_(0) { + TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount()); + TFLITE_CHECK_LE(new_shape_size, kMaxSmallSize); + Resize(new_shape_size); + const int size_increase = new_shape_size - shape.DimensionsCount(); + for (int i = 0; i < size_increase; ++i) { + SetDim(i, pad_value); + } + std::memcpy(DimsData() + size_increase, shape.DimsData(), + sizeof(int32) * shape.DimensionsCount()); + } + int32 size_; union { int32 dims_[kMaxSmallSize]; @@ -229,7 +278,9 @@ inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) { // Gets next index to iterate through a multidimensional array. inline bool NextIndex(const int num_dims, const int* dims, int* current) { - TFLITE_DCHECK_GT(num_dims, 0); + if (num_dims == 0) { + return false; + } TFLITE_DCHECK(dims != nullptr); TFLITE_DCHECK(current != nullptr); int carry = 1; @@ -256,7 +307,9 @@ inline bool NextIndex(const int num_dims, const int* dims, int* current) { inline size_t ReducedOutputOffset(const int num_dims, const int* dims, const int* index, const int num_axis, const int* axis) { - TFLITE_DCHECK_GT(num_dims, 0); + if (num_dims == 0) { + return 0; + } TFLITE_DCHECK(dims != nullptr); TFLITE_DCHECK(index != nullptr); size_t offset = 0; @@ -359,6 +412,7 @@ inline int RequiredBufferSizeForDims(const Dims<4>& dims) { // arrays. inline int MatchingFlatSize(const RuntimeShape& shape, const RuntimeShape& check_shape_0) { + TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount()); const int dims_count = shape.DimensionsCount(); for (int i = 0; i < dims_count; ++i) { TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); @@ -369,6 +423,7 @@ inline int MatchingFlatSize(const RuntimeShape& shape, inline int MatchingFlatSize(const RuntimeShape& shape, const RuntimeShape& check_shape_0, const RuntimeShape& check_shape_1) { + TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount()); const int dims_count = shape.DimensionsCount(); for (int i = 0; i < dims_count; ++i) { TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); @@ -380,6 +435,7 @@ inline int MatchingFlatSize(const RuntimeShape& shape, const RuntimeShape& check_shape_0, const RuntimeShape& check_shape_1, const RuntimeShape& check_shape_2) { + TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount()); const int dims_count = shape.DimensionsCount(); for (int i = 0; i < dims_count; ++i) { TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); @@ -392,6 +448,7 @@ inline int MatchingFlatSize(const RuntimeShape& shape, const RuntimeShape& check_shape_1, const RuntimeShape& check_shape_2, const RuntimeShape& check_shape_3) { + TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount()); const int dims_count = shape.DimensionsCount(); for (int i = 0; i < dims_count; ++i) { TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); @@ -588,6 +645,82 @@ void ComputeStrides(Dims* dims) { } } +struct PoolParams { + FusedActivationFunctionType activation; + PaddingType padding_type; + PaddingValues padding_values; + int stride_height; + int stride_width; + int filter_height; + int filter_width; + // uint8, etc, activation params. + int32 quantized_activation_min; + int32 quantized_activation_max; + // float activation params. + float float_activation_min; + float float_activation_max; +}; + +enum class BroadcastableOpCategory : uint8 { + kNone, + kNonBroadcast, // Matching input shapes. + kFirstInputBroadcastsFast, // Fivefold nested loops. + kSecondInputBroadcastsFast, // Fivefold nested loops. + kGenericBroadcast, // Fall-back. +}; + +// For Add, Sub, Mul ops. +struct ArithmeticParams { + // Shape dependent / common to data / op types. + BroadcastableOpCategory broadcast_category; + // uint8 inference params. + int32 input1_offset; + int32 input2_offset; + int32 output_offset; + int32 output_multiplier; + int output_shift; + // Add / Sub, not Mul, uint8 inference params. + int left_shift; + int32 input1_multiplier; + int input1_shift; + int32 input2_multiplier; + int input2_shift; + // uint8, etc, activation params. + int32 quantized_activation_min; + int32 quantized_activation_max; + // float activation params. + float float_activation_min; + float float_activation_max; + + // Processed output dimensions. + // Let input "a" be the one that broadcasts in the faster-changing dimension. + // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and + // {b0, b1, b2, b3, b4}, + // broadcast_shape[4] = b0 = a0. + // broadcast_shape[3] = b1; a1 = 1. + // broadcast_shape[2] = b2 = a2. + // broadcast_shape[1] = a3; b3 = 1. + // broadcast_shape[0] = b4 = a4. + int broadcast_shape[5]; +}; + +template +inline void SetActivationParams(T min, T max, ArithmeticParams* params); + +template <> +inline void SetActivationParams(float min, float max, + ArithmeticParams* params) { + params->float_activation_min = min; + params->float_activation_max = max; +} + +template <> +inline void SetActivationParams(int32 min, int32 max, + ArithmeticParams* params) { + params->quantized_activation_min = min; + params->quantized_activation_max = max; +} + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/contrib/lite/kernels/logical.cc new file mode 100644 index 0000000000000000000000000000000000000000..87c2fee667ccaf7bfdc4e2316309d2abc35b5324 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/logical.cc @@ -0,0 +1,134 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace logical { +namespace { + +// Input/output tensor index. +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +// Op data for logical op. +struct OpData { + bool requires_broadcast; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->requires_broadcast = false; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + // Reinterprete the opaque data provided by user. + OpData* data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + + const TfLiteType type = input1->type; + if (type != kTfLiteBool) { + context->ReportError(context, "Logical ops only support bool type."); + return kTfLiteError; + } + output->type = type; + + data->requires_broadcast = !HaveSameShapes(input1, input2); + + TfLiteIntArray* output_size = nullptr; + if (data->requires_broadcast) { + TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( + context, input1, input2, &output_size)); + } else { + output_size = TfLiteIntArrayCopy(input1->dims); + } + + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node, + const std::function& func) { + OpData* data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (data->requires_broadcast) { + reference_ops::BroadcastLogical( + GetTensorData(input1), GetTensorDims(input1), + GetTensorData(input2), GetTensorDims(input2), + GetTensorData(output), GetTensorDims(output), func); + } else { + reference_ops::Logical(GetTensorData(input1), GetTensorDims(input1), + GetTensorData(input2), GetTensorDims(input2), + GetTensorData(output), GetTensorDims(output), + func); + } + + return kTfLiteOk; +} + +TfLiteStatus LogicalOrEval(TfLiteContext* context, TfLiteNode* node) { + const auto logical_or_func = std::logical_or(); + return LogicalImpl(context, node, logical_or_func); +} + +TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) { + const auto logical_and_func = std::logical_and(); + return LogicalImpl(context, node, logical_and_func); +} + +} // namespace +} // namespace logical + +TfLiteRegistration* Register_LOGICAL_OR() { + // Init, Free, Prepare, Eval are satisfying the Interface required by + // TfLiteRegistration. + static TfLiteRegistration r = {logical::Init, logical::Free, logical::Prepare, + logical::LogicalOrEval}; + return &r; +} + +TfLiteRegistration* Register_LOGICAL_AND() { + // Init, Free, Prepare, Eval are satisfying the Interface required by + // TfLiteRegistration. + static TfLiteRegistration r = {logical::Init, logical::Free, logical::Prepare, + logical::LogicalAndEval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/logical_test.cc b/tensorflow/contrib/lite/kernels/logical_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..206cbde98fa48ec5f7c863bbced9dccc9cab5207 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/logical_test.cc @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +class LogicalOpModel : public SingleOpModel { + public: + LogicalOpModel(std::initializer_list input1_shape, + std::initializer_list input2_shape, BuiltinOperator op) { + input1_ = AddInput(TensorType_BOOL); + input2_ = AddInput(TensorType_BOOL); + output_ = AddOutput(TensorType_BOOL); + ConfigureBuiltinOp(op); + BuildInterpreter({input1_shape, input2_shape}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input1_; + int input2_; + int output_; + + void ConfigureBuiltinOp(BuiltinOperator op) { + switch (op) { + case BuiltinOperator_LOGICAL_OR: { + SetBuiltinOp(op, BuiltinOptions_LogicalOrOptions, + CreateLogicalOrOptions(builder_).Union()); + break; + } + case BuiltinOperator_LOGICAL_AND: { + SetBuiltinOp(op, BuiltinOptions_LogicalAndOptions, + CreateLogicalAndOptions(builder_).Union()); + break; + } + default: { FAIL() << "We shouldn't get here."; } + } + } +}; + +TEST(LogicalTest, LogicalOr) { + LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, BuiltinOperator_LOGICAL_OR); + model.PopulateTensor(model.input1(), {true, false, false, true}); + model.PopulateTensor(model.input2(), {true, false, true, false}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(LogicalTest, BroadcastLogicalOr) { + LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, BuiltinOperator_LOGICAL_OR); + model.PopulateTensor(model.input1(), {true, false, false, true}); + model.PopulateTensor(model.input2(), {false}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(LogicalTest, LogicalAnd) { + LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, BuiltinOperator_LOGICAL_AND); + model.PopulateTensor(model.input1(), {true, false, false, true}); + model.PopulateTensor(model.input2(), {true, false, true, false}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(LogicalTest, BroadcastLogicalAnd) { + LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, BuiltinOperator_LOGICAL_AND); + model.PopulateTensor(model.input1(), {true, false, false, true}); + model.PopulateTensor(model.input2(), {true}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc index 25d2dc2cdd699b4d9c8e83eb848fce0df3c59c15..69523b02cce0547fe87873e924deabb50cbeb4e5 100644 --- a/tensorflow/contrib/lite/kernels/lsh_projection.cc +++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc @@ -50,7 +50,6 @@ limitations under the License. // Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] } // A flattened tensor represents projected bit vectors. -#include #include #include #include diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index 3577ae6caa1e02ce2e5db2e8054ba9c2fccbe93e..ba251c451e549a09d265fc43fed7dc7eb6896d61 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -97,7 +96,7 @@ constexpr int kCellStateTensor = 1; constexpr int kOutputTensor = 2; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* op_data = new OpData; + auto* op_data = new OpData(); op_data->kernel_type = kTfLiteLSTMFullKernel; context->AddTensors(context, /*tensors_to_add=*/7, &op_data->scratch_tensor_index); @@ -306,7 +305,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int n_output = recurrent_to_output_weights->dims->data[1]; // Check that input tensor dimensions matches with each other. - CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); + TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input, + n_output, n_cell)); // Get the pointer to output, activation_state and cell_state tensors. TfLiteTensor* output = GetOutput(context, node, kOutputTensor); @@ -846,7 +846,7 @@ enum OutputTensor { }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* op_data = new OpData; + auto* op_data = new OpData(); op_data->kernel_type = kTfLiteLSTMBasicKernel; // `scratch_tensor_index` is unused in this kernel. op_data->scratch_tensor_index = -1; diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc index 0b7c56133e3cbb3d85f75657b6141620a8019e61..0266f5fe57e6c60ea19ad5f8de05e879e7da9304 100644 --- a/tensorflow/contrib/lite/kernels/lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Unit test for TFLite LSTM op. +// +// TODO(alanchiao): add unit test with invalid input dimensions for this and its +// variants. #include #include diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc index 1f72f3a3c7af4f9e042c9b2ac09252fab5de1a4f..349f3e672611b76ba9eb0019bbd55a5881ed6535 100644 --- a/tensorflow/contrib/lite/kernels/mul.cc +++ b/tensorflow/contrib/lite/kernels/mul.cc @@ -100,29 +100,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } template -void EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteMulParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRange(params->activation, &output_activation_min, - &output_activation_max); -#define TF_LITE_MUL(type, opname) \ - type::opname(GetTensorData(input1), GetTensorDims(input1), \ - GetTensorData(input2), GetTensorDims(input2), \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)) - if (kernel_type == kReference) { - if (data->requires_broadcast) { - TF_LITE_MUL(reference_ops, BroadcastMul); +void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params, + const OpData* data, const TfLiteTensor* input1, + const TfLiteTensor* input2, TfLiteTensor* output) { +#define TF_LITE_MUL(type, opname, data_type) \ + data_type output_activation_min, output_activation_max; \ + CalculateActivationRange(params->activation, &output_activation_min, \ + &output_activation_max); \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (output->type == kTfLiteInt32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_MUL(reference_ops, BroadcastMul, int32_t); + } else { + TF_LITE_MUL(reference_ops, Mul, int32_t); + } } else { - TF_LITE_MUL(reference_ops, Mul); + if (data->requires_broadcast) { + TF_LITE_MUL(optimized_ops, BroadcastMul, int32_t); + } else { + TF_LITE_MUL(optimized_ops, Mul, int32_t); + } } - } else { - if (data->requires_broadcast) { - TF_LITE_MUL(optimized_ops, BroadcastMul); + } else if (output->type == kTfLiteFloat32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_MUL(reference_ops, BroadcastMul, float); + } else { + TF_LITE_MUL(reference_ops, Mul, float); + } } else { - TF_LITE_MUL(optimized_ops, Mul); + if (data->requires_broadcast) { + TF_LITE_MUL(optimized_ops, BroadcastMul, float); + } else { + TF_LITE_MUL(optimized_ops, Mul, float); + } } } #undef TF_LITE_MUL @@ -194,17 +209,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - if (output->type == kTfLiteFloat32) { - EvalFloat(context, node, params, data, input1, input2, output); + if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { + EvalMul(context, node, params, data, input1, input2, output); } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { TF_LITE_ENSURE_OK( context, EvalQuantized(context, node, params, data, input1, input2, output)); } else { - context->ReportError( - context, - "Mul only supports FLOAT32 and quantized UINT8 and INT16 now, got %d.", - output->type); + context->ReportError(context, + "Mul only supports FLOAT32, INT32 and quantized UINT8 " + "and INT16 now, got %d.", + output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc index 43d56e50d2686ff2624f36a0c5d8e43279a572cc..2807550a6b07f3f9f1f1e3f72acc9882c76d166a 100644 --- a/tensorflow/contrib/lite/kernels/mul_test.cc +++ b/tensorflow/contrib/lite/kernels/mul_test.cc @@ -52,6 +52,13 @@ class FloatMulOpModel : public BaseMulOpModel { std::vector GetOutput() { return ExtractVector(output_); } }; +class IntegerMulOpModel : public BaseMulOpModel { + public: + using BaseMulOpModel::BaseMulOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + // For quantized Mul, the error shouldn't exceed (2*step + step^2). // The param min=-1.0 & max=1.0 is used in the following tests. // The tolerance value is ~0.0157. @@ -133,6 +140,57 @@ TEST(FloatMulOpTest, WithBroadcast) { } } +TEST(IntegerMulOpTest, NoActivation) { + IntegerMulOpModel m({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, + ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-20, 4, 21, 40})); +} + +TEST(IntegerMulOpTest, ActivationRELU_N1_TO_1) { + IntegerMulOpModel m({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, + ActivationFunctionType_RELU_N1_TO_1); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 1, 1, 1})); +} + +TEST(IntegerMulOpTest, VariousInputShapes) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + IntegerMulOpModel m({TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8, 11, 20}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5, 11, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-20, 4, 21, 40, 121, 20})) + << "With shape number " << i; + } +} + +TEST(IntegerMulOpTest, WithBroadcast) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + IntegerMulOpModel m({TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, {}}, // always a scalar + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8, 11, 20}); + m.PopulateTensor(m.input2(), {1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-20, 2, 7, 8, 11, 20}))) + << "With shape number " << i; + } +} + TEST(QuantizedMulOpTest, NoActivation) { QuantizedMulOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, diff --git a/tensorflow/contrib/lite/kernels/one_hot.cc b/tensorflow/contrib/lite/kernels/one_hot.cc new file mode 100644 index 0000000000000000000000000000000000000000..9ff3dca932d4284321b299cfc79571c43fce7155 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/one_hot.cc @@ -0,0 +1,199 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace one_hot { + +constexpr int kIndicesTensor = 0; +constexpr int kDepthTensor = 1; +constexpr int kOnValueTensor = 2; +constexpr int kOffValueTensor = 3; +constexpr int kOutputTensor = 0; + +// Convenience utility for destructuring a node into the appropriate tensors and +// data for the op. Note that this destructuring is quite cheap, so we can avoid +// allocating op-specific, persistent data on the heap. +struct OneHotContext { + OneHotContext(TfLiteContext* context, TfLiteNode* node) { + indices = GetInput(context, node, kIndicesTensor); + depth = GetInput(context, node, kDepthTensor); + on_value = GetInput(context, node, kOnValueTensor); + off_value = GetInput(context, node, kOffValueTensor); + output = GetOutput(context, node, kOutputTensor); + + const auto* params = + reinterpret_cast(node->builtin_data); + const int indices_dims = indices->dims->size; + axis = (params->axis == -1) ? indices_dims : params->axis; + output_dims = indices_dims + 1; + dtype = on_value->type; + } + + const TfLiteTensor* indices; + const TfLiteTensor* depth; + const TfLiteTensor* on_value; + const TfLiteTensor* off_value; + TfLiteTensor* output; + int axis; + int output_dims; + TfLiteType dtype; +}; + +template +void OneHotComputeImpl(const OneHotContext& op_context) { + // prefix_dim_size == # of elements before the axis + // depth == # of elements per axis + // suffix_dim_size == # of elements after the axis + int prefix_dim_size = 1; + for (int i = 0; i < op_context.axis; ++i) { + prefix_dim_size *= op_context.indices->dims->data[i]; + } + const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size; + const int depth = *op_context.depth->data.i32; + + const T on_value = *GetTensorData(op_context.on_value); + const T off_value = *GetTensorData(op_context.off_value); + + // View the indices as a matrix of size: + // prefix_dim_size x suffix_dim_size + // View the output as a matrix of size: + // prefix_dim_size x depth x suffix_dim_size + // Then the output is: + // output(i, j, k) == (indices(i, k) == j) ? on : off + T* output = GetTensorData(op_context.output); + const TI* indices = GetTensorData(op_context.indices); + for (int i = 0; i < prefix_dim_size; ++i) { + for (int j = 0; j < depth; ++j) { + for (int k = 0; k < suffix_dim_size; ++k, ++output) { + *output = static_cast(indices[i * suffix_dim_size + k]) == j + ? on_value + : off_value; + } + } + } +} + +template +void OneHotCompute(const OneHotContext& op_context) { + if (op_context.indices->type == kTfLiteInt64) { + OneHotComputeImpl(op_context); + } else { + OneHotComputeImpl(op_context); + } +} + +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + const OneHotContext& op_context) { + TF_LITE_ENSURE(context, *op_context.depth->data.i32 >= 0); + TfLiteIntArray* output_size = TfLiteIntArrayCreate(op_context.output_dims); + for (int i = 0; i < op_context.output_dims; ++i) { + if (i < op_context.axis) { + output_size->data[i] = op_context.indices->dims->data[i]; + } else if (i == op_context.axis) { + output_size->data[i] = *op_context.depth->data.i32; + } else { + output_size->data[i] = op_context.indices->dims->data[i - 1]; + } + } + return context->ResizeTensor(context, op_context.output, output_size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + OneHotContext op_context{context, node}; + switch (op_context.dtype) { + // TODO(b/111744875): Support uint8 and quantization. + case kTfLiteFloat32: + case kTfLiteInt16: + case kTfLiteInt32: + case kTfLiteInt64: + case kTfLiteBool: + op_context.output->type = op_context.dtype; + break; + default: + context->ReportError(context, "Unknown output data type: %d", + op_context.dtype); + return kTfLiteError; + } + + TF_LITE_ENSURE(context, op_context.indices->type == kTfLiteInt32 || + op_context.indices->type == kTfLiteInt64); + TF_LITE_ENSURE(context, op_context.axis >= 0 && + op_context.axis < op_context.output_dims); + TF_LITE_ENSURE_EQ(context, NumElements(op_context.depth), 1); + TF_LITE_ENSURE_EQ(context, NumElements(op_context.on_value), 1); + TF_LITE_ENSURE_EQ(context, NumElements(op_context.off_value), 1); + TF_LITE_ENSURE_EQ(context, op_context.on_value->type, op_context.dtype); + TF_LITE_ENSURE_EQ(context, op_context.off_value->type, op_context.dtype); + + if (!IsConstantTensor(op_context.depth)) { + SetTensorToDynamic(op_context.output); + return kTfLiteOk; + } + + return ResizeOutputTensor(context, op_context); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OneHotContext op_context{context, node}; + + if (IsDynamicTensor(op_context.output)) { + ResizeOutputTensor(context, op_context); + } + + switch (op_context.output->type) { + case kTfLiteFloat32: + OneHotCompute(op_context); + break; + case kTfLiteInt32: + OneHotCompute(op_context); + break; + case kTfLiteInt64: + OneHotCompute(op_context); + break; + case kTfLiteBool: + OneHotCompute(op_context); + break; + default: + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace one_hot + +TfLiteRegistration* Register_ONE_HOT() { + static TfLiteRegistration r = { + nullptr, + nullptr, + one_hot::Prepare, + one_hot::Eval, + }; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/one_hot_test.cc b/tensorflow/contrib/lite/kernels/one_hot_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6b604ec7a7f86b333805d91a95cb5054f0257ae4 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/one_hot_test.cc @@ -0,0 +1,182 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template +class OneHotOpModel : public SingleOpModel { + public: + OneHotOpModel(std::initializer_list input_shape, int depth_value, + TensorType dtype, int axis = -1, T on_value = 1, + T off_value = 0, TensorType indices_type = TensorType_INT32) { + indices_ = AddInput(indices_type); + int depth = AddInput(TensorType_INT32); + int on = AddInput(dtype); + int off = AddInput(dtype); + output_ = AddOutput(dtype); + SetBuiltinOp(BuiltinOperator_ONE_HOT, BuiltinOptions_OneHotOptions, + CreateOneHotOptions(builder_, axis).Union()); + BuildInterpreter({input_shape}); + + PopulateTensor(depth, {depth_value}); + PopulateTensor(on, {on_value}); + PopulateTensor(off, {off_value}); + } + + template + void SetIndices(std::initializer_list data) { + PopulateTensor(indices_, data); + } + + TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); } + + int32_t GetOutputSize() { return GetTensorSize(output_); } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int indices_; + int output_; +}; + +TEST(OneHotOpTest, BasicFloat) { + const int depth = 3; + OneHotOpModel model({3}, depth, TensorType_FLOAT32); + model.SetIndices({0, 1, 2}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3})); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f})); +} + +TEST(OneHotOpTest, BasicInt) { + const int depth = 3; + OneHotOpModel model({3}, depth, TensorType_INT32); + model.SetIndices({0, 1, 2}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1})); +} + +TEST(OneHotOpTest, BasicBool) { + const int depth = 3; + OneHotOpModel model({3}, depth, TensorType_BOOL); + model.SetIndices({0, 1, 2}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3})); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({true, false, false, false, true, false, false, + false, true})); +} + +TEST(OneHotOpTest, SmallDepth) { + const int depth = 1; + OneHotOpModel model({3}, depth, TensorType_INT32); + model.SetIndices({0, 1, 2}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0})); +} + +TEST(OneHotOpTest, BigDepth) { + const int depth = 4; + OneHotOpModel model({2}, depth, TensorType_INT32); + model.SetIndices({0, 1}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 0, 1, 0, 0})); +} + +TEST(OneHotOpTest, OnOffValues) { + const int depth = 3; + const int axis = -1; + const int on = 5; + const int off = 0; + OneHotOpModel model({4}, depth, TensorType_INT32, axis, on, off); + model.SetIndices({0, 2, -1, 1}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({4, 3})); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({5, 0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0})); +} + +TEST(OneHotOpTest, ZeroAxis) { + const int depth = 3; + const int axis = 0; + const int on = 5; + const int off = 0; + OneHotOpModel model({4}, depth, TensorType_INT32, axis, on, off); + model.SetIndices({0, 2, -1, 1}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 4})); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({5, 0, 0, 0, 0, 0, 0, 5, 0, 5, 0, 0})); +} + +TEST(OneHotOpTest, MultiDimensionalIndices) { + const int depth = 3; + const int axis = -1; + const float on = 2; + const float off = 0; + OneHotOpModel model({2, 2}, depth, TensorType_FLOAT32, axis, on, off); + model.SetIndices({0, 2, 1, -1}); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 3})); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({2, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0})); +} + +TEST(OneHotOpTest, Int64Indices) { + const int depth = 3; + const int axis = -1; + const int on = 1; + const int off = 0; + OneHotOpModel model({3}, depth, TensorType_INT32, axis, on, off, + TensorType_INT64); + std::initializer_list indices = {0, 1, 2}; + model.SetIndices(indices); + model.Invoke(); + + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3})); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/pack.cc b/tensorflow/contrib/lite/kernels/pack.cc new file mode 100644 index 0000000000000000000000000000000000000000..bb3416f6a6ca60250f137986e479e8f1085e2558 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pack.cc @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace pack { +namespace { + +constexpr int kOutputTensor = 0; + +// Op data for pack op. +struct OpData { + int values_count; + int axis; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->axis = 0; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const OpData* data = reinterpret_cast(node->builtin_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), data->values_count); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* input0 = GetInput(context, node, 0); + TF_LITE_ENSURE(context, NumDimensions(input0) < 4); + TF_LITE_ENSURE(context, NumDimensions(input0) >= data->axis); + // TODO(renjieliu): Support negative axis. + TF_LITE_ENSURE(context, data->axis >= 0); + if (input0->type != kTfLiteInt32 && input0->type != kTfLiteFloat32) { + context->ReportError(context, + "Currently pack only supports int32 and float32."); + return kTfLiteError; + } + // Make sure all inputs have the same shape and type. + for (int i = 1; i < data->values_count; ++i) { + const TfLiteTensor* input = GetInput(context, node, i); + TF_LITE_ENSURE(context, HaveSameShapes(input0, input)); + TF_LITE_ENSURE_EQ(context, input0->type, input->type); + } + + // Resize output. rank R will become rank R + 1 + const int dimension_size = NumDimensions(input0) + 1; + const TfLiteIntArray* input_shape = input0->dims; + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dimension_size); + int i = 0; + for (int index = 0; index < dimension_size; ++index) { + if (index == data->axis) { + output_shape->data[index] = data->values_count; + } else { + output_shape->data[index] = input_shape->data[i++]; + } + } + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, output->type, input0->type); + + return context->ResizeTensor(context, output, output_shape); +} + +template +void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output, + int values_count, int axis) { + VectorOfTensors all_inputs(*context, *node->inputs); + reference_ops::Pack(RemapDim(NumDimensions(output), axis), + all_inputs.data(), all_inputs.dims(), values_count, + GetTensorData(output), GetTensorDims(output)); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const OpData* data = reinterpret_cast(node->builtin_data); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + switch (output->type) { + case kTfLiteFloat32: { + PackImpl(context, node, output, data->values_count, data->axis); + break; + } + case kTfLiteInt32: { + PackImpl(context, node, output, data->values_count, data->axis); + break; + } + default: { + context->ReportError(context, + "Currently pack only supports int32 and float32."); + return kTfLiteError; + } + } + + return kTfLiteOk; +} + +} // namespace +} // namespace pack + +TfLiteRegistration* Register_PACK() { + static TfLiteRegistration r = {pack::Init, pack::Free, pack::Prepare, + pack::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/pack_test.cc b/tensorflow/contrib/lite/kernels/pack_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..485a50ad3ac493fd02f619f7d7cbaf10d3a6aff0 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pack_test.cc @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +template +class PackOpModel : public SingleOpModel { + public: + PackOpModel(const TensorData& input_template, int axis, int values_count) { + std::vector> all_input_shapes; + for (int i = 0; i < values_count; ++i) { + all_input_shapes.push_back(input_template.shape); + AddInput(input_template); + } + output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min, + input_template.max}); + SetBuiltinOp(BuiltinOperator_PACK, BuiltinOptions_PackOptions, + CreatePackOptions(builder_, values_count, axis).Union()); + BuildInterpreter(all_input_shapes); + } + + void SetInput(int index, std::initializer_list data) { + PopulateTensor(index, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int output_; +}; + +TEST(PackOpTest, FloatThreeInputs) { + PackOpModel model({TensorType_FLOAT32, {2}}, 0, 3); + model.SetInput(0, {1, 4}); + model.SetInput(1, {2, 5}); + model.SetInput(2, {3, 6}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6})); +} + +TEST(PackOpTest, FloatThreeInputsDifferentAxis) { + PackOpModel model({TensorType_FLOAT32, {2}}, 1, 3); + model.SetInput(0, {1, 4}); + model.SetInput(1, {2, 5}); + model.SetInput(2, {3, 6}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(PackOpTest, FloatMultilDimensions) { + PackOpModel model({TensorType_FLOAT32, {2, 3}}, 1, 2); + model.SetInput(0, {1, 2, 3, 4, 5, 6}); + model.SetInput(1, {7, 8, 9, 10, 11, 12}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 3)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); +} + +TEST(PackOpTest, IntThreeInputs) { + PackOpModel model({TensorType_INT32, {2}}, 0, 3); + model.SetInput(0, {1, 4}); + model.SetInput(1, {2, 5}); + model.SetInput(2, {3, 6}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(3, 2)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6})); +} + +TEST(PackOpTest, IntThreeInputsDifferentAxis) { + PackOpModel model({TensorType_INT32, {2}}, 1, 3); + model.SetInput(0, {1, 4}); + model.SetInput(1, {2, 5}); + model.SetInput(2, {3, 6}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 3)); + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(PackOpTest, IntMultilDimensions) { + PackOpModel model({TensorType_INT32, {2, 3}}, 1, 2); + model.SetInput(0, {1, 2, 3, 4, 5, 6}); + model.SetInput(1, {7, 8, 9, 10, 11, 12}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(2, 2, 3)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc index 7240fe04ccdadfb7b9703c3f2775c4b3502bd1d9..29a5be068368365e67ad0653b775afe1e976f23a 100644 --- a/tensorflow/contrib/lite/kernels/pooling.cc +++ b/tensorflow/contrib/lite/kernels/pooling.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -126,13 +125,19 @@ void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node, float activation_min, activation_max; CalculateActivationRange(params->activation, &activation_min, &activation_max); -#define TF_LITE_AVERAGE_POOL(type) \ - type::AveragePool(GetTensorData(input), GetTensorShape(input), \ - params->stride_width, params->stride_height, \ - data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, \ - activation_min, activation_max, \ - GetTensorData(output), GetTensorShape(output)) +#define TF_LITE_AVERAGE_POOL(type) \ + tflite::PoolParams op_params; \ + op_params.stride_height = params->stride_height; \ + op_params.stride_width = params->stride_width; \ + op_params.filter_height = params->filter_height; \ + op_params.filter_width = params->filter_width; \ + op_params.padding_values.height = data->padding.height; \ + op_params.padding_values.width = data->padding.width; \ + op_params.float_activation_min = activation_min; \ + op_params.float_activation_max = activation_max; \ + type::AveragePool(op_params, GetTensorShape(input), \ + GetTensorData(input), GetTensorShape(output), \ + GetTensorData(output)) if (kernel_type == kReference) { TF_LITE_AVERAGE_POOL(reference_ops); } else { @@ -149,13 +154,19 @@ void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node, int32_t activation_max; CalculateActivationRangeUint8(params->activation, output, &activation_min, &activation_max); -#define TF_LITE_AVERAGE_POOL(type) \ - type::AveragePool(GetTensorData(input), GetTensorShape(input), \ - params->stride_width, params->stride_height, \ - data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, \ - activation_min, activation_max, \ - GetTensorData(output), GetTensorShape(output)) +#define TF_LITE_AVERAGE_POOL(type) \ + tflite::PoolParams op_params; \ + op_params.stride_height = params->stride_height; \ + op_params.stride_width = params->stride_width; \ + op_params.filter_height = params->filter_height; \ + op_params.filter_width = params->filter_width; \ + op_params.padding_values.height = data->padding.height; \ + op_params.padding_values.width = data->padding.width; \ + op_params.quantized_activation_min = activation_min; \ + op_params.quantized_activation_max = activation_max; \ + type::AveragePool(op_params, GetTensorShape(input), \ + GetTensorData(input), GetTensorShape(output), \ + GetTensorData(output)) if (kernel_type == kReference) { TF_LITE_AVERAGE_POOL(reference_ops); } else { @@ -171,13 +182,18 @@ void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node, float activation_min, activation_max; CalculateActivationRange(params->activation, &activation_min, &activation_max); -#define TF_LITE_MAX_POOL(type) \ - type::MaxPool(GetTensorData(input), GetTensorShape(input), \ - params->stride_width, params->stride_height, \ - data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, activation_min, \ - activation_max, GetTensorData(output), \ - GetTensorShape(output)) +#define TF_LITE_MAX_POOL(type) \ + tflite::PoolParams op_params; \ + op_params.stride_height = params->stride_height; \ + op_params.stride_width = params->stride_width; \ + op_params.filter_height = params->filter_height; \ + op_params.filter_width = params->filter_width; \ + op_params.padding_values.height = data->padding.height; \ + op_params.padding_values.width = data->padding.width; \ + op_params.float_activation_min = activation_min; \ + op_params.float_activation_max = activation_max; \ + type::MaxPool(op_params, GetTensorShape(input), GetTensorData(input), \ + GetTensorShape(output), GetTensorData(output)) if (kernel_type == kReference) { TF_LITE_MAX_POOL(reference_ops); } else { @@ -194,13 +210,19 @@ void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node, int32_t activation_max; CalculateActivationRangeUint8(params->activation, output, &activation_min, &activation_max); -#define TF_LITE_MAX_POOL(type) \ - type::MaxPool(GetTensorData(input), GetTensorShape(input), \ - params->stride_width, params->stride_height, \ - data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, activation_min, \ - activation_max, GetTensorData(output), \ - GetTensorShape(output)) +#define TF_LITE_MAX_POOL(type) \ + tflite::PoolParams op_params; \ + op_params.stride_height = params->stride_height; \ + op_params.stride_width = params->stride_width; \ + op_params.filter_height = params->filter_height; \ + op_params.filter_width = params->filter_width; \ + op_params.padding_values.height = data->padding.height; \ + op_params.padding_values.width = data->padding.width; \ + op_params.quantized_activation_min = activation_min; \ + op_params.quantized_activation_max = activation_max; \ + type::MaxPool(op_params, GetTensorShape(input), \ + GetTensorData(input), GetTensorShape(output), \ + GetTensorData(output)) if (kernel_type == kReference) { TF_LITE_MAX_POOL(reference_ops); } else { @@ -216,13 +238,18 @@ void L2EvalFloat(TfLiteContext* context, TfLiteNode* node, float activation_min, activation_max; CalculateActivationRange(params->activation, &activation_min, &activation_max); -#define TF_LITE_L2_POOL(type) \ - type::L2Pool(GetTensorData(input), GetTensorShape(input), \ - params->stride_width, params->stride_height, \ - data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, activation_min, \ - activation_max, GetTensorData(output), \ - GetTensorShape(output)) +#define TF_LITE_L2_POOL(type) \ + tflite::PoolParams op_params; \ + op_params.stride_height = params->stride_height; \ + op_params.stride_width = params->stride_width; \ + op_params.filter_height = params->filter_height; \ + op_params.filter_width = params->filter_width; \ + op_params.padding_values.height = data->padding.height; \ + op_params.padding_values.width = data->padding.width; \ + op_params.float_activation_min = activation_min; \ + op_params.float_activation_max = activation_max; \ + type::L2Pool(op_params, GetTensorShape(input), GetTensorData(input), \ + GetTensorShape(output), GetTensorData(output)) if (kernel_type == kReference) { TF_LITE_L2_POOL(reference_ops); } else { diff --git a/tensorflow/contrib/lite/kernels/pow_test.cc b/tensorflow/contrib/lite/kernels/pow_test.cc index 474d323bc3a1a0f224aa0575a5bbd35394aa2f53..74b3aef5bd39d8bdb6375f24bd00d793889deef8 100644 --- a/tensorflow/contrib/lite/kernels/pow_test.cc +++ b/tensorflow/contrib/lite/kernels/pow_test.cc @@ -50,22 +50,22 @@ class PowOpModel : public SingleOpModel { }; TEST(PowOpModel, Simple) { - PowOpModel model({TensorType_INT32, {1, 2, 2, 1}}, - {TensorType_INT32, {1, 2, 2, 1}}, - {TensorType_INT32, {}}); - model.PopulateTensor(model.input1(), {12, 2, 7, 8}); - model.PopulateTensor(model.input2(), {1, 2, 3, 1}); + PowOpModel model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {}}); + model.PopulateTensor(model.input1(), {12, 2, 7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 1}); model.Invoke(); EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); EXPECT_THAT(model.GetOutput(), ElementsAre(12, 4, 343, 8)); } TEST(PowOpModel, NegativeAndZeroValue) { - PowOpModel model({TensorType_INT32, {1, 2, 2, 1}}, - {TensorType_INT32, {1, 2, 2, 1}}, - {TensorType_INT32, {}}); - model.PopulateTensor(model.input1(), {0, 2, -7, 8}); - model.PopulateTensor(model.input2(), {1, 2, 3, 0}); + PowOpModel model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {}}); + model.PopulateTensor(model.input1(), {0, 2, -7, 8}); + model.PopulateTensor(model.input2(), {1, 2, 3, 0}); model.Invoke(); EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); EXPECT_THAT(model.GetOutput(), ElementsAre(0, 4, -343, 1)); @@ -98,10 +98,10 @@ TEST(PowOpModel, NegativeFloatTest) { } TEST(PowOpModel, BroadcastTest) { - PowOpModel model({TensorType_INT32, {1, 2, 2, 1}}, - {TensorType_INT32, {1}}, {TensorType_INT32, {}}); - model.PopulateTensor(model.input1(), {12, 2, 7, 8}); - model.PopulateTensor(model.input2(), {4}); + PowOpModel model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1}}, {TensorType_INT32, {}}); + model.PopulateTensor(model.input1(), {12, 2, 7, 8}); + model.PopulateTensor(model.input2(), {4}); model.Invoke(); EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); EXPECT_THAT(model.GetOutput(), ElementsAre(20736, 16, 2401, 4096)); diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc index 31c331a8c61ded203af9ff2ae127cb6f985e2932..e99f67c7258c555903069dff67a86a3703249c7c 100644 --- a/tensorflow/contrib/lite/kernels/reduce.cc +++ b/tensorflow/contrib/lite/kernels/reduce.cc @@ -78,6 +78,10 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, OpContext* op_context) { size_t num_axis = NumElements(op_context->axis); const TfLiteIntArray* input_dims = op_context->input->dims; int input_num_dims = NumDimensions(op_context->input); + if (input_num_dims == 0) { + return context->ResizeTensor(context, op_context->output, + TfLiteIntArrayCreate(0)); + } const int* axis = GetTensorData(op_context->axis); if (op_context->params->keep_dims) { TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_num_dims); @@ -315,6 +319,99 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +template +TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + int64_t num_axis = NumElements(op_context.axis); + TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); + TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, + ResizeTempAxis(context, &op_context, resolved_axis)); + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + +#define TF_LITE_PROD(kernel_type, data_type) \ + kernel_type::ReduceProd<>( \ + GetTensorData(op_context.input), \ + op_context.input->dims->data, op_context.input->dims->size, \ + GetTensorData(op_context.output), \ + op_context.output->dims->data, op_context.output->dims->size, \ + GetTensorData(op_context.axis), num_axis, \ + op_context.params->keep_dims, GetTensorData(temp_index), \ + GetTensorData(resolved_axis)) + + if (kernel_type == kReference) { + switch (op_context.input->type) { + case kTfLiteFloat32: + TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, float)); + break; + case kTfLiteInt32: + TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int)); + break; + case kTfLiteInt64: + TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int64_t)); + break; + case kTfLiteUInt8: + // TODO(wangtz): uint8 reduce_prod is not yet supported. + default: + return kTfLiteError; + } + } +#undef TF_LITE_PROD + return kTfLiteOk; +} + +template +TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + int64_t num_axis = NumElements(op_context.axis); + TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); + TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, + ResizeTempAxis(context, &op_context, resolved_axis)); + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + +#define TF_LITE_MAX(kernel_type, data_type) \ + kernel_type::ReduceMax<>( \ + GetTensorData(op_context.input), \ + op_context.input->dims->data, op_context.input->dims->size, \ + GetTensorData(op_context.output), \ + op_context.output->dims->data, op_context.output->dims->size, \ + GetTensorData(op_context.axis), num_axis, \ + op_context.params->keep_dims, GetTensorData(temp_index), \ + GetTensorData(resolved_axis)) + + if (kernel_type == kReference) { + switch (op_context.input->type) { + case kTfLiteFloat32: + TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, float)); + break; + case kTfLiteInt32: + TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int)); + break; + case kTfLiteInt64: + TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int64_t)); + break; + case kTfLiteUInt8: + TF_LITE_ENSURE_EQ(context, op_context.input->params.scale, + op_context.output->params.scale); + TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, + op_context.output->params.zero_point); + TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, uint8_t)); + break; + default: + return kTfLiteError; + } + } +#undef TF_LITE_MAX + return kTfLiteOk; +} + } // namespace reduce TfLiteRegistration* Register_MEAN_REF() { @@ -331,9 +428,27 @@ TfLiteRegistration* Register_SUM_REF() { return &r; } +TfLiteRegistration* Register_REDUCE_PROD_REF() { + static TfLiteRegistration r = {reduce::Init, reduce::Free, + reduce::PrepareSimple, + reduce::EvalProd}; + return &r; +} + +TfLiteRegistration* Register_REDUCE_MAX_REF() { + static TfLiteRegistration r = {reduce::Init, reduce::Free, + reduce::PrepareSimple, + reduce::EvalMax}; + return &r; +} + // TODO(kanlig): add optimized implementation of Mean. TfLiteRegistration* Register_MEAN() { return Register_MEAN_REF(); } TfLiteRegistration* Register_SUM() { return Register_SUM_REF(); } +TfLiteRegistration* Register_REDUCE_PROD() { + return Register_REDUCE_PROD_REF(); +} +TfLiteRegistration* Register_REDUCE_MAX() { return Register_REDUCE_MAX_REF(); } } // namespace builtin } // namespace ops diff --git a/tensorflow/contrib/lite/kernels/reduce_test.cc b/tensorflow/contrib/lite/kernels/reduce_test.cc index 9e946822c686f6f20505d60b6161239624c94696..5d432d34ef5118e7164d7f767dad6017aa640e51 100644 --- a/tensorflow/contrib/lite/kernels/reduce_test.cc +++ b/tensorflow/contrib/lite/kernels/reduce_test.cc @@ -22,13 +22,14 @@ namespace tflite { namespace { using ::testing::ElementsAreArray; +using ::testing::IsEmpty; class BaseOpModel : public SingleOpModel { public: - void SetAxis(std::initializer_list data) { PopulateTensor(axis_, data); } + void SetAxis(const std::vector& data) { PopulateTensor(axis_, data); } template - void SetInput(std::initializer_list data) { + void SetInput(std::vector data) { PopulateTensor(input_, data); } @@ -110,14 +111,72 @@ class SumOpDynamicModel : public BaseOpModel { } }; +// Model for the tests case where axis is a const tensor. +class ProdOpConstModel : public BaseOpModel { + public: + ProdOpConstModel(const TensorData& input, const TensorData& output, + std::initializer_list axis_shape, + std::initializer_list axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_REDUCE_PROD, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// Model for the tests case where axis is a dynamic tensor. +class ProdOpDynamicModel : public BaseOpModel { + public: + ProdOpDynamicModel(const TensorData& input, const TensorData& output, + const TensorData& axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddInput(axis); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_REDUCE_PROD, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// Model for the tests case where axis is a const tensor. +class MaxOpConstModel : public BaseOpModel { + public: + MaxOpConstModel(const TensorData& input, const TensorData& output, + std::initializer_list axis_shape, + std::initializer_list axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_REDUCE_MAX, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// Model for the tests case where axis is a dynamic tensor. +class MaxOpDynamicModel : public BaseOpModel { + public: + MaxOpDynamicModel(const TensorData& input, const TensorData& output, + const TensorData& axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddInput(axis); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_REDUCE_MAX, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + // for quantized Add, the error shouldn't exceed step float GetTolerance(int min, int max) { return (max - min) / 255.0; } // Tests for reduce_mean TEST(ConstFloatMeanOpTest, NotKeepDims) { - std::initializer_list data = { - 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, - 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, {4}, {1, 0, -3, -3}, false); m.SetInput(data); @@ -127,9 +186,9 @@ TEST(ConstFloatMeanOpTest, NotKeepDims) { } TEST(ConstFloatMeanOpTest, KeepDims) { - std::initializer_list data = { - 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, - 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, {2}, {0, 2}, true); m.SetInput(data); @@ -139,14 +198,24 @@ TEST(ConstFloatMeanOpTest, KeepDims) { ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); } +TEST(ConstFloatMeanOpTest, Scalar) { + std::vector data = {3.27}; + MeanOpConstModel m({TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, {}, + {0}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3.27}))); +} + TEST(DynamicFloatMeanOpTest, NotKeepDims) { - std::initializer_list data = { - 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, - 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}}, false); - std::initializer_list axis = {1, 0, -3, -3}; + std::vector axis = {1, 0, -3, -3}; m.SetAxis(axis); m.SetInput(data); m.Invoke(); @@ -155,13 +224,13 @@ TEST(DynamicFloatMeanOpTest, NotKeepDims) { } TEST(DynamicFloatMeanOpTest, KeepDims) { - std::initializer_list data = { - 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, - 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, true); - std::initializer_list axis = {0, 2}; + std::vector axis = {0, 2}; m.SetAxis(axis); m.SetInput(data); m.Invoke(); @@ -171,10 +240,10 @@ TEST(DynamicFloatMeanOpTest, KeepDims) { } TEST(DynamicFloatMeanOpTest, Scale) { - std::initializer_list data = {9.527}; + std::vector data = {9.527}; MeanOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}}, {TensorType_INT32, {1}}, true); - std::initializer_list axis = {0}; + std::vector axis = {0}; m.SetAxis(axis); m.SetInput(data); m.Invoke(); @@ -185,7 +254,7 @@ TEST(DynamicFloatMeanOpTest, Scale) { TEST(ConstUint8MeanOpTest, NotKeepDims) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); - std::initializer_list data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; MeanOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0}, {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false); m.QuantizeAndPopulate(m.Input(), data); @@ -197,7 +266,7 @@ TEST(ConstUint8MeanOpTest, NotKeepDims) { TEST(ConstUint8MeanOpTest, KeepDims) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); - std::initializer_list data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; MeanOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0}, {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true); m.QuantizeAndPopulate(m.Input(), data); @@ -210,11 +279,11 @@ TEST(ConstUint8MeanOpTest, KeepDims) { TEST(DynamicUint8MeanOpTest, NotKeepDims) { float kQuantizedTolerance = GetTolerance(-5.0, 2.0); - std::initializer_list data = {1.3, -4.8, -3.6, 0.24}; + std::vector data = {1.3, -4.8, -3.6, 0.24}; MeanOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0}, {TensorType_UINT8, {2}, -5.0, 2.0}, {TensorType_INT32, {1}}, false); - std::initializer_list axis = {1}; + std::vector axis = {1}; m.SetAxis(axis); m.QuantizeAndPopulate(m.Input(), data); m.Invoke(); @@ -226,11 +295,11 @@ TEST(DynamicUint8MeanOpTest, NotKeepDims) { TEST(DynamicUint8MeanOpTest, KeepDims) { float kQuantizedTolerance = GetTolerance(-10.0, 12.0); - std::initializer_list data = {11.14, -0.14, 7.423, 0.879}; + std::vector data = {11.14, -0.14, 7.423, 0.879}; MeanOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0}, {TensorType_UINT8, {2}, -10.0, 12.0}, {TensorType_INT32, {1}}, true); - std::initializer_list axis = {0}; + std::vector axis = {0}; m.SetAxis(axis); m.QuantizeAndPopulate(m.Input(), data); m.Invoke(); @@ -243,9 +312,9 @@ TEST(DynamicUint8MeanOpTest, KeepDims) { // Tests for reduce_sum TEST(ConstFloatSumOpTest, NotKeepDims) { - std::initializer_list data = { - 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, - 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; SumOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, {4}, {1, 0, -3, -3}, false); m.SetInput(data); @@ -256,9 +325,9 @@ TEST(ConstFloatSumOpTest, NotKeepDims) { } TEST(ConstFloatSumOpTest, KeepDims) { - std::initializer_list data = { - 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, - 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; SumOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, {2}, {0, 2}, true); m.SetInput(data); @@ -269,13 +338,13 @@ TEST(ConstFloatSumOpTest, KeepDims) { } TEST(DynamicFloatSumOpTest, NotKeepDims) { - std::initializer_list data = { - 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, - 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; SumOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}}, false); - std::initializer_list axis = {1, 0, -3, -3}; + std::vector axis = {1, 0, -3, -3}; m.SetAxis(axis); m.SetInput(data); m.Invoke(); @@ -284,13 +353,23 @@ TEST(DynamicFloatSumOpTest, NotKeepDims) { ElementsAreArray(ArrayFloatNear({144, 156}))); } +TEST(ConstFloatSumOpTest, Scalar) { + std::vector data = {17.}; + SumOpConstModel m({TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, {}, {0}, + false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({17.}))); +} + TEST(DynamicFloatSumOpTest, KeepDims) { - std::initializer_list data = { - 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, - 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; SumOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, true); - std::initializer_list axis = {0, 2}; + std::vector axis = {0, 2}; m.SetAxis(axis); m.SetInput(data); m.Invoke(); @@ -300,10 +379,10 @@ TEST(DynamicFloatSumOpTest, KeepDims) { } TEST(DynamicFloatSumOpTest, Scale) { - std::initializer_list data = {9.527}; + std::vector data = {9.527}; SumOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}}, {TensorType_INT32, {1}}, true); - std::initializer_list axis = {0}; + std::vector axis = {0}; m.SetAxis(axis); m.SetInput(data); m.Invoke(); @@ -313,7 +392,7 @@ TEST(DynamicFloatSumOpTest, Scale) { TEST(ConstUint8SumOpTest, NotKeepDims) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); - std::initializer_list data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; SumOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0}, {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false); m.QuantizeAndPopulate(m.Input(), data); @@ -326,7 +405,7 @@ TEST(ConstUint8SumOpTest, NotKeepDims) { TEST(ConstUint8SumOpTest, KeepDims) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); - std::initializer_list data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; SumOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0}, {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true); m.QuantizeAndPopulate(m.Input(), data); @@ -339,11 +418,11 @@ TEST(ConstUint8SumOpTest, KeepDims) { TEST(DynamicUint8SumOpTest, NotKeepDims) { float kQuantizedTolerance = GetTolerance(-5.0, 2.0); - std::initializer_list data = {1.3, -4.8, -3.6, 0.24}; + std::vector data = {1.3, -4.8, -3.6, 0.24}; SumOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0}, {TensorType_UINT8, {2}, -5.0, 2.0}, {TensorType_INT32, {1}}, false); - std::initializer_list axis = {1}; + std::vector axis = {1}; m.SetAxis(axis); m.QuantizeAndPopulate(m.Input(), data); m.Invoke(); @@ -355,11 +434,11 @@ TEST(DynamicUint8SumOpTest, NotKeepDims) { TEST(DynamicUint8SumOpTest, KeepDims) { float kQuantizedTolerance = GetTolerance(-10.0, 12.0); - std::initializer_list data = {11.14, -0.14, 7.423, 0.879}; + std::vector data = {11.14, -0.14, 7.423, 0.879}; SumOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0}, {TensorType_UINT8, {2}, -10.0, 12.0}, {TensorType_INT32, {1}}, true); - std::initializer_list axis = {0}; + std::vector axis = {0}; m.SetAxis(axis); m.QuantizeAndPopulate(m.Input(), data); m.Invoke(); @@ -369,6 +448,223 @@ TEST(DynamicUint8SumOpTest, KeepDims) { ElementsAreArray(ArrayFloatNear({6.47059, 10.698}, kQuantizedTolerance))); } +// Tests for reduce_prod + +TEST(ConstFloatProdOpTest, NotKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + ProdOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, + {4}, {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3.162341376e+11, 1.9619905536e+12}))); +} + +TEST(ConstFloatProdOpTest, KeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + ProdOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, + {2}, {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray( + ArrayFloatNear({7.74592e+06, 1.197504e+08, 6.6889152e+08}))); +} + +TEST(DynamicFloatProdOpTest, NotKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + ProdOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}}, + false); + std::vector axis = {1, 0, -3, -3}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3.16234143225e+11, 1.9619905536e+12}))); +} + +TEST(DynamicFloatProdOpTest, KeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + ProdOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, + true); + std::vector axis = {0, 2}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray( + ArrayFloatNear({7.74592e+06, 1.197504e+08, 6.6889152e+08}))); +} + +TEST(DynamicFloatProdOpTest, Scale) { + std::vector data = {9.527}; + ProdOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}}, + {TensorType_INT32, {1}}, true); + std::vector axis = {0}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({9.527}))); +} + +// Tests for reduce_max + +TEST(ConstFloatMaxOpTest, NotKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MaxOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, + {4}, {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({23, 24}))); +} + +TEST(ConstFloatMaxOpTest, KeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MaxOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, + {2}, {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({20, 22, 24}))); +} + +TEST(DynamicFloatMaxOpTest, NotKeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MaxOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}}, + false); + std::vector axis = {1, 0, -3, -3}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({23, 24}))); +} + +TEST(DynamicFloatMaxOpTest, KeepDims) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MaxOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, true); + std::vector axis = {0, 2}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({20, 22, 24}))); +} + +TEST(DynamicFloatMaxOpTest, Scale) { + std::vector data = {9.527}; + MaxOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}}, + {TensorType_INT32, {1}}, true); + std::vector axis = {0}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({9.527}))); +} + +TEST(ConstUint8MaxOpTest, NotKeepDims) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + MaxOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0}, + {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({0.501961, 0.603922}, kQuantizedTolerance))); +} + +TEST(ConstUint8MaxOpTest, KeepDims) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + MaxOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0}, + {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({0.4, 0.4, 0.603922}, kQuantizedTolerance))); +} + +TEST(DynamicUint8MaxOpTest, NotKeepDims) { + float kQuantizedTolerance = GetTolerance(-5.0, 2.0); + std::vector data = {1.3, -4.8, -3.6, 0.24}; + MaxOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0}, + {TensorType_UINT8, {2}, -5.0, 2.0}, + {TensorType_INT32, {1}}, false); + std::vector axis = {1}; + m.SetAxis(axis); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({1.2902, 0.247059}, kQuantizedTolerance))); +} + +TEST(DynamicUint8MaxOpTest, KeepDims) { + float kQuantizedTolerance = GetTolerance(-10.0, 12.0); + std::vector data = {11.14, -0.14, 7.423, 0.879}; + MaxOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0}, + {TensorType_UINT8, {2}, -10.0, 12.0}, + {TensorType_INT32, {1}}, true); + std::vector axis = {0}; + m.SetAxis(axis); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({11.1294, 0.862745}, kQuantizedTolerance))); +} + +TEST(DynamicUint8MaxOpTest, Scalar) { + float kQuantizedTolerance = GetTolerance(-10.0, 12.0); + std::vector data = {11.14}; + MaxOpDynamicModel m({TensorType_UINT8, {}, -10.0, 12.0}, + {TensorType_UINT8, {}, -10.0, 12.0}, + {TensorType_INT32, {1}}, true); + std::vector axis = {0}; + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({11.1294}, kQuantizedTolerance))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 0ca08cd8f38216549b4383ebaacbf4c54442cd97..8d2c108116e1666f342392ada44854190a5b80ee 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -82,6 +82,7 @@ TfLiteRegistration* Register_PRELU(); TfLiteRegistration* Register_MAXIMUM(); TfLiteRegistration* Register_MINIMUM(); TfLiteRegistration* Register_ARG_MAX(); +TfLiteRegistration* Register_ARG_MIN(); TfLiteRegistration* Register_GREATER(); TfLiteRegistration* Register_GREATER_EQUAL(); TfLiteRegistration* Register_LESS(); @@ -90,6 +91,8 @@ TfLiteRegistration* Register_FLOOR(); TfLiteRegistration* Register_TILE(); TfLiteRegistration* Register_NEG(); TfLiteRegistration* Register_SUM(); +TfLiteRegistration* Register_REDUCE_PROD(); +TfLiteRegistration* Register_REDUCE_MAX(); TfLiteRegistration* Register_SELECT(); TfLiteRegistration* Register_SLICE(); TfLiteRegistration* Register_SIN(); @@ -102,6 +105,39 @@ TfLiteRegistration* Register_SQRT(); TfLiteRegistration* Register_RSQRT(); TfLiteRegistration* Register_SHAPE(); TfLiteRegistration* Register_POW(); +TfLiteRegistration* Register_FAKE_QUANT(); +TfLiteRegistration* Register_PACK(); +TfLiteRegistration* Register_ONE_HOT(); +TfLiteRegistration* Register_LOGICAL_OR(); +TfLiteRegistration* Register_LOGICAL_AND(); +TfLiteRegistration* Register_LOGICAL_NOT(); + +TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) { + context->ReportError( + context, + "Regular TensorFlow ops are not supported by this interpreter. Make sure " + "you invoke the Eager delegate before inference."); + return kTfLiteError; +} + +const TfLiteRegistration* BuiltinOpResolver::FindOp(tflite::BuiltinOperator op, + int version) const { + return MutableOpResolver::FindOp(op, version); +} + +const TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op, + int version) const { + // Return the NULL Op for all ops whose name start with "Eager:", allowing + // the interpreter to delegate their execution. + if (string(op).find("Eager:") == 0) { + static TfLiteRegistration null_op{ + nullptr, nullptr, &UnsupportedTensorFlowOp, + nullptr, nullptr, BuiltinOperator_CUSTOM, + "Eager", 1}; + return &null_op; + } + return MutableOpResolver::FindOp(op, version); +} BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -167,6 +203,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM()); AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM()); AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX()); + AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN()); AddBuiltin(BuiltinOperator_GREATER, Register_GREATER()); AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL()); AddBuiltin(BuiltinOperator_LESS, Register_LESS()); @@ -179,6 +216,8 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV()); AddBuiltin(BuiltinOperator_TILE, Register_TILE()); AddBuiltin(BuiltinOperator_SUM, Register_SUM()); + AddBuiltin(BuiltinOperator_REDUCE_PROD, Register_REDUCE_PROD()); + AddBuiltin(BuiltinOperator_REDUCE_MAX, Register_REDUCE_MAX()); AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS()); AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE()); AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL()); @@ -187,6 +226,12 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT()); AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE()); AddBuiltin(BuiltinOperator_POW, Register_POW()); + AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2); + AddBuiltin(BuiltinOperator_PACK, Register_PACK()); + AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT()); + AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR()); + AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND()); + AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h index 940718d67e70b7206227b891ea529cb9e9619161..0296152d68d6836fd592a65eeea69a7d4ebbb6ef 100644 --- a/tensorflow/contrib/lite/kernels/register.h +++ b/tensorflow/contrib/lite/kernels/register.h @@ -26,6 +26,10 @@ namespace builtin { class BuiltinOpResolver : public MutableOpResolver { public: BuiltinOpResolver(); + + const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, + int version) const override; + const TfLiteRegistration* FindOp(const char* op, int version) const override; }; } // namespace builtin diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc index 3287040695140e3e7921c9f517450b9416b050b6..49ba0571e2f214c0b2407240753fcec0661c71bf 100644 --- a/tensorflow/contrib/lite/kernels/reshape.cc +++ b/tensorflow/contrib/lite/kernels/reshape.cc @@ -25,16 +25,11 @@ namespace builtin { namespace reshape { constexpr int kInputTensor = 0; +constexpr int kShapeTensor = 1; constexpr int kOutputTensor = 0; -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - - // TODO(ahentz): we are often given a tensor with the shape but we only pay - // attention to what the shape specified in 'params'. - TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - +TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node, + TfLiteIntArray* output_shape) { const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); @@ -42,37 +37,84 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // special -1 value, meaning it will be calculated automatically based on the // input. Here we calculate what that dimension should be so that the number // of output elements in the same as the number of input elements. - int num_input_elements = 1; - for (int i = 0; i < NumDimensions(input); ++i) { - num_input_elements *= SizeOfDimension(input, i); - } + int num_input_elements = NumElements(input); - TfLiteIntArray* output_size = TfLiteIntArrayCreate(params->num_dimensions); int num_output_elements = 1; int stretch_dim = -1; - for (int i = 0; i < params->num_dimensions; ++i) { - int value = params->shape[i]; + for (int i = 0; i < output_shape->size; ++i) { + int value = output_shape->data[i]; if (value == -1) { TF_LITE_ENSURE_EQ(context, stretch_dim, -1); stretch_dim = i; } else { num_output_elements *= value; - output_size->data[i] = value; } } if (stretch_dim != -1) { - output_size->data[stretch_dim] = num_input_elements / num_output_elements; - num_output_elements *= output_size->data[stretch_dim]; + output_shape->data[stretch_dim] = num_input_elements / num_output_elements; + num_output_elements *= output_shape->data[stretch_dim]; } TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements); - return context->ResizeTensor(context, output, output_size); + return context->ResizeTensor(context, output, output_shape); +} + +TfLiteStatus ResizeOutputWithShapeTensor(TfLiteContext* context, + TfLiteNode* node) { + const TfLiteTensor* shape = GetInput(context, node, kShapeTensor); + + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]); + for (int i = 0; i < output_shape->size; ++i) { + output_shape->data[i] = shape->data.i32[i]; + } + return ResizeOutput(context, node, output_shape); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + // Attempt to use shape tensor if it exists. + if (NumInputs(node) == 2) { + const TfLiteTensor* shape = GetInput(context, node, kShapeTensor); + // Check if the shape tensor is valid. + if (shape->dims->size == 1 && shape->type == kTfLiteInt32) { + // Set the output tensor as dynamic if the shape isn't constnat. + if (!IsConstantTensor(shape)) { + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + SetTensorToDynamic(output); + return kTfLiteOk; + } + // Shape is constant. Resize now. + return ResizeOutputWithShapeTensor(context, node); + } + } + // The function is returned above this line if the shape tensor is usable. + // Now fallback to the shape parameter in `TfLiteReshapeParams`. + int num_dimensions = params->num_dimensions; + if (num_dimensions == 1 && params->shape[0] == 0) { + // Legacy tflite models use a shape parameter of [0] to indicate scalars, + // so adjust accordingly. TODO(b/111614235): Allow zero-sized buffers during + // toco conversion. + num_dimensions = 0; + } + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); + for (int i = 0; i < num_dimensions; ++i) { + output_shape->data[i] = params->shape[i]; + } + return ResizeOutput(context, node, output_shape); } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputWithShapeTensor(context, node)); + } + memcpy(output->data.raw, input->data.raw, input->bytes); return kTfLiteOk; diff --git a/tensorflow/contrib/lite/kernels/reshape_test.cc b/tensorflow/contrib/lite/kernels/reshape_test.cc index aecbd0399f7454045e8189072f45b695b0525204..52d71350d3ba9a27bf9a8df7a194161c4fb7f87c 100644 --- a/tensorflow/contrib/lite/kernels/reshape_test.cc +++ b/tensorflow/contrib/lite/kernels/reshape_test.cc @@ -22,18 +22,27 @@ namespace tflite { namespace { using ::testing::ElementsAreArray; +using ::testing::IsEmpty; class ReshapeOpModel : public SingleOpModel { public: ReshapeOpModel(std::initializer_list input_shape, - std::initializer_list new_shape) { + std::initializer_list new_shape, + bool use_shape_input_tensor = false) { input_ = AddInput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); + int shape_input_tensor = + use_shape_input_tensor ? AddInput(TensorType_INT32) : -1; SetBuiltinOp( BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions, CreateReshapeOptions(builder_, builder_.CreateVector(new_shape)) .Union()); - BuildInterpreter({input_shape}); + if (use_shape_input_tensor) { + BuildInterpreter({input_shape, GetShape(shape_input_tensor)}); + PopulateTensor(shape_input_tensor, new_shape); + } else { + BuildInterpreter({input_shape}); + } } void SetInput(std::initializer_list data) { @@ -71,6 +80,14 @@ TEST(ReshapeOpTest, SimpleTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); } +TEST(ReshapeOpTest, ShapeTensorInput) { + ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2}, /*use_shape_input_tensor=*/true); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); +} + TEST(ReshapeOpTest, WithStretchDimension) { ReshapeOpModel m({1, 2, 4, 1}, {2, 1, -1}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); @@ -79,6 +96,22 @@ TEST(ReshapeOpTest, WithStretchDimension) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 4})); } +TEST(ReshapeOpTest, ScalarOutput) { + ReshapeOpModel m({1}, {}); + m.SetInput({3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); +} + +TEST(ReshapeOpTest, LegacyScalarOutput) { + ReshapeOpModel m({1}, {0}); + m.SetInput({3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc index 10caffea03ebcec7862df1627541ac3d076b04e4..f4289105f7931ae572f219a61b5479287aff926a 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -247,7 +247,7 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { 3, 6, // 9, 12, // 4, 10, // - 10, 16 // + 12, 16 // }); m.SetSize({3, 3}); m.Invoke(); @@ -256,8 +256,8 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { 7, 9, 10, // 9, 11, 12, // 4, 8, 10, // - 8, 12, 14, // - 10, 13, 16, // + 9, 12, 14, // + 12, 14, 16, // }))); ResizeBilinearOpModel const_m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3}); @@ -265,7 +265,7 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { 3, 6, // 9, 12, // 4, 10, // - 10, 16 // + 12, 16 // }); const_m.Invoke(); EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ @@ -273,35 +273,35 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { 7, 9, 10, // 9, 11, 12, // 4, 8, 10, // - 8, 12, 14, // - 10, 13, 16, // + 9, 12, 14, // + 12, 14, 16, // }))); } TEST(ResizeBilinearOpTest, ThreeDimensionalResize8Bit) { ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}}); m.SetInput({ - 3, 4, 6, 10, // - 9, 10, 12, 16, // + 3, 4, 6, 10, // + 10, 12, 14, 16, // }); m.SetSize({3, 3}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 4, 5, 8, 6, 10, // - 7, 8, 9, 12, 10, 14, // - 9, 10, 11, 13, 12, 16, // + 3, 4, 5, 8, 6, 10, // + 7, 9, 10, 12, 11, 14, // + 10, 12, 12, 14, 14, 16, // }))); ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3}); const_m.SetInput({ - 3, 4, 6, 10, // - 9, 10, 12, 16, // + 3, 4, 6, 10, // + 10, 12, 14, 16, // }); const_m.Invoke(); EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 4, 5, 8, 6, 10, // - 7, 8, 9, 12, 10, 14, // - 9, 10, 11, 13, 12, 16, // + 3, 4, 5, 8, 6, 10, // + 7, 9, 10, 12, 11, 14, // + 10, 12, 12, 14, 14, 16, // }))); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc index 9b6cee3cb55bf93b987fa8e59bdf9c591f5c0372..3cdb5db2090a3cb3eeb43c6e20a4fec09fe8a069 100644 --- a/tensorflow/contrib/lite/kernels/select.cc +++ b/tensorflow/contrib/lite/kernels/select.cc @@ -89,6 +89,9 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) { case kTfLiteUInt8: \ TF_LITE_SELECT(uint8_t, op); \ break; \ + case kTfLiteInt16: \ + TF_LITE_SELECT(int16_t, op); \ + break; \ case kTfLiteInt32: \ TF_LITE_SELECT(int32_t, op); \ break; \ diff --git a/tensorflow/contrib/lite/kernels/select_test.cc b/tensorflow/contrib/lite/kernels/select_test.cc index 4664b9acb444747167f991944ddc120e9941ccd6..5b2e61cd29a7fd7c699fd81cb81e5f9a12c4b18f 100644 --- a/tensorflow/contrib/lite/kernels/select_test.cc +++ b/tensorflow/contrib/lite/kernels/select_test.cc @@ -96,6 +96,19 @@ TEST(SelectOpTest, SelectUInt8) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); } +TEST(SelectOpTest, SelectInt16) { + SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4}, + TensorType_INT16); + + model.PopulateTensor(model.input1(), {false, true, false, false}); + model.PopulateTensor(model.input2(), {1, 2, 3, 4}); + model.PopulateTensor(model.input3(), {5, 6, 7, 8}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 2, 7, 8})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); +} + TEST(SelectOpTest, SelectInt32) { SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc index c9269599e58f095ded4788e2ab064583ae0a708c..03079f1c3b4110da9193f91ed22940594152b10f 100644 --- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc @@ -113,7 +113,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); } -#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar) \ +#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar, pad_value) \ type::SpaceToBatchND(GetTensorData(op_context.input), \ GetTensorDims(op_context.input), \ GetTensorData(op_context.block_shape), \ @@ -121,34 +121,36 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTensorData(op_context.paddings), \ GetTensorDims(op_context.paddings), \ GetTensorData(op_context.output), \ - GetTensorDims(op_context.output)) + GetTensorDims(op_context.output), pad_value) switch (op_context.input->type) { // Already know in/out types are same. case kTfLiteFloat32: if (kernel_type == kReference) { - TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float); + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float, 0); } else { - TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float); + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float, 0); } break; case kTfLiteUInt8: if (kernel_type == kReference) { - TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t); + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t, + op_context.output->params.zero_point); } else { - TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t); + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t, + op_context.output->params.zero_point); } break; case kTfLiteInt32: if (kernel_type == kReference) { - TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t); + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t, 0); } else { - TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t); + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t, 0); } break; case kTfLiteInt64: if (kernel_type == kReference) { - TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t); + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t, 0); } else { - TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t); + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t, 0); } break; default: diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc index 92a4a037d5873e608ee7bdbdfc5eaa5e9b62bc8c..5756573629a51917e39a312117a1fcd29c150dc0 100644 --- a/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc @@ -23,6 +23,7 @@ namespace tflite { namespace { using ::testing::ElementsAreArray; +using ::testing::Matcher; class SpaceToBatchNDOpModel : public SingleOpModel { public: @@ -30,6 +31,10 @@ class SpaceToBatchNDOpModel : public SingleOpModel { PopulateTensor(input_, data); } + void SetQuantizedInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + void SetBlockShape(std::initializer_list data) { PopulateTensor(block_shape_, data); } @@ -41,6 +46,11 @@ class SpaceToBatchNDOpModel : public SingleOpModel { std::vector GetOutput() { return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } + protected: int input_; int block_shape_; @@ -56,18 +66,19 @@ class SpaceToBatchNDOpModel : public SingleOpModel { // m.Invoke(); class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel { public: - SpaceToBatchNDOpConstModel(std::initializer_list input_shape, + SpaceToBatchNDOpConstModel(const TensorData& input, std::initializer_list block_shape, - std::initializer_list paddings) { - input_ = AddInput(TensorType_FLOAT32); + std::initializer_list paddings, + const TensorData& output) { + input_ = AddInput(input); block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2}); paddings_ = AddConstInput(TensorType_INT32, paddings, {2, 2}); - output_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(output); SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND, BuiltinOptions_SpaceToBatchNDOptions, CreateSpaceToBatchNDOptions(builder_).Union()); - BuildInterpreter({input_shape}); + BuildInterpreter({input.shape}); } }; @@ -81,26 +92,30 @@ class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel { // m.Invoke(); class SpaceToBatchNDOpDynamicModel : public SpaceToBatchNDOpModel { public: - SpaceToBatchNDOpDynamicModel(std::initializer_list input_shape) { - input_ = AddInput(TensorType_FLOAT32); + SpaceToBatchNDOpDynamicModel(const TensorData& input, + const TensorData& output) { + input_ = AddInput(input); block_shape_ = AddInput(TensorType_INT32); paddings_ = AddInput(TensorType_INT32); - output_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(output); SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND, BuiltinOptions_SpaceToBatchNDOptions, CreateSpaceToBatchNDOptions(builder_).Union()); - BuildInterpreter({input_shape, {2}, {2, 2}}); + BuildInterpreter({input.shape, {2}, {2, 2}}); } }; TEST(SpaceToBatchNDOpTest, InvalidShapeTest) { - EXPECT_DEATH(SpaceToBatchNDOpConstModel({1, 3, 3, 1}, {2, 2}, {0, 0, 0, 0}), - "Cannot allocate tensors"); + EXPECT_DEATH( + SpaceToBatchNDOpConstModel({TensorType_FLOAT32, {1, 3, 3, 1}}, {2, 2}, + {0, 0, 0, 0}, {TensorType_FLOAT32}), + "Cannot allocate tensors"); } TEST(SpaceToBatchNDOpTest, SimpleConstTest) { - SpaceToBatchNDOpConstModel m({1, 4, 4, 1}, {2, 2}, {0, 0, 0, 0}); + SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {1, 4, 4, 1}}, {2, 2}, + {0, 0, 0, 0}, {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1})); @@ -109,7 +124,8 @@ TEST(SpaceToBatchNDOpTest, SimpleConstTest) { } TEST(SpaceToBatchNDOpTest, SimpleDynamicTest) { - SpaceToBatchNDOpDynamicModel m({1, 4, 4, 1}); + SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {1, 4, 4, 1}}, + {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.SetBlockShape({2, 2}); m.SetPaddings({0, 0, 0, 0}); @@ -120,7 +136,8 @@ TEST(SpaceToBatchNDOpTest, SimpleDynamicTest) { } TEST(SpaceToBatchNDOpTest, MultipleInputBatchesConstTest) { - SpaceToBatchNDOpConstModel m({2, 2, 4, 1}, {2, 2}, {0, 0, 0, 0}); + SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, {2, 2}, + {0, 0, 0, 0}, {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1})); @@ -129,7 +146,8 @@ TEST(SpaceToBatchNDOpTest, MultipleInputBatchesConstTest) { } TEST(SpaceToBatchNDOpTest, MultipleInputBatchesDynamicTest) { - SpaceToBatchNDOpDynamicModel m({2, 2, 4, 1}); + SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.SetBlockShape({2, 2}); m.SetPaddings({0, 0, 0, 0}); @@ -140,7 +158,8 @@ TEST(SpaceToBatchNDOpTest, MultipleInputBatchesDynamicTest) { } TEST(SpaceToBatchNDOpTest, SimplePaddingConstTest) { - SpaceToBatchNDOpConstModel m({1, 5, 2, 1}, {3, 2}, {1, 0, 2, 0}); + SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {1, 5, 2, 1}}, {3, 2}, + {1, 0, 2, 0}, {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); @@ -151,7 +170,8 @@ TEST(SpaceToBatchNDOpTest, SimplePaddingConstTest) { } TEST(SpaceToBatchNDOpTest, SimplePaddingDynamicTest) { - SpaceToBatchNDOpDynamicModel m({1, 5, 2, 1}); + SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {1, 5, 2, 1}}, + {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); m.SetBlockShape({3, 2}); m.SetPaddings({1, 0, 2, 0}); @@ -164,7 +184,8 @@ TEST(SpaceToBatchNDOpTest, SimplePaddingDynamicTest) { } TEST(SpaceToBatchNDOpTest, ComplexPaddingConstTest) { - SpaceToBatchNDOpConstModel m({1, 4, 2, 1}, {3, 2}, {1, 1, 2, 4}); + SpaceToBatchNDOpConstModel m({TensorType_FLOAT32, {1, 4, 2, 1}}, {3, 2}, + {1, 1, 2, 4}, {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); @@ -176,7 +197,8 @@ TEST(SpaceToBatchNDOpTest, ComplexPaddingConstTest) { } TEST(SpaceToBatchNDOpTest, ComplexPaddingDynamicTest) { - SpaceToBatchNDOpDynamicModel m({1, 4, 2, 1}); + SpaceToBatchNDOpDynamicModel m({TensorType_FLOAT32, {1, 4, 2, 1}}, + {TensorType_FLOAT32}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); m.SetBlockShape({3, 2}); m.SetPaddings({1, 1, 2, 4}); @@ -189,6 +211,88 @@ TEST(SpaceToBatchNDOpTest, ComplexPaddingDynamicTest) { })); } +class QuantizedSpaceToBatchNDOpTest : public ::testing::Test { + protected: + std::vector> DequantizedArrayNear( + const std::vector& values, const float min, const float max) { + const float quantization_tolerance = (max - min) / 255.0; + return ArrayFloatNear(values, quantization_tolerance); + } +}; + +TEST_F(QuantizedSpaceToBatchNDOpTest, ZeroNotInQuantizationRange) { + // The test_util and actual quantization code currently ensure that the range + // must include zero, but if that ever changes, this test will catch it. + EXPECT_DEATH(SpaceToBatchNDOpConstModel m( + {TensorType_UINT8, {1, 2, 2, 1}, 1.0, 2.0}, {4, 2}, + {0, 0, 1, 1, 1, 1, 0, 0}, {TensorType_UINT8, {}, 1.0, 2.0}), + ".*Check failed: f_min <= 0.*"); +} + +TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingConstTest) { + SpaceToBatchNDOpConstModel m({TensorType_UINT8, {1, 5, 2, 1}, -1.0, 1.0}, + {3, 2}, {1, 0, 2, 0}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + {0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7, + 0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1}, + -1.0, 1.0))); +} + +TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingDynamicTest) { + SpaceToBatchNDOpDynamicModel m({TensorType_UINT8, {1, 5, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1}); + m.SetBlockShape({3, 2}); + m.SetPaddings({1, 0, 2, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + {0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7, + 0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1}, + -1.0, 1.0))); +} + +TEST_F(QuantizedSpaceToBatchNDOpTest, ComplexPaddingConstTest) { + SpaceToBatchNDOpConstModel m({TensorType_UINT8, {1, 4, 2, 1}, -1.0, 1.0}, + {3, 2}, {1, 1, 2, 4}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + { + 0, 0, 0, 0, 0, -0.5, 0, 0, 0, 0, 0, 0, 0, 0.6, 0, 0, + 0, -0.1, 0, 0, 0, -0.7, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0, + 0, -0.3, 0, 0, 0, 0, 0, 0, 0, 0.4, 0, 0, 0, 0, 0, 0, + }, + -1.0, 1.0))); +} + +TEST_F(QuantizedSpaceToBatchNDOpTest, ComplexPaddingDynamicTest) { + SpaceToBatchNDOpDynamicModel m({TensorType_UINT8, {1, 4, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}); + m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8}); + m.SetBlockShape({3, 2}); + m.SetPaddings({1, 1, 2, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(DequantizedArrayNear( + { + 0, 0, 0, 0, 0, -0.5, 0, 0, 0, 0, 0, 0, 0, 0.6, 0, 0, + 0, -0.1, 0, 0, 0, -0.7, 0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0, + 0, -0.3, 0, 0, 0, 0, 0, 0, 0, 0.4, 0, 0, 0, 0, 0, 0, + }, + -1.0, 1.0))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc index 404c32ad9ca8b9f1e467b747708ccb451f2a5118..fec2a6f0d97ae48e0c49d82c726278a46d96a7fc 100644 --- a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc +++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -188,7 +187,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return ResizeOutputShape(context, output_shape, output); } -template +template TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); const TfLiteTensor* output_shape = @@ -205,10 +204,10 @@ TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) { const int num_indices = SizeOfDimension(indices, 0); const bool value_is_scalar = NumDimensions(values) == 0; - std::vector> indices_vector; + std::vector> indices_vector; indices_vector.reserve(num_indices); - TF_LITE_ENSURE_OK(context, GetIndicesVector(context, indices, num_indices, - &indices_vector)); + TF_LITE_ENSURE_OK(context, GetIndicesVector(context, indices, num_indices, + &indices_vector)); reference_ops::SparseToDense(indices_vector, GetTensorData(values), *GetTensorData(default_value), GetTensorData(output), GetTensorDims(output), diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc index 1247525d416e8166a9e2e1d67c7907c00b0f6723..77a1f596898bb7fa99a7509a25229c627d762bdd 100644 --- a/tensorflow/contrib/lite/kernels/sub.cc +++ b/tensorflow/contrib/lite/kernels/sub.cc @@ -78,29 +78,47 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } template -void EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteSubParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRange(params->activation, &output_activation_min, - &output_activation_max); -#define TF_LITE_SUB(type, opname) \ - type::opname(GetTensorData(input1), GetTensorDims(input1), \ - GetTensorData(input2), GetTensorDims(input2), \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)) - if (kernel_type == kReference) { - if (data->requires_broadcast) { - TF_LITE_SUB(reference_ops, BroadcastSub); +void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params, + const OpData* data, const TfLiteTensor* input1, + const TfLiteTensor* input2, TfLiteTensor* output) { +#define TF_LITE_SUB(type, opname, data_type) \ + data_type output_activation_min, output_activation_max; \ + CalculateActivationRange(params->activation, &output_activation_min, \ + &output_activation_max); \ + tflite::ArithmeticParams op_params; \ + SetActivationParams(output_activation_min, output_activation_max, \ + &op_params); \ + type::opname(op_params, GetTensorShape(input1), \ + GetTensorData(input1), GetTensorShape(input2), \ + GetTensorData(input2), GetTensorShape(output), \ + GetTensorData(output)) + if (output->type == kTfLiteInt32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_SUB(reference_ops, BroadcastSub4DSlow, int32_t); + } else { + TF_LITE_SUB(reference_ops, SubWithActivation, int32_t); + } } else { - TF_LITE_SUB(reference_ops, Sub); + if (data->requires_broadcast) { + TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow, int32_t); + } else { + TF_LITE_SUB(optimized_ops, SubWithActivation, int32_t); + } } - } else { - if (data->requires_broadcast) { - TF_LITE_SUB(optimized_ops, BroadcastSub); + } else if (output->type == kTfLiteFloat32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_SUB(reference_ops, BroadcastSub4DSlow, float); + } else { + TF_LITE_SUB(reference_ops, SubWithActivation, float); + } } else { - TF_LITE_SUB(optimized_ops, Sub); + if (data->requires_broadcast) { + TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow, float); + } else { + TF_LITE_SUB(optimized_ops, SubWithActivation, float); + } } } #undef TF_LITE_SUB @@ -128,36 +146,43 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, int input1_shift; QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier, &input1_shift); - input1_shift *= -1; int32 input2_multiplier; int input2_shift; QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier, &input2_shift); - input2_shift *= -1; int32 output_multiplier; int output_shift; QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier, &output_shift); - output_shift *= -1; int32 output_activation_min, output_activation_max; CalculateActivationRangeUint8(params->activation, output, &output_activation_min, &output_activation_max); -#define TF_LITE_SUB(type, opname) \ - type::opname(left_shift, GetTensorData(input1), \ - GetTensorDims(input1), input1_offset, input1_multiplier, \ - input1_shift, GetTensorData(input2), \ - GetTensorDims(input2), input2_offset, input2_multiplier, \ - input2_shift, output_offset, output_multiplier, output_shift, \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)); +#define TF_LITE_SUB(type, opname) \ + tflite::ArithmeticParams op_params; \ + op_params.left_shift = left_shift; \ + op_params.input1_offset = input1_offset; \ + op_params.input1_multiplier = input1_multiplier; \ + op_params.input1_shift = input1_shift; \ + op_params.input2_offset = input2_offset; \ + op_params.input2_multiplier = input2_multiplier; \ + op_params.input2_shift = input2_shift; \ + op_params.output_offset = output_offset; \ + op_params.output_multiplier = output_multiplier; \ + op_params.output_shift = output_shift; \ + SetActivationParams(output_activation_min, output_activation_max, \ + &op_params); \ + type::opname(op_params, GetTensorShape(input1), \ + GetTensorData(input1), GetTensorShape(input2), \ + GetTensorData(input2), GetTensorShape(output), \ + GetTensorData(output)) // The quantized version of Sub doesn't support activations, so we // always use BroadcastSub. if (kernel_type == kReference) { - TF_LITE_SUB(reference_ops, BroadcastSub); + TF_LITE_SUB(reference_ops, BroadcastSub4DSlow); } else { - TF_LITE_SUB(optimized_ops, BroadcastSub); + TF_LITE_SUB(optimized_ops, BroadcastSub4DSlow); } #undef TF_LITE_SUB } @@ -171,14 +196,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - if (output->type == kTfLiteFloat32) { - EvalFloat(context, node, params, data, input1, input2, output); + if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { + EvalSub(context, node, params, data, input1, input2, output); } else if (output->type == kTfLiteUInt8) { EvalQuantized(context, node, params, data, input1, input2, output); } else { context->ReportError( - context, "output type %d is not supported, requires float|uint8 types.", + context, + "output type %d is not supported, requires float|uint8|int32 types.", output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/sub_test.cc b/tensorflow/contrib/lite/kernels/sub_test.cc index ff07aeec49dbfcc0e1f65df3d674d5ec30f1b54c..5978c574d35492eda6b903fd83d95ecbd6b62148 100644 --- a/tensorflow/contrib/lite/kernels/sub_test.cc +++ b/tensorflow/contrib/lite/kernels/sub_test.cc @@ -52,6 +52,13 @@ class FloatSubOpModel : public BaseSubOpModel { std::vector GetOutput() { return ExtractVector(output_); } }; +class IntegerSubOpModel : public BaseSubOpModel { + public: + using BaseSubOpModel::BaseSubOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + class QuantizedSubOpModel : public BaseSubOpModel { public: using BaseSubOpModel::BaseSubOpModel; @@ -125,6 +132,57 @@ TEST(FloatSubOpModel, WithBroadcast) { } } +TEST(IntegerSubOpModel, NoActivation) { + IntegerSubOpModel m({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, + ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-21, 0, 4, 3})); +} + +TEST(IntegerSubOpModel, ActivationRELU_N1_TO_1) { + IntegerSubOpModel m({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}}, + ActivationFunctionType_RELU_N1_TO_1); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 0, 1, 1})); +} + +TEST(IntegerSubOpModel, VariousInputShapes) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + IntegerSubOpModel m({TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8, 11, 20}); + m.PopulateTensor(m.input2(), {1, 2, 3, 5, 11, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-21, 0, 4, 3, 0, 19})) + << "With shape number " << i; + } +} + +TEST(IntegerSubOpModel, WithBroadcast) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + IntegerSubOpModel m({TensorType_INT32, test_shapes[i]}, + {TensorType_INT32, {}}, // always a scalar + {TensorType_INT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-20, 2, 7, 8, 11, 20}); + m.PopulateTensor(m.input2(), {1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-21, 1, 6, 7, 10, 19}))) + << "With shape number " << i; + } +} + TEST(QuantizedSubOpModel, QuantizedTestsNoActivation) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = { diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc index 43ac3a2ce86df6dc9a0dd914851174aaf33b25be..6d4912ce3aa40bf95dc1e26572b8a07fb6362744 100644 --- a/tensorflow/contrib/lite/kernels/svdf.cc +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -16,7 +16,6 @@ limitations under the License. // SVDF op that compresses a fully connected op via low-rank matrix // factorization. See https://research.google.com/pubs/archive/43813.pdf for // details. -#include #include #include #include @@ -105,7 +104,7 @@ constexpr int kStateTensor = 0; constexpr int kOutputTensor = 1; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* op_data = new OpData; + auto* op_data = new OpData(); op_data->float_weights_time_initialized = false; context->AddTensors(context, /*tensors_to_add=*/4, &op_data->scratch_tensor_index); @@ -382,11 +381,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // the Eval function. // TODO(alanchiao): refactor logic out into dequantize function. if (!op_data->float_weights_time_initialized) { - const float inv_scale = 1.0 / weights_time->params.scale; + const float dequantization_scale = weights_time->params.scale; const int8_t* weights_time_ptr = reinterpret_cast(weights_time->data.uint8); for (int i = 0; i < NumElements(float_weights_time); ++i) { - float_weights_time->data.f[i] = weights_time_ptr[i] * inv_scale; + float_weights_time->data.f[i] = + weights_time_ptr[i] * dequantization_scale; } op_data->float_weights_time_initialized = true; } diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc index 06df509d32dacc25fbcf84606b5218697c831e96..5af3ff85004ce43c5b75c6f12761f121c0d8deca 100644 --- a/tensorflow/contrib/lite/kernels/svdf_test.cc +++ b/tensorflow/contrib/lite/kernels/svdf_test.cc @@ -342,7 +342,7 @@ TEST_F(SVDFOpTest, BlackBoxTestHybridRank1) { svdf.ResetState(); VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input), &svdf, - /*tolerance=*/0.00294435); + /*tolerance=*/0.002945); } TEST_F(SVDFOpTest, BlackBoxTestHybridRank2) { diff --git a/tensorflow/contrib/lite/kernels/tile.cc b/tensorflow/contrib/lite/kernels/tile.cc index af77f074742eb3fef10a74616ff679255911fbb2..5181a8f89a376302bad02913e3c7c1d094821da8 100644 --- a/tensorflow/contrib/lite/kernels/tile.cc +++ b/tensorflow/contrib/lite/kernels/tile.cc @@ -87,8 +87,9 @@ std::pair TileOneDimension(const TfLiteIntArray& in_dimensions, if (dimension == in_dimensions.size - 1) { CopyMultipleTimes(in_data, dimension_size, multipliers[dimension], out_data); - return std::make_pair(dimension_size, - dimension_size * multipliers[dimension]); + return std::make_pair( + dimension_size, + dimension_size * static_cast(multipliers[dimension])); } int total_stride_size = 0, total_tiled_stride_size = 0; const T* copy_from_data = in_data; diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc index fb0e49c90c41747f9b7e53570276c8b8045030fd..2dd760bbfebd1faa8b7ff9158bc1a1b1d4647525 100644 --- a/tensorflow/contrib/lite/kernels/topk_v2.cc +++ b/tensorflow/contrib/lite/kernels/topk_v2.cc @@ -56,11 +56,13 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { output_values_shape->data[num_dimensions - 1] = k; TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes); TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); + // Force output types. + output_indexes->type = kTfLiteInt32; + output_values->type = input->type; auto resize_tensor = [context](TfLiteTensor* tensor, TfLiteIntArray* new_size, TfLiteIntArray* delete_on_error) { TfLiteStatus status = context->ResizeTensor(context, tensor, new_size); if (status != kTfLiteOk) { - TfLiteIntArrayFree(new_size); if (delete_on_error != nullptr) { TfLiteIntArrayFree(delete_on_error); } diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc index 8b9deeed20d761876d526c07eb78b602ca7314dc..a9baa5c6988877ccc2e007e5fefdc980d7a3a679 100644 --- a/tensorflow/contrib/lite/kernels/transpose_conv.cc +++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc index 1c28123a24edd9886476bf8e9ea3ba4c692baa2b..0acd705950cb262bbb2625aa6143f88b429a6562 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -70,9 +69,21 @@ constexpr int kOutputStateTensor = 0; constexpr int kCellStateTensor = 1; constexpr int kOutputTensor = 2; +// Temporary tensors +enum TemporaryTensor { + kScratchBuffer = 0, + kInputQuantized = 1, + kOutputStateQuantized = 2, + kCellStateQuantized = 3, + kScalingFactors = 4, + kProductScalingFactors = 5, + kRecoveredCellWeights = 6, + kNumTemporaryTensors = 7 +}; + void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* scratch_tensor_index = new int; - context->AddTensors(context, 1, scratch_tensor_index); + context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index); return scratch_tensor_index; } @@ -84,7 +95,7 @@ void Free(TfLiteContext* context, void* buffer) { TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteNode* node, int n_input, int n_output, int n_cell) { - auto* params = reinterpret_cast(node->builtin_data); + const auto* params = reinterpret_cast(node->builtin_data); // Making sure clipping parameters have valid values. // == 0 means no clipping @@ -242,6 +253,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Inferring batch size, number of outputs and sequence length and // number of cells from the input tensors. const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE(context, input->dims->size > 1); const int max_time = input->dims->data[0]; const int n_batch = input->dims->data[1]; @@ -261,7 +273,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int n_output = recurrent_to_output_weights->dims->data[1]; // Check that input tensor dimensions matches with each other. - CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); + TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input, + n_output, n_cell)); // Get the pointer to output, output_state and cell_state buffer tensors. TfLiteTensor* output = GetOutput(context, node, kOutputTensor); @@ -288,86 +301,156 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, cell_state, cell_size)); - // Create a scratch buffer tensor. + // Mark state tensors as persistent tensors. + output_state->allocation_type = kTfLiteArenaRwPersistent; + cell_state->allocation_type = kTfLiteArenaRwPersistent; + + // The weights are of consistent type, so it suffices to check one. + // TODO(mirkov): create a utility/macro for this check, so all Ops can use it. + const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 && + input->type == kTfLiteFloat32); + TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(1); + if (is_hybrid_op) { + node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); + } else { + node->temporaries = TfLiteIntArrayCreate(1); + } node->temporaries->data[0] = *scratch_tensor_index; - TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); + + // Create a scratch buffer tensor. + TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer); scratch_buffer->type = input->type; scratch_buffer->allocation_type = kTfLiteArenaRw; - // Mark state tensors as persistent tensors. - output_state->allocation_type = kTfLiteArenaRwPersistent; - cell_state->allocation_type = kTfLiteArenaRwPersistent; - const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); const bool use_cifg = (input_to_input_weights == nullptr); + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; if (use_cifg) { - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; // Reserving space for Cell, Forget, Output gates scratch_buffer_size->data[1] = n_cell * 3; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); } else { - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; // Reserving space for Input, Cell, Forget, Output gates scratch_buffer_size->data[1] = n_cell * 4; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); + } + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + + if (is_hybrid_op) { + // Allocate temporary tensors to store quantized values of input, + // output_state and cell_state tensors. + node->temporaries->data[kInputQuantized] = + *scratch_tensor_index + kInputQuantized; + TfLiteTensor* input_quantized = + GetTemporary(context, node, kInputQuantized); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + node->temporaries->data[kOutputStateQuantized] = + *scratch_tensor_index + kOutputStateQuantized; + TfLiteTensor* output_state_quantized = + GetTemporary(context, node, kOutputStateQuantized); + output_state_quantized->type = kTfLiteUInt8; + output_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(output_state_quantized->dims, + output_state->dims)) { + TfLiteIntArray* output_state_quantized_size = + TfLiteIntArrayCopy(output_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output_state_quantized, + output_state_quantized_size)); + } + node->temporaries->data[kCellStateQuantized] = + *scratch_tensor_index + kCellStateQuantized; + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, kCellStateQuantized); + cell_state_quantized->type = kTfLiteUInt8; + cell_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) { + TfLiteIntArray* cell_state_quantized_size = + TfLiteIntArrayCopy(cell_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state_quantized, + cell_state_quantized_size)); + } + + // Allocate temporary tensors to store scaling factors and product scaling + // factors. The latter is a convenience storage which allows to quantize + // a vector once (which produces the scaling factors) and multiply it with + // different matrices (which requires multiplying the scaling factors with + // the scaling factor of the matrix). + node->temporaries->data[kScalingFactors] = + *scratch_tensor_index + kScalingFactors; + TfLiteTensor* scaling_factors = + GetTemporary(context, node, kScalingFactors); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } + node->temporaries->data[kProductScalingFactors] = + *scratch_tensor_index + kProductScalingFactors; + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, kProductScalingFactors); + prod_scaling_factors->type = kTfLiteFloat32; + prod_scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); + prod_scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(prod_scaling_factors->dims, + prod_scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, prod_scaling_factors, + prod_scaling_factors_size)); + } + + // Allocate a temporary tensor to store the recovered cell weights. Since + // this is used for diagonal matrices, only need to store n_cell values. + node->temporaries->data[kRecoveredCellWeights] = + *scratch_tensor_index + kRecoveredCellWeights; + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, kRecoveredCellWeights); + recovered_cell_weights->type = kTfLiteFloat32; + recovered_cell_weights->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1); + recovered_cell_weights_size->data[0] = n_cell; + if (!TfLiteIntArrayEqual(recovered_cell_weights->dims, + recovered_cell_weights_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, recovered_cell_weights, + recovered_cell_weights_size)); + } } return kTfLiteOk; } // The LSTM Op engine. -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - - const TfLiteTensor* input_to_input_weights = - GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - const TfLiteTensor* input_to_forget_weights = - GetInput(context, node, kInputToForgetWeightsTensor); - const TfLiteTensor* input_to_cell_weights = - GetInput(context, node, kInputToCellWeightsTensor); - const TfLiteTensor* input_to_output_weights = - GetInput(context, node, kInputToOutputWeightsTensor); - - const TfLiteTensor* recurrent_to_input_weights = - GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); - const TfLiteTensor* recurrent_to_forget_weights = - GetInput(context, node, kRecurrentToForgetWeightsTensor); - const TfLiteTensor* recurrent_to_cell_weights = - GetInput(context, node, kRecurrentToCellWeightsTensor); - const TfLiteTensor* recurrent_to_output_weights = - GetInput(context, node, kRecurrentToOutputWeightsTensor); - - const TfLiteTensor* cell_to_input_weights = - GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); - const TfLiteTensor* cell_to_forget_weights = - GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); - const TfLiteTensor* cell_to_output_weights = - GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); - - const TfLiteTensor* input_gate_bias = - GetOptionalInputTensor(context, node, kInputGateBiasTensor); - const TfLiteTensor* forget_gate_bias = - GetInput(context, node, kForgetGateBiasTensor); - const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); - const TfLiteTensor* output_gate_bias = - GetInput(context, node, kOutputGateBiasTensor); - - const TfLiteTensor* projection_weights = - GetOptionalInputTensor(context, node, kProjectionWeightsTensor); - const TfLiteTensor* projection_bias = - GetOptionalInputTensor(context, node, kProjectionBiasTensor); - - TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); - TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - +TfLiteStatus EvalFloat( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { const int max_time = input->dims->data[0]; const int n_batch = input->dims->data[1]; const int n_input = input->dims->data[2]; @@ -380,8 +463,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const bool use_cifg = (input_to_input_weights == nullptr); const bool use_peephole = (cell_to_output_weights != nullptr); - // Index the scratch buffers pointers to the global scratch buffer. - TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); float* input_gate_scratch = nullptr; float* cell_scratch = nullptr; float* forget_gate_scratch = nullptr; @@ -432,6 +513,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { float* output_state_ptr = output_state->data.f; float* cell_state_ptr = cell_state->data.f; + // Feed the sequence into the LSTM step-by-step. for (int t = 0; t < max_time; t++) { const float* input_ptr_batch = input->data.f + t * n_batch * n_input; float* output_ptr_batch = output->data.f + t * n_batch * n_output; @@ -452,6 +534,262 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, + TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, + TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, + TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { + const int max_time = input->dims->data[0]; + const int n_batch = input->dims->data[1]; + const int n_input = input->dims->data[2]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + int8_t* input_to_input_weights_ptr = nullptr; + float input_to_input_weights_scale = 1.0f; + int8_t* recurrent_to_input_weights_ptr = nullptr; + float recurrent_to_input_weights_scale = 1.0f; + float* input_gate_bias_ptr = nullptr; + if (!use_cifg) { + input_to_input_weights_ptr = + reinterpret_cast(input_to_input_weights->data.uint8); + recurrent_to_input_weights_ptr = + reinterpret_cast(recurrent_to_input_weights->data.uint8); + input_gate_bias_ptr = input_gate_bias->data.f; + input_to_input_weights_scale = input_to_input_weights->params.scale; + recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; + } + + int8_t* cell_to_input_weights_ptr = nullptr; + int8_t* cell_to_forget_weights_ptr = nullptr; + int8_t* cell_to_output_weights_ptr = nullptr; + float cell_to_input_weights_scale = 1.0f; + float cell_to_forget_weights_scale = 1.0f; + float cell_to_output_weights_scale = 1.0f; + if (use_peephole) { + if (!use_cifg) { + cell_to_input_weights_ptr = + reinterpret_cast(cell_to_input_weights->data.uint8); + cell_to_input_weights_scale = cell_to_input_weights->params.scale; + } + cell_to_forget_weights_ptr = + reinterpret_cast(cell_to_forget_weights->data.uint8); + cell_to_output_weights_ptr = + reinterpret_cast(cell_to_output_weights->data.uint8); + cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; + cell_to_output_weights_scale = cell_to_output_weights->params.scale; + } + + const int8_t* projection_weights_ptr = + (projection_weights == nullptr) + ? nullptr + : reinterpret_cast(projection_weights->data.uint8); + float projection_weights_scale = + (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const int8_t* input_to_forget_weights_ptr = + reinterpret_cast(input_to_forget_weights->data.uint8); + const float input_to_forget_weights_scale = + input_to_forget_weights->params.scale; + const int8_t* input_to_cell_weights_ptr = + reinterpret_cast(input_to_cell_weights->data.uint8); + const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; + const int8_t* input_to_output_weights_ptr = + reinterpret_cast(input_to_output_weights->data.uint8); + const float input_to_output_weights_scale = + input_to_output_weights->params.scale; + const int8_t* recurrent_to_forget_weights_ptr = + reinterpret_cast(recurrent_to_forget_weights->data.uint8); + const float recurrent_to_forget_weights_scale = + recurrent_to_forget_weights->params.scale; + const int8_t* recurrent_to_cell_weights_ptr = + reinterpret_cast(recurrent_to_cell_weights->data.uint8); + const float recurrent_to_cell_weights_scale = + recurrent_to_cell_weights->params.scale; + const int8_t* recurrent_to_output_weights_ptr = + reinterpret_cast(recurrent_to_output_weights->data.uint8); + const float recurrent_to_output_weights_scale = + recurrent_to_output_weights->params.scale; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* output_state_ptr = output_state->data.f; + float* cell_state_ptr = cell_state->data.f; + + // Temporary storage for quantized values and scaling factors. + int8_t* quantized_input_ptr = + reinterpret_cast(input_quantized->data.uint8); + int8_t* quantized_output_state_ptr = + reinterpret_cast(output_state_quantized->data.uint8); + int8_t* quantized_cell_state_ptr = + reinterpret_cast(cell_state_quantized->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; + float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; + float* recovered_cell_weights_ptr = recovered_cell_weights->data.f; + + // Feed the sequence into the LSTM step-by-step. + for (int t = 0; t < max_time; t++) { + const float* input_ptr_batch = input->data.f + t * n_batch * n_input; + float* output_ptr_batch = output->data.f + t * n_batch * n_output; + + kernel_utils::LstmStep( + input_ptr_batch, input_to_input_weights_ptr, + input_to_input_weights_scale, input_to_forget_weights_ptr, + input_to_forget_weights_scale, input_to_cell_weights_ptr, + input_to_cell_weights_scale, input_to_output_weights_ptr, + input_to_output_weights_scale, recurrent_to_input_weights_ptr, + recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, + recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, + recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, + recurrent_to_output_weights_scale, cell_to_input_weights_ptr, + cell_to_input_weights_scale, cell_to_forget_weights_ptr, + cell_to_forget_weights_scale, cell_to_output_weights_ptr, + cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr, + cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, + projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell, + n_input, n_output, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, scaling_factors_ptr, + prod_scaling_factors_ptr, recovered_cell_weights_ptr, + quantized_input_ptr, quantized_output_state_ptr, + quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, + output_ptr_batch); + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + + const TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + const TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + const TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + + const TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + const TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + const TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + const TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + + const TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + const TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + const TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + + const TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + const TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + const TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + + const TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + const TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); + + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (input_to_output_weights->type) { + case kTfLiteFloat32: { + return EvalFloat(input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, + cell_to_output_weights, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, params, + scratch_buffer, output_state, cell_state, output); + } + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* output_state_quantized = + GetTemporary(context, node, /*index=*/2); + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, /*index=*/3); + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, /*index=*/5); + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, /*index=*/6); + return EvalHybrid( + input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, params, scratch_buffer, + scaling_factors, prod_scaling_factors, recovered_cell_weights, + input_quantized, output_state_quantized, cell_state_quantized, + output_state, cell_state, output); + } + default: + context->ReportError(context, "Type %d is not currently supported.", + input_to_output_weights->type); + return kTfLiteError; + } + return kTfLiteOk; +} } // namespace unidirectional_sequence_lstm TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() { diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc index 5881ced7c7a616ef2c24db60892cbbf9eec7c42e..de38bdef6fd1b019c7790a664b29cd45d29e5dcc 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ // Unit test for TFLite Sequential LSTM op. -#include #include #include @@ -37,7 +36,8 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { bool use_peephole, bool use_projection_weights, bool use_projection_bias, float cell_clip, float proj_clip, - const std::vector>& input_shapes) + const std::vector>& input_shapes, + const TensorType& weights_type = TensorType_FLOAT32) : n_batch_(n_batch), n_input_(n_input), n_cell_(n_cell), @@ -48,31 +48,31 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { if (use_cifg) { input_to_input_weights_ = AddNullInput(); } else { - input_to_input_weights_ = AddInput(TensorType_FLOAT32); + input_to_input_weights_ = AddInput(weights_type); } - input_to_forget_weights_ = AddInput(TensorType_FLOAT32); - input_to_cell_weights_ = AddInput(TensorType_FLOAT32); - input_to_output_weights_ = AddInput(TensorType_FLOAT32); + input_to_forget_weights_ = AddInput(weights_type); + input_to_cell_weights_ = AddInput(weights_type); + input_to_output_weights_ = AddInput(weights_type); if (use_cifg) { recurrent_to_input_weights_ = AddNullInput(); } else { - recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_input_weights_ = AddInput(weights_type); } - recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); - recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); - recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_forget_weights_ = AddInput(weights_type); + recurrent_to_cell_weights_ = AddInput(weights_type); + recurrent_to_output_weights_ = AddInput(weights_type); if (use_peephole) { if (use_cifg) { cell_to_input_weights_ = AddNullInput(); } else { - cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + cell_to_input_weights_ = AddInput(weights_type); } - cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); - cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + cell_to_forget_weights_ = AddInput(weights_type); + cell_to_output_weights_ = AddInput(weights_type); } else { cell_to_input_weights_ = AddNullInput(); cell_to_forget_weights_ = AddNullInput(); @@ -89,7 +89,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { output_gate_bias_ = AddInput(TensorType_FLOAT32); if (use_projection_weights) { - projection_weights_ = AddInput(TensorType_FLOAT32); + projection_weights_ = AddInput(weights_type); if (use_projection_bias) { projection_bias_ = AddInput(TensorType_FLOAT32); } else { @@ -196,8 +196,9 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { zero_buffer.get() + zero_buffer_size); } - void SetInput(int offset, float* begin, float* end) { - PopulateTensor(input_, offset, begin, end); + void SetInput(int offset, const float* begin, const float* end) { + PopulateTensor(input_, offset, const_cast(begin), + const_cast(end)); } std::vector GetOutput() { return ExtractVector(output_); } @@ -208,7 +209,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { int num_batches() { return n_batch_; } int sequence_length() { return sequence_length_; } - private: + protected: int input_; int input_to_input_weights_; int input_to_forget_weights_; @@ -243,7 +244,183 @@ class UnidirectionalLSTMOpModel : public SingleOpModel { int sequence_length_; }; -TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { +// The hybrid model has quantized weights. +class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel { + public: + HybridUnidirectionalLSTMOpModel( + int n_batch, int n_input, int n_cell, int n_output, int sequence_length, + bool use_cifg, bool use_peephole, bool use_projection_weights, + bool use_projection_bias, float cell_clip, float proj_clip, + const std::vector>& input_shapes) + : UnidirectionalLSTMOpModel( + n_batch, n_input, n_cell, n_output, sequence_length, use_cifg, + use_peephole, use_projection_weights, use_projection_bias, + cell_clip, proj_clip, input_shapes, TensorType_UINT8) {} + + void SetInputToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_output_weights_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(projection_weights_, f); + } +}; + +class BaseLstmTest : public ::testing::Test { + protected: + // Weights of the LSTM model. Some are optional. + std::initializer_list input_to_input_weights_; + std::initializer_list input_to_cell_weights_; + std::initializer_list input_to_forget_weights_; + std::initializer_list input_to_output_weights_; + std::initializer_list input_gate_bias_; + std::initializer_list cell_gate_bias_; + std::initializer_list forget_gate_bias_; + std::initializer_list output_gate_bias_; + std::initializer_list recurrent_to_input_weights_; + std::initializer_list recurrent_to_cell_weights_; + std::initializer_list recurrent_to_forget_weights_; + std::initializer_list recurrent_to_output_weights_; + std::initializer_list cell_to_input_weights_; + std::initializer_list cell_to_forget_weights_; + std::initializer_list cell_to_output_weights_; + std::initializer_list projection_weights_; + + // LSTM input is stored as num_batch x num_inputs vector. + std::vector> lstm_input_; + // LSTM output is stored as num_batch x num_outputs vector. + std::vector> lstm_golden_output_; + + // Compares output up to tolerance to the result of the lstm given the input. + void VerifyGoldens(const std::vector>& input, + const std::vector>& output, + UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5) { + const int num_batches = input.size(); + EXPECT_GT(num_batches, 0); + const int num_inputs = lstm->num_inputs(); + EXPECT_GT(num_inputs, 0); + const int input_sequence_size = input[0].size() / num_inputs; + EXPECT_GT(input_sequence_size, 0); + // Feed the whole sequence as input. + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* batch_start = input[b].data() + i * num_inputs; + const float* batch_end = batch_start + num_inputs; + + lstm->SetInput(((i * num_batches) + b) * lstm->num_inputs(), + batch_start, batch_end); + } + } + + lstm->Invoke(); + + const int num_outputs = lstm->num_outputs(); + EXPECT_GT(num_outputs, 0); + std::vector expected; + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* golden_start_batch = output[b].data() + i * num_outputs; + const float* golden_end_batch = golden_start_batch + num_outputs; + + expected.insert(expected.end(), golden_start_batch, golden_end_batch); + } + } + + EXPECT_THAT(lstm->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } +}; + +class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}; + input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, -0.29909778}; + input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}; + input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, + -0.1556896, 0.19487578}; + input_gate_bias_ = {0., 0., 0., 0.}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_input_weights_ = { + -0.0063535, -0.2042388, 0.31454784, -0.35746509, + 0.28902304, 0.08183324, -0.16555229, 0.02286911, + -0.13566875, 0.03034258, 0.48091322, -0.12528998, + 0.24077177, -0.51332325, -0.33502164, 0.10629296}; + + recurrent_to_cell_weights_ = { + -0.3407414, 0.24443203, -0.2078532, 0.26320225, + 0.05695659, -0.00123841, -0.4744786, -0.35869038, + -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}; + + recurrent_to_forget_weights_ = { + -0.48684245, -0.06655136, 0.42224967, 0.2112639, + 0.27654213, 0.20864892, -0.07646349, 0.45877004, + 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}; + + recurrent_to_output_weights_ = { + 0.43385774, -0.17194885, 0.2718237, 0.09215671, + 0.24107647, -0.39835793, 0.18212086, 0.01301402, + 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765, + -0.03716109, 0.12507336, 0.41193449, -0.20860538, + -0.15053082, 0.09120187, 0.24278517, -0.12222792}}; + } +}; + +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -252,9 +429,11 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { const int sequence_length = 3; UnidirectionalLSTMOpModel lstm( - n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, - /*use_peephole=*/false, /*use_projection_weights=*/false, - /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + n_batch, n_input, n_cell, n_output, sequence_length, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, { {sequence_length, n_batch, n_input}, // input tensor @@ -281,77 +460,138 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, - -0.34550029, 0.04266912, -0.15680569, - -0.34856534, 0.43890524}); + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, - -0.20583314, 0.44344562, 0.22077113, - -0.29909778}); + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, - -0.31343272, -0.40032279, 0.44781327, - 0.01387155, -0.35593212}); + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, - 0.40525138, 0.44272184, 0.03897077, -0.1556896, - 0.19487578}); +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; - lstm.SetInputGateBias({0., 0., 0., 0.}); + HybridUnidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor - lstm.SetCellBias({0., 0., 0., 0.}); + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor - lstm.SetForgetGateBias({1., 1., 1., 1.}); + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor - lstm.SetOutputGateBias({0., 0., 0., 0.}); + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor - lstm.SetRecurrentToInputWeights( - {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, - -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, - -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor - lstm.SetRecurrentToCellWeights( - {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, - -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, - -0.46367589, 0.26016325, -0.03894562, -0.16368064}); + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); - lstm.SetRecurrentToForgetWeights( - {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, - -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, - 0.28053468, 0.01560611, -0.20127171, -0.01140004}); + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetRecurrentToOutputWeights( - {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, - 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, - -0.51818722, -0.15390486, 0.0468148, 0.39922136}); + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - // Input should have n_input * sequence_length many values. - static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; - static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126, - -0.15358765, -0.03716109, 0.12507336, - 0.41193449, -0.20860538, -0.15053082, - 0.09120187, 0.24278517, -0.12222792}; + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - float* batch0_start = lstm_input; - float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, + /*tolerance=*/0.0157651); +} - lstm.SetInput(0, batch0_start, batch0_end); +class CifgPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726, + 0.05100781, 0.04717243, 0.48944736, + -0.38535351, -0.17212132}; - lstm.Invoke(); + input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, + 0.24407166, 0.33826375}; - float* golden_start = lstm_golden_output; - float* golden_end = - golden_start + lstm.num_outputs() * lstm.sequence_length(); - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); -} + input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_cell_weights_ = { + 0.54066205, -0.32668582, -0.43562764, -0.56094903, + 0.42957711, 0.01841056, -0.32764608, -0.33027974, + -0.10826075, 0.20675004, 0.19069612, -0.03026325, + -0.54532051, 0.33003211, 0.44901288, 0.21193194}; + + recurrent_to_forget_weights_ = { + -0.13832897, -0.0515101, -0.2359007, -0.16661474, + -0.14340827, 0.36986142, 0.23414481, 0.55899, + 0.10798943, -0.41174671, 0.17751795, -0.34484994, + -0.35874045, -0.11352962, 0.27268326, 0.54058349}; + + recurrent_to_output_weights_ = { + 0.41613156, 0.42610586, -0.16495961, -0.5663873, + 0.30579174, -0.05115908, -0.33941799, 0.23364776, + 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}; + + cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408, + 0.31544167}; + cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703, + -0.77109635}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646, + -0.42312205, -0.01218222, 0.24201041, -0.08124574, + -0.358325, -0.04621704, 0.21641694, -0.06471302}}; + } +}; -TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { +TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -360,9 +600,11 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { const int sequence_length = 3; UnidirectionalLSTMOpModel lstm( - n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true, - /*use_peephole=*/true, /*use_projection_weights=*/false, - /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + n_batch, n_input, n_cell, n_output, sequence_length, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, { {sequence_length, n_batch, n_input}, // input tensor @@ -389,71 +631,690 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, - 0.04717243, 0.48944736, -0.38535351, - -0.17212132}); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, - -0.3633365, -0.22755712, 0.28253698, 0.24407166, - 0.33826375}); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, - -0.09426838, -0.44257352, 0.54939759, - 0.01533556, 0.42751634}); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} + +TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; + + HybridUnidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor - lstm.SetCellBias({0., 0., 0., 0.}); + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor - lstm.SetForgetGateBias({1., 1., 1., 1.}); + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor - lstm.SetOutputGateBias({0., 0., 0., 0.}); + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor - lstm.SetRecurrentToCellWeights( - {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, - 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, - 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, - 0.21193194}); + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); - lstm.SetRecurrentToForgetWeights( - {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, - 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, - -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetRecurrentToOutputWeights( - {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, - -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, - 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetCellToForgetWeights( - {0.47485286, -0.51955009, -0.24458408, 0.31544167}); - lstm.SetCellToOutputWeights( - {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; - static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585, - -0.05163646, -0.42312205, -0.01218222, - 0.24201041, -0.08124574, -0.358325, - -0.04621704, 0.21641694, -0.06471302}; + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - float* batch0_start = lstm_input; - float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); - - lstm.SetInput(0, batch0_start, batch0_end); - - lstm.Invoke(); - - float* golden_start = lstm_golden_output; - float* golden_end = - golden_start + lstm.num_outputs() * lstm.sequence_length(); - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); } -TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { +class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = { + 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}; + + input_to_forget_weights_ = { + -0.0018401089, -0.004852237, 0.03698424, 0.014181704, + 0.028273236, -0.016726194, -0.05249759, -0.10204261, + 0.00861066, -0.040979505, -0.009899187, 0.01923892, + -0.028177269, -0.08535103, -0.14585495, 0.10662567, + -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, + 0.0814421, -0.12257899, -0.033945758, -0.031303465, + 0.045630626, 0.06843887, -0.13492945, -0.012480007, + -0.0811829, -0.07224499, -0.09628791, 0.045100946, + 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, + 0.052625068, 0.12784666, 0.07077897, 0.025725935, + 0.04165009, 0.07241905, 0.018668644, -0.037377294, + -0.06277783, -0.08833636, -0.040120605, -0.011405586, + -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, + 0.13396506, -0.08402166, -0.01901462, -0.044678304, + -0.07720565, 0.014350063, -0.11757958, -0.0652038, + -0.08185733, -0.076754324, -0.092614375, 0.10405491, + 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, + -0.054523353, 0.02582715, 0.02327355, -0.011857179, + -0.0011980024, -0.034641717, -0.026125094, -0.17582615, + -0.15923657, -0.27486774, -0.0006143371, 0.0001771948, + -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}; + + input_to_cell_weights_ = { + -0.04580283, -0.09549462, -0.032418985, -0.06454633, + -0.043528453, 0.043018587, -0.049152344, -0.12418144, + -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, + -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, + 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, + -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, + 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, + 0.053568836, 0.06408714, 0.12835667, -0.008714329, + -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, + -0.036999565, -0.028842626, -0.0033637602, -0.017012902, + -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, + -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, + 0.07463894, 0.0075130584, 0.012850982, 0.04555431, + 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}; + + input_to_output_weights_ = { + -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}; + + input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666, + 0.053110216, -0.06928846, -0.13942584, -0.11816189, + 0.19483899, 0.03652339, -0.10250295, 0.036714908, + -0.18426876, 0.036065217, 0.21810818, 0.02383196, + -0.043370757, 0.08690144, -0.04444982, 0.00030581196}; + + forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}; + + cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}; + + output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113, + 0.027195795, 0.35373217, -0.018957434, 0.008907322, + -0.0762701, 0.12018895, 0.04216877, 0.0022856654, + 0.040952638, 0.3147856, 0.08225149, -0.057416286, + -0.14995944, -0.008040261, 0.13208859, 0.029760877}; + + recurrent_to_input_weights_ = { + -0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}; + + recurrent_to_cell_weights_ = { + -0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}; + + recurrent_to_forget_weights_ = { + -0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}; + + recurrent_to_output_weights_ = { + 0.025825322, -0.05813119, 0.09495884, -0.045984812, + -0.01255415, -0.0026479573, -0.08196161, -0.054914974, + -0.0046604523, -0.029587349, -0.044576716, -0.07480124, + -0.082868785, 0.023254942, 0.027502948, -0.0039728214, + -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, + -0.08829125, -0.005139627, -0.08989442, -0.0555066, + 0.13596267, -0.025062224, -0.048351806, -0.03850004, + 0.07266485, -0.022414139, 0.05940088, 0.075114764, + 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, + 0.014416728, 0.043229222, 0.034178585, -0.07530371, + 0.035837382, -0.085607, -0.007721233, -0.03287832, + -0.043848954, -0.06404588, -0.06632928, -0.073643476, + 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, + -0.030063879, 0.008801774, -0.023021035, -0.019558564, + 0.05158114, -0.010947698, -0.011825728, 0.0075720972, + 0.0699727, -0.0039981045, 0.069350146, 0.08799282, + 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, + -0.00924166, 0.0046702605, -0.036598757, -0.08811812, + 0.10522024, -0.032441203, 0.008176899, -0.04454919, + 0.07058152, 0.0067963637, 0.039206743, 0.03259838, + 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, + 0.036879618, 0.043357447, 0.028362012, -0.05908629, + 0.0059240665, -0.04995891, -0.019187413, 0.0276265, + -0.01628143, 0.0025863599, 0.08800015, 0.035250366, + -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, + -0.009660886, 0.019076364, 0.018299393, -0.046004917, + 0.08891175, 0.0431396, -0.026327137, -0.051502608, + 0.08979574, -0.051670972, 0.04940282, -0.07491107, + -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, + -0.035936575, -0.011681591, 0.064818054, 0.0073146066, + -0.021745546, -0.043124277, -0.06471268, -0.07053354, + -0.029321948, -0.05330136, 0.016933719, -0.053782392, + 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, + -0.07924483, 0.06936997, 0.0034815092, -0.007305279, + -0.037325785, -0.07251102, -0.033633437, -0.08677009, + 0.091591336, -0.14165086, 0.021752775, 0.019683983, + 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, + 0.1183656, -0.0010731248, -0.023590032, -0.072285876, + -0.0724771, -0.026382286, -0.0014920527, 0.042667855, + 0.0018776858, 0.02986552, 0.009814309, 0.0733756, + 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, + -0.010036754, 0.02576849, -0.08307328, 0.010112348, + 0.042521734, -0.05869831, -0.071689695, 0.03876447, + -0.13275425, -0.0352966, -0.023077697, 0.10285965, + 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, + -0.08271222, -0.0030240538, -0.016368777, 0.1070414, + 0.042672627, 0.013456989, -0.0437609, -0.022309763, + 0.11576483, 0.04108048, 0.061026827, -0.0190714, + -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, + -0.023771819, -0.01965048, 0.007955471, -0.043740474, + 0.03346837, -0.10549954, 0.090567775, 0.042013682, + -0.03176985, 0.12569028, -0.02421228, -0.029526481, + 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, + -0.06861939, -0.021256343, -0.041093912, -0.06669611, + 0.035498552, 0.021757556, -0.09302526, -0.015403468, + -0.06614931, -0.051798206, -0.013874718, 0.03630673, + 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, + -0.020674974, -0.03944324, -0.008110165, -0.11113267, + 0.08484226, 0.043586485, 0.040582247, 0.0968012, + -0.065249965, -0.028036479, 0.0050708856, 0.0017462453, + 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, + -0.11768019, 0.085926116, -0.08251791, -0.045081906, + 0.0948852, 0.068401024, 0.024856757, 0.06978981, + -0.057309967, -0.012775832, -0.0032452994, 0.01977615, + -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }; + + cell_to_input_weights_ = { + 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}; + + cell_to_forget_weights_ = { + -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}; + + cell_to_output_weights_ = { + 0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}; + + projection_weights_ = { + -0.009802181, 0.09401916, 0.0717386, -0.13895074, + 0.09641832, 0.060420845, 0.08539281, 0.054285463, + 0.061395317, 0.034448683, -0.042991187, 0.019801661, + -0.16840284, -0.015726732, -0.23041931, -0.024478018, + -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, + 0.16169067, 0.22465782, -0.03993472, -0.004017731, + 0.08633481, -0.28869787, 0.08682067, 0.17240396, + 0.014975425, 0.056431185, 0.031037588, 0.16702051, + 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, + 0.014777949, -0.20203483, 0.094781205, 0.19100232, + 0.13987629, -0.036132768, -0.06426278, -0.05108664, + 0.13221376, 0.009441198, -0.16715929, 0.15859416, + -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, + 0.0046453946, 0.050794356, 0.10770313, -0.20790008, + -0.07149004, -0.11425117, 0.008225835, -0.035802525, + 0.14374903, 0.15262283, 0.048710253, 0.1847461, + -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, + 0.016261552, 0.022461696, 0.12689082, -0.043589946, + -0.12035478, -0.08361797, -0.050666027, -0.1248618, + -0.1275799, -0.071875185, 0.07377272, 0.09944291, + -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, + 0.041546922, -0.20424393, 0.06907816, 0.050412357, + 0.00724631, 0.039827548, 0.12449835, 0.10747581, + 0.13708383, 0.09134148, -0.12617786, -0.06428341, + 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, + 0.022913318, -0.042050496, 0.16842307, -0.060597885, + 0.10531834, -0.06411776, -0.07451711, -0.03410368, + -0.13393489, 0.06534304, 0.003620307, 0.04490757, + 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, + -0.024575593, -0.036445823, 0.07155557, 0.009672501, + -0.02328883, 0.009533515, -0.03606021, -0.07421458, + -0.028082801, -0.2678904, -0.13221288, 0.18419984, + -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, + 0.14494097, -0.12522776, -0.098633975, -0.10766018, + -0.08317623, 0.08594209, 0.07749552, 0.039474737, + 0.1776665, -0.07409566, -0.0477268, 0.29323658, + 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, + 0.10034707, 0.045594677, 0.0635285, -0.0715442, + -0.089667566, -0.10811871, 0.00026344223, 0.08298446, + -0.009525053, 0.006585689, -0.24567553, -0.09450807, + 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, + 0.067035615, 0.19271925, -0.0032889997, -0.043264326, + 0.09663576, -0.057112187, -0.10100678, 0.0628376, + 0.04447668, 0.017961001, -0.10094388, -0.10190601, + 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, + -0.1947263, 0.02251204, 0.11216432, -0.10307853, + 0.17351969, -0.039091777, 0.08066188, -0.00561982, + 0.12633002, 0.11335965, -0.0088127935, -0.019777594, + 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, + -0.07468996, -0.0855457, 0.099339016, -0.07580735, + -0.13775392, 0.08434318, 0.08330512, -0.12131499, + 0.031935584, 0.09180414, -0.08876437, -0.08049874, + 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, + 0.04331237, 0.04299654, -0.036394123, -0.12915532, + 0.09793732, 0.07512415, -0.11319543, -0.032502122, + 0.15661901, 0.07671967, -0.005491124, -0.19379048, + -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, + -0.09334311, 0.15026465, -0.15493552, -0.057762887, + -0.11604192, -0.262013, -0.01391798, 0.012185008, + 0.11156489, -0.07483202, 0.06693364, -0.26151478, + 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, + 0.030635102, 0.010969227, 0.11109743, 0.010919218, + 0.027526086, 0.13519906, 0.01891392, -0.046839405, + -0.040167913, 0.017953383, -0.09700955, 0.0061885654, + -0.07000971, 0.026893595, -0.038844477, 0.14543656}; + + lstm_input_ = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0 + 0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1 + 0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2 + 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3 + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0 + 0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1 + 0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2 + 0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3 + }; + + lstm_golden_output_ = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, + -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, + -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, + 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, + -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, + -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, + 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, + 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, + 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, + -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, + -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, + -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, + 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, + 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, + -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, + -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, + 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, + 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, + 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, + -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, + -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + } +}; + +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 2; const int n_input = 5; const int n_cell = 20; @@ -461,8 +1322,9 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { const int sequence_length = 4; UnidirectionalLSTMOpModel lstm( - n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, - /*use_peephole=*/true, /*use_projection_weights=*/true, + n_batch, n_input, n_cell, n_output, sequence_length, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, { @@ -491,588 +1353,99 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToInputWeights( - {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, - 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, - -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, - -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, - -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, - -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, - -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, - 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, - 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, - 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, - -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, - 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, - -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, - -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, - -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, - 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, - -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, - -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, - -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, - -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}); - - lstm.SetInputToForgetWeights( - {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, - -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, - -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, - 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, - 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, - -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, - -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, - 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, - 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, - 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, - 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, - -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, - 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, - -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, - -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, - 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, - 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, - 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, - -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, - 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); - - lstm.SetInputToCellWeights( - {-0.04580283, -0.09549462, -0.032418985, -0.06454633, - -0.043528453, 0.043018587, -0.049152344, -0.12418144, - -0.078985475, -0.07596889, 0.019484362, -0.11434962, - -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, - -0.025034338, -0.0028890965, 0.048929527, 0.06235075, - 0.10665918, -0.032036792, -0.08505916, -0.10843358, - -0.13002433, -0.036816437, -0.02130134, -0.016518239, - 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, - -0.10652836, -0.1037554, -0.13056071, -0.03266643, - -0.033702414, -0.006473424, -0.04611692, 0.014419339, - -0.025174323, 0.0396852, 0.081777506, 0.06157468, - 0.10210095, -0.009658194, 0.046511717, 0.03603906, - 0.0069369148, 0.015960095, -0.06507666, 0.09551598, - 0.053568836, 0.06408714, 0.12835667, -0.008714329, - -0.20211966, -0.12093674, 0.029450472, 0.2849013, - -0.029227901, 0.1164364, -0.08560263, 0.09941786, - -0.036999565, -0.028842626, -0.0033637602, -0.017012902, - -0.09720865, -0.11193351, -0.029155117, -0.017936034, - -0.009768936, -0.04223324, -0.036159635, 0.06505112, - -0.021742892, -0.023377212, -0.07221364, -0.06430552, - 0.05453865, 0.091149814, 0.06387331, 0.007518393, - 0.055960953, 0.069779344, 0.046411168, 0.10509911, - 0.07463894, 0.0075130584, 0.012850982, 0.04555431, - 0.056955688, 0.06555285, 0.050801456, -0.009862683, - 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); - - lstm.SetInputToOutputWeights( - {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, - -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, - 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, - -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, - -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, - 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, - -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, - -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, - -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, - -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, - 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, - 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, - 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, - -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, - 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, - 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, - -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, - 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, - -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, - -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}); - - lstm.SetInputGateBias( - {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, - -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, - -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, - 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); - - lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, - 0.11098921, 0.15378423, 0.09263801, 0.09790885, - 0.09508917, 0.061199076, 0.07665568, -0.015443159, - -0.03499149, 0.046190713, 0.08895977, 0.10899629, - 0.40694186, 0.06030037, 0.012413437, -0.06108739}); - - lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, - -0.1483596, -0.10639995, -0.091433935, 0.058573797, - -0.06809782, -0.07889636, -0.043246906, -0.09829136, - -0.4279842, 0.034901652, 0.18797937, 0.0075234566, - 0.016178843, 0.1749513, 0.13975595, 0.92058027}); - - lstm.SetOutputGateBias( - {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, - 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, - 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, - -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); - - lstm.SetRecurrentToInputWeights( - {-0.001374326, -0.078856036, 0.10672688, 0.029162422, - -0.11585556, 0.02557986, -0.13446963, -0.035785314, - -0.01244275, 0.025961924, -0.02337298, -0.044228926, - -0.055839065, -0.046598054, -0.010546039, -0.06900766, - 0.027239809, 0.022582639, -0.013296484, -0.05459212, - 0.08981, -0.045407712, 0.08682226, -0.06867011, - -0.14390695, -0.02916037, 0.000996957, 0.091420636, - 0.14283475, -0.07390571, -0.06402044, 0.062524505, - -0.093129106, 0.04860203, -0.08364217, -0.08119002, - 0.009352075, 0.22920375, 0.0016303885, 0.11583097, - -0.13732095, 0.012405723, -0.07551853, 0.06343048, - 0.12162708, -0.031923793, -0.014335606, 0.01790974, - -0.10650317, -0.0724401, 0.08554849, -0.05727212, - 0.06556731, -0.042729504, -0.043227166, 0.011683251, - -0.013082158, -0.029302018, -0.010899579, -0.062036745, - -0.022509435, -0.00964907, -0.01567329, 0.04260106, - -0.07787477, -0.11576462, 0.017356863, 0.048673786, - -0.017577527, -0.05527947, -0.082487635, -0.040137455, - -0.10820036, -0.04666372, 0.022746278, -0.07851417, - 0.01068115, 0.032956902, 0.022433773, 0.0026891115, - 0.08944216, -0.0685835, 0.010513544, 0.07228705, - 0.02032331, -0.059686817, -0.0005566496, -0.086984694, - 0.040414046, -0.1380399, 0.094208956, -0.05722982, - 0.012092817, -0.04989123, -0.086576, -0.003399834, - -0.04696032, -0.045747425, 0.10091314, 0.048676282, - -0.029037097, 0.031399418, -0.0040285117, 0.047237843, - 0.09504992, 0.041799378, -0.049185462, -0.031518843, - -0.10516937, 0.026374253, 0.10058866, -0.0033195973, - -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, - -0.10167381, 0.042500053, -0.01447153, 0.06464186, - -0.017142897, 0.03312627, 0.009205989, 0.024138335, - -0.011337001, 0.035530265, -0.010912711, 0.0706555, - -0.005894094, 0.051841937, -0.1401738, -0.02351249, - 0.0365468, 0.07590991, 0.08838724, 0.021681072, - -0.10086113, 0.019608743, -0.06195883, 0.077335775, - 0.023646897, -0.095322326, 0.02233014, 0.09756986, - -0.048691444, -0.009579111, 0.07595467, 0.11480546, - -0.09801813, 0.019894179, 0.08502348, 0.004032281, - 0.037211012, 0.068537936, -0.048005626, -0.091520436, - -0.028379958, -0.01556313, 0.06554592, -0.045599163, - -0.01672207, -0.020169014, -0.011877351, -0.20212261, - 0.010889619, 0.0047078193, 0.038385306, 0.08540671, - -0.017140968, -0.0035865551, 0.016678626, 0.005633034, - 0.015963363, 0.00871737, 0.060130805, 0.028611384, - 0.10109069, -0.015060172, -0.07894427, 0.06401885, - 0.011584063, -0.024466386, 0.0047652307, -0.09041358, - 0.030737216, -0.0046374933, 0.14215417, -0.11823516, - 0.019899689, 0.006106124, -0.027092824, 0.0786356, - 0.05052217, -0.058925, -0.011402121, -0.024987547, - -0.0013661642, -0.06832946, -0.015667673, -0.1083353, - -0.00096863037, -0.06988685, -0.053350925, -0.027275559, - -0.033664223, -0.07978348, -0.025200296, -0.017207067, - -0.058403496, -0.055697463, 0.005798788, 0.12965427, - -0.062582195, 0.0013350133, -0.10482091, 0.0379771, - 0.072521195, -0.0029455067, -0.13797039, -0.03628521, - 0.013806405, -0.017858358, -0.01008298, -0.07700066, - -0.017081132, 0.019358726, 0.0027079724, 0.004635139, - 0.062634714, -0.02338735, -0.039547626, -0.02050681, - 0.03385117, -0.083611414, 0.002862572, -0.09421313, - 0.058618143, -0.08598433, 0.00972939, 0.023867095, - -0.053934585, -0.023203006, 0.07452513, -0.048767887, - -0.07314807, -0.056307215, -0.10433547, -0.06440842, - 0.04328182, 0.04389765, -0.020006588, -0.09076438, - -0.11652589, -0.021705797, 0.03345259, -0.010329105, - -0.025767034, 0.013057034, -0.07316461, -0.10145612, - 0.06358255, 0.18531723, 0.07759293, 0.12006465, - 0.1305557, 0.058638252, -0.03393652, 0.09622831, - -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, - -0.005644518, 0.06857898, -0.12598175, -0.035084512, - 0.03156317, -0.12794146, -0.031963028, 0.04692781, - 0.030070418, 0.0071660685, -0.095516115, -0.004643372, - 0.040170413, -0.062104587, -0.0037324072, 0.0554317, - 0.08184801, -0.019164372, 0.06791302, 0.034257166, - -0.10307039, 0.021943003, 0.046745934, 0.0790918, - -0.0265588, -0.007824208, 0.042546265, -0.00977924, - -0.0002440307, -0.017384544, -0.017990116, 0.12252321, - -0.014512694, -0.08251313, 0.08861942, 0.13589665, - 0.026351685, 0.012641483, 0.07466548, 0.044301085, - -0.045414884, -0.051112458, 0.03444247, -0.08502782, - -0.04106223, -0.028126027, 0.028473156, 0.10467447}); - - lstm.SetRecurrentToForgetWeights( - {-0.057784554, -0.026057621, -0.068447545, -0.022581743, - 0.14811787, 0.10826372, 0.09471067, 0.03987225, - -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, - 0.08414449, -0.022036452, -0.00066928595, -0.09203576, - 0.032950465, -0.10985798, -0.023809856, 0.0021431844, - -0.02196096, -0.00326074, 0.00058621005, -0.074678116, - -0.06193199, 0.055729095, 0.03736828, 0.020123724, - 0.061878487, -0.04729229, 0.034919553, -0.07585433, - -0.04421272, -0.044019096, 0.085488975, 0.04058006, - -0.06890133, -0.030951202, -0.024628663, -0.07672815, - 0.034293607, 0.08556707, -0.05293577, -0.033561368, - -0.04899627, 0.0241671, 0.015736353, -0.095442444, - -0.029564252, 0.016493602, -0.035026584, 0.022337519, - -0.026871363, 0.004780428, 0.0077918363, -0.03601621, - 0.016435321, -0.03263031, -0.09543275, -0.047392778, - 0.013454138, 0.028934088, 0.01685226, -0.086110644, - -0.046250615, -0.01847454, 0.047608484, 0.07339695, - 0.034546845, -0.04881143, 0.009128804, -0.08802852, - 0.03761666, 0.008096139, -0.014454086, 0.014361001, - -0.023502491, -0.0011840804, -0.07607001, 0.001856849, - -0.06509276, -0.006021153, -0.08570962, -0.1451793, - 0.060212336, 0.055259194, 0.06974018, 0.049454916, - -0.027794661, -0.08077226, -0.016179763, 0.1169753, - 0.17213494, -0.0056326236, -0.053934924, -0.0124349, - -0.11520337, 0.05409887, 0.088759385, 0.0019655675, - 0.0042065294, 0.03881498, 0.019844765, 0.041858196, - -0.05695512, 0.047233116, 0.038937137, -0.06542224, - 0.014429736, -0.09719407, 0.13908425, -0.05379757, - 0.012321099, 0.082840554, -0.029899208, 0.044217527, - 0.059855383, 0.07711018, -0.045319796, 0.0948846, - -0.011724666, -0.0033288454, -0.033542685, -0.04764985, - -0.13873616, 0.040668588, 0.034832682, -0.015319203, - -0.018715994, 0.046002675, 0.0599172, -0.043107376, - 0.0294216, -0.002314414, -0.022424703, 0.0030315618, - 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, - 0.12375372, -0.0006038222, 0.029104086, 0.087442465, - 0.052958444, 0.07558703, 0.04817258, 0.044462286, - -0.015213451, -0.08783778, -0.0561384, -0.003008196, - 0.047060397, -0.002058388, 0.03429439, -0.018839769, - 0.024734668, 0.024614193, -0.042046934, 0.09597743, - -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, - -0.02558259, -0.022822596, -0.023273505, -0.02464396, - -0.10991725, -0.006240552, 0.0074488563, 0.024044557, - 0.04383914, -0.046476185, 0.028658995, 0.060410924, - 0.050786525, 0.009452605, -0.0073054377, -0.024810238, - 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, - 0.015898481, 0.021362653, -0.030262267, 0.016587038, - -0.011442813, 0.041154444, -0.007631438, -0.03423484, - -0.010977775, 0.036152758, 0.0066366293, 0.11915515, - 0.02318443, -0.041350313, 0.021485701, -0.10906167, - -0.028218046, -0.00954771, 0.020531068, -0.11995105, - -0.03672871, 0.024019798, 0.014255957, -0.05221243, - -0.00661567, -0.04630967, 0.033188973, 0.10107534, - -0.014027541, 0.030796422, -0.10270911, -0.035999842, - 0.15443139, 0.07684145, 0.036571592, -0.035900835, - -0.0034699554, 0.06209149, 0.015920248, -0.031122351, - -0.03858649, 0.01849943, 0.13872518, 0.01503974, - 0.069941424, -0.06948533, -0.0088794185, 0.061282158, - -0.047401894, 0.03100163, -0.041533746, -0.10430945, - 0.044574402, -0.01425562, -0.024290353, 0.034563623, - 0.05866852, 0.023947537, -0.09445152, 0.035450947, - 0.02247216, -0.0042998926, 0.061146557, -0.10250651, - 0.020881841, -0.06747029, 0.10062043, -0.0023941975, - 0.03532124, -0.016341697, 0.09685456, -0.016764693, - 0.051808182, 0.05875331, -0.04536488, 0.001626336, - -0.028892258, -0.01048663, -0.009793449, -0.017093895, - 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, - -0.001845119, -0.03551521, 0.0018358806, 0.05763657, - -0.01769146, 0.040995963, 0.02235177, -0.060430344, - 0.11475477, -0.023854522, 0.10071741, 0.0686208, - -0.014250481, 0.034261297, 0.047418304, 0.08562733, - -0.030519066, 0.0060542435, 0.014653856, -0.038836084, - 0.04096551, 0.032249358, -0.08355519, -0.026823482, - 0.056386515, -0.010401743, -0.028396193, 0.08507674, - 0.014410365, 0.020995233, 0.17040324, 0.11511526, - 0.02459721, 0.0066619175, 0.025853224, -0.023133837, - -0.081302024, 0.017264642, -0.009585969, 0.09491168, - -0.051313367, 0.054532815, -0.014298593, 0.10657464, - 0.007076659, 0.10964551, 0.0409152, 0.008275321, - -0.07283536, 0.07937492, 0.04192024, -0.1075027}); - - lstm.SetRecurrentToCellWeights( - {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, - 0.055647098, -0.05713207, -0.05626563, 0.005559383, - 0.03375411, -0.025757805, -0.088049285, 0.06017052, - -0.06570978, 0.007384076, 0.035123326, -0.07920549, - 0.053676967, 0.044480428, -0.07663568, 0.0071805613, - 0.08089997, 0.05143358, 0.038261272, 0.03339287, - -0.027673481, 0.044746667, 0.028349208, 0.020090483, - -0.019443132, -0.030755889, -0.0040000007, 0.04465846, - -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, - -0.10893326, 0.076739706, -0.08509834, -0.027997585, - 0.037871376, 0.01449768, -0.09002357, -0.06111149, - -0.046195522, 0.0422062, -0.005683705, -0.1253618, - -0.012925729, -0.04890792, 0.06985068, 0.037654128, - 0.03398274, -0.004781977, 0.007032333, -0.031787455, - 0.010868644, -0.031489216, 0.09525667, 0.013939797, - 0.0058680447, 0.0167067, 0.02668468, -0.04797466, - -0.048885044, -0.12722108, 0.035304096, 0.06554885, - 0.00972396, -0.039238118, -0.05159735, -0.11329045, - 0.1613692, -0.03750952, 0.06529313, -0.071974665, - -0.11769596, 0.015524369, -0.0013754242, -0.12446318, - 0.02786344, -0.014179351, 0.005264273, 0.14376344, - 0.015983658, 0.03406988, -0.06939408, 0.040699873, - 0.02111075, 0.09669095, 0.041345075, -0.08316494, - -0.07684199, -0.045768797, 0.032298047, -0.041805092, - 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, - -0.024950314, 0.11574242, 0.04508852, -0.04335324, - 0.06760663, -0.027437469, 0.07216407, 0.06977076, - -0.05438599, 0.034033038, -0.028602652, 0.05346137, - 0.043184172, -0.037189785, 0.10420091, 0.00882477, - -0.054019816, -0.074273005, -0.030617684, -0.0028467078, - 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, - 0.04361412, -0.007001822, 0.09631092, -0.06702025, - -0.042049985, -0.035070654, -0.04103342, -0.10273396, - 0.0544271, 0.037184782, -0.13150354, -0.0058036847, - -0.008264958, 0.042035464, 0.05891794, 0.029673764, - 0.0063542654, 0.044788733, 0.054816857, 0.062257513, - -0.00093483756, 0.048938446, -0.004952862, -0.007730018, - -0.04043371, -0.017094059, 0.07229206, -0.023670016, - -0.052195564, -0.025616996, -0.01520939, 0.045104615, - -0.007376126, 0.003533447, 0.006570588, 0.056037236, - 0.12436656, 0.051817212, 0.028532185, -0.08686856, - 0.11868599, 0.07663395, -0.07323171, 0.03463402, - -0.050708205, -0.04458982, -0.11590894, 0.021273347, - 0.1251325, -0.15313013, -0.12224372, 0.17228661, - 0.023029093, 0.086124025, 0.006445803, -0.03496501, - 0.028332196, 0.04449512, -0.042436164, -0.026587414, - -0.006041347, -0.09292539, -0.05678812, 0.03897832, - 0.09465633, 0.008115513, -0.02171956, 0.08304309, - 0.071401566, 0.019622514, 0.032163795, -0.004167056, - 0.02295182, 0.030739572, 0.056506045, 0.004612461, - 0.06524936, 0.059999723, 0.046395954, -0.0045512207, - -0.1335546, -0.030136576, 0.11584653, -0.014678886, - 0.0020118146, -0.09688814, -0.0790206, 0.039770417, - -0.0329582, 0.07922767, 0.029322514, 0.026405897, - 0.04207835, -0.07073373, 0.063781224, 0.0859677, - -0.10925287, -0.07011058, 0.048005477, 0.03438226, - -0.09606514, -0.006669445, -0.043381985, 0.04240257, - -0.06955775, -0.06769346, 0.043903265, -0.026784198, - -0.017840602, 0.024307009, -0.040079936, -0.019946516, - 0.045318738, -0.12233574, 0.026170589, 0.0074471775, - 0.15978073, 0.10185836, 0.10298046, -0.015476589, - -0.039390966, -0.072174534, 0.0739445, -0.1211869, - -0.0347889, -0.07943156, 0.014809798, -0.12412325, - -0.0030663363, 0.039695457, 0.0647603, -0.08291318, - -0.018529687, -0.004423833, 0.0037507233, 0.084633216, - -0.01514876, -0.056505352, -0.012800942, -0.06994386, - 0.012962922, -0.031234352, 0.07029052, 0.016418684, - 0.03618972, 0.055686004, -0.08663945, -0.017404709, - -0.054761406, 0.029065743, 0.052404847, 0.020238016, - 0.0048197987, -0.0214882, 0.07078733, 0.013016777, - 0.06262858, 0.009184685, 0.020785125, -0.043904778, - -0.0270329, -0.03299152, -0.060088247, -0.015162964, - -0.001828936, 0.12642565, -0.056757294, 0.013586685, - 0.09232601, -0.035886683, 0.06000002, 0.05229691, - -0.052580316, -0.082029596, -0.010794592, 0.012947712, - -0.036429964, -0.085508935, -0.13127148, -0.017744139, - 0.031502828, 0.036232427, -0.031581745, 0.023051167, - -0.05325106, -0.03421577, 0.028793324, -0.034633752, - -0.009881397, -0.043551125, -0.018609839, 0.0019097115, - -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); - - lstm.SetRecurrentToOutputWeights({ - 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, - -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, - -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, - -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, - -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, - -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, - -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, - 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, - -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, - 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, - -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, - -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, - 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, - 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, - -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, - 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, - 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, - 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, - 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, - 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, - -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, - 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, - -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, - 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, - 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, - 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, - -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, - -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, - -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, - -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, - -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, - -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, - 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, - 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, - -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, - 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, - -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, - -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, - -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, - 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, - 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, - 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, - -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, - 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, - -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, - -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, - -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, - -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, - 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, - -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, - 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, - -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, - -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, - -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, - -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, - 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, - 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, - -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, - 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, - 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, - -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, - 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, - 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, - 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, - }); - - lstm.SetCellToInputWeights( - {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, - -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, - -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, - 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); - - lstm.SetCellToForgetWeights( - {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, - -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, - -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, - 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}); - - lstm.SetCellToOutputWeights( - {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, - -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, - -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, - 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); - - lstm.SetProjectionWeights( - {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, - 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, - -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, - -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, - 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, - 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, - 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, - 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, - -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, - -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, - -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, - 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, - 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, - 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, - 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, - 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, - -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, - 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, - -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, - 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, - -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, - -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, - 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, - -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, - 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, - -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, - -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, - 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, - -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, - -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, - -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, - 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, - 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, - -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, - 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, - 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, - 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, - 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, - 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, - -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, - -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, - 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, - -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, - -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, - 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, - 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, - 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, - -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, - -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, - -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, - 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, - -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, - 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, - 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, - -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, - -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, - -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, - 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, - -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, - -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, - -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, - 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, - 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, - 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}); - - static float lstm_input[][20] = { - {// Batch0: 4 (input_sequence_size) * 5 (n_input) - 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, - 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, - 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, - - {// Batch1: 4 (input_sequence_size) * 5 (n_input) - 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, - 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, - 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; - - static float lstm_golden_output[][64] = { - {// Batch0: 4 (input_sequence_size) * 16 (n_output) - -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, - -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, - -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, - 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, - -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, - -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, - 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, - 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, - 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, - 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, - -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, - -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, - 0.0286833, 0.00824207, 0.0264887, 0.0305169}, - {// Batch1: 4 (input_sequence_size) * 16 (n_output) - -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, - -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, - 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, - 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, - -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, - -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, - 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, - 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, - 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, - 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, - -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, - -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, - 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - for (int i = 0; i < lstm.sequence_length(); i++) { - float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetInput(2 * i * lstm.num_inputs(), batch0_start, batch0_end); +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; + const int sequence_length = 4; - float* batch1_start = lstm_input[1] + i * lstm.num_inputs(); - float* batch1_end = batch1_start + lstm.num_inputs(); - lstm.SetInput((2 * i + 1) * lstm.num_inputs(), batch1_start, batch1_end); - } + HybridUnidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor - lstm.Invoke(); + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor - std::vector expected; - for (int i = 0; i < lstm.sequence_length(); i++) { - float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs(); - float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs(); - float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs(); - float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs(); - expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); - expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); - } - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc index 164a0cbd08d6ce82a413f12ba6b1703087a30aba..0d6d29a171735a00a8dcc6cd0213a859b9f8094a 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include diff --git a/tensorflow/contrib/lite/mmap_allocation.cc b/tensorflow/contrib/lite/mmap_allocation.cc new file mode 100644 index 0000000000000000000000000000000000000000..fa9a3cd1d839b07149bb80c3b7714b32b5eda235 --- /dev/null +++ b/tensorflow/contrib/lite/mmap_allocation.cc @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/error_reporter.h" + +namespace tflite { + +MMAPAllocation::MMAPAllocation(const char* filename, + ErrorReporter* error_reporter) + : Allocation(error_reporter), mmapped_buffer_(MAP_FAILED) { + mmap_fd_ = open(filename, O_RDONLY); + if (mmap_fd_ == -1) { + error_reporter_->Report("Could not open '%s'.", filename); + return; + } + struct stat sb; + fstat(mmap_fd_, &sb); + buffer_size_bytes_ = sb.st_size; + mmapped_buffer_ = + mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0); + if (mmapped_buffer_ == MAP_FAILED) { + error_reporter_->Report("Mmap of '%s' failed.", filename); + return; + } +} + +MMAPAllocation::~MMAPAllocation() { + if (valid()) { + munmap(const_cast(mmapped_buffer_), buffer_size_bytes_); + } + if (mmap_fd_ != -1) close(mmap_fd_); +} + +const void* MMAPAllocation::base() const { return mmapped_buffer_; } + +size_t MMAPAllocation::bytes() const { return buffer_size_bytes_; } + +bool MMAPAllocation::valid() const { return mmapped_buffer_ != MAP_FAILED; } + +bool MMAPAllocation::IsSupported() { return true; } + +} // namespace tflite diff --git a/tensorflow/contrib/lite/mmap_allocation_disabled.cc b/tensorflow/contrib/lite/mmap_allocation_disabled.cc new file mode 100644 index 0000000000000000000000000000000000000000..f3d4cf1a257d43ebd56cc9b8831de0bb1994d40c --- /dev/null +++ b/tensorflow/contrib/lite/mmap_allocation_disabled.cc @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/allocation.h" + +#include + +namespace tflite { + +MMAPAllocation::MMAPAllocation(const char* filename, + ErrorReporter* error_reporter) + : Allocation(error_reporter), mmapped_buffer_(nullptr) { + // The disabled variant should never be created. + assert(false); +} + +MMAPAllocation::~MMAPAllocation() {} + +const void* MMAPAllocation::base() const { return nullptr; } + +size_t MMAPAllocation::bytes() const { return 0; } + +bool MMAPAllocation::valid() const { return false; } + +bool MMAPAllocation::IsSupported() { return false; } + +} // namespace tflite diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index f54db3af878b811a77f516417d6306323de4550b..9edf5ba38f4c6506524074bc0a3ebe7e068c7ee3 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -16,16 +16,16 @@ limitations under the License. #include #include #include -#include #include #include -#include #include "tensorflow/contrib/lite/allocation.h" #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/model.h" +#ifndef TFLITE_MCU #include "tensorflow/contrib/lite/nnapi_delegate.h" +#endif #include "tensorflow/contrib/lite/version.h" namespace tflite { @@ -74,6 +74,7 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, return kTfLiteOk; } +#ifndef TFLITE_MCU // Loads a model from `filename`. If `mmap_file` is true then use mmap, // otherwise make a copy of the model in a buffer. std::unique_ptr GetAllocationFromFile(const char* filename, @@ -81,8 +82,8 @@ std::unique_ptr GetAllocationFromFile(const char* filename, ErrorReporter* error_reporter, bool use_nnapi) { std::unique_ptr allocation; - if (mmap_file) { - if (use_nnapi && NNAPIExists()) + if (mmap_file && MMAPAllocation::IsSupported()) { + if (use_nnapi && NNAPIDelegate::IsSupported()) allocation.reset(new NNAPIAllocation(filename, error_reporter)); else allocation.reset(new MMAPAllocation(filename, error_reporter)); @@ -121,6 +122,7 @@ std::unique_ptr FlatBufferModel::VerifyAndBuildFromFile( if (!model->initialized()) model.reset(); return model; } +#endif std::unique_ptr FlatBufferModel::BuildFromBuffer( const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) { @@ -186,6 +188,8 @@ InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model, op_resolver_(op_resolver), error_reporter_(ValidateErrorReporter(error_reporter)) {} +InterpreterBuilder::~InterpreterBuilder() {} + TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { TfLiteStatus status = kTfLiteOk; auto opcodes = model_->operator_codes(); @@ -204,8 +208,9 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { } else if (builtin_code != BuiltinOperator_CUSTOM) { registration = op_resolver_.FindOp(builtin_code, version); if (registration == nullptr) { - error_reporter_->Report("Didn't find op for builtin opcode '%s'\n", - EnumNameBuiltinOperator(builtin_code)); + error_reporter_->Report( + "Didn't find op for builtin opcode '%s' version '%d'\n", + EnumNameBuiltinOperator(builtin_code), version); status = kTfLiteError; } } else if (!opcode->custom_code()) { @@ -613,6 +618,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_MEAN: + case BuiltinOperator_REDUCE_MAX: + case BuiltinOperator_REDUCE_PROD: case BuiltinOperator_SUM: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_ReducerOptions()) { @@ -661,6 +668,15 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_ARG_MIN: { + auto* params = MallocPOD(); + if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) { + ConvertTensorType(schema_params->output_type(), ¶ms->output_type, + error_reporter); + } + *builtin_data = reinterpret_cast(params); + break; + } case BuiltinOperator_TRANSPOSE_CONV: { TfLiteTransposeConvParams* params = MallocPOD(); @@ -692,11 +708,39 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = static_cast(params); break; } + case BuiltinOperator_PACK: { + TfLitePackParams* params = MallocPOD(); + if (auto* pack_params = op->builtin_options_as_PackOptions()) { + params->values_count = pack_params->values_count(); + params->axis = pack_params->axis(); + } + *builtin_data = reinterpret_cast(params); + break; + } case BuiltinOperator_DELEGATE: { // TODO(ycling): Revisit when supporting saving delegated models. error_reporter->Report("DELEGATE op shouldn't exist in model."); return kTfLiteError; } + case BuiltinOperator_FAKE_QUANT: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) { + params->min = schema_params->min(); + params->max = schema_params->max(); + params->num_bits = schema_params->num_bits(); + params->narrow_range = schema_params->narrow_range(); + } + *builtin_data = static_cast(params); + break; + } + case BuiltinOperator_ONE_HOT: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_OneHotOptions()) { + params->axis = schema_params->axis(); + } + *builtin_data = static_cast(params); + break; + } // Below are the ops with no builtin_data strcture. case BuiltinOperator_BATCH_TO_SPACE_ND: @@ -739,6 +783,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_TOPK_V2: case BuiltinOperator_TRANSPOSE: case BuiltinOperator_POW: + case BuiltinOperator_LOGICAL_OR: + case BuiltinOperator_LOGICAL_AND: + case BuiltinOperator_LOGICAL_NOT: break; } return kTfLiteOk; @@ -991,7 +1038,7 @@ TfLiteStatus InterpreterBuilder::operator()( variables.push_back(i); } } - (**interpreter).SetVariables(variables); + (**interpreter).SetVariables(std::move(variables)); return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h index 3946b490417104f620ecb55bb22d4ef99fd33bb7..8bc9ecd7ce9725c3d88985ccd92d48aed169fe31 100644 --- a/tensorflow/contrib/lite/model.h +++ b/tensorflow/contrib/lite/model.h @@ -156,6 +156,7 @@ class InterpreterBuilder { InterpreterBuilder(const ::tflite::Model* model, const OpResolver& op_resolver, ErrorReporter* error_reporter = DefaultErrorReporter()); + ~InterpreterBuilder(); InterpreterBuilder(const InterpreterBuilder&) = delete; InterpreterBuilder& operator=(const InterpreterBuilder&) = delete; TfLiteStatus operator()(std::unique_ptr* interpreter); diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc index 15bae21a411c1241cf71ab4d3f0e0289eaac8ef3..df4f60d4ad4eb71f48eb3ad364f95f93b84f3d75 100644 --- a/tensorflow/contrib/lite/model_test.cc +++ b/tensorflow/contrib/lite/model_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include "tensorflow/contrib/lite/model.h" @@ -242,14 +241,6 @@ TEST(BasicFlatBufferModel, TestWithNullVerifier) { "tensorflow/contrib/lite/testdata/test_model.bin", nullptr)); } -struct TestErrorReporter : public ErrorReporter { - int Report(const char* format, va_list args) override { - calls++; - return 0; - } - int calls = 0; -}; - // This makes sure the ErrorReporter is marshalled from FlatBufferModel to // the Interpreter. TEST(BasicFlatBufferModel, TestCustomErrorReporter) { @@ -263,7 +254,7 @@ TEST(BasicFlatBufferModel, TestCustomErrorReporter) { TrivialResolver resolver; InterpreterBuilder(*model, resolver)(&interpreter); ASSERT_NE(interpreter->Invoke(), kTfLiteOk); - ASSERT_EQ(reporter.calls, 1); + ASSERT_EQ(reporter.num_calls(), 1); } // This makes sure the ErrorReporter is marshalled from FlatBufferModel to diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD index 8b5fa240ac31d9ee61879c42aee3c5d449ae60db..9d88c396ba69948e3ae285c913a4499a1409b93a 100644 --- a/tensorflow/contrib/lite/models/smartreply/BUILD +++ b/tensorflow/contrib/lite/models/smartreply/BUILD @@ -47,6 +47,7 @@ cc_test( name = "extract_feature_op_test", size = "small", srcs = ["ops/extract_feature_test.cc"], + tags = ["no_oss"], deps = [ ":custom_ops", "//tensorflow/contrib/lite:framework", @@ -61,6 +62,7 @@ cc_test( name = "normalize_op_test", size = "small", srcs = ["ops/normalize_test.cc"], + tags = ["no_oss"], deps = [ ":custom_ops", "//tensorflow/contrib/lite:framework", @@ -75,6 +77,7 @@ cc_test( name = "predict_op_test", size = "small", srcs = ["ops/predict_test.cc"], + tags = ["no_oss"], deps = [ ":custom_ops", "//tensorflow/contrib/lite:framework", diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.h b/tensorflow/contrib/lite/models/smartreply/predictor.h index 90260c8d620b0e756f72089d3f4d8d9f92d44fbe..3151192d9277b6df513a76afb08af30d0379b7b1 100644 --- a/tensorflow/contrib/lite/models/smartreply/predictor.h +++ b/tensorflow/contrib/lite/models/smartreply/predictor.h @@ -65,9 +65,9 @@ struct SmartReplyConfig { float backoff_confidence; // Backoff responses are used when predicted responses cannot fulfill the // list. - const std::vector& backoff_responses; + std::vector backoff_responses; - SmartReplyConfig(std::vector backoff_responses) + SmartReplyConfig(const std::vector& backoff_responses) : num_response(kDefaultNumResponse), backoff_confidence(kDefaultBackoffConfidence), backoff_responses(backoff_responses) {} diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 7627d89c091d08390021bb47c640749956d8796d..13325a8c7c62142fcbc8ca37af1216e8943b905b 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -29,27 +29,46 @@ limitations under the License. namespace tflite { -// TODO(aselle): FATAL leaves resources hanging. -void FATAL(const char* format, ...) { +void logError(const char* format, ...) { + // TODO(mikie): use android logging, stderr is not captured for Java + // applications va_list args; va_start(args, format); vfprintf(stderr, format, args); va_end(args); + fprintf(stderr, "\n"); fflush(stderr); - exit(1); } +#define FATAL(...) \ + logError(__VA_ARGS__); \ + exit(1); + // TODO(aselle): Change the error model to use status codes. -#define CHECK_TFLITE_SUCCESS(x) \ - if (x != kTfLiteOk) { \ - FATAL("Aborting since tflite returned failure."); \ +#define CHECK_TFLITE_SUCCESS(x) \ + if (x != kTfLiteOk) { \ + FATAL("Aborting since tflite returned failure nnapi_delegate.cc:%d.", \ + __LINE__); \ } -#define CHECK_NN(x) \ - if (x != ANEURALNETWORKS_NO_ERROR) { \ - FATAL("Aborting since tflite returned failure."); \ +#define CHECK_NN(x) \ + if (x != ANEURALNETWORKS_NO_ERROR) { \ + FATAL("Aborting since NNAPI returned failure nnapi_delegate.cc:%d", \ + __LINE__); \ } +#define RETURN_ERROR_IF_NN_FAILED(x) \ + if (x != ANEURALNETWORKS_NO_ERROR) { \ + logError( \ + "Returning error since NNAPI returned failure nnapi_delegate.cc:%d.", \ + __LINE__); \ + return kTfLiteError; \ + } + +// Tracking of NNAPI operand ids +static const int64_t kOperandIdNotSet = -1; +static const int64_t kOperandNotNeeded = -2; + namespace { int32_t GetAndroidSdkVersion() { @@ -104,21 +123,16 @@ NNAPIDelegate::~NNAPIDelegate() { } // Adds the tensors of the interpreter to the NN API model. -// Returns the number of operands added. -uint32_t addTensorOperands(tflite::Interpreter* interpreter, - ANeuralNetworksModel* nn_model, - const std::vector& skip_list) { +TfLiteStatus addTensorOperands(tflite::Interpreter* interpreter, + ANeuralNetworksModel* nn_model, + uint32_t* no_of_operands_added, + std::vector* nnapi_ids) { uint32_t next_id = 0; for (size_t i = 0; i < interpreter->tensors_size(); i++) { - // skip temporaries tensors. - bool shouldSkip = false; - for (auto skip_idx : skip_list) { - if (i == skip_idx) { - shouldSkip = true; - break; - } - } - if (shouldSkip) continue; + // Skip temporaries and RNN back-edges. + if ((*nnapi_ids)[i] == kOperandNotNeeded) continue; + + (*nnapi_ids)[i] = int64_t(next_id); int32_t nn_type = 0; // NNAPI requires 32-bit float scale to be zero, tflite doesn't care @@ -144,7 +158,18 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter, zeroPoint = tensor->params.zero_point; break; default: - FATAL("Unsupported type."); + logError("Unsupported tensor type %d", tensor->type); + return kTfLiteError; + } + if (tensor->dims->size == 0) { + logError("NNAPI doesn't support tensors with rank 0 (index %d name %s)", + i, tensor->name); + return kTfLiteError; + } + if (tensor->dims->size > 4) { + logError("NNAPI doesn't support tensors with rank > 4 (index %d name %s)", + i, tensor->name); + return kTfLiteError; } // TODO(aselle): Note, many of these are intermediate results. Do I need // to ever specify these sizes. I am currently below doing setValue @@ -154,36 +179,53 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter, ANeuralNetworksOperandType operand_type{ nn_type, static_cast(tensor->dims->size), reinterpret_cast(tensor->dims->data), scale, zeroPoint}; - CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)); + RETURN_ERROR_IF_NN_FAILED( + ANeuralNetworksModel_addOperand(nn_model, &operand_type)); // TODO(aselle): Based on Michael's suggestion, limiting this to read // only memory if (tensor->allocation_type == kTfLiteMmapRo) { if (const NNAPIAllocation* alloc = dynamic_cast( static_cast(tensor->allocation))) { - CHECK_NN(ANeuralNetworksModel_setOperandValueFromMemory( - nn_model, next_id, alloc->memory(), alloc->offset(tensor->data.raw), - tensor->bytes)); + RETURN_ERROR_IF_NN_FAILED( + ANeuralNetworksModel_setOperandValueFromMemory( + nn_model, next_id, alloc->memory(), + alloc->offset(tensor->data.raw), tensor->bytes)); } else { - CHECK_NN(ANeuralNetworksModel_setOperandValue( + RETURN_ERROR_IF_NN_FAILED(ANeuralNetworksModel_setOperandValue( nn_model, next_id, tensor->data.raw, tensor->bytes)); } } else if (tensor->bytes == 0) { // These size 0 tensors are optional tensors reserved. - CHECK_NN( + RETURN_ERROR_IF_NN_FAILED( ANeuralNetworksModel_setOperandValue(nn_model, next_id, nullptr, 0)); } ++next_id; } - return next_id; + *no_of_operands_added = next_id; + return kTfLiteOk; +} + +void MapAndAddTensorIds(const int* from_ids_buf, size_t from_ids_count, + std::vector* into, + const std::vector& map) { + for (size_t i = 0; i < from_ids_count; i++) { + int from_id = from_ids_buf[i]; + if (from_id == kOptionalTensor) { + into->push_back(from_id); + } else { + into->push_back(map[from_id]); + } + } } // Adds the operations and their parameters to the NN API model. // 'next-id' is the operand ID of the next operand of the model. -void AddOpsAndParams(tflite::Interpreter* interpreter, - ANeuralNetworksModel* nn_model, uint32_t next_id, - std::vector* model_state_inputs, - std::vector* model_state_outputs) { +TfLiteStatus AddOpsAndParams( + tflite::Interpreter* interpreter, ANeuralNetworksModel* nn_model, + uint32_t next_id, std::vector* model_state_inputs, + std::vector* model_state_outputs, + const std::vector& tensor_id_to_nnapi_id) { for (size_t i = 0; i < interpreter->nodes_size(); i++) { const auto* node_and_registration = interpreter->node_and_registration(i); const TfLiteNode& node = node_and_registration->first; @@ -192,10 +234,11 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, static_cast(registration.builtin_code); // Add the parameters. - std::vector augmented_inputs( - node.inputs->data, node.inputs->data + node.inputs->size); - std::vector augmented_outputs( - node.outputs->data, node.outputs->data + node.outputs->size); + std::vector augmented_inputs, augmented_outputs; + MapAndAddTensorIds(node.inputs->data, node.inputs->size, &augmented_inputs, + tensor_id_to_nnapi_id); + MapAndAddTensorIds(node.outputs->data, node.outputs->size, + &augmented_outputs, tensor_id_to_nnapi_id); auto add_scalar_int32 = [&nn_model, &augmented_inputs, &next_id](int value) { @@ -244,42 +287,54 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, model_state_outputs->push_back(tensor_id); next_id++; }; + auto check_and_add_activation = [&add_scalar_int32](int activation) { + if (activation > kTfLiteActRelu6) { + FATAL("NNAPI only supports RELU, RELU1 and RELU6 activations"); + } + add_scalar_int32(activation); + }; auto add_add_params = [&add_scalar_int32](void* data) { auto* builtin = reinterpret_cast(data); + if (builtin->activation > kTfLiteActRelu6) { + FATAL("NNAPI only supports RELU, RELU1 and RELU6 activations"); + } add_scalar_int32(builtin->activation); }; - auto add_pooling_params = [&add_scalar_int32](void* data) { + auto add_pooling_params = [&add_scalar_int32, + &check_and_add_activation](void* data) { auto builtin = reinterpret_cast(data); add_scalar_int32(builtin->padding); add_scalar_int32(builtin->stride_width); add_scalar_int32(builtin->stride_height); add_scalar_int32(builtin->filter_width); add_scalar_int32(builtin->filter_height); - add_scalar_int32(builtin->activation); + check_and_add_activation(builtin->activation); }; - auto add_convolution_params = [&add_scalar_int32](void* data) { + auto add_convolution_params = [&add_scalar_int32, + &check_and_add_activation](void* data) { auto builtin = reinterpret_cast(data); add_scalar_int32(builtin->padding); add_scalar_int32(builtin->stride_width); add_scalar_int32(builtin->stride_height); - add_scalar_int32(builtin->activation); + check_and_add_activation(builtin->activation); }; - auto add_depthwise_conv_params = [&add_scalar_int32](void* data) { + auto add_depthwise_conv_params = [&add_scalar_int32, + &check_and_add_activation](void* data) { auto builtin = reinterpret_cast(data); add_scalar_int32(builtin->padding); add_scalar_int32(builtin->stride_width); add_scalar_int32(builtin->stride_height); add_scalar_int32(builtin->depth_multiplier); - add_scalar_int32(builtin->activation); + check_and_add_activation(builtin->activation); }; - auto add_fully_connected_params = [&add_scalar_int32](void* data) { + auto add_fully_connected_params = [&check_and_add_activation](void* data) { auto builtin = reinterpret_cast(data); - add_scalar_int32(builtin->activation); + check_and_add_activation(builtin->activation); }; auto add_concatenation_params = [&add_scalar_int32](void* data) { @@ -311,6 +366,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, // LSTM in NNAPI requires scratch tensor as an output operand. auto add_lstm_scratch_tensor_float32 = [interpreter, &node, &nn_model, &next_id, &augmented_outputs]() { + if (node.temporaries->size == 0) return; int scratch_buffer_index = node.temporaries->data[0]; const TfLiteTensor* tensor = interpreter->tensor(scratch_buffer_index); ANeuralNetworksOperandType operand_type{ @@ -385,7 +441,14 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, add_pooling_params(node.builtin_data); nn_op_type = ANEURALNETWORKS_L2_POOL_2D; break; - case tflite::BuiltinOperator_CONV_2D: + case tflite::BuiltinOperator_CONV_2D: { + auto builtin = reinterpret_cast(node.builtin_data); + if (builtin->dilation_width_factor != 1 || + builtin->dilation_height_factor != 1 || node.inputs->size != 3) { + logError("NNAPI does not support dilated Conv2D."); + return kTfLiteError; + } + } add_convolution_params(node.builtin_data); nn_op_type = ANEURALNETWORKS_CONV_2D; break; @@ -429,6 +492,10 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, nn_op_type = ANEURALNETWORKS_SPACE_TO_DEPTH; break; case tflite::BuiltinOperator_LSTM: { + if (node.inputs->size + /* no of params */ 3 != 21) { + logError("NNAPI only supports 21-input LSTMs"); + return kTfLiteError; + } duplicate_state_tensor_float32( node.outputs->data[/*kOutputStateTensor*/ 0]); duplicate_state_tensor_float32( @@ -467,25 +534,56 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_DIV: nnapi_version = 11; // require NNAPI 1.1 nn_op_type = ANEURALNETWORKS_DIV; + check_and_add_activation( + reinterpret_cast(node.builtin_data)->activation); break; case tflite::BuiltinOperator_SUB: nnapi_version = 11; // require NNAPI 1.1 nn_op_type = ANEURALNETWORKS_SUB; + check_and_add_activation( + reinterpret_cast(node.builtin_data)->activation); break; case tflite::BuiltinOperator_SQUEEZE: nnapi_version = 11; // requires NNAPI 1.1 add_squeeze_params(node.builtin_data); nn_op_type = ANEURALNETWORKS_SQUEEZE; break; + case tflite::BuiltinOperator_TRANSPOSE: + // The permutation input tensor value dictates the output dimensions. + // TODO(b/110888333): Support dynamically-sized tensors in delegates. + if ((node.inputs->size > 1) && + (interpreter->tensor(node.inputs->data[1])->allocation_type != + kTfLiteMmapRo)) { + logError("NNAPI does not yet support dynamic tensors."); + return kTfLiteError; + } + nnapi_version = 11; // require NNAPI 1.1 + nn_op_type = ANEURALNETWORKS_TRANSPOSE; + break; + case tflite::BuiltinOperator_L2_NORMALIZATION: + nn_op_type = ANEURALNETWORKS_L2_NORMALIZATION; + if (reinterpret_cast(node.builtin_data) + ->activation != kTfLiteActNone) { + FATAL( + "NNAPI does not support L2Normalization with fused activations"); + } + break; + case tflite::BuiltinOperator_HASHTABLE_LOOKUP: + if (interpreter->tensor(node.outputs->data[0])->type != + kTfLiteFloat32) { + logError("NNAPI only support HASHTABLE_LOOKUP with float32 output", + builtin); + return kTfLiteError; + } + nn_op_type = ANEURALNETWORKS_HASHTABLE_LOOKUP; + break; case tflite::BuiltinOperator_CONCAT_EMBEDDINGS: case tflite::BuiltinOperator_LSH_PROJECTION: - case tflite::BuiltinOperator_HASHTABLE_LOOKUP: case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: - case tflite::BuiltinOperator_L2_NORMALIZATION: case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: case tflite::BuiltinOperator_PADV2: case tflite::BuiltinOperator_RESIZE_BILINEAR: @@ -496,7 +594,6 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_SPACE_TO_BATCH_ND: case tflite::BuiltinOperator_BATCH_TO_SPACE_ND: case tflite::BuiltinOperator_TOPK_V2: - case tflite::BuiltinOperator_TRANSPOSE: case tflite::BuiltinOperator_SPLIT: case tflite::BuiltinOperator_STRIDED_SLICE: case tflite::BuiltinOperator_EXP: @@ -508,6 +605,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_MAXIMUM: case tflite::BuiltinOperator_MINIMUM: case tflite::BuiltinOperator_ARG_MAX: + case tflite::BuiltinOperator_ARG_MIN: case tflite::BuiltinOperator_GREATER: case tflite::BuiltinOperator_GREATER_EQUAL: case tflite::BuiltinOperator_LESS: @@ -524,16 +622,24 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_EQUAL: case tflite::BuiltinOperator_NOT_EQUAL: case tflite::BuiltinOperator_SUM: + case tflite::BuiltinOperator_REDUCE_MAX: + case tflite::BuiltinOperator_REDUCE_PROD: case tflite::BuiltinOperator_SQRT: case tflite::BuiltinOperator_RSQRT: case tflite::BuiltinOperator_SHAPE: case tflite::BuiltinOperator_POW: - FATAL("Op code %d is currently not delegated to NNAPI", builtin); - nn_op_type = -1; // set to invalid + case tflite::BuiltinOperator_FAKE_QUANT: + case tflite::BuiltinOperator_PACK: + case tflite::BuiltinOperator_LOGICAL_OR: + case tflite::BuiltinOperator_ONE_HOT: + case tflite::BuiltinOperator_LOGICAL_AND: + case tflite::BuiltinOperator_LOGICAL_NOT: + logError("Op code %d is currently not delegated to NNAPI", builtin); + return kTfLiteError; break; case tflite::BuiltinOperator_CUSTOM: - FATAL("Custom operations are not supported when using NNAPI."); - nn_op_type = -1; // set to invalid + logError("Custom operations are not supported when using NNAPI."); + return kTfLiteError; break; } @@ -542,47 +648,70 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, } // Add the operation. - CHECK_NN(ANeuralNetworksModel_addOperation( + RETURN_ERROR_IF_NN_FAILED(ANeuralNetworksModel_addOperation( nn_model, nn_op_type, static_cast(augmented_inputs.size()), augmented_inputs.data(), static_cast(augmented_outputs.size()), reinterpret_cast(augmented_outputs.data()))); } + return kTfLiteOk; } TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) { - // TODO(aselle): This is not correct. need to handle resize invalidation. - if (nn_model_ && nn_compiled_model_) return kTfLiteOk; + if (nn_model_ && nn_compiled_model_) return model_status_; + // TODO(aselle): This is not correct. need to handle resize invalidation. if (!nn_model_) { CHECK_NN(ANeuralNetworksModel_create(&nn_model_)); - // Find all the temporary tensors and put them in a skip_list. - std::vector skip_list; + // Find which tensors should be added to NNAPI. TFLite has temporaries + // and RNN back-edges which are are not valid for NNAPI. We look through all + // inputs and outputs and mark the mapping in tensor_id_to_nnapi_id with + // kOperandIdNotSet. addTensorOperands will replace those with the + // corresponding NNAPI operand ids and skip kOperandNotNeeded entries. + std::vector tensor_id_to_nnapi_id(interpreter->tensors_size(), + kOperandNotNeeded); + auto set_ids_to_not_set = [&tensor_id_to_nnapi_id](const int* buf, + size_t count) { + for (int j = 0; j < count; j++) { + auto tensor_id = buf[j]; + if (tensor_id != kOptionalTensor) { + tensor_id_to_nnapi_id[tensor_id] = kOperandIdNotSet; + } + } + }; for (size_t i = 0; i < interpreter->nodes_size(); i++) { const auto* node_and_registration = interpreter->node_and_registration(i); const TfLiteNode& node = node_and_registration->first; - if (node.temporaries != nullptr) { - for (int j = 0; j < node.temporaries->size; j++) { - skip_list.push_back(static_cast(node.temporaries->data[j])); - } - } + set_ids_to_not_set(node.inputs->data, node.inputs->size); + set_ids_to_not_set(node.outputs->data, node.outputs->size); } - - uint32_t next_id = addTensorOperands(interpreter, nn_model_, skip_list); - AddOpsAndParams(interpreter, nn_model_, next_id, &model_states_inputs_, - &model_states_outputs_); - - std::vector augmented_inputs = interpreter->inputs(); - std::vector augmented_outputs = interpreter->outputs(); - - // All state tensors input/output need to be treated as model input/output. + set_ids_to_not_set(interpreter->inputs().data(), + interpreter->inputs().size()); + set_ids_to_not_set(interpreter->outputs().data(), + interpreter->outputs().size()); + + uint32_t next_id = 0; + RETURN_ERROR_IF_NN_FAILED(addTensorOperands( + interpreter, nn_model_, &next_id, &tensor_id_to_nnapi_id)); + RETURN_ERROR_IF_NN_FAILED( + AddOpsAndParams(interpreter, nn_model_, next_id, &model_states_inputs_, + &model_states_outputs_, tensor_id_to_nnapi_id)); + + std::vector augmented_inputs; + MapAndAddTensorIds(interpreter->inputs().data(), + interpreter->inputs().size(), &augmented_inputs, + tensor_id_to_nnapi_id); augmented_inputs.insert(augmented_inputs.end(), model_states_inputs_.begin(), model_states_inputs_.end()); - augmented_outputs.insert(augmented_outputs.end(), - model_states_outputs_.begin(), - model_states_outputs_.end()); + std::vector augmented_outputs; + MapAndAddTensorIds(interpreter->outputs().data(), + interpreter->outputs().size(), &augmented_outputs, + tensor_id_to_nnapi_id); + MapAndAddTensorIds(model_states_outputs_.data(), + model_states_outputs_.size(), &augmented_outputs, + tensor_id_to_nnapi_id); CHECK_NN(ANeuralNetworksModel_identifyInputsAndOutputs( nn_model_, static_cast(augmented_inputs.size()), @@ -600,7 +729,13 @@ TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) { TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) { if (!nn_model_) { - TF_LITE_ENSURE_STATUS(BuildGraph(interpreter)); + model_status_ = BuildGraph(interpreter); + if (model_status_ != kTfLiteOk) { + logError("Failed to build graph for NNAPI"); + } + } + if (model_status_ != kTfLiteOk) { + return model_status_; } ANeuralNetworksExecution* execution = nullptr; @@ -664,4 +799,6 @@ TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) { return kTfLiteOk; } +bool NNAPIDelegate::IsSupported() { return NNAPIExists(); } + } // namespace tflite diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h index 94dea4f9b23f208fddbacd3c77d889ea753a8a1d..2bdb2cc5c8211a48ea07e7ec45f9eebc0a3f7c10 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.h +++ b/tensorflow/contrib/lite/nnapi_delegate.h @@ -19,9 +19,10 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" -class ANeuralNetworsModel; +class ANeuralNetworksModel; +class ANeuralNetworksMemory; +class ANeuralNetworksCompilation; namespace tflite { @@ -54,19 +55,24 @@ class NNAPIDelegate { // Run TfLiteStatus Invoke(Interpreter* interpreter); + // Whether the current platform supports NNAPI delegation. + static bool IsSupported(); + private: // The NN API model handle ANeuralNetworksModel* nn_model_ = nullptr; // The NN API compilation handle ANeuralNetworksCompilation* nn_compiled_model_ = nullptr; + // Model status + TfLiteStatus model_status_ = kTfLiteOk; // List of state tensors for LSTM, RNN, SVDF. // NN API does not allow ops to maintain states across multiple // invocations. We need to manually create state input tensors from // corresponding state output tensors of TFLite operations, and map them // correctly. - std::vector model_states_inputs_; - std::vector model_states_outputs_; + std::vector model_states_inputs_; // holds NNAPI operand ids + std::vector model_states_outputs_; // holds TFLite tensor ids }; } // namespace tflite diff --git a/tensorflow/contrib/lite/nnapi_delegate_disabled.cc b/tensorflow/contrib/lite/nnapi_delegate_disabled.cc new file mode 100644 index 0000000000000000000000000000000000000000..efde72b1a76a86728f4cccd8782ca0e993dd0338 --- /dev/null +++ b/tensorflow/contrib/lite/nnapi_delegate_disabled.cc @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/nnapi_delegate.h" + +#include + +namespace tflite { + +NNAPIAllocation::NNAPIAllocation(const char* filename, + ErrorReporter* error_reporter) + : MMAPAllocation(filename, error_reporter) { + // The disabled variant should never be created. + assert(false); +} + +NNAPIAllocation::~NNAPIAllocation() {} + +NNAPIDelegate::~NNAPIDelegate() {} + +TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) { + return kTfLiteError; +} + +TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) { + return kTfLiteError; +} + +bool NNAPIDelegate::IsSupported() { return false; } + +} // namespace tflite diff --git a/tensorflow/contrib/lite/profiling/BUILD b/tensorflow/contrib/lite/profiling/BUILD index a162b87b8f98576ec7c3b2623d1d34f2baef6cce..1172722f7a70771af73eb07571349e431755471c 100644 --- a/tensorflow/contrib/lite/profiling/BUILD +++ b/tensorflow/contrib/lite/profiling/BUILD @@ -58,6 +58,7 @@ cc_test( name = "profile_summarizer_test", srcs = ["profile_summarizer_test.cc"], copts = common_copts, + tags = ["no_oss"], deps = [ ":profile_summarizer", "//tensorflow/contrib/lite:framework", diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.cc b/tensorflow/contrib/lite/profiling/profile_summarizer.cc index c37a0965884a803e82da536f73a8f32a28691651..720bd717b9e3b0c45cbdbaaad2b6900edacc3051 100644 --- a/tensorflow/contrib/lite/profiling/profile_summarizer.cc +++ b/tensorflow/contrib/lite/profiling/profile_summarizer.cc @@ -23,8 +23,6 @@ namespace tflite { namespace profiling { namespace { -using Detail = tensorflow::StatsCalculator::Detail; - struct OperatorDetails { std::string name; std::vector inputs; @@ -83,7 +81,7 @@ OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter, OperatorDetails details; details.name = op_name; if (profiling_string) { - details.name += ":" + string(profiling_string); + details.name += ":" + std::string(profiling_string); } details.inputs = GetTensorNames(interpreter, inputs); details.outputs = GetTensorNames(interpreter, outputs); @@ -125,28 +123,17 @@ void ProfileSummarizer::ProcessProfiles( int64_t base_start_us = events[0]->begin_timestamp_us; int node_num = 0; int64_t curr_total_us = 0; - std::map details; for (auto event : events) { auto op_details = GetOperatorDetails(interpreter, event->event_metadata); auto node_name = ToString(op_details.outputs); - auto result = details.emplace(node_name, Detail()); - Detail* detail = &(result.first->second); - detail->start_us.UpdateStat(event->begin_timestamp_us - base_start_us); + int64_t start_us = event->begin_timestamp_us - base_start_us; int64_t node_exec_time = event->end_timestamp_us - event->begin_timestamp_us; - detail->rel_end_us.UpdateStat(node_exec_time); + stats_calculator_->AddNodeStats(node_name, op_details.name, node_num, + start_us, node_exec_time, 0 /*memory */); curr_total_us += node_exec_time; ++node_num; - - if (result.second) { - detail->name = node_name; - detail->type = op_details.name; - detail->run_order = node_num; - detail->times_called = 0; - } - ++detail->times_called; } - stats_calculator_->UpdateDetails(details); stats_calculator_->UpdateRunTotalUs(curr_total_us); } } // namespace profiling diff --git a/tensorflow/contrib/lite/profiling/time.cc b/tensorflow/contrib/lite/profiling/time.cc index 446660bb747cd6e3b694669b64ac1d95cf415fbe..875ddb02bcfc30f4c2ef543fe1c15bec467e5410 100644 --- a/tensorflow/contrib/lite/profiling/time.cc +++ b/tensorflow/contrib/lite/profiling/time.cc @@ -14,16 +14,34 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/profiling/time.h" +#if defined(_MSC_VER) +#include // NOLINT(build/c++11) +#else #include +#endif namespace tflite { namespace profiling { namespace time { + +#if defined(_MSC_VER) + +uint64_t NowMicros() { + return std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); +} + +#else + uint64_t NowMicros() { struct timeval tv; gettimeofday(&tv, nullptr); return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; } + +#endif // defined(_MSC_VER) + } // namespace time } // namespace profiling } // namespace tflite diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 27909a9458f6b09f96cb556a5254f01e54f46e05..860aff9e7e2de9616dea40f42a33bc1e4ee9f400 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -19,6 +19,8 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/contrib/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper", + "//tensorflow/python:util", + "//third_party/py/numpy", ], ) @@ -30,9 +32,10 @@ py_test( tags = ["no_oss"], deps = [ ":interpreter", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:platform_test", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform", + "//third_party/py/numpy", ], ) @@ -69,7 +72,10 @@ py_test( srcs = ["lite_test.py"], data = [":interpreter_test_data"], srcs_version = "PY2AND3", - tags = ["no_windows"], + tags = [ + "no_oss", + "no_windows", + ], deps = [ ":lite", ], @@ -161,7 +167,10 @@ py_test( name = "convert_saved_model_test", srcs = ["convert_saved_model_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], + tags = [ + "no_oss", + "no_windows", + ], visibility = ["//visibility:public"], deps = [ ":convert_saved_model", diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py index 0ea2630f711727787332f207bdff6383aac8097c..ec49738fb5365a16c41cc6737198b5707508a3e2 100644 --- a/tensorflow/contrib/lite/python/convert.py +++ b/tensorflow/contrib/lite/python/convert.py @@ -115,6 +115,7 @@ def build_toco_convert_protos(input_tensors, inference_type=lite_constants.FLOAT, inference_input_type=None, input_format=lite_constants.TENSORFLOW_GRAPHDEF, + input_shapes=None, output_format=lite_constants.TFLITE, quantized_input_stats=None, default_ranges_stats=None, @@ -141,6 +142,8 @@ def build_toco_convert_protos(input_tensors, Must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) input_format: Type of data to read Currently must be `{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF) + input_shapes: Input array shape. It needs to be a list of the same length + as `input_tensors`, or None. (default None) output_format: Output file format. Currently must be `{TFLITE, GRAPHVIZ_DOT}`. (default TFLITE) quantized_input_stats: List of tuples of integers representing the mean and @@ -209,7 +212,11 @@ def build_toco_convert_protos(input_tensors, if inference_type == lite_constants.QUANTIZED_UINT8: input_array.mean_value, input_array.std_value = quantized_input_stats[idx] input_array.name = tensor_name(input_tensor) - input_array.shape.dims.extend(map(int, input_tensor.get_shape())) + if input_shapes is None: + shape = input_tensor.get_shape() + else: + shape = input_shapes[idx] + input_array.shape.dims.extend(map(int, shape)) for output_tensor in output_tensors: model.output_arrays.append(tensor_name(output_tensor)) diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py index fd908234254185e0a0639618e936ca8ff58631da..3243bddac879b8eb0ca7a03d28b2f6094f905983 100644 --- a/tensorflow/contrib/lite/python/interpreter.py +++ b/tensorflow/contrib/lite/python/interpreter.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import sys +import numpy as np from tensorflow.python.util.lazy_loader import LazyLoader # Lazy load since some of the performance benchmark skylark rules @@ -56,9 +57,6 @@ class Interpreter(object): self._interpreter = ( _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer( model_content)) - if not self._interpreter: - raise ValueError( - 'Failed to create model from {} bytes'.format(len(model_content))) elif not model_path and not model_path: raise ValueError('`model_path` or `model_content` must be specified.') else: @@ -66,8 +64,7 @@ class Interpreter(object): def allocate_tensors(self): self._ensure_safe() - if not self._interpreter.AllocateTensors(): - raise ValueError('Failed to allocate tensors') + return self._interpreter.AllocateTensors() def _safe_to_run(self): """Returns true if there exist no numpy array buffers. @@ -152,8 +149,7 @@ class Interpreter(object): Raises: ValueError: If the interpreter could not set the tensor. """ - if not self._interpreter.SetTensor(tensor_index, value): - raise ValueError('Failed to set tensor') + self._interpreter.SetTensor(tensor_index, value) def resize_tensor_input(self, input_index, tensor_size): """Resizes an input tensor. @@ -167,8 +163,10 @@ class Interpreter(object): ValueError: If the interpreter could not resize the input tensor. """ self._ensure_safe() - if not self._interpreter.ResizeInputTensor(input_index, tensor_size): - raise ValueError('Failed to resize input') + # `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size + # parameter. + tensor_size = np.array(tensor_size, dtype=np.int32) + self._interpreter.ResizeInputTensor(input_index, tensor_size) def get_output_details(self): """Gets model output details. @@ -181,7 +179,9 @@ class Interpreter(object): ] def get_tensor(self, tensor_index): - """Gets the value of the input tensor. Note this makes a copy so prefer `tensor()`. + """Gets the value of the input tensor (get a copy). + + If you wish to avoid the copy, use `tensor()`. Args: tensor_index: Tensor index of tensor to get. This value can be gotten from @@ -208,7 +208,7 @@ class Interpreter(object): for i in range(10): input().fill(3.) interpreter.invoke() - print("inference %s" % output) + print("inference %s" % output()) Notice how this function avoids making a numpy array directly. This is because it is important to not hold actual numpy views to the data longer @@ -247,5 +247,7 @@ class Interpreter(object): ValueError: When the underlying interpreter fails raise ValueError. """ self._ensure_safe() - if not self._interpreter.Invoke(): - raise ValueError('Failed to invoke TFLite model') + self._interpreter.Invoke() + + def reset_all_variables_to_zero(self): + return self._interpreter.ResetVariableTensorsToZero() diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py index 5f1fa26c3b7f76309a6f1f80aa3c1e4889781528..e77d52ca9950ec42300264bb56ebce253d4982b1 100644 --- a/tensorflow/contrib/lite/python/interpreter_test.py +++ b/tensorflow/contrib/lite/python/interpreter_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import io import numpy as np +import six from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper from tensorflow.python.framework import test_util @@ -82,7 +83,7 @@ class InterpreterTest(test_util.TensorFlowTestCase): test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8) expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8) interpreter.resize_tensor_input(input_details[0]['index'], - np.array(test_input.shape, dtype=np.int32)) + test_input.shape) interpreter.allocate_tensors() interpreter.set_tensor(input_details[0]['index'], test_input) interpreter.invoke() @@ -91,6 +92,28 @@ class InterpreterTest(test_util.TensorFlowTestCase): self.assertTrue((expected_output == output_data).all()) +class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase): + + def testInvalidModelContent(self): + with self.assertRaisesRegexp(ValueError, + 'Model provided has model identifier \''): + interpreter_wrapper.Interpreter(model_content=six.b('garbage')) + + def testInvalidModelFile(self): + with self.assertRaisesRegexp( + ValueError, 'Could not open \'totally_invalid_file_name\''): + interpreter_wrapper.Interpreter( + model_path='totally_invalid_file_name') + + def testInvokeBeforeReady(self): + interpreter = interpreter_wrapper.Interpreter( + model_path=resource_loader.get_path_to_datafile( + 'testdata/permute_float.tflite')) + with self.assertRaisesRegexp(RuntimeError, + 'Invoke called on model that is not ready'): + interpreter.invoke() + + class InterpreterTensorAccessorTest(test_util.TensorFlowTestCase): def setUp(self): diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD index 634c2a1e1f5005208b4eea5c853a43cccf4d244c..69ee95c320b72b68052c6f76f32c1493707f34b1 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD @@ -13,7 +13,6 @@ cc_library( deps = [ "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/core:lib", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", "@com_google_absl//absl/memory", diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 5554d08fa08fdc6ddcb042d12f979164a144e337..9ab05f3068494a573ffa5b46f84be66a12d54e46 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -14,13 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h" +#include #include #include "absl/memory/memory.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" -#include "tensorflow/core/platform/logging.h" // Disallow Numpy 1.7 deprecated symbols. #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION @@ -38,9 +38,58 @@ limitations under the License. #define CPP_TO_PYSTRING PyString_FromStringAndSize #endif +#define TFLITE_PY_CHECK(x) \ + if ((x) != kTfLiteOk) { \ + return error_reporter_->exception(); \ + } + +#define TFLITE_PY_TENSOR_BOUNDS_CHECK(i) \ + if (i >= interpreter_->tensors_size() || i < 0) { \ + PyErr_Format(PyExc_ValueError, \ + "Invalid tensor index %d exceeds max tensor index %lu", i, \ + interpreter_->tensors_size()); \ + return nullptr; \ + } + +#define TFLITE_PY_ENSURE_VALID_INTERPRETER() \ + if (!interpreter_) { \ + PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \ + return nullptr; \ + } + namespace tflite { namespace interpreter_wrapper { +class PythonErrorReporter : public tflite::ErrorReporter { + public: + PythonErrorReporter() {} + + // Report an error message + int Report(const char* format, va_list args) override { + char buf[1024]; + int formatted = vsnprintf(buf, sizeof(buf), format, args); + buffer_ << buf; + return formatted; + } + + // Set's a Python runtime exception with the last error. + PyObject* exception() { + std::string last_message = message(); + PyErr_SetString(PyExc_RuntimeError, last_message.c_str()); + return nullptr; + } + + // Gets the last error message and clears the buffer. + std::string message() { + std::string value = buffer_.str(); + buffer_.clear(); + return value; + } + + private: + std::stringstream buffer_; +}; + namespace { // Calls PyArray's initialization to initialize all the API pointers. Note that @@ -59,19 +108,8 @@ std::unique_ptr CreateInterpreter( ImportNumpy(); std::unique_ptr interpreter; - tflite::InterpreterBuilder(*model, resolver)(&interpreter); - if (interpreter) { - for (const int input_index : interpreter->inputs()) { - const TfLiteTensor* tensor = interpreter->tensor(input_index); - CHECK(tensor); - const TfLiteIntArray* dims = tensor->dims; - if (!dims) { - continue; - } - - std::vector input_dims(dims->data, dims->data + dims->size); - interpreter->ResizeInputTensor(input_index, input_dims); - } + if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) { + return nullptr; } return interpreter; } @@ -95,10 +133,10 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { case kTfLiteComplex64: return NPY_COMPLEX64; case kTfLiteNoType: - return -1; + return NPY_NOTYPE; + // Avoid default so compiler errors created when new types are made. } - LOG(ERROR) << "Unknown TfLiteType " << tf_lite_type; - return -1; + return NPY_NOTYPE; } TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { @@ -122,8 +160,8 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { return kTfLiteString; case NPY_COMPLEX64: return kTfLiteComplex64; + // Avoid default so compiler errors created when new types are made. } - LOG(ERROR) << "Unknown PyArray dtype " << pyarray_type; return kTfLiteNoType; } @@ -146,33 +184,54 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) { } // namespace +InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper( + std::unique_ptr model, + std::unique_ptr error_reporter, + std::string* error_msg) { + if (!model) { + *error_msg = error_reporter->message(); + return nullptr; + } + + auto resolver = absl::make_unique(); + auto interpreter = CreateInterpreter(model.get(), *resolver); + if (!interpreter) { + *error_msg = error_reporter->message(); + return nullptr; + } + + InterpreterWrapper* wrapper = + new InterpreterWrapper(std::move(model), std::move(error_reporter), + std::move(resolver), std::move(interpreter)); + return wrapper; +} + InterpreterWrapper::InterpreterWrapper( - std::unique_ptr model) + std::unique_ptr model, + std::unique_ptr error_reporter, + std::unique_ptr resolver, + std::unique_ptr interpreter) : model_(std::move(model)), - resolver_(absl::make_unique()), - interpreter_(CreateInterpreter(model_.get(), *resolver_)) {} + error_reporter_(std::move(error_reporter)), + resolver_(std::move(resolver)), + interpreter_(std::move(interpreter)) {} InterpreterWrapper::~InterpreterWrapper() {} -bool InterpreterWrapper::AllocateTensors() { - if (!interpreter_) { - LOG(ERROR) << "Cannot allocate tensors: invalid interpreter."; - return false; - } - - if (interpreter_->AllocateTensors() != kTfLiteOk) { - LOG(ERROR) << "Unable to allocate tensors."; - return false; - } - - return true; +PyObject* InterpreterWrapper::AllocateTensors() { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_CHECK(interpreter_->AllocateTensors()); + Py_RETURN_NONE; } -bool InterpreterWrapper::Invoke() { - return interpreter_ ? (interpreter_->Invoke() == kTfLiteOk) : false; +PyObject* InterpreterWrapper::Invoke() { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_CHECK(interpreter_->Invoke()); + Py_RETURN_NONE; } PyObject* InterpreterWrapper::InputIndices() const { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(), interpreter_->inputs().size()); @@ -186,35 +245,36 @@ PyObject* InterpreterWrapper::OutputIndices() const { return PyArray_Return(reinterpret_cast(np_array)); } -bool InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) { - if (!interpreter_) { - LOG(ERROR) << "Invalid interpreter."; - return false; - } +PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); std::unique_ptr array_safe( PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr)); if (!array_safe) { - LOG(ERROR) << "Failed to convert value into readable tensor."; - return false; + PyErr_SetString(PyExc_ValueError, + "Failed to convert numpy value into readable tensor."); + return nullptr; } PyArrayObject* array = reinterpret_cast(array_safe.get()); if (PyArray_NDIM(array) != 1) { - LOG(ERROR) << "Expected 1-D defining input shape."; - return false; + PyErr_Format(PyExc_ValueError, "Shape should be 1D instead of %d.", + PyArray_NDIM(array)); + return nullptr; } if (PyArray_TYPE(array) != NPY_INT32) { - LOG(ERROR) << "Shape must be an int32 array"; - return false; + PyErr_Format(PyExc_ValueError, "Shape must be type int32 (was %d).", + PyArray_TYPE(array)); + return nullptr; } std::vector dims(PyArray_SHAPE(array)[0]); memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int)); - return (interpreter_->ResizeInputTensor(i, dims) == kTfLiteOk); + TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims)); + Py_RETURN_NONE; } std::string InterpreterWrapper::TensorName(int i) const { @@ -227,21 +287,21 @@ std::string InterpreterWrapper::TensorName(int i) const { } PyObject* InterpreterWrapper::TensorType(int i) const { - if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) { - return nullptr; - } + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(i); const TfLiteTensor* tensor = interpreter_->tensor(i); - int typenum = TfLiteTypeToPyArrayType(tensor->type); - return PyArray_TypeObjectFromType(typenum); + int code = TfLiteTypeToPyArrayType(tensor->type); + if (code == -1) { + PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code); + return nullptr; + } + return PyArray_TypeObjectFromType(code); } PyObject* InterpreterWrapper::TensorSize(int i) const { - if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) { - Py_INCREF(Py_None); - return Py_None; - } - + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(i); const TfLiteTensor* tensor = interpreter_->tensor(i); PyObject* np_array = PyArrayFromIntVector(tensor->dims->data, tensor->dims->size); @@ -250,100 +310,87 @@ PyObject* InterpreterWrapper::TensorSize(int i) const { } PyObject* InterpreterWrapper::TensorQuantization(int i) const { - if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) { - Py_INCREF(Py_None); - return Py_None; - } - + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(i); const TfLiteTensor* tensor = interpreter_->tensor(i); return PyTupleFromQuantizationParam(tensor->params); } -bool InterpreterWrapper::SetTensor(int i, PyObject* value) { - if (!interpreter_) { - LOG(ERROR) << "Invalid interpreter."; - return false; - } - - if (i >= interpreter_->tensors_size()) { - LOG(ERROR) << "Invalid tensor index: " << i << " exceeds max tensor index " - << interpreter_->tensors_size(); - return false; - } +PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(i); std::unique_ptr array_safe( PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr)); if (!array_safe) { - LOG(ERROR) << "Failed to convert value into readable tensor."; - return false; + PyErr_SetString(PyExc_ValueError, + "Failed to convert value into readable tensor."); + return nullptr; } PyArrayObject* array = reinterpret_cast(array_safe.get()); const TfLiteTensor* tensor = interpreter_->tensor(i); if (TfLiteTypeFromPyArray(array) != tensor->type) { - LOG(ERROR) << "Cannot set tensor:" - << " Got tensor of type " << TfLiteTypeFromPyArray(array) - << " but expected type " << tensor->type << " for input " << i; - return false; + PyErr_Format(PyExc_ValueError, + "Cannot set tensor:" + " Got tensor of type %d" + " but expected type %d for input %d ", + TfLiteTypeFromPyArray(array), tensor->type, i); + return nullptr; } if (PyArray_NDIM(array) != tensor->dims->size) { - LOG(ERROR) << "Cannot set tensor: Dimension mismatch"; - return false; + PyErr_SetString(PyExc_ValueError, "Cannot set tensor: Dimension mismatch"); + return nullptr; } for (int j = 0; j < PyArray_NDIM(array); j++) { if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) { - LOG(ERROR) << "Cannot set tensor: Dimension mismatch"; - return false; + PyErr_SetString(PyExc_ValueError, + "Cannot set tensor: Dimension mismatch"); + return nullptr; } } size_t size = PyArray_NBYTES(array); - DCHECK_EQ(size, tensor->bytes); + if (size != tensor->bytes) { + PyErr_Format(PyExc_ValueError, + "numpy array had %zu bytes but expected %zu bytes.", size, + tensor->bytes); + return nullptr; + } memcpy(tensor->data.raw, PyArray_DATA(array), size); - return true; + Py_RETURN_NONE; } namespace { -PyObject* CheckGetTensorArgs(Interpreter* interpreter, int tensor_index, +// Checks to see if a tensor access can succeed (returns nullptr on error). +// Otherwise returns Py_None. +PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index, TfLiteTensor** tensor, int* type_num) { - if (!interpreter) { - LOG(ERROR) << "Invalid interpreter."; - Py_INCREF(Py_None); - return Py_None; - } + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_TENSOR_BOUNDS_CHECK(tensor_index); - if (tensor_index >= interpreter->tensors_size() || tensor_index < 0) { - LOG(ERROR) << "Invalid tensor index: " << tensor_index - << " exceeds max tensor index " << interpreter->inputs().size(); - Py_INCREF(Py_None); - return Py_None; - } - - *tensor = interpreter->tensor(tensor_index); + *tensor = interpreter_->tensor(tensor_index); if ((*tensor)->bytes == 0) { - LOG(ERROR) << "Invalid tensor size"; - Py_INCREF(Py_None); - return Py_None; + PyErr_SetString(PyExc_ValueError, "Invalid tensor size."); + return nullptr; } *type_num = TfLiteTypeToPyArrayType((*tensor)->type); if (*type_num == -1) { - LOG(ERROR) << "Unknown tensor type " << (*tensor)->type; - Py_INCREF(Py_None); - return Py_None; + PyErr_SetString(PyExc_ValueError, "Unknown tensor type."); + return nullptr; } if (!(*tensor)->data.raw) { - LOG(ERROR) << "Tensor data is null."; - Py_INCREF(Py_None); - return Py_None; + PyErr_SetString(PyExc_ValueError, "Tensor data is null."); + return nullptr; } - return nullptr; + Py_RETURN_NONE; } } // namespace @@ -352,19 +399,20 @@ PyObject* InterpreterWrapper::GetTensor(int i) const { // Sanity check accessor TfLiteTensor* tensor = nullptr; int type_num = 0; - if (PyObject* pynone_or_nullptr = - CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) { - return pynone_or_nullptr; - } + + PyObject* check_result = + CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num); + if (check_result == nullptr) return check_result; + Py_XDECREF(check_result); + std::vector dims(tensor->dims->data, tensor->dims->data + tensor->dims->size); // Make a buffer copy but we must tell Numpy It owns that data or else // it will leak. void* data = malloc(tensor->bytes); if (!data) { - LOG(ERROR) << "Malloc to copy tensor failed."; - Py_INCREF(Py_None); - return Py_None; + PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed."); + return nullptr; } memcpy(data, tensor->data.raw, tensor->bytes); PyObject* np_array = @@ -378,10 +426,11 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) { // Sanity check accessor TfLiteTensor* tensor = nullptr; int type_num = 0; - if (PyObject* pynone_or_nullptr = - CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) { - return pynone_or_nullptr; - } + + PyObject* check_result = + CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num); + if (check_result == nullptr) return check_result; + Py_XDECREF(check_result); std::vector dims(tensor->dims->data, tensor->dims->data + tensor->dims->size); @@ -394,22 +443,33 @@ PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) { } InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( - const char* model_path) { + const char* model_path, std::string* error_msg) { + std::unique_ptr error_reporter(new PythonErrorReporter); std::unique_ptr model = - tflite::FlatBufferModel::BuildFromFile(model_path); - return model ? new InterpreterWrapper(std::move(model)) : nullptr; + tflite::FlatBufferModel::BuildFromFile(model_path, error_reporter.get()); + return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), + error_msg); } InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( - PyObject* data) { + PyObject* data, std::string* error_msg) { char * buf = nullptr; Py_ssize_t length; + std::unique_ptr error_reporter(new PythonErrorReporter); if (PY_TO_CPPSTRING(data, &buf, &length) == -1) { return nullptr; } std::unique_ptr model = - tflite::FlatBufferModel::BuildFromBuffer(buf, length); - return model ? new InterpreterWrapper(std::move(model)) : nullptr; + tflite::FlatBufferModel::BuildFromBuffer(buf, length, + error_reporter.get()); + return CreateInterpreterWrapper(std::move(model), std::move(error_reporter), + error_msg); +} + +PyObject* InterpreterWrapper::ResetVariableTensorsToZero() { + TFLITE_PY_ENSURE_VALID_INTERPRETER(); + TFLITE_PY_CHECK(interpreter_->ResetVariableTensorsToZero()); + Py_RETURN_NONE; } } // namespace interpreter_wrapper diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h index e7343cb388d657e472464f69fa8cd0c6ddc60923..641dd93db5b9df292e03e9704a218299f48b14fb 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -21,6 +21,9 @@ limitations under the License. // Place `` before to avoid build failures in macOS. #include + +// The empty line above is on purpose as otherwise clang-format will +// automatically move before . #include // We forward declare TFLite classes here to avoid exposing them to SWIG. @@ -36,41 +39,63 @@ class Interpreter; namespace interpreter_wrapper { +class PythonErrorReporter; + class InterpreterWrapper { public: // SWIG caller takes ownership of pointer. - static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path); + static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path, + std::string* error_msg); // SWIG caller takes ownership of pointer. - static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data); + static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data, + std::string* error_msg); ~InterpreterWrapper(); - bool AllocateTensors(); - bool Invoke(); + PyObject* AllocateTensors(); + PyObject* Invoke(); PyObject* InputIndices() const; PyObject* OutputIndices() const; - bool ResizeInputTensor(int i, PyObject* value); + PyObject* ResizeInputTensor(int i, PyObject* value); std::string TensorName(int i) const; PyObject* TensorType(int i) const; PyObject* TensorSize(int i) const; PyObject* TensorQuantization(int i) const; - bool SetTensor(int i, PyObject* value); + PyObject* SetTensor(int i, PyObject* value); PyObject* GetTensor(int i) const; + PyObject* ResetVariableTensorsToZero(); + // Returns a reference to tensor index i as a numpy array. The base_object // should be the interpreter object providing the memory. PyObject* tensor(PyObject* base_object, int i); private: - InterpreterWrapper(std::unique_ptr model); + // Helper function to construct an `InterpreterWrapper` object. + // It only returns InterpreterWrapper if it can construct an `Interpreter`. + // Otherwise it returns `nullptr`. + static InterpreterWrapper* CreateInterpreterWrapper( + std::unique_ptr model, + std::unique_ptr error_reporter, + std::string* error_msg); + + InterpreterWrapper( + std::unique_ptr model, + std::unique_ptr error_reporter, + std::unique_ptr resolver, + std::unique_ptr interpreter); // InterpreterWrapper is not copyable or assignable. We avoid the use of // InterpreterWrapper() = delete here for SWIG compatibility. InterpreterWrapper(); InterpreterWrapper(const InterpreterWrapper& rhs); + // The public functions which creates `InterpreterWrapper` should ensure all + // these member variables are initialized successfully. Otherwise it should + // report the error and return `nullptr`. const std::unique_ptr model_; + const std::unique_ptr error_reporter_; const std::unique_ptr resolver_; const std::unique_ptr interpreter_; }; diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i index 7f51f9f00d1b2fe057052f7b7bd52bcb65231164..afb2092eacab1d8dcccf8c75cee1d8d5c34d7e75 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i @@ -18,8 +18,51 @@ limitations under the License. %{ #define SWIG_FILE_WITH_INIT +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h" %} %include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h" + +namespace tflite { +namespace interpreter_wrapper { +%extend InterpreterWrapper { + + // Version of the constructor that handles producing Python exceptions + // that propagate strings. + static PyObject* CreateWrapperCPPFromFile(const char* model_path) { + std::string error; + if(tflite::interpreter_wrapper::InterpreterWrapper* ptr = + tflite::interpreter_wrapper::InterpreterWrapper + ::CreateWrapperCPPFromFile( + model_path, &error)) { + return SWIG_NewPointerObj( + ptr, SWIGTYPE_p_tflite__interpreter_wrapper__InterpreterWrapper, 1); + } else { + PyErr_SetString(PyExc_ValueError, error.c_str()); + return nullptr; + } + } + + // Version of the constructor that handles producing Python exceptions + // that propagate strings. + static PyObject* CreateWrapperCPPFromBuffer( + PyObject* data) { + std::string error; + if(tflite::interpreter_wrapper::InterpreterWrapper* ptr = + tflite::interpreter_wrapper::InterpreterWrapper + ::CreateWrapperCPPFromBuffer( + data, &error)) { + return SWIG_NewPointerObj( + ptr, SWIGTYPE_p_tflite__interpreter_wrapper__InterpreterWrapper, 1); + } else { + PyErr_SetString(PyExc_ValueError, error.c_str()); + return nullptr; + } + } +} + +} // namespace interpreter_wrapper +} // namespace tflite diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index a4229f91f5bfe2fd0a3030e6c45f9f0a5f23b7b8..2f9b9d469a27cc8910cb61c0da14769e5ff0baf0 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -40,24 +40,23 @@ from google.protobuf import text_format as _text_format from google.protobuf.message import DecodeError from tensorflow.contrib.lite.python import lite_constants as constants from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import -from tensorflow.contrib.lite.python.convert import tensor_name +from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name from tensorflow.contrib.lite.python.convert import toco_convert from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import -from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model -from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names -from tensorflow.contrib.lite.python.convert_saved_model import set_tensor_shapes +from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model +from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names as _get_tensors_from_tensor_names +from tensorflow.contrib.lite.python.convert_saved_model import set_tensor_shapes as _set_tensor_shapes from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: disable=unused-import from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import from tensorflow.core.framework import graph_pb2 as _graph_pb2 from tensorflow.python import keras as _keras from tensorflow.python.client import session as _session -from tensorflow.python.framework import graph_util as tf_graph_util -from tensorflow.python.framework.importer import import_graph_def -from tensorflow.python.ops.variables import global_variables_initializer -from tensorflow.python.saved_model import signature_constants -from tensorflow.python.saved_model import tag_constants -# from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import graph_util as _tf_graph_util +from tensorflow.python.framework.importer import import_graph_def as _import_graph_def +from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer +from tensorflow.python.saved_model import signature_constants as _signature_constants +from tensorflow.python.saved_model import tag_constants as _tag_constants class TocoConverter(object): @@ -132,7 +131,7 @@ class TocoConverter(object): Args: - graph_def: TensorFlow GraphDef. + graph_def: Frozen TensorFlow GraphDef. input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). @@ -178,7 +177,7 @@ class TocoConverter(object): """Creates a TocoConverter class from a file containing a frozen GraphDef. Args: - graph_def_file: Full filepath of file containing TensorFlow GraphDef. + graph_def_file: Full filepath of file containing frozen GraphDef. input_arrays: List of input tensors to freeze graph with. output_arrays: List of output tensors to freeze graph with. input_shapes: Dict of strings representing input tensor names to list of @@ -196,7 +195,7 @@ class TocoConverter(object): input_arrays or output_arrays contains an invalid tensor name. """ with _session.Session() as sess: - sess.run(global_variables_initializer()) + sess.run(_global_variables_initializer()) # Read GraphDef from file. graph_def = _graph_pb2.GraphDef() @@ -218,12 +217,12 @@ class TocoConverter(object): raise ValueError( "Unable to parse input file '{}'.".format(graph_def_file)) sess.graph.as_default() - import_graph_def(graph_def, name="") + _import_graph_def(graph_def, name="") # Get input and output tensors. - input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays) - output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays) - set_tensor_shapes(input_tensors, input_shapes) + input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays) + output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays) + _set_tensor_shapes(input_tensors, input_shapes) # Check if graph is frozen. if not _is_frozen_graph(sess): @@ -261,12 +260,12 @@ class TocoConverter(object): TocoConverter class. """ if tag_set is None: - tag_set = set([tag_constants.SERVING]) + tag_set = set([_tag_constants.SERVING]) if signature_key is None: - signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - result = freeze_saved_model(saved_model_dir, input_arrays, input_shapes, - output_arrays, tag_set, signature_key) + result = _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, + output_arrays, tag_set, signature_key) return cls( graph_def=result[0], input_tensors=result[1], output_tensors=result[2]) @@ -299,15 +298,15 @@ class TocoConverter(object): # Get input and output tensors. if input_arrays: - input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays) + input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays) else: input_tensors = keras_model.inputs if output_arrays: - output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays) + output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays) else: output_tensors = keras_model.outputs - set_tensor_shapes(input_tensors, input_shapes) + _set_tensor_shapes(input_tensors, input_shapes) graph_def = _freeze_graph(sess, output_tensors) return cls(graph_def, input_tensors, output_tensors) @@ -328,12 +327,12 @@ class TocoConverter(object): for tensor in self._input_tensors: if not tensor.get_shape(): raise ValueError("Provide an input shape for input array '{0}'.".format( - tensor_name(tensor))) + _tensor_name(tensor))) shape = tensor.get_shape().as_list() if None in shape[1:]: raise ValueError( "None is only supported in the 1st dimension. Tensor '{0}' has " - "invalid shape '{1}'.".format(tensor_name(tensor), shape)) + "invalid shape '{1}'.".format(_tensor_name(tensor), shape)) elif shape[0] is None: self._set_batch_size(batch_size=1) @@ -343,7 +342,7 @@ class TocoConverter(object): quantized_stats = [] invalid_stats = [] for tensor in self._input_tensors: - name = tensor_name(tensor) + name = _tensor_name(tensor) if name in self.quantized_input_stats: quantized_stats.append(self.quantized_input_stats[name]) else: @@ -381,7 +380,7 @@ class TocoConverter(object): Returns: List of strings. """ - return [tensor_name(tensor) for tensor in self._input_tensors] + return [_tensor_name(tensor) for tensor in self._input_tensors] def _set_batch_size(self, batch_size): """Sets the first dimension of the input tensor to `batch_size`. @@ -428,11 +427,9 @@ def _freeze_graph(sess, output_tensors): Frozen GraphDef. """ if not _is_frozen_graph(sess): - sess.run(global_variables_initializer()) - output_arrays = [tensor_name(tensor) for tensor in output_tensors] - return tf_graph_util.convert_variables_to_constants(sess, sess.graph_def, - output_arrays) + sess.run(_global_variables_initializer()) + output_arrays = [_tensor_name(tensor) for tensor in output_tensors] + return _tf_graph_util.convert_variables_to_constants( + sess, sess.graph_def, output_arrays) else: return sess.graph_def - -# remove_undocumented(__name__) diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py index 0a60477c6de851ab858701a971dcc8131f15f7c8..d17482e60113da5bad3a76fa2ab634ae0ffb89fd 100644 --- a/tensorflow/contrib/lite/python/tflite_convert.py +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -225,7 +225,7 @@ def run_main(_): input_file_group.add_argument( "--graph_def_file", type=str, - help="Full filepath of file containing TensorFlow GraphDef.") + help="Full filepath of file containing frozen TensorFlow GraphDef.") input_file_group.add_argument( "--saved_model_dir", type=str, @@ -257,7 +257,7 @@ def run_main(_): parser.add_argument( "--input_arrays", type=str, - help="Names of the output arrays, comma-separated.") + help="Names of the input arrays, comma-separated.") parser.add_argument( "--input_shapes", type=str, diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD index 9717a4a1a496b888348514584888e62c4e3703b4..b616e449e6ddae6467a6b86269cd108c7eec0c26 100644 --- a/tensorflow/contrib/lite/schema/BUILD +++ b/tensorflow/contrib/lite/schema/BUILD @@ -30,7 +30,10 @@ py_test( size = "small", srcs = ["upgrade_schema_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_oss", + "no_pip", + ], deps = [ ":upgrade_schema", "//tensorflow/python:client_testlib", @@ -64,7 +67,9 @@ cc_test( "schema_v3.fbs", ], tags = [ + "no_oss", "tflite_not_portable_android", + "tflite_not_portable_ios", ], deps = [ "//tensorflow/core:lib_platform", diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/BUILD b/tensorflow/contrib/lite/schema/builtin_ops_header/BUILD index 0148149a6adc141d67e82808f7e8c72ddb7e309a..4a627761daf45b0fddd7b99e8a9c3d0d0ed2ee5e 100644 --- a/tensorflow/contrib/lite/schema/builtin_ops_header/BUILD +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/BUILD @@ -24,6 +24,7 @@ cc_binary( cc_test( name = "generator_test", srcs = ["generator_test.cc"], + tags = ["no_oss"], deps = [ ":generator", "@com_google_googletest//:gtest", @@ -36,6 +37,7 @@ cc_test( data = [ "//tensorflow/contrib/lite:builtin_ops.h", ], + tags = ["no_oss"], deps = [ ":generator", "@com_google_googletest//:gtest", diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 15fb8bbdb8f100201750faf706eb45b697319dfb..14f88b4c009e4f7cd913c2a27799ab418562fb1f 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -44,7 +44,7 @@ enum TensorType : byte { table QuantizationParameters { min:[float]; // For importing back into tensorflow. max:[float]; // For importing back into tensorflow. - scale:[float]; + scale:[float]; // For dequantizing the tensor's values. zero_point:[long]; } @@ -155,11 +155,20 @@ enum BuiltinOperator : byte { EQUAL = 71, NOT_EQUAL = 72, LOG = 73, - SUM=74, + SUM = 74, SQRT = 75, RSQRT = 76, SHAPE = 77, POW = 78, + ARG_MIN = 79, + FAKE_QUANT = 80, + REDUCE_PROD = 81, + REDUCE_MAX = 82, + PACK = 83, + LOGICAL_OR = 84, + ONE_HOT = 85, + LOGICAL_AND = 86, + LOGICAL_NOT = 87, } // Options for the builtin operators. @@ -220,6 +229,13 @@ union BuiltinOptions { NotEqualOptions, ShapeOptions, PowOptions, + ArgMinOptions, + FakeQuantOptions, + PackOptions, + LogicalOrOptions, + OneHotOptions, + LogicalAndOptions, + LogicalNotOptions, } enum Padding : byte { SAME, VALID } @@ -469,6 +485,10 @@ table ArgMaxOptions { output_type : TensorType; } +table ArgMinOptions { + output_type : TensorType; +} + table GreaterOptions { } @@ -517,6 +537,34 @@ table ShapeOptions { table PowOptions { } +table FakeQuantOptions { + // Parameters supported by version 1: + min:float; + max:float; + num_bits:int; + + // Parameters supported by version 2: + narrow_range:bool; +} + +table PackOptions { + values_count:int; + axis:int; +} + +table LogicalOrOptions { +} + +table OneHotOptions { + axis:int; +} + +table LogicalAndOptions { +} + +table LogicalNotOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index fe0ff9a7a5ba0764475f4a7c14cd875b3cdb2aa8..3efa153e2cfd98dcac9352ff0ef4d8eb9bb6b66a 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -157,6 +157,9 @@ struct TileOptionsT; struct ArgMaxOptions; struct ArgMaxOptionsT; +struct ArgMinOptions; +struct ArgMinOptionsT; + struct GreaterOptions; struct GreaterOptionsT; @@ -199,6 +202,24 @@ struct ShapeOptionsT; struct PowOptions; struct PowOptionsT; +struct FakeQuantOptions; +struct FakeQuantOptionsT; + +struct PackOptions; +struct PackOptionsT; + +struct LogicalOrOptions; +struct LogicalOrOptionsT; + +struct OneHotOptions; +struct OneHotOptionsT; + +struct LogicalAndOptions; +struct LogicalAndOptionsT; + +struct LogicalNotOptions; +struct LogicalNotOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -343,11 +364,20 @@ enum BuiltinOperator { BuiltinOperator_RSQRT = 76, BuiltinOperator_SHAPE = 77, BuiltinOperator_POW = 78, + BuiltinOperator_ARG_MIN = 79, + BuiltinOperator_FAKE_QUANT = 80, + BuiltinOperator_REDUCE_PROD = 81, + BuiltinOperator_REDUCE_MAX = 82, + BuiltinOperator_PACK = 83, + BuiltinOperator_LOGICAL_OR = 84, + BuiltinOperator_ONE_HOT = 85, + BuiltinOperator_LOGICAL_AND = 86, + BuiltinOperator_LOGICAL_NOT = 87, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_POW + BuiltinOperator_MAX = BuiltinOperator_LOGICAL_NOT }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[78] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[87] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -426,7 +456,16 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[78] { BuiltinOperator_SQRT, BuiltinOperator_RSQRT, BuiltinOperator_SHAPE, - BuiltinOperator_POW + BuiltinOperator_POW, + BuiltinOperator_ARG_MIN, + BuiltinOperator_FAKE_QUANT, + BuiltinOperator_REDUCE_PROD, + BuiltinOperator_REDUCE_MAX, + BuiltinOperator_PACK, + BuiltinOperator_LOGICAL_OR, + BuiltinOperator_ONE_HOT, + BuiltinOperator_LOGICAL_AND, + BuiltinOperator_LOGICAL_NOT }; return values; } @@ -512,6 +551,15 @@ inline const char **EnumNamesBuiltinOperator() { "RSQRT", "SHAPE", "POW", + "ARG_MIN", + "FAKE_QUANT", + "REDUCE_PROD", + "REDUCE_MAX", + "PACK", + "LOGICAL_OR", + "ONE_HOT", + "LOGICAL_AND", + "LOGICAL_NOT", nullptr }; return names; @@ -580,11 +628,18 @@ enum BuiltinOptions { BuiltinOptions_NotEqualOptions = 54, BuiltinOptions_ShapeOptions = 55, BuiltinOptions_PowOptions = 56, + BuiltinOptions_ArgMinOptions = 57, + BuiltinOptions_FakeQuantOptions = 58, + BuiltinOptions_PackOptions = 59, + BuiltinOptions_LogicalOrOptions = 60, + BuiltinOptions_OneHotOptions = 61, + BuiltinOptions_LogicalAndOptions = 62, + BuiltinOptions_LogicalNotOptions = 63, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_PowOptions + BuiltinOptions_MAX = BuiltinOptions_LogicalNotOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[57] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[64] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -642,7 +697,14 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[57] { BuiltinOptions_EqualOptions, BuiltinOptions_NotEqualOptions, BuiltinOptions_ShapeOptions, - BuiltinOptions_PowOptions + BuiltinOptions_PowOptions, + BuiltinOptions_ArgMinOptions, + BuiltinOptions_FakeQuantOptions, + BuiltinOptions_PackOptions, + BuiltinOptions_LogicalOrOptions, + BuiltinOptions_OneHotOptions, + BuiltinOptions_LogicalAndOptions, + BuiltinOptions_LogicalNotOptions }; return values; } @@ -706,6 +768,13 @@ inline const char **EnumNamesBuiltinOptions() { "NotEqualOptions", "ShapeOptions", "PowOptions", + "ArgMinOptions", + "FakeQuantOptions", + "PackOptions", + "LogicalOrOptions", + "OneHotOptions", + "LogicalAndOptions", + "LogicalNotOptions", nullptr }; return names; @@ -944,6 +1013,34 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_PowOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ArgMinOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_FakeQuantOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_PackOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LogicalOrOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_OneHotOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LogicalAndOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LogicalNotOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1423,6 +1520,62 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_PowOptions ? reinterpret_cast(value) : nullptr; } + ArgMinOptionsT *AsArgMinOptions() { + return type == BuiltinOptions_ArgMinOptions ? + reinterpret_cast(value) : nullptr; + } + const ArgMinOptionsT *AsArgMinOptions() const { + return type == BuiltinOptions_ArgMinOptions ? + reinterpret_cast(value) : nullptr; + } + FakeQuantOptionsT *AsFakeQuantOptions() { + return type == BuiltinOptions_FakeQuantOptions ? + reinterpret_cast(value) : nullptr; + } + const FakeQuantOptionsT *AsFakeQuantOptions() const { + return type == BuiltinOptions_FakeQuantOptions ? + reinterpret_cast(value) : nullptr; + } + PackOptionsT *AsPackOptions() { + return type == BuiltinOptions_PackOptions ? + reinterpret_cast(value) : nullptr; + } + const PackOptionsT *AsPackOptions() const { + return type == BuiltinOptions_PackOptions ? + reinterpret_cast(value) : nullptr; + } + LogicalOrOptionsT *AsLogicalOrOptions() { + return type == BuiltinOptions_LogicalOrOptions ? + reinterpret_cast(value) : nullptr; + } + const LogicalOrOptionsT *AsLogicalOrOptions() const { + return type == BuiltinOptions_LogicalOrOptions ? + reinterpret_cast(value) : nullptr; + } + OneHotOptionsT *AsOneHotOptions() { + return type == BuiltinOptions_OneHotOptions ? + reinterpret_cast(value) : nullptr; + } + const OneHotOptionsT *AsOneHotOptions() const { + return type == BuiltinOptions_OneHotOptions ? + reinterpret_cast(value) : nullptr; + } + LogicalAndOptionsT *AsLogicalAndOptions() { + return type == BuiltinOptions_LogicalAndOptions ? + reinterpret_cast(value) : nullptr; + } + const LogicalAndOptionsT *AsLogicalAndOptions() const { + return type == BuiltinOptions_LogicalAndOptions ? + reinterpret_cast(value) : nullptr; + } + LogicalNotOptionsT *AsLogicalNotOptions() { + return type == BuiltinOptions_LogicalNotOptions ? + reinterpret_cast(value) : nullptr; + } + const LogicalNotOptionsT *AsLogicalNotOptions() const { + return type == BuiltinOptions_LogicalNotOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -4486,6 +4639,60 @@ inline flatbuffers::Offset CreateArgMaxOptions( flatbuffers::Offset CreateArgMaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct ArgMinOptionsT : public flatbuffers::NativeTable { + typedef ArgMinOptions TableType; + TensorType output_type; + ArgMinOptionsT() + : output_type(TensorType_FLOAT32) { + } +}; + +struct ArgMinOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ArgMinOptionsT NativeTableType; + enum { + VT_OUTPUT_TYPE = 4 + }; + TensorType output_type() const { + return static_cast(GetField(VT_OUTPUT_TYPE, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_OUTPUT_TYPE) && + verifier.EndTable(); + } + ArgMinOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ArgMinOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ArgMinOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_output_type(TensorType output_type) { + fbb_.AddElement(ArgMinOptions::VT_OUTPUT_TYPE, static_cast(output_type), 0); + } + explicit ArgMinOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ArgMinOptionsBuilder &operator=(const ArgMinOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateArgMinOptions( + flatbuffers::FlatBufferBuilder &_fbb, + TensorType output_type = TensorType_FLOAT32) { + ArgMinOptionsBuilder builder_(_fbb); + builder_.add_output_type(output_type); + return builder_.Finish(); +} + +flatbuffers::Offset CreateArgMinOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct GreaterOptionsT : public flatbuffers::NativeTable { typedef GreaterOptions TableType; GreaterOptionsT() { @@ -5112,6 +5319,336 @@ inline flatbuffers::Offset CreatePowOptions( flatbuffers::Offset CreatePowOptions(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct FakeQuantOptionsT : public flatbuffers::NativeTable { + typedef FakeQuantOptions TableType; + float min; + float max; + int32_t num_bits; + bool narrow_range; + FakeQuantOptionsT() + : min(0.0f), + max(0.0f), + num_bits(0), + narrow_range(false) { + } +}; + +struct FakeQuantOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef FakeQuantOptionsT NativeTableType; + enum { + VT_MIN = 4, + VT_MAX = 6, + VT_NUM_BITS = 8, + VT_NARROW_RANGE = 10 + }; + float min() const { + return GetField(VT_MIN, 0.0f); + } + float max() const { + return GetField(VT_MAX, 0.0f); + } + int32_t num_bits() const { + return GetField(VT_NUM_BITS, 0); + } + bool narrow_range() const { + return GetField(VT_NARROW_RANGE, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_MIN) && + VerifyField(verifier, VT_MAX) && + VerifyField(verifier, VT_NUM_BITS) && + VerifyField(verifier, VT_NARROW_RANGE) && + verifier.EndTable(); + } + FakeQuantOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(FakeQuantOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct FakeQuantOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_min(float min) { + fbb_.AddElement(FakeQuantOptions::VT_MIN, min, 0.0f); + } + void add_max(float max) { + fbb_.AddElement(FakeQuantOptions::VT_MAX, max, 0.0f); + } + void add_num_bits(int32_t num_bits) { + fbb_.AddElement(FakeQuantOptions::VT_NUM_BITS, num_bits, 0); + } + void add_narrow_range(bool narrow_range) { + fbb_.AddElement(FakeQuantOptions::VT_NARROW_RANGE, static_cast(narrow_range), 0); + } + explicit FakeQuantOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + FakeQuantOptionsBuilder &operator=(const FakeQuantOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateFakeQuantOptions( + flatbuffers::FlatBufferBuilder &_fbb, + float min = 0.0f, + float max = 0.0f, + int32_t num_bits = 0, + bool narrow_range = false) { + FakeQuantOptionsBuilder builder_(_fbb); + builder_.add_num_bits(num_bits); + builder_.add_max(max); + builder_.add_min(min); + builder_.add_narrow_range(narrow_range); + return builder_.Finish(); +} + +flatbuffers::Offset CreateFakeQuantOptions(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct PackOptionsT : public flatbuffers::NativeTable { + typedef PackOptions TableType; + int32_t values_count; + int32_t axis; + PackOptionsT() + : values_count(0), + axis(0) { + } +}; + +struct PackOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef PackOptionsT NativeTableType; + enum { + VT_VALUES_COUNT = 4, + VT_AXIS = 6 + }; + int32_t values_count() const { + return GetField(VT_VALUES_COUNT, 0); + } + int32_t axis() const { + return GetField(VT_AXIS, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_VALUES_COUNT) && + VerifyField(verifier, VT_AXIS) && + verifier.EndTable(); + } + PackOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(PackOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const PackOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct PackOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_values_count(int32_t values_count) { + fbb_.AddElement(PackOptions::VT_VALUES_COUNT, values_count, 0); + } + void add_axis(int32_t axis) { + fbb_.AddElement(PackOptions::VT_AXIS, axis, 0); + } + explicit PackOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + PackOptionsBuilder &operator=(const PackOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreatePackOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t values_count = 0, + int32_t axis = 0) { + PackOptionsBuilder builder_(_fbb); + builder_.add_axis(axis); + builder_.add_values_count(values_count); + return builder_.Finish(); +} + +flatbuffers::Offset CreatePackOptions(flatbuffers::FlatBufferBuilder &_fbb, const PackOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LogicalOrOptionsT : public flatbuffers::NativeTable { + typedef LogicalOrOptions TableType; + LogicalOrOptionsT() { + } +}; + +struct LogicalOrOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef LogicalOrOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + LogicalOrOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LogicalOrOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LogicalOrOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit LogicalOrOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + LogicalOrOptionsBuilder &operator=(const LogicalOrOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateLogicalOrOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + LogicalOrOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateLogicalOrOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct OneHotOptionsT : public flatbuffers::NativeTable { + typedef OneHotOptions TableType; + int32_t axis; + OneHotOptionsT() + : axis(0) { + } +}; + +struct OneHotOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef OneHotOptionsT NativeTableType; + enum { + VT_AXIS = 4 + }; + int32_t axis() const { + return GetField(VT_AXIS, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_AXIS) && + verifier.EndTable(); + } + OneHotOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct OneHotOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_axis(int32_t axis) { + fbb_.AddElement(OneHotOptions::VT_AXIS, axis, 0); + } + explicit OneHotOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + OneHotOptionsBuilder &operator=(const OneHotOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateOneHotOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t axis = 0) { + OneHotOptionsBuilder builder_(_fbb); + builder_.add_axis(axis); + return builder_.Finish(); +} + +flatbuffers::Offset CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LogicalAndOptionsT : public flatbuffers::NativeTable { + typedef LogicalAndOptions TableType; + LogicalAndOptionsT() { + } +}; + +struct LogicalAndOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef LogicalAndOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + LogicalAndOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LogicalAndOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LogicalAndOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit LogicalAndOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + LogicalAndOptionsBuilder &operator=(const LogicalAndOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateLogicalAndOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + LogicalAndOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateLogicalAndOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LogicalNotOptionsT : public flatbuffers::NativeTable { + typedef LogicalNotOptions TableType; + LogicalNotOptionsT() { + } +}; + +struct LogicalNotOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef LogicalNotOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + LogicalNotOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LogicalNotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LogicalNotOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit LogicalNotOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + LogicalNotOptionsBuilder &operator=(const LogicalNotOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateLogicalNotOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + LogicalNotOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateLogicalNotOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -5413,6 +5950,27 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const PowOptions *builtin_options_as_PowOptions() const { return builtin_options_type() == BuiltinOptions_PowOptions ? static_cast(builtin_options()) : nullptr; } + const ArgMinOptions *builtin_options_as_ArgMinOptions() const { + return builtin_options_type() == BuiltinOptions_ArgMinOptions ? static_cast(builtin_options()) : nullptr; + } + const FakeQuantOptions *builtin_options_as_FakeQuantOptions() const { + return builtin_options_type() == BuiltinOptions_FakeQuantOptions ? static_cast(builtin_options()) : nullptr; + } + const PackOptions *builtin_options_as_PackOptions() const { + return builtin_options_type() == BuiltinOptions_PackOptions ? static_cast(builtin_options()) : nullptr; + } + const LogicalOrOptions *builtin_options_as_LogicalOrOptions() const { + return builtin_options_type() == BuiltinOptions_LogicalOrOptions ? static_cast(builtin_options()) : nullptr; + } + const OneHotOptions *builtin_options_as_OneHotOptions() const { + return builtin_options_type() == BuiltinOptions_OneHotOptions ? static_cast(builtin_options()) : nullptr; + } + const LogicalAndOptions *builtin_options_as_LogicalAndOptions() const { + return builtin_options_type() == BuiltinOptions_LogicalAndOptions ? static_cast(builtin_options()) : nullptr; + } + const LogicalNotOptions *builtin_options_as_LogicalNotOptions() const { + return builtin_options_type() == BuiltinOptions_LogicalNotOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -5668,6 +6226,34 @@ template<> inline const PowOptions *Operator::builtin_options_as() c return builtin_options_as_PowOptions(); } +template<> inline const ArgMinOptions *Operator::builtin_options_as() const { + return builtin_options_as_ArgMinOptions(); +} + +template<> inline const FakeQuantOptions *Operator::builtin_options_as() const { + return builtin_options_as_FakeQuantOptions(); +} + +template<> inline const PackOptions *Operator::builtin_options_as() const { + return builtin_options_as_PackOptions(); +} + +template<> inline const LogicalOrOptions *Operator::builtin_options_as() const { + return builtin_options_as_LogicalOrOptions(); +} + +template<> inline const OneHotOptions *Operator::builtin_options_as() const { + return builtin_options_as_OneHotOptions(); +} + +template<> inline const LogicalAndOptions *Operator::builtin_options_as() const { + return builtin_options_as_LogicalAndOptions(); +} + +template<> inline const LogicalNotOptions *Operator::builtin_options_as() const { + return builtin_options_as_LogicalNotOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -7333,6 +7919,32 @@ inline flatbuffers::Offset CreateArgMaxOptions(flatbuffers::FlatB _output_type); } +inline ArgMinOptionsT *ArgMinOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ArgMinOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ArgMinOptions::UnPackTo(ArgMinOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = output_type(); _o->output_type = _e; }; +} + +inline flatbuffers::Offset ArgMinOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateArgMinOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateArgMinOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ArgMinOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _output_type = _o->output_type; + return tflite::CreateArgMinOptions( + _fbb, + _output_type); +} + inline GreaterOptionsT *GreaterOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new GreaterOptionsT(); UnPackTo(_o, _resolver); @@ -7670,6 +8282,165 @@ inline flatbuffers::Offset CreatePowOptions(flatbuffers::FlatBufferB _fbb); } +inline FakeQuantOptionsT *FakeQuantOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new FakeQuantOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void FakeQuantOptions::UnPackTo(FakeQuantOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = min(); _o->min = _e; }; + { auto _e = max(); _o->max = _e; }; + { auto _e = num_bits(); _o->num_bits = _e; }; + { auto _e = narrow_range(); _o->narrow_range = _e; }; +} + +inline flatbuffers::Offset FakeQuantOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateFakeQuantOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateFakeQuantOptions(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FakeQuantOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _min = _o->min; + auto _max = _o->max; + auto _num_bits = _o->num_bits; + auto _narrow_range = _o->narrow_range; + return tflite::CreateFakeQuantOptions( + _fbb, + _min, + _max, + _num_bits, + _narrow_range); +} + +inline PackOptionsT *PackOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new PackOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void PackOptions::UnPackTo(PackOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = values_count(); _o->values_count = _e; }; + { auto _e = axis(); _o->axis = _e; }; +} + +inline flatbuffers::Offset PackOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PackOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreatePackOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreatePackOptions(flatbuffers::FlatBufferBuilder &_fbb, const PackOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PackOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _values_count = _o->values_count; + auto _axis = _o->axis; + return tflite::CreatePackOptions( + _fbb, + _values_count, + _axis); +} + +inline LogicalOrOptionsT *LogicalOrOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new LogicalOrOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void LogicalOrOptions::UnPackTo(LogicalOrOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset LogicalOrOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateLogicalOrOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateLogicalOrOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LogicalOrOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateLogicalOrOptions( + _fbb); +} + +inline OneHotOptionsT *OneHotOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new OneHotOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void OneHotOptions::UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = axis(); _o->axis = _e; }; +} + +inline flatbuffers::Offset OneHotOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateOneHotOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OneHotOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _axis = _o->axis; + return tflite::CreateOneHotOptions( + _fbb, + _axis); +} + +inline LogicalAndOptionsT *LogicalAndOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new LogicalAndOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void LogicalAndOptions::UnPackTo(LogicalAndOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset LogicalAndOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateLogicalAndOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateLogicalAndOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LogicalAndOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateLogicalAndOptions( + _fbb); +} + +inline LogicalNotOptionsT *LogicalNotOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new LogicalNotOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void LogicalNotOptions::UnPackTo(LogicalNotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset LogicalNotOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateLogicalNotOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateLogicalNotOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LogicalNotOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateLogicalNotOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -8083,6 +8854,34 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_ArgMinOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_FakeQuantOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_PackOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LogicalOrOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LogicalAndOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LogicalNotOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -8325,6 +9124,34 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_ArgMinOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_FakeQuantOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_PackOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LogicalOrOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LogicalAndOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LogicalNotOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -8555,6 +9382,34 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreatePowOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_ArgMinOptions: { + auto ptr = reinterpret_cast(value); + return CreateArgMinOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_FakeQuantOptions: { + auto ptr = reinterpret_cast(value); + return CreateFakeQuantOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_PackOptions: { + auto ptr = reinterpret_cast(value); + return CreatePackOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LogicalOrOptions: { + auto ptr = reinterpret_cast(value); + return CreateLogicalOrOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast(value); + return CreateOneHotOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LogicalAndOptions: { + auto ptr = reinterpret_cast(value); + return CreateLogicalAndOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LogicalNotOptions: { + auto ptr = reinterpret_cast(value); + return CreateLogicalNotOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -8785,6 +9640,34 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new PowOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_ArgMinOptions: { + value = new ArgMinOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_FakeQuantOptions: { + value = new FakeQuantOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_PackOptions: { + value = new PackOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LogicalOrOptions: { + value = new LogicalOrOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_OneHotOptions: { + value = new OneHotOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LogicalAndOptions: { + value = new LogicalAndOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LogicalNotOptions: { + value = new LogicalNotOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -9072,6 +9955,41 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_ArgMinOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_FakeQuantOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_PackOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LogicalOrOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LogicalAndOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LogicalNotOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/contrib/lite/simple_memory_arena.cc b/tensorflow/contrib/lite/simple_memory_arena.cc index 4eaf6f1bfe76efc1e6737d03d58be9bc87bb849d..cd0f1f7c17a50f6ce61fa2033e5d13580399f5cf 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.cc +++ b/tensorflow/contrib/lite/simple_memory_arena.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/lite/simple_memory_arena.h" +#include #include #include #include @@ -34,7 +35,7 @@ namespace tflite { TfLiteStatus SimpleMemoryArena::Allocate(TfLiteContext* context, size_t alignment, size_t size, ArenaAlloc* new_alloc) { - TF_LITE_ENSURE(context, alignment < arena_alignment_); + TF_LITE_ENSURE(context, alignment <= arena_alignment_); if (size == 0) { new_alloc->offset = 0; diff --git a/tensorflow/contrib/lite/testdata/add.bin b/tensorflow/contrib/lite/testdata/add.bin new file mode 100644 index 0000000000000000000000000000000000000000..aef0fe3d82c9d92dc444076d3b46e05af1923f46 Binary files /dev/null and b/tensorflow/contrib/lite/testdata/add.bin differ diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index b823c97f38e7660652aa0ce3538b11de59dc9aea..a788d41ba7b370cd0e84c343202f1dca090180f3 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -140,6 +140,7 @@ cc_test( cc_library( name = "join", hdrs = ["join.h"], + deps = ["//tensorflow/contrib/lite:string"], ) cc_test( @@ -172,6 +173,7 @@ cc_test( data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], tags = [ "tflite_not_portable_android", + "tflite_not_portable_ios", ], deps = [ ":tflite_driver", @@ -208,6 +210,10 @@ cc_library( cc_library( name = "util", hdrs = ["util.h"], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string", + ], ) cc_test( @@ -251,6 +257,7 @@ cc_test( srcs = ["tf_driver_test.cc"], data = ["//tensorflow/contrib/lite:testdata/multi_add.pb"], tags = [ + "no_oss", "tflite_not_portable", ], deps = [ @@ -267,6 +274,7 @@ cc_library( ":join", ":split", ":tf_driver", + "//tensorflow/contrib/lite:string", "//tensorflow/core:framework", ], ) @@ -276,6 +284,7 @@ cc_test( size = "small", srcs = ["generate_testspec_test.cc"], tags = [ + "no_oss", "tflite_not_portable", ], deps = [ @@ -332,7 +341,7 @@ tf_cc_test( ], tags = [ "no_cuda_on_cpu_tap", - "no_oss", + "no_oss", # needs test data "tflite_not_portable", ], deps = [ diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 1360f1a27383a709accc1abbd723601854d48a12..52ef0d5b86524d605b2f5d6dbae98d4c343ad6a0 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -90,12 +90,10 @@ TEST_INPUT_DEPTH = 3 # matching the expression will be considered due to the corresponding bug. KNOWN_BUGS = { # TOCO doesn't support scalars as input. - r"relu.*input_shape=\[\]": "67587484", - r"sigmoid.*input_shape=\[\]": "67645668", # Concat doesn't work with a single input tensor r"concat.*num_tensors=1": "67378344", - # Transposition in MatMul is not supported. - r"fully_connected.*transpose_.=True": "67586970", + # Transposition in MatMul is not fully supported. + "fully_connected.*transpose_a=True": "67586970", # Softmax graphs are too complex. r"softmax.*dim=0": "67749831", # BatchToSpaceND only supports 4D tensors. @@ -228,7 +226,9 @@ _TF_TYPE_INFO = { tf.float16: (np.float16, "FLOAT"), tf.int32: (np.int32, "INT32"), tf.uint8: (np.uint8, "QUANTIZED_UINT8"), + tf.int16: (np.int16, "QUANTIZED_INT16"), tf.int64: (np.int64, "INT64"), + tf.bool: (np.bool, "BOOL"), } @@ -240,9 +240,12 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100): if dtype in (tf.float32, tf.float16): value = (max_value-min_value)*np.random.random_sample(shape)+min_value - elif dtype in (tf.int32, tf.uint8, tf.int64): + elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16): value = np.random.randint(min_value, max_value+1, shape) - return value.astype(dtype) + elif dtype == tf.bool: + value = np.random.choice([True, False], size=shape) + return np.dtype(dtype).type(value) if np.isscalar(value) else value.astype( + dtype) def create_scalar_data(dtype, min_value=-100, max_value=100): @@ -253,7 +256,7 @@ def create_scalar_data(dtype, min_value=-100, max_value=100): if dtype in (tf.float32, tf.float16): value = (max_value - min_value) * np.random.random() + min_value - elif dtype in (tf.int32, tf.uint8, tf.int64): + elif dtype in (tf.int32, tf.uint8, tf.int64, tf.int16): value = np.random.randint(min_value, max_value + 1) return np.array(value, dtype=dtype) @@ -479,7 +482,7 @@ def make_zip_of_tests(zip_path, else report_lib.FAILED) report["toco_log"] = toco_log - if FLAGS.save_graphdefs: + if True or FLAGS.save_graphdefs: archive.writestr(label + ".pbtxt", text_format.MessageToString(graph_def), zipfile.ZIP_DEFLATED) @@ -678,6 +681,63 @@ def make_relu6_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_prelu_tests(zip_path): + """Make a set of tests to do PReLU.""" + + test_parameters = [ + { + # The canonical case for image processing is having a 4D `input` + # (NHWC)and `shared_axes`=[1, 2], so the alpha parameter is per + # channel. + "input_shape": [[1, 10, 10, 3], [3, 3, 3, 3]], + "shared_axes": [[1, 2], [1]], + }, + { + # 2D-3D example. Share the 2nd axis. + "input_shape": [[20, 20], [20, 20, 20]], + "shared_axes": [[1]], + } + ] + + def build_graph(parameters): + """Build the graph for the test case.""" + + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + prelu = tf.keras.layers.PReLU(shared_axes=parameters["shared_axes"]) + out = prelu(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + """Build the inputs for the test case.""" + + input_shape = parameters["input_shape"] + input_values = create_tensor_data( + np.float32, input_shape, min_value=-10, max_value=10) + shared_axes = parameters["shared_axes"] + + alpha_shape = [] + for dim in range(1, len(input_shape)): + alpha_shape.append(1 if dim in shared_axes else input_shape[dim]) + + alpha_values = create_tensor_data(np.float32, alpha_shape) + + # There should be only 1 trainable variable tensor. + variables = tf.all_variables() + assert len(variables) == 1 + sess.run(variables[0].assign(alpha_values)) + + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + use_frozen_graph=True) + + # This function tests various TensorFLow functions that generates Const op, # including `tf.ones`, `tf.zeros` and random functions. def make_constant_tests(zip_path): @@ -685,21 +745,22 @@ def make_constant_tests(zip_path): test_parameters = [{ "dtype": [tf.float32, tf.int32], - "input_shape": [[1], [2], [1, 1, 1, 1], [2, 2, 2, 2]], + "input_shape": [[], [1], [2], [1, 1, 1, 1], [2, 2, 2, 2]], }] def build_graph(parameters): - # Since Toco & Tflite can't have a single constant op in the entire graph, - # this test adds a zero tensor with a constant op tensor. - input1 = tf.placeholder(dtype=parameters["dtype"], name="input1", - shape=parameters["input_shape"]) - out = tf.ones(parameters["input_shape"], dtype=parameters["dtype"]) + input1 - return [input1], [out] + dummy_input = tf.placeholder( + dtype=parameters["dtype"], + name="input1", + shape=parameters["input_shape"]) + out = tf.constant( + create_tensor_data(parameters["dtype"], parameters["input_shape"])) + return [dummy_input], [out] def build_inputs(parameters, sess, inputs, outputs): - input1 = np.zeros(parameters["input_shape"], - dtype=_TF_TYPE_INFO[parameters["dtype"]][0]) - return [input1], sess.run(outputs, feed_dict={inputs[0]: input1}) + dummy_input = np.zeros( + parameters["input_shape"], dtype=_TF_TYPE_INFO[parameters["dtype"]][0]) + return [dummy_input], sess.run(outputs, feed_dict={inputs[0]: dummy_input}) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -723,6 +784,11 @@ def make_binary_op_tests(zip_path, binary_operator): "input_shape_1": [[1, 3, 4, 3]], "input_shape_2": [[3]], "activation": [True] + }, { + "dtype": [tf.float32], + "input_shape_1": [[]], + "input_shape_2": [[]], + "activation": [False] }] def build_graph(parameters): @@ -755,11 +821,13 @@ def make_binary_op_tests(zip_path, binary_operator): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) -def make_reduce_tests(reduce_op): +def make_reduce_tests(reduce_op, min_value=-10, max_value=10): """Make a set of tests to do reduce operation. Args: reduce_op: TensorFlow reduce operation to test, i.e. `tf.reduce_mean`. + min_value: min value for created tensor data. + max_value: max value for created tensor data. Returns: a function representing the true generator with `reduce_op_in` curried. @@ -772,7 +840,7 @@ def make_reduce_tests(reduce_op): "input_dtype": [tf.float32, tf.int32, tf.int64], "input_shape": [[3, 2, 4]], "axis": [ - None, 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0], + 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0], [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0], [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3] ], @@ -782,13 +850,19 @@ def make_reduce_tests(reduce_op): "input_dtype": [tf.float32], "input_shape": [[1, 8, 8, 3]], "axis": [ - None, 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3], + 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3], [3, 2, 1, 0], [3, 1, 0, 2], [2, 0], [3, 0], [3, 1], [1, 0], -1, -2, -3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2], [2, 2, 3], [-3, -3, -4], [-3, 2, 1] ], "const_axis": [True, False], "keepdims": [True, False], + }, { + "input_dtype": [tf.float32], + "input_shape": [[], [1, 8, 8, 3], [3, 2, 4]], + "axis": [None], + "const_axis": [True], + "keepdims": [True, False], }] def build_graph(parameters): @@ -806,7 +880,7 @@ def make_reduce_tests(reduce_op): if isinstance(parameters["axis"], list): shape = [len(parameters["axis"])] else: - shape = [0] # shape for None or integers. + shape = [] # shape for None or integers. axis = tf.placeholder(dtype=tf.int32, name="axis", shape=shape) input_tensors = [input_tensor, axis] @@ -816,11 +890,14 @@ def make_reduce_tests(reduce_op): def build_inputs(parameters, sess, inputs, outputs): values = [ - create_tensor_data(parameters["input_dtype"], - parameters["input_shape"])] + create_tensor_data( + parameters["input_dtype"], + parameters["input_shape"], + min_value=min_value, + max_value=max_value) + ] if not parameters["const_axis"]: - if parameters["axis"]: - values.append(np.array(parameters["axis"])) + values.append(np.array(parameters["axis"])) return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -830,22 +907,31 @@ def make_reduce_tests(reduce_op): def make_mean_tests(zip_path): """Make a set of tests to do mean.""" - return make_reduce_tests(tf.reduce_mean)(zip_path) def make_sum_tests(zip_path): """Make a set of tests to do sum.""" - return make_reduce_tests(tf.reduce_sum)(zip_path) +def make_reduce_prod_tests(zip_path): + """Make a set of tests to do prod.""" + # set min max value to be -2, 2 to avoid overflow. + return make_reduce_tests(tf.reduce_prod, -2, 2)(zip_path) + + +def make_reduce_max_tests(zip_path): + """Make a set of tests to do max.""" + return make_reduce_tests(tf.reduce_max)(zip_path) + + def make_exp_tests(zip_path): """Make a set of tests to do exp.""" test_parameters = [{ "input_dtype": [tf.float32], - "input_shape": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + "input_shape": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], }] def build_graph(parameters): @@ -904,8 +990,8 @@ def make_maximum_tests(zip_path): test_parameters = [{ "input_dtype": [tf.float32], - "input_shape_1": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], - "input_shape_2": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + "input_shape_1": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + "input_shape_2": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], }] def build_graph(parameters): @@ -939,8 +1025,8 @@ def make_minimum_tests(zip_path): test_parameters = [{ "input_dtype": [tf.float32], - "input_shape_1": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], - "input_shape_2": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + "input_shape_1": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + "input_shape_2": [[], [3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], }] def build_graph(parameters): @@ -1271,6 +1357,7 @@ def make_concat_tests(zip_path): "base_shape": [[1, 3, 4, 3], [3, 4]], "num_tensors": [1, 2, 3, 4, 5, 6], "axis": [0, 1, 2, 3, -3, -2, -1], + "type": [tf.float32, tf.uint8, tf.int32, tf.int64], }] def get_shape(parameters, delta): @@ -1286,7 +1373,8 @@ def make_concat_tests(zip_path): def build_graph(parameters): all_tensors = [] for n in range(0, parameters["num_tensors"]): - input_tensor = tf.placeholder(dtype=tf.float32, name=("input%d" % n), + input_tensor = tf.placeholder(dtype=parameters["type"], + name=("input%d" % n), shape=get_shape(parameters, n)) all_tensors.append(input_tensor) out = tf.concat(all_tensors, parameters["axis"]) @@ -1295,8 +1383,8 @@ def make_concat_tests(zip_path): def build_inputs(parameters, sess, inputs, outputs): all_values = [] for n in range(0, parameters["num_tensors"]): - input_values = create_tensor_data(np.float32, - get_shape(parameters, n)) + input_values = create_tensor_data( + parameters["type"], get_shape(parameters, n)) all_values.append(input_values) return all_values, sess.run( outputs, feed_dict=dict(zip(inputs, all_values))) @@ -1325,6 +1413,12 @@ def make_fully_connected_tests(zip_path): "transpose_a": [False], "transpose_b": [False], "constant_filter": [True, False], + }, { + "shape1": [[40, 37]], + "shape2": [[40, 37]], + "transpose_a": [False], + "transpose_b": [True], + "constant_filter": [True, False], }] def build_graph(parameters): @@ -1532,19 +1626,39 @@ def make_reshape_tests(zip_path): "dtype": [tf.float32, tf.int32], "input_shape": [[3, 4, 5, 7], [4, 105], [21, 5, 2, 2], [420]], "output_shape": [[15, 28], [420], [1, -1, 5, 7], [-1]], + "constant_shape": [True, False], + }, { + "dtype": [tf.float32], + "input_shape": [[1]], + "output_shape": [[]], + "constant_shape": [True, False], }] def build_graph(parameters): input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", shape=parameters["input_shape"]) - out = tf.reshape(input_tensor, shape=parameters["output_shape"]) - return [input_tensor], [out] + + # Get shape as either a placeholder or constants. + if parameters["constant_shape"]: + output_shape = parameters["output_shape"] + input_tensors = [input_tensor] + else: + # The shape of the shape tensor. + shape_tensor_shape = [len(parameters["output_shape"])] + output_shape = tf.placeholder( + dtype=tf.int32, name="output_shape", shape=shape_tensor_shape) + input_tensors = [input_tensor, output_shape] + out = tf.reshape(input_tensor, shape=output_shape) + return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(parameters["dtype"], - parameters["input_shape"]) - return [input_values], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_values]))) + values = [ + create_tensor_data(parameters["dtype"], parameters["input_shape"]) + ] + if not parameters["constant_shape"]: + values.append(np.array(parameters["output_shape"])) + + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) @@ -1559,7 +1673,7 @@ def make_shape_tests(zip_path): }] def build_graph(parameters): - """Build the topk op testing graph.""" + """Build the shape op testing graph.""" # Note that we intentionally leave out the shape from the input placeholder # to prevent the Shape operation from being optimized out during conversion. input_value = tf.placeholder(dtype=parameters["input_dtype"], name="input") @@ -1575,6 +1689,65 @@ def make_shape_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_one_hot_tests(zip_path): + """Make a set of tests to do one_hot.""" + + test_parameters = [{ + "indices_type": [tf.int32, tf.int64], + "indices_shape": [[3], [4, 4], [1, 5], [5, 1]], + "axis": [0, 1], + "dtype": [tf.int32, tf.int64, tf.float32], + "provide_optional_inputs": [True, False], + }] + + def build_graph(parameters): + indices = tf.placeholder( + dtype=parameters["indices_type"], + name="indices", + shape=parameters["indices_shape"]) + depth = tf.placeholder(dtype=tf.int32, name="depth", shape=()) + + if not parameters["provide_optional_inputs"]: + out = tf.one_hot(indices=indices, depth=depth) + return [indices, depth], [out] + + on_value = tf.placeholder( + dtype=parameters["dtype"], name="on_value", shape=()) + off_value = tf.placeholder( + dtype=parameters["dtype"], name="off_value", shape=()) + out = tf.one_hot( + indices=indices, + depth=depth, + on_value=on_value, + off_value=off_value, + axis=parameters["axis"], + dtype=parameters["dtype"]) + return [indices, depth, on_value, off_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = [ + create_tensor_data( + parameters["indices_type"], + shape=parameters["indices_shape"], + min_value=-1, + max_value=10), + create_tensor_data(tf.int32, shape=None, min_value=1, max_value=10), + ] + + if parameters["provide_optional_inputs"]: + input_values.append( + create_tensor_data( + parameters["dtype"], shape=None, min_value=1, max_value=10)) + input_values.append( + create_tensor_data( + parameters["dtype"], shape=None, min_value=-1, max_value=0)) + + return input_values, sess.run( + outputs, feed_dict=dict(zip(inputs, input_values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_resize_bilinear_tests(zip_path): """Make a set of tests to do resize_bilinear.""" @@ -2148,6 +2321,7 @@ def make_topk_tests(zip_path): test_parameters = [{ "input_dtype": [tf.float32, tf.int32], "input_shape": [[10], [5, 20]], + "input_k": [None, 1, 3], }] def build_graph(parameters): @@ -2156,27 +2330,36 @@ def make_topk_tests(zip_path): dtype=parameters["input_dtype"], name="input", shape=parameters["input_shape"]) - k = tf.constant(3, name="k") + if parameters["input_k"] is not None: + k = tf.placeholder(dtype=tf.int32, name="input_k", shape=[]) + else: + k = tf.constant(3, name="k") out = tf.nn.top_k(input_value, k) - return [input_value], [out[1]] + return [input_value, k], [out[1]] def build_inputs(parameters, sess, inputs, outputs): input_value = create_tensor_data(parameters["input_dtype"], parameters["input_shape"]) - return [input_value], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_value]))) + if parameters["input_k"] is not None: + k = np.array(parameters["input_k"], dtype=np.int32) + return [input_value, k], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value, k]))) + else: + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) -def make_arg_max_tests(zip_path): +def make_arg_min_max_tests(zip_path): """Make a set of tests to do arg_max.""" test_parameters = [{ "input_dtype": [tf.float32, tf.int32], - "input_shape": [[1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5], [10]], + "input_shape": [[], [1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5], [10]], "output_type": [tf.int32, tf.int64], "axis_is_last_dim": [True, False], + "is_arg_max": [True], }] def build_graph(parameters): @@ -2189,7 +2372,10 @@ def make_arg_max_tests(zip_path): axis = len(parameters["input_shape"]) - 1 else: axis = random.randint(0, max(len(parameters["input_shape"]) - 2, 0)) - out = tf.arg_max(input_value, axis, output_type=parameters["output_type"]) + if parameters["is_arg_max"]: + out = tf.arg_max(input_value, axis, output_type=parameters["output_type"]) + else: + out = tf.arg_min(input_value, axis, output_type=parameters["output_type"]) return [input_value], [out] def build_inputs(parameters, sess, inputs, outputs): @@ -2206,7 +2392,8 @@ def make_equal_tests(zip_path): test_parameters = [{ "input_dtype": [tf.float32, tf.int32, tf.int64], - "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + "input_shape_pair": [([], []), + ([1, 1, 1, 3], [1, 1, 1, 3]), ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), ([5, 5], [1]), ([10], [2, 4, 10])], }] @@ -2463,7 +2650,7 @@ def _make_elementwise_tests(op): """Actual function that generates examples.""" test_parameters = [{ "input_dtype": [tf.float32], - "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], + "input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], }] def build_graph(parameters): @@ -2785,6 +2972,95 @@ def make_sparse_to_dense_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_pack_tests(zip_path): + """Make a set of tests to do stack.""" + + test_parameters = [{ + "base_shape": [[3, 4, 3], [3, 4], [5]], + "num_tensors": [1, 2, 3, 4, 5, 6], + "axis": [0, 1, 2, 3], + "additional_shape": [1, 2, 3], + }] + + def get_shape(parameters): + """Return a tweaked version of 'base_shape'.""" + axis = parameters["axis"] + shape = parameters["base_shape"][:] + if axis < len(shape): + shape[axis] += parameters["additional_shape"] + return shape + + def build_graph(parameters): + all_tensors = [] + for n in range(0, parameters["num_tensors"]): + input_tensor = tf.placeholder( + dtype=tf.float32, name=("input%d" % n), shape=get_shape(parameters)) + all_tensors.append(input_tensor) + out = tf.stack(all_tensors, parameters["axis"]) + return all_tensors, [out] + + def build_inputs(parameters, sess, inputs, outputs): + all_values = [] + for _ in range(0, parameters["num_tensors"]): + input_values = create_tensor_data(np.float32, get_shape(parameters)) + all_values.append(input_values) + return all_values, sess.run( + outputs, feed_dict=dict(zip(inputs, all_values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def _make_logical_tests(op): + """Make a set of tests to do logical operations.""" + + def logical(zip_path): + """Generate examples.""" + test_parameters = [{ + "input_shape_pair": [([], []), ([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the logical testing graph.""" + input_value1 = tf.placeholder( + dtype=tf.bool, name="input1", shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=tf.bool, name="input2", shape=parameters["input_shape_pair"][1]) + out = op(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(tf.bool, + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(tf.bool, + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + return logical + + +def make_logical_or_tests(zip_path): + """Make a set of tests to do logical_or.""" + return _make_logical_tests(tf.logical_or)(zip_path) + + +def make_logical_and_tests(zip_path): + """Make a set of tests to do logical_and.""" + return _make_logical_tests(tf.logical_and)(zip_path) + + +def make_logical_xor_tests(zip_path): + """Make a set of tests to do logical_xor. + + Test logical_not as well. + """ + return _make_logical_tests(tf.logical_xor)(zip_path) + + # Toco binary path provided by the generate rule. bin_path = None diff --git a/tensorflow/contrib/lite/testing/generate_testspec.cc b/tensorflow/contrib/lite/testing/generate_testspec.cc index c0c861ff6da2fc144b9303dfdd48f19794cebeca..f29c188e6c2c55bdb13d257c70e23c2943abfa4a 100644 --- a/tensorflow/contrib/lite/testing/generate_testspec.cc +++ b/tensorflow/contrib/lite/testing/generate_testspec.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/contrib/lite/testing/generate_testspec.h" #include "tensorflow/contrib/lite/testing/join.h" #include "tensorflow/contrib/lite/testing/split.h" @@ -25,7 +27,7 @@ namespace testing { template void GenerateCsv(const std::vector& shape, float min, float max, string* out) { - auto random_float = [](int min, int max) { + auto random_float = [](float min, float max) { static unsigned int seed; return min + (max - min) * static_cast(rand_r(&seed)) / RAND_MAX; }; @@ -37,16 +39,10 @@ void GenerateCsv(const std::vector& shape, float min, float max, *out = Join(data.data(), data.size(), ","); } -bool GenerateTestSpecFromTensorflowModel( - std::iostream& stream, const string& tensorflow_model_path, - const string& tflite_model_path, const std::vector& input_layer, +std::vector GenerateInputValues( + const std::vector& input_layer, const std::vector& input_layer_type, - const std::vector& input_layer_shape, - const std::vector& output_layer) { - CHECK_EQ(input_layer.size(), input_layer_type.size()); - CHECK_EQ(input_layer.size(), input_layer_shape.size()); - - // Generate inputs. + const std::vector& input_layer_shape) { std::vector input_values; input_values.resize(input_layer.size()); for (int i = 0; i < input_layer.size(); i++) { @@ -73,57 +69,82 @@ bool GenerateTestSpecFromTensorflowModel( default: fprintf(stderr, "Unsupported type %d (%s) when generating testspec.\n", type, input_layer_type[i].c_str()); - return false; + input_values.clear(); + return input_values; } } + return input_values; +} + +bool GenerateTestSpecFromTensorflowModel( + std::iostream& stream, const string& tensorflow_model_path, + const string& tflite_model_path, int num_invocations, + const std::vector& input_layer, + const std::vector& input_layer_type, + const std::vector& input_layer_shape, + const std::vector& output_layer) { + CHECK_EQ(input_layer.size(), input_layer_type.size()); + CHECK_EQ(input_layer.size(), input_layer_shape.size()); // Invoke tensorflow model. TfDriver runner(input_layer, input_layer_type, input_layer_shape, output_layer); if (!runner.IsValid()) { - cerr << runner.GetErrorMessage() << endl; + std::cerr << runner.GetErrorMessage() << std::endl; return false; } runner.LoadModel(tensorflow_model_path); if (!runner.IsValid()) { - cerr << runner.GetErrorMessage() << endl; + std::cerr << runner.GetErrorMessage() << std::endl; return false; } - for (int i = 0; i < input_values.size(); i++) { - runner.SetInput(i, input_values[i]); - if (!runner.IsValid()) { - cerr << runner.GetErrorMessage() << endl; - return false; - } - } - - runner.Invoke(); - if (!runner.IsValid()) { - cerr << runner.GetErrorMessage() << endl; - return false; - } - - // Write test spec. + // Write first part of test spec, defining model and input shapes. stream << "load_model: " << tflite_model_path << "\n"; stream << "reshape {\n"; for (const auto& shape : input_layer_shape) { stream << " input: \"" << shape << "\"\n"; } stream << "}\n"; - stream << "invoke {\n"; - for (const auto& value : input_values) { - stream << " input: \"" << value << "\"\n"; - } - for (int i = 0; i < output_layer.size(); i++) { - stream << " output: \"" << runner.ReadOutput(i) << "\"\n"; + + // Generate inputs. + for (int i = 0; i < num_invocations; ++i) { + // Note that the input values are random, so each invocation will have a + // different set. + std::vector input_values = + GenerateInputValues(input_layer, input_layer_type, input_layer_shape); + if (input_values.empty()) return false; + + // Run TensorFlow. + for (int j = 0; j < input_values.size(); j++) { + runner.SetInput(j, input_values[j]); + if (!runner.IsValid()) { + std::cerr << runner.GetErrorMessage() << std::endl; + return false; + } + } + + runner.Invoke(); if (!runner.IsValid()) { - cerr << runner.GetErrorMessage() << endl; + std::cerr << runner.GetErrorMessage() << std::endl; return false; } + + // Write second part of test spec, with inputs and outputs. + stream << "invoke {\n"; + for (const auto& value : input_values) { + stream << " input: \"" << value << "\"\n"; + } + for (int j = 0; j < output_layer.size(); j++) { + stream << " output: \"" << runner.ReadOutput(j) << "\"\n"; + if (!runner.IsValid()) { + std::cerr << runner.GetErrorMessage() << std::endl; + return false; + } + } + stream << "}\n"; } - stream << "}\n"; return true; } diff --git a/tensorflow/contrib/lite/testing/generate_testspec.h b/tensorflow/contrib/lite/testing/generate_testspec.h index 6e31a853c3f7f82a89126ff83af784ffd418741a..b3d0db31c01a8cb1b8f34ff6dbb00c77de29b131 100644 --- a/tensorflow/contrib/lite/testing/generate_testspec.h +++ b/tensorflow/contrib/lite/testing/generate_testspec.h @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "tensorflow/contrib/lite/string.h" + namespace tflite { namespace testing { @@ -30,13 +32,15 @@ namespace testing { // stream: mutable iostream that contains the contents of test spec. // tensorflow_model_path: path to TensorFlow model. // tflite_model_path: path to tflite_model_path that the test spec runs +// num_invocations: how many pairs of inputs and outputs will be generated. // against. input_layer: names of input tensors. Example: input1 // input_layer_type: datatypes of input tensors. Example: float // input_layer_shape: shapes of input tensors, separated by comma. example: // 1,3,4 output_layer: names of output tensors. Example: output bool GenerateTestSpecFromTensorflowModel( std::iostream& stream, const string& tensorflow_model_path, - const string& tflite_model_path, const std::vector& input_layer, + const string& tflite_model_path, int num_invocations, + const std::vector& input_layer, const std::vector& input_layer_type, const std::vector& input_layer_shape, const std::vector& output_layer); diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index a86cd5c6ccfc980e8b3a83526714fa11dcc3a4a9..e475f256c01c95755aabb0550153ff3c225aeb3b 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -42,6 +42,7 @@ string* FLAGS_unzip_binary_path = new string("/usr/bin/unzip"); string* FLAGS_unzip_binary_path = new string("/system/bin/unzip"); #endif bool FLAGS_use_nnapi = false; +bool FLAGS_ignore_unsupported_nnapi = false; } // namespace // TensorFlow system environment for file system called. @@ -52,10 +53,6 @@ tensorflow::Env* env = tensorflow::Env::Default(); // Key is a substring of the test name and value is a bug number. // TODO(ahentz): make sure we clean this list up frequently. std::map kBrokenTests = { - {R"(^\/mul.*int32)", "68808744"}, - {R"(^\/div.*int32)", "68808744"}, - {R"(^\/sub.*int32)", "68808744"}, - // Pad and PadV2 only supports 4D tensors. {R"(^\/pad.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])", "70527055"}, @@ -89,18 +86,16 @@ std::map kBrokenTests = { // Transpose only supports 1D-4D input tensors. {R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"}, - // PRelu only supports 4D input with (1, 1, channels) 3D alpha now. - {R"(^\/prelu.*shared_axes=\[1\])", "75975192"}, - // No support for axis!=0 in GatherV2. {R"(^\/gather.*axis=1)", "76910444"}, // No support for arbitrary dimensions in ArgMax. - {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.,.,.\])", + {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.,.,.\])", + "77546240"}, + {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.,.\])", "77546240"}, - {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.,.\])", + {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.\])", "77546240"}, - {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.\])", "77546240"}, }; // Allows test data to be unzipped into a temporary directory and makes @@ -225,16 +220,22 @@ TEST_P(OpsTest, RunZipTests) { } bool result = tflite::testing::ParseAndRunTests(&tflite_stream, &test_driver); + string message = test_driver.GetErrorMessage(); if (bug_number.empty()) { - EXPECT_TRUE(result) << test_driver.GetErrorMessage(); + if (FLAGS_use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) { + EXPECT_EQ(message, string("Failed to invoke NNAPI interpreter")) + << message; + } else { + EXPECT_TRUE(result) << message; + } } else { if (FLAGS_ignore_known_bugs) { EXPECT_FALSE(result) << "Test was expected to fail but is now passing; " "you can mark http://b/" << bug_number << " as fixed! Yay!"; } else { - EXPECT_TRUE(result) << test_driver.GetErrorMessage() - << ": Possibly due to http://b/" << bug_number; + EXPECT_TRUE(result) << message << ": Possibly due to http://b/" + << bug_number; } } } @@ -277,8 +278,11 @@ int main(int argc, char** argv) { tflite::testing::FLAGS_unzip_binary_path, "Required: Location of a suitable unzip binary."), tensorflow::Flag("use_nnapi", &tflite::testing::FLAGS_use_nnapi, - "Whether to enable the NNAPI delegate")}; - + "Whether to enable the NNAPI delegate"), + tensorflow::Flag("ignore_unsupported_nnapi", + &tflite::testing::FLAGS_ignore_unsupported_nnapi, + "Don't fail tests just because delegation to NNAPI " + "is not possible")}; bool success = tensorflow::Flags::Parse(&argc, argv, flags); if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) { fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str()); diff --git a/tensorflow/contrib/lite/testing/join.h b/tensorflow/contrib/lite/testing/join.h index 1edee01cf97da3c53be1895e667b005551ac2991..4be19ad7569c3333b6647b91adbc6e77ff088f10 100644 --- a/tensorflow/contrib/lite/testing/join.h +++ b/tensorflow/contrib/lite/testing/join.h @@ -17,7 +17,8 @@ limitations under the License. #include #include -#include + +#include "tensorflow/contrib/lite/string.h" namespace tflite { namespace testing { diff --git a/tensorflow/contrib/lite/testing/test_runner.h b/tensorflow/contrib/lite/testing/test_runner.h index 96ab6be54e528334f9e4a8cc259e44f99878fefb..fac7d01aab4b1e4c251213041eb4b823cd7d66aa 100644 --- a/tensorflow/contrib/lite/testing/test_runner.h +++ b/tensorflow/contrib/lite/testing/test_runner.h @@ -90,7 +90,7 @@ class TestRunner { // Invalidate the test runner, preventing it from executing any further. void Invalidate(const string& error_message) { - cerr << error_message << std::endl; + std::cerr << error_message << std::endl; error_message_ = error_message; } bool IsValid() const { return error_message_.empty(); } diff --git a/tensorflow/contrib/lite/testing/tf_driver.cc b/tensorflow/contrib/lite/testing/tf_driver.cc index 3b27f6f3da92ce80c3830feb7c6af095e7c48e9c..ec435ca60d959a11a9392b6fbab99b0561f50942 100644 --- a/tensorflow/contrib/lite/testing/tf_driver.cc +++ b/tensorflow/contrib/lite/testing/tf_driver.cc @@ -28,8 +28,8 @@ namespace { tensorflow::Tensor CreateTensor(const tensorflow::DataType type, const std::vector& dim) { - tensorflow::TensorShape shape{gtl::ArraySlice{ - reinterpret_cast(dim.data()), dim.size()}}; + tensorflow::TensorShape shape{tensorflow::gtl::ArraySlice{ + reinterpret_cast(dim.data()), dim.size()}}; return {type, shape}; } @@ -179,7 +179,7 @@ void TfDriver::Invoke() { auto status = session_->Run({input_tensors_.begin(), input_tensors_.end()}, output_names_, {}, &output_tensors_); if (!status.ok()) { - Invalidate("Failed to invoke interpreter"); + Invalidate("Failed to run input data on graph"); } } diff --git a/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc b/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc index 5afa0f800cdaa8bf70a11cb6e2ac64ace8138e79..f2c49fe389763110279b3dd1e4f13b1522de0460 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc +++ b/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc @@ -20,12 +20,29 @@ int main(int argc, char** argv) { ::tflite::testing::DiffOptions options = ::tflite::testing::ParseTfliteDiffFlags(&argc, argv); if (options.tensorflow_model.empty()) return 1; + int failure_count = 0; - for (int i = 0; i < 100; i++) { - if (!tflite::testing::RunDiffTest(options)) { + for (int i = 0; i < options.num_runs_per_pass; i++) { + if (!tflite::testing::RunDiffTest(options, /*num_invocations=*/1)) { ++failure_count; } } - fprintf(stderr, "Num errors: %d\n", failure_count); + int failures_in_first_pass = failure_count; + + if (failure_count == 0) { + // Let's try again with num_invocations > 1 to make sure we can do multiple + // invocations without resetting the interpreter. + for (int i = 0; i < options.num_runs_per_pass; i++) { + if (!tflite::testing::RunDiffTest(options, /*num_invocations=*/2)) { + ++failure_count; + } + } + } + + fprintf(stderr, "Num errors in single-inference pass: %d\n", + failures_in_first_pass); + fprintf(stderr, "Num errors in multi-inference pass : %d\n", + failure_count - failures_in_first_pass); + return failure_count != 0 ? 1 : 0; } diff --git a/tensorflow/contrib/lite/testing/tflite_diff_flags.h b/tensorflow/contrib/lite/testing/tflite_diff_flags.h index 706108ed73bb3fd9bd784cffffe322d6981433e6..695c2a3de6c5d7c74a943134f0c97390710ef1e7 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_flags.h +++ b/tensorflow/contrib/lite/testing/tflite_diff_flags.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_FLAGS_H_ #define TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_FLAGS_H_ +#include + #include "tensorflow/contrib/lite/testing/split.h" #include "tensorflow/contrib/lite/testing/tflite_diff_util.h" #include "tensorflow/core/util/command_line_flags.h" @@ -30,6 +32,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { string input_layer_type; string input_layer_shape; string output_layer; + int32_t num_runs_per_pass = 100; } values; std::vector flags = { @@ -49,6 +52,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { tensorflow::Flag("output_layer", &values.output_layer, "Names of output tensors, separated by comma. Example " "output_1,output_2"), + tensorflow::Flag("num_runs_per_pass", &values.num_runs_per_pass, + "Number of full runs in each pass."), }; bool no_inputs = *argc == 1; @@ -63,7 +68,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { Split(values.input_layer, ","), Split(values.input_layer_type, ","), Split(values.input_layer_shape, ":"), - Split(values.output_layer, ",")}; + Split(values.output_layer, ","), + values.num_runs_per_pass}; } } // namespace testing diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.cc b/tensorflow/contrib/lite/testing/tflite_diff_util.cc index f601d3752ddb5df9f2b5ac73d9bc303efaade4a5..19f34c0a51e442804bf2824adc3a1d8bde1eb4b0 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_util.cc +++ b/tensorflow/contrib/lite/testing/tflite_diff_util.cc @@ -25,13 +25,14 @@ limitations under the License. namespace tflite { namespace testing { -bool RunDiffTest(const DiffOptions& options) { +bool RunDiffTest(const DiffOptions& options, int num_invocations) { std::stringstream tflite_stream; if (!GenerateTestSpecFromTensorflowModel( tflite_stream, options.tensorflow_model, options.tflite_model, - options.input_layer, options.input_layer_type, - options.input_layer_shape, options.output_layer)) + num_invocations, options.input_layer, options.input_layer_type, + options.input_layer_shape, options.output_layer)) { return false; + } TfLiteDriver tflite_driver(/*use_nnapi=*/true); tflite_driver.LoadModel(options.tflite_model); return tflite::testing::ParseAndRunTests(&tflite_stream, &tflite_driver); diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.h b/tensorflow/contrib/lite/testing/tflite_diff_util.h index 326fa6c3e28000dee9b6eb9cc5b3a6c5c87e28d0..4ab2f230fdcdfe4616ab1706aa41f0e806665f66 100644 --- a/tensorflow/contrib/lite/testing/tflite_diff_util.h +++ b/tensorflow/contrib/lite/testing/tflite_diff_util.h @@ -40,10 +40,14 @@ struct DiffOptions { // Names of output tensors. // Example output_1,output_2 std::vector output_layer; + // Number of full runs (from building interpreter to checking outputs) in + // each of the passes. The first pass has a single inference, while the + // second pass does multiple inferences back to back. + int num_runs_per_pass; }; // Run a single TensorFLow Lite diff test with a given options. -bool RunDiffTest(const DiffOptions& options); +bool RunDiffTest(const DiffOptions& options, int num_invocations); } // namespace testing } // namespace tflite diff --git a/tensorflow/contrib/lite/testing/util.h b/tensorflow/contrib/lite/testing/util.h index 6d20aec141c7c3a3e48af290edb169c6fd7254cf..8aa639157b8b68061f9ee8c3483959a79cb5794e 100644 --- a/tensorflow/contrib/lite/testing/util.h +++ b/tensorflow/contrib/lite/testing/util.h @@ -15,8 +15,39 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ #define TENSORFLOW_CONTRIB_LITE_TESTING_UTIL_H_ +#include + +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/string.h" + namespace tflite { +// An ErrorReporter that collects error message in a string, in addition +// to printing to stderr. +class TestErrorReporter : public ErrorReporter { + public: + int Report(const char* format, va_list args) override { + char buffer[1024]; + int size = vsnprintf(buffer, sizeof(buffer), format, args); + fprintf(stderr, "%s", buffer); + error_messages_ += buffer; + num_calls_++; + return size; + } + + void Reset() { + num_calls_ = 0; + error_messages_.clear(); + } + + int num_calls() const { return num_calls_; } + const string& error_messages() const { return error_messages_; } + + private: + int num_calls_ = 0; + string error_messages_; +}; + inline void LogToStderr() { #ifdef PLATFORM_GOOGLE FLAGS_logtostderr = true; diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index f74fc45330e825a41c0ec9d93033fea60bb4de09..7243e584e9c26ce64bfa95e63842cda2087c655c 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -11,6 +11,7 @@ load( "//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", + "tf_copts", ) tf_proto_library_cc( @@ -93,6 +94,7 @@ cc_library( ":runtime", ":toco_port", "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", ], ) @@ -176,7 +178,7 @@ cc_library( "graph_transformations/convert_reorder_axes.cc", "graph_transformations/convert_squeeze_to_reshape.cc", "graph_transformations/convert_trivial_addn_to_add.cc", - "graph_transformations/convert_trivial_stack_to_reshape.cc", + "graph_transformations/convert_trivial_pack_to_reshape.cc", "graph_transformations/convert_trivial_tile_to_concat.cc", "graph_transformations/convert_trivial_transpose_to_reshape.cc", "graph_transformations/create_im2col_arrays.cc", @@ -202,6 +204,7 @@ cc_library( "graph_transformations/lstm_utils.cc", "graph_transformations/make_initial_dequantize_operator.cc", "graph_transformations/merge_reshape_into_preceding_transpose.cc", + "graph_transformations/move_binary_operator_before_reshape.cc", "graph_transformations/propagate_activation_function_into_constants.cc", "graph_transformations/propagate_array_data_types.cc", "graph_transformations/propagate_default_min_max.cc", @@ -211,7 +214,7 @@ cc_library( "graph_transformations/quantization_util.h", "graph_transformations/quantize.cc", "graph_transformations/quantize_weights.cc", - "graph_transformations/read_fake_quant_min_max.cc", + "graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc", "graph_transformations/remove_final_dequantize_op.cc", "graph_transformations/remove_tensorflow_assert.cc", "graph_transformations/remove_tensorflow_identity.cc", @@ -235,19 +238,21 @@ cc_library( "graph_transformations/resolve_constant_fake_quant.cc", "graph_transformations/resolve_constant_fill.cc", "graph_transformations/resolve_constant_gather.cc", + "graph_transformations/resolve_constant_pack.cc", "graph_transformations/resolve_constant_random_uniform.cc", "graph_transformations/resolve_constant_range.cc", "graph_transformations/resolve_constant_reshape.cc", "graph_transformations/resolve_constant_shape_or_rank.cc", "graph_transformations/resolve_constant_slice.cc", - "graph_transformations/resolve_constant_stack.cc", "graph_transformations/resolve_constant_strided_slice.cc", "graph_transformations/resolve_constant_transpose.cc", "graph_transformations/resolve_constant_unary.cc", - "graph_transformations/resolve_mean_attributes.cc", + "graph_transformations/resolve_fake_quant_args_from_vars.cc", + "graph_transformations/resolve_gather_attributes.cc", "graph_transformations/resolve_multiply_by_zero.cc", "graph_transformations/resolve_pad_attributes.cc", "graph_transformations/resolve_padv2_attributes.cc", + "graph_transformations/resolve_reduce_attributes.cc", "graph_transformations/resolve_reorder_axes.cc", "graph_transformations/resolve_reshape_attributes.cc", "graph_transformations/resolve_slice_attributes.cc", @@ -301,7 +306,7 @@ cc_library( "tensorflow_util.h", "toco_tooling.h", ], - copts = select({ + copts = tf_copts() + select({ "//tensorflow:darwin": ["-DTOCO_SUPPORT_PORTABLE_PROTOS=0"], "//conditions:default": [], }), @@ -335,6 +340,7 @@ cc_library( tf_cc_test( name = "import_tensorflow_test", srcs = ["import_tensorflow_test.cc"], + tags = ["no_oss"], deps = [ ":toco_tooling", "//tensorflow/core:framework", @@ -374,6 +380,7 @@ cc_library( tf_cc_test( name = "tooling_util_test", srcs = ["tooling_util_test.cc"], + tags = ["no_oss"], deps = [ ":model", ":tooling_util", @@ -408,6 +415,7 @@ tf_cc_test( data = [ "toco_port_test.cc", ], + tags = ["no_oss"], deps = [ ":toco_port", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/contrib/lite/toco/README.md b/tensorflow/contrib/lite/toco/README.md index ee83c7a6e3253d02fb1a2c791fc22428473c1832..2db6a627ab59604a99cafe3b38df08b70092d989 100644 --- a/tensorflow/contrib/lite/toco/README.md +++ b/tensorflow/contrib/lite/toco/README.md @@ -17,11 +17,12 @@ Usage information is given in these documents: Once an application developer has a trained TensorFlow model, TOCO will accept that model and generate a TensorFlow Lite [FlatBuffer](https://google.github.io/flatbuffers/) file. TOCO currently supports -[SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators) -and frozen graphs (models generated via -[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)). -The TensorFlow Lite FlatBuffer file can be shipped to client devices, generally -mobile devices, where the TensorFlow Lite interpreter handles them on-device. -This flow is represented in the diagram below. +[SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators), +frozen graphs (models generated via +[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)), +and `tf.Keras` model files. The TensorFlow Lite FlatBuffer file can be shipped +to client devices, generally mobile devices, where the TensorFlow Lite +interpreter handles them on-device. This flow is represented in the diagram +below. ![drawing](g3doc/toco_landscape.svg) diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc index 6877fb237c0514a972589ac0301647104f5ed7ed..30525efd2391bb63afd7035b8134e5858add45f2 100644 --- a/tensorflow/contrib/lite/toco/dump_graphviz.cc +++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc @@ -167,7 +167,7 @@ NodeProperties GetPropertiesForArray(const Model& model, node_properties.label += "]"; int buffer_size = 0; - if (IsValid(array.shape())) { + if (IsNonEmpty(array.shape())) { buffer_size = RequiredBufferSizeForShape(array.shape()); node_properties.log2_buffer_size = std::log2(static_cast(buffer_size)); diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 6be6b25f9318deb08bd427d5e3166909fae8f3ea..02671f0408f55726df730dbe0fe9a4f936d22632 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -215,6 +215,30 @@ void ConvertFloatTensorConst(const Model& model, const string& name, LegacyScalarPolicy::kAvoidLegacyScalars); } +void ConvertBoolTensorConst(const Model& model, const string& name, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + CHECK(model.HasArray(name)); + const auto& array = model.GetArray(name); + tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_BOOL); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_BOOL); + const auto& data = array.GetBuffer().data; + for (auto index : data) { + tensor->add_bool_val(index); + } + const auto& array_shape = array.shape(); + auto* shape = tensor->mutable_tensor_shape(); + for (int i = 0; i < array_shape.dimensions_count(); i++) { + shape->add_dim()->set_size(array_shape.dims(i)); + } +} + void ConvertIntTensorConst(const Model& model, const string& name, GraphDef* tensorflow_graph) { if (HasAlreadyExportedConst(name, *tensorflow_graph)) { @@ -621,7 +645,8 @@ void ConvertAddOperator(const Model& model, const AddOperator& src_op, CHECK_EQ(src_op.inputs.size(), 2); *add_op->add_input() = src_op.inputs[0]; *add_op->add_input() = src_op.inputs[1]; - (*add_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*add_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); } void ConvertAddNOperator(const Model& model, const AddNOperator& src_op, @@ -633,27 +658,42 @@ void ConvertAddNOperator(const Model& model, const AddNOperator& src_op, *add_op->add_input() = input; } (*add_op->mutable_attr())["N"].set_i(src_op.inputs.size()); - (*add_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*add_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); } void ConvertMulOperator(const Model& model, const MulOperator& src_op, GraphDef* tensorflow_graph) { - tensorflow::NodeDef* add_op = tensorflow_graph->add_node(); - add_op->set_op("Mul"); - add_op->set_name(src_op.outputs[0]); + tensorflow::NodeDef* mul_op = tensorflow_graph->add_node(); + mul_op->set_op("Mul"); + mul_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); - *add_op->add_input() = src_op.inputs[0]; - *add_op->add_input() = src_op.inputs[1]; - (*add_op->mutable_attr())["T"].set_type(DT_FLOAT); + *mul_op->add_input() = src_op.inputs[0]; + *mul_op->add_input() = src_op.inputs[1]; + (*mul_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); } -void ConvertReluOperator(const ReluOperator& src_op, +void ConvertDivOperator(const Model& model, const DivOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* div_op = tensorflow_graph->add_node(); + div_op->set_op("Div"); + div_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *div_op->add_input() = src_op.inputs[0]; + *div_op->add_input() = src_op.inputs[1]; + (*div_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); +} + +void ConvertReluOperator(const Model& model, const ReluOperator& src_op, GraphDef* tensorflow_graph) { tensorflow::NodeDef* relu_op = tensorflow_graph->add_node(); relu_op->set_op("Relu"); relu_op->set_name(src_op.outputs[0]); *relu_op->add_input() = src_op.inputs[0]; - (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*relu_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); } void ConvertRelu1Operator(const Relu1Operator& src_op, @@ -884,6 +924,9 @@ void ConvertFakeQuantOperator(const FakeQuantOperator& src_op, if (src_op.num_bits) { (*fakequant_op->mutable_attr())["num_bits"].set_i(src_op.num_bits); } + if (src_op.narrow_range) { + (*fakequant_op->mutable_attr())["narrow_range"].set_b(src_op.narrow_range); + } } void ConvertMaxPoolOperator(const MaxPoolOperator& src_op, @@ -1107,13 +1150,27 @@ void ConvertFloorOperator(const Model& model, const FloorOperator& src_op, void ConvertGatherOperator(const Model& model, const GatherOperator& src_op, GraphDef* tensorflow_graph) { tensorflow::NodeDef* gather_op = tensorflow_graph->add_node(); - gather_op->set_op("Gather"); + gather_op->set_op("GatherV2"); gather_op->set_name(src_op.outputs[0]); - CHECK_EQ(src_op.inputs.size(), 2); *gather_op->add_input() = src_op.inputs[0]; *gather_op->add_input() = src_op.inputs[1]; + if (!src_op.axis) { + // Dynamic axis. + CHECK_EQ(src_op.inputs.size(), 3); + *gather_op->add_input() = src_op.inputs[2]; + } else { + // Constant axis. + CHECK_EQ(src_op.inputs.size(), 2); + const string gather_axis = + AvailableArrayName(model, gather_op->name() + "/axis"); + CreateIntTensorConst(gather_axis, {src_op.axis.value()}, {}, + tensorflow_graph); + *gather_op->add_input() = gather_axis; + } + (*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32); + (*gather_op->mutable_attr())["Taxis"].set_type(DT_INT32); const tensorflow::DataType params_type = GetTensorFlowDataType(model, src_op.inputs[0]); (*gather_op->mutable_attr())["Tparams"].set_type(params_type); @@ -1135,6 +1192,22 @@ void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op, GetTensorFlowDataType(model, src_op.outputs[0])); } +void ConvertArgMinOperator(const Model& model, const ArgMinOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* argmin_op = tensorflow_graph->add_node(); + argmin_op->set_op("ArgMin"); + argmin_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *argmin_op->add_input() = src_op.inputs[0]; + *argmin_op->add_input() = src_op.inputs[1]; + (*argmin_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); + (*argmin_op->mutable_attr())["Tidx"].set_type( + GetTensorFlowDataType(model, src_op.inputs[1])); + (*argmin_op->mutable_attr())["output_type"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); +} + void ConvertTransposeOperator(const Model& model, const TransposeOperator& src_op, GraphDef* tensorflow_graph) { @@ -1188,17 +1261,17 @@ void ConvertRangeOperator(const Model& model, const RangeOperator& src_op, GetTensorFlowDataType(src_op.dtype)); } -void ConvertStackOperator(const Model& model, const StackOperator& src_op, - GraphDef* tensorflow_graph) { - tensorflow::NodeDef* stack_op = tensorflow_graph->add_node(); - stack_op->set_op("Stack"); - stack_op->set_name(src_op.outputs[0]); +void ConvertPackOperator(const Model& model, const PackOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* pack_op = tensorflow_graph->add_node(); + pack_op->set_op("Pack"); + pack_op->set_name(src_op.outputs[0]); for (const auto& input : src_op.inputs) { - *stack_op->add_input() = input; + *pack_op->add_input() = input; } - (*stack_op->mutable_attr())["elem_type"].set_type( - GetTensorFlowDataType(model, src_op.outputs[0])); - (*stack_op->mutable_attr())["axis"].set_i(src_op.axis); + (*pack_op->mutable_attr())["axis"].set_i(src_op.axis); + (*pack_op->mutable_attr())["N"].set_i(src_op.inputs.size()); + (*pack_op->mutable_attr())["T"].set_type(GetTensorFlowDataType(src_op.dtype)); } void ConvertFillOperator(const Model& model, const FillOperator& src_op, @@ -1255,6 +1328,20 @@ void ConvertResizeBilinearOperator(const Model& model, (*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners); } +void ConvertOneHotOperator(const Model& model, const OneHotOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* onehot_op = tensorflow_graph->add_node(); + onehot_op->set_op("OneHot"); + onehot_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 4); + for (const auto& input : src_op.inputs) { + *onehot_op->add_input() = input; + } + (*onehot_op->mutable_attr())["T"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); + (*onehot_op->mutable_attr())["axis"].set_i(src_op.axis); +} + namespace { // TODO(aselle): Remove when available in absl absl::string_view FindLongestCommonPrefix(absl::string_view a, @@ -1604,10 +1691,11 @@ void ConvertSliceOperator(const Model& model, const SliceOperator& src_op, CreateSliceInput(src_op.inputs[2], src_op.size, tensorflow_graph); } -void ConvertMeanOperator(const Model& model, const MeanOperator& src_op, - GraphDef* tensorflow_graph) { +template +void ConvertReduceOperator(const Model& model, const T& src_op, + GraphDef* tensorflow_graph, const string& op_name) { tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); - new_op->set_op("Mean"); + new_op->set_op(op_name); new_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *new_op->add_input() = src_op.inputs[0]; @@ -1616,6 +1704,9 @@ void ConvertMeanOperator(const Model& model, const MeanOperator& src_op, const tensorflow::DataType params_type = GetTensorFlowDataType(model, src_op.inputs[0]); (*new_op->mutable_attr())["T"].set_type(params_type); + const tensorflow::DataType indices_type = + GetTensorFlowDataType(model, src_op.inputs[1]); + (*new_op->mutable_attr())["Tidx"].set_type(indices_type); if (src_op.keep_dims) { (*new_op->mutable_attr())["keep_dims"].set_b(true); @@ -1672,43 +1763,43 @@ void ConvertSubOperator(const Model& model, const SubOperator& src_op, void ConvertTensorFlowMinimumOperator(const Model& model, const TensorFlowMinimumOperator& src_op, GraphDef* tensorflow_graph) { - tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); - sub_op->set_op("Minimum"); - sub_op->set_name(src_op.outputs[0]); + tensorflow::NodeDef* min_op = tensorflow_graph->add_node(); + min_op->set_op("Minimum"); + min_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); - *sub_op->add_input() = src_op.inputs[0]; - *sub_op->add_input() = src_op.inputs[1]; + *min_op->add_input() = src_op.inputs[0]; + *min_op->add_input() = src_op.inputs[1]; const tensorflow::DataType data_type = GetTensorFlowDataType(model, src_op.inputs[0]); - (*sub_op->mutable_attr())["T"].set_type(data_type); + (*min_op->mutable_attr())["T"].set_type(data_type); } void ConvertTensorFlowMaximumOperator(const Model& model, const TensorFlowMaximumOperator& src_op, GraphDef* tensorflow_graph) { - tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); - sub_op->set_op("Maximum"); - sub_op->set_name(src_op.outputs[0]); + tensorflow::NodeDef* max_op = tensorflow_graph->add_node(); + max_op->set_op("Maximum"); + max_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); - *sub_op->add_input() = src_op.inputs[0]; - *sub_op->add_input() = src_op.inputs[1]; + *max_op->add_input() = src_op.inputs[0]; + *max_op->add_input() = src_op.inputs[1]; const tensorflow::DataType data_type = GetTensorFlowDataType(model, src_op.inputs[0]); - (*sub_op->mutable_attr())["T"].set_type(data_type); + (*max_op->mutable_attr())["T"].set_type(data_type); } void ConvertSelectOperator(const Model& model, const SelectOperator& src_op, GraphDef* tensorflow_graph) { - tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); - sub_op->set_op("Select"); - sub_op->set_name(src_op.outputs[0]); + tensorflow::NodeDef* select_op = tensorflow_graph->add_node(); + select_op->set_op("Select"); + select_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 3); - *sub_op->add_input() = src_op.inputs[0]; - *sub_op->add_input() = src_op.inputs[1]; - *sub_op->add_input() = src_op.inputs[2]; + *select_op->add_input() = src_op.inputs[0]; + *select_op->add_input() = src_op.inputs[1]; + *select_op->add_input() = src_op.inputs[2]; const tensorflow::DataType data_type = GetTensorFlowDataType(model, src_op.inputs[1]); - (*sub_op->mutable_attr())["T"].set_type(data_type); + (*select_op->mutable_attr())["T"].set_type(data_type); } void ConvertTileOperator(const Model& model, @@ -1731,11 +1822,14 @@ void ConvertTileOperator(const Model& model, void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op, GraphDef* tensorflow_graph) { tensorflow::NodeDef* topk_op = tensorflow_graph->add_node(); - topk_op->set_op("TOPKV2"); + topk_op->set_op("TopKV2"); topk_op->set_name(src_op.outputs[0]); CHECK_EQ(src_op.inputs.size(), 2); *topk_op->add_input() = src_op.inputs[0]; *topk_op->add_input() = src_op.inputs[1]; + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*topk_op->mutable_attr())["T"].set_type(data_type); (*topk_op->mutable_attr())["sorted"].set_b(true); } @@ -1806,6 +1900,73 @@ void ConvertPowOperator(const Model& model, const PowOperator& src_op, (*pow_op->mutable_attr())["T"].set_type(data_type); } +void ConvertAnyOperator(const Model& model, const AnyOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* any_op = tensorflow_graph->add_node(); + any_op->set_op("Any"); + any_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *any_op->add_input() = src_op.inputs[i]; + } + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[1]); + (*any_op->mutable_attr())["Tidx"].set_type(data_type); + (*any_op->mutable_attr())["keep_dims"].set_b(src_op.keep_dims); +} + +void ConvertLogicalAndOperator(const Model& model, + const LogicalAndOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* logical_op = tensorflow_graph->add_node(); + logical_op->set_op("LogicalAnd"); + logical_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *logical_op->add_input() = src_op.inputs[i]; + } +} + +void ConvertLogicalNotOperator(const Model& model, + const LogicalNotOperator& src_op, + GraphDef* tensorflow_graph) { + tensorflow::NodeDef* logical_op = tensorflow_graph->add_node(); + logical_op->set_op("LogicalNot"); + logical_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *logical_op->add_input() = src_op.inputs[0]; +} + +void ConvertLogicalOrOperator(const Model& model, + const LogicalOrOperator& src_op, + const char* op_name, GraphDef* tensorflow_graph) { + tensorflow::NodeDef* logical_or_op = tensorflow_graph->add_node(); + logical_or_op->set_op(op_name); + logical_or_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *logical_or_op->add_input() = src_op.inputs[i]; + } + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*logical_or_op->mutable_attr())["T"].set_type(data_type); +} + +void ConvertCTCBeamSearchDecoderOperator( + const Model& model, const CTCBeamSearchDecoderOperator& src_op, + const char* op_name, GraphDef* tensorflow_graph) { + auto* op = tensorflow_graph->add_node(); + op->set_op(op_name); + op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *op->add_input() = src_op.inputs[i]; + } + (*op->mutable_attr())["beam_width"].set_i(src_op.beam_width); + (*op->mutable_attr())["top_paths"].set_i(src_op.top_paths); + (*op->mutable_attr())["merge_repeated"].set_b(src_op.merge_repeated); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -1841,8 +2002,11 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kMul) { ConvertMulOperator(model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kDiv) { + ConvertDivOperator(model, static_cast(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kRelu) { - ConvertReluOperator(static_cast(src_op), + ConvertReluOperator(model, static_cast(src_op), tensorflow_graph); } else if (src_op.type == OperatorType::kRelu1) { ConvertRelu1Operator(static_cast(src_op), @@ -1942,8 +2106,24 @@ void ConvertOperator(const Model& model, const Operator& src_op, model, static_cast(src_op), tensorflow_graph); } else if (src_op.type == OperatorType::kMean) { - ConvertMeanOperator(model, static_cast(src_op), - tensorflow_graph); + ConvertReduceOperator(model, static_cast(src_op), + tensorflow_graph, "Mean"); + } else if (src_op.type == OperatorType::kSum) { + ConvertReduceOperator(model, + static_cast(src_op), + tensorflow_graph, "Sum"); + } else if (src_op.type == OperatorType::kReduceProd) { + ConvertReduceOperator(model, + static_cast(src_op), + tensorflow_graph, "Prod"); + } else if (src_op.type == OperatorType::kReduceMin) { + ConvertReduceOperator(model, + static_cast(src_op), + tensorflow_graph, "Min"); + } else if (src_op.type == OperatorType::kReduceMax) { + ConvertReduceOperator(model, + static_cast(src_op), + tensorflow_graph, "Max"); } else if (src_op.type == OperatorType::kSub) { ConvertSubOperator(model, static_cast(src_op), tensorflow_graph); @@ -1964,6 +2144,9 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kArgMax) { ConvertArgMaxOperator(model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kArgMin) { + ConvertArgMinOperator(model, static_cast(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kTopK_V2) { ConvertTopKV2Operator(model, static_cast(src_op), tensorflow_graph); @@ -1980,9 +2163,9 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kRange) { ConvertRangeOperator(model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kStack) { - ConvertStackOperator(model, static_cast(src_op), - tensorflow_graph); + } else if (src_op.type == OperatorType::kPack) { + ConvertPackOperator(model, static_cast(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kFill) { ConvertFillOperator(model, static_cast(src_op), tensorflow_graph); @@ -2023,6 +2206,28 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kPow) { ConvertPowOperator(model, static_cast(src_op), "Pow", tensorflow_graph); + } else if (src_op.type == OperatorType::kAny) { + ConvertAnyOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kLogicalAnd) { + ConvertLogicalAndOperator(model, + static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kLogicalNot) { + ConvertLogicalNotOperator(model, + static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kOneHot) { + ConvertOneHotOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kLogicalOr) { + ConvertLogicalOrOperator(model, + static_cast(src_op), + "LogicalOr", tensorflow_graph); + } else if (src_op.type == OperatorType::kCTCBeamSearchDecoder) { + ConvertCTCBeamSearchDecoderOperator( + model, static_cast(src_op), + "CTCBeamSearchDecoder", tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } @@ -2101,6 +2306,9 @@ void ExportTensorFlowGraphDefImplementation(const Model& model, const auto& array = *array_pair.second; if (array.buffer) { switch (array.data_type) { + case ArrayDataType::kBool: + ConvertBoolTensorConst(model, array_name, tensorflow_graph); + break; case ArrayDataType::kFloat: ConvertFloatTensorConst(model, array_name, tensorflow_graph); break; diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md index 0ab024c6187ffed2acf860505812f16ab12a32f5..4bf47aa3c4d1b682808ab8175c4d07d8a347067a 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md @@ -11,8 +11,10 @@ Table of contents: * [Command-line tools](#tools) * [Converting models prior to TensorFlow 1.9.](#pre-tensorflow-1.9) -* [Convert a TensorFlow GraphDef](#graphdef) -* [Convert a TensorFlow SavedModel](#savedmodel) +* [Basic examples](#basic) + * [Convert a TensorFlow GraphDef](#graphdef) + * [Convert a TensorFlow SavedModel](#savedmodel) + * [Convert a tf.keras model](#keras) * [Quantization](#quantization) * [Convert a TensorFlow GraphDef for quantized inference](#graphdef-quant) * [Use "dummy-quantization" to try out quantized inference on a float @@ -34,7 +36,7 @@ There are two approaches to running TOCO via command line. * `tflite_convert`: Starting from TensorFlow 1.9, the command-line tool `tflite_convert` will be installed as part of the Python package. All of the examples below use `tflite_convert` for simplicity. - * Example: `tflite --output_file=...` + * Example: `tflite_convert --output_file=...` * `bazel`: In order to run the latest version of TOCO, [clone the TensorFlow repository](https://www.tensorflow.org/install/install_sources#clone_the_tensorflow_repository) and use `bazel`. This is the recommended approach for converting models that @@ -51,7 +53,12 @@ API](python_api.md#pre-tensorflow-1.9). If a command line tool is desired, the Terminal for additional details on the command-line flags available. There were no command line tools in TensorFlow 1.8. -## Convert a TensorFlow GraphDef +## Basic examples + +The following section shows examples of how to convert a basic float-point model +from each of the supported data formats into a TensorFlow Lite FlatBuffers. + +### Convert a TensorFlow GraphDef The follow example converts a basic TensorFlow GraphDef (frozen by [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)) @@ -70,7 +77,7 @@ tflite_convert \ The value for `input_shapes` is automatically determined whenever possible. -## Convert a TensorFlow SavedModel +### Convert a TensorFlow SavedModel The follow example converts a basic TensorFlow SavedModel into a Tensorflow Lite FlatBuffer to perform floating-point inference. @@ -95,6 +102,17 @@ There is currently no support for MetaGraphDefs without a SignatureDef or for MetaGraphDefs that use the [`assets/` directory](https://www.tensorflow.org/guide/saved_model#structure_of_a_savedmodel_directory). +### Convert a tf.Keras model + +The following example converts a `tf.keras` model into a TensorFlow Lite +Flatbuffer. The `tf.keras` file must contain both the model and the weights. + +``` +tflite_convert \ + --output_file=/tmp/foo.tflite \ + --keras_model_file=/tmp/keras_model.h5 +``` + ## Quantization ### Convert a TensorFlow GraphDef for quantized inference diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md index 2d44b871c6616dbb415bdec550eed867ab49657b..decc8a45a40ffba2a27320ce8391b1916391d744 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md @@ -19,7 +19,7 @@ Table of contents: The following high level flags specify the details of the input and output files. The flag `--output_file` is always required. Additionally, either -`--graph_def_file` or `--saved_model_dir` is required. +`--graph_def_file`, `--saved_model_dir` or `--keras_model_file` is required. * `--output_file`. Type: string. Specifies the full path of the output file. * `--graph_def_file`. Type: string. Specifies the full path of the input @@ -27,6 +27,8 @@ files. The flag `--output_file` is always required. Additionally, either [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py). * `--saved_model_dir`. Type: string. Specifies the full path to the directory containing the SavedModel. +* `--keras_model_file`. Type: string. Specifies the full path of the HDF5 file + containing the tf.keras model. * `--output_format`. Type: string. Default: `TFLITE`. Specifies the format of the output file. Allowed values: * `TFLITE`: TensorFlow Lite FlatBuffer format. diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md index b04d166f89eb5e0f761a795f1c662d7e9cb300c5..3799eac0a1181afe3b63d2f8651745c2ec61f5e0 100644 --- a/tensorflow/contrib/lite/toco/g3doc/python_api.md +++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md @@ -41,9 +41,11 @@ is `tf.contrib.lite.TocoConverter`. The API for calling the Python intepreter is `TocoConverter` provides class methods based on the original format of the model. `TocoConverter.from_session()` is available for GraphDefs. -`TocoConverter.from_saved_model()` is available for SavedModels. Example usages -for simple float-point models are shown in [Basic Examples](#basic). Examples -usages for more complex models is shown in [Complex Examples](#complex). +`TocoConverter.from_saved_model()` is available for SavedModels. +`TocoConverter.from_keras_model_file()` is available for `tf.Keras` files. +Example usages for simple float-point models are shown in [Basic +Examples](#basic). Examples usages for more complex models is shown in [Complex +Examples](#complex). **NOTE**: Currently, `TocoConverter` will cause a fatal error to the Python interpreter when the conversion fails. This will be remedied as soon as @@ -117,7 +119,7 @@ available by running `help(tf.contrib.lite.TocoConverter)`. ### Exporting a tf.keras File -The following example shows how to convert a tf.keras model into a TensorFlow +The following example shows how to convert a `tf.keras` model into a TensorFlow Lite FlatBuffer. ```python @@ -128,7 +130,7 @@ tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) ``` -The tf.keras file must contain both the model and the weights. A comprehensive +The `tf.keras` file must contain both the model and the weights. A comprehensive example including model construction can be seen below. ```python diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc index 56f48d47de4e86ece76ceef1d09a25f50957a8dc..310a88484c246b8035aa73b5e04ad677d575e4c4 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc @@ -40,11 +40,6 @@ bool ConvertExpandDimsToReshape::Run(Model* model, std::size_t op_index) { // Yield until input dims have been resolved. return false; } - if (input_array.shape().dimensions_count() == 0) { - // Input array cannot be 0-D. - // (Unsure if this is TF behavior, but was required to get a test to pass.) - return false; - } const auto& axis_array = model->GetArray(expand_op->inputs[1]); if (!axis_array.has_shape()) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_stack_to_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc similarity index 72% rename from tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_stack_to_reshape.cc rename to tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc index 0615b5e6c6db910ee847188427b416fd812aa141..75113a2a8c7c446bd13de8b5c1a8d8ef3cf7fdd6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_stack_to_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc @@ -25,19 +25,19 @@ limitations under the License. namespace toco { -bool ConvertTrivialStackToReshape::Run(Model* model, std::size_t op_index) { - auto stack_it = model->operators.begin() + op_index; - if (stack_it->get()->type != OperatorType::kStack) { +bool ConvertTrivialPackToReshape::Run(Model* model, std::size_t op_index) { + auto pack_it = model->operators.begin() + op_index; + if (pack_it->get()->type != OperatorType::kPack) { return false; } - auto* stack_op = static_cast(stack_it->get()); - if (stack_op->inputs.size() > 1) { + auto* pack_op = static_cast(pack_it->get()); + if (pack_op->inputs.size() > 1) { // Not trivial. return false; } - CHECK_EQ(stack_op->outputs.size(), 1); + CHECK_EQ(pack_op->outputs.size(), 1); - const auto& input_array = model->GetArray(stack_op->inputs[0]); + const auto& input_array = model->GetArray(pack_op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. return false; @@ -48,16 +48,16 @@ bool ConvertTrivialStackToReshape::Run(Model* model, std::size_t op_index) { return false; } - AddMessageF("Converting trivial %s to a reshape", LogName(*stack_op)); + AddMessageF("Converting trivial %s to a reshape", LogName(*pack_op)); // Note that we could convert to ExpandDims but toco prefers reshapes. auto* reshape_op = new TensorFlowReshapeOperator; - reshape_op->inputs = {stack_op->inputs[0]}; - reshape_op->outputs = stack_op->outputs; + reshape_op->inputs = {pack_op->inputs[0]}; + reshape_op->outputs = pack_op->outputs; // Create shape param. string shape_array_name = - AvailableArrayName(*model, stack_op->outputs[0] + "_shape"); + AvailableArrayName(*model, pack_op->outputs[0] + "_shape"); Array& shape_array = model->GetOrCreateArray(shape_array_name); *(shape_array.mutable_shape()->mutable_dims()) = { 1 + input_array.shape().dimensions_count()}; @@ -70,10 +70,10 @@ bool ConvertTrivialStackToReshape::Run(Model* model, std::size_t op_index) { } // Replace the operator in the graph. - const auto reshape_it = model->operators.emplace(stack_it, reshape_op); - stack_it = reshape_it + 1; - CHECK_EQ(stack_it->get(), stack_op); - model->operators.erase(stack_it); + const auto reshape_it = model->operators.emplace(pack_it, reshape_op); + pack_it = reshape_it + 1; + CHECK_EQ(pack_it->get(), pack_op); + model->operators.erase(pack_it); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc index 2c7ffe488477ef1a544dfe6f36a6e0d1ac40aa96..1688586733b0434c7fc98686a19f0ceb8092f33b 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc @@ -159,6 +159,7 @@ bool DequantizeArray(const string& array_name, new_array.GetOrCreateMinMax() = array->GetMinMax(); fakequant_op->minmax.reset(new MinMax); *fakequant_op->minmax = array->GetMinMax(); + fakequant_op->narrow_range = array->narrow_range; if (must_insert_fakequant_before) { for (const auto& op : model->operators) { for (string& output : op->outputs) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc index 75642bbc37be6b3140e5b79a463ca70b5786d772..c13fc0de7502a9edc80dc399354708a5b1b96b02 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc @@ -181,7 +181,7 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model, // future without worrying. static constexpr int kMinDistanceBetweenBadValues = 16; if (distance < kMinDistanceBetweenBadValues) { - if (allow_nudging_weights()) { + if (allow_nudging_weights() || has_default_ranges_flag()) { buffer_data[i] = 1; changed = true; continue; @@ -200,6 +200,15 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model, } if (changed) { + if (has_default_ranges_flag()) { + std::cerr + << "Since the specified values of --default_ranges_min and " + "--default_ranges_max result in values incompatible with TFLite's " + "fast int8 kernels, " + "--allow_nudging_weights_to_use_fast_gemm_kernel " + "has been enabled. This may affect the accuracy of the model." + << std::endl; + } AddMessageF("Tweaked weights values for %s", LogName(op)); } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 4025fede6f160d7ad0fb09be99c246adb93b43a6..8d9a4c4700e12ac1a187038a0a5efc1b033d4e57 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -116,7 +116,7 @@ DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise) DECLARE_GRAPH_TRANSFORMATION(ConvertSqueezeToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialAddNToAdd) -DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialStackToReshape) +DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialPackToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTileToConcat) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTransposeToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes) @@ -135,6 +135,7 @@ DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1) DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu) DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv) DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator) +DECLARE_GRAPH_TRANSFORMATION(MoveBinaryOperatorBeforeReshape) DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants) DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes) DECLARE_GRAPH_TRANSFORMATION(PropagateFakeQuantNumBits); @@ -158,7 +159,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantBinaryOperator) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantUnaryOperator) DECLARE_GRAPH_TRANSFORMATION(CreateIm2colArrays) DECLARE_GRAPH_TRANSFORMATION(DropIm2colArrays) -DECLARE_GRAPH_TRANSFORMATION(ReadFakeQuantMinMax) +DECLARE_GRAPH_TRANSFORMATION(ReadArrayMinmaxAndNarrowRangeFromFakeQuant) DECLARE_GRAPH_TRANSFORMATION(ReorderElementwiseUnary) DECLARE_GRAPH_TRANSFORMATION(ReorderReshapeTranspose) DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes) @@ -179,13 +180,13 @@ DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolvePadV2Attributes) DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes) -DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes) +DECLARE_GRAPH_TRANSFORMATION(ResolveReduceAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantPack) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRandomUniform) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRange) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSlice) -DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStack) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantGather) @@ -193,6 +194,8 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveMultiplyByZero) DECLARE_GRAPH_TRANSFORMATION(Dequantize) DECLARE_GRAPH_TRANSFORMATION(UnpartitionEmbeddingLookup) DECLARE_GRAPH_TRANSFORMATION(ShuffleFCWeights) +DECLARE_GRAPH_TRANSFORMATION(ResolveFakeQuantArgsFromVars) +DECLARE_GRAPH_TRANSFORMATION(ResolveGatherAttributes) class PropagateDefaultMinMax : public GraphTransformation { public: @@ -259,8 +262,12 @@ class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation { bool allow_nudging_weights() const { return allow_nudging_weights_; } void set_allow_nudging_weights(bool val) { allow_nudging_weights_ = val; } + bool has_default_ranges_flag() const { return has_default_ranges_flag_; } + void set_has_default_ranges_flag(bool val) { has_default_ranges_flag_ = val; } + private: bool allow_nudging_weights_ = false; + bool has_default_ranges_flag_ = false; }; #undef DECLARE_GRAPH_TRANSFORMATION diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index 39f55208e453bdd946cfc8bbbacdc05b793c5d99..d26c3b2878b8499fcbabc5448de9ec045eb07879 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -228,6 +228,14 @@ bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min, return true; } +bool MinMaxApproximatelyEqual(const MinMax& minmax1, const MinMax& minmax2) { + const double magnitude = + std::min(minmax1.max - minmax1.min, minmax2.max - minmax2.min); + const double tolerated = 1e-6 * magnitude; + return std::abs(minmax1.min - minmax2.min) < tolerated && + std::abs(minmax1.max - minmax2.max) < tolerated; +} + // Propagates MinMax from any of the listed arrays, to all others. // If multiple of these arrays have MinMax, then these are required // to agree with each other. @@ -250,7 +258,7 @@ bool PropagateMinMaxAmongArrays(Model* model, for (const string& array_name : array_names) { auto& array = model->GetArray(array_name); if (array.minmax) { - CHECK(*array.minmax == *reference_minmax) + CHECK(MinMaxApproximatelyEqual(*array.minmax, *reference_minmax)) << "Both the following arrays have minmax, and they disagree: " << reference_array_name << " (" << reference_minmax->min << "," << reference_minmax->max << ") and " << array_name << " (" @@ -363,12 +371,26 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { case OperatorType::kStridedSlice: case OperatorType::kSqueeze: case OperatorType::kReshape: + case OperatorType::kExpandDims: case OperatorType::kPad: case OperatorType::kGather: case OperatorType::kTranspose: case OperatorType::kMean: changed = HardcodeMinMaxFromFirstInput(model, op); break; + case OperatorType::kSum: + // reduce_sum is expected to change the output range. Hence + // a fake_quant op is necessary in the output to minimize error. However + // in special circumstances like when computing expected value using + // reduce_sum the input range and the output range matches. Hence the + // below code would act as a fallback. If a fake_quant node is observed in + // the output that takes precendence over the hard coding logic below. + changed = HardcodeMinMaxFromFirstInput(model, op); + if (changed) { + LOG(WARNING) << "Using the input range for output in reduce_sum op." + << "This could have an impact on your model accuracy."; + } + break; case OperatorType::kSelect: changed = HardcodeMinMaxForSelect(model, op); break; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc index 3ca7f53512bb7e307f9a2bc5cfb7c27b45cc052c..c0b014b45eb1df25173ce3ca3fa488b0655c3c76 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc @@ -35,6 +35,26 @@ std::vector>::iterator FindOperator( return it; } +bool ValidateSourceOp(const Model& model, const string& array_name, + OperatorType op_type, Operator** source_op) { + if (op_type == OperatorType::kNone) { + CHECK(!source_op); + } else { + CHECK(source_op); + *source_op = GetOpWithOutput(model, array_name); + if (*source_op == nullptr) { + return false; + } + + // Check that first operator, if connected, is of correct type + if ((*source_op)->type != op_type) { + return false; + } + } + + return true; +} + // Returns true if the given operator has exactly 1 input, and is connected to // the given op_type. // We use kNone to indicate an input unattached to an operator output. Usually @@ -47,24 +67,10 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, } // Check if first input is disconnected/connected to an operator - Operator* x = GetOpWithOutput(model, op.inputs[0]); - if ((op_type == OperatorType::kNone) && (x != nullptr)) { - return false; - } - if ((op_type != OperatorType::kNone) && (x == nullptr)) { + if (!ValidateSourceOp(model, op.inputs[0], op_type, connected_op)) { return false; } - // Check that first operator, if connected, is of correct type - if ((x != nullptr) && (x->type != op_type)) { - return false; - } - - // Successfully matched. Optionally return matching input operators. - if (connected_op) { - *connected_op = x; - } - return true; } @@ -81,40 +87,15 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, } // Check if first input is disconnected/connected to an operator - Operator* x = GetOpWithOutput(model, op.inputs[0]); - if ((a_op_type == OperatorType::kNone) && (x != nullptr)) { - return false; - } - if ((a_op_type != OperatorType::kNone) && (x == nullptr)) { - return false; - } - - // Check that first operator, if connected, is of correct type - if ((x != nullptr) && (x->type != a_op_type)) { + if (!ValidateSourceOp(model, op.inputs[0], a_op_type, a_op)) { return false; } // Check if second input is disconnected/connected to an operator - Operator* y = GetOpWithOutput(model, op.inputs[1]); - if ((b_op_type == OperatorType::kNone) && (y != nullptr)) { - return false; - } - if ((b_op_type != OperatorType::kNone) && (y == nullptr)) { + if (!ValidateSourceOp(model, op.inputs[1], b_op_type, b_op)) { return false; } - // Check that second operator, if connected, is of correct type - if ((y != nullptr) && (y->type != b_op_type)) { - return false; - } - - // Successfully matched. Optionally return matching input operators. - if (a_op != nullptr) { - *a_op = x; - } - if (b_op != nullptr) { - *b_op = y; - } return true; } @@ -132,57 +113,20 @@ bool MatchOperatorInputs(const Operator& op, const Model& model, } // Check if first input is disconnected/connected to an operator - Operator* x = GetOpWithOutput(model, op.inputs[0]); - if ((a_op_type == OperatorType::kNone) && (x != nullptr)) { - return false; - } - if ((a_op_type != OperatorType::kNone) && (x == nullptr)) { - return false; - } - - // Check that first operator, if connected, is of correct type - if ((x != nullptr) && (x->type != a_op_type)) { + if (!ValidateSourceOp(model, op.inputs[0], a_op_type, a_op)) { return false; } // Check if second input is disconnected/connected to an operator - Operator* y = GetOpWithOutput(model, op.inputs[1]); - if ((b_op_type == OperatorType::kNone) && (y != nullptr)) { - return false; - } - if ((b_op_type != OperatorType::kNone) && (y == nullptr)) { - return false; - } - - // Check that second operator, if connected, is of correct type - if ((y != nullptr) && (y->type != b_op_type)) { + if (!ValidateSourceOp(model, op.inputs[1], b_op_type, b_op)) { return false; } // Check if third input is disconnected/connected to an operator - Operator* z = GetOpWithOutput(model, op.inputs[2]); - if ((c_op_type == OperatorType::kNone) && (z != nullptr)) { - return false; - } - if ((c_op_type != OperatorType::kNone) && (z == nullptr)) { + if (!ValidateSourceOp(model, op.inputs[2], c_op_type, c_op)) { return false; } - // Check that third operator, if connected, is of correct type - if ((z != nullptr) && (z->type != c_op_type)) { - return false; - } - - // Successfully matched. Optionally return matching input operators. - if (a_op != nullptr) { - *a_op = x; - } - if (b_op != nullptr) { - *b_op = y; - } - if (c_op != nullptr) { - *c_op = z; - } return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc index 30be4ac0aa5e9f639bbf0630e142c2806faa3260..b90a156a0dcfcd77c3e2b47bb0d77e246f2fc625 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc @@ -74,14 +74,30 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) { const auto* relu_neg_input_op = GetOpWithOutput(*model, mul_op->inputs[1]); if (relu_neg_input_op == nullptr || - relu_neg_input_op->type != OperatorType::kNeg || - relu_neg_input_op->fused_activation_function != - FusedActivationFunctionType::kRelu || relu_neg_input_op->inputs.size() != 1) { return false; } - if (relu_input_op->inputs[0] != relu_neg_input_op->inputs[0]) { + const Operator* final_input_op; + if (relu_neg_input_op->type == OperatorType::kNeg && + relu_neg_input_op->fused_activation_function == + FusedActivationFunctionType::kRelu) { + // This detects a Neg op with fused Relu activation function. + final_input_op = relu_neg_input_op; + } else { + // This detects a Neg op followed by a separated Relu op. + const auto* neg_input_op = + GetOpWithOutput(*model, relu_neg_input_op->inputs[0]); + if (neg_input_op == nullptr || neg_input_op->inputs.size() != 1 || + relu_neg_input_op->type != OperatorType::kRelu || + relu_neg_input_op->fused_activation_function != + FusedActivationFunctionType::kNone) { + return false; + } + final_input_op = neg_input_op; + } + + if (relu_input_op->inputs[0] != final_input_op->inputs[0]) { return false; } @@ -112,7 +128,6 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) { // intermediate tensors aren't used by other ops, those will be removed by // other graph transformation rules. model->operators.erase(FindOp(*model, add_op)); - return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc index 45d9f73a1e6416b8f3fe3936c740da637961b7fc..f684de08abf72d05d4408bf6341fa5a3c2ed11cd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc @@ -85,15 +85,8 @@ bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op, dequantized_input_minmax = input_minmax; auto& input_qparams = input_array.GetOrCreateQuantizationParams(); input_array.data_type = input_array.final_data_type; - if (input_array.data_type == ArrayDataType::kUint8) { - GetQuantizationParamsFromMinMax(input_minmax, - &input_qparams); - } else if (input_array.data_type == ArrayDataType::kInt16) { - GetQuantizationParamsFromMinMax(input_minmax, - &input_qparams); - } else { - LOG(FATAL) << "unhandled data type"; - } + ChooseQuantizationParamsForArrayAndQuantizedDataType( + input_array, input_array.data_type, &input_qparams); transformation->AddMessageF( "Created %s" diff --git a/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc new file mode 100644 index 0000000000000000000000000000000000000000..7f44c65285bdef6ba314b16122fdd550bfa47e6a --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc @@ -0,0 +1,178 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +namespace { + +bool IsTailOfShape(const Shape& tail, const Shape& shape) { + // Return true if 'tail' dimensions are the same as the ending dimensions of + // 'shape'. + + int shape_end = shape.dimensions_count() - 1; + int tail_end = tail.dimensions_count() - 1; + + if (tail_end > shape_end) { + // tail cannot be longer than shape. + return false; + } + + // Walk dimensions back to front and compare + for (int i = 0; i <= tail_end; i++) { + if (shape.dims(shape_end - i) != tail.dims(tail_end - i)) { + return false; + } + } + return true; +} + +} // namespace + +// If a binary operator is doing a broadcast operation from a constant array, +// and the constant array shape is the tail of both the other input shape, and a +// subsequent reshape op's output shape, we can swap their order. Since we +// prefer to have reshape ops after mathematic ops, this can allow for the +// collapsing of some reshapes. The WaveNet model in particular benefits from +// this transformation. +// +// Note we are testing for one particular case of a broader set of possible +// binary-reshape op transformations. This transformation could be generalized. +bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + Operator* binary_op = binary_it->get(); + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv && + binary_op->type != OperatorType::kFloorDiv && + binary_op->type != OperatorType::kFloorMod && + binary_op->type != OperatorType::kMinimum && + binary_op->type != OperatorType::kMaximum && + binary_op->type != OperatorType::kLess && + binary_op->type != OperatorType::kLessEqual && + binary_op->type != OperatorType::kGreater && + binary_op->type != OperatorType::kGreaterEqual) { + return false; + } + + // BINARY OP INPUT CHECKS + CHECK_EQ(binary_op->inputs.size(), 2); + const bool input_is_const[2] = { + IsConstantParameterArray(*model, binary_op->inputs[0]), + IsConstantParameterArray(*model, binary_op->inputs[1]), + }; + if (!input_is_const[0] && !input_is_const[1]) { + // To limit our scope, we require one constant input. Though there's no + // reason this transformation wouldn't work with all variable inputs. + return false; + } + if (input_is_const[0] && input_is_const[1]) { + // Both inputs are constants. Leave this for constants propagation. + return false; + } + const int constant_input_idx = input_is_const[0] ? 0 : 1; + const int variable_input_idx = input_is_const[0] ? 1 : 0; + CHECK(input_is_const[constant_input_idx]); + CHECK(!input_is_const[variable_input_idx]); + + const auto& variable_input_array = + model->GetArray(binary_op->inputs[variable_input_idx]); + if (!variable_input_array.has_shape()) { + AddMessageF( + "Not moving %s because it's non-constant input shape is not resolved.", + LogName(*binary_op)); + return false; + } + if (!IsTailOfShape( + model->GetArray(binary_op->inputs[constant_input_idx]).shape(), + model->GetArray(binary_op->inputs[variable_input_idx]).shape())) { + // Constant array shape must be the latter part of the variable shape. + return false; + } + + // RESHAPE OP CHECKS + auto reshape_it = + FindOpWithOutput(*model, binary_op->inputs[variable_input_idx]); + if (reshape_it == model->operators.end()) { + AddMessageF("Not moving %s because it's variable input is not connected.", + LogName(*binary_op)); + return false; + } + Operator* reshape_op = reshape_it->get(); + if (reshape_op->type != OperatorType::kReshape) { + AddMessageF("Not moving %s because the preceding %s is not a reshape op", + LogName(*binary_op), LogName(*reshape_op)); + return false; + } + const auto& reshape_input_array = model->GetArray(reshape_op->inputs[0]); + if (!reshape_input_array.has_shape()) { + AddMessageF( + "Not moving %s because it's non-constant input shape is not resolved " + "yet", + LogName(*binary_op)); + return false; + } + if (!IsTailOfShape( + model->GetArray(binary_op->inputs[constant_input_idx]).shape(), + model->GetArray(reshape_op->outputs[0]).shape())) { + // Constant array shape must be the latter part of the binary op output + // shape. + return false; + } + + // EXTRA CHECKS ON CONNECTING ARRAY + for (const string& output_array : model->flags.output_arrays()) { + if (binary_op->inputs[variable_input_idx] == output_array) { + AddMessageF( + "Not moving %s because the output of reshape op %s is an output op.", + LogName(*binary_op), LogName(*reshape_op)); + return false; + } + } + int count_ops_consuming_output = + CountOpsWithInput(*model, binary_op->inputs[variable_input_idx]); + DCHECK_GE(count_ops_consuming_output, 1); + if (count_ops_consuming_output > 1) { + AddMessageF( + "Not moving %s because the output of reshape op %s is consumed by " + "another op", + LogName(*binary_op), LogName(*reshape_op)); + return false; + } + + // SWAP ORDER OF BINARY AND RESHAPE OPS + AddMessageF("Moving op %s before reshape op %s", LogName(*binary_op), + LogName(*reshape_op)); + + // Swap op input and outputs + std::iter_swap(reshape_op->inputs.begin(), + binary_op->inputs.begin() + variable_input_idx); + std::iter_swap(reshape_op->outputs.begin(), binary_op->outputs.begin()); + + // Swap operator ordering + std::iter_swap(binary_it, reshape_it); + + // Clear binary output shape so it will be re-propagated + model->GetArray(binary_op->outputs[0]).clear_shape(); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 00ab7cbaa90b399ca08bdfba82991fbd5d2c9f7e..c8310161cb33bcc7137e8b163ea6469698ed2fd7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -62,6 +62,10 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { case OperatorType::kGreaterEqual: case OperatorType::kEqual: case OperatorType::kNotEqual: + case OperatorType::kAny: + case OperatorType::kLogicalAnd: + case OperatorType::kLogicalNot: + case OperatorType::kLogicalOr: // These operators unconditionally produce bool outputs SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); break; @@ -100,6 +104,13 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type; break; } + case OperatorType::kArgMin: { + // Data type of the ArgMin op is specified. + CHECK_EQ(op->outputs.size(), 1); + auto* argmin_op = static_cast(op); + model->GetArray(op->outputs[0]).data_type = argmin_op->output_data_type; + break; + } case OperatorType::kRange: { auto* range_op = static_cast(op); // Output type of the Range op can be set via an attribute @@ -131,7 +142,8 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { CHECK_EQ(op->inputs.size(), 2); CHECK_EQ(op->outputs.size(), 2); CHECK(model->GetArray(op->inputs[1]).data_type == ArrayDataType::kInt32); - model->GetArray(op->outputs[0]).data_type = model->GetArray(op->inputs[0]).data_type; + model->GetArray(op->outputs[0]).data_type = + model->GetArray(op->inputs[0]).data_type; model->GetArray(op->outputs[1]).data_type = ArrayDataType ::kInt32; break; } @@ -144,8 +156,8 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { return false; } for (int i = 0; i < op->outputs.size(); ++i) { - auto output = op->outputs[i]; - auto data_type = unsupported_op->output_data_types[i]; + const string& output = op->outputs[i]; + const ArrayDataType data_type = unsupported_op->output_data_types[i]; model->GetArray(output).data_type = data_type; } break; @@ -183,6 +195,38 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { SetDataTypeForAllOutputs(model, op, data_type); break; } + case OperatorType::kPack: { + const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; + for (const auto& input : op->inputs) { + CHECK(data_type == model->GetArray(input).data_type); + } + SetDataTypeForAllOutputs(model, op, data_type); + break; + } + case OperatorType::kOneHot: { + CHECK_EQ(op->inputs.size(), 4); + CHECK_EQ(op->outputs.size(), 1); + const ArrayDataType on_value_type = + model->GetArray(op->inputs[OneHotOperator::ON_VALUE_INPUT]).data_type; + const ArrayDataType off_value_type = + model->GetArray(op->inputs[OneHotOperator::OFF_VALUE_INPUT]) + .data_type; + CHECK(on_value_type == off_value_type); + model->GetArray(op->outputs[0]).data_type = on_value_type; + break; + } + case OperatorType::kCTCBeamSearchDecoder: { + CHECK_EQ(op->inputs.size(), 2); + // All outputs (sparse tensors) are int32s (although tf uses int64s) + // except the last one (log probabilities) is float. + const int output_size = op->outputs.size(); + for (int i = 0; i < output_size - 1; ++i) { + model->GetArray(op->outputs[i]).data_type = ArrayDataType::kInt32; + } + model->GetArray(op->outputs[output_size - 1]).data_type = + ArrayDataType::kFloat; + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc index 50b90e7c2bfddb0382a4d44ad6c90fc7f7701273..cd078ef189e922682098a0ec8dc4743060181aac 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc @@ -25,6 +25,14 @@ limitations under the License. namespace toco { +namespace { + +bool SupportsMinMax(const Array& array) { + return array.data_type == ArrayDataType::kFloat; +} + +} // namespace + // Propagates default min/max values to any operator input/output array that // is missing them. // @@ -39,14 +47,16 @@ bool PropagateDefaultMinMax::Run(Model* model, std::size_t op_index) { for (const auto& input : op->inputs) { auto& input_array = model->GetArray(input); - if (!input_array.minmax && !input_array.buffer) { + if (!input_array.minmax && !input_array.buffer && + SupportsMinMax(input_array)) { did_change |= SetArrayMinMax(input, &input_array); } } for (const auto& output : op->outputs) { auto& output_array = model->GetArray(output); - if (!output_array.minmax && !output_array.buffer) { + if (!output_array.minmax && !output_array.buffer && + SupportsMinMax(output_array)) { did_change |= SetArrayMinMax(output, &output_array); } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc index 0f2592d05f6e01599735c5138c53ba7779ce805d..3ad6b0ec6f7a3c4a9a0ab3964c1198ee757ea4b5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc @@ -30,15 +30,9 @@ namespace { bool ChangeArrayDataType(GraphTransformation* transformation, Array* array, ArrayDataType new_data_type, const MinMax* new_minmax) { - // The code below assumes kInt16, see - // GetQuantizationParamsFromMinMax - if (new_data_type != ArrayDataType::kInt16) { - return false; - } - - bool changed = false; // Ensure the array ends up in the new type (if it hasn't yet been quantized). - if ((array->final_data_type != new_data_type)) { + bool changed = false; + if (array->final_data_type != new_data_type) { array->final_data_type = new_data_type; changed = true; } @@ -72,12 +66,10 @@ bool ChangeArrayDataType(GraphTransformation* transformation, Array* array, "Rescaling min/max from %g,%g (%s) to %g,%g (%s)", array_minmax.min, array_minmax.max, ArrayDataTypeName(array->data_type), min, max, ArrayDataTypeName(new_data_type)); - array_minmax.min = min; array_minmax.max = max; - GetQuantizationParamsFromMinMax( - array_minmax, array->quantization_params.get()); - + ChooseQuantizationParamsForArrayAndQuantizedDataType( + *array, new_data_type, array->quantization_params.get()); // Directly change the type as the array was already quantized. array->data_type = new_data_type; changed = true; @@ -95,6 +87,7 @@ bool ChangeArrayDataType(GraphTransformation* transformation, Array* array, changed = true; } } + return changed; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 82b3ab96fe07a7385e678cc9ccfd68ca1d7ce330..91e290439ae4bfd491c8201b02b161fe2caf2f8d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -437,6 +437,7 @@ void ProcessTensorFlowReshapeOperator(Model* model, product_non_wildcard_dims *= shape_data[i]; } } + const int input_flat_size = RequiredBufferSizeForShape(input_shape); if (has_wildcard) { CHECK_GE(input_flat_size, product_non_wildcard_dims) @@ -445,6 +446,12 @@ void ProcessTensorFlowReshapeOperator(Model* model, << op->outputs[0] << "\". Are your input shapes correct?"; shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims; } + + if (shape_data.size() == 1 && shape_data[0] == 0) { + // We have reshaped a scalar, so preserve as a scalar. + shape_data.clear(); + } + auto& output_shape = *output_array.mutable_shape(); *output_shape.mutable_dims() = shape_data; CHECK_EQ(input_flat_size, RequiredBufferSizeForShape(output_shape)) @@ -522,12 +529,14 @@ void ProcessAddNOperator(Model* model, Operator* op) { bool KeepDims(const Operator& op) { switch (op.type) { - case OperatorType::kMin: // Reduction Min + case OperatorType::kReduceMin: // Reduction Min return static_cast(op).keep_dims; - case OperatorType::kMax: // Reduction Max + case OperatorType::kReduceMax: // Reduction Max return static_cast(op).keep_dims; case OperatorType::kSum: return static_cast(op).keep_dims; + case OperatorType::kReduceProd: + return static_cast(op).keep_dims; case OperatorType::kMean: return static_cast(op).keep_dims; default: @@ -1034,20 +1043,28 @@ void ProcessGatherOperator(Model* model, GatherOperator* op) { return; } + // Yield until the axis has been resolved. + if (!op->axis) { + return; + } + int axis = op->axis.value(); + const auto& input_shape = input_array.shape(); const auto& indices_shape = indices_array.shape(); QCHECK_GE(input_shape.dimensions_count(), 1); op->input_rank = input_shape.dimensions_count(); + QCHECK_LT(axis, op->input_rank); - // We only support 1-D indices. - QCHECK_EQ(indices_shape.dimensions_count(), 1); - - // Copy the input dimensions to the output except for dimension 0, + // Copy the input dimensions to the output except for the axis dimensions // where the dimension of indices_shape is used. - // TODO(mgubin): if axis != 0 this is not true, change when it's supported. auto output_dims = output_array.mutable_shape()->mutable_dims(); - output_dims->push_back(indices_shape.dims(0)); - for (int dim = 1; dim < input_shape.dimensions_count(); dim++) { + for (int dim = 0; dim < axis; ++dim) { + output_dims->push_back(input_shape.dims(dim)); + } + for (int dim = 0; dim < indices_shape.dimensions_count(); ++dim) { + output_dims->push_back(indices_shape.dims(dim)); + } + for (int dim = axis + 1; dim < input_shape.dimensions_count(); ++dim) { output_dims->push_back(input_shape.dims(dim)); } } @@ -1065,27 +1082,23 @@ void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) { } // Yield until input dims have been resolved. - if (!input_values.has_shape()) { + if (!input_values.has_shape() || !input_k.has_shape()) { return; } - const auto& input_values_shape = input_values.shape(); - auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims(); - auto output_values_dims = output_values.mutable_shape()->mutable_dims(); - for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) { - output_indexes_dims->push_back(input_values_shape.dims(dim)); - output_values_dims->push_back(input_values_shape.dims(dim)); - } // If the value is initialized, we can specify the last dimension, otherwise // unknown. if (input_k.buffer) { + const auto& input_values_shape = input_values.shape(); + auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims(); + auto output_values_dims = output_values.mutable_shape()->mutable_dims(); + for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) { + output_indexes_dims->push_back(input_values_shape.dims(dim)); + output_values_dims->push_back(input_values_shape.dims(dim)); + } const int32_t k_value = input_k.GetBuffer().data[0]; output_indexes_dims->push_back(k_value); output_values_dims->push_back(k_value); - - } else { - output_indexes_dims->push_back(0); - output_values_dims->push_back(0); } } @@ -1193,7 +1206,7 @@ void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) { output_shape->ReplaceDims({input_array.shape().dimensions_count()}); } -void ProcessStackOperator(Model* model, StackOperator* op) { +void ProcessPackOperator(Model* model, PackOperator* op) { CHECK_GE(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); auto& output_array = model->GetArray(op->outputs[0]); @@ -1202,7 +1215,7 @@ void ProcessStackOperator(Model* model, StackOperator* op) { return; } - std::unique_ptr stacked_shape; + std::unique_ptr packed_shape; for (const auto& input : op->inputs) { const auto& input_array = model->GetArray(input); if (!input_array.has_shape()) { @@ -1211,23 +1224,23 @@ void ProcessStackOperator(Model* model, StackOperator* op) { } Shape shape = input_array.shape(); - if (!stacked_shape) { - stacked_shape.reset(new Shape(shape)); + if (!packed_shape) { + packed_shape.reset(new Shape(shape)); } else { - CHECK(*stacked_shape == shape) << "All input arrays to Stack operators " - "must have the same shape. Input \"" - << input << "\" is different."; + CHECK(*packed_shape == shape) << "All input arrays to Pack operators " + "must have the same shape. Input \"" + << input << "\" is different."; } } int axis = op->axis; if (axis < 0) { // Handle negative axis - axis += stacked_shape->dims().size() + 1; + axis += packed_shape->dims().size() + 1; } - stacked_shape->mutable_dims()->insert( - stacked_shape->mutable_dims()->begin() + axis, op->inputs.size()); - output_array.copy_shape(*stacked_shape); + packed_shape->mutable_dims()->insert( + packed_shape->mutable_dims()->begin() + axis, op->inputs.size()); + output_array.copy_shape(*packed_shape); } void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { @@ -1407,7 +1420,8 @@ void ProcessTransposeOperator(Model* model, TransposeOperator* op) { } } -void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) { +template +void ProcessArgMinMaxOperator(Model* model, Op* op) { CHECK_EQ(op->inputs.size(), 2); const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. @@ -1501,6 +1515,120 @@ void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) { } } +void ProcessAnyOperator(Model* model, AnyOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // We have already run. + return; + } + + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.has_shape()) { + // Yield until input dims have been resolved. + return; + } + const auto& input_shape = input_array.shape(); + + auto& reduction_indices_array = model->GetArray(op->inputs[1]); + if (!reduction_indices_array.has_shape()) { + // Yield until reduction indices shape been resolved. + return; + } + if (!reduction_indices_array.buffer) { + // Yield until the reduction indices are constant. + return; + } + CHECK(reduction_indices_array.data_type == ArrayDataType::kInt32) + << "Any reduction input must be int32"; + + int input_rank = input_shape.dimensions_count(); + std::set true_indices; + const auto& reduction_indices = + reduction_indices_array.GetBuffer().data; + for (int i = 0; i < reduction_indices.size(); ++i) { + const int32 reduction_index = reduction_indices[i]; + if (reduction_index < -input_rank || reduction_index >= input_rank) { + CHECK(false) << "Invalid reduction dimension " << reduction_index + << " for input with " << input_rank << " dimensions"; + } + int32 wrapped_index = reduction_index; + if (wrapped_index < 0) { + wrapped_index += input_rank; + } + true_indices.insert(wrapped_index); + } + + auto* mutable_dims = output_array.mutable_shape()->mutable_dims(); + mutable_dims->clear(); + for (int i = 0; i < input_rank; ++i) { + if (true_indices.count(i) > 0) { + if (op->keep_dims) { + mutable_dims->emplace_back(1); + } + } else { + mutable_dims->emplace_back(input_shape.dims(i)); + } + } +} + +void ProcessOneHotOperator(Model* model, OneHotOperator* op) { + CHECK_EQ(op->inputs.size(), 4); + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // Shape already propagated + return; + } + + // Yield until indices dims have been resolved. + const auto& indices_array = + model->GetArray(op->inputs[OneHotOperator::INDICES_INPUT]); + if (!indices_array.has_shape()) { + return; + } + + // Yield until depth is constant and dims have been resolved. + if (!IsConstantParameterArray(*model, + op->inputs[OneHotOperator::DEPTH_INPUT])) { + return; + } + const auto& depth_array = + model->GetArray(op->inputs[OneHotOperator::DEPTH_INPUT]); + if (!depth_array.has_shape()) { + return; + } + + CHECK(depth_array.data_type == ArrayDataType::kInt32) + << "Depth array must be int32."; + CHECK_EQ(RequiredBufferSizeForShape(depth_array.shape()), 1) + << "Depth array must be scalar."; + + const int depth = depth_array.GetBuffer().data[0]; + CHECK_GE(depth, 0) << "Depth must be non-negative."; + + const int indices_dims = indices_array.shape().dimensions_count(); + const int output_dims = indices_dims + 1; + const int axis = op->axis == -1 ? indices_dims : op->axis; + CHECK_GE(axis, 0) << "Resolved axis must be non-negative."; + + auto* mutable_dims = output_array.mutable_shape()->mutable_dims(); + mutable_dims->resize(output_dims); + for (int i = 0; i < output_dims; ++i) { + int dim = 0; + if (i < axis) { + dim = indices_array.shape().dims(i); + } else if (i == axis) { + dim = depth; + } else { + dim = indices_array.shape().dims(i - 1); + } + (*mutable_dims)[i] = dim; + } +} + } // namespace bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { @@ -1539,6 +1667,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kFloor: case OperatorType::kExp: case OperatorType::kSin: + case OperatorType::kLogicalAnd: + case OperatorType::kLogicalNot: + case OperatorType::kLogicalOr: ProcessSimpleOperator(model, op, 0); break; case OperatorType::kGather: @@ -1607,9 +1738,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kL2Pool: ProcessL2PoolOperator(model, static_cast(op)); break; - case OperatorType::kMin: // Reduction Min - case OperatorType::kMax: // Reduction Max + case OperatorType::kReduceMin: // Reduction Min + case OperatorType::kReduceMax: // Reduction Max case OperatorType::kSum: + case OperatorType::kReduceProd: case OperatorType::kMean: ProcessTensorFlowReductionOperator(model, op); break; @@ -1658,8 +1790,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kShape: ProcessShapeOperator(model, static_cast(op)); break; - case OperatorType::kStack: - ProcessStackOperator(model, static_cast(op)); + case OperatorType::kPack: + ProcessPackOperator(model, static_cast(op)); break; case OperatorType::kReorderAxes: ProcessReorderAxesOperator(model, static_cast(op)); @@ -1699,10 +1831,26 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { static_cast(op)); break; case OperatorType::kArgMax: - ProcessArgMaxOperator(model, static_cast(op)); + ProcessArgMinMaxOperator( + model, static_cast(op)); break; - case OperatorType::kUnsupported: + case OperatorType::kArgMin: + ProcessArgMinMaxOperator( + model, static_cast(op)); break; + case OperatorType::kUnsupported: { + const auto* unsupported_op = + static_cast(op); + // Attribute can be not specified, ignore it. + if (unsupported_op->output_shapes.size() < op->outputs.size()) { + return false; + } + for (int i = 0; i < op->outputs.size(); ++i) { + const string& output = op->outputs[i]; + model->GetArray(output).copy_shape(unsupported_op->output_shapes.at(i)); + } + break; + } case OperatorType::kSvdf: ProcessSvdfOperator(model, static_cast(op)); break; @@ -1726,6 +1874,12 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kTile: ProcessTileOperator(model, static_cast(op)); break; + case OperatorType::kAny: + ProcessAnyOperator(model, static_cast(op)); + break; + case OperatorType::kOneHot: + ProcessOneHotOperator(model, static_cast(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc index d74cad9a626b3a472e2740d6bdaaaf7aab5bd484..44733391f5a1d9ebf9a24f4f31b425a35354e1fc 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc @@ -74,46 +74,54 @@ ArrayDataType GetQuantizedDataType(const Array& array, } } -void GetQuantizationParams(ArrayDataType data_type, const MinMax& minmax, - QuantizationParams* quantization_params) { - switch (data_type) { +template +void ChooseQuantizationParamsForArrayAndQuantizedDataType( + const Array& array, QuantizationParams* quantization_params) { + *quantization_params = ::tflite::ChooseQuantizationParams>( + array.minmax->min, array.minmax->max, array.narrow_range); +} + +void ChooseQuantizationParamsForArrayAndQuantizedDataType( + const Array& array, ArrayDataType quantized_data_type, + QuantizationParams* quantization_params) { + switch (quantized_data_type) { case ArrayDataType::kInt8: - GetQuantizationParamsFromMinMax( - minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType< + ArrayDataType::kInt8>(array, quantization_params); break; case ArrayDataType::kUint8: - GetQuantizationParamsFromMinMax( - minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType< + ArrayDataType::kUint8>(array, quantization_params); break; case ArrayDataType::kInt16: - GetQuantizationParamsFromMinMax( - minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType< + ArrayDataType::kInt16>(array, quantization_params); break; case ArrayDataType::kUint16: - GetQuantizationParamsFromMinMax( - minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType< + ArrayDataType::kUint16>(array, quantization_params); break; case ArrayDataType::kInt32: - GetQuantizationParamsFromMinMax( - minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType< + ArrayDataType::kInt32>(array, quantization_params); break; case ArrayDataType::kUint32: - GetQuantizationParamsFromMinMax( - minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType< + ArrayDataType::kUint32>(array, quantization_params); break; case ArrayDataType::kInt64: - GetQuantizationParamsFromMinMax( - minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType< + ArrayDataType::kInt64>(array, quantization_params); break; case ArrayDataType::kUint64: - GetQuantizationParamsFromMinMax( - minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType< + ArrayDataType::kUint64>(array, quantization_params); break; case ArrayDataType::kFloat: case ArrayDataType::kNone: default: LOG(FATAL) << "Unhandled final quantization type " - << static_cast(data_type); + << static_cast(quantized_data_type); } } @@ -121,8 +129,8 @@ namespace { template std::unique_ptr QuantizeBuffer( - const GenericBuffer& buffer, - const QuantizationParams& quantization_params) { + const Array& array, const QuantizationParams& quantization_params) { + const GenericBuffer& buffer = *array.buffer; const auto inverse_scale = 1. / quantization_params.scale; CHECK(buffer.type == ArrayDataType::kFloat); const auto& float_buffer = @@ -140,8 +148,15 @@ std::unique_ptr QuantizeBuffer( } else { scaled_val = quantization_params.zero_point + inverse_scale * src_val; } - quantized_buffer->data[i] = - tflite::SafeCast>(std::round(scaled_val)); + auto integer_val = tflite::SafeCast>(std::round(scaled_val)); + // In addition to its effect on the choice of quantization params upstream + // of here, narrow_range also means nudge the min quantized value by +1, + // so e.g. uint8 values get constrained to [1, 255]. + if (integer_val == std::numeric_limits>::min() && + array.narrow_range) { + integer_val++; + } + quantized_buffer->data[i] = integer_val; } return std::unique_ptr(quantized_buffer); } @@ -155,7 +170,7 @@ void QuantizeArray(GraphTransformation* transformation, Model* model, CHECK(!array.quantization_params); array.GetOrCreateQuantizationParams() = quantization_params; if (array.buffer) { - array.buffer = QuantizeBuffer(*array.buffer, quantization_params); + array.buffer = QuantizeBuffer(array, quantization_params); } array.data_type = A; array.final_data_type = A; @@ -210,8 +225,8 @@ bool IsArrayQuantizedRangeSubset(GraphTransformation* transformation, } else { // Work around cases where we are asking for this prior to the Quantize // transformation having added the quantization_params. - GetQuantizationParams(quantized_data_type, *array.minmax, - &quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType( + array, quantized_data_type, &quantization_params); transformation->AddMessageF( "No quantization params - infering from data type %s with minmax " "%g,%g as zero_point=%g, scale=%g", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h index 79a2ce7e50887b4608b278471da0e5e63b5673e3..cf093c6f17b45839156dae0d06ca2fc7e5e2f3c6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h @@ -38,21 +38,11 @@ bool GetQuantizedDataTypeNumericalRange(ArrayDataType data_type, ArrayDataType GetQuantizedDataType(const Array& array, ArrayDataType default_type); -// Returns the quantization params for the array with the given data type and -// minmax. -void GetQuantizationParams(ArrayDataType data_type, const MinMax& minmax, - QuantizationParams* quantization_params); - -// Returns the quantization params for the data type and minmax values. -template -void GetQuantizationParamsFromMinMax(const MinMax& minmax, - QuantizationParams* quantization_params) { - using Integer = DataType; - const double rmin = minmax.min; - const double rmax = minmax.max; - *quantization_params = - ::tflite::ChooseQuantizationParams(rmin, rmax); -} +// Chooses the quantization params for a given array and a given target +// quantized data type (which may not be the array's current data type). +void ChooseQuantizationParamsForArrayAndQuantizedDataType( + const Array& array, ArrayDataType quantized_data_type, + QuantizationParams* quantization_params); // Quantizes an array by setting its data type and (if constant) quantizing // all values in the array. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index 38699a62b5b18e9fbb24060f3eb4ac373f175c09..8d22ae2eb1356b8c9c9430c517acddfc971b9f57 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -50,6 +50,7 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kSqueeze || type == OperatorType::kPad || type == OperatorType::kPadV2 || type == OperatorType::kReshape || type == OperatorType::kTanh || type == OperatorType::kMul || + type == OperatorType::kBatchToSpaceND || type == OperatorType::kSum || type == OperatorType::kSpaceToBatchND || type == OperatorType::kSpaceToDepth || type == OperatorType::kStridedSlice || @@ -59,9 +60,21 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kGreater || type == OperatorType::kGreaterEqual || type == OperatorType::kLess || type == OperatorType::kLessEqual || type == OperatorType::kSelect || - type == OperatorType::kArgMax; + type == OperatorType::kArgMax || type == OperatorType::kRelu || + type == OperatorType::kRelu1 || type == OperatorType::kRelu6 || + type == OperatorType::kShape || type == OperatorType::kExpandDims; } +// The quantized op allows output arrays of type float using +// the attribute support_output_type_float_in_quantized_op +bool SupportOutputTypeFloatInQuantizedOp(const Operator& op) { + auto type = op.type; + if (type == OperatorType::kUnsupported) { + auto* unsupported = static_cast(&op); + return unsupported->support_output_type_float_in_quantized_op; + } + return false; +} const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) { auto& array = model->GetArray(array_name); // Normally we should have a MinMax recorded on this Array, @@ -211,13 +224,15 @@ bool ChooseQuantizationForOperatorInput( if (op.type == OperatorType::kLstmCell) { if (input_index == LstmCellOperator::PREV_STATE_INPUT) { *quantized_data_type = ArrayDataType::kInt16; - GetQuantizationParams(*quantized_data_type, minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType( + array, *quantized_data_type, quantization_params); return true; } } *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8); - GetQuantizationParams(*quantized_data_type, minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType( + array, *quantized_data_type, quantization_params); transformation->AddMessageF( "For input array %s with min=%g, max=%g, chose to quantize as %s (f=%s) " "with zero_point=%d, scale=%g", @@ -325,12 +340,13 @@ bool ChooseQuantizationForOperatorOutput( output, OperatorTypeName(op.type)); return true; } - if ((op.type == OperatorType::kDepthToSpace) || - (op.type == OperatorType::kSpaceToDepth) || - (op.type == OperatorType::kReshape) || - (op.type == OperatorType::kSplit) || - (op.type == OperatorType::kConcatenation && - model->flags.change_concat_input_ranges())) { + if ((op.type == OperatorType::kConcatenation && + model->flags.change_concat_input_ranges()) || + op.type == OperatorType::kDepthToSpace || + op.type == OperatorType::kSpaceToDepth || + op.type == OperatorType::kReshape || op.type == OperatorType::kSplit || + op.type == OperatorType::kRelu || op.type == OperatorType::kRelu1 || + op.type == OperatorType::kRelu6) { int data_input_index = 0; if (op.type == OperatorType::kSplit) { data_input_index = 1; @@ -356,12 +372,14 @@ bool ChooseQuantizationForOperatorOutput( if (output_index == LstmCellOperator::STATE_OUTPUT || output_index == LstmCellOperator::ACTIV_TEMP) { *quantized_data_type = ArrayDataType::kInt16; - GetQuantizationParams(*quantized_data_type, minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType( + array, *quantized_data_type, quantization_params); return true; } } *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8); - GetQuantizationParams(*quantized_data_type, minmax, quantization_params); + ChooseQuantizationParamsForArrayAndQuantizedDataType( + array, *quantized_data_type, quantization_params); transformation->AddMessageF( "For output array %s with min=%g, max=%g" ", chose to quantize as %s with zero_point=%d" @@ -577,61 +595,67 @@ bool Quantize::Run(Model* model, std::size_t op_index) { } // Quantize outputs, add Dequantize ops as needed on the outputs side - for (std::size_t output_index = 0; output_index < op.outputs.size(); - output_index++) { - ArrayDataType quantized_data_type; - QuantizationParams quantization_params; - if (ChooseQuantizationForOperatorOutput(this, model, op, output_index, - &quantized_data_type, - &quantization_params)) { - changed = true; - const auto& output = op.outputs[output_index]; - auto& output_array = model->GetArray(output); - - // Fix up the min/max information on the output array to match the chosen - // quantization parameters. - CHECK(output_array.minmax) - << "Output array named " << output << " lacks minmax"; - auto& output_minmax = output_array.GetMinMax(); - FixMinMaxPostQuantization(this, quantized_data_type, quantization_params, - &output_minmax); - - QuantizeArray(this, model, output, quantized_data_type, - quantization_params); - - const auto& dequantized_output = - AvailableArrayName(*model, output + "_dequantized"); - auto& dequantized_output_array = - model->GetOrCreateArray(dequantized_output); - dequantized_output_array.data_type = ArrayDataType::kFloat; - dequantized_output_array.final_data_type = output_array.data_type; - auto& dequantized_output_minmax = - dequantized_output_array.GetOrCreateMinMax(); - dequantized_output_minmax.min = output_minmax.min; - dequantized_output_minmax.max = output_minmax.max; - for (const auto& other_op : model->operators) { - for (auto& other_op_input : other_op->inputs) { - if (other_op_input == output) { - other_op_input = dequantized_output; + if (SupportOutputTypeFloatInQuantizedOp(op)) { + LOG(WARNING) + << HelpfulOperatorTypeName(op) << " is a quantized op" + << "but it has a model flag that sets the output arrays to float."; + } else { + for (std::size_t output_index = 0; output_index < op.outputs.size(); + output_index++) { + QuantizationParams quantization_params; + ArrayDataType quantized_data_type; + if (ChooseQuantizationForOperatorOutput(this, model, op, output_index, + &quantized_data_type, + &quantization_params)) { + changed = true; + const auto& output = op.outputs[output_index]; + auto& output_array = model->GetArray(output); + + // Fix up the min/max information on the output array to match the + // chosen quantization parameters. + CHECK(output_array.minmax) + << "Output array named " << output << " lacks minmax"; + auto& output_minmax = output_array.GetMinMax(); + FixMinMaxPostQuantization(this, quantized_data_type, + quantization_params, &output_minmax); + + QuantizeArray(this, model, output, quantized_data_type, + quantization_params); + + const auto& dequantized_output = + AvailableArrayName(*model, output + "_dequantized"); + auto& dequantized_output_array = + model->GetOrCreateArray(dequantized_output); + dequantized_output_array.data_type = ArrayDataType::kFloat; + dequantized_output_array.final_data_type = output_array.data_type; + auto& dequantized_output_minmax = + dequantized_output_array.GetOrCreateMinMax(); + dequantized_output_minmax.min = output_minmax.min; + dequantized_output_minmax.max = output_minmax.max; + for (const auto& other_op : model->operators) { + for (auto& other_op_input : other_op->inputs) { + if (other_op_input == output) { + other_op_input = dequantized_output; + } } } - } - auto* dequantize_op = new DequantizeOperator; - dequantize_op->inputs = {output}; - dequantize_op->outputs = {dequantized_output}; - for (int i = 0; i < model->flags.output_arrays_size(); i++) { - if (model->flags.output_arrays(i) == output) { - // TODO(b/78013785): never rename output arrays. - AddMessageF( - "Renaming output array %d after inserting dequant op %s: %s -> " - "%s", - i, LogName(*dequantize_op), model->flags.output_arrays(i), - dequantized_output); - model->flags.set_output_arrays(i, dequantized_output); + auto* dequantize_op = new DequantizeOperator; + dequantize_op->inputs = {output}; + dequantize_op->outputs = {dequantized_output}; + for (int i = 0; i < model->flags.output_arrays_size(); i++) { + if (model->flags.output_arrays(i) == output) { + // TODO(b/78013785): never rename output arrays. + AddMessageF( + "Renaming output array %d after inserting dequant op %s: %s -> " + "%s", + i, LogName(*dequantize_op), model->flags.output_arrays(i), + dequantized_output); + model->flags.set_output_arrays(i, dequantized_output); + } } + const auto op_it = FindOp(*model, &op); + model->operators.emplace(op_it + 1, dequantize_op); } - const auto op_it = FindOp(*model, &op); - model->operators.emplace(op_it + 1, dequantize_op); } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc index 88ea0945e7dd15ba325d34ea3fdbf34ff7d91381..7a8515f6d12f96d464ea0764907f9cc2a487d3e7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc @@ -36,10 +36,8 @@ void GetQuantizationParamsFromArray(const Array& array, const std::vector& float_vals = array.GetBuffer().data; auto minmax = std::minmax_element(float_vals.begin(), float_vals.end()); - MinMax toco_minmax; - toco_minmax.min = *minmax.first; - toco_minmax.max = *minmax.second; - GetQuantizationParams(ArrayDataType::kUint8, toco_minmax, params); + *params = tflite::ChooseQuantizationParams( + *minmax.first, *minmax.second, array.narrow_range); } } // namespace diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc new file mode 100644 index 0000000000000000000000000000000000000000..5b41c49bfaff245d599d26989e4ed3f9b0d582cf --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool ApplyAttrsToArray(GraphTransformation* transformation, Model* model, + const FakeQuantOperator& fq_op, + const string& array_name) { + bool changed = false; + auto& annotated_array = model->GetArray(array_name); + if (!annotated_array.minmax) { + const MinMax& minmax = *fq_op.minmax; + annotated_array.GetOrCreateMinMax() = minmax; + transformation->AddMessageF( + "Read min/max annotation for array %s: min=%g, max=%g", array_name, + minmax.min, minmax.max); + changed = true; + } + if (fq_op.narrow_range && !annotated_array.narrow_range) { + annotated_array.narrow_range = true; + transformation->AddMessageF("Read narrow_range annotation for array %s", + array_name); + changed = true; + } + return changed; +} + +} // end namespace + +bool ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run(Model* model, + std::size_t op_index) { + const auto fakequant_it = model->operators.begin() + op_index; + auto* fakequant_base_op = fakequant_it->get(); + if (fakequant_base_op->type != OperatorType::kFakeQuant) { + return false; + } + auto* fq_op = static_cast(fakequant_base_op); + + if (!fq_op->minmax) { + // Need to be resolved first by ResolveFakeQuantArgsFromVars. + return false; + } + + // At this point, this FakeQuantOperator should have a MinMax + // attached to it, and should only have 1 input (it should not have + // 2nd and 3rd input arrays giving min and max anymore). + CHECK(fq_op->minmax); + CHECK_EQ(1, fq_op->inputs.size()); + + return ApplyAttrsToArray(this, model, *fq_op, fq_op->inputs[0]) || + ApplyAttrsToArray(this, model, *fq_op, fq_op->outputs[0]); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc deleted file mode 100644 index bdcca5b7caf61a62203debaa32c4d7a9b2eb43fa..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc +++ /dev/null @@ -1,112 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include -#include -#include -#include - -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" - -namespace toco { - -namespace { - -bool ApplyMinMaxToArray(GraphTransformation* transformation, Model* model, - const MinMax& minmax, const string& array_name) { - auto& annotated_array = model->GetArray(array_name); - if (annotated_array.minmax) { - return false; - } - annotated_array.GetOrCreateMinMax() = minmax; - transformation->AddMessageF( - "Read min/max annotation for array %s: min=%g, max=%g", array_name, - minmax.min, minmax.max); - return true; -} - -} // end namespace - -bool ReadFakeQuantMinMax::Run(Model* model, std::size_t op_index) { - const auto fakequant_it = model->operators.begin() + op_index; - auto* fakequant_base_op = fakequant_it->get(); - if (fakequant_base_op->type != OperatorType::kFakeQuant) { - return false; - } - auto* fakequant_op = static_cast(fakequant_base_op); - - bool changed = false; - - if (!fakequant_op->minmax) { - CHECK_EQ(fakequant_op->inputs.size(), 3); - // We need to yield until the min and max parameters have been - // resolved to constant arrays. - for (int i = 1; i <= 2; i++) { - if (!IsConstantParameterArray(*model, fakequant_op->inputs[1])) { - return false; - } - } - - // Obtain the final min/max values - const auto& min_array = model->GetArray(fakequant_op->inputs[1]); - const auto& max_array = model->GetArray(fakequant_op->inputs[2]); - CHECK_EQ(RequiredBufferSizeForShape(min_array.shape()), 1); - CHECK_EQ(RequiredBufferSizeForShape(max_array.shape()), 1); - fakequant_op->minmax.reset(new MinMax); - MinMax& minmax = *fakequant_op->minmax; - minmax.min = min_array.GetBuffer().data[0]; - minmax.max = max_array.GetBuffer().data[0]; - // We always want [min, max] to contain 0. - if (minmax.min > 0 || minmax.max < 0) { - LOG(ERROR) << "For " << LogName(*fakequant_op) << " the MinMax range " - << "[" << minmax.min << ", " << minmax.max - << "] does not contain 0. " - << "Proceeding by tweaking it to contain 0, which will result " - "in poor accuracy."; - } - minmax.min = std::min(minmax.min, 0.); - minmax.max = std::max(minmax.max, 0.); - - // We won't use the input arrays that provided these min and max - // values, anymore. Delete them unless they are used by something - // else. - for (int i = 1; i <= 2; i++) { - if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) { - model->EraseArray(fakequant_op->inputs[i]); - } - } - fakequant_op->inputs.resize(1); - changed = true; - } - - // At this point, this FakeQuantOperator should have a MinMax - // attached to it, and should only have 1 input (it should not have - // 2nd and 3rd input arrays giving min and max anymore). - CHECK(fakequant_op->minmax); - CHECK_EQ(1, fakequant_op->inputs.size()); - - const MinMax& minmax = *fakequant_op->minmax; - - // Record the MinMax info on the input and output arrays - changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->inputs[0]); - changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->outputs[0]); - - return changed; -} - -} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc index 404f27e067402474484d3ee8e23595fb9f93a6c9..5295eeccecb05b05232922f4b5e4ef75a2b04672 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc @@ -59,6 +59,15 @@ bool IsReshapeTrivial(const Model& model, const Operator& op, if (CountOpsWithInput(model, op.outputs[0]) == 1) { const auto* next_op = GetOpWithInput(model, op.outputs[0]); if (next_op->type == OperatorType::kReshape) { + if (!IsDiscardableArray(model, next_op->outputs[0])) { + // If the |next_op| output is used as a model output we need to preserve + // its shape. + transformation->AddMessageF( + "%s cannot be merged into following reshape %s as it is " + "non-discardable and must keep the specified shape", + LogName(op), LogName(*next_op)); + return false; + } transformation->AddMessageF( "%s is trivial because its output is only consumed by another " "Reshape op %s", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc index efb7bb218421dd045e3e8e2a38b9c70989f222e1..d395d7a6a0862d93fd4f52bb8b8d8d3ea7f8dc1e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc @@ -25,6 +25,40 @@ limitations under the License. namespace toco { +template +void GetBoundsForQuantizedDataType(float* min, float* max) { + using limits = std::numeric_limits>; + *min = limits::min(); + *max = limits::max(); +} + +void GetBoundsForQuantizedDataType(ArrayDataType quantized_data_type, + float* min, float* max) { + // It is important for matching accuracy between TF training and TFLite + // inference, that the min and max values are float to match TF's + // FakeQuantWithMinMaxVarsFunctor. + switch (quantized_data_type) { + case ArrayDataType::kUint8: + return GetBoundsForQuantizedDataType(min, max); + case ArrayDataType::kInt8: + return GetBoundsForQuantizedDataType(min, max); + case ArrayDataType::kUint16: + return GetBoundsForQuantizedDataType(min, max); + case ArrayDataType::kInt16: + return GetBoundsForQuantizedDataType(min, max); + case ArrayDataType::kUint32: + return GetBoundsForQuantizedDataType(min, max); + case ArrayDataType::kInt32: + return GetBoundsForQuantizedDataType(min, max); + case ArrayDataType::kUint64: + return GetBoundsForQuantizedDataType(min, max); + case ArrayDataType::kInt64: + return GetBoundsForQuantizedDataType(min, max); + default: + LOG(FATAL) << "unhandled quantized data type"; + } +} + bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { const auto fakequant_it = model->operators.begin() + op_index; const auto* fakequant_base_op = fakequant_it->get(); @@ -76,18 +110,25 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { const int size = input_buffer.data.size(); output_buffer.data.resize(size); QuantizationParams qparams; - GetQuantizationParamsFromMinMax(*fakequant_op->minmax, - &qparams); - for (int i = 0; i < size; i++) { - const double src_val = input_buffer.data[i]; - const double unclamped_quantized_val = - std::round(qparams.zero_point + src_val / qparams.scale); - const double quantized_val = - std::min(255., std::max(0., unclamped_quantized_val)); - const double dst_val = qparams.scale * (quantized_val - qparams.zero_point); - output_buffer.data[i] = dst_val; + ChooseQuantizationParamsForArrayAndQuantizedDataType( + output_array, quantized_data_type, &qparams); + float quantized_min, quantized_max; + GetBoundsForQuantizedDataType(quantized_data_type, &quantized_min, + &quantized_max); + if (fakequant_op->narrow_range) { + quantized_min++; } + // It is important for matching accuracy between TF training and TFLite + // inference, that the following variables are float to match TF's + // FakeQuantWithMinMaxVarsFunctor. + const float scale = qparams.scale; + const float nudged_min = (quantized_min - qparams.zero_point) * scale; + const float nudged_max = (quantized_max - qparams.zero_point) * scale; + tflite::FakeQuantizeArray(scale, nudged_min, nudged_max, + input_buffer.data.data(), output_buffer.data.data(), + size); + if (IsDiscardableArray(*model, fakequant_op->inputs[0]) && CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) { model->EraseArray(fakequant_op->inputs[0]); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc index debe298a5a93034bcb928d7384b5ec1fc7439e47..36d7dad0ce9de81ec132ef992538b6022916bfbd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc @@ -69,7 +69,7 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) { } const auto* op = static_cast(base_op); - CHECK_EQ(op->inputs.size(), 2); + CHECK_GE(op->inputs.size(), 2); CHECK_EQ(op->outputs.size(), 1); auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { @@ -81,10 +81,14 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) { return false; } - // Only handling axis=0 for now. - if (op->axis != 0) { + if (!op->axis) { + // Yield until axis has been set by ResolveGatherAttributes. + return false; + } + if (op->axis.value() != 0) { + // Only handling axis=0 for now. AddMessageF("%s has axis %d; only axis=0 is supported", LogName(*op), - op->axis); + op->axis.value()); return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc similarity index 82% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc rename to tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc index a4d5f1923a1dffdff1ef51eb5317fa5794a8bc27..e86616574d5a0f1345cde167d4ce0d41665d5a02 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc @@ -24,7 +24,7 @@ namespace toco { namespace { template -void Stack(Model* model, StackOperator const& op) { +void Pack(Model* model, PackOperator const& op) { auto& output_array = model->GetArray(op.outputs[0]); CHECK(output_array.data_type == Type); @@ -33,8 +33,8 @@ void Stack(Model* model, StackOperator const& op) { output_array.GetMutableBuffer().data; output_data.resize(RequiredBufferSizeForShape(output_array.shape())); - // Stack inputs into buffer - CHECK_EQ(op.axis, 0) << "Stacking only supported along first axis"; + // Pack inputs into buffer + CHECK_EQ(op.axis, 0) << "Packing only supported along first axis"; int dst_offset = 0; for (int i = 0; i < op.inputs.size(); i++) { // Append array data to output for each input array @@ -49,13 +49,13 @@ void Stack(Model* model, StackOperator const& op) { } // namespace -bool ResolveConstantStack::Run(Model* model, std::size_t op_index) { +bool ResolveConstantPack::Run(Model* model, std::size_t op_index) { auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); - if (base_op->type != OperatorType::kStack) { + if (base_op->type != OperatorType::kPack) { return false; } - const auto* op = static_cast(base_op); + const auto* op = static_cast(base_op); CHECK_GE(op->inputs.size(), 1); CHECK_EQ(op->outputs.size(), 1); @@ -82,24 +82,24 @@ bool ResolveConstantStack::Run(Model* model, std::size_t op_index) { // Handle negative axis axis += model->GetArray(op->inputs[0]).shape().dims().size(); } - CHECK_EQ(axis, 0) << "Stacking only supported along 0th axis"; + CHECK_EQ(axis, 0) << "Packing only supported along 0th axis"; CHECK(!output_array.buffer); switch (output_array.data_type) { case ArrayDataType::kFloat: - Stack(model, *op); + Pack(model, *op); break; case ArrayDataType::kUint8: - Stack(model, *op); + Pack(model, *op); break; case ArrayDataType::kInt32: - Stack(model, *op); + Pack(model, *op); break; case ArrayDataType::kInt64: - Stack(model, *op); + Pack(model, *op); break; default: - LOG(FATAL) << "Unsupported data type given to Stack op with output \"" + LOG(FATAL) << "Unsupported data type given to Pack op with output \"" << op->outputs[0] << "\""; break; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc index f89ef85fdb63ca4906c7f016e86bb1f9d8a7099a..fe3882c28df893080846b24ffa3cac7267f08ae2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -57,8 +57,8 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { case OperatorType::kSqrt: case OperatorType::kSquare: case OperatorType::kSum: - case OperatorType::kMin: // Reduction Min - case OperatorType::kMax: // Reduction Max + case OperatorType::kReduceMin: // Reduction Min + case OperatorType::kReduceMax: // Reduction Max case OperatorType::kReshape: case OperatorType::kRelu6: case OperatorType::kRelu1: @@ -196,7 +196,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { } output_float_data[i] = sum; } - } else if (unary_op->type == OperatorType::kMin) { + } else if (unary_op->type == OperatorType::kReduceMin) { // At the moment only full reduction across all dimensions is supported. // TODO(starka): Output should not be padded. for (int i = 0; i < output_dims_count; i++) { @@ -207,7 +207,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { min = std::min(min, (*input_float_data)[i]); } output_float_data[0] = min; - } else if (unary_op->type == OperatorType::kMax) { + } else if (unary_op->type == OperatorType::kReduceMax) { // At the moment only full reduction across all dimensions is supported. // TODO(starka): Output should not be padded. for (int i = 0; i < output_dims_count; i++) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc new file mode 100644 index 0000000000000000000000000000000000000000..0dda1fd0b35fb0cdc3c605360df5126c52c05403 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_fake_quant_args_from_vars.cc @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveFakeQuantArgsFromVars::Run(Model* model, std::size_t op_index) { + const auto fakequant_it = model->operators.begin() + op_index; + auto* fakequant_base_op = fakequant_it->get(); + if (fakequant_base_op->type != OperatorType::kFakeQuant) { + return false; + } + auto* fakequant_op = static_cast(fakequant_base_op); + + if (fakequant_op->minmax) { + // Already resolved. + return false; + } + + CHECK_EQ(fakequant_op->inputs.size(), 3); + // We need to yield until the min and max parameters have been + // resolved to constant arrays. + for (int i = 1; i <= 2; i++) { + if (!IsConstantParameterArray(*model, fakequant_op->inputs[i])) { + return false; + } + } + + // Obtain the final min/max values + const auto& min_array = model->GetArray(fakequant_op->inputs[1]); + const auto& max_array = model->GetArray(fakequant_op->inputs[2]); + CHECK_EQ(RequiredBufferSizeForShape(min_array.shape()), 1); + CHECK_EQ(RequiredBufferSizeForShape(max_array.shape()), 1); + fakequant_op->minmax.reset(new MinMax); + MinMax& minmax = *fakequant_op->minmax; + minmax.min = min_array.GetBuffer().data[0]; + minmax.max = max_array.GetBuffer().data[0]; + // We always want [min, max] to contain 0. + if (minmax.min > 0 || minmax.max < 0) { + LOG(ERROR) << "For " << LogName(*fakequant_op) << " the MinMax range " + << "[" << minmax.min << ", " << minmax.max + << "] does not contain 0. " + << "Proceeding by tweaking it to contain 0, which will result " + "in poor accuracy."; + } + minmax.min = std::min(minmax.min, 0.); + minmax.max = std::max(minmax.max, 0.); + + // We won't use the input arrays that provided these min and max + // values, anymore. Delete them unless they are used by something + // else. + for (int i = 1; i <= 2; i++) { + DeleteArrayIfUsedOnce(fakequant_op->inputs[i], model); + } + fakequant_op->inputs.resize(1); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc new file mode 100644 index 0000000000000000000000000000000000000000..ce825c91af428c866ca9f83b765399f209606af9 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_gather_attributes.cc @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveGatherAttributes::Run(Model* model, std::size_t op_index) { + auto* gather_op = model->operators[op_index].get(); + if (gather_op->type != OperatorType::kGather) return false; + auto* op = static_cast(gather_op); + + if (op->axis) { + // Attributes already resolved + return false; + } + if (op->inputs.size() != 3) return false; + if (!IsConstantParameterArray(*model, op->inputs[2])) return false; + + const auto& indices_array = model->GetArray(op->inputs[2]); + if (!indices_array.has_shape()) return false; + const auto& axis_data = indices_array.GetBuffer().data; + CHECK_EQ(axis_data.size(), 1) + << "Multidimensional gather not supported on " << LogName(*op); + op->axis = {axis_data[0]}; + + // Drop the axis array as we no longer need it. + DeleteArrayIfUsedOnce(op->inputs[2], model); + op->inputs.resize(2); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc similarity index 60% rename from tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc rename to tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc index 013b50ac9ba8a51c23b19953d987b2fbf63fcea1..7d456af2fbc69352662b798cf1314f1653ef9f98 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc @@ -24,11 +24,8 @@ limitations under the License. namespace toco { -bool ResolveMeanAttributes::Run(Model* model, std::size_t op_index) { - auto* mean_op = model->operators[op_index].get(); - if (mean_op->type != OperatorType::kMean) return false; - auto* op = static_cast(mean_op); - +template +bool ResolveAttributes(Model* model, T* op) { if (!op->axis.empty()) { // Attributes already resolved return false; @@ -36,10 +33,28 @@ bool ResolveMeanAttributes::Run(Model* model, std::size_t op_index) { if (op->inputs.size() != 2) return false; if (!IsConstantParameterArray(*model, op->inputs[1])) return false; - const auto& indices_array = model->GetArray(op->inputs[1]); + const Array& indices_array = model->GetArray(op->inputs[1]); if (!indices_array.has_shape()) return false; op->axis = indices_array.GetBuffer().data; return true; } +bool ResolveReduceAttributes::Run(Model* model, std::size_t op_index) { + Operator* op = model->operators[op_index].get(); + switch (op->type) { + case OperatorType::kMean: + return ResolveAttributes(model, static_cast(op)); + case OperatorType::kSum: + return ResolveAttributes(model, static_cast(op)); + case OperatorType::kReduceProd: + return ResolveAttributes(model, static_cast(op)); + case OperatorType::kReduceMin: + return ResolveAttributes(model, static_cast(op)); + case OperatorType::kReduceMax: + return ResolveAttributes(model, static_cast(op)); + default: + return false; + } +} + } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc index d496f5ae5eeeca5063e23b25498b0ac450e9f946..fcf30bd34725fc59bb819e75deda0dadf330f372 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc @@ -32,21 +32,34 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { const auto* matmul_op = static_cast(matmul_it->get()); + // Handling transposition of the first input here isn't very simple because + // we need to know the actual shape in order to produce a proper + // TransposeOperator. However, the second input is supposed to be 2D, so we + // can actually handle transposition of that matrix, which happens to be more + // common anyway. + CHECK(!matmul_op->transpose_a); + // Reorder the axes on the second input. TensorFlow uses row-major ordering // on both inputs, however this is inefficient for the FullyConnected // operator. We'll transpose the second input to be in column-major order now // and let constant propagation optimize things (if possible). - auto* transpose_op = new TransposeOperator; - transpose_op->inputs = { - matmul_op->inputs[1], - CreateInt32Array( - model, - AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose/perm"), - {1, 0})}; - transpose_op->outputs = { - AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")}; - model->GetOrCreateArray(transpose_op->outputs[0]); - model->operators.emplace(matmul_it, transpose_op); + string input_lhs = matmul_op->inputs[0]; + string input_rhs = matmul_op->inputs[1]; + if (!matmul_op->transpose_b) { + auto* transpose_op = new TransposeOperator; + transpose_op->inputs = { + matmul_op->inputs[1], + CreateInt32Array(model, + AvailableArrayName( + *model, matmul_op->inputs[1] + "/transpose/perm"), + {1, 0})}; + transpose_op->outputs = { + AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")}; + model->GetOrCreateArray(transpose_op->outputs[0]); + model->operators.emplace(matmul_it, transpose_op); + + input_rhs = transpose_op->outputs[0]; + } // Refresh iterator. matmul_it = model->operators.begin(); @@ -57,9 +70,6 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { } DCHECK_EQ(matmul_it->get(), matmul_op); - string input_lhs = matmul_op->inputs[0]; - string input_rhs = transpose_op->outputs[0]; - // Construct the new FullyConnectedOperator. auto* fc_op = new FullyConnectedOperator; fc_op->outputs = matmul_op->outputs; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD index 95e8433be2a332cfce5175f4f65ea0b83d5638c5..e163fc9ae1422504ef1b0a3c567c420f649f0827 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD @@ -10,6 +10,7 @@ load( tf_cc_test( name = "lstm_utils_test", srcs = ["lstm_utils_test.cc"], + tags = ["no_oss"], deps = [ "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", @@ -21,6 +22,7 @@ tf_cc_test( tf_cc_test( name = "quantize_weights_test", srcs = ["quantize_weights_test.cc"], + tags = ["no_oss"], deps = [ "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", @@ -33,6 +35,7 @@ tf_cc_test( tf_cc_test( name = "resolve_constant_concatenation_test", srcs = ["resolve_constant_concatenation_test.cc"], + tags = ["no_oss"], deps = [ "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc index 2c7046c8c77c94a89fc05a26d7d72b3661380475..69bad2fa89cb89cd74e3a4bca98da906a322a670 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc @@ -64,7 +64,14 @@ bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) { const string& tmp_array_name = AvailableArrayName(*model, op->outputs[0] + "_unfused"); CHECK(!model->HasArray(tmp_array_name)); - model->GetOrCreateArray(tmp_array_name); + + const auto& output_array = model->GetArray(op->outputs[0]); + auto& tmp_array = model->GetOrCreateArray(tmp_array_name); + if (output_array.quantization_params) { + tmp_array.GetOrCreateQuantizationParams() = + output_array.GetQuantizationParams(); + } + ac_op->inputs = {tmp_array_name}; op->outputs = {tmp_array_name}; return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc b/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc index cbea39bcc09ea6787c055d5aaca7f291c2b47a7f..dd9e26e68bd7e6d5cb751fdbf705b861c3f2f188 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc @@ -187,6 +187,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { AvailableArrayName(*model, gather_ops[0]->inputs[0] + "_permuted/perm")); gather_params_permute_op->outputs.push_back( AvailableArrayName(*model, gather_ops[0]->inputs[0] + "_permuted")); + gather_params_permute_op->axis = {0}; op_it = model->operators.emplace(op_it, gather_params_permute_op) + 1; model->GetOrCreateArray(gather_params_permute_op->outputs[0]); const auto& partition_array = model->GetArray(gather_ops[0]->inputs[0]); @@ -212,6 +213,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { mod_op->inputs[0]}; merged_gather_op->outputs = {stitch_op->outputs[0]}; merged_gather_op->input_rank = partition_array.shape().dimensions_count(); + merged_gather_op->axis = {0}; model->operators.emplace(op_it, merged_gather_op); AddMessageF( diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc index da81ea2ff3b4ab0bee0550874a9c4ea1044a3579..5f0cece67a49de6d50fd08896d14d3f27df46b44 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/unroll_batch_matmul.cc @@ -76,7 +76,7 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) { AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op), batch_count); auto tail_it = batch_op_it; - std::vector stack_inputs; + std::vector pack_inputs; for (int batch = 0; batch < batch_count; ++batch) { std::string batch_name = std::string(batch_op->outputs[0]) + "_b" + std::to_string(batch); @@ -146,15 +146,15 @@ bool UnrollBatchMatMul::Run(Model* model, std::size_t op_index) { tail_it = model->operators.emplace(tail_it, matmul_op) + 1; // Add to stack. - stack_inputs.push_back(matmul_op->outputs[0]); + pack_inputs.push_back(matmul_op->outputs[0]); } - // The stack that will join all the individual matmul results together. - auto* stack_op = new StackOperator; - stack_op->inputs = stack_inputs; - stack_op->outputs = {batch_op->outputs[0]}; - stack_op->axis = 0; - model->operators.emplace(tail_it, stack_op); + // The pack that will join all the individual matmul results together. + auto* pack_op = new PackOperator; + pack_op->inputs = pack_inputs; + pack_op->outputs = {batch_op->outputs[0]}; + pack_op->axis = 0; + model->operators.emplace(tail_it, pack_op); // Remove the old batch matmul now that we've unrolled. batch_op_it = model->operators.begin(); diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 55e39d963f97eb35790b460ed8c634b32abf490f..b7fffbce2223a71ac1e16ec1ce18ba9f610cc2ac 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -215,7 +215,7 @@ tensorflow::Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_FLOAT); const auto& input_shape = input_tensor.tensor_shape(); - CHECK_LE(input_shape.dim_size(), 4); + CHECK_LE(input_shape.dim_size(), 6); int input_flat_size; auto status = ImportShape(input_shape.dim(), &input_flat_size, output_array->mutable_shape()); @@ -253,7 +253,7 @@ tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_QUINT8); const auto& input_shape = input_tensor.tensor_shape(); - CHECK_LE(input_shape.dim_size(), 4); + CHECK_LE(input_shape.dim_size(), 6); int input_flat_size; auto status = ImportShape(input_shape.dim(), &input_flat_size, output_array->mutable_shape()); @@ -290,7 +290,7 @@ tensorflow::Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT32); const auto& input_shape = input_tensor.tensor_shape(); - CHECK_LE(input_shape.dim_size(), 4); + CHECK_LE(input_shape.dim_size(), 6); int input_flat_size; auto status = ImportShape(input_shape.dim(), &input_flat_size, output_array->mutable_shape()); @@ -326,7 +326,7 @@ tensorflow::Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT64); const auto& input_shape = input_tensor.tensor_shape(); - CHECK_LE(input_shape.dim_size(), 4); + CHECK_LE(input_shape.dim_size(), 6); int input_flat_size; auto status = ImportShape(input_shape.dim(), &input_flat_size, output_array->mutable_shape()); @@ -363,7 +363,7 @@ tensorflow::Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_BOOL); const auto& input_shape = input_tensor.tensor_shape(); - CHECK_LE(input_shape.dim_size(), 4); + CHECK_LE(input_shape.dim_size(), 6); int input_flat_size; auto status = ImportShape(input_shape.dim(), &input_flat_size, output_array->mutable_shape()); @@ -409,7 +409,7 @@ tensorflow::Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_STRING); const auto& input_shape = input_tensor.tensor_shape(); - CHECK_LE(input_shape.dim_size(), 4); + CHECK_LE(input_shape.dim_size(), 6); int input_flat_size; auto status = ImportShape(input_shape.dim(), &input_flat_size, output_array->mutable_shape()); @@ -755,6 +755,9 @@ tensorflow::Status ConvertFakeQuantWithMinMaxArgs( op->outputs.push_back(node.name()); // tf.fake_quant_with_min_max_args num_bits defaults to 8. op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8; + if (HasAttr(node, "narrow_range")) { + op->narrow_range = GetBoolAttr(node, "narrow_range"); + } model->operators.emplace_back(op); return tensorflow::Status::OK(); } @@ -774,6 +777,9 @@ tensorflow::Status ConvertFakeQuantWithMinMaxVars( } op->outputs.push_back(node.name()); op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8; + if (HasAttr(node, "narrow_range")) { + op->narrow_range = GetBoolAttr(node, "narrow_range"); + } model->operators.emplace_back(op); return tensorflow::Status::OK(); } @@ -799,22 +805,6 @@ tensorflow::Status ConvertSqueezeOperator( return tensorflow::Status::OK(); } -tensorflow::Status ConvertSumOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sum"); - TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - auto* op = new TensorFlowSumOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); - if (HasAttr(node, "keep_dims")) { - op->keep_dims = GetBoolAttr(node, "keep_dims"); - } - return tensorflow::Status::OK(); -} - tensorflow::Status ConvertSplitOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -984,18 +974,19 @@ tensorflow::Status ConvertMatMulOperator( Model* model) { TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - // Transpose flags should be easy to support, but we don't have a - // GraphDef with them to test on at the moment. - CHECK_EQ(HasAttr(node, "transpose_a") && GetBoolAttr(node, "transpose_a"), - false); - CHECK_EQ(HasAttr(node, "transpose_b") && GetBoolAttr(node, "transpose_b"), - false); CHECK(!HasAttr(node, "adjoint_a") || (GetBoolAttr(node, "adjoint_a") == false)); CHECK(!HasAttr(node, "adjoint_b") || (GetBoolAttr(node, "adjoint_b") == false)); auto* matmul = new TensorFlowMatMulOperator; + if (HasAttr(node, "transpose_a")) { + matmul->transpose_a = GetBoolAttr(node, "transpose_a"); + } + if (HasAttr(node, "transpose_b")) { + matmul->transpose_b = GetBoolAttr(node, "transpose_b"); + } + matmul->inputs = {node.input(0), node.input(1)}; matmul->outputs = {node.name()}; model->operators.emplace_back(matmul); @@ -1051,41 +1042,16 @@ tensorflow::Status ConvertSimpleOperator( return ConvertSimpleOperator(node, tf_import_flags, model); } -tensorflow::Status ConvertMaxOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Max"); - TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - auto* op = new TensorFlowMaxOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); - if (HasAttr(node, "keep_dims")) { - op->keep_dims = GetBoolAttr(node, "keep_dims"); - } - return tensorflow::Status::OK(); -} - -tensorflow::Status ConvertMinOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Min"); - TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - auto* op = new TensorFlowMinOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); - if (HasAttr(node, "keep_dims")) { - op->keep_dims = GetBoolAttr(node, "keep_dims"); - } - return tensorflow::Status::OK(); -} - tensorflow::Status ConvertUnsupportedOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { + // Names of special attributes in TF graph that are used by Toco. + static constexpr char kAttrOutputQuantized[] = "_output_quantized"; + static constexpr char kAttrOutputTypes[] = "_output_types"; + static constexpr char kAttrOutputShapes[] = "_output_shapes"; + static constexpr char kAttrSupportOutputTypeFloatInQuantizedOp[] = + "_support_output_type_float_in_quantized_op"; + LOG(INFO) << "Converting unsupported operation: " << node.op(); auto* op = new TensorFlowUnsupportedOperator; const int num_inputs = GetInputsCount(node, tf_import_flags); @@ -1096,11 +1062,17 @@ tensorflow::Status ConvertUnsupportedOperator( op->tensorflow_op = node.op(); node.SerializeToString(&op->tensorflow_node_def); model->operators.emplace_back(op); - if (HasAttr(node, "_output_quantized")) { - op->quantized = GetBoolAttr(node, "_output_quantized"); + // Parse if the op supports quantization + if (HasAttr(node, kAttrOutputQuantized)) { + op->quantized = GetBoolAttr(node, kAttrOutputQuantized); + } + // Parse if the quantized op allows output arrays of type float + if (HasAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp)) { + op->support_output_type_float_in_quantized_op = + GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp); } - if (HasAttr(node, "_output_types")) { - const auto& output_types = GetListAttr(node, "_output_types"); + if (HasAttr(node, kAttrOutputTypes)) { + const auto& output_types = GetListAttr(node, kAttrOutputTypes); for (int i = 0; i < output_types.type_size(); ++i) { op->output_data_types.push_back(ConvertDataType(output_types.type(i))); } @@ -1108,6 +1080,19 @@ tensorflow::Status ConvertUnsupportedOperator( const auto& output_type = GetDataTypeAttr(node, "Tout"); op->output_data_types.push_back(ConvertDataType(output_type)); } + if (HasAttr(node, kAttrOutputShapes)) { + const auto& output_shapes = GetListAttr(node, kAttrOutputShapes); + Shape output_shape; + for (int i = 0; i < output_shapes.shape_size(); ++i) { + const auto status = + ImportShape(output_shapes.shape(i).dim(), /*input_flat_size=*/nullptr, + &output_shape); + if (!status.ok()) { + return status; + } + op->output_shapes.push_back(output_shape); + } + } return tensorflow::Status::OK(); } @@ -1222,17 +1207,26 @@ tensorflow::Status ConvertGatherOperator( auto* op = new GatherOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); - // TODO(ahentz): we currently ignore the third tensor in GatherV2 but we - // should read it an pass it on to the TF Lite Interpreter. + if (node.input_size() >= 3) { + // GatherV2 form where we are provided an axis. It may be either a constant + // or runtime defined value, so we just wire up the array and let + // ResolveGatherAttributes take care of it later on. + const auto axis_data_type = GetDataTypeAttr(node, "Taxis"); + CHECK(axis_data_type == DT_INT32 || axis_data_type == DT_INT64); + op->inputs.push_back(node.input(2)); + } else { + // Gather form that assumes axis=0. + op->axis = {0}; + } op->outputs.push_back(node.name()); model->operators.emplace_back(op); return tensorflow::Status::OK(); } -tensorflow::Status ConvertArgMaxOperator( +template +tensorflow::Status ConvertArgMinMaxOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK_EQ(node.op(), "ArgMax"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); const auto axis_data_type = HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; @@ -1241,7 +1235,7 @@ tensorflow::Status ConvertArgMaxOperator( : DT_INT64; CHECK(axis_data_type == DT_INT64 || axis_data_type == DT_INT32); CHECK(output_type == DT_INT64 || output_type == DT_INT32); - auto* op = new ArgMaxOperator; + auto* op = new Op; op->output_data_type = ConvertDataType(output_type); op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -1250,6 +1244,20 @@ tensorflow::Status ConvertArgMaxOperator( return tensorflow::Status::OK(); } +tensorflow::Status ConvertArgMaxOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "ArgMax"); + return ConvertArgMinMaxOperator(node, tf_import_flags, model); +} + +tensorflow::Status ConvertArgMinOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "ArgMin"); + return ConvertArgMinMaxOperator(node, tf_import_flags, model); +} + tensorflow::Status ConvertResizeBilinearOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1404,12 +1412,12 @@ tensorflow::Status ConvertBatchToSpaceNDOperator( return tensorflow::Status::OK(); } -tensorflow::Status ConvertMeanOperator( +template +tensorflow::Status ConvertReduceOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK_EQ(node.op(), "Mean"); TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - auto* op = new MeanOperator; + auto* op = new T; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); op->outputs.push_back(node.name()); @@ -1542,11 +1550,15 @@ tensorflow::Status ConvertRangeOperator( return tensorflow::Status::OK(); } -tensorflow::Status ConvertStackOperator( +// Note that it's easy to confuse/conflate "Stack" and "Pack" operators, but +// they aren't the same thing. tf.stack results in a "Pack" operator. "Stack" +// operators also exist, but involve manipulating the TF runtime stack, and are +// not directly related to tf.stack() usage. +tensorflow::Status ConvertPackOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK((node.op() == "Stack") || (node.op() == "Pack")); - auto* op = new StackOperator; + CHECK_EQ(node.op(), "Pack"); + auto op = absl::make_unique(); const int num_inputs = GetInputsCount(node, tf_import_flags); QCHECK_GE(num_inputs, 1) << node.op() @@ -1556,10 +1568,11 @@ tensorflow::Status ConvertStackOperator( for (int i = 0; i < num_inputs; ++i) { op->inputs.push_back(node.input(i)); } - // Both "Stack" and "Pack" have the "axis" attribute. + op->values_count = HasAttr(node, "N") ? GetIntAttr(node, "N") : num_inputs; op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0; + op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T")); op->outputs.push_back(node.name()); - model->operators.emplace_back(op); + model->operators.emplace_back(std::move(op)); return tensorflow::Status::OK(); } @@ -1605,6 +1618,24 @@ tensorflow::Status ConvertShapeOperator( return tensorflow::Status::OK(); } +tensorflow::Status ConvertAnyOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "Any"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); + const auto idx_type = + HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; + CHECK(idx_type == DT_INT32); + auto op = absl::make_unique(); + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + op->keep_dims = + HasAttr(node, "keep_dims") ? GetBoolAttr(node, "keep_dims") : false; + model->operators.push_back(std::move(op)); + return tensorflow::Status::OK(); +} + void StripCaretFromArrayNames(Model* model) { for (auto& op : model->operators) { for (auto& input : op->inputs) { @@ -1823,6 +1854,55 @@ tensorflow::Status ConvertSparseToDenseOperator( return tensorflow::Status::OK(); } +tensorflow::Status ConvertOneHotOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "OneHot"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4)); + + const auto dtype = GetDataTypeAttr(node, "T"); + // TODO(b/111744875): Support DT_UINT8 and quantization. + CHECK(dtype == DT_INT32 || dtype == DT_INT64 || dtype == DT_FLOAT || + dtype == DT_BOOL); + + auto op = absl::make_unique(); + op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : -1; + for (const string& input : node.input()) { + op->inputs.push_back(input); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op.release()); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertCTCBeamSearchDecoderOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "CTCBeamSearchDecoder"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); + + auto* op = new CTCBeamSearchDecoderOperator; + for (const string& input : node.input()) { + op->inputs.push_back(input); + } + + op->beam_width = + HasAttr(node, "beam_width") ? GetIntAttr(node, "beam_width") : 1; + op->top_paths = + HasAttr(node, "top_paths") ? GetIntAttr(node, "top_paths") : 1; + op->merge_repeated = HasAttr(node, "merge_repeated") + ? GetBoolAttr(node, "merge_repeated") + : true; + + // There are top_paths + 1 outputs. + op->outputs.push_back(node.name()); // Implicit :0. + for (int i = 0; i < op->top_paths; ++i) { + op->outputs.push_back(node.name() + ":" + std::to_string(i + 1)); + } + model->operators.emplace_back(op); + return tensorflow::Status::OK(); +} + } // namespace namespace internal { @@ -1837,7 +1917,9 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Add", ConvertSimpleOperator}, {"AddN", ConvertSimpleOperator}, {"All", ConvertSimpleOperator}, + {"Any", ConvertAnyOperator}, {"ArgMax", ConvertArgMaxOperator}, + {"ArgMin", ConvertArgMinOperator}, {"Assert", ConvertSimpleOperator}, {"AvgPool", ConvertAvgPoolOperator}, {"BatchMatMul", ConvertBatchMatMulOperator}, @@ -1852,6 +1934,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Const", ConvertConstOperator}, {"Conv2D", ConvertConvOperator}, {"Conv2DBackpropInput", ConvertTransposeConvOperator}, + {"CTCBeamSearchDecoder", ConvertCTCBeamSearchDecoderOperator}, {"DepthToSpace", ConvertDepthToSpaceOperator}, {"DepthwiseConv2dNative", ConvertDepthwiseConvOperator}, {"Div", ConvertSimpleOperator}, @@ -1878,28 +1961,32 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Less", ConvertSimpleOperator}, {"LessEqual", ConvertSimpleOperator}, {"Log", ConvertSimpleOperator}, - {"Log", ConvertSimpleOperator}, + {"LogicalAnd", ConvertSimpleOperator}, + {"LogicalOr", ConvertSimpleOperator}, + {"LogicalNot", ConvertSimpleOperator}, {"LogSoftmax", ConvertSimpleOperator}, {"MatMul", ConvertMatMulOperator}, - {"Max", ConvertMaxOperator}, + {"Max", ConvertReduceOperator}, {"MaxPool", ConvertMaxPoolOperator}, {"Maximum", ConvertSimpleOperator}, - {"Mean", ConvertMeanOperator}, + {"Mean", ConvertReduceOperator}, {"Merge", ConvertSimpleOperator}, - {"Min", ConvertMinOperator}, + {"Min", ConvertReduceOperator}, {"Minimum", ConvertSimpleOperator}, {"Mul", ConvertSimpleOperator}, {"Neg", ConvertSimpleOperator}, {"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge}, {"NoOp", ConvertNoOpOperator}, {"NotEqual", ConvertSimpleOperator}, - {"Pack", ConvertStackOperator}, + {"OneHot", ConvertOneHotOperator}, + {"Pack", ConvertPackOperator}, {"Pad", ConvertSimpleOperator}, {"PadV2", ConvertSimpleOperator}, {"ParallelDynamicStitch", ConvertDynamicStitchOperator}, {"Placeholder", ConvertPlaceholderOperator}, {"PlaceholderWithDefault", ConvertIdentityOperator}, {"Pow", ConvertSimpleOperator}, + {"Prod", ConvertReduceOperator}, {"RandomUniform", ConvertRandomUniform}, {"Range", ConvertRangeOperator}, {"Rank", ConvertSimpleOperator}, @@ -1922,11 +2009,10 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Sqrt", ConvertSimpleOperator}, {"Square", ConvertSimpleOperator}, {"Squeeze", ConvertSqueezeOperator}, - {"Stack", ConvertStackOperator}, {"StopGradient", ConvertIdentityOperator}, {"StridedSlice", ConvertStridedSliceOperator}, {"Sub", ConvertSimpleOperator}, - {"Sum", ConvertSumOperator}, + {"Sum", ConvertReduceOperator}, {"Svdf", ConvertSvdfOperator}, {"Switch", ConvertSwitchOperator}, {"Tanh", ConvertSimpleOperator}, diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index abe0bf3c54460709dc67a4d5835df77ca8a83575..412e14c4ada3280dafcd2fcfa59e2908dd785f9f 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/runtime/types.h" #include "tensorflow/contrib/lite/toco/toco_port.h" @@ -63,6 +64,7 @@ enum class OperatorType : uint8 { kMaxPool, kFakeQuant, kMul, + kOneHot, kRandomUniform, kRange, kRank, @@ -81,10 +83,11 @@ enum class OperatorType : uint8 { kResizeBilinear, kSin, kSpaceToBatchND, - kStack, + kPack, kBatchToSpaceND, kPad, kPadV2, + kReduceProd, // Reduction product kStridedSlice, kSlice, kSqueeze, @@ -106,10 +109,10 @@ enum class OperatorType : uint8 { kIdentity, kLess, kLessEqual, - kMax, // Reduction Max - kMaximum, // Element-wise Maximum - kMin, // Reduction Min - kMinimum, // Element-wise Minimum + kReduceMax, // Reduction Max + kMaximum, // Element-wise Maximum + kReduceMin, // Reduction Min + kMinimum, // Element-wise Minimum kMatMul, kMerge, kNeg, @@ -140,6 +143,12 @@ enum class OperatorType : uint8 { kEqual, kNotEqual, kPow, + kArgMin, + kAny, + kLogicalAnd, + kLogicalNot, + kLogicalOr, + kCTCBeamSearchDecoder, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -286,6 +295,46 @@ struct Buffer : GenericBuffer { std::vector> data; }; +class Shape { + public: + // For Shape, we stick to half-way encapsulation for now: + // we hide the raw dims_ member, but expose it raw by accessors + // because from some brainstorming, it's not at all easy to + // anticipate which flavor of more hermetic encapsulation would + // actually buy us future-proof-ness without being needlessly + // cumbersome. + Shape() {} + Shape(std::initializer_list dim_list) : dims_(dim_list) {} + + void ReplaceDims(std::initializer_list dim_list) { + dims_ = std::vector(dim_list); + } + + const std::vector& dims() const { return dims_; } + std::vector* mutable_dims() { return &dims_; } + const int dimensions_count() const { return dims_.size(); } + + // We still have that one convenience accessor to avoid + // the awkward double bracket issue: shape.dims()[i]. + int dims(int i) const { + // Always check for out-of-bounds accesses, even in optimized builds where + // standard assertions are disabled. Out-of-bounds access here is a common + // occurrence. + CHECK_GE(i, 0); + CHECK_GT(dims_.size(), i); + return dims_[i]; + } + + bool operator==(const Shape& comp) const { + return (this->dims_ == comp.dims()); + } + + bool operator!=(const Shape& comp) const { return !((*this) == comp); } + + private: + std::vector dims_; +}; + // Base class for all operator classes. struct Operator { // Non-default-constructible: only OperatorType-specific subclass @@ -390,6 +439,28 @@ struct ConvOperator : Operator { int dilation_height_factor = 1; }; +// CTCBeamSearchDecoder operator: +// +// Inputs: +// inputs[0]: required: the logits. +// inputs[1]: required: sequence length. +// inputs[2]: optional: beam width. +// inputs[3]: optional: top paths. +// inputs[4]: optional: merge repeated. +// +// Outputs: +// outputs[0]: deocoded. +// outputs[1]: log probability. +// +// TensorFlow equivalent: CTCBeamSearchDecoder +struct CTCBeamSearchDecoderOperator : Operator { + CTCBeamSearchDecoderOperator() + : Operator(OperatorType::kCTCBeamSearchDecoder) {} + int beam_width; + int top_paths; + bool merge_repeated = true; +}; + // Depthwise-separable convolution operator. // // Inputs: @@ -790,6 +861,7 @@ struct FakeQuantOperator : Operator { FakeQuantOperator() : Operator(OperatorType::kFakeQuant) {} std::unique_ptr minmax; int num_bits = 8; + bool narrow_range = false; }; // Element-wise division operator. @@ -837,6 +909,8 @@ struct BatchMatMulOperator : Operator { // TensorFlow equivalent: MatMul struct TensorFlowMatMulOperator : Operator { TensorFlowMatMulOperator() : Operator(OperatorType::kMatMul) {} + bool transpose_a = false; + bool transpose_b = false; }; // Padding operator. Pads a tensor with zeros. @@ -1152,10 +1226,12 @@ struct TensorFlowRsqrtOperator : Operator { // Inputs: this operator accepts any number >= 1 of inputs. // inputs[i]: the i-th array to merge. // -// TensorFlow equivalent: Stack or Pack -struct StackOperator : Operator { - StackOperator() : Operator(OperatorType::kStack) {} +// TensorFlow equivalent: Pack +struct PackOperator : Operator { + PackOperator() : Operator(OperatorType::kPack) {} + int values_count; int axis = 0; + ArrayDataType dtype = ArrayDataType::kNone; }; // Shape operator. Extracts the shape of the tensor. @@ -1225,6 +1301,19 @@ struct SubOperator : Operator { // TensorFlow equivalent: Sum struct TensorFlowSumOperator : Operator { TensorFlowSumOperator() : Operator(OperatorType::kSum) {} + std::vector axis; + bool keep_dims = false; +}; + +// Prod reduction: computes the product of all of entries across the axes. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Prod +struct TensorFlowProdOperator : Operator { + TensorFlowProdOperator() : Operator(OperatorType::kReduceProd) {} + std::vector axis; bool keep_dims = false; }; @@ -1384,29 +1473,27 @@ struct TensorFlowNotEqualOperator : Operator { TensorFlowNotEqualOperator() : Operator(OperatorType::kNotEqual) {} }; -// Global max reduction: computes the max of all of entries in the input array. -// Thus the output is "0-dimensional": it consists of a single scalar value. +// Max reduction: computes the max of all of entries across the axes. // // Inputs: // inputs[0]: required: the input array // -// TensorFlow equivalent: Max --- except that we only support the special case -// of global reduction across all dimensions. +// TensorFlow equivalent: Max struct TensorFlowMaxOperator : Operator { - TensorFlowMaxOperator() : Operator(OperatorType::kMax) {} + TensorFlowMaxOperator() : Operator(OperatorType::kReduceMax) {} + std::vector axis; bool keep_dims = false; }; -// Global min reduction: computes the min of all of entries in the input array. -// Thus the output is "0-dimensional": it consists of a single scalar value. +// Min reduction: computes the min of all of entries across the axes. // // Inputs: // inputs[0]: required: the input array // -// TensorFlow equivalent: Min --- except that we only support the special case -// of global reduction across all dimensions. +// TensorFlow equivalent: Min struct TensorFlowMinOperator : Operator { - TensorFlowMinOperator() : Operator(OperatorType::kMin) {} + TensorFlowMinOperator() : Operator(OperatorType::kReduceMin) {} + std::vector axis; bool keep_dims = false; }; @@ -1445,8 +1532,13 @@ struct TensorFlowUnsupportedOperator : Operator { string tensorflow_node_def; // A boolean indicating if the unsupported op should be treated as quantized. bool quantized = false; + // A boolean indicating if the unsupported op output should allow float values + // in quantized mode. + bool support_output_type_float_in_quantized_op = false; // Output data types std::vector output_data_types; + // Output shapes. + std::vector output_shapes; }; // Softmax activation function. @@ -1507,11 +1599,15 @@ struct FloorOperator : Operator { // Inputs: // inputs[0]: required: the params array // inputs[1]: required: the indices to gather +// inputs[2]: optional: axis // // TensorFlow equivalent: Gather struct GatherOperator : Operator { GatherOperator() : Operator(OperatorType::kGather) {} - int axis = 0; + // Axis is populated explicitly or implicitly from the axis input by + // ResolveGatherAttributes. An empty axis indicates that the axis has not yet + // be resolved. + absl::optional axis; int input_rank = 0; }; @@ -1526,6 +1622,17 @@ struct ArgMaxOperator : Operator { ArrayDataType output_data_type = ArrayDataType::kInt64; }; +// ArgMin operator. It returns the index of the minimum value along axis. +// +// Inputs: +// inputs[0]: required: the input tensor +// +// TensorFlow equivalent: ArgMin +struct ArgMinOperator : Operator { + ArgMinOperator() : Operator(OperatorType::kArgMin) {} + ArrayDataType output_data_type = ArrayDataType::kInt64; +}; + // ResizeBilinear operator. It resizes input images with bilinear interpolation. // It does not support align_corners at the moment. // @@ -1656,6 +1763,71 @@ struct PowOperator : Operator { PowOperator() : Operator(OperatorType::kPow) {} }; +// Any operator: +// +// Inputs: +// Inputs[0]: required: A boolean input tensor. +// Inputs[1]: required: reduction_indices. +// +// TensorFlow equivalent: tf.reduce_any. +struct AnyOperator : Operator { + AnyOperator() : Operator(OperatorType::kAny) {} + bool keep_dims = false; +}; + +// LogicalAnd operator: +// +// Inputs: +// Inputs[0]: required: A boolean tensor. +// Inputs[1]: required: A boolean tensor. +// +// TensorFlow equivalent: tf.logical_and. +struct LogicalAndOperator : Operator { + LogicalAndOperator() : Operator(OperatorType::kLogicalAnd) {} +}; + +// LogicalNot operator: +// +// Inputs: +// Inputs[0]: required: A boolean tensor. +// +// TensorFlow equivalent: tf.logical_not. +struct LogicalNotOperator : Operator { + LogicalNotOperator() : Operator(OperatorType::kLogicalNot) {} +}; + +// OneHot operator: +// +// Inputs: +// Inputs[0]: required: indices. +// Inputs[1]: required: depth. +// Inputs[2]: required: on_value. +// Inputs[3]: required: off_value. +// +// TensorFlow equivalent: OneHot. +struct OneHotOperator : Operator { + enum Inputs { + INDICES_INPUT = 0, + DEPTH_INPUT = 1, + ON_VALUE_INPUT = 2, + OFF_VALUE_INPUT = 3, + }; + + OneHotOperator() : Operator(OperatorType::kOneHot) {} + int axis = -1; +}; + +// LogicalOr operator: +// +// Inputs: +// Inputs[0]: required: A Bool tensor. +// Inputs[1]: required: A Bool tensor. +// +// TensorFlow equivalent: LogicalOr. +struct LogicalOrOperator : Operator { + LogicalOrOperator() : Operator(OperatorType::kLogicalOr) {} +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are @@ -1669,46 +1841,6 @@ inline bool operator<(const Alloc& a, const Alloc& b) { return a.start < b.start; } -class Shape { - public: - // For Shape, we stick to half-way encapsulation for now: - // we hide the raw dims_ member, but expose it raw by accessors - // because from some brainstorming, it's not at all easy to - // anticipate which flavor of more hermetic encapsulation would - // actually buy us future-proof-ness without being needlessly - // cumbersome. - Shape() {} - Shape(std::initializer_list dim_list) : dims_(dim_list) {} - - void ReplaceDims(std::initializer_list dim_list) { - dims_ = std::vector(dim_list); - } - - const std::vector& dims() const { return dims_; } - std::vector* mutable_dims() { return &dims_; } - const int dimensions_count() const { return dims_.size(); } - - // We still have that one convenience accessor to avoid - // the awkward double bracket issue: shape.dims()[i]. - int dims(int i) const { - // Always check for out-of-bounds accesses, even in optimized builds where - // standard assertions are disabled. Out-of-bounds access here is a common - // occurrence. - CHECK_GE(i, 0); - CHECK_GT(dims_.size(), i); - return dims_[i]; - } - - bool operator==(const Shape& comp) const { - return (this->dims_ == comp.dims()); - } - - bool operator!=(const Shape& comp) const { return !((*this) == comp); } - - private: - std::vector dims_; -}; - // Array represents an array (either a constant parameter array or an // activations array) in a Model. struct Array { @@ -1840,6 +1972,40 @@ struct Array { // If this is non-null, then these quantization parameters are to be used // to assign a meaning as real numbers to the elements of this array. std::unique_ptr quantization_params; + // narrow_range is a detail of how toco handles FakeQuant operators with + // narrow_range, see + // https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_vars + // + // For more context about what that is useful for, see the big comment in + // graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc + // + // The narrow_range flag applies only to quantized arrays, and changes + // their quantization in the following way when it is set to 'true': + // 1. The computation of {zero_point, scale} from {min, max} needs to be + // amended so that the real min value will get quantized to + // (min_quantized_value + 1) instead of just (min_quantized_value). + // E.g. for uint8 quantization, the real min value should get quantized to + // the uint8 value 1, not 0. + // 2. Quantized values should get clamped to the interval + // [min_quantized_value + 1, max_value]. Equivalently, the + // min_quantized_value should get nudged to (min_quantized_value + 1). + // The reason why 1. does not imply 2. is that real values may not belong to + // the stated [min, max] interval. Concretely, weights recorded at the last + // learning step may not fall in the [min, max] interval recorded over + // previous learning steps, as the values evolve across learning steps. + // + // Rationale why this is directly a field on Array: + // - This can't be just a field on FakeQuantOperator, because + // FakeQuantOperators are gone (DropFakeQuant) before we get to using that + // information (Quantize). We need a place to store that bit in the interim. + // - This can't be in QuantizationParams because we need to record this + // ahead of quantization, and QuantizationParams are only created during + // quantization. + // - This could be in MinMax, but that would be an abuse of what MinMax is + // about, and would break existing code that assumes that a MinMax is just + // a min and a max. Unlike MinMax which is agnostic as to the quantized + // data type, narrow_range refers to values in the quantized data type. + bool narrow_range = false; private: std::unique_ptr array_shape; @@ -1905,7 +2071,7 @@ class Model { std::size_t transient_data_size = 0; // For code-generation only: required alignment of the transient_data buffer std::size_t transient_data_alignment = 0; - // Arithmatic operations performed in the model. + // Arithmetic operations performed in the model. int64 ops_count = 0; private: diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc index 06072d1fcb0612ed8193b3a0be1317923fe95bcc..d34da63e43eee3b48e575c33ddb6c89f7701865c 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -322,6 +322,10 @@ void ReadModelFlagsFromCommandLineFlags( for (int i = 0; i < input_shapes.size(); ++i) { auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape(); shape->clear_dims(); + // Treat an empty input shape as a scalar. + if (input_shapes[i].empty()) { + continue; + } for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) { int size; CHECK(absl::SimpleAtoi(dim_str, &size)) diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD index 93fe756a55d378fa205ff88be5e18aff586e5dca..33c5b164622cee94d7ba16e7b1a3006dbacb9ca9 100644 --- a/tensorflow/contrib/lite/toco/python/BUILD +++ b/tensorflow/contrib/lite/toco/python/BUILD @@ -53,5 +53,8 @@ tf_py_test( data = [ ":toco_from_protos", ], - tags = ["no_pip"], + tags = [ + "no_oss", + "no_pip", + ], ) diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD index 336e94de1ed3238d64f521cf1347acc8f0737de7..ea1fc2827ead7e7442bbf7f569e3ea88c3b0de57 100644 --- a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD @@ -60,6 +60,7 @@ cc_library( tf_cc_test( name = "resolve_svdf_test", srcs = ["resolve_svdf_test.cc"], + tags = ["no_oss"], deps = [ ":cluster", ":cluster_utils", diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD index a02f90988b2863900b6a735fd69aa1975a762338..83e977d7b3b0a4d572faee3ba7e36690896ac8e8 100644 --- a/tensorflow/contrib/lite/toco/tflite/BUILD +++ b/tensorflow/contrib/lite/toco/tflite/BUILD @@ -37,6 +37,7 @@ tf_cc_test( srcs = [ "operator_test.cc", ], + tags = ["no_oss"], deps = [ ":operator", "//tensorflow/contrib/lite/toco:tooling_util", @@ -66,6 +67,7 @@ tf_cc_test( srcs = [ "types_test.cc", ], + tags = ["no_oss"], deps = [ ":types", "@com_google_googletest//:gtest_main", @@ -98,6 +100,7 @@ tf_cc_test( srcs = [ "export_test.cc", ], + tags = ["no_oss"], deps = [ ":export", "//tensorflow/contrib/lite/schema:schema_fbs", @@ -131,6 +134,7 @@ tf_cc_test( srcs = [ "import_test.cc", ], + tags = ["no_oss"], deps = [ ":import", "//tensorflow/contrib/lite:schema_fbs_version", diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index 19722468079a32b76f6952db6ca818da470a03ac..5ad307af14a0613188482ae17aed491dea06f984 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -336,17 +336,13 @@ void Export( auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map, &builder, &error_summary); - const string fake_quant_operation_name = "FAKE_QUANT"; - - if (error_summary.count(fake_quant_operation_name) != 0) { - LOG(ERROR) - << fake_quant_operation_name - << " operation was not converted. If running quantized make sure you " - "are passing --inference_type=QUANTIZED_UINT8 and values for " - "--std_values and --mean_values."; - // Remove the fake quant operation from the errors, since it shouldn't - // be provided a custom implementation. - error_summary.erase(fake_quant_operation_name); + for (const auto& op : model.operators) { + if (op->type == OperatorType::kFakeQuant) { + LOG(WARNING) << "FAKE_QUANT operation " << LogName(*op) + << " was not converted. If running quantized make sure you " + "are passing --inference_type=QUANTIZED_UINT8 and values " + "for --std_values and --mean_values."; + } } if (!allow_custom_ops && !error_summary.empty()) { // Remove ExpandDims and ReorderAxes from unimplemented list unless they diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc index d1fdbcb8e9131e1d65fa32ca0395bbc17b2014e7..a95937ba0f4f66fedfab6c1528c8dc4e417297d0 100644 --- a/tensorflow/contrib/lite/toco/tflite/export_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc @@ -262,7 +262,7 @@ TEST_F(VersionedOpExportTest, Export) { EXPECT_EQ(1, (*operators)[1]->opcode_index()); } -// TODO(ahentz): tests for tensors, inputs, outpus, opcodes and operators. +// TODO(ahentz): tests for tensors, inputs, outputs, opcodes and operators. } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 7e55ae92bd57447cc821b21b40ba289cb484a9ed..9ff89e9a653173fa1edc691b17b60f709af0a435 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -282,25 +282,31 @@ class DepthToSpace : public CustomOperator { int GetVersion(const Operator& op) const override { return 1; } }; -class FakeQuant : public CustomOperator { +class FakeQuant + : public BuiltinOperator { public: - using CustomOperator::CustomOperator; - void WriteOptions(const TocoOperator& op, - flexbuffers::Builder* fbb) const override { - fbb->Float("min", op.minmax->min); - fbb->Float("max", op.minmax->max); - fbb->Int("num_bits", op.num_bits); + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateFakeQuantOptions( + *builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range); } - void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { auto* minmax = new MinMax; - minmax->min = m["min"].AsFloat(); - minmax->max = m["max"].AsFloat(); + minmax->min = options.min(); + minmax->max = options.max(); op->minmax.reset(minmax); - const auto& num_bits = m["num_bits"]; - op->num_bits = num_bits.IsInt() ? num_bits.AsInt32() : 8; + op->num_bits = options.num_bits(); + op->narrow_range = options.narrow_range(); } - int GetVersion(const Operator& op) const override { return 1; } + int GetVersion(const Operator& op) const override { + const auto& fq_op = static_cast(op); + return fq_op.narrow_range ? 2 : 1; + } }; class FullyConnected @@ -364,12 +370,13 @@ class Gather : public BuiltinOperator WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - return ::tflite::CreateGatherOptions(*builder, op.axis); + int axis = op.axis ? op.axis.value() : 0; + return ::tflite::CreateGatherOptions(*builder, axis); } void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override { - op->axis = options.axis(); + op->axis = {options.axis()}; } int GetVersion(const Operator& op) const override { return 1; } @@ -761,6 +768,44 @@ class Sum int GetVersion(const Operator& op) const override { return 1; } }; +class ReduceMax + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateReducerOptions(*builder, op.keep_dims); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->keep_dims = options.keep_dims(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class ReduceProd + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateReducerOptions(*builder, op.keep_dims); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->keep_dims = options.keep_dims(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class ResizeBilinear : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateArgMinOptions( + *builder, DataType::Serialize(op.output_data_type)); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->output_data_type = DataType::Deserialize(options.output_type()); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class TransposeConv : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->values_count = options.values_count(); + op->axis = options.axis(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class Shape : public BuiltinOperator { @@ -969,6 +1053,44 @@ class Shape int GetVersion(const Operator& op) const override { return 1; } }; +class OneHot : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateOneHotOptions(*builder, op.axis); + } + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->axis = options.axis(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class CTCBeamSearchDecoder + : public CustomOperator { + public: + using CustomOperator::CustomOperator; + + void WriteOptions(const TocoOperator& op, + flexbuffers::Builder* fbb) const override { + fbb->Int("beam_width", op.beam_width); + fbb->Int("top_paths", op.top_paths); + fbb->Bool("merge_repeated", op.merge_repeated); + } + + void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { + op->beam_width = m["beam_width"].AsInt32(); + op->top_paths = m["top_paths"].AsInt32(); + op->merge_repeated = m["merge_repeated"].AsBool(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class TensorFlowUnsupported : public BaseOperator { public: using BaseOperator::BaseOperator; @@ -1078,6 +1200,12 @@ class TensorFlowUnsupported : public BaseOperator { break; case flexbuffers::TYPE_BOOL: (*attr)[key].set_b(value.AsBool()); + if (string(key) == "_output_quantized") { + op->quantized = value.AsBool(); + } + if (string(key) == "_support_output_type_float_in_quantized_op") { + op->support_output_type_float_in_quantized_op = value.AsBool(); + } break; case flexbuffers::TYPE_VECTOR_INT: { auto* list = (*attr)[key].mutable_list(); @@ -1158,6 +1286,10 @@ std::vector> BuildOperatorList() { ops.emplace_back( new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean)); ops.emplace_back(new Sum(::tflite::BuiltinOperator_SUM, OperatorType::kSum)); + ops.emplace_back(new ReduceProd(::tflite::BuiltinOperator_REDUCE_PROD, + OperatorType::kReduceProd)); + ops.emplace_back(new ReduceMax(::tflite::BuiltinOperator_REDUCE_MAX, + OperatorType::kReduceMax)); ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR, OperatorType::kResizeBilinear)); ops.emplace_back( @@ -1174,6 +1306,8 @@ std::vector> BuildOperatorList() { new Cast(::tflite::BuiltinOperator_CAST, OperatorType::kCast)); ops.emplace_back( new ArgMax(::tflite::BuiltinOperator_ARG_MAX, OperatorType::kArgMax)); + ops.emplace_back( + new ArgMin(::tflite::BuiltinOperator_ARG_MIN, OperatorType::kArgMin)); ops.emplace_back( new Tile(::tflite::BuiltinOperator_TILE, OperatorType::kTile)); ops.emplace_back(new ExpandDims(::tflite::BuiltinOperator_EXPAND_DIMS, @@ -1184,11 +1318,18 @@ std::vector> BuildOperatorList() { OperatorType::kSparseToDense)); ops.emplace_back( new Shape(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape)); + ops.emplace_back(new FakeQuant(::tflite::BuiltinOperator_FAKE_QUANT, + OperatorType::kFakeQuant)); + ops.emplace_back( + new Pack(::tflite::BuiltinOperator_PACK, OperatorType::kPack)); + ops.emplace_back( + new OneHot(::tflite::BuiltinOperator_ONE_HOT, OperatorType::kOneHot)); // Custom Operators. ops.emplace_back( new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); - ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant)); + ops.emplace_back(new CTCBeamSearchDecoder( + "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder)); ops.emplace_back(new TensorFlowUnsupported("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported)); @@ -1238,6 +1379,12 @@ std::vector> BuildOperatorList() { ops.emplace_back( new SimpleOperator("SLICE", OperatorType::kSlice)); ops.emplace_back(new SimpleOperator("POW", OperatorType::kPow)); + ops.emplace_back(new SimpleOperator( + "LOGICAL_OR", OperatorType::kLogicalOr)); + ops.emplace_back(new SimpleOperator( + "LOGICAL_AND", OperatorType::kLogicalAnd)); + ops.emplace_back(new SimpleOperator( + "LOGICAL_NOT", OperatorType::kLogicalNot)); // Element-wise operator ops.emplace_back(new SimpleOperator("SIN", OperatorType::kSin)); ops.emplace_back(new SimpleOperator("LOG", OperatorType::kLog)); diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 8b6808d3c78d8c51c1b33d09eb4082326100b028..fc854461b4e816e12e12590479501b6542258fef 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -127,6 +127,12 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator("SQRT", OperatorType::kSqrt); CheckSimpleOperator("RSQRT", OperatorType::kRsqrt); CheckSimpleOperator("POW", OperatorType::kPow); + CheckSimpleOperator("LOGICAL_OR", + OperatorType::kLogicalOr); + CheckSimpleOperator("LOGICAL_AND", + OperatorType::kLogicalAnd); + CheckSimpleOperator("LOGICAL_NOT", + OperatorType::kLogicalNot); } TEST_F(OperatorTest, BuiltinAdd) { @@ -416,6 +422,13 @@ TEST_F(OperatorTest, BuiltinArgMax) { EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type); } +TEST_F(OperatorTest, BuiltinArgMin) { + ArgMinOperator op; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("ARG_MIN", OperatorType::kArgMin), op); + EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type); +} + TEST_F(OperatorTest, BuiltinTransposeConv) { TransposeConvOperator op; op.stride_width = 123; @@ -445,6 +458,38 @@ TEST_F(OperatorTest, BuiltinSparseToDense) { EXPECT_EQ(op.validate_indices, output_toco_op->validate_indices); } +TEST_F(OperatorTest, BuiltinPack) { + PackOperator op; + op.values_count = 3; + op.axis = 1; + std::unique_ptr output_toco_op = + SerializeAndDeserialize(GetOperator("PACK", OperatorType::kPack), op); + EXPECT_EQ(op.values_count, output_toco_op->values_count); + EXPECT_EQ(op.axis, output_toco_op->axis); +} + +TEST_F(OperatorTest, BuiltinOneHot) { + OneHotOperator op; + op.axis = 2; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("ONE_HOT", OperatorType::kOneHot), op); + EXPECT_EQ(op.axis, output_toco_op->axis); +} + +TEST_F(OperatorTest, CustomCTCBeamSearchDecoder) { + CTCBeamSearchDecoderOperator op; + op.beam_width = 3; + op.top_paths = 2; + op.merge_repeated = false; + std::unique_ptr output_toco_op = + SerializeAndDeserialize(GetOperator("CTC_BEAM_SEARCH_DECODER", + OperatorType::kCTCBeamSearchDecoder), + op); + EXPECT_EQ(op.beam_width, output_toco_op->beam_width); + EXPECT_EQ(op.top_paths, output_toco_op->top_paths); + EXPECT_EQ(op.merge_repeated, output_toco_op->merge_repeated); +} + TEST_F(OperatorTest, TensorFlowUnsupported) { TensorFlowUnsupportedOperator op; op.tensorflow_op = "MyCustomUnsupportedOp"; diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/contrib/lite/toco/toco_port.cc index de76fd4032d24eff8a6c2fd0c16a911b9c00186b..14168fa33f77a75706a52f00ddfa6b1120d90626 100644 --- a/tensorflow/contrib/lite/toco/toco_port.cc +++ b/tensorflow/contrib/lite/toco/toco_port.cc @@ -38,7 +38,8 @@ void CopyToBuffer(const Cord& src, char* dest) { src.CopyToArray(dest); } } // namespace port } // namespace toco -#if defined(PLATFORM_GOOGLE) && !defined(__APPLE__) && !defined(__ANDROID__) +#if defined(PLATFORM_GOOGLE) && !defined(__APPLE__) && \ + !defined(__ANDROID__) && !defined(_WIN32) // Wrap Google file operations. @@ -115,9 +116,12 @@ string JoinPath(const string& a, const string& b) { } // namespace port } // namespace toco -#else // (__APPLE__ || __ANDROID__) +#else // !PLATFORM_GOOGLE || __APPLE__ || __ANDROID__ || _WIN32 #include +#if defined(_WIN32) +#include // for _close, _open, _read +#endif #include #include #include @@ -130,6 +134,19 @@ string JoinPath(const string& a, const string& b) { namespace toco { namespace port { +#if defined(_WIN32) +#define close _close +#define open _open +#define read _read +#define O_RDONLY _O_RDONLY +#define O_CREAT _O_CREAT +#define O_WRONLY _O_WRONLY +// Windows does not support the same set of file permissions as other platforms. +constexpr int kFileCreateMode = _S_IREAD | _S_IWRITE; +#else +constexpr int kFileCreateMode = 0664; +#endif // _WIN32 + static bool port_initialized = false; void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags) { @@ -209,7 +226,7 @@ tensorflow::Status GetContents(const string& path, string* output, tensorflow::Status SetContents(const string& filename, const string& contents, const file::Options& options) { - int fd = open(filename.c_str(), O_WRONLY | O_CREAT, 0664); + int fd = open(filename.c_str(), O_WRONLY | O_CREAT, kFileCreateMode); if (fd == -1) { return tensorflow::errors::Internal("can't open() for write"); } @@ -243,4 +260,4 @@ string JoinPath(const string& base, const string& filename) { } // namespace port } // namespace toco -#endif // (__APPLE || __ANDROID__) +#endif // !PLATFORM_GOOGLE || __APPLE || __ANDROID__ || _WIN32 diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index a057dcef121a9a17b15d0b19ca908d12d89b0367..fcd3cbab07c06737f43d822e5b16f7c188f56b1a 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -55,7 +55,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ConvertExpandDimsToReshape); transformations->Add(new ConvertSqueezeToReshape); transformations->Add(new ConvertTrivialAddNToAdd); - transformations->Add(new ConvertTrivialStackToReshape); + transformations->Add(new ConvertTrivialPackToReshape); transformations->Add(new ConvertTrivialTileToConcat); transformations->Add(new ConvertTrivialTransposeToReshape); transformations->Add(new ConvertReorderAxes); @@ -79,17 +79,18 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new FuseBinaryIntoFollowingAffine); transformations->Add(new FuseBroadcastIntoFollowingBinary); transformations->Add(new MergeReshapeIntoPrecedingTranspose); + transformations->Add(new MoveBinaryOperatorBeforeReshape); transformations->Add(new ReorderElementwiseUnary); transformations->Add(new ReorderReshapeTranspose); transformations->Add(new ResolveBatchNormalization); transformations->Add(new ResolveConstantBinaryOperator); transformations->Add(new ResolveConstantFill); transformations->Add(new ResolveConstantGather); + transformations->Add(new ResolveConstantPack); transformations->Add(new ResolveConstantRandomUniform); transformations->Add(new ResolveConstantRange); transformations->Add(new ResolveConstantReshape); transformations->Add(new ResolveConstantSlice); - transformations->Add(new ResolveConstantStack); transformations->Add(new ResolveConstantStridedSlice); transformations->Add(new ResolveConstantTranspose); transformations->Add(new ResolveConstantUnaryOperator); @@ -104,17 +105,19 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new IdentifyRelu1); transformations->Add(new IdentifyPRelu); transformations->Add(new RemoveTrivialBinaryOperator); - transformations->Add(new ReadFakeQuantMinMax); + transformations->Add(new ResolveFakeQuantArgsFromVars); + transformations->Add(new ReadArrayMinmaxAndNarrowRangeFromFakeQuant); transformations->Add(new ResolveSpaceToBatchNDAttributes); transformations->Add(new ResolveBatchToSpaceNDAttributes); transformations->Add(new ResolvePadAttributes); transformations->Add(new ResolvePadV2Attributes); transformations->Add(new ResolveStridedSliceAttributes); transformations->Add(new ResolveSliceAttributes); - transformations->Add(new ResolveMeanAttributes); + transformations->Add(new ResolveReduceAttributes); transformations->Add(new ResolveConstantShapeOrRank); transformations->Add(new MakeInitialDequantizeOperator); transformations->Add(new UnpartitionEmbeddingLookup); + transformations->Add(new ResolveGatherAttributes); } bool SupportsQuantization(FileFormat format) { @@ -272,13 +275,16 @@ void Transform(const TocoFlags& toco_flags, Model* model) { transformations.Add(new toco::MergeLstmCellInputs); } } - if (toco_flags.quantize_weights()) { - transformations.Add(new QuantizeWeights); - } transformations.Add(new ResolveConstantConcatenation); RunGraphTransformations(model, "general graph transformations", transformations); + if (toco_flags.quantize_weights()) { + // Run the quantize weights transformation after batchnorms have been + // folded into the weights. + RunGraphTransformations(model, "quantize weights transformation", + {new QuantizeWeights}); + } if (quantize_output) { if (toco_flags.propagate_fake_quant_num_bits()) { RunGraphTransformations(model, @@ -303,8 +309,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) { // HardcodeMinMax to move changes through the graph as we make changes. auto propagate_default_min_max = absl::make_unique(); - if (toco_flags.has_default_ranges_min() && - toco_flags.has_default_ranges_max()) { + bool has_default_ranges_flag = (toco_flags.has_default_ranges_min() && + toco_flags.has_default_ranges_max()); + if (has_default_ranges_flag) { propagate_default_min_max->DefineTypeRange( ArrayDataType::kUint8, toco_flags.default_ranges_min(), toco_flags.default_ranges_max()); @@ -329,6 +336,8 @@ void Transform(const TocoFlags& toco_flags, Model* model) { new EnsureUint8WeightsSafeForFastInt8Kernels; ensure_safe_for_int8_kernels->set_allow_nudging_weights( toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel()); + ensure_safe_for_int8_kernels->set_has_default_ranges_flag( + has_default_ranges_flag); RunGraphTransformations(model, "quantization graph transformations", { new RemoveTrivialQuantizedActivationFunc, diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 7dc1af9f1dc13fdc0f166e12bfd616cfdacc06c9..2ad27198119b4a8150a7381c047a4edb51aebfe6 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -350,16 +350,17 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Less) HANDLE_OPERATORTYPENAME_CASE(LessEqual) HANDLE_OPERATORTYPENAME_CASE(MatMul) - HANDLE_OPERATORTYPENAME_CASE(Max) // Reduction Max - HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum + HANDLE_OPERATORTYPENAME_CASE(ReduceMax) // Reduction Max + HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum HANDLE_OPERATORTYPENAME_CASE(Merge) - HANDLE_OPERATORTYPENAME_CASE(Min) // Reduction Min - HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum + HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min + HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum HANDLE_OPERATORTYPENAME_CASE(Neg) + HANDLE_OPERATORTYPENAME_CASE(OneHot) + HANDLE_OPERATORTYPENAME_CASE(Pack) HANDLE_OPERATORTYPENAME_CASE(Pad) HANDLE_OPERATORTYPENAME_CASE(PadV2) HANDLE_OPERATORTYPENAME_CASE(StridedSlice) - HANDLE_OPERATORTYPENAME_CASE(Stack) HANDLE_OPERATORTYPENAME_CASE(Range) HANDLE_OPERATORTYPENAME_CASE(Rank) HANDLE_OPERATORTYPENAME_CASE(Reshape) @@ -385,8 +386,10 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND) HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND) HANDLE_OPERATORTYPENAME_CASE(Mean) + HANDLE_OPERATORTYPENAME_CASE(ReduceProd) HANDLE_OPERATORTYPENAME_CASE(Svdf) HANDLE_OPERATORTYPENAME_CASE(ArgMax) + HANDLE_OPERATORTYPENAME_CASE(ArgMin) HANDLE_OPERATORTYPENAME_CASE(TopK_V2) HANDLE_OPERATORTYPENAME_CASE(Unsupported) HANDLE_OPERATORTYPENAME_CASE(Exp) @@ -397,6 +400,11 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Equal) HANDLE_OPERATORTYPENAME_CASE(NotEqual) HANDLE_OPERATORTYPENAME_CASE(Pow) + HANDLE_OPERATORTYPENAME_CASE(Any) + HANDLE_OPERATORTYPENAME_CASE(LogicalAnd) + HANDLE_OPERATORTYPENAME_CASE(LogicalNot) + HANDLE_OPERATORTYPENAME_CASE(LogicalOr) + HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE @@ -447,8 +455,12 @@ void LogSummary(int log_level, const Model& model) { } void LogArray(int log_level, const Model& model, const string& name) { - const auto& array = model.GetArray(name); VLOG(log_level) << "Array: " << name; + if (!model.HasArray(name)) { + VLOG(log_level) << " DOES NOT EXIST"; + return; + } + const auto& array = model.GetArray(name); VLOG(log_level) << " Data type: " << ArrayDataTypeName(array.data_type); VLOG(log_level) << " Final type: " << ArrayDataTypeName(array.final_data_type); @@ -590,14 +602,33 @@ void UnextendShape(Shape* shape, int new_shape_size) { shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction); } -bool IsValid(const Shape& shape) { +// In general, zero-sized dimensions are disallowed, but there are exceptions, +// e.g., if the tensor data itself represents a scalar (rank 0) shape, its +// shape will have dimensions [0]. CheckNonEmptyShapeDimensions is more +// strict, and is appropriate for ops and comparisons where an empty shape +// doesn't make sense. +template +void CheckValidShapeDimensions(const Dims& dims) { + if (dims.size() == 1 && dims[0] == 0) { + return; + } + for (const auto& dim : dims) { + CHECK_GE(dim, 1); + } +} + +void CheckValidShape(const Shape& shape) { + CheckValidShapeDimensions(shape.dims()); +} + +bool IsNonEmpty(const Shape& shape) { for (int i = 0; i < shape.dimensions_count(); ++i) { if (shape.dims(i) < 1) return false; } return true; } -void CheckShapeDimensions(const Shape& shape) { +void CheckNonEmptyShapeDimensions(const Shape& shape) { for (int i = 0; i < shape.dimensions_count(); ++i) { CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i << ". shape = " << ShapeToString(shape); @@ -605,8 +636,8 @@ void CheckShapeDimensions(const Shape& shape) { } bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) { - CheckShapeDimensions(shape0); - CheckShapeDimensions(shape1); + CheckNonEmptyShapeDimensions(shape0); + CheckNonEmptyShapeDimensions(shape1); const Shape* longer = &shape0; const Shape* shorter = &shape1; @@ -633,8 +664,8 @@ bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) { } bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) { - CheckShapeDimensions(shape0); - CheckShapeDimensions(shape1); + CheckNonEmptyShapeDimensions(shape0); + CheckNonEmptyShapeDimensions(shape1); const Shape* longer = &shape0; const Shape* shorter = &shape1; @@ -671,9 +702,9 @@ bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) { } int RequiredBufferSizeForShape(const Shape& shape) { + CheckValidShape(shape); int max_offset = 1; for (const auto& dim : shape.dims()) { - CHECK_GE(dim, 1); max_offset *= dim; } return max_offset; @@ -934,9 +965,7 @@ void CheckEachArray(const Model& model) { // shape. CHECK(array->has_shape()); // Constant buffer should has a valid shape. - for (int d : array->shape().dims()) { - CHECK_GE(d, 1); - } + CheckValidShape(array->shape()); // The shape flat-size should agree with the buffer length. CHECK_EQ(array->buffer->Length(), RequiredBufferSizeForShape(array->shape())); @@ -1261,8 +1290,13 @@ void InsertCopyOperator(Model* model, const string& source_array_name, auto* copy_op = new TensorFlowReshapeOperator; copy_op->inputs = { source_array_name, - CreateInt32Array(model, target_array_name + "_copy_shape", shape)}; + CreateInt32Array( + model, AvailableArrayName(*model, target_array_name + "_copy_shape"), + shape)}; copy_op->outputs = {target_array_name}; + if (target_array.has_shape()) { + copy_op->shape = target_array.shape().dims(); + } model->operators.emplace_back(copy_op); } @@ -1523,8 +1557,8 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { if (!input_array.has_shape()) { if (input_array_proto.has_shape()) { auto& input_array_dims = *input_array.mutable_shape()->mutable_dims(); + CheckValidShapeDimensions(input_array_proto.shape().dims()); for (auto dim : input_array_proto.shape().dims()) { - CHECK_GE(dim, 1); input_array_dims.push_back(dim); } } @@ -1567,11 +1601,6 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { model); } - for (const auto& input_array : model->flags.input_arrays()) { - if (input_array.has_shape()) { - CHECK(input_array.shape().dims_size()); - } - } model->flags.set_change_concat_input_ranges( model_flags.change_concat_input_ranges()); model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays()); @@ -1604,11 +1633,12 @@ void CheckIsReadyForQuantization(const Model& model) { << "Array " << input << ", which is an input to the " << HelpfulOperatorTypeName(*op) << " operator producing the output " << "array " << op->outputs[0] << ", is lacking min/max data, " - << "which is necessary for quantization. Either target a " - << "non-quantized output format, or change the input graph to " - << "contain min/max information, or pass --default_ranges_min= and " - << "--default_ranges_max= if you do not care about the accuracy of " - << "results."; + << "which is necessary for quantization. If accuracy matters, either " + << "target a non-quantized output format, or run quantized training " + << "with your model from a floating point checkpoint to change the " + << "input graph to contain min/max information. If you don't care " + << "about accuracy, you can pass --default_ranges_min= and " + << "--default_ranges_max= for easy experimentation."; } } } diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index 5dbfa54fa0369676dce638aec171b409a468da9f..b99e6111fe92be178b5ff8b83477f1ce10c20926 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -115,10 +115,9 @@ void ExtendShape(Shape* shape, int new_shape_size); // TODO(b/36075966): Clean up when dims superseded by array shape. void UnextendShape(Shape* shape, int new_shape_size); -// Checks that all dimensions of 'shape' are at least 1. -bool IsValid(const Shape& shape); -// Same as above, but reports error using CHECK. -void CheckShapeDimensions(const Shape& shape); +// Checks that all dimensions of 'shape' are at least 1. Note that scalars, +// lacking dimensions, satisfy this condition and are considered non-empty. +bool IsNonEmpty(const Shape& shape); // Given two shapes with potentially different dimensionality and dimension // arrays d0 and d1. Without loss of generality, assume that shape0 may have diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/contrib/lite/toco/tooling_util_test.cc index 8609e5beddd200be4e5ebfe1fb2a79048e0e60ab..eb495646a2df0d0295eab54fcc5a5bf156a59d39 100644 --- a/tensorflow/contrib/lite/toco/tooling_util_test.cc +++ b/tensorflow/contrib/lite/toco/tooling_util_test.cc @@ -39,6 +39,8 @@ std::vector CreateShapePairs() { {Shape({256, 256, 3}), Shape({256, 256, 3}), Agreement::kBroadcast}, {Shape({256, 256, 3}), Shape({3}), Agreement::kBroadcast}, {Shape({8, 1, 6, 1}), Shape({7, 1, 5}), Agreement::kBroadcast}, + {Shape({}), Shape({3}), Agreement::kBroadcast}, + {Shape({}), Shape({3, 1}), Agreement::kBroadcast}, // These extend (and therefore broadcast). {Shape({3}), Shape({3}), Agreement::kExtend}, @@ -54,6 +56,7 @@ std::vector CreateShapePairs() { {Shape({15, 3, 5}), Shape({15, 1, 5}), Agreement::kBroadcastNotExtend}, {Shape({15, 3, 5}), Shape({3, 5}), Agreement::kBroadcastNotExtend}, {Shape({15, 3, 5}), Shape({3, 1}), Agreement::kBroadcastNotExtend}, + {Shape({3, 1}), Shape({}), Agreement::kBroadcastNotExtend}, // These do not broadcast (and therefore also do not extend). {Shape({3}), Shape({4}), Agreement::kNeither}, @@ -175,6 +178,20 @@ TEST(NumElementsTest, UnsignedInt64) { EXPECT_EQ(status.error_message(), kLargeTensorMessage); } +TEST(NumElementsTest, Scalar) { + tensorflow::Status status = tensorflow::Status::OK(); + + int32_t count; + status = NumElements(std::vector{}, &count); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(count, 1); + + uint64_t countu64; + status = NumElements(std::vector{}, &countu64); + EXPECT_TRUE(status.ok()); + EXPECT_EQ(countu64, 1ULL); +} + TEST(FusedActivationTest, DefaultsToUnfused) { EXPECT_TRUE(OperatorSupportsFusedActivation(OperatorType::kAdd)); EXPECT_FALSE(OperatorSupportsFusedActivation(OperatorType::kNone)); diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD index 5913847329eeae7373d0d21834dd37327e4068c4..0b268264031f4f1e86b2956a75bde173a945ddf4 100644 --- a/tensorflow/contrib/lite/tools/BUILD +++ b/tensorflow/contrib/lite/tools/BUILD @@ -14,6 +14,7 @@ py_binary( srcs = ["visualize.py"], data = [ "//tensorflow/contrib/lite/schema:schema.fbs", + "//tensorflow/python:platform", "@flatbuffers//:flatc", ], srcs_version = "PY2AND3", @@ -52,7 +53,9 @@ cc_test( "//tensorflow/contrib/lite:testdata/test_model_broken.bin", ], tags = [ + "no_oss", "tflite_not_portable_android", + "tflite_not_portable_ios", ], deps = [ ":gen_op_registration", @@ -77,6 +80,7 @@ cc_test( size = "small", srcs = ["verifier_test.cc"], tags = [ + "no_oss", "tflite_not_portable", ], deps = [ diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD index 183a545295f690decec47f1c31aa473667408a3d..2cb07eb6ec9405a5fefec9cc49f3b1aaff663e4b 100644 --- a/tensorflow/contrib/lite/tools/benchmark/BUILD +++ b/tensorflow/contrib/lite/tools/benchmark/BUILD @@ -10,11 +10,16 @@ load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") common_copts = ["-Wall"] + tflite_copts() +cc_library( + name = "logging", + hdrs = ["logging.h"], + copts = common_copts, +) + cc_binary( name = "benchmark_model", srcs = [ "benchmark_main.cc", - "logging.h", ], copts = common_copts, linkopts = tflite_linkopts() + select({ @@ -26,6 +31,26 @@ cc_binary( }), deps = [ ":benchmark_tflite_model_lib", + ":logging", + ], +) + +cc_test( + name = "benchmark_test", + srcs = ["benchmark_test.cc"], + args = [ + "--graph=$(location //tensorflow/contrib/lite:testdata/multi_add.bin)", + ], + data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], + tags = [ + "tflite_not_portable_android", + "tflite_not_portable_ios", + ], + deps = [ + ":benchmark_tflite_model_lib", + ":command_line_flags", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", ], ) @@ -58,6 +83,7 @@ cc_library( copts = common_copts, deps = [ ":benchmark_model_lib", + ":logging", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/kernels:builtin_ops", @@ -70,23 +96,23 @@ cc_library( name = "benchmark_params", srcs = [ "benchmark_params.cc", - "logging.h", ], hdrs = ["benchmark_params.h"], copts = common_copts, + deps = [":logging"], ) cc_library( name = "benchmark_model_lib", srcs = [ "benchmark_model.cc", - "logging.h", ], hdrs = ["benchmark_model.h"], copts = common_copts, deps = [ ":benchmark_params", ":command_line_flags", + ":logging", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/kernels:builtin_ops", diff --git a/tensorflow/contrib/lite/tools/benchmark/README.md b/tensorflow/contrib/lite/tools/benchmark/README.md index 93769305bde210b58f3b2cb668a9d8c1ad0ce396..f1e257ad104885a23cd7f17b9c21317c0881ccc0 100644 --- a/tensorflow/contrib/lite/tools/benchmark/README.md +++ b/tensorflow/contrib/lite/tools/benchmark/README.md @@ -115,7 +115,7 @@ E.g. for running the benchmark on big cores on Pixel 2 with a single thread one can use the following command: ``` -adb shell tasket f0 /data/local/tmp/benchmark_model \ +adb shell taskset f0 /data/local/tmp/benchmark_model \ --graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \ --input_layer="input" \ --input_layer_shape="1,224,224,3" \ diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc index 08648bcfe26365d180d984fde8f8e04b22eb45dd..f86c0445b0525cd053c733b18bb7f1205d310d43 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc @@ -84,7 +84,7 @@ std::vector BenchmarkModel::GetFlags() { }; } -void BenchmarkModel::LogFlags() { +void BenchmarkModel::LogParams() { TFLITE_LOG(INFO) << "Num runs: [" << params_.Get("num_runs") << "]"; TFLITE_LOG(INFO) << "Inter-run delay (seconds): [" << params_.Get("run_delay") << "]"; @@ -98,10 +98,13 @@ void BenchmarkModel::LogFlags() { << "]"; } +void BenchmarkModel::PrepareInputsAndOutputs() {} + Stat BenchmarkModel::Run(int num_times, RunType run_type) { Stat run_stats; TFLITE_LOG(INFO) << "Running benchmark for " << num_times << " iterations "; for (int run = 0; run < num_times; run++) { + PrepareInputsAndOutputs(); listeners_.OnSingleRunStart(run_type); int64_t start_us = profiling::time::NowMicros(); RunImpl(); @@ -119,12 +122,18 @@ Stat BenchmarkModel::Run(int num_times, RunType run_type) { return run_stats; } +bool BenchmarkModel::ValidateParams() { return true; } + void BenchmarkModel::Run(int argc, char **argv) { if (!ParseFlags(argc, argv)) { return; } + Run(); +} - LogFlags(); +void BenchmarkModel::Run() { + ValidateParams(); + LogParams(); listeners_.OnBenchmarkStart(params_); int64_t initialization_start_us = profiling::time::NowMicros(); @@ -152,7 +161,7 @@ bool BenchmarkModel::ParseFlags(int argc, char **argv) { TFLITE_LOG(ERROR) << usage; return false; } - return ValidateFlags(); + return true; } } // namespace benchmark diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h index 942e21f67a7f864f16b7b1b85b2599d5c872b5c7..677a1ee68c247fb016c7ede4e1a614bacb7a0a15 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h @@ -137,19 +137,21 @@ class BenchmarkModel { BenchmarkModel(); BenchmarkModel(BenchmarkParams params) : params_(std::move(params)) {} virtual ~BenchmarkModel() {} - bool ParseFlags(int argc, char** argv); virtual void Init() = 0; void Run(int argc, char** argv); + virtual void Run(); void AddListener(BenchmarkListener* listener) { listeners_.AddListener(listener); } protected: - virtual void LogFlags(); - virtual bool ValidateFlags() { return true; } + virtual void LogParams(); + virtual bool ValidateParams(); + bool ParseFlags(int argc, char** argv); virtual std::vector GetFlags(); virtual uint64_t ComputeInputBytes() = 0; virtual tensorflow::Stat Run(int num_times, RunType run_type); + virtual void PrepareInputsAndOutputs(); virtual void RunImpl() = 0; BenchmarkParams params_; BenchmarkListeners listeners_; diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h index 33448dd1623577fdfda6316c588cc60ccbaa1994..c98f47bb0d89864dff54d7cdebe764e56e4cfda2 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h @@ -31,6 +31,8 @@ class TypedBenchmarkParam; class BenchmarkParam { protected: enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING }; + template + static ParamType GetValueType(); public: template @@ -49,8 +51,6 @@ class BenchmarkParam { private: static void AssertHasSameType(ParamType a, ParamType b); - template - static ParamType GetValueType(); const ParamType type_; }; diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_test.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b697bb394db9b967dfaaff649517dcc23e85ccb0 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_test.cc @@ -0,0 +1,74 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" +#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" + +namespace { +const std::string* g_model_path = nullptr; +} + +namespace tflite { +namespace benchmark { +namespace { + +BenchmarkParams CreateParams() { + BenchmarkParams params; + params.AddParam("num_runs", BenchmarkParam::Create(2)); + params.AddParam("run_delay", BenchmarkParam::Create(-1.0f)); + params.AddParam("num_threads", BenchmarkParam::Create(1)); + params.AddParam("benchmark_name", BenchmarkParam::Create("")); + params.AddParam("output_prefix", BenchmarkParam::Create("")); + params.AddParam("warmup_runs", BenchmarkParam::Create(1)); + params.AddParam("graph", BenchmarkParam::Create(*g_model_path)); + params.AddParam("input_layer", BenchmarkParam::Create("")); + params.AddParam("input_layer_shape", BenchmarkParam::Create("")); + params.AddParam("use_nnapi", BenchmarkParam::Create(false)); + return params; +} + +TEST(BenchmarkTest, DoesntCrash) { + ASSERT_THAT(g_model_path, testing::NotNull()); + + BenchmarkTfLiteModel benchmark(CreateParams()); + benchmark.Run(); +} + +} // namespace +} // namespace benchmark +} // namespace tflite + +int main(int argc, char** argv) { + std::string model_path; + std::vector flags = { + tflite::Flag::CreateFlag("graph", &model_path, "Path to model file.")}; + g_model_path = &model_path; + const bool parse_result = + tflite::Flags::Parse(&argc, const_cast(argv), flags); + if (!parse_result) { + std::cerr << tflite::Flags::Usage(argv[0], flags); + return 1; + } + + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc index 73affc26b034f415ae2a2101e0b558cdb94d8d5b..7f97f5d0cd6c412653f6d510406daf86b7baa3f7 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc @@ -198,8 +198,8 @@ std::vector BenchmarkTfLiteModel::GetFlags() { return flags; } -void BenchmarkTfLiteModel::LogFlags() { - BenchmarkModel::LogFlags(); +void BenchmarkTfLiteModel::LogParams() { + BenchmarkModel::LogParams(); TFLITE_LOG(INFO) << "Graph: [" << params_.Get("graph") << "]"; TFLITE_LOG(INFO) << "Input layers: [" << params_.Get("input_layer") << "]"; @@ -208,7 +208,7 @@ void BenchmarkTfLiteModel::LogFlags() { TFLITE_LOG(INFO) << "Use nnapi : [" << params_.Get("use_nnapi") << "]"; } -bool BenchmarkTfLiteModel::ValidateFlags() { +bool BenchmarkTfLiteModel::ValidateParams() { if (params_.Get("graph").empty()) { TFLITE_LOG(ERROR) << "Please specify the name of your TF Lite input file with --graph"; diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h index 50cc3f24b3bd2f31555eac69ff208fa2480449b9..9931dcbafe06cb9f8673462858244f6f2793b29d 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h @@ -54,8 +54,8 @@ class BenchmarkTfLiteModel : public BenchmarkModel { BenchmarkTfLiteModel(BenchmarkParams params); std::vector GetFlags() override; - void LogFlags() override; - bool ValidateFlags() override; + void LogParams() override; + bool ValidateParams() override; uint64_t ComputeInputBytes() override; void Init() override; void RunImpl() override; diff --git a/tensorflow/contrib/lite/tools/visualize.py b/tensorflow/contrib/lite/tools/visualize.py index f571dd59da0a3f4aff264b48fba3e41f75b50404..597dede63b0c089da21f4b0ede065189d8bbe1d8 100644 --- a/tensorflow/contrib/lite/tools/visualize.py +++ b/tensorflow/contrib/lite/tools/visualize.py @@ -28,11 +28,24 @@ import json import os import sys +from tensorflow.python.platform import resource_loader + # Schema to use for flatbuffers _SCHEMA = "third_party/tensorflow/contrib/lite/schema/schema.fbs" -# Where the binary will be once built in for the flatc converter -_BINARY = "third_party/flatbuffers/flatc" +# TODO(angerson): fix later when rules are simplified.. +_SCHEMA = resource_loader.get_path_to_datafile("../schema/schema.fbs") +_BINARY = resource_loader.get_path_to_datafile("../../../../flatbuffers/flatc") +# Account for different package positioning internal vs. external. +if not os.path.exists(_BINARY): + _BINARY = resource_loader.get_path_to_datafile( + "../../../../../flatbuffers/flatc") + +if not os.path.exists(_SCHEMA): + raise RuntimeError("Sorry, schema file cannot be found at %r" % _SCHEMA) +if not os.path.exists(_BINARY): + raise RuntimeError("Sorry, flatc is not available at %r" % _BINARY) + # A CSS description for making the visualizer _CSS = """ @@ -321,7 +334,7 @@ def CreateHtmlFile(tflite_input, html_output): for key, mapping in toplevel_stuff: if not mapping: mapping = lambda x: x - html += "

\n" % (key, mapping(data[key])) + html += "\n" % (key, mapping(data.get(key))) html += "
- Mobilenet_1.0_224(float) + Mobilenet_1.0_224(float) Pixel 2 166.5 ms (2.6 ms)
- Mobilenet_1.0_224 (quant) + Mobilenet_1.0_224 (quant) Pixel 2 69.5 ms (0.9 ms)
- Mobilenet_1.0_224(float) + Mobilenet_1.0_224(float) iPhone 8 32.2 ms (0.8 ms)
- Mobilenet_1.0_224 (quant) + Mobilenet_1.0_224 (quant) iPhone 8 24.4 ms (0.8 ms)
%s%s
%s%s
\n" # Spec on what keys to display diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h index 89d9b4f5cffa99e708f391fd8fe19208009b5e79..3c4801183bad834e5789c97a56416cdf4668f897 100644 --- a/tensorflow/contrib/lite/util.h +++ b/tensorflow/contrib/lite/util.h @@ -26,12 +26,17 @@ limitations under the License. namespace tflite { -// Converts a `std::vector` to a `TfLiteIntArray`. +// Converts a `std::vector` to a `TfLiteIntArray`. The caller takes ownership +// of the returned pointer. TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector& input); +// Converts an array (of the given size) to a `TfLiteIntArray`. The caller +// takes ownership of the returned pointer, and must make sure 'dims' has at +// least 'rank' elemnts. TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims); // Checks whether a `TfLiteIntArray` and an int array have matching elements. +// The caller must guarantee that 'b' has at least 'b_size' elements. bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size, const int* b); diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 889accdd5aafae2931048ffdd26408cccb3c874e..8d510ede5827df3889307c0f38572bece84f102e 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -280,6 +280,21 @@ class HashTableOpTest(test.TestCase): table.init.run() self.assertAllEqual(3, table.size().eval()) + def testHashTableInt32String(self): + with self.test_session(): + default_val = "n/a" + keys = constant_op.constant([0, 1, 2], dtypes.int32) + values = constant_op.constant(["brain", "salad", "surgery"]) + table = lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), default_val) + table.init.run() + + input_tensor = constant_op.constant([0, 1, -1]) + output = table.lookup(input_tensor) + + result = output.eval() + self.assertAllEqual([b"brain", b"salad", b"n/a"], result) + class MutableHashTableOpTest(test.TestCase): diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index 48953e2e3843ff92744514d28bd725cc0d72f3a8..448ae6d22e65fcd9129e27e6321d3081abf7d1ac 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -30,7 +30,11 @@ EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" +# Note: The Protobuf source in `tensorflow/workspace.bzl` in TensorFlow +# 1.10 branch does not work. `make distclean` fails and blocks the build +# process. For now we're hardcoding to the version which is used by +# TensorFlow 1.9. +PROTOBUF_URL="https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz" RE2_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" DOUBLE_CONVERSION_URL="$(grep -o "https.*google/double-conversion.*\.zip" "${BZL_FILE_PATH}" | head -n1)" diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt index 76428bc1d4e682e000998a6e28fc290e218c2341..7d26429f9c3b26bcd8819e92cbc15daed60ea9f4 100644 --- a/tensorflow/contrib/makefile/proto_text_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt @@ -35,6 +35,7 @@ tensorflow/core/lib/random/random.cc tensorflow/core/lib/random/distribution_sampler.cc tensorflow/core/lib/io/zlib_outputbuffer.cc tensorflow/core/lib/io/zlib_inputstream.cc +tensorflow/core/lib/io/zlib_compression_options.cc tensorflow/core/lib/io/two_level_iterator.cc tensorflow/core/lib/io/table_builder.cc tensorflow/core/lib/io/table.cc diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 89db9ee2794ddf0a99951dca327e74c5d9694d23..ecf2e120df98d82cca068e186f95e91e71ebc66d 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -92,6 +92,7 @@ tensorflow/core/kernels/reduction_ops_common.cc tensorflow/core/kernels/reduction_ops_any.cc tensorflow/core/kernels/reduction_ops_all.cc tensorflow/core/kernels/roll_op.cc +tensorflow/core/kernels/queue_op.cc tensorflow/core/kernels/queue_ops.cc tensorflow/core/kernels/queue_base.cc tensorflow/core/kernels/pooling_ops_common.cc @@ -228,6 +229,8 @@ tensorflow/core/kernels/cast_op_impl_int32.cc tensorflow/core/kernels/cast_op_impl_int64.cc tensorflow/core/kernels/cast_op_impl_int8.cc tensorflow/core/kernels/cast_op_impl_uint16.cc +tensorflow/core/kernels/cast_op_impl_uint32.cc +tensorflow/core/kernels/cast_op_impl_uint64.cc tensorflow/core/kernels/cast_op_impl_uint8.cc tensorflow/core/kernels/boosted_trees/prediction_ops.cc tensorflow/core/kernels/boosted_trees/resource_ops.cc diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD index 66cb493e5c5bb9b8645e87dc7f5b274d916f64fc..21cd34f73ffbbf615a81c18b9d365bffa61397f4 100644 --- a/tensorflow/contrib/metrics/BUILD +++ b/tensorflow/contrib/metrics/BUILD @@ -31,6 +31,7 @@ py_library( "//tensorflow/python:check_ops", "//tensorflow/python:confusion_matrix", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:distribute", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:histogram_ops", "//tensorflow/python:init_ops", diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py index 5effea3596bb83a08e0a8627e411684262aef5f7..88798d61b71388de63e492ba69284a72303d32ab 100644 --- a/tensorflow/contrib/metrics/__init__.py +++ b/tensorflow/contrib/metrics/__init__.py @@ -63,6 +63,7 @@ See the @{$python/contrib.metrics} guide. @@aggregate_metrics @@aggregate_metric_map @@confusion_matrix +@@f1_score @@set_difference @@set_intersection @@set_size diff --git a/tensorflow/contrib/metrics/python/metrics/classification.py b/tensorflow/contrib/metrics/python/metrics/classification.py index 26aba1cc51446e589856013d69526007fbe9d921..e5536122698a50852c4cb96f12ce52ab5d5f6e39 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification.py +++ b/tensorflow/contrib/metrics/python/metrics/classification.py @@ -22,6 +22,9 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics_impl +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import distribute as distribute_lib # TODO(nsilberman): move into metrics/python/ops/ @@ -62,3 +65,121 @@ def accuracy(predictions, labels, weights=None, name=None): return math_ops.div(math_ops.reduce_sum(is_correct), math_ops.reduce_sum(num_values)) return math_ops.reduce_mean(is_correct) + + +def f1_score(labels, predictions, weights=None, num_thresholds=200, + metrics_collections=None, updates_collections=None, name=None): + """Computes the approximately best F1-score across different thresholds. + + The f1_score function applies a range of thresholds to the predictions to + convert them from [0, 1] to bool. Precision and recall are computed by + comparing them to the labels. The F1-Score is then defined as + 2 * precision * recall / (precision + recall). The best one across the + thresholds is returned. + + Disclaimer: In practice it may be desirable to choose the best threshold on + the validation set and evaluate the F1 score with this threshold on a + separate test set. Or it may be desirable to use a fixed threshold (e.g. 0.5). + + This function internally creates four local variables, `true_positives`, + `true_negatives`, `false_positives` and `false_negatives` that are used to + compute the pairs of recall and precision values for a linearly spaced set of + thresholds from which the best f1-score is derived. + + This value is ultimately returned as `f1-score`, an idempotent operation that + computes the F1-score (computed using the aforementioned variables). The + `num_thresholds` variable controls the degree of discretization with larger + numbers of thresholds more closely approximating the true best F1-score. + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the F1-score. + + Example usage with a custom estimator: + def model_fn(features, labels, mode): + predictions = make_predictions(features) + loss = make_loss(predictions, labels) + train_op = tf.contrib.training.create_train_op( + total_loss=loss, + optimizer='Adam') + eval_metric_ops = {'f1': f1_score(labels, predictions)} + return tf.estimator.EstimatorSpec( + mode=mode, + predictions=predictions, + loss=loss, + train_op=train_op, + eval_metric_ops=eval_metric_ops, + export_outputs=export_outputs) + estimator = tf.estimator.Estimator(model_fn=model_fn) + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + labels: A `Tensor` whose shape matches `predictions`. Will be cast to + `bool`. + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + num_thresholds: The number of thresholds to use when discretizing the roc + curve. + metrics_collections: An optional list of collections that `f1_score` should + be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + f1_score: A scalar `Tensor` representing the current best f1-score across + different thresholds. + update_op: An operation that increments the `true_positives`, + `true_negatives`, `false_positives` and `false_negatives` variables + appropriately and whose value matches the `f1_score`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope( + name, 'f1', (labels, predictions, weights)): + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=predictions, labels=labels, weights=weights) + # To account for floating point imprecisions / avoid division by zero. + epsilon = 1e-7 + thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) + for i in range(num_thresholds - 2)] + thresholds = [0.0 - epsilon] + thresholds + [1.0 + epsilon] + + # Confusion matrix. + values, update_ops = metrics_impl._confusion_matrix_at_thresholds( # pylint: disable=protected-access + labels, predictions, thresholds, weights, includes=('tp', 'fp', 'fn')) + + # Compute precision and recall at various thresholds. + def compute_best_f1_score(tp, fp, fn, name): + precision_at_t = math_ops.div(tp, epsilon + tp + fp, + name='precision_' + name) + recall_at_t = math_ops.div(tp, epsilon + tp + fn, name='recall_' + name) + # Compute F1 score. + f1_at_thresholds = ( + 2.0 * precision_at_t * recall_at_t / + (precision_at_t + recall_at_t + epsilon)) + return math_ops.reduce_max(f1_at_thresholds) + + def f1_across_towers(_, values): + best_f1 = compute_best_f1_score(tp=values['tp'], fp=values['fp'], + fn=values['fn'], name='value') + if metrics_collections: + ops.add_to_collections(metrics_collections, best_f1) + return best_f1 + + best_f1 = distribute_lib.get_tower_context().merge_call( + f1_across_towers, values) + + update_op = compute_best_f1_score(tp=update_ops['tp'], fp=update_ops['fp'], + fn=update_ops['fn'], name='update') + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return best_f1, update_op diff --git a/tensorflow/contrib/metrics/python/metrics/classification_test.py b/tensorflow/contrib/metrics/python/metrics/classification_test.py index fa0f12d029620ad6427f715f035ff69f15c133e7..3d0b81c1bed02dae013141367fb052e16d31fe08 100644 --- a/tensorflow/contrib/metrics/python/metrics/classification_test.py +++ b/tensorflow/contrib/metrics/python/metrics/classification_test.py @@ -18,9 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.metrics.python.metrics import classification +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -108,5 +115,200 @@ class ClassificationTest(test.TestCase): self.assertEqual(result, 0.5) +class F1ScoreTest(test.TestCase): + + def setUp(self): + super(F1ScoreTest, self).setUp() + np.random.seed(1) + + def testVars(self): + classification.f1_score( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + num_thresholds=3) + expected = {'f1/true_positives:0', 'f1/false_positives:0', + 'f1/false_negatives:0'} + self.assertEquals( + expected, set(v.name for v in variables.local_variables())) + self.assertEquals( + set(expected), set(v.name for v in variables.local_variables())) + self.assertEquals( + set(expected), + set(v.name for v in ops.get_collection(ops.GraphKeys.METRIC_VARIABLES))) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + f1, _ = classification.f1_score( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + num_thresholds=3, + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [f1]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, f1_op = classification.f1_score( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + num_thresholds=3, + updates_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [f1_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes.float32, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=2, dtype=dtypes.int64, seed=2) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run([f1_op]) + + # Then verify idempotency. + initial_f1 = f1.eval() + for _ in range(10): + self.assertAllClose(initial_f1, f1.eval()) + + def testAllCorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes.float32) + labels = constant_op.constant(inputs) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3) + + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + self.assertEqual(1, f1.eval()) + + def testSomeCorrect(self): + predictions = constant_op.constant( + [1, 0, 1, 0], shape=(1, 4), dtype=dtypes.float32) + labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=1) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + # Threshold 0 will have around 0.5 precision and 1 recall yielding an F1 + # score of 2 * 0.5 * 1 / (1 + 0.5). + self.assertAlmostEqual(2 * 0.5 * 1 / (1 + 0.5), f1.eval()) + + def testAllIncorrect(self): + inputs = np.random.randint(0, 2, size=(10000, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes.float32) + labels = constant_op.constant(1 - inputs, dtype=dtypes.float32) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3) + + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + # Threshold 0 will have around 0.5 precision and 1 recall yielding an F1 + # score of 2 * 0.5 * 1 / (1 + 0.5). + self.assertAlmostEqual(2 * 0.5 * 1 / (1 + 0.5), f1.eval(), places=2) + + def testWeights1d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0], [1]], shape=(2, 1), dtype=dtypes.float32) + f1, f1_op = classification.f1_score(predictions, labels, weights, + num_thresholds=3) + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + self.assertAlmostEqual(1.0, f1.eval(), places=5) + + def testWeights2d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes.float32) + f1, f1_op = classification.f1_score(predictions, labels, weights, + num_thresholds=3) + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + self.assertAlmostEqual(1.0, f1.eval(), places=5) + + def testZeroLabelsPredictions(self): + with self.test_session() as sess: + predictions = array_ops.zeros([4], dtype=dtypes.float32) + labels = array_ops.zeros([4]) + f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3) + sess.run(variables.local_variables_initializer()) + sess.run([f1_op]) + + self.assertAlmostEqual(0.0, f1.eval(), places=5) + + def testWithMultipleUpdates(self): + num_samples = 1000 + batch_size = 10 + num_batches = int(num_samples / batch_size) + + # Create the labels and data. + labels = np.random.randint(0, 2, size=(num_samples, 1)) + noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1)) + predictions = 0.4 + 0.2 * labels + noise + predictions[predictions > 1] = 1 + predictions[predictions < 0] = 0 + thresholds = [-0.01, 0.5, 1.01] + + expected_max_f1 = -1.0 + for threshold in thresholds: + tp = 0 + fp = 0 + fn = 0 + tn = 0 + for i in range(num_samples): + if predictions[i] >= threshold: + if labels[i] == 1: + tp += 1 + else: + fp += 1 + else: + if labels[i] == 1: + fn += 1 + else: + tn += 1 + epsilon = 1e-7 + expected_prec = tp / (epsilon + tp + fp) + expected_rec = tp / (epsilon + tp + fn) + expected_f1 = (2 * expected_prec * expected_rec / + (epsilon + expected_prec + expected_rec)) + if expected_f1 > expected_max_f1: + expected_max_f1 = expected_f1 + + labels = labels.astype(np.float32) + predictions = predictions.astype(np.float32) + tf_predictions, tf_labels = (dataset_ops.Dataset + .from_tensor_slices((predictions, labels)) + .repeat() + .batch(batch_size) + .make_one_shot_iterator() + .get_next()) + f1, f1_op = classification.f1_score(tf_labels, tf_predictions, + num_thresholds=3) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for _ in range(num_batches): + sess.run([f1_op]) + # Since this is only approximate, we can't expect a 6 digits match. + # Although with higher number of samples/thresholds we should see the + # accuracy improving + self.assertAlmostEqual(expected_max_f1, f1.eval(), 2) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index b14202ff9ec38016f926ee37c8acbd2bbb4c6ef5..a328670526089988c181a8e1146c911309640009 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -3715,6 +3715,7 @@ def count(values, name=None): """Computes the number of examples, or sum of `weights`. + This metric keeps track of the denominator in `tf.metrics.mean`. When evaluating some metric (e.g. mean) on one or more subsets of the data, this auxiliary metric is useful for keeping track of how many examples there are in each subset. @@ -3741,15 +3742,21 @@ def count(values, ValueError: If `weights` is not `None` and its shape doesn't match `values`, or if either `metrics_collections` or `updates_collections` are not a list or tuple. + RuntimeError: If eager execution is enabled. """ + if context.executing_eagerly(): + raise RuntimeError('tf.contrib.metrics.count is not supported when eager ' + 'execution is enabled.') with variable_scope.variable_scope(name, 'count', (values, weights)): + count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') if weights is None: num_values = math_ops.to_float(array_ops.size(values)) else: - _, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + values = math_ops.to_float(values) + values, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions=values, labels=None, weights=weights) @@ -3758,15 +3765,14 @@ def count(values, num_values = math_ops.reduce_sum(weights) with ops.control_dependencies([values]): - update_op = state_ops.assign_add(count_, num_values) + update_count_op = state_ops.assign_add(count_, num_values) - if metrics_collections: - ops.add_to_collections(metrics_collections, count_) + count_ = metrics_impl._aggregate_variable(count_, metrics_collections) # pylint: disable=protected-access if updates_collections: - ops.add_to_collections(updates_collections, update_op) + ops.add_to_collections(updates_collections, update_count_op) - return count_, update_op + return count_, update_count_op def cohen_kappa(labels, diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index a09fc4abd461323d67e914c70932688816fed764..401fedcbed8fef12308d563d108725a418dfef17 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -6854,6 +6854,11 @@ class CountTest(test.TestCase): array_ops.ones([4, 3]), updates_collections=[my_collection_name]) self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) + def testReturnType(self): + c, op = metrics.count(array_ops.ones([4, 3])) + self.assertTrue(isinstance(c, ops.Tensor)) + self.assertTrue(isinstance(op, ops.Operation) or isinstance(op, ops.Tensor)) + def testBasic(self): with self.test_session() as sess: values_queue = data_flow_ops.FIFOQueue( diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py index ef34f7bf7bf3eba047b50ce8abf883b0ed741a63..93050a3ae373603c516c7eb72c22f327f4a60a00 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py @@ -77,7 +77,7 @@ class LossScaleOptimizer(optimizer.Optimizer): If gradients clipping is applied, one can call `optimizer.compute_gradients()` and `optimizer.apply_gradients()` - seperately. + separately. Notice the following way of using LossScaleOptimizer is not intended. Always use `loss_scale_optimizer.compute_gradients()` to compute gradients instead of diff --git a/tensorflow/contrib/model_pruning/BUILD b/tensorflow/contrib/model_pruning/BUILD index 54bd39afacbec07f054f61b72eda0a3654858aa7..16ddc38f5a5ba88485e18b136b2b1081b0e2ff0f 100644 --- a/tensorflow/contrib/model_pruning/BUILD +++ b/tensorflow/contrib/model_pruning/BUILD @@ -95,6 +95,22 @@ py_library( ], ) +py_library( + name = "strip_pruning_vars_lib", + srcs = ["python/strip_pruning_vars_lib.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":pruning", + "//tensorflow/python:client", + "//tensorflow/python:framework", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_test( name = "pruning_utils_test", size = "small", @@ -129,6 +145,31 @@ py_test( ], ) +py_test( + name = "strip_pruning_vars_test", + size = "small", + srcs = ["python/strip_pruning_vars_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":layers", + ":pruning", + ":rnn_cells", + ":strip_pruning_vars_lib", + "//tensorflow/python:client_testlib", + ], +) + +py_binary( + name = "strip_pruning_vars", + srcs = ["python/strip_pruning_vars.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":strip_pruning_vars_lib", + "//tensorflow/python:platform", + ], +) + py_library( name = "init_py", srcs = ["__init__.py"], @@ -145,5 +186,6 @@ py_library( ":learning", ":pruning", ":rnn_cells", + ":strip_pruning_vars_lib", ], ) diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index 86f4fd6adf60d8fa54c13989bf4087e28f1e006f..a5267fd90482287a65a4c38ae257a0af349523e8 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -4,7 +4,15 @@ This document describes the API that facilitates magnitude-based pruning of neural network's weight tensors. The API helps inject necessary tensorflow op into the training graph so the model can be pruned while it is being trained. -### Model creation +## Table of contents +1. [Model creation](#model-creation) +2. [Hyperparameters for pruning](#hyperparameters) + - [Block sparsity](#block-sparsity) +3. [Adding pruning ops to the training graph](#adding-pruning-ops) +4. [Removing pruning ops from trained model](#remove) +5. [Example](#example) + +### Model creation The first step involves adding mask and threshold variables to the layers that need to undergo pruning. The variable mask is the same shape as the layer's @@ -33,7 +41,7 @@ auxiliary variables built-in (see * [rnn_cells.MaskedLSTMCell](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py?l=154) -### Adding pruning ops to the training graph +### Pruning-related hyperparameters The pruning library allows for specification of the following hyper parameters: @@ -42,7 +50,7 @@ The pruning library allows for specification of the following hyper parameters: | name | string | model_pruning | Name of the pruning specification. Used for adding summaries and ops under a common tensorflow name_scope | | begin_pruning_step | integer | 0 | The global step at which to begin pruning | | end_pruning_step | integer | -1 | The global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops | -| do_not_prune | list of strings | [""] | list of layers names that are not pruned | +| weight_sparsity_map | list of strings | [""] | list of weight variable name (or layer name):target sparsity pairs. Eg. [conv1:0.9,conv2/kernel:0.8]. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. | | threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds | | pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) | | nbins | integer | 256 | Number of bins to use for histogram computation | @@ -64,12 +72,18 @@ is divided into $$n$$ intervals of size equal to the pruning_frequency ($$\Delta t$$). $$s_f$$ is the target_sparsity, $$s_i$$ is the initial_sparsity, $$t_0$$ is the sparsity_function_begin_step. In this equation, the sparsity_function_exponent is set to 3. -### Adding pruning ops to the training graph -The final step involves adding ops to the training graph that monitors the -distribution of the layer's weight magnitudes and determines the layer threshold -such masking all the weights below this threshold achieves the sparsity level -desired for the current training step. This can be achieved as follows: +#### Block Sparsity + +For some hardware architectures, it may be beneficial to induce spatially correlated sparsity. To train models in which the weight tensors have block sparse structure, set *block_height* and *block_width* hyperparameters to the desired block configuration (2x2, 4x4, 4x1, 1x8, etc). Currently, block sparsity is only supported for weight tensors which can be squeezed to rank 2. The matrix is partitioned into non-overlapping blocks of size *[block_height, block_dim]* and the either the average or max absolute value in this block is taken as a proxy for the entire block (set by *block_pooling_function* hyperparameter). +The convolution layer tensors are always pruned used block dimensions of [1,1]. + +### Adding pruning ops to the training graph + +The final step involves adding ops to the training graph that monitor the +distribution of the layer's weight magnitudes and determine the layer threshold, +such that masking all the weights below this threshold achieves the sparsity +level desired for the current training step. This can be achieved as follows: ```python tf.app.flags.DEFINE_string( @@ -79,7 +93,7 @@ tf.app.flags.DEFINE_string( with tf.graph.as_default(): # Create global step variable - global_step = tf.train.get_global_step() + global_step = tf.train.get_or_create_global_step() # Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) @@ -103,8 +117,21 @@ with tf.graph.as_default(): mon_sess.run(mask_update_op) ``` +Ensure that `global_step` is being [incremented](https://www.tensorflow.org/api_docs/python/tf/train/Optimizer#minimize), otherwise pruning will not work! + +### Removing pruning ops from the trained graph +Once the model is trained, it is necessary to remove the auxiliary variables (mask, threshold) and pruning ops added to the graph in the steps above. This can be accomplished using the `strip_pruning_vars` utility. + +This utility generates a binary GraphDef in which the variables have been converted to constants. In particular, the threshold variables are removed from the graph and the mask variable is fused with the corresponding weight tensor to produce a `masked_weight` tensor. This tensor is sparse, has the same size as the weight tensor, and the sparsity is as set by the `target_sparsity` or the `weight_sparsity_map` hyperparameters above. + +```shell +$ bazel build -c opt contrib/model_pruning:strip_pruning_vars +$ bazel-bin/contrib/model_pruning/strip_pruning_vars --checkpoint_dir=/path/to/checkpoints/ --output_node_names=graph_node1,graph_node2 --output_dir=/tmp --filename=pruning_stripped.pb +``` + +For now, it is assumed that the underlying hardware platform will provide mechanisms for compressing the sparse tensors and/or accelerating the sparse tensor computations. -## Example: Pruning and training deep CNNs on the cifar10 dataset +## Example: Pruning and training deep CNNs on the cifar10 dataset Please see https://www.tensorflow.org/tutorials/deep_cnn for details on neural network architecture, setting up inputs etc. The additional changes needed to @@ -120,7 +147,7 @@ incorporate pruning are captured in the following: To train the pruned version of cifar10: -```bash +```shell $ examples_dir=contrib/model_pruning/examples $ bazel build -c opt $examples_dir/cifar10:cifar10_{train,eval} $ bazel-bin/$examples_dir/cifar10/cifar10_train --pruning_hparams=name=cifar10_pruning,begin_pruning_step=10000,end_pruning_step=100000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000 @@ -132,10 +159,14 @@ Eval: $ bazel-bin/$examples_dir/cifar10/cifar10_eval --run_once ``` -### Block Sparsity +Removing pruning nodes from the trained graph: -For some hardware architectures, it may be beneficial to induce spatially correlated sparsity. To train models in which the weight tensors have block sparse structure, set *block_height* and *block_width* hyperparameters to the desired block configuration (2x2, 4x4, 4x1, 1x8, etc). Currently, block sparsity is only supported for weight tensors which can be squeezed to rank 2. The matrix is partitioned into non-overlapping blocks of size *[block_height, block_dim]* and the either the average or max absolute value in this block is taken as a proxy for the entire block (set by *block_pooling_function* hyperparameter). -The convolution layer tensors are always pruned used block dimensions of [1,1]. +```shell +$ bazel build -c opt contrib/model_pruning:strip_pruning_vars +$ bazel-bin/contrib/model_pruning/strip_pruning_vars --checkpoint_path=/tmp/cifar10_train --output_node_names=softmax_linear/softmax_linear_2 --filename=cifar_pruned.pb +``` + +The generated GraphDef (cifar_pruned.pb) may be visualized using the [`import_pb_to_tensorboard`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/tools/import_pb_to_tensorboard.py) utility ## References diff --git a/tensorflow/contrib/model_pruning/__init__.py b/tensorflow/contrib/model_pruning/__init__.py index d32bedbcd6b63bc8e473a9e9d1c8e0753877e6f8..6eca54aaee186f5873a84ef2cb3ff3c7cfb42cd4 100644 --- a/tensorflow/contrib/model_pruning/__init__.py +++ b/tensorflow/contrib/model_pruning/__init__.py @@ -33,6 +33,9 @@ from tensorflow.contrib.model_pruning.python.pruning import get_thresholds from tensorflow.contrib.model_pruning.python.pruning import get_weight_sparsity from tensorflow.contrib.model_pruning.python.pruning import get_weights from tensorflow.contrib.model_pruning.python.pruning import Pruning +from tensorflow.contrib.model_pruning.python.strip_pruning_vars_lib import graph_def_from_checkpoint +from tensorflow.contrib.model_pruning.python.strip_pruning_vars_lib import strip_pruning_vars_fn + # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented @@ -41,7 +44,8 @@ _allowed_symbols = [ 'masked_convolution', 'masked_conv2d', 'masked_fully_connected', 'MaskedBasicLSTMCell', 'MaskedLSTMCell', 'train', 'apply_mask', 'get_masked_weights', 'get_masks', 'get_pruning_hparams', 'get_thresholds', - 'get_weights', 'get_weight_sparsity', 'Pruning' + 'get_weights', 'get_weight_sparsity', 'Pruning', 'strip_pruning_vars_fn', + 'graph_def_from_checkpoint' ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/model_pruning/python/layers/layers.py b/tensorflow/contrib/model_pruning/python/layers/layers.py index 466daf204a1ae86a7f37107342046305ea7249fc..d453e350f05c8e66df13c3861959980d69a564e8 100644 --- a/tensorflow/contrib/model_pruning/python/layers/layers.py +++ b/tensorflow/contrib/model_pruning/python/layers/layers.py @@ -139,7 +139,7 @@ def masked_convolution(inputs, with "NC". num_outputs: Integer, the number of output filters. kernel_size: A sequence of N positive integers specifying the spatial - dimensions of of the filters. Can be a single integer to specify the same + dimensions of the filters. Can be a single integer to specify the same value for all spatial dimensions. stride: A sequence of N positive integers specifying the stride at which to compute output. Can be a single integer to specify the same value for all diff --git a/tensorflow/contrib/model_pruning/python/learning.py b/tensorflow/contrib/model_pruning/python/learning.py index 2b79c23cefe961b1c4056d41b5fcc0a0521efec6..26695237c27cc4fbe4e9fbaa2666d55836ed39b8 100644 --- a/tensorflow/contrib/model_pruning/python/learning.py +++ b/tensorflow/contrib/model_pruning/python/learning.py @@ -33,11 +33,14 @@ to support training of pruned models # Create the train_op train_op = slim.learning.create_train_op(total_loss, optimizer) - # Set up sparsity - sparsity = pruning.setup_gradual_sparsity(self.global_step) + # Parse pruning hyperparameters + pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) - # Create mask update op - mask_update_op = pruning.add_mask_update_ip(sparsity) + # Create a pruning object using the pruning_hparams + p = pruning.Pruning(pruning_hparams) + + # Add mask update ops to the graph + mask_update_op = p.conditional_mask_update_op() # Run training. learning.train(train_op, diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index 4b7af18b3316950afdb90c344ce777848c63e4c1..cd58526ed3620d4bd880cf36d806afac70c4bff7 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -152,8 +152,11 @@ def get_pruning_hparams(): end_pruning_step: integer the global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops - do_not_prune: list of strings - list of layers that are not pruned + weight_sparsity_map: list of strings + comma separed list of weight variable name:target sparsity pairs. + For layers/weights not in this list, sparsity as specified by the + target_sparsity hyperparameter is used. + Eg. [conv1:0.9,conv2/kernel:0.8] threshold_decay: float the decay factor to use for exponential decay of the thresholds pruning_frequency: integer @@ -200,7 +203,7 @@ def get_pruning_hparams(): name='model_pruning', begin_pruning_step=0, end_pruning_step=-1, - do_not_prune=[''], + weight_sparsity_map=[''], threshold_decay=0.9, pruning_frequency=10, nbins=256, @@ -234,6 +237,9 @@ class Pruning(object): # Pruning specification self._spec = spec if spec else get_pruning_hparams() + # Sanity check for pruning hparams + self._validate_spec() + # A tensorflow variable that tracks the sparsity function. # If not provided as input, the graph must already contain the global_step # variable before calling this constructor. @@ -256,6 +262,37 @@ class Pruning(object): # Block pooling function self._block_pooling_function = self._spec.block_pooling_function + # Mapping of weight names and target sparsity + self._weight_sparsity_map = self._get_weight_sparsity_map() + + def _validate_spec(self): + spec = self._spec + if spec.begin_pruning_step < 0: + raise ValueError('Illegal value for begin_pruning_step') + + if spec.begin_pruning_step >= spec.end_pruning_step: + if spec.end_pruning_step != -1: + raise ValueError( + 'Pruning must begin before it can end. begin_step=%d, end_step=%d.' + 'Set end_pruning_step to -1 if pruning is required till training' + 'stops' % (spec.begin_pruning_step, spec.end_pruning_step)) + + if spec.sparsity_function_begin_step < 0: + raise ValueError('Illegal value for sparsity_function_begin_step') + + if spec.sparsity_function_begin_step >= spec.sparsity_function_end_step: + raise ValueError( + 'Sparsity function requires begin_step < end_step') + + if not 0.0 <= spec.threshold_decay < 1.0: + raise ValueError('threshold_decay must be in range [0,1)') + + if not 0.0 <= spec.initial_sparsity < 1.0: + raise ValueError('initial_sparsity must be in range [0,1)') + + if not 0.0 <= spec.target_sparsity < 1.0: + raise ValueError('target_sparsity must be in range [0,1)') + def _setup_global_step(self, global_step): graph_global_step = global_step if graph_global_step is None: @@ -270,11 +307,6 @@ class Pruning(object): target_sparsity = self._spec.target_sparsity exponent = self._spec.sparsity_function_exponent - if begin_step >= end_step: - raise ValueError( - 'Pruning must begin before it can end. begin_step=%d, end_step=%d' % - (begin_step, end_step)) - with ops.name_scope(self._spec.name): p = math_ops.minimum( 1.0, @@ -306,15 +338,36 @@ class Pruning(object): 'last_mask_update_step', dtype=dtypes.int32) return last_update_step - def _exists_in_do_not_prune_list(self, tensor_name): - do_not_prune_list = self._spec.do_not_prune - if not do_not_prune_list[0]: - return False - for layer_name in do_not_prune_list: - if tensor_name.find(layer_name) != -1: - return True - - return False + def _get_weight_sparsity_map(self): + """Return the map of weight_name:sparsity parsed from the hparams.""" + weight_sparsity_map = {} + val_list = self._spec.weight_sparsity_map + filtered_val_list = [l for l in val_list if l] + for val in filtered_val_list: + weight_name, sparsity = val.split(':') + if float(sparsity) >= 1.0: + raise ValueError('Weight sparsity can not exceed 1.0') + weight_sparsity_map[weight_name] = float(sparsity) + + return weight_sparsity_map + + def _get_sparsity(self, weight_name): + """Return target sparsity for the given layer/weight name.""" + target_sparsity = [ + sparsity for name, sparsity in self._weight_sparsity_map.items() + if weight_name.find(name) != -1 + ] + if not target_sparsity: + return self._sparsity + + if len(target_sparsity) > 1: + raise ValueError( + 'Multiple matches in weight_sparsity_map for weight %s' % weight_name) + # TODO(suyoggupta): This will work when initial_sparsity = 0. Generalize + # to handle other cases as well. + return math_ops.mul( + self._sparsity, + math_ops.div(target_sparsity[0], self._spec.target_sparsity)) def _update_mask(self, weights, threshold): """Updates the mask for a given weight tensor. @@ -342,6 +395,8 @@ class Pruning(object): if self._sparsity is None: raise ValueError('Sparsity variable undefined') + sparsity = self._get_sparsity(weights.op.name) + with ops.name_scope(weights.op.name + '_pruning_ops'): abs_weights = math_ops.abs(weights) max_value = math_ops.reduce_max(abs_weights) @@ -354,7 +409,7 @@ class Pruning(object): math_ops.div( math_ops.reduce_sum( math_ops.cast( - math_ops.less(norm_cdf, self._sparsity), dtypes.float32)), + math_ops.less(norm_cdf, sparsity), dtypes.float32)), float(self._spec.nbins)), max_value) smoothed_threshold = math_ops.add_n([ @@ -453,10 +508,6 @@ class Pruning(object): if is_partitioned: weight = weight.as_tensor() - if self._spec.do_not_prune: - if self._exists_in_do_not_prune_list(mask.name): - continue - new_threshold, new_mask = self._maybe_update_block_mask(weight, threshold) self._assign_ops.append( pruning_utils.variable_assign(threshold, new_threshold)) @@ -507,22 +558,15 @@ class Pruning(object): no_update_op) def add_pruning_summaries(self): - """Adds summaries for this pruning spec. - - Args: none - - Returns: none - """ + """Adds summaries of weight sparsities and thresholds.""" with ops.name_scope(self._spec.name + '_summaries'): summary.scalar('sparsity', self._sparsity) summary.scalar('last_mask_update_step', self._last_update_step) masks = get_masks() thresholds = get_thresholds() - for index, mask in enumerate(masks): - if not self._exists_in_do_not_prune_list(mask.name): - summary.scalar(mask.name + '/sparsity', nn_impl.zero_fraction(mask)) - summary.scalar(thresholds[index].op.name + '/threshold', - thresholds[index]) + for mask, threshold in zip(masks, thresholds): + summary.scalar(mask.op.name + '/sparsity', nn_impl.zero_fraction(mask)) + summary.scalar(threshold.op.name + '/threshold', threshold) def print_hparams(self): logging.info(self._spec.to_json()) diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py index f80b7c52c000f13b5ce98dd442ff21abfac37761..33c4ad58bd7f57422935fc839ddfc64d5e1f00f5 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_test.py @@ -35,8 +35,8 @@ from tensorflow.python.training import training_util class PruningHParamsTest(test.TestCase): PARAM_LIST = [ "name=test", "threshold_decay=0.9", "pruning_frequency=10", - "do_not_prune=[conv1,conv2]", "sparsity_function_end_step=100", - "target_sparsity=0.9" + "sparsity_function_end_step=100", "target_sparsity=0.9", + "weight_sparsity_map=[conv1:0.8,conv2/kernel:0.8]" ] TEST_HPARAMS = ",".join(PARAM_LIST) @@ -55,9 +55,10 @@ class PruningHParamsTest(test.TestCase): self.assertEqual(p._spec.name, "test") self.assertAlmostEqual(p._spec.threshold_decay, 0.9) self.assertEqual(p._spec.pruning_frequency, 10) - self.assertAllEqual(p._spec.do_not_prune, ["conv1", "conv2"]) self.assertEqual(p._spec.sparsity_function_end_step, 100) self.assertAlmostEqual(p._spec.target_sparsity, 0.9) + self.assertEqual(p._weight_sparsity_map["conv1"], 0.8) + self.assertEqual(p._weight_sparsity_map["conv2/kernel"], 0.8) def testInitWithExternalSparsity(self): with self.test_session(): @@ -211,6 +212,37 @@ class PruningTest(test.TestCase): expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40] self.assertAllEqual(expected_non_zero_count, non_zero_count) + def testWeightSpecificSparsity(self): + param_list = [ + "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100", + "target_sparsity=0.5", "weight_sparsity_map=[layer2/weights:0.75]", + "threshold_decay=0.0" + ] + test_spec = ",".join(param_list) + pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) + + with variable_scope.variable_scope("layer1"): + w1 = variables.Variable( + math_ops.linspace(1.0, 100.0, 100), name="weights") + _ = pruning.apply_mask(w1) + with variable_scope.variable_scope("layer2"): + w2 = variables.Variable( + math_ops.linspace(1.0, 100.0, 100), name="weights") + _ = pruning.apply_mask(w2) + + p = pruning.Pruning(pruning_hparams) + mask_update_op = p.conditional_mask_update_op() + increment_global_step = state_ops.assign_add(self.global_step, 1) + + with self.test_session() as session: + variables.global_variables_initializer().run() + for _ in range(110): + session.run(mask_update_op) + session.run(increment_global_step) + + self.assertAllEqual( + session.run(pruning.get_weight_sparsity()), [0.5, 0.75]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/model_pruning/python/strip_pruning_vars.py b/tensorflow/contrib/model_pruning/python/strip_pruning_vars.py new file mode 100644 index 0000000000000000000000000000000000000000..3385103807f6dbdab2d27882c670a3ccf6a26e9d --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/strip_pruning_vars.py @@ -0,0 +1,103 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Removes the auxiliary variables and ops added by the pruning library. + +Usage: + +bazel build tensorflow/contrib/model_pruning:strip_pruning_vars && \ +bazel-bin/tensorflow/contrib/model_pruning/strip_pruning_vars \ +--checkpoint_dir=/tmp/model_ckpts \ +--output_node_names=softmax \ +--output_dir=/tmp \ +--filename=pruning_stripped.pb +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import sys + +from tensorflow.contrib.model_pruning.python import strip_pruning_vars_lib +from tensorflow.python.framework import graph_io +from tensorflow.python.platform import app +from tensorflow.python.platform import tf_logging as logging + +FLAGS = None + + +def strip_pruning_vars(checkpoint_dir, output_node_names, output_dir, filename): + """Remove pruning-related auxiliary variables and ops from the graph. + + Accepts training checkpoints and produces a GraphDef in which the pruning vars + and ops have been removed. + + Args: + checkpoint_dir: Path to the checkpoints. + output_node_names: The name of the output nodes, comma separated. + output_dir: Directory where to write the graph. + filename: Output GraphDef file name. + + Returns: + None + + Raises: + ValueError: if output_nodes_names are not provided. + """ + if not output_node_names: + raise ValueError( + 'Need to specify atleast 1 output node through output_node_names flag') + output_node_names = output_node_names.replace(' ', '').split(',') + + initial_graph_def = strip_pruning_vars_lib.graph_def_from_checkpoint( + checkpoint_dir, output_node_names) + + final_graph_def = strip_pruning_vars_lib.strip_pruning_vars_fn( + initial_graph_def, output_node_names) + graph_io.write_graph(final_graph_def, output_dir, filename, as_text=False) + logging.info('\nFinal graph written to %s', os.path.join( + output_dir, filename)) + + +def main(unused_args): + return strip_pruning_vars(FLAGS.checkpoint_dir, FLAGS.output_node_names, + FLAGS.output_dir, FLAGS.filename) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument( + '--checkpoint_dir', type=str, default='', help='Path to the checkpoints.') + parser.add_argument( + '--output_node_names', + type=str, + default='', + help='The name of the output nodes, comma separated.') + parser.add_argument( + '--output_dir', + type=str, + default='/tmp', + help='Directory where to write the graph.') + parser.add_argument( + '--filename', + type=str, + default='pruning_stripped.pb', + help='Output \'GraphDef\' file name.') + + FLAGS, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/model_pruning/python/strip_pruning_vars_lib.py b/tensorflow/contrib/model_pruning/python/strip_pruning_vars_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..fc4b10863f7c46235059f948fbbfcfcf83d3e15b --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/strip_pruning_vars_lib.py @@ -0,0 +1,142 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities to remove pruning-related ops and variables from a GraphDef. +""" + +# pylint: disable=missing-docstring +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import graph_util +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import saver as saver_lib + + +def _node_name(tensor_name): + """Remove the trailing ':0' from the variable name.""" + if ':' not in tensor_name: + return tensor_name + + return tensor_name.split(':')[0] + + +def _tensor_name(node_name): + """Appends the :0 in the op name to get the canonical tensor name.""" + if ':' in node_name: + return node_name + + return node_name + ':0' + + +def _get_masked_weights(input_graph_def): + """Extracts masked_weights from the graph as a dict of {var_name:ndarray}.""" + input_graph = ops.Graph() + with input_graph.as_default(): + importer.import_graph_def(input_graph_def, name='') + + with session.Session(graph=input_graph) as sess: + masked_weights_dict = {} + for node in input_graph_def.node: + if 'masked_weight' in node.name: + masked_weight_val = sess.run( + sess.graph.get_tensor_by_name(_tensor_name(node.name))) + logging.info( + '%s has %d values, %1.2f%% zeros \n', node.name, + np.size(masked_weight_val), + 100 - float(100 * np.count_nonzero(masked_weight_val)) / + np.size(masked_weight_val)) + masked_weights_dict.update({node.name: masked_weight_val}) + return masked_weights_dict + + +def strip_pruning_vars_fn(input_graph_def, output_node_names): + """Removes mask variable from the graph. + + Replaces the masked_weight tensor with element-wise multiplication of mask + and the corresponding weight variable. + + Args: + input_graph_def: A GraphDef in which the variables have been converted to + constants. This is typically the output of + tf.graph_util.convert_variables_to_constant() + output_node_names: List of name strings for the result nodes of the graph + + Returns: + A GraphDef in which pruning-related variables have been removed + """ + masked_weights_dict = _get_masked_weights(input_graph_def) + pruned_graph_def = graph_pb2.GraphDef() + + # Replace masked_weight with a const op containing the + # result of tf.multiply(mask,weight) + for node in input_graph_def.node: + output_node = node_def_pb2.NodeDef() + if 'masked_weight' in node.name: + output_node.op = 'Const' + output_node.name = node.name + dtype = node.attr['T'] + data = masked_weights_dict[node.name] + output_node.attr['dtype'].CopyFrom(dtype) + output_node.attr['value'].CopyFrom( + attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(data))) + + else: + output_node.CopyFrom(node) + pruned_graph_def.node.extend([output_node]) + + # Remove stranded nodes: mask and weights + return graph_util.extract_sub_graph(pruned_graph_def, output_node_names) + + +def graph_def_from_checkpoint(checkpoint_dir, output_node_names): + """Converts checkpoint data to GraphDef. + + Reads the latest checkpoint data and produces a GraphDef in which the + variables have been converted to constants. + + Args: + checkpoint_dir: Path to the checkpoints. + output_node_names: List of name strings for the result nodes of the graph. + + Returns: + A GraphDef from the latest checkpoint + + Raises: + ValueError: if no checkpoint is found + """ + checkpoint_path = saver_lib.latest_checkpoint(checkpoint_dir) + if checkpoint_path is None: + raise ValueError('Could not find a checkpoint at: {0}.' + .format(checkpoint_dir)) + + saver_for_restore = saver_lib.import_meta_graph( + checkpoint_path + '.meta', clear_devices=True) + with session.Session() as sess: + saver_for_restore.restore(sess, checkpoint_path) + graph_def = ops.get_default_graph().as_graph_def() + output_graph_def = graph_util.convert_variables_to_constants( + sess, graph_def, output_node_names) + + return output_graph_def diff --git a/tensorflow/contrib/model_pruning/python/strip_pruning_vars_test.py b/tensorflow/contrib/model_pruning/python/strip_pruning_vars_test.py new file mode 100644 index 0000000000000000000000000000000000000000..255daa036099c0d3ef2dbc5eb37fdb0c31c71383 --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/strip_pruning_vars_test.py @@ -0,0 +1,232 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for strip_pruning_vars.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re + +from tensorflow.contrib.model_pruning.python import pruning +from tensorflow.contrib.model_pruning.python import strip_pruning_vars_lib +from tensorflow.contrib.model_pruning.python.layers import layers +from tensorflow.contrib.model_pruning.python.layers import rnn_cells +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import graph_util +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell as tf_rnn_cells +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import training_util + + +def _get_number_pruning_vars(graph_def): + number_vars = 0 + for node in graph_def.node: + if re.match(r"^.*(mask$)|(threshold$)", node.name): + number_vars += 1 + return number_vars + + +def _get_node_names(tensor_names): + return [ + strip_pruning_vars_lib._node_name(tensor_name) + for tensor_name in tensor_names + ] + + +class StripPruningVarsTest(test.TestCase): + + def setUp(self): + param_list = [ + "pruning_frequency=1", "begin_pruning_step=1", "end_pruning_step=10", + "nbins=2048", "threshold_decay=0.0" + ] + self.initial_graph = ops.Graph() + self.initial_graph_def = None + self.final_graph = ops.Graph() + self.final_graph_def = None + self.pruning_spec = ",".join(param_list) + with self.initial_graph.as_default(): + self.sparsity = variables.Variable(0.5, name="sparsity") + self.global_step = training_util.get_or_create_global_step() + self.increment_global_step = state_ops.assign_add(self.global_step, 1) + self.mask_update_op = None + + def _build_convolutional_model(self, number_of_layers): + # Create a graph with several conv2d layers + kernel_size = 3 + base_depth = 4 + depth_step = 7 + height, width = 7, 9 + with variable_scope.variable_scope("conv_model"): + input_tensor = array_ops.ones((8, height, width, base_depth)) + top_layer = input_tensor + for ix in range(number_of_layers): + top_layer = layers.masked_conv2d( + top_layer, + base_depth + (ix + 1) * depth_step, + kernel_size, + scope="Conv_" + str(ix)) + + return top_layer + + def _build_fully_connected_model(self, number_of_layers): + base_depth = 4 + depth_step = 7 + + input_tensor = array_ops.ones((8, base_depth)) + + top_layer = input_tensor + + with variable_scope.variable_scope("fc_model"): + for ix in range(number_of_layers): + top_layer = layers.masked_fully_connected( + top_layer, base_depth + (ix + 1) * depth_step) + + return top_layer + + def _build_lstm_model(self, number_of_layers): + batch_size = 8 + dim = 10 + inputs = variables.Variable(random_ops.random_normal([batch_size, dim])) + + def lstm_cell(): + return rnn_cells.MaskedBasicLSTMCell( + dim, forget_bias=0.0, state_is_tuple=True, reuse=False) + + cell = tf_rnn_cells.MultiRNNCell( + [lstm_cell() for _ in range(number_of_layers)], state_is_tuple=True) + + outputs = rnn.static_rnn( + cell, [inputs], + initial_state=cell.zero_state(batch_size, dtypes.float32)) + + return outputs + + def _prune_model(self, session): + pruning_hparams = pruning.get_pruning_hparams().parse(self.pruning_spec) + p = pruning.Pruning(pruning_hparams, sparsity=self.sparsity) + self.mask_update_op = p.conditional_mask_update_op() + + variables.global_variables_initializer().run() + for _ in range(20): + session.run(self.mask_update_op) + session.run(self.increment_global_step) + + def _get_outputs(self, session, input_graph, tensors_list, graph_prefix=None): + outputs = [] + + for output_tensor in tensors_list: + if graph_prefix: + output_tensor = graph_prefix + "/" + output_tensor + outputs.append( + session.run(session.graph.get_tensor_by_name(output_tensor))) + + return outputs + + def _get_initial_outputs(self, output_tensor_names_list): + with self.test_session(graph=self.initial_graph) as sess1: + self._prune_model(sess1) + reference_outputs = self._get_outputs(sess1, self.initial_graph, + output_tensor_names_list) + + self.initial_graph_def = graph_util.convert_variables_to_constants( + sess1, sess1.graph.as_graph_def(), + _get_node_names(output_tensor_names_list)) + return reference_outputs + + def _get_final_outputs(self, output_tensor_names_list): + self.final_graph_def = strip_pruning_vars_lib.strip_pruning_vars_fn( + self.initial_graph_def, _get_node_names(output_tensor_names_list)) + _ = importer.import_graph_def(self.final_graph_def, name="final") + + with self.test_session(self.final_graph) as sess2: + final_outputs = self._get_outputs( + sess2, + self.final_graph, + output_tensor_names_list, + graph_prefix="final") + return final_outputs + + def _check_removal_of_pruning_vars(self, number_masked_layers): + self.assertEqual( + _get_number_pruning_vars(self.initial_graph_def), number_masked_layers) + self.assertEqual(_get_number_pruning_vars(self.final_graph_def), 0) + + def _check_output_equivalence(self, initial_outputs, final_outputs): + for initial_output, final_output in zip(initial_outputs, final_outputs): + self.assertAllEqual(initial_output, final_output) + + def testConvolutionalModel(self): + with self.initial_graph.as_default(): + number_masked_conv_layers = 5 + top_layer = self._build_convolutional_model(number_masked_conv_layers) + output_tensor_names = [top_layer.name] + initial_outputs = self._get_initial_outputs(output_tensor_names) + + # Remove pruning-related nodes. + with self.final_graph.as_default(): + final_outputs = self._get_final_outputs(output_tensor_names) + + # Check that the final graph has no pruning-related vars + self._check_removal_of_pruning_vars(number_masked_conv_layers) + + # Check that outputs remain the same after removal of pruning-related nodes + self._check_output_equivalence(initial_outputs, final_outputs) + + def testFullyConnectedModel(self): + with self.initial_graph.as_default(): + number_masked_fc_layers = 3 + top_layer = self._build_fully_connected_model(number_masked_fc_layers) + output_tensor_names = [top_layer.name] + initial_outputs = self._get_initial_outputs(output_tensor_names) + + # Remove pruning-related nodes. + with self.final_graph.as_default(): + final_outputs = self._get_final_outputs(output_tensor_names) + + # Check that the final graph has no pruning-related vars + self._check_removal_of_pruning_vars(number_masked_fc_layers) + + # Check that outputs remain the same after removal of pruning-related nodes + self._check_output_equivalence(initial_outputs, final_outputs) + + def testLSTMModel(self): + with self.initial_graph.as_default(): + number_masked_lstm_layers = 2 + outputs = self._build_lstm_model(number_masked_lstm_layers) + output_tensor_names = [outputs[0][0].name] + initial_outputs = self._get_initial_outputs(output_tensor_names) + + # Remove pruning-related nodes. + with self.final_graph.as_default(): + final_outputs = self._get_final_outputs(output_tensor_names) + + # Check that the final graph has no pruning-related vars + self._check_removal_of_pruning_vars(number_masked_lstm_layers) + + # Check that outputs remain the same after removal of pruning-related nodes + self._check_output_equivalence(initial_outputs, final_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/mpi_collectives/BUILD b/tensorflow/contrib/mpi_collectives/BUILD index a7be92a35e0d62a61f7923ac61bb2c1267d039c6..ecac06354d2ce796f2a6021cdf2370d7c30ccab7 100644 --- a/tensorflow/contrib/mpi_collectives/BUILD +++ b/tensorflow/contrib/mpi_collectives/BUILD @@ -52,6 +52,7 @@ tf_custom_op_library( deps = [ ":mpi_defines", ":mpi_message_proto_cc", + "//tensorflow/stream_executor:stream_executor_headers_lib", "//third_party/mpi", ], ) diff --git a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc index ed22ee667f1d73b3f86f77e09bad9bfec7e46391..e4b0c2c6541836243347d2950686c60ef06d2bfc 100644 --- a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc +++ b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc @@ -73,7 +73,7 @@ limitations under the License. */ template -using StatusOr = se::port::StatusOr; +using StatusOr = stream_executor::port::StatusOr; using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops.py b/tensorflow/contrib/mpi_collectives/mpi_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..bd7096d9cee2d32bde5227a95038ae65cd8a6e18 --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/mpi_ops.py @@ -0,0 +1,163 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Inter-process communication using MPI.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.python.framework import errors +from tensorflow.python.framework import load_library +from tensorflow.python.framework import ops +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import tf_logging as logging + + +def _load_library(name, op_list=None): + """Loads a .so file containing the specified operators. + + Args: + name: The name of the .so file to load. + op_list: A list of names of operators that the library should have. If None + then the .so file's contents will not be verified. + + Raises: + NameError if one of the required ops is missing. + """ + try: + filename = resource_loader.get_path_to_datafile(name) + library = load_library.load_op_library(filename) + for expected_op in (op_list or []): + for lib_op in library.OP_LIST.op: + if lib_op.name == expected_op: + break + else: + raise NameError('Could not find operator %s in dynamic library %s' % + (expected_op, name)) + return library + except errors.NotFoundError: + logging.warning('%s file could not be loaded.', name) + + +MPI_LIB = _load_library( + 'mpi_collectives.so', + ['MPISize', 'MPIRank', 'MPILocalRank', 'MPIAllgather', 'MPIAllreduce']) + + +def size(name=None): + """An op which returns the number of MPI processes. + + This is equivalent to running `MPI_Comm_size(MPI_COMM_WORLD, ...)` to get the + size of the global communicator. + + Returns: + An integer scalar containing the number of MPI processes. + """ + return MPI_LIB.mpi_size(name=name) + + +ops.NotDifferentiable('MPISize') + + +def rank(name=None): + """An op which returns the MPI rank of the calling process. + + This is equivalent to running `MPI_Comm_rank(MPI_COMM_WORLD, ...)` to get the + rank of the current process in the global communicator. + + Returns: + An integer scalar with the MPI rank of the calling process. + """ + return MPI_LIB.mpi_rank(name=name) + + +ops.NotDifferentiable('MPIRank') + + +def init(name=None): + """An op which initializes MPI on the device on which it is run. + + All future MPI ops must be run on the same device that the `init` op was run + on. + """ + return MPI_LIB.mpi_init(name=name) + + +ops.NotDifferentiable('MPIInit') + + +def local_rank(name=None): + """An op which returns the local MPI rank of the calling process, within the + node that it is running on. For example, if there are seven processes running + on a node, their local ranks will be zero through six, inclusive. + + This is equivalent to running `MPI_Comm_rank(...)` on a new communicator + which only includes processes on the same node. + + Returns: + An integer scalar with the local MPI rank of the calling process. + """ + return MPI_LIB.mpi_local_rank(name=name) + + +ops.NotDifferentiable('MPILocalRank') + + +def _allreduce(tensor, name=None): + """An op which sums an input tensor over all the MPI processes. + + The reduction operation is keyed by the name of the op. The tensor type and + shape must be the same on all MPI processes for a given name. The reduction + will not start until all processes are ready to send and receive the tensor. + + Returns: + A tensor of the same shape and type as `tensor`, summed across all + processes. + """ + return MPI_LIB.mpi_allreduce(tensor, name=name) + + +ops.NotDifferentiable('MPIAllreduce') + + +def allgather(tensor, name=None): + """An op which concatenates the input tensor with the same input tensor on + all other MPI processes. + + The concatenation is done on the first dimension, so the input tensors on the + different processes must have the same rank and shape, except for the first + dimension, which is allowed to be different. + + Returns: + A tensor of the same type as `tensor`, concatenated on dimension zero + across all processes. The shape is identical to the input shape, except for + the first dimension, which may be greater and is the sum of all first + dimensions of the tensors in different MPI processes. + """ + # Specify that first allgather is to collect the tensor gather sizes, + # indicated by passing in a scalar (0-D tensor) of value 0 + sizes_flag = tf.constant(0, dtype=tf.int64, name='size_flag_const') + my_size = tf.slice( + tf.shape(tensor, out_type=tf.int64), [0], [1], name='size_slice') + if name is None: + name = 'allgather' + sizing_name = '{}_sizing'.format(name) + sizes = MPI_LIB.mpi_allgather(my_size, sizes_flag, name=sizing_name) + return MPI_LIB.mpi_allgather(tensor, sizes, name=name) + + +ops.NotDifferentiable('MPIAllgather') diff --git a/tensorflow/contrib/mpi_collectives/ring.cc b/tensorflow/contrib/mpi_collectives/ring.cc new file mode 100644 index 0000000000000000000000000000000000000000..d93233eb210b80df10fd9c2c7975ce77112d18a2 --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/ring.cc @@ -0,0 +1,80 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef TENSORFLOW_USE_MPI + +#define EIGEN_USE_THREADS + +#include "tensorflow/contrib/mpi_collectives/ring.h" + +namespace tensorflow { +namespace contrib { +namespace mpi { + +using CPUDevice = Eigen::ThreadPoolDevice; + +extern template MPI_Datatype MPIType(); +extern template MPI_Datatype MPIType(); +extern template MPI_Datatype MPIType(); +extern template DataType TensorFlowDataType(); +extern template DataType TensorFlowDataType(); +extern template DataType TensorFlowDataType(); + +// Generate all necessary specializations for RingAllreduce. +template Status RingAllreduce(OpKernelContext*, const Tensor*, + Tensor*, Tensor*); +template Status RingAllreduce(OpKernelContext*, + const Tensor*, Tensor*, + Tensor*); +template Status RingAllreduce(OpKernelContext*, const Tensor*, + Tensor*, Tensor*); + +// Generate all necessary specializations for RingAllgather. +template Status RingAllgather(OpKernelContext*, const Tensor*, + const std::vector&, + Tensor*); +template Status RingAllgather(OpKernelContext*, + const Tensor*, + const std::vector&, + Tensor*); +template Status RingAllgather(OpKernelContext*, const Tensor*, + const std::vector&, + Tensor*); + +// Copy data on a CPU using a straight-forward memcpy. +template <> +void CopyTensorData(void* dst, void* src, size_t size) { + std::memcpy(dst, src, size); +}; + +// Accumulate values on a CPU. +#define GENERATE_ACCUMULATE(type) \ + template <> \ + void AccumulateTensorData(type * dst, type * src, \ + size_t size) { \ + for (unsigned int i = 0; i < size; i++) { \ + dst[i] += src[i]; \ + } \ + }; +GENERATE_ACCUMULATE(int); +GENERATE_ACCUMULATE(long long); +GENERATE_ACCUMULATE(float); +#undef GENERATE_ACCUMULATE + +} // namespace mpi +} // namespace contrib +} // namespace tensorflow + +#endif // TENSORFLOW_USE_MPI diff --git a/tensorflow/contrib/mpi_collectives/ring.cu.cc b/tensorflow/contrib/mpi_collectives/ring.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..2f3eef366a9a3c10e59cd5298fc1626e1094dff8 --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/ring.cu.cc @@ -0,0 +1,117 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef TENSORFLOW_USE_MPI + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/contrib/mpi_collectives/ring.h" + +namespace tensorflow { +namespace contrib { +namespace mpi { + +using CPUDevice = Eigen::ThreadPoolDevice; + +template <> +MPI_Datatype MPIType() { + return MPI_FLOAT; +}; +template <> +MPI_Datatype MPIType() { + return MPI_INT; +}; +template <> +MPI_Datatype MPIType() { + return MPI_LONG_LONG; +}; + +template <> +DataType TensorFlowDataType() { + return DT_FLOAT; +}; +template <> +DataType TensorFlowDataType() { + return DT_INT32; +}; +template <> +DataType TensorFlowDataType() { + return DT_INT64; +}; + +// Generate all necessary specializations for RingAllreduce. +template Status RingAllreduce(OpKernelContext*, const Tensor*, + Tensor*, Tensor*); +template Status RingAllreduce(OpKernelContext*, + const Tensor*, Tensor*, + Tensor*); +template Status RingAllreduce(OpKernelContext*, const Tensor*, + Tensor*, Tensor*); + +// Generate all necessary specializations for RingAllgather. +template Status RingAllgather(OpKernelContext*, const Tensor*, + const std::vector&, + Tensor*); +template Status RingAllgather(OpKernelContext*, + const Tensor*, + const std::vector&, + Tensor*); +template Status RingAllgather(OpKernelContext*, const Tensor*, + const std::vector&, + Tensor*); + +// Synchronously copy data on the GPU, using a different stream than the default +// and than TensorFlow to avoid synchronizing on operations unrelated to the +// allreduce. +template <> +void CopyTensorData(void* dst, void* src, size_t size) { + auto stream = CudaStreamForMPI(); + cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream); + cudaStreamSynchronize(stream); +}; + +// Elementwise accumulation kernel for GPU. +template +__global__ void elemwise_accum(T* out, const T* in, const size_t N) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + out[i] += in[i]; + } +} + +// Synchronously accumulate tensors on the GPU, using a different stream than +// the default and than TensorFlow to avoid synchronizing on operations +// unrelated to the allreduce. +#define GENERATE_ACCUMULATE(type) \ + template <> \ + void AccumulateTensorData(type * dst, type * src, \ + size_t size) { \ + auto stream = CudaStreamForMPI(); \ + elemwise_accum<<<32, 256, 0, stream>>>(dst, src, size); \ + cudaStreamSynchronize(stream); \ + }; +GENERATE_ACCUMULATE(int); +GENERATE_ACCUMULATE(long long); +GENERATE_ACCUMULATE(float); +#undef GENERATE_ACCUMULATE + +} // namespace mpi +} // namespace contrib +} // namespace tensorflow +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_USE_MPI diff --git a/tensorflow/contrib/mpi_collectives/ring.h b/tensorflow/contrib/mpi_collectives/ring.h new file mode 100644 index 0000000000000000000000000000000000000000..cae57ce60eb09509af69f8ccab9eacedea361548 --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/ring.h @@ -0,0 +1,327 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_MPI_H_ +#define TENSORFLOW_CONTRIB_MPI_H_ + +#ifdef TENSORFLOW_USE_MPI + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" + +#if GOOGLE_CUDA +#include "cuda_runtime.h" +#endif + +// Needed to avoid header issues with C++-supporting MPI implementations +#define OMPI_SKIP_MPICXX +#include "third_party/mpi/mpi.h" + +#define TAG_TENSOR 12 + +namespace tensorflow { +namespace contrib { +namespace mpi { + +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +// Convert from templated types to values we can pass to MPI. +template +MPI_Datatype MPIType(); + +// Convert from templated types to TensorFlow data types. +template +DataType TensorFlowDataType(); + +#define MPI_REQUIRES_OK(MPI_STATUS) \ + if ((MPI_STATUS) != MPI_SUCCESS) { \ + return errors::Unknown("MPI operation failed unexpectedly."); \ + } + +// Copy data from one tensor to another tensor. +// This uses a custom CUDA stream on GPU, which is necessary to overlay the +// backpropagation computations with the allreduce. +template +void CopyTensorData(void* destination, void* source, size_t size); + +// Add a tensor into another tensor, accumulating in place. +// This uses a custom CUDA stream on GPU, which is necessary to overlay the +// backpropagation computations with the allreduce. +template +void AccumulateTensorData(T* destination, T* source, size_t size); + +// We need to get the right stream for doing CUDA memory transfers and +// operations, which is possibly different from the standard TensorFlow stream. +#if GOOGLE_CUDA +cudaStream_t CudaStreamForMPI(); +#endif + +/* Perform a ring allreduce on the data. Allocate the necessary output tensor + * and store it in the output parameter. + * + * Assumes that all MPI processes are doing an allreduce of the same tensor, + * with the same dimensions. + * + * A ring allreduce is a bandwidth-optimal way to do an allreduce. To do the + * allreduce, the nodes involved are arranged in a ring: + * + * .--0--. + * / \ + * 3 1 + * \ / + * *--2--* + * + * Each node always sends to the next clockwise node in the ring, and receives + * from the previous one. + * + * The allreduce is done in two parts: a scatter-reduce and an allgather. In + * the scatter reduce, a reduction is done, so that each node ends up with a + * chunk of the final output tensor which has contributions from all other + * nodes. In the allgather, those chunks are distributed among all the nodes, + * so that all nodes have the entire output tensor. + * + * Both of these operations are done by dividing the input tensor into N + * evenly sized chunks (where N is the number of nodes in the ring). + * + * The scatter-reduce is done in N-1 steps. In the ith step, node j will send + * the (j - i)th chunk and receive the (j - i - 1)th chunk, adding it in to + * its existing data for that chunk. For example, in the first iteration with + * the ring depicted above, you will have the following transfers: + * + * Segment 0: Node 0 --> Node 1 + * Segment 1: Node 1 --> Node 2 + * Segment 2: Node 2 --> Node 3 + * Segment 3: Node 3 --> Node 0 + * + * In the second iteration, you'll have the following transfers: + * + * Segment 0: Node 1 --> Node 2 + * Segment 1: Node 2 --> Node 3 + * Segment 2: Node 3 --> Node 0 + * Segment 3: Node 0 --> Node 1 + * + * After this iteration, Node 2 has 3 of the four contributions to Segment 0. + * The last iteration has the following transfers: + * + * Segment 0: Node 2 --> Node 3 + * Segment 1: Node 3 --> Node 0 + * Segment 2: Node 0 --> Node 1 + * Segment 3: Node 1 --> Node 2 + * + * After this iteration, Node 3 has the fully accumulated Segment 0; Node 0 + * has the fully accumulated Segment 1; and so on. The scatter-reduce is + * complete. + * + * Next, the allgather distributes these fully accumululated chunks across all + * nodes. Communication proceeds in the same ring, once again in N-1 steps. At + * the ith step, node j will send chunk (j - i + 1) and receive chunk (j - i). + * For example, at the first iteration, the following transfers will occur: + * + * Segment 0: Node 3 --> Node 0 + * Segment 1: Node 0 --> Node 1 + * Segment 2: Node 1 --> Node 2 + * Segment 3: Node 2 --> Node 3 + * + * After the first iteration, Node 0 will have a fully accumulated Segment 0 + * (from Node 3) and Segment 1. In the next iteration, Node 0 will send its + * just-received Segment 0 onward to Node 1, and receive Segment 3 from Node 3. + * After this has continued for N - 1 iterations, all nodes will have a the + * fully accumulated tensor. + * + * Each node will do (N-1) sends for the scatter-reduce and (N-1) sends for the + * allgather. Each send will contain K / N bytes, if there are K bytes in the + * original tensor on every node. Thus, each node sends and receives 2K(N - 1)/N + * bytes of data, and the performance of the allreduce (assuming no latency in + * connections) is constrained by the slowest interconnect between the nodes. + * + */ +template +Status RingAllreduce(OpKernelContext* context, const Tensor* input, + Tensor* temp, Tensor* output) { + // Acquire MPI size and rank + int n, r; + MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n)); + MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r)); + + T* buffer = (T*)output->tensor_data().data(); + + CopyTensorData((void*)buffer, (void*)input->tensor_data().data(), + output->tensor_data().size()); + + // Calculate segment sizes and segment ends + const size_t elements_to_reduce = input->NumElements(); + const size_t segment_size = elements_to_reduce / n; + std::vector segment_sizes(n, segment_size); + + const size_t residual = elements_to_reduce % n; + for (size_t i = 0; i < residual; ++i) { + segment_sizes[i]++; + } + + std::vector segment_starts(n); + segment_starts[0] = 0; + for (size_t i = 1; i < segment_starts.size(); ++i) { + segment_starts[i] = segment_starts[i - 1] + segment_sizes[i - 1]; + } + + assert(segment_starts[n - 1] + segment_sizes[n - 1] == elements_to_reduce); + + T* segment_recv = (T*)temp->tensor_data().data(); + + // Receive from your left neighbor with wrap-around + const size_t recv_from = ((r - 1) + n) % n; + + // Send to your right neighbor with wrap-around + const size_t send_to = (r + 1) % n; + + MPI_Status recv_status; + MPI_Request recv_req; + + // Now start ring. At every step, for every rank, we iterate through + // segments with wraparound and send and recv from our neighbors and reduce + // locally. At the i'th iteration, rank r, sends segment (r-i) and receives + // segment (r-i-1). + for (int i = 0; i < n - 1; i++) { + const size_t send_seg_id = ((r - i) + n) % n; + const size_t recv_seg_id = ((r - i - 1) + n) % n; + + T* segment_send = &(buffer[segment_starts[send_seg_id]]); + + MPI_REQUIRES_OK(MPI_Irecv(segment_recv, segment_sizes[recv_seg_id], + MPIType(), recv_from, TAG_TENSOR, + MPI_COMM_WORLD, &recv_req)); + + MPI_REQUIRES_OK(MPI_Send(segment_send, segment_sizes[send_seg_id], + MPIType(), send_to, TAG_TENSOR, + MPI_COMM_WORLD)); + + T* segment_update = &(buffer[segment_starts[recv_seg_id]]); + + // Wait for recv to complete before reduction + MPI_REQUIRES_OK(MPI_Wait(&recv_req, &recv_status)); + + const size_t recv_seg_size = segment_sizes[recv_seg_id]; + AccumulateTensorData(segment_update, segment_recv, + recv_seg_size); + } + + // Now start pipelined ring allgather. At every step, for every rank, we + // iterate through segments with wraparound and send and recv from our + // neighbors. At the i'th iteration, rank r, sends segment (r-i+1) and + // receives segment (r-i). + for (size_t i = 0; i < n - 1; ++i) { + const size_t send_seg_id = ((r - i + 1) + n) % n; + const size_t recv_seg_id = ((r - i) + n) % n; + + // Segment to send - at every iteration we send segment (r-i+1) + T* segment_send = &(buffer[segment_starts[send_seg_id]]); + + // Segment to recv - at every iteration we receive segment (r-i) + T* segment_recv = &(buffer[segment_starts[recv_seg_id]]); + + MPI_REQUIRES_OK(MPI_Sendrecv( + segment_send, segment_sizes[send_seg_id], MPIType(), send_to, + TAG_TENSOR, segment_recv, segment_sizes[recv_seg_id], MPIType(), + recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status)); + } + + return Status::OK(); +} + +// Perform a ring allgather on a Tensor. Other ranks may allgather with a +// tensor which differs in the first dimension only; all other dimensions must +// be the same. +// +// For more information on the ring allgather, read the documentation for the +// ring allreduce, which includes a ring allgather. +template +Status RingAllgather(OpKernelContext* context, const Tensor* input, + const std::vector& sizes, Tensor* output) { + // Acquire MPI size and rank + int n, r; + MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n)); + MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r)); + + assert(sizes.size() == n); + assert(input->dim_size(0) == sizes[r]); + + // Compute number of elements in every "row". We can't compute number of + // elements in every chunks, because those chunks are variable length. + size_t elements_per_row = 1; + for (int i = 1; i < input->shape().dims(); i++) { + elements_per_row *= input->dim_size(i); + } + + // Copy data from input tensor to correct place in output tensor. + std::vector segment_starts(n); + segment_starts[0] = 0; + for (int i = 1; i < n; i++) { + segment_starts[i] = segment_starts[i - 1] + elements_per_row * sizes[i - 1]; + } + size_t offset = segment_starts[r]; + + // Copy data to the right offset for this rank. + T* buffer = (T*)output->tensor_data().data(); + CopyTensorData((void*)(buffer + offset), + (void*)input->tensor_data().data(), + elements_per_row * sizes[r] * sizeof(T)); + + // Receive from your left neighbor with wrap-around + const size_t recv_from = ((r - 1) + n) % n; + + // Send to your right neighbor with wrap-around + const size_t send_to = (r + 1) % n; + + // Perform a ring allgather. At every step, for every rank, we iterate + // through segments with wraparound and send and recv from our neighbors. + // At the i'th iteration, rank r, sends segment (r-i) and receives segment + // (r-1-i). + MPI_Status recv_status; + for (size_t i = 0; i < n - 1; ++i) { + const size_t send_seg_id = ((r - i) + n) % n; + const size_t recv_seg_id = ((r - i - 1) + n) % n; + + // Segment to send - at every iteration we send segment (r-i) + size_t offset_send = segment_starts[send_seg_id]; + size_t rows_send = sizes[send_seg_id]; + T* segment_send = &(buffer[offset_send]); + + // Segment to recv - at every iteration we receive segment (r-1-i) + size_t offset_recv = segment_starts[recv_seg_id]; + size_t rows_recv = sizes[recv_seg_id]; + T* segment_recv = &(buffer[offset_recv]); + + MPI_REQUIRES_OK(MPI_Sendrecv( + segment_send, elements_per_row * rows_send, MPIType(), send_to, + TAG_TENSOR, segment_recv, elements_per_row * rows_recv, MPIType(), + recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status)); + } + + return Status::OK(); +} + +} // namespace mpi +} // namespace contrib +} // namespace tensorflow + +#endif // TENSORFLOW_USE_MPI + +#undef TENSORFLOW_CONTRIB_MPI_H_ +#endif // TENSORFLOW_CONTRIB_MPI_H_ diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.cc b/tensorflow/contrib/nccl/kernels/nccl_manager.cc index b1cb89391ceaa70813be47cc1bba0c16f4f70e77..99fecf96517935bf3bde3636df83b4a9a4e1c779 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_manager.cc @@ -445,7 +445,7 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { se::Stream* comm_stream = nccl_stream->stream.get(); ScopedActivateExecutorContext scoped_context(nccl_stream->executor); const cudaStream_t* cu_stream = reinterpret_cast( - comm_stream->implementation()->CudaStreamMemberHack()); + comm_stream->implementation()->GpuStreamMemberHack()); while (true) { // Find collective to run. diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops.py b/tensorflow/contrib/nccl/python/ops/nccl_ops.py index 029b01412d96ca03d4ecf7bf4d7d9872864e3ddc..fa597cf3efcf915311047f3a483772c45cc314fd 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops.py @@ -63,12 +63,12 @@ def _all_sum_grad(op, grad): Raises: LookupError: If `reduction` is not `sum`. """ - if op.get_attr('reduction') != 'sum': + if op.get_attr('reduction') != b'sum': raise LookupError('No gradient defined for NcclAllReduce except sum.') _check_device(grad, expected=op.device) num_devices = op.get_attr('num_devices') - shared_name = op.get_attr('shared_name') + '_grad' + shared_name = op.get_attr('shared_name') + b'_grad' with ops.device(op.device): return gen_nccl_ops.nccl_all_reduce( @@ -162,7 +162,7 @@ def _reduce_sum_grad(op, grad): Raises: LookupError: If the reduction attribute of op is not `sum`. """ - if op.get_attr('reduction') != 'sum': + if op.get_attr('reduction') != b'sum': raise LookupError('No gradient defined for NcclReduce except sum.') _check_device(grad, expected=op.device) diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index bbdf962d0480e52045d31f65b3d137ed3f11f2f1..778b710d78a2095b8a1315018641c67419c26b98 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -27,6 +27,7 @@ py_library( "python/training/nadam_optimizer.py", "python/training/powersign.py", "python/training/reg_adagrad_optimizer.py", + "python/training/shampoo.py", "python/training/sign_decay.py", "python/training/variable_clipping_optimizer.py", "python/training/weight_decay_optimizers.py", @@ -344,3 +345,23 @@ py_test( "//third_party/py/numpy", ], ) + +py_test( + name = "shampoo_test", + size = "large", + srcs = ["python/training/shampoo_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index 157ed6a278bb699724d3854426d780a3a58823db..9471fb018162ee377e9c614d6e4d745b4282165a 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -22,17 +22,18 @@ from __future__ import print_function from tensorflow.contrib.opt.python.training.adamax import * from tensorflow.contrib.opt.python.training.addsign import * from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import * +from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * from tensorflow.contrib.opt.python.training.external_optimizer import * +from tensorflow.contrib.opt.python.training.ggt import * from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * +from tensorflow.contrib.opt.python.training.model_average_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * from tensorflow.contrib.opt.python.training.nadam_optimizer import * +from tensorflow.contrib.opt.python.training.shampoo import * from tensorflow.contrib.opt.python.training.weight_decay_optimizers import * from tensorflow.contrib.opt.python.training.powersign import * from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import * -from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * -from tensorflow.contrib.opt.python.training.model_average_optimizer import * -from tensorflow.contrib.opt.python.training.ggt import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented @@ -61,6 +62,7 @@ _allowed_symbols = [ 'ModelAverageOptimizer', 'ModelAverageCustomGetter', 'GGTOptimizer', + 'ShampooOptimizer', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/opt/python/training/addsign_test.py b/tensorflow/contrib/opt/python/training/addsign_test.py index 08d45ed73f3ae4b580d7078272e79fef22ef67c5..628a735e721d2f0c594dd59b5193499dfd7da02e 100644 --- a/tensorflow/contrib/opt/python/training/addsign_test.py +++ b/tensorflow/contrib/opt/python/training/addsign_test.py @@ -214,7 +214,7 @@ class AddSignTest(test.TestCase): # Run 7 steps of AddSign # first 4 steps with positive gradient # last 3 steps with negative gradient (sign(gm) should be -1) - for t in range(1, 4): + for t in range(1, 8): if t < 5: update.run() else: @@ -222,7 +222,7 @@ class AddSignTest(test.TestCase): var0_np, m0 = addsign_update_numpy( var0_np, - grads0_np, + grads0_np if t < 5 else -grads0_np, m0, learning_rate, alpha=alpha, @@ -232,7 +232,7 @@ class AddSignTest(test.TestCase): ) var1_np, m1 = addsign_update_numpy( var1_np, - grads1_np, + grads1_np if t < 5 else -grads1_np, m1, learning_rate, alpha=alpha, diff --git a/tensorflow/contrib/opt/python/training/ggt.py b/tensorflow/contrib/opt/python/training/ggt.py index 928c453517f825ed2d305ec498d07ac29c065f1a..cae952d8f50acbc3a176697fb3989db6c9ac3e9b 100644 --- a/tensorflow/contrib/opt/python/training/ggt.py +++ b/tensorflow/contrib/opt/python/training/ggt.py @@ -33,7 +33,7 @@ class GGTOptimizer(optimizer_v2.OptimizerV2): GGT has an advantage over sgd and adam on large models with poor conditioning, for example language models and CNNs, - see [ABCHSZZ 2018]([pdf](https://arxiv.org/pdf/1806.02958.pdf)). + see [[ABCHSZZ 2018]](https://arxiv.org/pdf/1806.02958.pdf). """ def __init__(self, diff --git a/tensorflow/contrib/opt/python/training/powersign_test.py b/tensorflow/contrib/opt/python/training/powersign_test.py index 5214082dd66f00eadadad71d50f7e00b178b8c10..0bcf5d230a8b7b5b778d233a79922dc34449f8dd 100644 --- a/tensorflow/contrib/opt/python/training/powersign_test.py +++ b/tensorflow/contrib/opt/python/training/powersign_test.py @@ -216,7 +216,7 @@ class PowerSignTest(test.TestCase): self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([3.0, 4.0], var1.eval()) - # Run 3 steps of powersign + # Run 7 steps of powersign # first 4 steps with positive gradient # last 3 steps with negative gradient (sign(gm) should be -1) for t in range(1, 8): diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py new file mode 100644 index 0000000000000000000000000000000000000000..a98866b180330727a8331faf1f1d9b7e398108e4 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/shampoo.py @@ -0,0 +1,474 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""The Shampoo Optimizer. + +Variant of Adagrad using one preconditioner matrix per variable dimension. +For details, see https://arxiv.org/abs/1802.09568 +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.platform import tf_logging +from tensorflow.python.training import optimizer + + +def GetParam(var, timestep): + if callable(var): + return var(timestep) + else: + return var + + +class ShampooOptimizer(optimizer.Optimizer): + """The Shampoo Optimizer + + Variant of Adagrad using one preconditioner matrix per variable dimension. + For details, see https://arxiv.org/abs/1802.09568 + + gbar is time-weighted accumulated gradient: + gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t] + + mat_gbar is time-weighted accumulated gradient square: + mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1] + + mat_gbar_weight[t] * gg_j[t] + where if g[t] = g_abcd then gg_a[t] = g_abcd g_a'bcd (Einstein notation) + + Update rule: + w[t+1] = w[t] - learning_rate[t] * Prod_j mat_gbar_j[t]^(-alpha/n) gbar[t] + Again, mat_gbar_j[t]^(-alpha) gbar[t] is a tensor contraction along the + j'th dimension of gbar[t] with the first dimension of + mat_gbar_j[t]^(-alpha/n), where alpha is a hyperparameter, + and n = rank of the variable. + Prod_j represents doing this contraction for all j in 0..n-1. + + Typically learning_rate is constant, but could be time dependent by passing + a lambda function that depends on step. + """ + + def __init__(self, + global_step=0, + max_matrix_size=768, + gbar_decay=0.0, + gbar_weight=1.0, + mat_gbar_decay=1.0, + mat_gbar_weight=1.0, + learning_rate=1.0, + svd_interval=1, + precond_update_interval=1, + epsilon=0.1, + alpha=0.5, + use_iterative_root=False, + use_locking=False, + name="Shampoo"): + """Default values of the various hyper-parameters. + + gbar_decay, gbar_weight etc. can be a float or a time varying parameter. + For time-varying parameters use e.g. "lambda T: T / (T + 1.0)" + where the expression in the lambda is a tensorflow expression + + Args: + global_step: tensorflow variable indicating the step. + max_matrix_size: We do not perform SVD for matrices larger than this. + gbar_decay: + gbar_weight: Used to update gbar: + gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t] + mat_gbar_decay: + mat_gbar_weight: Used to update mat_gbar: + mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1] + + mat_gbar_weight[t] * gg_j[t] + learning_rate: Similar to SGD + svd_interval: We should do SVD after this many steps. Default = 1, i.e. + every step. Usually 20 leads to no loss of accuracy, and + 50 or 100 is also OK. May also want more often early, + and less often later - set in caller as for example: + "svd_interval = lambda(T): tf.cond( + T < 2000, lambda: 20.0, lambda: 1000.0)" + precond_update_interval: We should update the preconditioners after + this many steps. Default = 1. Usually less than + svd_interval. + epsilon: epsilon * I_n is added to each mat_gbar_j for stability + alpha: total power of the preconditioners. + use_iterative_root: should the optimizer use SVD (faster) or the + iterative root method (for TPU) for finding the + roots of PSD matrices. + use_locking: + name: name of optimizer. + """ + + super(ShampooOptimizer, self).__init__(use_locking, name) + + self._global_step = math_ops.to_float(global_step) + self._max_matrix_size = max_matrix_size + self._gbar_decay = gbar_decay + self._gbar_weight = gbar_weight + self._mat_gbar_decay = mat_gbar_decay + self._mat_gbar_weight = mat_gbar_weight + self._learning_rate = learning_rate + self._svd_interval = svd_interval + self._precond_update_interval = precond_update_interval + self._epsilon = epsilon + self._alpha = alpha + self._use_iterative_root = use_iterative_root + self._name = name + + def _create_slots(self, var_list): + for v in var_list: + with ops.colocate_with(v): + _ = self._zeros_slot(v, "gbar", self._name) + shape = np.array(v.get_shape()) + for i, d in enumerate(shape): + d_tensor = ops.convert_to_tensor(d) + if d < self._max_matrix_size: + mat_g_init = array_ops.zeros_like(linalg_ops.eye(d_tensor)) + if self._svd_interval > 1: + _ = self._get_or_make_slot(v, linalg_ops.eye(d_tensor), + "H_" + str(i), self._name) + else: + mat_g_init = array_ops.zeros([d_tensor]) + + _ = self._get_or_make_slot(v, mat_g_init, "Gbar_" + str(i), + self._name) + + def _resource_apply_dense(self, grad, var): + return self._apply_dense(grad, var) + + def _apply_dense(self, grad, var): + return self._apply_gradient(grad, var) + + def _resource_apply_sparse(self, grad_values, var, grad_indices): + return self._apply_sparse_shared(grad_values, grad_indices, var) + + def _apply_sparse(self, grad, var): + return self._apply_sparse_shared(grad.values, grad.indices, var) + + def _apply_sparse_shared(self, grad_values, grad_indices, var): + if var.get_shape()[0] < self._max_matrix_size or self._gbar_decay != 0.0: + # The dimension is small enough, we can make the variable dense and + # do a dense update + dense_grad = array_ops.scatter_nd( + array_ops.expand_dims(grad_indices, axis=1), grad_values, + array_ops.shape(var, out_type=grad_indices.dtype)) + return self._apply_gradient(dense_grad, var) + return self._apply_gradient(grad_values, var, grad_indices) + + def _weighted_average(self, var, weight, weight_t, rest): + """Computes exponential weighted average: var = weight_t * var + rest. + + Important to ensure that var does not occur in rest, otherwise + we can get race conditions in a distributed setting. + + Args: + var: variable to be updated + weight: parameter to be checked. If it is a constant, we can optimize. + weight_t: current value of parameter, used for weighting + rest: the remaining tensor to be added + + Returns: + updated variable. + """ + if weight == 0.0: + return rest # no need to update var, we will never use it. + if weight == 1.0: # common case + return state_ops.assign_add(var, rest) + # The op below can cause race conditions in a distributed setting, + # since computing weight_t * var + rest can take some time, during + # which var may be set by another worker. To prevent this, it should + # be implemented as a C++ op. + return var.assign_add((weight_t - 1) * var + rest) + + def _update_mat_g(self, mat_g, grad, axes, mat_gbar_decay, + mat_gbar_weight, i): + """Updates the cumulative outer products of the gradients. + + Args: + mat_g: the matrix to be updated + grad: the gradient of the variable + axes: a list of k-1 integers 0 to k-1, except i + mat_gbar_decay: constant for weighted average: + mat_g = mat_g * decay + grad * weight + mat_gbar_weight: constant for weighted average + i: index of dimension to be updated. + + Returns: + updated mat_g = mat_g * mat_gbar_decay + grad_outer * mat_gbar_weight + + In Einstein notation if i = 0: grad_outer_aa'= g_abcd g_a'bcd + thus grad_outer is a matrix d_i x d_i, where d_i is the size of the + i'th dimension of g. + Alternate view: If mat_i(grad) is the flattening of grad to a + d_i x (d_1d_2...d_{i-1}d_{i+1}...d_k) matrix, then + grad_outer = mat_i(grad) mat_i(grad).transpose + """ + grad_outer = math_ops.tensordot(grad, grad, axes=(axes, axes), + name="grad_outer_" + str(i)) + return self._weighted_average(mat_g, self._mat_gbar_decay, mat_gbar_decay, + mat_gbar_weight * grad_outer) + + def _compute_power_svd(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name): + """Computes mat_h = mat_g^alpha using svd. mat_g is a symmetric PSD matrix. + + Args: + var: the variable we are updating. + mat_g: the symmetric PSD matrix whose power it to be computed + mat_g_size: size of mat_g + alpha: a real number + mat_h_slot_name: name of slot to store the power, if needed. + + Returns: + mat_h = mat_g^alpha + + Stores mat_h in the appropriate slot, if it exists. + Note that mat_g is PSD. So we could use linalg_ops.self_adjoint_eig. + """ + if mat_g_size == 1: + mat_h = math_ops.pow(mat_g + self._epsilon, alpha) + else: + damping = self._epsilon * linalg_ops.eye(math_ops.to_int32(mat_g_size)) + diag_d, mat_u, mat_v = linalg_ops.svd(mat_g + damping, full_matrices=True) + mat_h = math_ops.matmul( + mat_v * math_ops.pow(math_ops.maximum(diag_d, self._epsilon), alpha), + array_ops.transpose(mat_u)) + if mat_h_slot_name is not None: + return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h) + return mat_h + + def _compute_power_iter(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name, + iter_count=100, epsilon=1e-6): + """Computes mat_g^alpha, where alpha = -1/p, p a positive integer. + + We use an iterative Schur-Newton method from equation 3.2 on page 9 of: + + A Schur-Newton Method for the Matrix p-th Root and its Inverse + by Chun-Hua Guo and Nicholas J. Higham + SIAM Journal on Matrix Analysis and Applications, + 2006, Vol. 28, No. 3 : pp. 788-804 + https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf + + Args: + var: the variable we are updating. + mat_g: the symmetric PSD matrix whose power it to be computed + mat_g_size: size of mat_g. + alpha: exponent, must be -1/p for p a positive integer. + mat_h_slot_name: name of slot to store the power, if needed. + iter_count: Maximum number of iterations. + epsilon: accuracy indicator, useful for early termination. + + Returns: + mat_g^alpha + """ + + identity = linalg_ops.eye(math_ops.to_int32(mat_g_size)) + + def MatPower(mat_m, p): + """Computes mat_m^p, for p a positive integer. + + Power p is known at graph compile time, so no need for loop and cond. + Args: + mat_m: a square matrix + p: a positive integer + + Returns: + mat_m^p + """ + assert p == int(p) and p > 0 + power = None + while p > 0: + if p % 2 == 1: + power = math_ops.matmul(mat_m, power) if power is not None else mat_m + p //= 2 + mat_m = math_ops.matmul(mat_m, mat_m) + return power + + def IterCondition(i, mat_m, _): + return math_ops.logical_and( + i < iter_count, + math_ops.reduce_max(math_ops.abs(mat_m - identity)) > epsilon) + + def IterBody(i, mat_m, mat_x): + mat_m_i = (1 - alpha) * identity + alpha * mat_m + return (i + 1, math_ops.matmul(MatPower(mat_m_i, -1.0/alpha), mat_m), + math_ops.matmul(mat_x, mat_m_i)) + + if mat_g_size == 1: + mat_h = math_ops.pow(mat_g + self._epsilon, alpha) + else: + damped_mat_g = mat_g + self._epsilon * identity + z = (1 - 1 / alpha) / (2 * linalg_ops.norm(damped_mat_g)) + # The best value for z is + # (1 - 1/alpha) * (c_max^{-alpha} - c_min^{-alpha}) / + # (c_max^{1-alpha} - c_min^{1-alpha}) + # where c_max and c_min are the largest and smallest singular values of + # damped_mat_g. + # The above estimate assumes that c_max > c_min * 2^p. (p = -1/alpha) + # Can replace above line by the one below, but it is less accurate, + # hence needs more iterations to converge. + # z = (1 - 1/alpha) / math_ops.trace(damped_mat_g) + # If we want the method to always converge, use z = 1 / norm(damped_mat_g) + # or z = 1 / math_ops.trace(damped_mat_g), but these can result in many + # extra iterations. + _, _, mat_h = control_flow_ops.while_loop( + IterCondition, IterBody, + [0, damped_mat_g * z, identity * math_ops.pow(z, -alpha)]) + if mat_h_slot_name is not None: + return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h) + return mat_h + + def _compute_power(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name=None): + """Just a switch between the iterative power vs svd.""" + with ops.name_scope("matrix_iterative_power"): + if self._use_iterative_root: + return self._compute_power_iter(var, mat_g, mat_g_size, alpha, + mat_h_slot_name) + else: + return self._compute_power_svd(var, mat_g, mat_g_size, alpha, + mat_h_slot_name) + + def _apply_gradient(self, grad, var, indices=None): + """The main function to update a variable. + + Args: + grad: A Tensor containing gradient to apply. + var: A Tensor containing the variable to update. + indices: An array of integers, for sparse update. + + Returns: + Updated variable var = var - learning_rate * preconditioner * grad + + If the gradient is dense, var and grad have the same shape. + If the update is sparse, then the first dimension of the gradient and var + may differ, others are all the same. In this case the indices array + provides the set of indices of the variable which are to be updated with + each row of the gradient. + """ + global_step = self._global_step + 1 + + # Update accumulated weighted average of gradients + gbar = self.get_slot(var, "gbar") + gbar_decay_t = GetParam(self._gbar_decay, global_step) + gbar_weight_t = GetParam(self._gbar_weight, global_step) + if indices is not None: + # Note - the sparse update is not easily implemented, since the + # algorithm needs all indices of gbar to be updated + # if mat_gbar_decay != 1 or mat_gbar_decay != 0. + # One way to make mat_gbar_decay = 1 is by rescaling. + # If we want the update: + # G_{t+1} = a_{t+1} G_t + b_{t+1} w_t + # define: + # r_{t+1} = a_{t+1} * r_t + # h_t = G_t / r_t + # Then: + # h_{t+1} = h_t + (b_{t+1} / r_{t+1}) * w_t + # So we get the mat_gbar_decay = 1 as desired. + # We can implement this in a future version as needed. + # However we still need gbar_decay = 0, otherwise all indices + # of the variable will need to be updated. + if self._gbar_decay != 0.0: + tf_logging.warning("Not applying momentum for variable: %s" % var.name) + gbar_updated = grad + else: + gbar_updated = self._weighted_average(gbar, self._gbar_decay, + gbar_decay_t, + gbar_weight_t * grad) + + # Update the preconditioners and compute the preconditioned gradient + shape = var.get_shape() + mat_g_list = [] + for i in range(len(shape)): + mat_g_list.append(self.get_slot(var, "Gbar_" + str(i))) + mat_gbar_decay_t = GetParam(self._mat_gbar_decay, global_step) + mat_gbar_weight_t = GetParam(self._mat_gbar_weight, global_step) + + preconditioned_grad = gbar_updated + v_rank = len(mat_g_list) + neg_alpha = - GetParam(self._alpha, global_step) / v_rank + svd_interval = GetParam(self._svd_interval, global_step) + precond_update_interval = GetParam(self._precond_update_interval, + global_step) + for i, mat_g in enumerate(mat_g_list): + # axes is the list of indices to reduce - everything but the current i. + axes = list(range(i)) + list(range(i+1, v_rank)) + if shape[i] < self._max_matrix_size: + # If the tensor size is sufficiently small perform full Shampoo update + # Note if precond_update_interval > 1 and mat_gbar_decay_t != 1, this + # is not strictly correct. However we will use it for now, and + # fix if needed. (G_1 = aG + bg ==> G_n = a^n G + (1+a+..+a^{n-1})bg) + + # pylint: disable=g-long-lambda,cell-var-from-loop + mat_g_updated = control_flow_ops.cond( + math_ops.mod(global_step, precond_update_interval) < 1, + lambda: self._update_mat_g( + mat_g, grad, axes, mat_gbar_decay_t, + mat_gbar_weight_t * precond_update_interval, i), + lambda: mat_g) + + if self._svd_interval == 1: + mat_h = self._compute_power(var, mat_g_updated, shape[i], neg_alpha) + else: + mat_h = control_flow_ops.cond( + math_ops.mod(global_step, svd_interval) < 1, + lambda: self._compute_power(var, mat_g_updated, shape[i], + neg_alpha, "H_" + str(i)), + lambda: self.get_slot(var, "H_" + str(i))) + + # mat_h is a square matrix of size d_i x d_i + # preconditioned_grad is a d_i x ... x d_n x d_0 x ... d_{i-1} tensor + # After contraction with a d_i x d_i tensor + # it becomes a d_{i+1} x ... x d_n x d_0 x ... d_i tensor + # (the first dimension is contracted out, and the second dimension of + # mat_h is appended). After going through all the indices, it becomes + # a d_0 x ... x d_n tensor again. + preconditioned_grad = math_ops.tensordot(preconditioned_grad, mat_h, + axes=([0], [0]), + name="precond_" + str(i)) + else: + # Tensor size is too large -- perform diagonal Shampoo update + grad_outer = math_ops.reduce_sum(grad * grad, axis=axes) + if i == 0 and indices is not None: + assert self._mat_gbar_decay == 1.0 + mat_g_updated = state_ops.scatter_add(mat_g, indices, + mat_gbar_weight_t * grad_outer) + mat_h = math_ops.pow( + array_ops.gather(mat_g_updated, indices) + self._epsilon, + neg_alpha) + else: + mat_g_updated = self._weighted_average(mat_g, + self._mat_gbar_decay, + mat_gbar_decay_t, + mat_gbar_weight_t * grad_outer) + mat_h = math_ops.pow(mat_g_updated + self._epsilon, neg_alpha) + + # Need to do the transpose to ensure that the tensor becomes + # a d_{i+1} x ... x d_n x d_0 x ... d_i tensor as described above. + preconditioned_grad = array_ops.transpose( + preconditioned_grad, perm=list(range(1, v_rank)) + [0]) * mat_h + + # Update the variable based on the Shampoo update + learning_rate_t = GetParam(self._learning_rate, global_step) + if indices is not None: + var_updated = state_ops.scatter_add( + var, indices, -learning_rate_t * preconditioned_grad) + else: + var_updated = state_ops.assign_sub(var, + learning_rate_t * preconditioned_grad) + return var_updated diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2e0a202ae293664d85ece884a505096455cde73c --- /dev/null +++ b/tensorflow/contrib/opt/python/training/shampoo_test.py @@ -0,0 +1,734 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Functional tests for AdaMoo optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.opt.python.training import shampoo +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + +TOLERANCE = 1e-3 + + +def np_power(mat_g, alpha): + """Computes mat_g^alpha for a square symmetric matrix mat_g.""" + + mat_u, diag_d, mat_v = np.linalg.svd(mat_g) + diag_d = np.power(diag_d, alpha) + return np.dot(np.dot(mat_u, np.diag(diag_d)), mat_v) + + +class ShampooTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters(('Var', False), ('ResourceVar', True)) + def testBasicVector(self, use_resource_var): + """Similar to the full Adagrad update.""" + + size = 20 + init_var_np = np.zeros(size) + grad_np = np.random.rand(size) + grad_np_2 = np.random.rand(size) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_g^{-0.5} * grad + # lr = 1 + mat_g = np.outer(grad_np, grad_np) + mat_h = np_power(mat_g + 0.1 * np.eye(size), -0.5) + new_val_np = init_var_np - np.dot(mat_h, grad_np) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g += np.outer(grad_np_2, grad_np_2) + mat_h = np_power(mat_g + 0.1 * np.eye(size), -0.5) + new_val_np -= np.dot(mat_h, grad_np_2) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters(('Var', False), ('ResourceVar', True)) + def testBasicMatrix(self, use_resource_var): + """Check update when gradient is a matrix.""" + size = [10, 5] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1]) + grad_np_2 = np.random.rand(size[0], size[1]) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_g1^{-0.25} * grad * mat_g2^{-0.25} + # lr = 1 + mat_g1 = np.dot(grad_np, grad_np.transpose()) + mat_left = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.25) + mat_g2 = np.dot(grad_np.transpose(), grad_np) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np = init_var_np - np.dot(np.dot(mat_left, grad_np), mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.dot(grad_np_2, grad_np_2.transpose()) + mat_left = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.25) + mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np -= np.dot(np.dot(mat_left, grad_np_2), mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + def _testBasicTensor(self, use_iterative_root, use_resource_var): + """Check update when gradient is a tensor. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1], size[2]) + grad_np_2 = np.random.rand(size[0], size[1], size[2]) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + mat_g1 = np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 = np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 = np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_np, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np = init_var_np - precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_np_2, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testBasicTensor(self, use_iterative_root, use_resource_var): + self._testBasicTensor(use_iterative_root, use_resource_var) + + @parameterized.named_parameters(('Var', False), ('ResourceVar', True)) + def testLargeVector(self, use_resource_var): + """This is just the diagonal Adagrad update.""" + + size = 2000 + init_var_np = np.zeros(size) + grad_np = np.random.rand(size) + grad_np_2 = np.random.rand(size) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * gg^{-0.5} * grad + # lr = 1 + mat_g = grad_np * grad_np + 0.1 + new_val_np = init_var_np - np.power(mat_g, -0.5) * grad_np + + self.assertAllCloseAccordingToType(new_val_np, new_val) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g += grad_np_2 * grad_np_2 + new_val_np -= np.power(mat_g, -0.5) * grad_np_2 + + self.assertAllCloseAccordingToType(new_val_np, new_val) + + @parameterized.named_parameters(('Var', False), ('ResourceVar', True)) + def testLargeMatrix(self, use_resource_var): + """Gradient is a matrix, one of whose dimensions is large. + + We do diagonal updates for large dimensions. + + Args: + use_resource_var: use resource var as variables. + """ + + size = [2000, 3] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1]) + grad_np_2 = np.random.rand(size[0], size[1]) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_left * grad * mat_right + # where the mat_left * grad is just element-wise product, + # with broadcasting + # lr = 1 + + mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True) + mat_left = np.power(mat_g1 + 0.1, -0.25) + mat_g2 = np.dot(grad_np.transpose(), grad_np) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np = init_var_np - np.dot(grad_np * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True) + mat_left = np.power(mat_g1 + 0.1, -0.25) + mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np -= np.dot(grad_np_2 * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters(('Var', False)) + def testSparseUpdateLarge(self, use_resource_var): + """Check update when gradient is of type IndexSlices. + + We do diagonal updates for the first dimension, unless it is very small. + + Args: + use_resource_var: use resource var as variables. + """ + size = [2000, 3] + sample_size_1 = 100 + init_var_np = np.zeros(size) + grad_indices = np.sort(np.random.choice(np.arange(size[0]), sample_size_1, + replace=False)) + grad_np = np.random.rand(sample_size_1, size[1]) + + sample_size_2 = 7 + grad_indices_2 = np.sort(np.random.choice(np.arange(size[0]), sample_size_2, + replace=False)) + grad_np_2 = np.random.rand(sample_size_2, size[1]) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = ops.IndexedSlices( + constant_op.constant(grad_np, dtype=dtypes.float32), + constant_op.constant(grad_indices), + constant_op.constant(size)) + grad_2 = ops.IndexedSlices( + constant_op.constant(grad_np_2, dtype=dtypes.float32), + constant_op.constant(grad_indices_2), + constant_op.constant(size)) + + opt = shampoo.ShampooOptimizer(global_step) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * mat_left * grad * mat_right + # where the mat_left * grad is just element-wise product, + # with broadcasting + # lr = 1 + # In this case the update lr * mat_left * grad * mat_right is + # of size 10 x 2. + # So the correct indices of var need to be updated. + + mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True) + mat_g1_acc = np.zeros((size[0], 1)) + mat_g1_acc[grad_indices] += mat_g1 + mat_left = np.power(mat_g1 + 0.1, -0.25) + mat_g2 = np.dot(grad_np.transpose(), grad_np) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np = init_var_np + new_val_np[grad_indices, :] -= np.dot(grad_np * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 = np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True) + mat_g1_acc[grad_indices_2] += mat_g1 + mat_left = np.power(mat_g1_acc[grad_indices_2] + 0.1, -0.25) + mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) + mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25) + new_val_np[grad_indices_2, :] -= np.dot(grad_np_2 * mat_left, mat_right) + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + def _testSparseUpdateSmall(self, use_iterative_root, use_resource_var): + """Gradient is of type IndexSlices, but the first dimension is small. + + We create dense gradient and do the full update with SVD etc. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + + size = [100, 3, 5] + sample_size = 10 + init_var_np = np.zeros(size) + grad_indices = np.sort(np.random.choice(np.arange(size[0]), sample_size, + replace=False)) + grad_np = np.random.rand(sample_size, size[1], size[2]) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = ops.IndexedSlices( + constant_op.constant(grad_np, dtype=dtypes.float32), + constant_op.constant(grad_indices), + constant_op.constant(size)) + + opt = shampoo.ShampooOptimizer(global_step, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.125} grad + # lr = 1 + grad_dense = np.zeros_like(init_var_np) + grad_dense[grad_indices] = grad_np + + mat_g1 = np.tensordot(grad_dense, grad_dense, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 = np.tensordot(grad_dense, grad_dense, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 = np.tensordot(grad_dense, grad_dense, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_dense, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np = init_var_np - precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testSparseUpdateSmall(self, use_iterative_root, use_resource_var): + self._testSparseUpdateSmall(use_iterative_root, use_resource_var) + + def _testBasicTensorWithMomentum(self, use_iterative_root, use_resource_var): + """Check update with momentum when gradient is a tensor. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size) + grad_np = np.random.rand(size[0], size[1], size[2]) + grad_np_2 = np.random.rand(size[0], size[1], size[2]) + gbar_decay = 0.9 + gbar_weight = 0.1 + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = constant_op.constant(grad_np, dtype=dtypes.float32) + grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) + + opt = shampoo.ShampooOptimizer(global_step, gbar_decay=gbar_decay, + gbar_weight=gbar_weight, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + update_2 = opt.apply_gradients(zip([grad_2], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + # Run a step of Shampoo + update.run() + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + mat_g1 = np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 = np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 = np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + gbar_np = gbar_weight * grad_np + precond_grad = np.tensordot(gbar_np, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np = init_var_np - precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + # Run another step of Shampoo + update_2.run() + new_val = sess.run(var) + + mat_g1 += np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + gbar_np_2 = gbar_decay * gbar_np + gbar_weight * grad_np_2 + precond_grad = np.tensordot(gbar_np_2, mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testBasicTensorWithMomentum(self, use_iterative_root, use_resource_var): + self._testBasicTensorWithMomentum(use_iterative_root, use_resource_var) + + def _testDelayedSVD(self, use_iterative_root, use_resource_var): + """Performing the SVD every nth step. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size).astype(np.float32) + iterations = 20 + svd_interval = 5 + grad_np = np.random.rand( + iterations, size[0], size[1], size[2]).astype(np.float32) + mat_g1_a = np.eye(size[0]) + mat_g1 = np.zeros_like(mat_g1_a) + mat_g2_a = np.eye(size[1]) + mat_g2 = np.zeros_like(mat_g2_a) + mat_g3_a = np.eye(size[2]) + mat_g3 = np.zeros_like(mat_g3_a) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = array_ops.placeholder(dtypes.float32, shape=size) + + opt = shampoo.ShampooOptimizer(global_step, svd_interval=svd_interval, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + new_val_np = init_var_np + + # Run n steps of Shampoo + for i in range(iterations): + _ = sess.run(update, feed_dict={grad: grad_np[i]}) + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + mat_g1 += np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) + mat_g2 += np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) + mat_g3 += np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) + if (i + 1) % svd_interval == 0: + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testDelayedSVD(self, use_iterative_root, use_resource_var): + self._testDelayedSVD(use_iterative_root, use_resource_var) + + def _testDelayedPrecondUpdate(self, use_iterative_root, use_resource_var): + """Update the squared sum every nth step, drop the other steps. + + Args: + use_iterative_root: use iterative power method or SVD to find nth roots. + use_resource_var: use resource var as variables. + """ + size = [10, 5, 7] + init_var_np = np.zeros(size).astype(np.float32) + iterations = 100 + grad_np = np.random.rand( + iterations, size[0], size[1], size[2]).astype(np.float32) + svd_interval = 20 + precond_update_interval = 5 + mat_g1_a = np.eye(size[0]) + mat_g1 = np.zeros_like(mat_g1_a) + mat_g2_a = np.eye(size[1]) + mat_g2 = np.zeros_like(mat_g2_a) + mat_g3_a = np.eye(size[2]) + mat_g3 = np.zeros_like(mat_g3_a) + + with self.test_session() as sess: + global_step = variables.Variable( + 0, dtype=dtypes.int64, use_resource=use_resource_var) + var = variables.Variable( + init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) + grad = array_ops.placeholder(dtypes.float32, shape=size) + + opt = shampoo.ShampooOptimizer( + global_step, svd_interval=svd_interval, + precond_update_interval=precond_update_interval, + use_iterative_root=use_iterative_root) + update = opt.apply_gradients(zip([grad], [var]), + global_step=global_step) + variables.global_variables_initializer().run() + + init_val = sess.run(var) + self.assertAllCloseAccordingToType(init_var_np, init_val) + new_val_np = init_var_np + + # Run n steps of Shampoo + for i in range(iterations): + _ = sess.run(update, feed_dict={grad: grad_np[i]}) + new_val = sess.run(var) + + # let up compute this in numpy + # Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad + # lr = 1 + if (i + 1) % precond_update_interval == 0: + mat_g1 += (np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) + * precond_update_interval) + mat_g2 += (np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) + * precond_update_interval) + mat_g3 += (np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) + * precond_update_interval) + + if (i + 1) % svd_interval == 0: + mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0) + mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0) + mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0) + + precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0])) + precond_grad = np.tensordot(precond_grad, mat_g3_a, axes=([0], [0])) + new_val_np -= precond_grad + + self.assertAllCloseAccordingToType(new_val_np, new_val, + atol=TOLERANCE, rtol=TOLERANCE) + + @parameterized.named_parameters( + ('SVDWithVar', False, False), + ('SVDWithResourceVar', False, True), + ('IterRootWithVar', True, False), + ('IterRootWithResourceVar', True, True), + ) + def testDelayedPrecondUpdate(self, use_iterative_root, use_resource_var): + self._testDelayedPrecondUpdate(use_iterative_root, use_resource_var) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py index 8aa40aeb45d4ec15140bdfc5ebd824e8aa08d8d9..b9cf40eb7b2d11c98b93c51213145ca4e2670318 100644 --- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py @@ -19,13 +19,13 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops -from tensorflow.python.training import optimizer from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops from tensorflow.python.training import adam from tensorflow.python.training import momentum as momentum_opt +from tensorflow.python.training import optimizer from tensorflow.python.util.tf_export import tf_export -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import resource_variable_ops class DecoupledWeightDecayExtension(object): @@ -65,7 +65,7 @@ class DecoupledWeightDecayExtension(object): Args: weight_decay: A `Tensor` or a floating point value, the factor by which a variable is decayed in the update step. - decay_var_list: Optional list or tuple or set of `Variable` objects to + **kwargs: Optional list or tuple or set of `Variable` objects to decay. """ self._decay_var_list = None # is set in minimize or apply_gradients @@ -85,6 +85,28 @@ class DecoupledWeightDecayExtension(object): If decay_var_list is None, all variables in var_list are decayed. For more information see the documentation of Optimizer.minimize. + + Args: + loss: A `Tensor` containing the value to minimize. + global_step: Optional `Variable` to increment by one after the + variables have been updated. + var_list: Optional list or tuple of `Variable` objects to update to + minimize `loss`. Defaults to the list of variables collected in + the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. + gate_gradients: How to gate the computation of gradients. Can be + `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. + aggregation_method: Specifies the method used to combine gradient terms. + Valid values are defined in the class `AggregationMethod`. + colocate_gradients_with_ops: If True, try colocating gradients with + the corresponding op. + name: Optional name for the returned operation. + grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. + decay_var_list: Optional list of decay variables. + + Returns: + An Operation that updates the variables in `var_list`. If `global_step` + was not `None`, that operation also increments `global_step`. + """ self._decay_var_list = set(decay_var_list) if decay_var_list else False return super(DecoupledWeightDecayExtension, self).minimize( @@ -103,6 +125,19 @@ class DecoupledWeightDecayExtension(object): are decayed. For more information see the documentation of Optimizer.apply_gradients. + + Args: + grads_and_vars: List of (gradient, variable) pairs as returned by + `compute_gradients()`. + global_step: Optional `Variable` to increment by one after the + variables have been updated. + name: Optional name for the returned operation. Default to the + name passed to the `Optimizer` constructor. + decay_var_list: Optional list of decay variables. + + Returns: + An `Operation` that applies the specified gradients. If `global_step` + was not None, that operation also increments `global_step`. """ self._decay_var_list = set(decay_var_list) if decay_var_list else False return super(DecoupledWeightDecayExtension, self).apply_gradients( @@ -197,6 +232,7 @@ def extend_with_decoupled_weight_decay(base_optimizer): A new optimizer class that inherits from DecoupledWeightDecayExtension and base_optimizer. """ + class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, base_optimizer): """Base_optimizer with decoupled weight decay. diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py index 74d1cdbbdac8724518937d141a976abf9fec6ce3..76d8a5697acb79e7748175c4a81dfdd85807dd49 100644 --- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.opt.python.training import weight_decay_optimizers from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -29,7 +30,6 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import adam -from tensorflow.contrib.opt.python.training import weight_decay_optimizers WEIGHT_DECAY = 0.01 @@ -91,7 +91,6 @@ class WeightDecayOptimizerTest(test.TestCase): opt = optimizer() update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - if not context.executing_eagerly(): with ops.Graph().as_default(): # Shouldn't return non-slot variables from other graphs. @@ -171,9 +170,9 @@ class ExtendWithWeightDecayTest(WeightDecayOptimizerTest): @staticmethod def get_optimizer(): - AdamW = weight_decay_optimizers.extend_with_decoupled_weight_decay( + adamw = weight_decay_optimizers.extend_with_decoupled_weight_decay( adam.AdamOptimizer) - return AdamW(WEIGHT_DECAY) + return adamw(WEIGHT_DECAY) def testBasic(self): self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m", @@ -185,6 +184,5 @@ class ExtendWithWeightDecayTest(WeightDecayOptimizerTest): use_resource=True) - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/optimizer_v2/BUILD b/tensorflow/contrib/optimizer_v2/BUILD index 5225ecc14fef3cec9506eceb776805b74a87714e..3ba3ee29ec79687df522eb330665a2ce80061682 100644 --- a/tensorflow/contrib/optimizer_v2/BUILD +++ b/tensorflow/contrib/optimizer_v2/BUILD @@ -193,6 +193,7 @@ cuda_py_test( srcs = ["rmsprop_test.py"], additional_deps = [ ":training", + "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", "//tensorflow/python:embedding_ops", "//tensorflow/python:framework", diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index 06ab58188a2fffa0e3a810d451875ca951a077b9..28a531dfecf275c48fea54310b93b5266a79899a 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -41,6 +41,7 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import saver as core_saver from tensorflow.python.training import training_util from tensorflow.python.training.checkpointable import tracking @@ -278,7 +279,8 @@ class CheckpointingTests(test.TestCase): root = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step()) - root.restore(core_saver.latest_checkpoint(checkpoint_directory)) + root.restore(checkpoint_management.latest_checkpoint( + checkpoint_directory)) for _ in range(num_training_steps): # TODO(allenl): Use a Dataset and serialize/checkpoint it. input_value = constant_op.constant([[3.]]) @@ -306,7 +308,8 @@ class CheckpointingTests(test.TestCase): train_op = optimizer.minimize( model(input_value), global_step=root.global_step) - checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) with self.test_session(graph=ops.get_default_graph()) as session: status = root.restore(save_path=checkpoint_path) status.initialize_or_restore(session=session) @@ -339,7 +342,8 @@ class CheckpointingTests(test.TestCase): root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) - checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) status = root.restore(save_path=checkpoint_path) input_value = constant_op.constant([[3.]]) train_fn = functools.partial( @@ -372,7 +376,8 @@ class CheckpointingTests(test.TestCase): root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) - checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + checkpoint_path = checkpoint_management.latest_checkpoint( + checkpoint_directory) status = root.restore(save_path=checkpoint_path) def train_fn(): @function.defun diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index c6f3bd6ee18fa353944e2fc303573894933f5b27..8c11d8bcfdf76bc12e13ffb58f917978e966476e 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -766,7 +766,8 @@ class OptimizerV2(optimizer_v1.Optimizer): # *after* loss() is evaluated, so we know what loss reduction it uses. if scale_loss_by_num_towers is None: scale_loss_by_num_towers = ( - distribute_lib.get_loss_reduction() == "mean") + distribute_lib.get_loss_reduction() == + variable_scope.VariableAggregation.MEAN) if scale_loss_by_num_towers: num_towers = distribute_lib.get_distribution_strategy().num_towers if num_towers > 1: @@ -784,7 +785,8 @@ class OptimizerV2(optimizer_v1.Optimizer): # Scale loss for number of towers (non-callable-loss case). if scale_loss_by_num_towers is None: scale_loss_by_num_towers = ( - distribute_lib.get_loss_reduction() == "mean") + distribute_lib.get_loss_reduction() == + variable_scope.VariableAggregation.MEAN) if scale_loss_by_num_towers: num_towers = distribute_lib.get_distribution_strategy().num_towers if num_towers > 1: @@ -896,7 +898,8 @@ class OptimizerV2(optimizer_v1.Optimizer): def _distributed_apply(self, distribution, grads_and_vars, global_step, name): """`apply_gradients` for use with a `DistributionStrategy`.""" - reduced_grads = distribution.batch_reduce("sum", grads_and_vars) + reduced_grads = distribution.batch_reduce( + variable_scope.VariableAggregation.SUM, grads_and_vars) var_list = [v for _, v in grads_and_vars] grads_and_vars = zip(reduced_grads, var_list) diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py index ec033c4a0163ba9ed39e55fa9e92dfdadc9a1b2f..a44bfd1bfd97e678fbf4c402ef5b1298dc518f75 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py @@ -38,12 +38,8 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testBasic(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - # Note that we name the variables uniquely here since the variables don't - # seem to be getting deleted at the end of the loop. - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, - name='a_%d' % i) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, - name='b_%d' % i) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) def loss(): return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop # Note that for eager execution, minimize expects a function instead of a @@ -131,12 +127,8 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testNoGradients(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - # Note that we name the variables uniquely here since the variables don't - # seem to be getting deleted at the end of the loop. - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, - name='a%d' % i) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, - name='b%d' % i) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) # pylint: disable=cell-var-from-loop def loss(): return 5 * var0 @@ -149,12 +141,8 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_Minimize(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - # Note that we name the variables uniquely here since the variables don't - # seem to be getting deleted at the end of the loop. - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, - name='a_%d' % i) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, - name='b_%d' % i) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) def loss(): return constant_op.constant(5.0) sgd_op = gradient_descent.GradientDescentOptimizer(3.0) @@ -165,12 +153,8 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testNoGradientsForAnyVariables_ApplyGradients(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - # Note that we name the variables uniquely here since the variables don't - # seem to be getting deleted at the end of the loop. - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, - name='a_%d' % i) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, - name='b_%d' % i) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) sgd_op = gradient_descent.GradientDescentOptimizer(3.0) with self.assertRaisesRegexp(ValueError, 'No gradients provided for any variable'): @@ -179,12 +163,8 @@ class OptimizerTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testGradientsAsVariables(self): for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): - # Note that we name the variables uniquely here since the variables don't - # seem to be getting deleted at the end of the loop. - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype, - name='a%d' % i) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype, - name='b%d' % i) + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) def loss(): return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop sgd_op = gradient_descent.GradientDescentOptimizer(3.0) diff --git a/tensorflow/contrib/optimizer_v2/rmsprop_test.py b/tensorflow/contrib/optimizer_v2/rmsprop_test.py index ed68f6afbf8bf9678649c1ce6fc59c3b91026dc0..dc23ef241a43900ed40f029f1b857820459e43d0 100644 --- a/tensorflow/contrib/optimizer_v2/rmsprop_test.py +++ b/tensorflow/contrib/optimizer_v2/rmsprop_test.py @@ -19,15 +19,16 @@ from __future__ import division from __future__ import print_function import copy -import itertools import math +from absl.testing import parameterized import numpy as np from tensorflow.contrib.optimizer_v2 import rmsprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -48,13 +49,8 @@ _TEST_PARAM_VALUES = [ [0.5, 0.95, 0.9, 1e-5, True, False], ] -_TESTPARAMS = [ - [data_type] + values - for data_type, values in itertools.product(_DATA_TYPES, _TEST_PARAM_VALUES) -] - -class RMSPropOptimizerTest(test.TestCase): +class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase): def _rmsprop_update_numpy(self, var, g, mg, rms, mom, lr, decay, momentum, epsilon, centered): @@ -87,362 +83,366 @@ class RMSPropOptimizerTest(test.TestCase): var_t[gindex] = var[gindex] - mom_t[gindex] return var_t, mg_t, rms_t, mom_t - def testDense(self): - # TODO(yori): Use ParameterizedTest when available - for (dtype, learning_rate, decay, momentum, - epsilon, centered, use_resource) in _TESTPARAMS: - with self.test_session(use_gpu=True): - # Initialize variables for numpy implementation. - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1, 0.2], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01, 0.2], dtype=dtype.as_numpy_dtype) - - if use_resource: - var0 = resource_variable_ops.ResourceVariable(var0_np) - var1 = resource_variable_ops.ResourceVariable(var1_np) - else: - var0 = variables.Variable(var0_np) - var1 = variables.Variable(var1_np) - grads0 = constant_op.constant(grads0_np) - grads1 = constant_op.constant(grads1_np) - opt = rmsprop.RMSPropOptimizer( - learning_rate=learning_rate, - decay=decay, - momentum=momentum, - epsilon=epsilon, - centered=centered) - - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - - mg0 = opt.get_slot(var0, "mg") - self.assertEqual(mg0 is not None, centered) - mg1 = opt.get_slot(var1, "mg") - self.assertEqual(mg1 is not None, centered) - rms0 = opt.get_slot(var0, "rms") - self.assertTrue(rms0 is not None) - rms1 = opt.get_slot(var1, "rms") - self.assertTrue(rms1 is not None) - mom0 = opt.get_slot(var0, "momentum") - self.assertTrue(mom0 is not None) - mom1 = opt.get_slot(var1, "momentum") - self.assertTrue(mom1 is not None) - - mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) - rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) - mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - - # Run 4 steps of RMSProp - for _ in range(1, 5): - update.run() - - var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy( - var0_np, grads0_np, mg0_np, rms0_np, mom0_np, learning_rate, - decay, momentum, epsilon, centered) - var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy( - var1_np, grads1_np, mg1_np, rms1_np, mom1_np, learning_rate, - decay, momentum, epsilon, centered) - - # Validate updated params - if centered: - self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) - self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) - self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) - self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) - self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) - self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) - - def testMinimizeSparseResourceVariable(self): - for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) - x = constant_op.constant([[4.0], [5.0]], dtype=dtype) - pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) - loss = pred * pred - sgd_op = rmsprop.RMSPropOptimizer( - learning_rate=1.0, - decay=0.0, - momentum=0.0, - epsilon=0.0, - centered=False).minimize(loss) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType( - [[0., 1.]], var0.eval(), atol=0.01) - - def testMinimizeSparseResourceVariableCentered(self): - for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) - x = constant_op.constant([[4.0], [5.0]], dtype=dtype) - pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) - loss = pred * pred - sgd_op = rmsprop.RMSPropOptimizer( - learning_rate=1.0, - decay=0.0, - momentum=0.0, - epsilon=1.0, - centered=True).minimize(loss) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType( - [[-111, -138]], var0.eval(), atol=0.01) - - def testSparse(self): - # TODO(yori): Use ParameterizedTest when available - for (dtype, learning_rate, decay, - momentum, epsilon, centered, _) in _TESTPARAMS: - with self.test_session(use_gpu=True): - # Initialize variables for numpy implementation. - var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) - grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) - grads1_np = np.array([0.01], dtype=dtype.as_numpy_dtype) - + @parameterized.named_parameters( + *test_util.generate_combinations_with_testcase_name( + dtype=_DATA_TYPES, param_value=_TEST_PARAM_VALUES)) + def testDense(self, dtype, param_value): + (learning_rate, decay, momentum, epsilon, centered, use_resource) = tuple( + param_value) + with self.test_session(use_gpu=True): + # Initialize variables for numpy implementation. + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.2], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.2], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: var0 = variables.Variable(var0_np) var1 = variables.Variable(var1_np) - grads0_np_indices = np.array([0], dtype=np.int32) - grads0 = ops.IndexedSlices( - constant_op.constant(grads0_np), - constant_op.constant(grads0_np_indices), constant_op.constant([1])) - grads1_np_indices = np.array([1], dtype=np.int32) - grads1 = ops.IndexedSlices( - constant_op.constant(grads1_np), - constant_op.constant(grads1_np_indices), constant_op.constant([1])) - opt = rmsprop.RMSPropOptimizer( - learning_rate=learning_rate, - decay=decay, - momentum=momentum, - epsilon=epsilon, - centered=centered) - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - - mg0 = opt.get_slot(var0, "mg") - self.assertEqual(mg0 is not None, centered) - mg1 = opt.get_slot(var1, "mg") - self.assertEqual(mg1 is not None, centered) - rms0 = opt.get_slot(var0, "rms") - self.assertTrue(rms0 is not None) - rms1 = opt.get_slot(var1, "rms") - self.assertTrue(rms1 is not None) - mom0 = opt.get_slot(var0, "momentum") - self.assertTrue(mom0 is not None) - mom1 = opt.get_slot(var1, "momentum") - self.assertTrue(mom1 is not None) - - mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) - rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) - mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - - # Run 4 steps of RMSProp - for _ in range(1, 5): - update.run() - - var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy( - var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np, mom0_np, - learning_rate, decay, momentum, epsilon, centered) - var1_np, mg1_np, rms1_np, mom1_np = self._sparse_rmsprop_update_numpy( - var1_np, grads1_np_indices, grads1_np, mg1_np, rms1_np, mom1_np, - learning_rate, decay, momentum, epsilon, centered) - - # Validate updated params - if centered: - self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) - self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) - self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) - self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) - self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) - self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) - - def testWithoutMomentum(self): - for dtype in [dtypes.half, dtypes.float32]: - with self.test_session(use_gpu=True): - var0 = variables.Variable([1.0, 2.0], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - opt = rmsprop.RMSPropOptimizer( - learning_rate=2.0, decay=0.9, momentum=0.0, epsilon=1.0) - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - - rms0 = opt.get_slot(var0, "rms") - self.assertTrue(rms0 is not None) - rms1 = opt.get_slot(var1, "rms") - self.assertTrue(rms1 is not None) - mom0 = opt.get_slot(var0, "momentum") - self.assertTrue(mom0 is not None) - mom1 = opt.get_slot(var1, "momentum") - self.assertTrue(mom1 is not None) - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - # Step 1: the rms accumulators where 1. So we should see a normal - # update: v -= grad * learning_rate - update.run() - # Check the root mean square accumulators. - self.assertAllCloseAccordingToType( - np.array([0.901, 0.901]), rms0.eval()) - self.assertAllCloseAccordingToType( - np.array([0.90001, 0.90001]), rms1.eval()) - # Check the parameters. - self.assertAllCloseAccordingToType( - np.array([ - 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)), - 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - ]), var0.eval()) - self.assertAllCloseAccordingToType( - np.array([ - 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)), - 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - ]), var1.eval()) - # Step 2: the root mean square accumulators contain the previous update. - update.run() - # Check the rms accumulators. - self.assertAllCloseAccordingToType( - np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval()) - self.assertAllCloseAccordingToType( - np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval()) - # Check the parameters. - self.assertAllCloseAccordingToType( - np.array([ - 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)), - 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)) - ]), var0.eval()) - self.assertAllCloseAccordingToType( - np.array([ - 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)), - 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)) - ]), var1.eval()) - - def testWithMomentum(self): - for dtype in [dtypes.half, dtypes.float32]: - with self.test_session(use_gpu=True): - var0 = variables.Variable([1.0, 2.0], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - - opt = rmsprop.RMSPropOptimizer( - learning_rate=2.0, decay=0.9, momentum=0.5, epsilon=1e-5) - update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - - rms0 = opt.get_slot(var0, "rms") - self.assertTrue(rms0 is not None) - rms1 = opt.get_slot(var1, "rms") - self.assertTrue(rms1 is not None) - mom0 = opt.get_slot(var0, "momentum") - self.assertTrue(mom0 is not None) - mom1 = opt.get_slot(var1, "momentum") - self.assertTrue(mom1 is not None) - - # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) - # Step 1: rms = 1, mom = 0. So we should see a normal - # update: v -= grad * learning_rate + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = rmsprop.RMSPropOptimizer( + learning_rate=learning_rate, + decay=decay, + momentum=momentum, + epsilon=epsilon, + centered=centered) + + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + mg0 = opt.get_slot(var0, "mg") + self.assertEqual(mg0 is not None, centered) + mg1 = opt.get_slot(var1, "mg") + self.assertEqual(mg1 is not None, centered) + rms0 = opt.get_slot(var0, "rms") + self.assertIsNotNone(rms0) + rms1 = opt.get_slot(var1, "rms") + self.assertIsNotNone(rms1) + mom0 = opt.get_slot(var0, "momentum") + self.assertIsNotNone(mom0) + mom1 = opt.get_slot(var1, "momentum") + self.assertIsNotNone(mom1) + + mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 4 steps of RMSProp + for _ in range(4): update.run() - # Check the root mean square accumulators. - self.assertAllCloseAccordingToType( - np.array([0.901, 0.901]), rms0.eval()) - self.assertAllCloseAccordingToType( - np.array([0.90001, 0.90001]), rms1.eval()) - # Check the momentum accumulators - self.assertAllCloseAccordingToType( - np.array([(0.1 * 2.0 / math.sqrt(0.901 + 1e-5)), - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5))]), mom0.eval()) - self.assertAllCloseAccordingToType( - np.array([(0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)), - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5))]), mom1.eval()) - - # Check that the parameters. - self.assertAllCloseAccordingToType( - np.array([ - 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)), - 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - ]), var0.eval()) - self.assertAllCloseAccordingToType( - np.array([ - 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)), - 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - ]), var1.eval()) - - # Step 2: the root mean square accumulators contain the previous update. + + var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy( + var0_np, grads0_np, mg0_np, rms0_np, mom0_np, learning_rate, + decay, momentum, epsilon, centered) + var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy( + var1_np, grads1_np, mg1_np, rms1_np, mom1_np, learning_rate, + decay, momentum, epsilon, centered) + + # Validate updated params + if centered: + self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) + self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) + self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) + self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) + self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) + self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + @parameterized.parameters([dtypes.float32, dtypes.float64]) + def testMinimizeSparseResourceVariable(self, dtype): + with self.test_session(): + var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + x = constant_op.constant([[4.0], [5.0]], dtype=dtype) + pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) + loss = pred * pred + sgd_op = rmsprop.RMSPropOptimizer( + learning_rate=1.0, + decay=0.0, + momentum=0.0, + epsilon=0.0, + centered=False).minimize(loss) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType( + [[0., 1.]], var0.eval(), atol=0.01) + + @parameterized.parameters([dtypes.float32, dtypes.float64]) + def testMinimizeSparseResourceVariableCentered(self, dtype): + with self.test_session(): + var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype) + x = constant_op.constant([[4.0], [5.0]], dtype=dtype) + pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x) + loss = pred * pred + sgd_op = rmsprop.RMSPropOptimizer( + learning_rate=1.0, + decay=0.0, + momentum=0.0, + epsilon=1.0, + centered=True).minimize(loss) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType( + [[-111, -138]], var0.eval(), atol=0.01) + + @parameterized.named_parameters( + *test_util.generate_combinations_with_testcase_name( + dtype=_DATA_TYPES, param_value=_TEST_PARAM_VALUES)) + def testSparse(self, dtype, param_value): + (learning_rate, decay, momentum, epsilon, centered, _) = tuple( + param_value) + with self.test_session(use_gpu=True): + # Initialize variables for numpy implementation. + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0_np_indices = np.array([0], dtype=np.int32) + grads0 = ops.IndexedSlices( + constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), constant_op.constant([1])) + grads1_np_indices = np.array([1], dtype=np.int32) + grads1 = ops.IndexedSlices( + constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), constant_op.constant([1])) + opt = rmsprop.RMSPropOptimizer( + learning_rate=learning_rate, + decay=decay, + momentum=momentum, + epsilon=epsilon, + centered=centered) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + mg0 = opt.get_slot(var0, "mg") + self.assertEqual(mg0 is not None, centered) + mg1 = opt.get_slot(var1, "mg") + self.assertEqual(mg1 is not None, centered) + rms0 = opt.get_slot(var0, "rms") + self.assertIsNotNone(rms0) + rms1 = opt.get_slot(var1, "rms") + self.assertIsNotNone(rms1) + mom0 = opt.get_slot(var0, "momentum") + self.assertIsNotNone(mom0) + mom1 = opt.get_slot(var1, "momentum") + self.assertIsNotNone(mom1) + + mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 4 steps of RMSProp + for _ in range(4): update.run() - # Check the rms accumulators. - self.assertAllCloseAccordingToType( - np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval()) - self.assertAllCloseAccordingToType( - np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval()) - self.assertAllCloseAccordingToType( - np.array([ - 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)), - 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)) - ]), mom0.eval()) - self.assertAllCloseAccordingToType( - np.array([ - 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)), - 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)) - ]), mom1.eval()) - - # Check the parameters. - self.assertAllCloseAccordingToType( - np.array([ - 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - - (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))), - 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - - (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + - (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))) - ]), var0.eval()) - - self.assertAllCloseAccordingToType( - np.array([ - 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - - (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))), - 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - - (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + - (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))) - ]), var1.eval()) + + var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy( + var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np, mom0_np, + learning_rate, decay, momentum, epsilon, centered) + var1_np, mg1_np, rms1_np, mom1_np = self._sparse_rmsprop_update_numpy( + var1_np, grads1_np_indices, grads1_np, mg1_np, rms1_np, mom1_np, + learning_rate, decay, momentum, epsilon, centered) + + # Validate updated params + if centered: + self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) + self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) + self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) + self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) + self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) + self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + @parameterized.parameters(_DATA_TYPES) + def testWithoutMomentum(self, dtype): + with self.test_session(use_gpu=True): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + opt = rmsprop.RMSPropOptimizer( + learning_rate=2.0, decay=0.9, momentum=0.0, epsilon=1.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + rms0 = opt.get_slot(var0, "rms") + self.assertIsNotNone(rms0) + rms1 = opt.get_slot(var1, "rms") + self.assertIsNotNone(rms1) + mom0 = opt.get_slot(var0, "momentum") + self.assertIsNotNone(mom0) + mom1 = opt.get_slot(var1, "momentum") + self.assertIsNotNone(mom1) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Step 1: the rms accumulators where 1. So we should see a normal + # update: v -= grad * learning_rate + update.run() + # Check the root mean square accumulators. + self.assertAllCloseAccordingToType( + np.array([0.901, 0.901]), rms0.eval()) + self.assertAllCloseAccordingToType( + np.array([0.90001, 0.90001]), rms1.eval()) + # Check the parameters. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) + ]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) + ]), var1.eval()) + # Step 2: the root mean square accumulators contain the previous update. + update.run() + # Check the rms accumulators. + self.assertAllCloseAccordingToType( + np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval()) + self.assertAllCloseAccordingToType( + np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval()) + # Check the parameters. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) - + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)) + ]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) - + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)) + ]), var1.eval()) + + @parameterized.parameters(_DATA_TYPES) + def testWithMomentum(self, dtype): + with self.test_session(use_gpu=True): + var0 = variables.Variable([1.0, 2.0], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + + opt = rmsprop.RMSPropOptimizer( + learning_rate=2.0, decay=0.9, momentum=0.5, epsilon=1e-5) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + rms0 = opt.get_slot(var0, "rms") + self.assertIsNotNone(rms0) + rms1 = opt.get_slot(var1, "rms") + self.assertIsNotNone(rms1) + mom0 = opt.get_slot(var0, "momentum") + self.assertIsNotNone(mom0) + mom1 = opt.get_slot(var1, "momentum") + self.assertIsNotNone(mom1) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + # Step 1: rms = 1, mom = 0. So we should see a normal + # update: v -= grad * learning_rate + update.run() + # Check the root mean square accumulators. + self.assertAllCloseAccordingToType( + np.array([0.901, 0.901]), rms0.eval()) + self.assertAllCloseAccordingToType( + np.array([0.90001, 0.90001]), rms1.eval()) + # Check the momentum accumulators + self.assertAllCloseAccordingToType( + np.array([(0.1 * 2.0 / math.sqrt(0.901 + 1e-5)), + (0.1 * 2.0 / math.sqrt(0.901 + 1e-5))]), mom0.eval()) + self.assertAllCloseAccordingToType( + np.array([(0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)), + (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5))]), mom1.eval()) + + # Check that the parameters. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + ]), var0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + ]), var1.eval()) + + # Step 2: the root mean square accumulators contain the previous update. + update.run() + # Check the rms accumulators. + self.assertAllCloseAccordingToType( + np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval()) + self.assertAllCloseAccordingToType( + np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)), + 0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5)) + ]), mom0.eval()) + self.assertAllCloseAccordingToType( + np.array([ + 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)), + 0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)) + ]), mom1.eval()) + + # Check the parameters. + self.assertAllCloseAccordingToType( + np.array([ + 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - + (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))), + 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) - + (0.5 * (0.1 * 2.0 / math.sqrt(0.901 + 1e-5)) + + (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1e-5))) + ]), var0.eval()) + + self.assertAllCloseAccordingToType( + np.array([ + 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - + (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))), + 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) - + (0.5 * (0.01 * 2.0 / math.sqrt(0.90001 + 1e-5)) + + (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5))) + ]), var1.eval()) if __name__ == "__main__": diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor.py b/tensorflow/contrib/predictor/contrib_estimator_predictor.py index af3b2ad1b531b835f484a155efcc57bbe634f2df..c2166594e598857065a7fd109ec599a3b36e2d2c 100644 --- a/tensorflow/contrib/predictor/contrib_estimator_predictor.py +++ b/tensorflow/contrib/predictor/contrib_estimator_predictor.py @@ -22,8 +22,8 @@ from __future__ import print_function from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils from tensorflow.contrib.predictor import predictor from tensorflow.python.framework import ops +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import monitored_session -from tensorflow.python.training import saver class ContribEstimatorPredictor(predictor.Predictor): @@ -57,7 +57,8 @@ class ContribEstimatorPredictor(predictor.Predictor): # pylint: disable=protected-access model_fn_ops = estimator._get_predict_ops(input_fn_ops.features) # pylint: enable=protected-access - checkpoint_path = saver.latest_checkpoint(estimator.model_dir) + checkpoint_path = checkpoint_management.latest_checkpoint( + estimator.model_dir) self._session = monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( config=config, diff --git a/tensorflow/contrib/predictor/predictor_factories.py b/tensorflow/contrib/predictor/predictor_factories.py index f275bc15adfa0a51a48964dff8edddbd45500e45..7886744b3ce7fc438bc73cb81bccfd0ddeea873e 100644 --- a/tensorflow/contrib/predictor/predictor_factories.py +++ b/tensorflow/contrib/predictor/predictor_factories.py @@ -108,6 +108,8 @@ def from_estimator(estimator, def from_saved_model(export_dir, signature_def_key=None, signature_def=None, + input_names=None, + output_names=None, tags=None, graph=None, config=None): @@ -121,6 +123,12 @@ def from_saved_model(export_dir, signature_def: A `SignatureDef` proto specifying the inputs and outputs for prediction. Only one of `signature_def_key` and `signature_def` should be specified. + input_names: A dictionary mapping strings to `Tensor`s in the `SavedModel` + that represent the input. The keys can be any string of the user's + choosing. + output_names: A dictionary mapping strings to `Tensor`s in the + `SavedModel` that represent the output. The keys can be any string of + the user's choosing. tags: Optional. Tags that will be used to retrieve the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`. graph: Optional. The Tensorflow `graph` in which prediction should be @@ -138,6 +146,8 @@ def from_saved_model(export_dir, export_dir, signature_def_key=signature_def_key, signature_def=signature_def, + input_names=input_names, + output_names=output_names, tags=tags, graph=graph, config=config) diff --git a/tensorflow/contrib/proto/BUILD b/tensorflow/contrib/proto/BUILD index 3e9b1a0b8d8ec7c3c5fe5d1f2cf896dbb6c3de72..b27142cf4a6413eccb8489ea3eb775060ffd787b 100644 --- a/tensorflow/contrib/proto/BUILD +++ b/tensorflow/contrib/proto/BUILD @@ -16,17 +16,3 @@ py_library( "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", ], ) - -py_library( - name = "proto_pip", - data = [ - "//tensorflow/contrib/proto/python/kernel_tests:test_messages", - ] + if_static( - [], - otherwise = ["//tensorflow/contrib/proto/python/kernel_tests:libtestexample.so"], - ), - deps = [ - ":proto", - "//tensorflow/contrib/proto/python/kernel_tests:py_test_deps", - ], -) diff --git a/tensorflow/contrib/proto/python/kernel_tests/BUILD b/tensorflow/contrib/proto/python/kernel_tests/BUILD index a380a131f86abc8dd921a123afdb964bf6c2466c..125c1cee292092e55bc17294a29f175c8cc3999c 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/BUILD +++ b/tensorflow/contrib/proto/python/kernel_tests/BUILD @@ -4,47 +4,41 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -# Much of the work in this BUILD file actually happens in the corresponding -# build_defs.bzl, which creates an individual testcase for each example .pbtxt -# file in this directory. -# -load(":build_defs.bzl", "decode_proto_test_suite") -load(":build_defs.bzl", "encode_proto_test_suite") - -# This expands to a tf_py_test for each test file. -# It defines the test_suite :decode_proto_op_tests. -decode_proto_test_suite( - name = "decode_proto_tests", - examples = glob(["*.pbtxt"]), -) - -# This expands to a tf_py_test for each test file. -# It defines the test_suite :encode_proto_op_tests. -encode_proto_test_suite( - name = "encode_proto_tests", - examples = glob(["*.pbtxt"]), -) - -# Below here are tests that are not tied to an example text proto. -filegroup( - name = "test_messages", - srcs = glob(["*.pbtxt"]), -) - load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") tf_py_test( - name = "decode_proto_fail_test", + name = "decode_proto_op_test", size = "small", - srcs = ["decode_proto_fail_test.py"], + srcs = ["decode_proto_op_test.py"], additional_deps = [ + ":decode_proto_op_test_base", + ":py_test_deps", + "//tensorflow/contrib/proto:proto", + "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", + ], + data = if_static( + [], + otherwise = [":libtestexample.so"], + ), + tags = [ + "no_pip", # TODO(b/78026780) + "no_windows", # TODO(b/78028010) + ], +) + +tf_py_test( + name = "encode_proto_op_test", + size = "small", + srcs = ["encode_proto_op_test.py"], + additional_deps = [ + ":encode_proto_op_test_base", ":py_test_deps", - "//third_party/py/numpy", "//tensorflow/contrib/proto:proto", "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", + "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", ], data = if_static( [], @@ -57,19 +51,41 @@ tf_py_test( ) py_library( - name = "test_case", - srcs = ["test_case.py"], - deps = ["//tensorflow/python:client_testlib"], + name = "proto_op_test_base", + testonly = 1, + srcs = ["proto_op_test_base.py"], + deps = [ + ":test_example_proto_py", + "//tensorflow/python:client_testlib", + ], +) + +py_library( + name = "decode_proto_op_test_base", + testonly = 1, + srcs = ["decode_proto_op_test_base.py"], + deps = [ + ":proto_op_test_base", + ":test_example_proto_py", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], ) py_library( - name = "py_test_deps", + name = "encode_proto_op_test_base", + testonly = 1, + srcs = ["encode_proto_op_test_base.py"], deps = [ - ":test_case", + ":proto_op_test_base", ":test_example_proto_py", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) +py_library(name = "py_test_deps") + tf_proto_library( name = "test_example_proto", srcs = ["test_example.proto"], @@ -84,3 +100,30 @@ tf_cc_shared_object( ":test_example_proto_cc", ], ) + +py_library( + name = "descriptor_source_test_base", + testonly = 1, + srcs = ["descriptor_source_test_base.py"], + deps = [ + ":proto_op_test_base", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + "@protobuf_archive//:protobuf_python", + ], +) + +tf_py_test( + name = "descriptor_source_test", + size = "small", + srcs = ["descriptor_source_test.py"], + additional_deps = [ + ":descriptor_source_test_base", + "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", + "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", + "//tensorflow/python:client_testlib", + ], + tags = [ + "no_pip", + ], +) diff --git a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl b/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl deleted file mode 100644 index f425601691e21b36914f340d53ccadf9b4e3641f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl +++ /dev/null @@ -1,89 +0,0 @@ -"""BUILD rules for generating file-driven proto test cases. - -The decode_proto_test_suite() and encode_proto_test_suite() rules take a list -of text protos and generates a tf_py_test() for each one. -""" - -load("//tensorflow:tensorflow.bzl", "tf_py_test") -load("//tensorflow:tensorflow.bzl", "register_extension_info") -load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") - -def _test_name(test, path): - return "%s_%s_test" % (test, path.split("/")[-1].split(".")[0]) - -def decode_proto_test_suite(name, examples): - """Build the decode_proto py_test for each test filename.""" - for test_filename in examples: - tf_py_test( - name = _test_name("decode_proto", test_filename), - srcs = ["decode_proto_op_test.py"], - size = "small", - data = [test_filename] + if_static( - [], - otherwise = [":libtestexample.so"], - ), - main = "decode_proto_op_test.py", - args = [ - "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename), - ], - additional_deps = [ - ":py_test_deps", - "//third_party/py/numpy", - "//tensorflow/contrib/proto:proto", - "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", - ], - tags = [ - "no_pip", # TODO(b/78026780) - "no_windows", # TODO(b/78028010) - ], - ) - native.test_suite( - name = name, - tests = [":" + _test_name("decode_proto", test_filename) - for test_filename in examples], - ) - -def encode_proto_test_suite(name, examples): - """Build the encode_proto py_test for each test filename.""" - for test_filename in examples: - tf_py_test( - name = _test_name("encode_proto", test_filename), - srcs = ["encode_proto_op_test.py"], - size = "small", - data = [test_filename] + if_static( - [], - otherwise = [":libtestexample.so"], - ), - main = "encode_proto_op_test.py", - args = [ - "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename), - ], - additional_deps = [ - ":py_test_deps", - "//third_party/py/numpy", - "//tensorflow/contrib/proto:proto", - "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", - "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", - ], - tags = [ - "no_pip", # TODO(b/78026780) - "no_windows", # TODO(b/78028010) - ], - ) - native.test_suite( - name = name, - tests = [":" + _test_name("encode_proto", test_filename) - for test_filename in examples], - ) - -register_extension_info( - extension_name = "decode_proto_test_suite", - label_regex_map = { - "deps": "deps:decode_example_.*", - }) - -register_extension_info( - extension_name = "encode_proto_test_suite", - label_regex_map = { - "deps": "deps:encode_example_.*", - }) diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py deleted file mode 100644 index 5298342ee79b08a50b13ce8715e891a332efb3bc..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py +++ /dev/null @@ -1,68 +0,0 @@ -# ============================================================================= -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= - -# Python3 preparedness imports. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.proto.python.kernel_tests import test_case -from tensorflow.contrib.proto.python.ops import decode_proto_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.platform import test - - -class DecodeProtoFailTest(test_case.ProtoOpTestCase): - """Test failure cases for DecodeToProto.""" - - def _TestCorruptProtobuf(self, sanitize): - """Test failure cases for DecodeToProto.""" - - # The goal here is to check the error reporting. - # Testing against a variety of corrupt protobufs is - # done by fuzzing. - corrupt_proto = 'This is not a binary protobuf' - - # Numpy silently truncates the strings if you don't specify dtype=object. - batch = np.array(corrupt_proto, dtype=object) - msg_type = 'tensorflow.contrib.proto.TestCase' - field_names = ['sizes'] - field_types = [dtypes.int32] - - with self.test_session() as sess: - ctensor, vtensor = decode_proto_op.decode_proto( - batch, - message_type=msg_type, - field_names=field_names, - output_types=field_types, - sanitize=sanitize) - with self.assertRaisesRegexp(errors.DataLossError, - 'Unable to parse binary protobuf' - '|Failed to consume entire buffer'): - _ = sess.run([ctensor] + vtensor) - - def testCorrupt(self): - self._TestCorruptProtobuf(sanitize=False) - - def testSanitizerCorrupt(self): - self._TestCorruptProtobuf(sanitize=True) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py index d1c13c82bc264bc8bcc721eb68ee3916f32ef7a8..934035ec4c97e04846f493817d4b4ed65db94f14 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py +++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py @@ -13,287 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= -"""Table-driven test for decode_proto op. +"""Tests for decode_proto op.""" -This test is run once with each of the *.TestCase.pbtxt files -in the test directory. -""" # Python3 preparedness imports. from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from google.protobuf import text_format - -from tensorflow.contrib.proto.python.kernel_tests import test_case -from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.contrib.proto.python.kernel_tests import decode_proto_op_test_base as test_base from tensorflow.contrib.proto.python.ops import decode_proto_op -from tensorflow.python.framework import dtypes -from tensorflow.python.platform import flags from tensorflow.python.platform import test -FLAGS = flags.FLAGS - -flags.DEFINE_string('message_text_file', None, - 'A file containing a text serialized TestCase protobuf.') - - -class DecodeProtoOpTest(test_case.ProtoOpTestCase): - - def _compareValues(self, fd, vs, evs): - """Compare lists/arrays of field values.""" - - if len(vs) != len(evs): - self.fail('Field %s decoded %d outputs, expected %d' % - (fd.name, len(vs), len(evs))) - for i, ev in enumerate(evs): - # Special case fuzzy match for float32. TensorFlow seems to mess with - # MAX_FLT slightly and the test doesn't work otherwise. - # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through. - if fd.cpp_type == fd.CPPTYPE_FLOAT: - # Numpy isclose() is better than assertIsClose() which uses an absolute - # value comparison. - self.assertTrue( - np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i])) - elif fd.cpp_type == fd.CPPTYPE_STRING: - # In Python3 string tensor values will be represented as bytes, so we - # reencode the proto values to match that. - self.assertEqual(vs[i], ev.encode('ascii')) - else: - # Doubles and other types pass through unscathed. - self.assertEqual(vs[i], ev) - - def _compareRepeatedPrimitiveValue(self, batch_shape, sizes, fields, - field_dict): - """Compare protos of type RepeatedPrimitiveValue. - - Args: - batch_shape: the shape of the input tensor of serialized messages. - sizes: int matrix of repeat counts returned by decode_proto - fields: list of test_example_pb2.FieldSpec (types and expected values) - field_dict: map from field names to decoded numpy tensors of values - """ - - # Check that expected values match. - for field in fields: - values = field_dict[field.name] - self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype) - - fd = field.expected.DESCRIPTOR.fields_by_name[field.name] - - # Values has the same shape as the input plus an extra - # dimension for repeats. - self.assertEqual(list(values.shape)[:-1], batch_shape) - - # Nested messages are represented as TF strings, requiring - # some special handling. - if field.name == 'message_value': - vs = [] - for buf in values.flat: - msg = test_example_pb2.PrimitiveValue() - msg.ParseFromString(buf) - vs.append(msg) - evs = getattr(field.expected, field.name) - if len(vs) != len(evs): - self.fail('Field %s decoded %d outputs, expected %d' % - (fd.name, len(vs), len(evs))) - for v, ev in zip(vs, evs): - self.assertEqual(v, ev) - continue - - # This can be a little confusing. For testing we are using - # RepeatedPrimitiveValue in two ways: it's the proto that we - # decode for testing, and it's used in the expected value as a - # union type. The two cases are slightly different: this is the - # second case. - # We may be fetching the uint64_value from the test proto, but - # in the expected proto we store it in the int64_value field - # because TensorFlow doesn't support unsigned int64. - tf_type_to_primitive_value_field = { - dtypes.float32: - 'float_value', - dtypes.float64: - 'double_value', - dtypes.int32: - 'int32_value', - dtypes.uint8: - 'uint8_value', - dtypes.int8: - 'int8_value', - dtypes.string: - 'string_value', - dtypes.int64: - 'int64_value', - dtypes.bool: - 'bool_value', - # Unhandled TensorFlow types: - # DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32 - # DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16 - } - tf_field_name = tf_type_to_primitive_value_field.get(field.dtype) - if tf_field_name is None: - self.fail('Unhandled tensorflow type %d' % field.dtype) - - self._compareValues(fd, values.flat, - getattr(field.expected, tf_field_name)) - - def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch, - message_type, message_format, sanitize, - force_disordered=False): - """Run decode tests on a batch of messages. - - Args: - fields: list of test_example_pb2.FieldSpec (types and expected values) - case_sizes: expected sizes array - batch_shape: the shape of the input tensor of serialized messages - batch: list of serialized messages - message_type: descriptor name for messages - message_format: format of messages, 'text' or 'binary' - sanitize: whether to sanitize binary protobuf inputs - force_disordered: whether to force fields encoded out of order. - """ - - if force_disordered: - # Exercise code path that handles out-of-order fields by prepending extra - # fields with tag numbers higher than any real field. Note that this won't - # work with sanitization because that forces reserialization using a - # trusted decoder and encoder. - assert not sanitize - extra_fields = test_example_pb2.ExtraFields() - extra_fields.string_value = 'IGNORE ME' - extra_fields.bool_value = False - extra_msg = extra_fields.SerializeToString() - batch = [extra_msg + msg for msg in batch] - - # Numpy silently truncates the strings if you don't specify dtype=object. - batch = np.array(batch, dtype=object) - batch = np.reshape(batch, batch_shape) - - field_names = [f.name for f in fields] - output_types = [f.dtype for f in fields] - - with self.test_session() as sess: - sizes, vtensor = decode_proto_op.decode_proto( - batch, - message_type=message_type, - field_names=field_names, - output_types=output_types, - message_format=message_format, - sanitize=sanitize) - - vlist = sess.run([sizes] + vtensor) - sizes = vlist[0] - # Values is a list of tensors, one for each field. - value_tensors = vlist[1:] - - # Check that the repeat sizes are correct. - self.assertTrue( - np.all(np.array(sizes.shape) == batch_shape + [len(field_names)])) - - # Check that the decoded sizes match the expected sizes. - self.assertEqual(len(sizes.flat), len(case_sizes)) - self.assertTrue( - np.all(sizes.flat == np.array( - case_sizes, dtype=np.int32))) - - field_dict = dict(zip(field_names, value_tensors)) - - self._compareRepeatedPrimitiveValue(batch_shape, sizes, fields, - field_dict) - - def testBinary(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - - batch = [primitive.SerializeToString() for primitive in case.primitive] - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - batch, - 'tensorflow.contrib.proto.RepeatedPrimitiveValue', - 'binary', - sanitize=False) - - def testBinaryDisordered(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - - batch = [primitive.SerializeToString() for primitive in case.primitive] - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - batch, - 'tensorflow.contrib.proto.RepeatedPrimitiveValue', - 'binary', - sanitize=False, - force_disordered=True) - - def testPacked(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - - # Now try with the packed serialization. - # We test the packed representations by loading the same test cases - # using PackedPrimitiveValue instead of RepeatedPrimitiveValue. - # To do this we rely on the text format being the same for packed and - # unpacked fields, and reparse the test message using the packed version - # of the proto. - packed_batch = [ - # Note: float_format='.17g' is necessary to ensure preservation of - # doubles and floats in text format. - text_format.Parse( - text_format.MessageToString( - primitive, float_format='.17g'), - test_example_pb2.PackedPrimitiveValue()).SerializeToString() - for primitive in case.primitive - ] - - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - packed_batch, - 'tensorflow.contrib.proto.PackedPrimitiveValue', - 'binary', - sanitize=False) - - def testText(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - - # Note: float_format='.17g' is necessary to ensure preservation of - # doubles and floats in text format. - text_batch = [ - text_format.MessageToString( - primitive, float_format='.17g') for primitive in case.primitive - ] - - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - text_batch, - 'tensorflow.contrib.proto.RepeatedPrimitiveValue', - 'text', - sanitize=False) - def testSanitizerGood(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) +class DecodeProtoOpTest(test_base.DecodeProtoOpTestBase): - batch = [primitive.SerializeToString() for primitive in case.primitive] - self._runDecodeProtoTests( - case.field, - case.sizes, - list(case.shape), - batch, - 'tensorflow.contrib.proto.RepeatedPrimitiveValue', - 'binary', - sanitize=True) + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + super(DecodeProtoOpTest, self).__init__(decode_proto_op, methodName) if __name__ == '__main__': diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..e3570e38a3aac738b01b28eb4bfdf57e6abbc595 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py @@ -0,0 +1,303 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for decode_proto op.""" + +# Python3 preparedness imports. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + + +from google.protobuf import text_format + +from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base +from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors + + +class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): + """Base class for testing proto decoding ops.""" + + def __init__(self, decode_module, methodName='runTest'): # pylint: disable=invalid-name + """DecodeProtoOpTestBase initializer. + + Args: + decode_module: a module containing the `decode_proto_op` method + methodName: the name of the test method (same as for test.TestCase) + """ + + super(DecodeProtoOpTestBase, self).__init__(methodName) + self._decode_module = decode_module + + def _compareValues(self, fd, vs, evs): + """Compare lists/arrays of field values.""" + + if len(vs) != len(evs): + self.fail('Field %s decoded %d outputs, expected %d' % + (fd.name, len(vs), len(evs))) + for i, ev in enumerate(evs): + # Special case fuzzy match for float32. TensorFlow seems to mess with + # MAX_FLT slightly and the test doesn't work otherwise. + # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through. + if fd.cpp_type == fd.CPPTYPE_FLOAT: + # Numpy isclose() is better than assertIsClose() which uses an absolute + # value comparison. + self.assertTrue( + np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i])) + elif fd.cpp_type == fd.CPPTYPE_STRING: + # In Python3 string tensor values will be represented as bytes, so we + # reencode the proto values to match that. + self.assertEqual(vs[i], ev.encode('ascii')) + else: + # Doubles and other types pass through unscathed. + self.assertEqual(vs[i], ev) + + def _compareProtos(self, batch_shape, sizes, fields, field_dict): + """Compare protos of type TestValue. + + Args: + batch_shape: the shape of the input tensor of serialized messages. + sizes: int matrix of repeat counts returned by decode_proto + fields: list of test_example_pb2.FieldSpec (types and expected values) + field_dict: map from field names to decoded numpy tensors of values + """ + + # Check that expected values match. + for field in fields: + values = field_dict[field.name] + self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype) + + fd = field.value.DESCRIPTOR.fields_by_name[field.name] + + # Values has the same shape as the input plus an extra + # dimension for repeats. + self.assertEqual(list(values.shape)[:-1], batch_shape) + + # Nested messages are represented as TF strings, requiring + # some special handling. + if field.name == 'message_value': + vs = [] + for buf in values.flat: + msg = test_example_pb2.PrimitiveValue() + msg.ParseFromString(buf) + vs.append(msg) + evs = getattr(field.value, field.name) + if len(vs) != len(evs): + self.fail('Field %s decoded %d outputs, expected %d' % + (fd.name, len(vs), len(evs))) + for v, ev in zip(vs, evs): + self.assertEqual(v, ev) + continue + + tf_type_to_primitive_value_field = { + dtypes.bool: + 'bool_value', + dtypes.float32: + 'float_value', + dtypes.float64: + 'double_value', + dtypes.int8: + 'int8_value', + dtypes.int32: + 'int32_value', + dtypes.int64: + 'int64_value', + dtypes.string: + 'string_value', + dtypes.uint8: + 'uint8_value', + dtypes.uint32: + 'uint32_value', + dtypes.uint64: + 'uint64_value', + } + tf_field_name = tf_type_to_primitive_value_field.get(field.dtype) + if tf_field_name is None: + self.fail('Unhandled tensorflow type %d' % field.dtype) + + self._compareValues(fd, values.flat, + getattr(field.value, tf_field_name)) + + def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch, + message_type, message_format, sanitize, + force_disordered=False): + """Run decode tests on a batch of messages. + + Args: + fields: list of test_example_pb2.FieldSpec (types and expected values) + case_sizes: expected sizes array + batch_shape: the shape of the input tensor of serialized messages + batch: list of serialized messages + message_type: descriptor name for messages + message_format: format of messages, 'text' or 'binary' + sanitize: whether to sanitize binary protobuf inputs + force_disordered: whether to force fields encoded out of order. + """ + + if force_disordered: + # Exercise code path that handles out-of-order fields by prepending extra + # fields with tag numbers higher than any real field. Note that this won't + # work with sanitization because that forces reserialization using a + # trusted decoder and encoder. + assert not sanitize + extra_fields = test_example_pb2.ExtraFields() + extra_fields.string_value = 'IGNORE ME' + extra_fields.bool_value = False + extra_msg = extra_fields.SerializeToString() + batch = [extra_msg + msg for msg in batch] + + # Numpy silently truncates the strings if you don't specify dtype=object. + batch = np.array(batch, dtype=object) + batch = np.reshape(batch, batch_shape) + + field_names = [f.name for f in fields] + output_types = [f.dtype for f in fields] + + with self.test_session() as sess: + sizes, vtensor = self._decode_module.decode_proto( + batch, + message_type=message_type, + field_names=field_names, + output_types=output_types, + message_format=message_format, + sanitize=sanitize) + + vlist = sess.run([sizes] + vtensor) + sizes = vlist[0] + # Values is a list of tensors, one for each field. + value_tensors = vlist[1:] + + # Check that the repeat sizes are correct. + self.assertTrue( + np.all(np.array(sizes.shape) == batch_shape + [len(field_names)])) + + # Check that the decoded sizes match the expected sizes. + self.assertEqual(len(sizes.flat), len(case_sizes)) + self.assertTrue( + np.all(sizes.flat == np.array( + case_sizes, dtype=np.int32))) + + field_dict = dict(zip(field_names, value_tensors)) + + self._compareProtos(batch_shape, sizes, fields, field_dict) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testBinary(self, case): + batch = [value.SerializeToString() for value in case.values] + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + batch, + 'tensorflow.contrib.proto.TestValue', + 'binary', + sanitize=False) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testBinaryDisordered(self, case): + batch = [value.SerializeToString() for value in case.values] + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + batch, + 'tensorflow.contrib.proto.TestValue', + 'binary', + sanitize=False, + force_disordered=True) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testPacked(self, case): + # Now try with the packed serialization. + # + # We test the packed representations by loading the same test case using + # PackedTestValue instead of TestValue. To do this we rely on the text + # format being the same for packed and unpacked fields, and reparse the + # test message using the packed version of the proto. + packed_batch = [ + # Note: float_format='.17g' is necessary to ensure preservation of + # doubles and floats in text format. + text_format.Parse( + text_format.MessageToString( + value, float_format='.17g'), + test_example_pb2.PackedTestValue()).SerializeToString() + for value in case.values + ] + + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + packed_batch, + 'tensorflow.contrib.proto.PackedTestValue', + 'binary', + sanitize=False) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testText(self, case): + # Note: float_format='.17g' is necessary to ensure preservation of + # doubles and floats in text format. + text_batch = [ + text_format.MessageToString( + value, float_format='.17g') for value in case.values + ] + + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + text_batch, + 'tensorflow.contrib.proto.TestValue', + 'text', + sanitize=False) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testSanitizerGood(self, case): + batch = [value.SerializeToString() for value in case.values] + self._runDecodeProtoTests( + case.fields, + case.sizes, + list(case.shapes), + batch, + 'tensorflow.contrib.proto.TestValue', + 'binary', + sanitize=True) + + @parameterized.parameters((False), (True)) + def testCorruptProtobuf(self, sanitize): + corrupt_proto = 'This is not a binary protobuf' + + # Numpy silently truncates the strings if you don't specify dtype=object. + batch = np.array(corrupt_proto, dtype=object) + msg_type = 'tensorflow.contrib.proto.TestCase' + field_names = ['sizes'] + field_types = [dtypes.int32] + + with self.test_session() as sess: + ctensor, vtensor = self._decode_module.decode_proto( + batch, + message_type=msg_type, + field_names=field_names, + output_types=field_types, + sanitize=sanitize) + with self.assertRaisesRegexp(errors.DataLossError, + 'Unable to parse binary protobuf' + '|Failed to consume entire buffer'): + _ = sess.run([ctensor] + vtensor) diff --git a/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt deleted file mode 100644 index 4e316819077c7dbb28beefd4dc260568f26da680..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt +++ /dev/null @@ -1,94 +0,0 @@ -primitive { - # No fields specified, so we get all defaults -} -shape: 1 -sizes: 0 -field { - name: "double_default" - dtype: DT_DOUBLE - expected { double_value: 1.0 } -} -sizes: 0 -field { - name: "float_default" - dtype: DT_DOUBLE # Try casting the float field to double. - expected { double_value: 2.0 } -} -sizes: 0 -field { - name: "int64_default" - dtype: DT_INT64 - expected { int64_value: 3 } -} -sizes: 0 -field { - name: "uint64_default" - dtype: DT_INT64 - expected { int64_value: 4 } -} -sizes: 0 -field { - name: "int32_default" - dtype: DT_INT32 - expected { int32_value: 5 } -} -sizes: 0 -field { - name: "fixed64_default" - dtype: DT_INT64 - expected { int64_value: 6 } -} -sizes: 0 -field { - name: "fixed32_default" - dtype: DT_INT32 - expected { int32_value: 7 } -} -sizes: 0 -field { - name: "bool_default" - dtype: DT_BOOL - expected { bool_value: true } -} -sizes: 0 -field { - name: "string_default" - dtype: DT_STRING - expected { string_value: "a" } -} -sizes: 0 -field { - name: "bytes_default" - dtype: DT_STRING - expected { string_value: "a longer default string" } -} -sizes: 0 -field { - name: "uint32_default" - dtype: DT_INT32 - expected { int32_value: -1 } -} -sizes: 0 -field { - name: "sfixed32_default" - dtype: DT_INT32 - expected { int32_value: 10 } -} -sizes: 0 -field { - name: "sfixed64_default" - dtype: DT_INT64 - expected { int64_value: 11 } -} -sizes: 0 -field { - name: "sint32_default" - dtype: DT_INT32 - expected { int32_value: 12 } -} -sizes: 0 -field { - name: "sint64_default" - dtype: DT_INT64 - expected { int64_value: 13 } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_case.py b/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test.py similarity index 65% rename from tensorflow/contrib/proto/python/kernel_tests/test_case.py rename to tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test.py index b95202c5df654cfc02339477b242b2c58575a4d5..32ca318f733ce11221539838dfdbcf710dca51a1 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/test_case.py +++ b/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test.py @@ -13,23 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= -"""Test case base for testing proto operations.""" - +"""Tests for proto ops reading descriptors from other sources.""" # Python3 preparedness imports. from __future__ import absolute_import from __future__ import division from __future__ import print_function -import ctypes as ct -import os - +from tensorflow.contrib.proto.python.kernel_tests import descriptor_source_test_base as test_base +from tensorflow.contrib.proto.python.ops import decode_proto_op +from tensorflow.contrib.proto.python.ops import encode_proto_op from tensorflow.python.platform import test -class ProtoOpTestCase(test.TestCase): +class DescriptorSourceTest(test_base.DescriptorSourceTestBase): def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - super(ProtoOpTestCase, self).__init__(methodName) - lib = os.path.join(os.path.dirname(__file__), 'libtestexample.so') - if os.path.isfile(lib): - ct.cdll.LoadLibrary(lib) + super(DescriptorSourceTest, self).__init__(decode_proto_op, encode_proto_op, + methodName) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..9a1c04af324620fc893583ebb17cd99ea3ba166d --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/descriptor_source_test_base.py @@ -0,0 +1,176 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for proto ops reading descriptors from other sources.""" +# Python3 preparedness imports. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np + +from google.protobuf.descriptor_pb2 import FieldDescriptorProto +from google.protobuf.descriptor_pb2 import FileDescriptorSet +from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base +from tensorflow.python.framework import dtypes +from tensorflow.python.platform import test + + +class DescriptorSourceTestBase(test.TestCase): + """Base class for testing descriptor sources.""" + + def __init__(self, decode_module, encode_module, methodName='runTest'): # pylint: disable=invalid-name + """DescriptorSourceTestBase initializer. + + Args: + decode_module: a module containing the `decode_proto_op` method + encode_module: a module containing the `encode_proto_op` method + methodName: the name of the test method (same as for test.TestCase) + """ + + super(DescriptorSourceTestBase, self).__init__(methodName) + self._decode_module = decode_module + self._encode_module = encode_module + + # NOTE: We generate the descriptor programmatically instead of via a compiler + # because of differences between different versions of the compiler. + # + # The generated descriptor should capture the subset of `test_example.proto` + # used in `test_base.simple_test_case()`. + def _createDescriptorFile(self): + set_proto = FileDescriptorSet() + + file_proto = set_proto.file.add( + name='types.proto', + package='tensorflow', + syntax='proto3') + enum_proto = file_proto.enum_type.add(name='DataType') + enum_proto.value.add(name='DT_DOUBLE', number=0) + enum_proto.value.add(name='DT_BOOL', number=1) + + file_proto = set_proto.file.add( + name='test_example.proto', + package='tensorflow.contrib.proto', + dependency=['types.proto']) + message_proto = file_proto.message_type.add(name='TestCase') + message_proto.field.add( + name='values', + number=1, + type=FieldDescriptorProto.TYPE_MESSAGE, + type_name='.tensorflow.contrib.proto.TestValue', + label=FieldDescriptorProto.LABEL_REPEATED) + message_proto.field.add( + name='shapes', + number=2, + type=FieldDescriptorProto.TYPE_INT32, + label=FieldDescriptorProto.LABEL_REPEATED) + message_proto.field.add( + name='sizes', + number=3, + type=FieldDescriptorProto.TYPE_INT32, + label=FieldDescriptorProto.LABEL_REPEATED) + message_proto.field.add( + name='fields', + number=4, + type=FieldDescriptorProto.TYPE_MESSAGE, + type_name='.tensorflow.contrib.proto.FieldSpec', + label=FieldDescriptorProto.LABEL_REPEATED) + + message_proto = file_proto.message_type.add( + name='TestValue') + message_proto.field.add( + name='double_value', + number=1, + type=FieldDescriptorProto.TYPE_DOUBLE, + label=FieldDescriptorProto.LABEL_REPEATED) + message_proto.field.add( + name='bool_value', + number=2, + type=FieldDescriptorProto.TYPE_BOOL, + label=FieldDescriptorProto.LABEL_REPEATED) + + message_proto = file_proto.message_type.add( + name='FieldSpec') + message_proto.field.add( + name='name', + number=1, + type=FieldDescriptorProto.TYPE_STRING, + label=FieldDescriptorProto.LABEL_OPTIONAL) + message_proto.field.add( + name='dtype', + number=2, + type=FieldDescriptorProto.TYPE_ENUM, + type_name='.tensorflow.DataType', + label=FieldDescriptorProto.LABEL_OPTIONAL) + message_proto.field.add( + name='value', + number=3, + type=FieldDescriptorProto.TYPE_MESSAGE, + type_name='.tensorflow.contrib.proto.TestValue', + label=FieldDescriptorProto.LABEL_OPTIONAL) + + fn = os.path.join(self.get_temp_dir(), 'descriptor.pb') + with open(fn, 'wb') as f: + f.write(set_proto.SerializeToString()) + return fn + + def _testRoundtrip(self, descriptor_source): + # Numpy silently truncates the strings if you don't specify dtype=object. + in_bufs = np.array( + [test_base.ProtoOpTestBase.simple_test_case().SerializeToString()], + dtype=object) + message_type = 'tensorflow.contrib.proto.TestCase' + field_names = ['values', 'shapes', 'sizes', 'fields'] + tensor_types = [dtypes.string, dtypes.int32, dtypes.int32, dtypes.string] + + with self.test_session() as sess: + sizes, field_tensors = self._decode_module.decode_proto( + in_bufs, + message_type=message_type, + field_names=field_names, + output_types=tensor_types, + descriptor_source=descriptor_source) + + out_tensors = self._encode_module.encode_proto( + sizes, + field_tensors, + message_type=message_type, + field_names=field_names, + descriptor_source=descriptor_source) + + out_bufs, = sess.run([out_tensors]) + + # Check that the re-encoded tensor has the same shape. + self.assertEqual(in_bufs.shape, out_bufs.shape) + + # Compare the input and output. + for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat): + # Check that the input and output serialized messages are identical. + # If we fail here, there is a difference in the serialized + # representation but the new serialization still parses. This could + # be harmless (a change in map ordering?) or it could be bad (e.g. + # loss of packing in the encoding). + self.assertEqual(in_buf, out_buf) + + def testWithFileDescriptorSet(self): + # First try parsing with a local proto db, which should fail. + with self.assertRaisesOpError('No descriptor found for message type'): + self._testRoundtrip('local://') + + # Now try parsing with a FileDescriptorSet which contains the test proto. + descriptor_file = self._createDescriptorFile() + self._testRoundtrip(descriptor_file) diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py index 30e58e6336dc66830418c7cd2b3111a851d691b6..fc5cd25d43be1df2480630396c39f7a83e0eb57a 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py +++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py @@ -13,167 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= -"""Table-driven test for encode_proto op. +"""Tests for encode_proto op.""" -This test is run once with each of the *.TestCase.pbtxt files -in the test directory. - -It tests that encode_proto is a lossless inverse of decode_proto -(for the specified fields). -""" # Python3 readiness boilerplate from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from google.protobuf import text_format - -from tensorflow.contrib.proto.python.kernel_tests import test_case -from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.contrib.proto.python.kernel_tests import encode_proto_op_test_base as test_base from tensorflow.contrib.proto.python.ops import decode_proto_op from tensorflow.contrib.proto.python.ops import encode_proto_op -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import flags from tensorflow.python.platform import test -FLAGS = flags.FLAGS - -flags.DEFINE_string('message_text_file', None, - 'A file containing a text serialized TestCase protobuf.') - - -class EncodeProtoOpTest(test_case.ProtoOpTestCase): - - def testBadInputs(self): - # Invalid field name - with self.test_session(): - with self.assertRaisesOpError('Unknown field: non_existent_field'): - encode_proto_op.encode_proto( - sizes=[[1]], - values=[np.array([[0.0]], dtype=np.int32)], - message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', - field_names=['non_existent_field']).eval() - - # Incorrect types. - with self.test_session(): - with self.assertRaisesOpError( - 'Incompatible type for field double_value.'): - encode_proto_op.encode_proto( - sizes=[[1]], - values=[np.array([[0.0]], dtype=np.int32)], - message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', - field_names=['double_value']).eval() - - # Incorrect shapes of sizes. - with self.test_session(): - with self.assertRaisesOpError( - r'sizes should be batch_size \+ \[len\(field_names\)\]'): - sizes = array_ops.placeholder(dtypes.int32) - values = array_ops.placeholder(dtypes.float64) - encode_proto_op.encode_proto( - sizes=sizes, - values=[values], - message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', - field_names=['double_value']).eval(feed_dict={ - sizes: [[[0, 0]]], - values: [[0.0]] - }) - - # Inconsistent shapes of values. - with self.test_session(): - with self.assertRaisesOpError( - 'Values must match up to the last dimension'): - sizes = array_ops.placeholder(dtypes.int32) - values1 = array_ops.placeholder(dtypes.float64) - values2 = array_ops.placeholder(dtypes.int32) - (encode_proto_op.encode_proto( - sizes=[[1, 1]], - values=[values1, values2], - message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', - field_names=['double_value', 'int32_value']).eval(feed_dict={ - values1: [[0.0]], - values2: [[0], [0]] - })) - - def _testRoundtrip(self, in_bufs, message_type, fields): - - field_names = [f.name for f in fields] - out_types = [f.dtype for f in fields] - - with self.test_session() as sess: - sizes, field_tensors = decode_proto_op.decode_proto( - in_bufs, - message_type=message_type, - field_names=field_names, - output_types=out_types) - - out_tensors = encode_proto_op.encode_proto( - sizes, - field_tensors, - message_type=message_type, - field_names=field_names) - - out_bufs, = sess.run([out_tensors]) - - # Check that the re-encoded tensor has the same shape. - self.assertEqual(in_bufs.shape, out_bufs.shape) - - # Compare the input and output. - for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat): - in_obj = test_example_pb2.RepeatedPrimitiveValue() - in_obj.ParseFromString(in_buf) - - out_obj = test_example_pb2.RepeatedPrimitiveValue() - out_obj.ParseFromString(out_buf) - - # Check that the deserialized objects are identical. - self.assertEqual(in_obj, out_obj) - - # Check that the input and output serialized messages are identical. - # If we fail here, there is a difference in the serialized - # representation but the new serialization still parses. This could - # be harmless (a change in map ordering?) or it could be bad (e.g. - # loss of packing in the encoding). - self.assertEqual(in_buf, out_buf) - - def testRoundtrip(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - - in_bufs = [primitive.SerializeToString() for primitive in case.primitive] - - # np.array silently truncates strings if you don't specify dtype=object. - in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shape)) - return self._testRoundtrip( - in_bufs, 'tensorflow.contrib.proto.RepeatedPrimitiveValue', case.field) - - def testRoundtripPacked(self): - with open(FLAGS.message_text_file, 'r') as fp: - case = text_format.Parse(fp.read(), test_example_pb2.TestCase()) - # Now try with the packed serialization. - # We test the packed representations by loading the same test cases - # using PackedPrimitiveValue instead of RepeatedPrimitiveValue. - # To do this we rely on the text format being the same for packed and - # unpacked fields, and reparse the test message using the packed version - # of the proto. - in_bufs = [ - # Note: float_format='.17g' is necessary to ensure preservation of - # doubles and floats in text format. - text_format.Parse( - text_format.MessageToString( - primitive, float_format='.17g'), - test_example_pb2.PackedPrimitiveValue()).SerializeToString() - for primitive in case.primitive - ] +class EncodeProtoOpTest(test_base.EncodeProtoOpTestBase): - # np.array silently truncates strings if you don't specify dtype=object. - in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shape)) - return self._testRoundtrip( - in_bufs, 'tensorflow.contrib.proto.PackedPrimitiveValue', case.field) + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + super(EncodeProtoOpTest, self).__init__(decode_proto_op, encode_proto_op, + methodName) if __name__ == '__main__': diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..07dfb924d3ede5bdb9b848c5eb0d3382ec053121 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py @@ -0,0 +1,177 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Table-driven test for encode_proto op. + +This test is run once with each of the *.TestCase.pbtxt files +in the test directory. + +It tests that encode_proto is a lossless inverse of decode_proto +(for the specified fields). +""" +# Python3 readiness boilerplate +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from google.protobuf import text_format + +from tensorflow.contrib.proto.python.kernel_tests import proto_op_test_base as test_base +from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops + + +class EncodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): + """Base class for testing proto encoding ops.""" + + def __init__(self, decode_module, encode_module, methodName='runTest'): # pylint: disable=invalid-name + """EncodeProtoOpTestBase initializer. + + Args: + decode_module: a module containing the `decode_proto_op` method + encode_module: a module containing the `encode_proto_op` method + methodName: the name of the test method (same as for test.TestCase) + """ + + super(EncodeProtoOpTestBase, self).__init__(methodName) + self._decode_module = decode_module + self._encode_module = encode_module + + def testBadInputs(self): + # Invalid field name + with self.test_session(): + with self.assertRaisesOpError('Unknown field: non_existent_field'): + self._encode_module.encode_proto( + sizes=[[1]], + values=[np.array([[0.0]], dtype=np.int32)], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['non_existent_field']).eval() + + # Incorrect types. + with self.test_session(): + with self.assertRaisesOpError( + 'Incompatible type for field double_value.'): + self._encode_module.encode_proto( + sizes=[[1]], + values=[np.array([[0.0]], dtype=np.int32)], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value']).eval() + + # Incorrect shapes of sizes. + with self.test_session(): + with self.assertRaisesOpError( + r'sizes should be batch_size \+ \[len\(field_names\)\]'): + sizes = array_ops.placeholder(dtypes.int32) + values = array_ops.placeholder(dtypes.float64) + self._encode_module.encode_proto( + sizes=sizes, + values=[values], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value']).eval(feed_dict={ + sizes: [[[0, 0]]], + values: [[0.0]] + }) + + # Inconsistent shapes of values. + with self.test_session(): + with self.assertRaisesOpError( + 'Values must match up to the last dimension'): + sizes = array_ops.placeholder(dtypes.int32) + values1 = array_ops.placeholder(dtypes.float64) + values2 = array_ops.placeholder(dtypes.int32) + (self._encode_module.encode_proto( + sizes=[[1, 1]], + values=[values1, values2], + message_type='tensorflow.contrib.proto.TestValue', + field_names=['double_value', 'int32_value']).eval(feed_dict={ + values1: [[0.0]], + values2: [[0], [0]] + })) + + def _testRoundtrip(self, in_bufs, message_type, fields): + + field_names = [f.name for f in fields] + out_types = [f.dtype for f in fields] + + with self.test_session() as sess: + sizes, field_tensors = self._decode_module.decode_proto( + in_bufs, + message_type=message_type, + field_names=field_names, + output_types=out_types) + + out_tensors = self._encode_module.encode_proto( + sizes, + field_tensors, + message_type=message_type, + field_names=field_names) + + out_bufs, = sess.run([out_tensors]) + + # Check that the re-encoded tensor has the same shape. + self.assertEqual(in_bufs.shape, out_bufs.shape) + + # Compare the input and output. + for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat): + in_obj = test_example_pb2.TestValue() + in_obj.ParseFromString(in_buf) + + out_obj = test_example_pb2.TestValue() + out_obj.ParseFromString(out_buf) + + # Check that the deserialized objects are identical. + self.assertEqual(in_obj, out_obj) + + # Check that the input and output serialized messages are identical. + # If we fail here, there is a difference in the serialized + # representation but the new serialization still parses. This could + # be harmless (a change in map ordering?) or it could be bad (e.g. + # loss of packing in the encoding). + self.assertEqual(in_buf, out_buf) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testRoundtrip(self, case): + in_bufs = [value.SerializeToString() for value in case.values] + + # np.array silently truncates strings if you don't specify dtype=object. + in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shapes)) + return self._testRoundtrip( + in_bufs, 'tensorflow.contrib.proto.TestValue', case.fields) + + @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) + def testRoundtripPacked(self, case): + # Now try with the packed serialization. + # We test the packed representations by loading the same test cases using + # PackedTestValue instead of TestValue. To do this we rely on the text + # format being the same for packed and unpacked fields, and reparse the test + # message using the packed version of the proto. + in_bufs = [ + # Note: float_format='.17g' is necessary to ensure preservation of + # doubles and floats in text format. + text_format.Parse( + text_format.MessageToString( + value, float_format='.17g'), + test_example_pb2.PackedTestValue()).SerializeToString() + for value in case.values + ] + + # np.array silently truncates strings if you don't specify dtype=object. + in_bufs = np.reshape(np.array(in_bufs, dtype=object), list(case.shapes)) + return self._testRoundtrip( + in_bufs, 'tensorflow.contrib.proto.PackedTestValue', case.fields) diff --git a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt deleted file mode 100644 index b170f89c0f00dd9dffd5785197bb3bfd1ca2cfee..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt +++ /dev/null @@ -1,161 +0,0 @@ -primitive { - double_value: -1.7976931348623158e+308 - double_value: 2.2250738585072014e-308 - double_value: 1.7976931348623158e+308 - float_value: -3.402823466e+38 - float_value: 1.175494351e-38 - float_value: 3.402823466e+38 - int64_value: -9223372036854775808 - int64_value: 9223372036854775807 - uint64_value: 0 - uint64_value: 18446744073709551615 - int32_value: -2147483648 - int32_value: 2147483647 - fixed64_value: 0 - fixed64_value: 18446744073709551615 - fixed32_value: 0 - fixed32_value: 4294967295 - bool_value: false - bool_value: true - string_value: "" - string_value: "I refer to the infinite." - uint32_value: 0 - uint32_value: 4294967295 - sfixed32_value: -2147483648 - sfixed32_value: 2147483647 - sfixed64_value: -9223372036854775808 - sfixed64_value: 9223372036854775807 - sint32_value: -2147483648 - sint32_value: 2147483647 - sint64_value: -9223372036854775808 - sint64_value: 9223372036854775807 -} -shape: 1 -sizes: 3 -sizes: 3 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -sizes: 2 -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: -1.7976931348623158e+308 - double_value: 2.2250738585072014e-308 - double_value: 1.7976931348623158e+308 - } -} -field { - name: "float_value" - dtype: DT_FLOAT - expected { - float_value: -3.402823466e+38 - float_value: 1.175494351e-38 - float_value: 3.402823466e+38 - } -} -field { - name: "int64_value" - dtype: DT_INT64 - expected { - int64_value: -9223372036854775808 - int64_value: 9223372036854775807 - } -} -field { - name: "uint64_value" - dtype: DT_INT64 - expected { - int64_value: 0 - int64_value: -1 - } -} -field { - name: "int32_value" - dtype: DT_INT32 - expected { - int32_value: -2147483648 - int32_value: 2147483647 - } -} -field { - name: "fixed64_value" - dtype: DT_INT64 - expected { - int64_value: 0 - int64_value: -1 # unsigned is 18446744073709551615 - } -} -field { - name: "fixed32_value" - dtype: DT_INT32 - expected { - int32_value: 0 - int32_value: -1 # unsigned is 4294967295 - } -} -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: false - bool_value: true - } -} -field { - name: "string_value" - dtype: DT_STRING - expected { - string_value: "" - string_value: "I refer to the infinite." - } -} -field { - name: "uint32_value" - dtype: DT_INT32 - expected { - int32_value: 0 - int32_value: -1 # unsigned is 4294967295 - } -} -field { - name: "sfixed32_value" - dtype: DT_INT32 - expected { - int32_value: -2147483648 - int32_value: 2147483647 - } -} -field { - name: "sfixed64_value" - dtype: DT_INT64 - expected { - int64_value: -9223372036854775808 - int64_value: 9223372036854775807 - } -} -field { - name: "sint32_value" - dtype: DT_INT32 - expected { - int32_value: -2147483648 - int32_value: 2147483647 - } -} -field { - name: "sint64_value" - dtype: DT_INT64 - expected { - int64_value: -9223372036854775808 - int64_value: 9223372036854775807 - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt deleted file mode 100644 index c664e52851b5bb3c439544537ce6402fc7cf3362..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt +++ /dev/null @@ -1,16 +0,0 @@ -primitive { - message_value { - double_value: 23.5 - } -} -shape: 1 -sizes: 1 -field { - name: "message_value" - dtype: DT_STRING - expected { - message_value { - double_value: 23.5 - } - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt deleted file mode 100644 index 125651d7eaa1901e4804712bb807322b02ed5bc6..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt +++ /dev/null @@ -1,20 +0,0 @@ -primitive { - bool_value: true -} -shape: 1 -sizes: 1 -sizes: 0 -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: true - } -} -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: 0.0 - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt deleted file mode 100644 index bc07efc8f3038c6c540855c97b2254575e517ef3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt +++ /dev/null @@ -1,29 +0,0 @@ -primitive { - fixed32_value: 4294967295 - uint32_value: 4294967295 -} -shape: 1 -sizes: 1 -field { - name: "fixed32_value" - dtype: DT_INT64 - expected { - int64_value: 4294967295 - } -} -sizes: 1 -field { - name: "uint32_value" - dtype: DT_INT64 - expected { - int64_value: 4294967295 - } -} -sizes: 0 -field { - name: "uint32_default" - dtype: DT_INT64 - expected { - int64_value: 4294967295 # Comes from an explicitly-specified default - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..2950c7dfdc59a11ba7d2c07d8406bd4af26b5bd9 --- /dev/null +++ b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py @@ -0,0 +1,419 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Test case base for testing proto operations.""" + +# Python3 preparedness imports. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ctypes as ct +import os + +from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.python.platform import test + + +class ProtoOpTestBase(test.TestCase): + """Base class for testing proto decoding and encoding ops.""" + + def __init__(self, methodName="runTest"): # pylint: disable=invalid-name + super(ProtoOpTestBase, self).__init__(methodName) + lib = os.path.join(os.path.dirname(__file__), "libtestexample.so") + if os.path.isfile(lib): + ct.cdll.LoadLibrary(lib) + + @staticmethod + def named_parameters(): + return ( + ("defaults", ProtoOpTestBase.defaults_test_case()), + ("minmax", ProtoOpTestBase.minmax_test_case()), + ("nested", ProtoOpTestBase.nested_test_case()), + ("optional", ProtoOpTestBase.optional_test_case()), + ("promote", ProtoOpTestBase.promote_test_case()), + ("ragged", ProtoOpTestBase.ragged_test_case()), + ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()), + ("simple", ProtoOpTestBase.simple_test_case()), + ) + + @staticmethod + def defaults_test_case(): + test_case = test_example_pb2.TestCase() + test_case.values.add() # No fields specified, so we get all defaults. + test_case.shapes.append(1) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "double_value_with_default" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(1.0) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "float_value_with_default" + field.dtype = types_pb2.DT_FLOAT + field.value.float_value.append(2.0) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "int64_value_with_default" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(3) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "sfixed64_value_with_default" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(11) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "sint64_value_with_default" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(13) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "uint64_value_with_default" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(4) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "fixed64_value_with_default" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(6) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "int32_value_with_default" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(5) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "sfixed32_value_with_default" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(10) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "sint32_value_with_default" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(12) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "uint32_value_with_default" + field.dtype = types_pb2.DT_UINT32 + field.value.uint32_value.append(9) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "fixed32_value_with_default" + field.dtype = types_pb2.DT_UINT32 + field.value.uint32_value.append(7) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "bool_value_with_default" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "string_value_with_default" + field.dtype = types_pb2.DT_STRING + field.value.string_value.append("a") + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "bytes_value_with_default" + field.dtype = types_pb2.DT_STRING + field.value.string_value.append("a longer default string") + return test_case + + @staticmethod + def minmax_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.double_value.append(-1.7976931348623158e+308) + value.double_value.append(2.2250738585072014e-308) + value.double_value.append(1.7976931348623158e+308) + value.float_value.append(-3.402823466e+38) + value.float_value.append(1.175494351e-38) + value.float_value.append(3.402823466e+38) + value.int64_value.append(-9223372036854775808) + value.int64_value.append(9223372036854775807) + value.sfixed64_value.append(-9223372036854775808) + value.sfixed64_value.append(9223372036854775807) + value.sint64_value.append(-9223372036854775808) + value.sint64_value.append(9223372036854775807) + value.uint64_value.append(0) + value.uint64_value.append(18446744073709551615) + value.fixed64_value.append(0) + value.fixed64_value.append(18446744073709551615) + value.int32_value.append(-2147483648) + value.int32_value.append(2147483647) + value.sfixed32_value.append(-2147483648) + value.sfixed32_value.append(2147483647) + value.sint32_value.append(-2147483648) + value.sint32_value.append(2147483647) + value.uint32_value.append(0) + value.uint32_value.append(4294967295) + value.fixed32_value.append(0) + value.fixed32_value.append(4294967295) + value.bool_value.append(False) + value.bool_value.append(True) + value.string_value.append("") + value.string_value.append("I refer to the infinite.") + test_case.shapes.append(1) + test_case.sizes.append(3) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(-1.7976931348623158e+308) + field.value.double_value.append(2.2250738585072014e-308) + field.value.double_value.append(1.7976931348623158e+308) + test_case.sizes.append(3) + field = test_case.fields.add() + field.name = "float_value" + field.dtype = types_pb2.DT_FLOAT + field.value.float_value.append(-3.402823466e+38) + field.value.float_value.append(1.175494351e-38) + field.value.float_value.append(3.402823466e+38) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "int64_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(-9223372036854775808) + field.value.int64_value.append(9223372036854775807) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "sfixed64_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(-9223372036854775808) + field.value.int64_value.append(9223372036854775807) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "sint64_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(-9223372036854775808) + field.value.int64_value.append(9223372036854775807) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "uint64_value" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(0) + field.value.uint64_value.append(18446744073709551615) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "fixed64_value" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(0) + field.value.uint64_value.append(18446744073709551615) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "int32_value" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(-2147483648) + field.value.int32_value.append(2147483647) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "sfixed32_value" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(-2147483648) + field.value.int32_value.append(2147483647) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "sint32_value" + field.dtype = types_pb2.DT_INT32 + field.value.int32_value.append(-2147483648) + field.value.int32_value.append(2147483647) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "uint32_value" + field.dtype = types_pb2.DT_UINT32 + field.value.uint32_value.append(0) + field.value.uint32_value.append(4294967295) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "fixed32_value" + field.dtype = types_pb2.DT_UINT32 + field.value.uint32_value.append(0) + field.value.uint32_value.append(4294967295) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(False) + field.value.bool_value.append(True) + test_case.sizes.append(2) + field = test_case.fields.add() + field.name = "string_value" + field.dtype = types_pb2.DT_STRING + field.value.string_value.append("") + field.value.string_value.append("I refer to the infinite.") + return test_case + + @staticmethod + def nested_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + message_value = value.message_value.add() + message_value.double_value = 23.5 + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "message_value" + field.dtype = types_pb2.DT_STRING + message_value = field.value.message_value.add() + message_value.double_value = 23.5 + return test_case + + @staticmethod + def optional_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.bool_value.append(True) + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + test_case.sizes.append(0) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(0.0) + return test_case + + @staticmethod + def promote_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.sint32_value.append(2147483647) + value.sfixed32_value.append(2147483647) + value.int32_value.append(2147483647) + value.fixed32_value.append(4294967295) + value.uint32_value.append(4294967295) + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "sint32_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(2147483647) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "sfixed32_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(2147483647) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "int32_value" + field.dtype = types_pb2.DT_INT64 + field.value.int64_value.append(2147483647) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "fixed32_value" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(4294967295) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "uint32_value" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(4294967295) + return test_case + + @staticmethod + def ragged_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.double_value.append(23.5) + value.double_value.append(123.0) + value.bool_value.append(True) + value = test_case.values.add() + value.double_value.append(3.1) + value.bool_value.append(False) + test_case.shapes.append(2) + test_case.sizes.append(2) + test_case.sizes.append(1) + test_case.sizes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(23.5) + field.value.double_value.append(123.0) + field.value.double_value.append(3.1) + field.value.double_value.append(0.0) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + field.value.bool_value.append(False) + return test_case + + @staticmethod + def shaped_batch_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.double_value.append(23.5) + value.bool_value.append(True) + value = test_case.values.add() + value.double_value.append(44.0) + value.bool_value.append(False) + value = test_case.values.add() + value.double_value.append(3.14159) + value.bool_value.append(True) + value = test_case.values.add() + value.double_value.append(1.414) + value.bool_value.append(True) + value = test_case.values.add() + value.double_value.append(-32.2) + value.bool_value.append(False) + value = test_case.values.add() + value.double_value.append(0.0001) + value.bool_value.append(True) + test_case.shapes.append(3) + test_case.shapes.append(2) + for _ in range(12): + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(23.5) + field.value.double_value.append(44.0) + field.value.double_value.append(3.14159) + field.value.double_value.append(1.414) + field.value.double_value.append(-32.2) + field.value.double_value.append(0.0001) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + field.value.bool_value.append(False) + field.value.bool_value.append(True) + field.value.bool_value.append(True) + field.value.bool_value.append(False) + field.value.bool_value.append(True) + return test_case + + @staticmethod + def simple_test_case(): + test_case = test_example_pb2.TestCase() + value = test_case.values.add() + value.double_value.append(23.5) + value.bool_value.append(True) + test_case.shapes.append(1) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "double_value" + field.dtype = types_pb2.DT_DOUBLE + field.value.double_value.append(23.5) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "bool_value" + field.dtype = types_pb2.DT_BOOL + field.value.bool_value.append(True) + return test_case diff --git a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt deleted file mode 100644 index 61c7ac53f72b0764a0d57241cbdcdd93fcbd9279..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt +++ /dev/null @@ -1,32 +0,0 @@ -primitive { - double_value: 23.5 - double_value: 123.0 - bool_value: true -} -primitive { - double_value: 3.1 - bool_value: false -} -shape: 2 -sizes: 2 -sizes: 1 -sizes: 1 -sizes: 1 -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: 23.5 - double_value: 123.0 - double_value: 3.1 - double_value: 0.0 - } -} -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: true - bool_value: false - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt deleted file mode 100644 index f4828076d52dc5d03a887c4a445dbcf52414c361..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt +++ /dev/null @@ -1,62 +0,0 @@ -primitive { - double_value: 23.5 - bool_value: true -} -primitive { - double_value: 44.0 - bool_value: false -} -primitive { - double_value: 3.14159 - bool_value: true -} -primitive { - double_value: 1.414 - bool_value: true -} -primitive { - double_value: -32.2 - bool_value: false -} -primitive { - double_value: 0.0001 - bool_value: true -} -shape: 3 -shape: 2 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -sizes: 1 -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: 23.5 - double_value: 44.0 - double_value: 3.14159 - double_value: 1.414 - double_value: -32.2 - double_value: 0.0001 - } -} -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: true - bool_value: false - bool_value: true - bool_value: true - bool_value: false - bool_value: true - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt deleted file mode 100644 index dc20ac147b0e772f05b4fc614f9f56513aceb1d5..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt +++ /dev/null @@ -1,21 +0,0 @@ -primitive { - double_value: 23.5 - bool_value: true -} -shape: 1 -sizes: 1 -sizes: 1 -field { - name: "double_value" - dtype: DT_DOUBLE - expected { - double_value: 23.5 - } -} -field { - name: "bool_value" - dtype: DT_BOOL - expected { - bool_value: true - } -} diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto index a2c88e372bf7c6b7f14c5bb55776b66c4c06bcd4..674d881220a1113631def47c5111e3ef401b99f3 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto +++ b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto @@ -1,6 +1,4 @@ // Test description and protos to work with it. -// -// Many of the protos in this file are for unit tests that haven't been written yet. syntax = "proto2"; @@ -8,54 +6,27 @@ import "tensorflow/core/framework/types.proto"; package tensorflow.contrib.proto; -// A TestCase holds a proto and a bunch of assertions -// about how it should decode. +// A TestCase holds a proto and assertions about how it should decode. message TestCase { - // A batch of primitives to be serialized and decoded. - repeated RepeatedPrimitiveValue primitive = 1; - // The shape of the batch. - repeated int32 shape = 2; + // Batches of primitive values. + repeated TestValue values = 1; + // The batch shapes. + repeated int32 shapes = 2; // Expected sizes for each field. repeated int32 sizes = 3; // Expected values for each field. - repeated FieldSpec field = 4; + repeated FieldSpec fields = 4; }; // FieldSpec describes the expected output for a single field. message FieldSpec { optional string name = 1; optional tensorflow.DataType dtype = 2; - optional RepeatedPrimitiveValue expected = 3; + optional TestValue value = 3; }; +// NOTE: This definition must be kept in sync with PackedTestValue. message TestValue { - optional PrimitiveValue primitive_value = 1; - optional EnumValue enum_value = 2; - optional MessageValue message_value = 3; - optional RepeatedMessageValue repeated_message_value = 4; - optional RepeatedPrimitiveValue repeated_primitive_value = 6; -} - -message PrimitiveValue { - optional double double_value = 1; - optional float float_value = 2; - optional int64 int64_value = 3; - optional uint64 uint64_value = 4; - optional int32 int32_value = 5; - optional fixed64 fixed64_value = 6; - optional fixed32 fixed32_value = 7; - optional bool bool_value = 8; - optional string string_value = 9; - optional bytes bytes_value = 12; - optional uint32 uint32_value = 13; - optional sfixed32 sfixed32_value = 15; - optional sfixed64 sfixed64_value = 16; - optional sint32 sint32_value = 17; - optional sint64 sint64_value = 18; -} - -// NOTE: This definition must be kept in sync with PackedPrimitiveValue. -message RepeatedPrimitiveValue { repeated double double_value = 1; repeated float float_value = 2; repeated int64 int64_value = 3; @@ -74,30 +45,31 @@ message RepeatedPrimitiveValue { repeated PrimitiveValue message_value = 19; // Optional fields with explicitly-specified defaults. - optional double double_default = 20 [default = 1.0]; - optional float float_default = 21 [default = 2.0]; - optional int64 int64_default = 22 [default = 3]; - optional uint64 uint64_default = 23 [default = 4]; - optional int32 int32_default = 24 [default = 5]; - optional fixed64 fixed64_default = 25 [default = 6]; - optional fixed32 fixed32_default = 26 [default = 7]; - optional bool bool_default = 27 [default = true]; - optional string string_default = 28 [default = "a"]; - optional bytes bytes_default = 29 [default = "a longer default string"]; - optional uint32 uint32_default = 30 [default = 4294967295]; - optional sfixed32 sfixed32_default = 31 [default = 10]; - optional sfixed64 sfixed64_default = 32 [default = 11]; - optional sint32 sint32_default = 33 [default = 12]; - optional sint64 sint64_default = 34 [default = 13]; + optional double double_value_with_default = 20 [default = 1.0]; + optional float float_value_with_default = 21 [default = 2.0]; + optional int64 int64_value_with_default = 22 [default = 3]; + optional uint64 uint64_value_with_default = 23 [default = 4]; + optional int32 int32_value_with_default = 24 [default = 5]; + optional fixed64 fixed64_value_with_default = 25 [default = 6]; + optional fixed32 fixed32_value_with_default = 26 [default = 7]; + optional bool bool_value_with_default = 27 [default = true]; + optional string string_value_with_default = 28 [default = "a"]; + optional bytes bytes_value_with_default = 29 + [default = "a longer default string"]; + optional uint32 uint32_value_with_default = 30 [default = 9]; + optional sfixed32 sfixed32_value_with_default = 31 [default = 10]; + optional sfixed64 sfixed64_value_with_default = 32 [default = 11]; + optional sint32 sint32_value_with_default = 33 [default = 12]; + optional sint64 sint64_value_with_default = 34 [default = 13]; } -// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue -// in the text format, but the binary serializion is different. -// We test the packed representations by loading the same test cases -// using this definition instead of RepeatedPrimitiveValue. -// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue -// in every way except the packed=true declaration. -message PackedPrimitiveValue { +// A PackedTestValue looks exactly the same as a TestValue in the text format, +// but the binary serializion is different. We test the packed representations +// by loading the same test cases using this definition instead of TestValue. +// +// NOTE: This definition must be kept in sync with TestValue in every way except +// the packed=true declaration. +message PackedTestValue { repeated double double_value = 1 [packed = true]; repeated float float_value = 2 [packed = true]; repeated int64 int64_value = 3 [packed = true]; @@ -115,23 +87,53 @@ message PackedPrimitiveValue { repeated sint64 sint64_value = 18 [packed = true]; repeated PrimitiveValue message_value = 19; - optional double double_default = 20 [default = 1.0]; - optional float float_default = 21 [default = 2.0]; - optional int64 int64_default = 22 [default = 3]; - optional uint64 uint64_default = 23 [default = 4]; - optional int32 int32_default = 24 [default = 5]; - optional fixed64 fixed64_default = 25 [default = 6]; - optional fixed32 fixed32_default = 26 [default = 7]; - optional bool bool_default = 27 [default = true]; - optional string string_default = 28 [default = "a"]; - optional bytes bytes_default = 29 [default = "a longer default string"]; - optional uint32 uint32_default = 30 [default = 4294967295]; - optional sfixed32 sfixed32_default = 31 [default = 10]; - optional sfixed64 sfixed64_default = 32 [default = 11]; - optional sint32 sint32_default = 33 [default = 12]; - optional sint64 sint64_default = 34 [default = 13]; + optional double double_value_with_default = 20 [default = 1.0]; + optional float float_value_with_default = 21 [default = 2.0]; + optional int64 int64_value_with_default = 22 [default = 3]; + optional uint64 uint64_value_with_default = 23 [default = 4]; + optional int32 int32_value_with_default = 24 [default = 5]; + optional fixed64 fixed64_value_with_default = 25 [default = 6]; + optional fixed32 fixed32_value_with_default = 26 [default = 7]; + optional bool bool_value_with_default = 27 [default = true]; + optional string string_value_with_default = 28 [default = "a"]; + optional bytes bytes_value_with_default = 29 + [default = "a longer default string"]; + optional uint32 uint32_value_with_default = 30 [default = 9]; + optional sfixed32 sfixed32_value_with_default = 31 [default = 10]; + optional sfixed64 sfixed64_value_with_default = 32 [default = 11]; + optional sint32 sint32_value_with_default = 33 [default = 12]; + optional sint64 sint64_value_with_default = 34 [default = 13]; } +message PrimitiveValue { + optional double double_value = 1; + optional float float_value = 2; + optional int64 int64_value = 3; + optional uint64 uint64_value = 4; + optional int32 int32_value = 5; + optional fixed64 fixed64_value = 6; + optional fixed32 fixed32_value = 7; + optional bool bool_value = 8; + optional string string_value = 9; + optional bytes bytes_value = 12; + optional uint32 uint32_value = 13; + optional sfixed32 sfixed32_value = 15; + optional sfixed64 sfixed64_value = 16; + optional sint32 sint32_value = 17; + optional sint64 sint64_value = 18; +} + +// Message containing fields with field numbers higher than any field above. +// An instance of this message is prepended to each binary message in the test +// to exercise the code path that handles fields encoded out of order of field +// number. +message ExtraFields { + optional string string_value = 1776; + optional bool bool_value = 1777; +} + +// The messages below are for yet-to-be created tests. + message EnumValue { enum Color { RED = 0; @@ -171,12 +173,3 @@ message RepeatedMessageValue { repeated NestedMessageValue message_values = 11; } - -// Message containing fields with field numbers higher than any field above. An -// instance of this message is prepended to each binary message in the test to -// exercise the code path that handles fields encoded out of order of field -// number. -message ExtraFields { - optional string string_value = 1776; - optional bool bool_value = 1777; -} diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index e3c48998305e9d9b6c185fd4c0f324fa0449c691..d9f179bee48de587976872dabb470cfd5c69114c 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -120,6 +120,7 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay): scaled_weight_tensor = math_ops.multiply( weights, multiplier_tensor, name='mul_fold') + new_layer_tensor = _CloneWithNewOperands( match.layer_op, match.input_tensor, scaled_weight_tensor, match.batch_to_space_op) @@ -368,20 +369,20 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, lambda: bn_decay_zero, lambda: match.bn_decay_mean_tensor, name='freeze_moving_mean') + graph_editor.reroute_ts( [bn_decay_mean_out], [match.bn_decay_mean_tensor], can_modify=bn_decay_mean_consumers) - if fused_batch_norm is False: - bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) - bn_decay_var_out = utils.smart_cond( - use_mv_avg, - lambda: bn_decay_zero, - lambda: match.bn_decay_var_tensor, - name='freeze_moving_var') - graph_editor.reroute_ts( - [bn_decay_var_out], [match.bn_decay_var_tensor], - can_modify=bn_decay_var_consumers) + bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) + bn_decay_var_out = utils.smart_cond( + use_mv_avg, + lambda: bn_decay_zero, + lambda: match.bn_decay_var_tensor, + name='freeze_moving_var') + graph_editor.reroute_ts( + [bn_decay_var_out], [match.bn_decay_var_tensor], + can_modify=bn_decay_var_consumers) correction_recip = utils.smart_cond( use_mv_avg, diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py index 7c907ffd92c1ae0c762e41cc429b0e6ce053f6b9..3f8063cc022726cb745d42aba3c834c71e876e70 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py @@ -128,6 +128,9 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + if freeze_batch_norm_delay is not None: + self._AssertMovingAveragesAreFrozen(g, scope) + for op in g.get_operations(): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) @@ -216,6 +219,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): ]) output_op_names = [scope + '/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + if freeze_batch_norm_delay is not None: + self._AssertMovingAveragesAreFrozen(g, scope) for op in g.get_operations(): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) @@ -284,6 +289,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + if freeze_batch_norm_delay is not None: + self._AssertMovingAveragesAreFrozen(g, scope) for op in g.get_operations(): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) @@ -351,6 +358,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + if freeze_batch_norm_delay is not None: + self._AssertMovingAveragesAreFrozen(g, scope) for op in g.get_operations(): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) @@ -431,6 +440,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + if freeze_batch_norm_delay is not None: + self._AssertMovingAveragesAreFrozen(g, scope) for op in g.get_operations(): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) @@ -515,6 +526,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): ]) output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] self._AssertOutputGoesToOps(folded_add, g, output_op_names) + if freeze_batch_norm_delay is not None: + self._AssertMovingAveragesAreFrozen(g, scope) for op in g.get_operations(): self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name) @@ -644,6 +657,22 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase): out_op = graph.get_operation_by_name(out_op_name) self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs]) + def _AssertMovingAveragesAreFrozen(self, graph, scope): + """Asserts to check if moving mean and variance are frozen. + + Args: + graph: Graph where the operations are located. + scope: Scope of batch norm op + """ + moving_average_mult = graph.get_operation_by_name( + scope + '/BatchNorm/AssignMovingAvg/mul') + self.assertTrue( + moving_average_mult.inputs[1].name.find('freeze_moving_mean/Merge') > 0) + moving_var_mult = graph.get_operation_by_name( + scope + '/BatchNorm/AssignMovingAvg_1/mul') + self.assertTrue( + moving_var_mult.inputs[1].name.find('freeze_moving_var/Merge') > 0) + def _CopyGraph(self, graph): """Return a copy of graph.""" meta_graph = saver_lib.export_meta_graph( diff --git a/tensorflow/contrib/quantize/python/quant_ops_test.py b/tensorflow/contrib/quantize/python/quant_ops_test.py index c2a8def48012c808da18587c8ff462fa33a363c0..a45840009b758881c14fb64b2d39af6cd4ec4bc4 100644 --- a/tensorflow/contrib/quantize/python/quant_ops_test.py +++ b/tensorflow/contrib/quantize/python/quant_ops_test.py @@ -75,7 +75,7 @@ class QuantOpsTest(googletest.TestCase): self.assertGreater(max_value, 0.0) self.assertLess(max_value, 1.0) - def testVariablesNotParitioned_LastValue(self): + def testVariablesNotPartitioned_LastValue(self): # Variables added should not use a default partiioner since they are # scalar. There would be a tensorflow error thrown if the partitioner was # respected by the rewrite. @@ -90,7 +90,7 @@ class QuantOpsTest(googletest.TestCase): is_training=True, vars_collection=_MIN_MAX_VARS) - def testVariablesNotParitioned_MovingAvg(self): + def testVariablesNotPartitioned_MovingAvg(self): # Variables added should not use a default partiioner since they are # scalar. There would be a tensorflow error thrown if the partitioner was # respected by the rewrite. diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 19e5bef1ea48ca4441cdef6b1a74e98e9cf6ddb9..cb66fd1f76bcdb0a8f77fc7c476511576368ab4e 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -198,7 +198,7 @@ def _FindLayersToQuantize(graph): | [post_conv_correction] | - biasadd|folded_bias + [biasadd|folded_bias] | [bypass] | @@ -261,6 +261,16 @@ def _FindLayersToQuantize(graph): layer_output_pattern = graph_matcher.OneofPattern( [batch_to_space_pattern, layer_pattern]) + + # For separable convolutions, we are looking for a conv, followed by a conv + # with no activations between the two. + sep_conv_pattern = graph_matcher.OpTypePattern( + '|'.join(_QUANTIZABLE_TYPES), + inputs=[ + graph_matcher.OneofPattern([layer_output_pattern]), + graph_matcher.OpTypePattern('*') + ], + ordered_inputs=False) folded_bias_mul_pattern = graph_matcher.OpTypePattern( 'Mul', inputs=[graph_matcher.OpTypePattern('*'), layer_output_pattern], @@ -278,6 +288,13 @@ def _FindLayersToQuantize(graph): ], ordered_inputs=False) + # batch_norms with forced updates have an Identity operation at the end. + # TODO(suharshs): Find a way to easily skip extra Identity operations. The + # current issue is that doing so can often match patterns across many layers + # incorrectly. + batch_norm_identity = graph_matcher.OpTypePattern( + 'Identity', inputs=[folded_bias_add_pattern]) + bias_add_pattern = graph_matcher.OpTypePattern( 'Add|BiasAdd', inputs=[layer_output_pattern, '*'], ordered_inputs=False) @@ -286,21 +303,24 @@ def _FindLayersToQuantize(graph): 'Add', inputs=[ graph_matcher.OneofPattern( - [bias_add_pattern, folded_bias_add_pattern]), '*' + [bias_add_pattern, folded_bias_add_pattern, batch_norm_identity]), + '*' ], ordered_inputs=False) # The input to the activation can come from bias add, fold bias add, the # bypasses. # TODO(suharshs): We should ideally skip Identity operations instead of - # treating them as an activation. + # treating them as activations. activation_pattern = graph_matcher.OpTypePattern( '|'.join(_ACTIVATION_TYPES) + '|Identity', inputs=[ graph_matcher.OneofPattern([ bias_add_pattern, folded_bias_add_pattern, + batch_norm_identity, bypass_pattern, + layer_pattern, ]) ]) @@ -384,6 +404,17 @@ def _FindLayersToQuantize(graph): layer_matches.append( _LayerMatch(layer_op, weight_tensor, activation_op, None, None, None)) + # Look for separable convolutions here + sep_conv_matcher = graph_matcher.GraphMatcher(sep_conv_pattern) + for match_result in sep_conv_matcher.match_graph(graph): + layer_op = match_result.get_op(layer_pattern) + weight_tensor = match_result.get_tensor(weight_identity_pattern) + activation_op = match_result.get_op(layer_pattern) + if layer_op not in matched_layer_set: + matched_layer_set.add(layer_op) + layer_matches.append( + _LayerMatch(layer_op, weight_tensor, activation_op, None, None, None)) + return layer_matches diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py index 11d052d7f491dc029d1bda9b47364d6e9c880a67..2944f964c7078814111c96890f18abe1607b68fc 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph.py +++ b/tensorflow/contrib/quantize/python/quantize_graph.py @@ -191,6 +191,7 @@ def experimental_create_training_graph(input_graph=None, def experimental_create_eval_graph(input_graph=None, weight_bits=8, activation_bits=8, + quant_delay=None, scope=None): """Rewrites an eval input_graph in place for simulated quantization. @@ -209,6 +210,8 @@ def experimental_create_eval_graph(input_graph=None, default graph. weight_bits: Number of bits to use for quantizing weights. activation_bits: Number of bits to use for quantizing activations. + quant_delay: Number of steps after which weights and activations are + quantized during eval. scope: The scope to be transformed. If it's not None, only the ops which are in this scope will be transformed. @@ -221,4 +224,5 @@ def experimental_create_eval_graph(input_graph=None, is_training=False, weight_bits=weight_bits, activation_bits=activation_bits, + quant_delay=quant_delay, scope=scope) diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py index 5e3af0a567536ef6fcfd86d82e94c0ba21077a85..31a2955ddb3b32f2b07c6125c8f83ffba335cc5f 100644 --- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py +++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py @@ -654,8 +654,80 @@ class QuantizeTest(test_util.TensorFlowTestCase): graph_def_after = str(graph.as_graph_def()) self.assertEqual(graph_def_before, graph_def_after) - def _BatchNormParams(self, fused=False): - return {'center': True, 'scale': True, 'decay': 1.0 - 0.003, 'fused': fused} + def testBatchNormForcedUpdates(self): + parameter_list = [ + # (activation, activation_op_name, fused_batch_norm) + (nn_ops.relu6, 'Relu6', False), + (nn_ops.relu, 'Relu', False), + (array_ops.identity, 'Identity', False), + (nn_ops.relu6, 'Relu6', True), + (nn_ops.relu, 'Relu', True), + (array_ops.identity, 'Identity', True), + ] + for params in parameter_list: + self._TestBatchNormForcedUpdates(params[0], params[1], params[2], False) + self._TestBatchNormForcedUpdates(params[0], params[1], params[2], True) + + def _TestBatchNormForcedUpdates(self, activation, activation_op_name, + fused_batch_norm, use_resource): + """post_activation bypass quantization should happen with forced updates.""" + graph = ops.Graph() + with graph.as_default(): + variable_scope.get_variable_scope().set_use_resource(use_resource) + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32)) + # Setting updates_collections to None forces updates adding an extra + # identity operation following batch norms. + bn_params = self._BatchNormParams( + fused=fused_batch_norm, force_updates=True) + conv = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation, + normalizer_fn=batch_norm, + normalizer_params=bn_params, + scope='test/test') + bypass_tensor = math_ops.add(conv, input2, name='test/add') + # The output of the post_activation bypass will be another layer. + _ = conv2d( + bypass_tensor, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + normalizer_fn=batch_norm, + normalizer_params=bn_params, + activation_fn=activation, + scope='test/unused') + + fold_batch_norms.FoldBatchNorms(graph, is_training=True) + quantize.Quantize(graph, is_training=True) + + # Ensure that the bypass node is preceded by and followed by a + # FakeQuantWithMinMaxVar operation, since the output of the Add isn't an + # activation. + self.assertTrue('FakeQuantWithMinMaxVars' in + [c.type for c in bypass_tensor.consumers()]) + self.assertTrue('FakeQuantWithMinMaxVars' in + [i.op.type for i in bypass_tensor.op.inputs]) + + with open('/tmp/bn_quant_test.pbtxt', 'w') as f: + f.write(str(graph.as_graph_def())) + + def _BatchNormParams(self, fused=False, force_updates=False): + params = { + 'center': True, + 'scale': True, + 'decay': 1.0 - 0.003, + 'fused': fused + } + if force_updates: + params['updates_collections'] = None + return params def _WeightInit(self, stddev): """Returns truncated normal variable initializer. diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index 92ca4a1b0c3126ebccf2b525f01f4d6455c4d527..06ebcdfee1617af0c13cd6ed09a2ec5190c5a718 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -122,12 +122,67 @@ class QuantizeTest(test_util.TensorFlowTestCase): array_ops.identity(node, name='control_dependency') quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + # Check if output of bias add is quantized + quantization_node_name = 'FakeQuantWithMinMaxVars' + conv_quant = graph.get_operation_by_name('test/test/conv_quant/' + + quantization_node_name) + self.assertEqual(conv_quant.type, quantization_node_name) + + for op in graph.get_operations(): + if op.type == quantization_node_name: + quant_op = graph.get_operation_by_name(op.name) + # Scan through all FakeQuant operations, ensuring that the activation + # identity op isn't in the consumers of the operation. + consumers = [] + for output in quant_op.outputs: + consumers.extend(output.consumers()) + + self.assertNotIn('test/relu6', [c.name for c in consumers]) + + def testInsertQuantOpInSeparableConv2d(self): + self._RunTestOverParameters(self._TestInsertQuantOpInSeparableConv2d) + + def _TestInsertQuantOpInSeparableConv2d(self, is_training): + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + input2 = array_ops.zeros((batch_size, height / 2, width / 2, depth)) + conv = separable_conv2d( + input1, + 3, [5, 5], + stride=2, + depth_multiplier=1.0, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + scope='test/test') + node = math_ops.add(conv, input2, name='test/add') + node = nn_ops.relu6(node, name='test/relu6') + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + # Check if output of bias add is quantized quantization_node_name = 'FakeQuantWithMinMaxVars' conv_quant = graph.get_operation_by_name('test/test/conv_quant/' + quantization_node_name) self.assertEqual(conv_quant.type, quantization_node_name) + # Check if weights for both convs inside seperable conv are quantized + pointwise_weight_quant = graph.get_operation_by_name( + 'test/test/weights_quant/' + quantization_node_name) + self.assertEqual(pointwise_weight_quant.type, quantization_node_name) + depthwise_weight_quant = graph.get_operation_by_name( + 'test/test/separable_conv2d/weights_quant/' + quantization_node_name) + self.assertEqual(depthwise_weight_quant.type, quantization_node_name) + + # Check if activations after first depthwise conv are quantized. + depthwise_act_quant = graph.get_operation_by_name( + 'test/test/separable_conv2d/act_quant/' + quantization_node_name) + self.assertEqual(depthwise_act_quant.type, quantization_node_name) + for op in graph.get_operations(): if op.type == quantization_node_name: quant_op = graph.get_operation_by_name(op.name) @@ -139,6 +194,33 @@ class QuantizeTest(test_util.TensorFlowTestCase): self.assertNotIn('test/relu6', [c.name for c in consumers]) + def testLayerActivationQuantized(self): + self._RunTestOverParameters(self._TestLayerActivationQuantized) + + def _TestLayerActivationQuantized(self, is_training): + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + _ = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=nn_ops.relu6, + biases_initializer=None, + scope='test') + # Ensure that both weights and output of activations are quantized + # when we have a conv->relu6 with no bias add + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + activation_op = graph.get_operation_by_name('test/Relu6') + conv_op = graph.get_operation_by_name('test/Conv2D') + self.assertTrue('test/weights_quant/FakeQuantWithMinMaxVars:0' in + [tensor_in.name for tensor_in in conv_op.inputs]) + self.assertTrue('FakeQuantWithMinMaxVars' in + [op.type for op in activation_op.outputs[0].consumers()]) + def testFinalLayerQuantized(self): self._RunTestOverParameters(self._TestFinalLayerQuantized) diff --git a/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py b/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py index 0f19ac7dbe0cee2eb6c780ec5ea6266bc847abd7..f23194a6f2e64e0619049bac51891d6d6099831f 100644 --- a/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py +++ b/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py @@ -61,10 +61,17 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase): func, args = self._CELLDEFS[celldef_name] return func(*args) - def _CreateInputs(self): - inputs = np.random.random([FunctionalRnnTest._BATCH_SIZE, - FunctionalRnnTest._TOTAL_TIME, - FunctionalRnnTest._INPUT_SIZE]) + def _CreateInputs(self, time_major=False): + if time_major: + inputs = np.random.random([ + FunctionalRnnTest._TOTAL_TIME, FunctionalRnnTest._BATCH_SIZE, + FunctionalRnnTest._INPUT_SIZE + ]) + else: + inputs = np.random.random([ + FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._TOTAL_TIME, + FunctionalRnnTest._INPUT_SIZE + ]) # Always leave one time slot empty, to check max_length behavior. sequence_length = np.random.randint( 0, high=FunctionalRnnTest._TOTAL_TIME - 1, @@ -72,15 +79,51 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase): dtype=np.int) return (inputs, sequence_length) - def _CreateRnnGraph(self, create_rnn_computation_func, cell, tf_inputs, - tf_sequence_length, initial_state=None, - time_major=None, scope=None): - tf_result = create_rnn_computation_func(cell=cell, inputs=tf_inputs, - sequence_length=tf_sequence_length, - initial_state=initial_state, - dtype=dtypes.float32, - time_major=time_major, - scope=scope) + def _CreateSymmetricInputs(self): + # total time = batch size + inputs = np.zeros( + (FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._BATCH_SIZE, + FunctionalRnnTest._INPUT_SIZE)) + for i in range(FunctionalRnnTest._BATCH_SIZE): + for j in range(i, FunctionalRnnTest._BATCH_SIZE): + inputs[i][j] = np.random.random([FunctionalRnnTest._INPUT_SIZE]) + inputs[j][i] = inputs[i][j] + + # Always leave one time slot empty, to check max_length behavior. + sequence_length = np.random.randint( + 0, + high=FunctionalRnnTest._BATCH_SIZE - 1, + size=FunctionalRnnTest._BATCH_SIZE, + dtype=np.int) + return (inputs, sequence_length) + + def _CreateRnnGraph(self, + create_rnn_computation_func, + cell, + tf_inputs, + tf_sequence_length, + is_bidirectional, + initial_state=None, + time_major=None, + scope=None): + if is_bidirectional: + tf_result = create_rnn_computation_func( + cell_fw=cell, + cell_bw=cell, + inputs=tf_inputs, + sequence_length=tf_sequence_length, + dtype=dtypes.float32, + time_major=time_major, + scope=scope) + else: + tf_result = create_rnn_computation_func( + cell=cell, + inputs=tf_inputs, + sequence_length=tf_sequence_length, + initial_state=initial_state, + dtype=dtypes.float32, + time_major=time_major, + scope=scope) grad = gradients_impl.gradients(tf_result, variables.trainable_variables()) return {'inference': tf_result, 'grad': grad} @@ -102,15 +145,26 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase): variable_cache[n] = v def _RunRnn(self, numpy_inputs, numpy_slen, cell_name, variable_cache, - is_dynamic): + is_dynamic, time_major=None, is_bidirectional=False): with ops.Graph().as_default() as graph: tf_inputs = array_ops.placeholder( dtypes.float32, shape=numpy_inputs.shape) tf_slen = array_ops.placeholder(dtypes.int32) feeds = {tf_inputs: numpy_inputs, tf_slen: numpy_slen} cell = self._CreateCell(cell_name) - fn = rnn_lib.dynamic_rnn if is_dynamic else functional_rnn.functional_rnn - fetches = self._CreateRnnGraph(fn, cell, tf_inputs, tf_slen) + if is_dynamic: + if is_bidirectional: + fn = rnn_lib.bidirectional_dynamic_rnn + else: + fn = rnn_lib.dynamic_rnn + else: + if is_bidirectional: + fn = functional_rnn.bidirectional_functional_rnn + else: + fn = functional_rnn.functional_rnn + + fetches = self._CreateRnnGraph( + fn, cell, tf_inputs, tf_slen, is_bidirectional, time_major=time_major) with self.test_session(graph=graph) as sess: sess.run(variables.global_variables_initializer()) # Note that cell.trainable_variables it not always set. @@ -158,6 +212,78 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase): self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) + def testLstmWithTimeMajorInputs(self): + """Checks an LSTM against the reference implementation, with time_major.""" + time_major = True + np_inputs, np_slen = self._CreateInputs(time_major=True) + var_cache = {} + args = [np_inputs, np_slen, 'lstm', var_cache] + _, func_rnn = self._RunRnn(*(args + [False]), time_major=time_major) + _, dyn_rnn = self._RunRnn(*(args + [True]), time_major=time_major) + self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) + self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) + + def testBidirectionalLstmWithTimeMajorInputs(self): + """Checks a bi-directional LSTM with time-major inputs.""" + time_major = True + np_inputs, np_slen = self._CreateInputs(time_major) + var_cache = {} + args = [np_inputs, np_slen, 'lstm', var_cache] + _, func_rnn = self._RunRnn( + *(args + [False]), time_major=time_major, is_bidirectional=True) + _, dyn_rnn = self._RunRnn( + *(args + [True]), time_major=time_major, is_bidirectional=True) + self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) + # TODO(b/112170761): comment out this line after the bug is fixed. + # self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) + + def testBidirectionalLstm(self): + """Checks time-major and batch-major rnn produce consistent results.""" + time_major_inputs, np_slen = self._CreateInputs(True) + batch_major_inputs = np.transpose(time_major_inputs, [1, 0, 2]) + var_cache = {} + args = [np_slen, 'lstm', var_cache, False] + _, time_major_rnn = self._RunRnn( + *([time_major_inputs] + args), time_major=True, is_bidirectional=True) + _, batch_major_rnn = self._RunRnn( + *([batch_major_inputs]+ args), time_major=False, is_bidirectional=True) + # Convert the batch-major outputs to be time-major before the comparasion. + outputs, state = batch_major_rnn['inference'] + outputs = [np.transpose(x, [1, 0, 2]) for x in outputs] + batch_major_rnn['inference'] = [outputs, state] + self.assertAllClose(time_major_rnn['inference'], + batch_major_rnn['inference']) + self.assertAllClose(time_major_rnn['grad'], batch_major_rnn['grad']) + + def testBidirectionalLstmWithSymmetricInputs(self): + """Checks a bi-directional LSTM with symmetric inputs. + + time-major and batch-major rnn produce the same result with symmetric + inputs. + """ + np_inputs, np_slen = self._CreateSymmetricInputs() + var_cache = {} + args = [np_inputs, np_slen, 'lstm', var_cache] + _, time_major_func_rnn = self._RunRnn( + *(args + [False]), time_major=True, is_bidirectional=True) + _, batch_major_func_rnn = self._RunRnn( + *(args + [False]), time_major=False, is_bidirectional=True) + _, time_major_dyn_rnn = self._RunRnn( + *(args + [True]), time_major=True, is_bidirectional=True) + _, batch_major_dyn_rnn = self._RunRnn( + *(args + [True]), time_major=False, is_bidirectional=True) + self.assertAllClose(time_major_func_rnn['inference'], + batch_major_func_rnn['inference']) + self.assertAllClose(time_major_func_rnn['grad'], + batch_major_func_rnn['grad']) + self.assertAllClose(time_major_dyn_rnn['inference'], + batch_major_dyn_rnn['inference']) + self.assertAllClose(time_major_dyn_rnn['grad'], batch_major_dyn_rnn['grad']) + self.assertAllClose(time_major_func_rnn['inference'], + batch_major_dyn_rnn['inference']) + self.assertAllClose(time_major_func_rnn['grad'], + batch_major_dyn_rnn['grad']) + if __name__ == '__main__': test_lib.main() diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py index a085474c1bf6117ba5663139c78d8f08f71392d3..67a8f59c3c03d01a5957a9eff8bd026e70770a45 100644 --- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py +++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py @@ -206,7 +206,7 @@ def _PickFinalStateFromHistory(acc_state, sequence_length): lengths = array_ops.tile(array_ops.reshape(sequence_length, [-1, 1]), [1, max_time]) last_idx = math_ops.cast(math_ops.equal(output_time, lengths - 1), - dtype=dtypes.float32) + dtype=state_var.dtype) last_idx = array_ops.transpose(last_idx) last_idx_for_bcast = array_ops.expand_dims(last_idx, -1) sliced = math_ops.multiply(last_idx_for_bcast, state_var) @@ -284,8 +284,13 @@ def functional_rnn(cell, inputs, sequence_length=None, inputs=inputs, cell_fn=func_cell.cell_step, use_tpu=use_tpu) - return _PostProcessOutput(extended_acc_state, extended_final_state, - func_cell, inputs_flat[0].shape[0], sequence_length) + tf_output, tf_state = _PostProcessOutput( + extended_acc_state, extended_final_state, func_cell, + inputs_flat[0].shape[0], sequence_length) + + if time_major: + tf_output = array_ops.transpose(tf_output, [1, 0, 2]) + return tf_output, tf_state def bidirectional_functional_rnn( diff --git a/tensorflow/contrib/recurrent/python/ops/recurrent.py b/tensorflow/contrib/recurrent/python/ops/recurrent.py index fa16b82ab62f27d034c3ca7584e7e1ca14be6f9b..4f289e0c85e2260a44a8ea2f3f1d6cacbc839f66 100644 --- a/tensorflow/contrib/recurrent/python/ops/recurrent.py +++ b/tensorflow/contrib/recurrent/python/ops/recurrent.py @@ -79,7 +79,7 @@ def _Index(struct, index): """ index = ops.convert_to_tensor(index) index.get_shape().assert_has_rank(0) - return nest.map_structure(lambda x: x[index], struct) + return nest.map_structure(lambda x: array_ops.gather(x, index), struct) def _Update(struct_acc, struct_x, t): diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 4eb5c920b3517a8968ff730003e786ae2a9c9e26..2a84629080d20e38807a4be87e51646c3046ebf3 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -118,7 +118,6 @@ cuda_py_tests( "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", "//tensorflow/python:rnn", "//tensorflow/python:rnn_cell", "//tensorflow/python:variable_scope", diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py index 07227bcb77d353200ee46763d51727ed9c0974a1..cb437f2a2f252fcb0763587b07fed19be5887282 100644 --- a/tensorflow/contrib/rnn/__init__.py +++ b/tensorflow/contrib/rnn/__init__.py @@ -59,6 +59,9 @@ See @{$python/contrib.rnn} guide. @@HighwayWrapper @@GLSTMCell @@SRUCell +@@IndRNNCell +@@IndyGRUCell +@@IndyLSTMCell @@AttentionCellWrapper diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 86f1e27abd53d011f37f06851dd6d0977853c8f4..85f0f8ced91e15cd0f9b3bc51f3a9e3aee12c978 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools import os import numpy as np @@ -35,7 +34,6 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope @@ -117,6 +115,27 @@ class RNNCellTest(test.TestCase): }) self.assertEqual(res[0].shape, (1, 2)) + def testIndRNNCell(self): + with self.test_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 2]) + cell = contrib_rnn_cell.IndRNNCell(2) + g, _ = cell(x, m) + self.assertEqual([ + "root/ind_rnn_cell/%s_w:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/ind_rnn_cell/%s_u:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/ind_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME + ], [v.name for v in cell.trainable_variables]) + self.assertFalse(cell.non_trainable_variables) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + self.assertEqual(res[0].shape, (1, 2)) + def testGRUCell(self): with self.test_session() as sess: with variable_scope.variable_scope( @@ -145,6 +164,34 @@ class RNNCellTest(test.TestCase): # Smoke test self.assertAllClose(res[0], [[0.156736, 0.156736]]) + def testIndyGRUCell(self): + with self.test_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.185265, 0.17704]]) + with variable_scope.variable_scope( + "other", initializer=init_ops.constant_initializer(0.5)): + # Test IndyGRUCell with input_size != num_units. + x = array_ops.zeros([1, 3]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.155127, 0.157328]]) + def testSRUCell(self): with self.test_session() as sess: with variable_scope.variable_scope( @@ -345,6 +392,72 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1], expected_mem0) self.assertAllClose(res[2], expected_mem1) + def testIndyLSTMCell(self): + for dtype in [dtypes.float16, dtypes.float32]: + np_dtype = dtype.as_numpy_dtype + with self.test_session(graph=ops.Graph()) as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2], dtype=dtype) + state_0 = (array_ops.zeros([1, 2], dtype=dtype),) * 2 + state_1 = (array_ops.zeros([1, 2], dtype=dtype),) * 2 + cell = rnn_cell_impl.MultiRNNCell( + [contrib_rnn_cell.IndyLSTMCell(2) for _ in range(2)]) + self.assertEqual(cell.dtype, None) + self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name) + self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name) + cell.get_config() # Should not throw an error + g, (out_state_0, out_state_1) = cell(x, (state_0, state_1)) + # Layer infers the input type. + self.assertEqual(cell.dtype, dtype.name) + expected_variable_names = [ + "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_w:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_u:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s:0" % + rnn_cell_impl._BIAS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_w:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_u:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s:0" % + rnn_cell_impl._BIAS_VARIABLE_NAME + ] + self.assertEqual(expected_variable_names, + [v.name for v in cell.trainable_variables]) + self.assertFalse(cell.non_trainable_variables) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run( + [g, out_state_0, out_state_1], { + x.name: np.array([[1., 1.]]), + state_0[0].name: 0.1 * np.ones([1, 2]), + state_0[1].name: 0.1 * np.ones([1, 2]), + state_1[0].name: 0.1 * np.ones([1, 2]), + state_1[1].name: 0.1 * np.ones([1, 2]), + }) + self.assertEqual(len(res), 3) + variables = variables_lib.global_variables() + self.assertEqual(expected_variable_names, [v.name for v in variables]) + # Only check the range of outputs as this is just a smoke test. + self.assertAllInRange(res[0], -1.0, 1.0) + self.assertAllInRange(res[1], -1.0, 1.0) + self.assertAllInRange(res[2], -1.0, 1.0) + with variable_scope.variable_scope( + "other", initializer=init_ops.constant_initializer(0.5)): + # Test IndyLSTMCell with input_size != num_units. + x = array_ops.zeros([1, 3], dtype=dtype) + state = (array_ops.zeros([1, 2], dtype=dtype),) * 2 + g, out_state = contrib_rnn_cell.IndyLSTMCell(2)(x, state) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run( + [g, out_state], { + x.name: np.array([[1., 1., 1.]], dtype=np_dtype), + state[0].name: 0.1 * np.ones([1, 2], dtype=np_dtype), + state[1].name: 0.1 * np.ones([1, 2], dtype=np_dtype), + }) + self.assertEqual(len(res), 2) + def testLSTMCell(self): with self.test_session() as sess: num_units = 8 @@ -935,50 +1048,6 @@ class DropoutWrapperTest(test.TestCase): self.assertAllClose(res0[1].h, res1[1].h) -class SlimRNNCellTest(test.TestCase): - - def testBasicRNNCell(self): - with self.test_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 2]) - my_cell = functools.partial(basic_rnn_cell, num_units=2) - # pylint: disable=protected-access - g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m) - # pylint: enable=protected-access - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([g], { - x.name: np.array([[1., 1.]]), - m.name: np.array([[0.1, 0.1]]) - }) - self.assertEqual(res[0].shape, (1, 2)) - - def testBasicRNNCellMatch(self): - batch_size = 32 - input_size = 100 - num_units = 10 - with self.test_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - inputs = random_ops.random_uniform((batch_size, input_size)) - _, initial_state = basic_rnn_cell(inputs, None, num_units) - rnn_cell = rnn_cell_impl.BasicRNNCell(num_units) - outputs, state = rnn_cell(inputs, initial_state) - variable_scope.get_variable_scope().reuse_variables() - my_cell = functools.partial(basic_rnn_cell, num_units=num_units) - # pylint: disable=protected-access - slim_cell = rnn_cell_impl._SlimRNNCell(my_cell) - # pylint: enable=protected-access - slim_outputs, slim_state = slim_cell(inputs, initial_state) - self.assertEqual(slim_outputs.get_shape(), outputs.get_shape()) - self.assertEqual(slim_state.get_shape(), state.get_shape()) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run([slim_outputs, slim_state, outputs, state]) - self.assertAllClose(res[0], res[2]) - self.assertAllClose(res[1], res[3]) - - def basic_rnn_cell(inputs, state, num_units, scope=None): if state is None: if inputs is not None: diff --git a/tensorflow/contrib/rnn/python/ops/rnn.py b/tensorflow/contrib/rnn/python/ops/rnn.py index 2f0caadda336b878e58e973e1c995cbec65d5732..0266b72dcb15e4aba01a9a31b4be75c5b84d44da 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn.py +++ b/tensorflow/contrib/rnn/python/ops/rnn.py @@ -175,7 +175,7 @@ def stack_bidirectional_dynamic_rnn(cells_fw, Returns: A tuple (outputs, output_state_fw, output_state_bw) where: outputs: Output `Tensor` shaped: - `batch_size, max_time, layers_output]`. Where layers_output + `[batch_size, max_time, layers_output]`. Where layers_output are depth-concatenated forward and backward outputs. output_states_fw is the final states, one tensor per layer, of the forward rnn. diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index b12e2cd5eddc3f8abdba62781692673a40e41d9b..1816b469ee5bf338453a82d18663f97f6565dc0c 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -23,6 +23,7 @@ import math from tensorflow.contrib.compiler import jit from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.rnn.python.ops import core_rnn_cell +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops @@ -30,6 +31,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_impl # pylint: disable=unused-import @@ -3050,3 +3052,343 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell): new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) return new_h, new_state + + +class IndRNNCell(rnn_cell_impl.LayerRNNCell): + """Independently Recurrent Neural Network (IndRNN) cell + (cf. https://arxiv.org/abs/1803.04831). + + Args: + num_units: int, The number of units in the RNN cell. + activation: Nonlinearity to use. Default: `tanh`. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + name: String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require reuse=True in such + cases. + dtype: Default dtype of the layer (default of `None` means use the type + of the first input). Required when `build` is called before `call`. + """ + + def __init__(self, + num_units, + activation=None, + reuse=None, + name=None, + dtype=None): + super(IndRNNCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) + + # Inputs must be 2-dimensional. + self.input_spec = base_layer.InputSpec(ndim=2) + + self._num_units = num_units + self._activation = activation or math_ops.tanh + + @property + def state_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + def build(self, inputs_shape): + if inputs_shape[1].value is None: + raise ValueError( + "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) + + input_depth = inputs_shape[1].value + # pylint: disable=protected-access + self._kernel_w = self.add_variable( + "%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + shape=[input_depth, self._num_units]) + self._kernel_u = self.add_variable( + "%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + shape=[1, self._num_units], + initializer=init_ops.random_uniform_initializer( + minval=-1, maxval=1, dtype=self.dtype)) + self._bias = self.add_variable( + rnn_cell_impl._BIAS_VARIABLE_NAME, + shape=[self._num_units], + initializer=init_ops.zeros_initializer(dtype=self.dtype)) + # pylint: enable=protected-access + + self.built = True + + def call(self, inputs, state): + """IndRNN: output = new_state = act(W * input + u * state + B).""" + + gate_inputs = math_ops.matmul(inputs, self._kernel_w) + ( + state * self._kernel_u) + gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) + output = self._activation(gate_inputs) + return output, output + + +class IndyGRUCell(rnn_cell_impl.LayerRNNCell): + r"""Independently Gated Recurrent Unit cell. + + Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to GRUCell, + yet with the \(U_r\), \(U_z\), and \(U\) matrices in equations 5, 6, and + 8 of http://arxiv.org/abs/1406.1078 respectively replaced by diagonal + matrices, i.e. a Hadamard product with a single vector: + + $$r_j = \sigma\left([\mathbf W_r\mathbf x]_j + + [\mathbf u_r\circ \mathbf h_{(t-1)}]_j\right)$$ + $$z_j = \sigma\left([\mathbf W_z\mathbf x]_j + + [\mathbf u_z\circ \mathbf h_{(t-1)}]_j\right)$$ + $$\tilde{h}^{(t)}_j = \phi\left([\mathbf W \mathbf x]_j + + [\mathbf u \circ \mathbf r \circ \mathbf h_{(t-1)}]_j\right)$$ + + where \(\circ\) denotes the Hadamard operator. This means that each IndyGRU + node sees only its own state, as opposed to seeing all states in the same + layer. + + TODO(gonnet): Write a paper describing this and add a reference here. + + Args: + num_units: int, The number of units in the GRU cell. + activation: Nonlinearity to use. Default: `tanh`. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + kernel_initializer: (optional) The initializer to use for the weight + matrices applied to the input. + bias_initializer: (optional) The initializer to use for the bias. + name: String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require reuse=True in such + cases. + dtype: Default dtype of the layer (default of `None` means use the type + of the first input). Required when `build` is called before `call`. + """ + + def __init__(self, + num_units, + activation=None, + reuse=None, + kernel_initializer=None, + bias_initializer=None, + name=None, + dtype=None): + super(IndyGRUCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) + + # Inputs must be 2-dimensional. + self.input_spec = base_layer.InputSpec(ndim=2) + + self._num_units = num_units + self._activation = activation or math_ops.tanh + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + + @property + def state_size(self): + return self._num_units + + @property + def output_size(self): + return self._num_units + + def build(self, inputs_shape): + if inputs_shape[1].value is None: + raise ValueError( + "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) + + input_depth = inputs_shape[1].value + # pylint: disable=protected-access + self._gate_kernel_w = self.add_variable( + "gates/%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + shape=[input_depth, 2 * self._num_units], + initializer=self._kernel_initializer) + self._gate_kernel_u = self.add_variable( + "gates/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + shape=[1, 2 * self._num_units], + initializer=init_ops.random_uniform_initializer( + minval=-1, maxval=1, dtype=self.dtype)) + self._gate_bias = self.add_variable( + "gates/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME, + shape=[2 * self._num_units], + initializer=(self._bias_initializer + if self._bias_initializer is not None else + init_ops.constant_initializer(1.0, dtype=self.dtype))) + self._candidate_kernel_w = self.add_variable( + "candidate/%s" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + shape=[input_depth, self._num_units], + initializer=self._kernel_initializer) + self._candidate_kernel_u = self.add_variable( + "candidate/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + shape=[1, self._num_units], + initializer=init_ops.random_uniform_initializer( + minval=-1, maxval=1, dtype=self.dtype)) + self._candidate_bias = self.add_variable( + "candidate/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME, + shape=[self._num_units], + initializer=(self._bias_initializer + if self._bias_initializer is not None else + init_ops.zeros_initializer(dtype=self.dtype))) + # pylint: enable=protected-access + + self.built = True + + def call(self, inputs, state): + """Gated recurrent unit (GRU) with nunits cells.""" + + gate_inputs = math_ops.matmul(inputs, self._gate_kernel_w) + ( + gen_array_ops.tile(state, [1, 2]) * self._gate_kernel_u) + gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias) + + value = math_ops.sigmoid(gate_inputs) + r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) + + r_state = r * state + + candidate = math_ops.matmul(inputs, self._candidate_kernel_w) + ( + r_state * self._candidate_kernel_u) + candidate = nn_ops.bias_add(candidate, self._candidate_bias) + + c = self._activation(candidate) + new_h = u * state + (1 - u) * c + return new_h, new_h + + +class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): + r"""Basic IndyLSTM recurrent network cell. + + Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to + BasicLSTMCell, yet with the \(U_f\), \(U_i\), \(U_o\) and \(U_c\) + matrices in + https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate + replaced by diagonal matrices, i.e. a Hadamard product with a single vector: + + $$f_t = \sigma_g\left(W_f x_t + u_f \circ h_{t-1} + b_f\right)$$ + $$i_t = \sigma_g\left(W_i x_t + u_i \circ h_{t-1} + b_i\right)$$ + $$o_t = \sigma_g\left(W_o x_t + u_o \circ h_{t-1} + b_o\right)$$ + $$c_t = f_t \circ c_{t-1} + + i_t \circ \sigma_c\left(W_c x_t + u_c \circ h_{t-1} + b_c\right)$$ + + where \(\circ\) denotes the Hadamard operator. This means that each IndyLSTM + node sees only its own state \(h\) and \(c\), as opposed to seeing all + states in the same layer. + + We add forget_bias (default: 1) to the biases of the forget gate in order to + reduce the scale of forgetting in the beginning of the training. + + It does not allow cell clipping, a projection layer, and does not + use peep-hole connections: it is the basic baseline. + + For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell} + that follows. + + TODO(gonnet): Write a paper describing this and add a reference here. + """ + + def __init__(self, + num_units, + forget_bias=1.0, + activation=None, + reuse=None, + kernel_initializer=None, + bias_initializer=None, + name=None, + dtype=None): + """Initialize the IndyLSTM cell. + + Args: + num_units: int, The number of units in the LSTM cell. + forget_bias: float, The bias added to forget gates (see above). + Must set to `0.0` manually when restoring from CudnnLSTM-trained + checkpoints. + activation: Activation function of the inner states. Default: `tanh`. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + kernel_initializer: (optional) The initializer to use for the weight + matrix applied to the inputs. + bias_initializer: (optional) The initializer to use for the bias. + name: String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require reuse=True in such + cases. + dtype: Default dtype of the layer (default of `None` means use the type + of the first input). Required when `build` is called before `call`. + """ + super(IndyLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) + + # Inputs must be 2-dimensional. + self.input_spec = base_layer.InputSpec(ndim=2) + + self._num_units = num_units + self._forget_bias = forget_bias + self._activation = activation or math_ops.tanh + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + + @property + def state_size(self): + return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units) + + @property + def output_size(self): + return self._num_units + + def build(self, inputs_shape): + if inputs_shape[1].value is None: + raise ValueError( + "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) + + input_depth = inputs_shape[1].value + # pylint: disable=protected-access + self._kernel_w = self.add_variable( + "%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + shape=[input_depth, 4 * self._num_units], + initializer=self._kernel_initializer) + self._kernel_u = self.add_variable( + "%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + shape=[1, 4 * self._num_units], + initializer=init_ops.random_uniform_initializer( + minval=-1, maxval=1, dtype=self.dtype)) + self._bias = self.add_variable( + rnn_cell_impl._BIAS_VARIABLE_NAME, + shape=[4 * self._num_units], + initializer=(self._bias_initializer + if self._bias_initializer is not None else + init_ops.zeros_initializer(dtype=self.dtype))) + # pylint: enable=protected-access + + self.built = True + + def call(self, inputs, state): + """Independent Long short-term memory cell (IndyLSTM). + + Args: + inputs: `2-D` tensor with shape `[batch_size, input_size]`. + state: An `LSTMStateTuple` of state tensors, each shaped + `[batch_size, num_units]`. + + Returns: + A pair containing the new hidden state, and the new state (a + `LSTMStateTuple`). + """ + sigmoid = math_ops.sigmoid + one = constant_op.constant(1, dtype=dtypes.int32) + c, h = state + + gate_inputs = math_ops.matmul(inputs, self._kernel_w) + gate_inputs += gen_array_ops.tile(h, [1, 4]) * self._kernel_u + gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = array_ops.split( + value=gate_inputs, num_or_size_splits=4, axis=one) + + forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype) + # Note that using `add` and `multiply` instead of `+` and `*` gives a + # performance improvement. So using those at the cost of readability. + add = math_ops.add + multiply = math_ops.multiply + new_c = add( + multiply(c, sigmoid(add(f, forget_bias_tensor))), + multiply(sigmoid(i), self._activation(j))) + new_h = multiply(self._activation(new_c), sigmoid(o)) + + new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) + return new_h, new_state diff --git a/tensorflow/contrib/rpc/python/kernel_tests/BUILD b/tensorflow/contrib/rpc/python/kernel_tests/BUILD index 2311c15a68c46090cec0f97bd950296506b0817e..cb0b89ae55b96361428c7845d4d6aab72543feb7 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/BUILD +++ b/tensorflow/contrib/rpc/python/kernel_tests/BUILD @@ -1,5 +1,3 @@ -# TODO(b/76425722): Port everything in here to OS (currently excluded). - package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 @@ -17,7 +15,6 @@ tf_proto_library( srcs = ["test_example.proto"], has_services = 1, cc_api_version = 2, - protodeps = ["//tensorflow/core:protos_all"], ) py_library( diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py index 27273d16b1c09eba60e124e632b353b09ea2d063..1c23c28860dac6203ea4ec8e808f63d3e9e467e2 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py +++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py @@ -51,23 +51,23 @@ class RpcOpTestBase(object): def testScalarHostPortRpc(self): with self.test_session() as sess: request_tensors = ( - test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString()) + test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString()) response_tensors = self.rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=self._address, request=request_tensors) self.assertEqual(response_tensors.shape, ()) response_values = sess.run(response_tensors) response_message = test_example_pb2.TestCase() self.assertTrue(response_message.ParseFromString(response_values)) - self.assertAllEqual([2, 3, 4], response_message.shape) + self.assertAllEqual([2, 3, 4], response_message.values) def testScalarHostPortTryRpc(self): with self.test_session() as sess: request_tensors = ( - test_example_pb2.TestCase(shape=[1, 2, 3]).SerializeToString()) + test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString()) response_tensors, status_code, status_message = self.try_rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=self._address, request=request_tensors) self.assertEqual(status_code.shape, ()) @@ -77,7 +77,7 @@ class RpcOpTestBase(object): sess.run((response_tensors, status_code, status_message))) response_message = test_example_pb2.TestCase() self.assertTrue(response_message.ParseFromString(response_values)) - self.assertAllEqual([2, 3, 4], response_message.shape) + self.assertAllEqual([2, 3, 4], response_message.values) # For the base Rpc op, don't expect to get error status back. self.assertEqual(errors.OK, status_code_values) self.assertEqual(b'', status_message_values) @@ -86,7 +86,7 @@ class RpcOpTestBase(object): with self.test_session() as sess: request_tensors = [] response_tensors = self.rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=self._address, request=request_tensors) self.assertAllEqual(response_tensors.shape, [0]) @@ -95,7 +95,7 @@ class RpcOpTestBase(object): def testInvalidMethod(self): for method in [ - '/InvalidService.IncrementTestShapes', + '/InvalidService.Increment', self.get_method_name('InvalidMethodName') ]: with self.test_session() as sess: @@ -115,12 +115,12 @@ class RpcOpTestBase(object): with self.assertRaises(errors.UnavailableError): sess.run( self.rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=address, request='')) _, status_code_value, status_message_value = sess.run( self.try_rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=address, request='')) self.assertEqual(errors.UNAVAILABLE, status_code_value) @@ -182,10 +182,10 @@ class RpcOpTestBase(object): with self.test_session() as sess: request_tensors = [ test_example_pb2.TestCase( - shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20) + values=[i, i + 1, i + 2]).SerializeToString() for i in range(20) ] response_tensors = self.rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=self._address, request=request_tensors) self.assertEqual(response_tensors.shape, (20,)) @@ -194,17 +194,17 @@ class RpcOpTestBase(object): for i in range(20): response_message = test_example_pb2.TestCase() self.assertTrue(response_message.ParseFromString(response_values[i])) - self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape) + self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values) def testVecHostPortManyParallelRpcs(self): with self.test_session() as sess: request_tensors = [ test_example_pb2.TestCase( - shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20) + values=[i, i + 1, i + 2]).SerializeToString() for i in range(20) ] many_response_tensors = [ self.rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=self._address, request=request_tensors) for _ in range(10) ] @@ -216,25 +216,25 @@ class RpcOpTestBase(object): for i in range(20): response_message = test_example_pb2.TestCase() self.assertTrue(response_message.ParseFromString(response_values[i])) - self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape) + self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values) def testVecHostPortRpcUsingEncodeAndDecodeProto(self): with self.test_session() as sess: request_tensors = encode_proto_op.encode_proto( message_type='tensorflow.contrib.rpc.TestCase', - field_names=['shape'], + field_names=['values'], sizes=[[3]] * 20, values=[ [[i, i + 1, i + 2] for i in range(20)], ]) response_tensor_strings = self.rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=self._address, request=request_tensors) _, (response_shape,) = decode_proto_op.decode_proto( bytes=response_tensor_strings, message_type='tensorflow.contrib.rpc.TestCase', - field_names=['shape'], + field_names=['values'], output_types=[dtypes.int32]) response_shape_values = sess.run(response_shape) self.assertAllEqual([[i + 1, i + 2, i + 3] @@ -285,9 +285,9 @@ class RpcOpTestBase(object): addresses = flatten([[ self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@' ] for _ in range(10)]) - request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString() + request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString() response_tensors, status_code, _ = self.try_rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=addresses, request=request) response_tensors_values, status_code_values = sess.run((response_tensors, @@ -303,9 +303,9 @@ class RpcOpTestBase(object): flatten = lambda x: list(itertools.chain.from_iterable(x)) with self.test_session() as sess: methods = flatten( - [[self.get_method_name('IncrementTestShapes'), 'InvalidMethodName'] + [[self.get_method_name('Increment'), 'InvalidMethodName'] for _ in range(10)]) - request = test_example_pb2.TestCase(shape=[0, 1, 2]).SerializeToString() + request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString() response_tensors, status_code, _ = self.try_rpc( method=methods, address=self._address, request=request) response_tensors_values, status_code_values = sess.run((response_tensors, @@ -325,10 +325,10 @@ class RpcOpTestBase(object): ] for _ in range(10)]) requests = [ test_example_pb2.TestCase( - shape=[i, i + 1, i + 2]).SerializeToString() for i in range(20) + values=[i, i + 1, i + 2]).SerializeToString() for i in range(20) ] response_tensors, status_code, _ = self.try_rpc( - method=self.get_method_name('IncrementTestShapes'), + method=self.get_method_name('Increment'), address=addresses, request=requests) response_tensors_values, status_code_values = sess.run((response_tensors, @@ -343,4 +343,4 @@ class RpcOpTestBase(object): response_message = test_example_pb2.TestCase() self.assertTrue( response_message.ParseFromString(response_tensors_values[i])) - self.assertAllEqual([i + 1, i + 2, i + 3], response_message.shape) + self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values) diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py index 7cbd636cb16e3befc9ae27cb231696634e859a22..265254aa51c64ff5a76ad3a9f7e081c56dd639e7 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py +++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_servicer.py @@ -30,8 +30,8 @@ from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2_grpc class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer): """Test servicer for RpcOp tests.""" - def IncrementTestShapes(self, request, context): - """Increment the entries in the shape attribute of request. + def Increment(self, request, context): + """Increment the entries in the `values` attribute of request. Args: request: input TestCase. @@ -40,8 +40,8 @@ class RpcOpTestServicer(test_example_pb2_grpc.TestCaseServiceServicer): Returns: output TestCase. """ - for i in range(len(request.shape)): - request.shape[i] += 1 + for i in range(len(request.values)): + request.values[i] += 1 return request def AlwaysFailWithInvalidArgument(self, request, context): diff --git a/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto b/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto index 96f4550f62bc17e713abe1f3843ec0964f57b046..8141466349afcebcd104153a9f28c8f382458098 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto +++ b/tensorflow/contrib/rpc/python/kernel_tests/test_example.proto @@ -1,29 +1,17 @@ // Test description and protos to work with it. -// -// Many of the protos in this file are for unit tests that haven't been written yet. syntax = "proto2"; -import "tensorflow/core/framework/types.proto"; - package tensorflow.contrib.rpc; -// A TestCase holds a proto and a bunch of assertions -// about how it should decode. +// A TestCase holds a sequence of values. message TestCase { - // A batch of primitives to be serialized and decoded. - repeated RepeatedPrimitiveValue primitive = 1; - // The shape of the batch. - repeated int32 shape = 2; - // Expected sizes for each field. - repeated int32 sizes = 3; - // Expected values for each field. - repeated FieldSpec field = 4; + repeated int32 values = 1; }; service TestCaseService { - // Copy input, and increment each entry in 'shape' by 1. - rpc IncrementTestShapes(TestCase) returns (TestCase) { + // Copy input, and increment each entry in 'values' by 1. + rpc Increment(TestCase) returns (TestCase) { } // Sleep forever. @@ -42,130 +30,3 @@ service TestCaseService { rpc SometimesFailWithInvalidArgument(TestCase) returns (TestCase) { } }; - -// FieldSpec describes the expected output for a single field. -message FieldSpec { - optional string name = 1; - optional tensorflow.DataType dtype = 2; - optional RepeatedPrimitiveValue expected = 3; -}; - -message TestValue { - optional PrimitiveValue primitive_value = 1; - optional EnumValue enum_value = 2; - optional MessageValue message_value = 3; - optional RepeatedMessageValue repeated_message_value = 4; - optional RepeatedPrimitiveValue repeated_primitive_value = 6; -} - -message PrimitiveValue { - optional double double_value = 1; - optional float float_value = 2; - optional int64 int64_value = 3; - optional uint64 uint64_value = 4; - optional int32 int32_value = 5; - optional fixed64 fixed64_value = 6; - optional fixed32 fixed32_value = 7; - optional bool bool_value = 8; - optional string string_value = 9; - optional bytes bytes_value = 12; - optional uint32 uint32_value = 13; - optional sfixed32 sfixed32_value = 15; - optional sfixed64 sfixed64_value = 16; - optional sint32 sint32_value = 17; - optional sint64 sint64_value = 18; -} - -// NOTE: This definition must be kept in sync with PackedPrimitiveValue. -message RepeatedPrimitiveValue { - repeated double double_value = 1; - repeated float float_value = 2; - repeated int64 int64_value = 3; - repeated uint64 uint64_value = 4; - repeated int32 int32_value = 5; - repeated fixed64 fixed64_value = 6; - repeated fixed32 fixed32_value = 7; - repeated bool bool_value = 8; - repeated string string_value = 9; - repeated bytes bytes_value = 12; - repeated uint32 uint32_value = 13; - repeated sfixed32 sfixed32_value = 15; - repeated sfixed64 sfixed64_value = 16; - repeated sint32 sint32_value = 17; - repeated sint64 sint64_value = 18; - repeated PrimitiveValue message_value = 19; -} - -// A PackedPrimitiveValue looks exactly the same as a RepeatedPrimitiveValue -// in the text format, but the binary serializion is different. -// We test the packed representations by loading the same test cases -// using this definition instead of RepeatedPrimitiveValue. -// NOTE: This definition must be kept in sync with RepeatedPrimitiveValue -// in every way except the packed=true declaration. -message PackedPrimitiveValue { - repeated double double_value = 1 [packed = true]; - repeated float float_value = 2 [packed = true]; - repeated int64 int64_value = 3 [packed = true]; - repeated uint64 uint64_value = 4 [packed = true]; - repeated int32 int32_value = 5 [packed = true]; - repeated fixed64 fixed64_value = 6 [packed = true]; - repeated fixed32 fixed32_value = 7 [packed = true]; - repeated bool bool_value = 8 [packed = true]; - repeated string string_value = 9; - repeated bytes bytes_value = 12; - repeated uint32 uint32_value = 13 [packed = true]; - repeated sfixed32 sfixed32_value = 15 [packed = true]; - repeated sfixed64 sfixed64_value = 16 [packed = true]; - repeated sint32 sint32_value = 17 [packed = true]; - repeated sint64 sint64_value = 18 [packed = true]; - repeated PrimitiveValue message_value = 19; -} - -message EnumValue { - enum Color { - RED = 0; - ORANGE = 1; - YELLOW = 2; - GREEN = 3; - BLUE = 4; - INDIGO = 5; - VIOLET = 6; - }; - optional Color enum_value = 14; - repeated Color repeated_enum_value = 15; -} - - -message InnerMessageValue { - optional float float_value = 2; - repeated bytes bytes_values = 8; -} - -message MiddleMessageValue { - repeated int32 int32_values = 5; - optional InnerMessageValue message_value = 11; - optional uint32 uint32_value = 13; -} - -message MessageValue { - optional double double_value = 1; - optional MiddleMessageValue message_value = 11; -} - -message RepeatedMessageValue { - message NestedMessageValue { - optional float float_value = 2; - repeated bytes bytes_values = 8; - } - - repeated NestedMessageValue message_values = 11; -} - -// Message containing fields with field numbers higher than any field above. An -// instance of this message is prepended to each binary message in the test to -// exercise the code path that handles fields encoded out of order of field -// number. -message ExtraFields { - optional string string_value = 1776; - optional bool bool_value = 1777; -} diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index 26fd4e2023806765ea4088f4c13a780ca7338bff..fbb50befdfb2ccbd97465c11f8219e604a0ebc18 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -93,3 +93,32 @@ py_test( "//tensorflow/python/saved_model:utils", ], ) + +py_library( + name = "keras_saved_model", + srcs = ["python/saved_model/keras_saved_model.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python:lib", + "//tensorflow/python:util", + "//tensorflow/python/keras:engine", + "//tensorflow/python/saved_model:constants", + ], +) + +py_test( + name = "keras_saved_model_test", + size = "small", + srcs = ["python/saved_model/keras_saved_model_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":saved_model_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:training", + "//tensorflow/python/keras", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/saved_model/__init__.py b/tensorflow/contrib/saved_model/__init__.py index b4f27a055dad7a5b95112d561cc878609a558f8d..95e1a8967b2223fd3feb112af3cbe0c5991d2d03 100644 --- a/tensorflow/contrib/saved_model/__init__.py +++ b/tensorflow/contrib/saved_model/__init__.py @@ -24,11 +24,12 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long +from tensorflow.contrib.saved_model.python.saved_model.keras_saved_model import * from tensorflow.contrib.saved_model.python.saved_model.signature_def_utils import * # pylint: enable=unused-import,widcard-import,line-too-long from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ["get_signature_def_by_key"] +_allowed_symbols = ["get_signature_def_by_key", "load_model", "save_model"] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/saved_model/python/saved_model/__init__.py b/tensorflow/contrib/saved_model/python/saved_model/__init__.py index 7b91622b6127413ce122c4166a18255b65365d32..e3b76bb6f34846f02ccdf623d48ddd9c5909fdce 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/__init__.py +++ b/tensorflow/contrib/saved_model/python/saved_model/__init__.py @@ -24,5 +24,6 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import +from tensorflow.contrib.saved_model.python.saved_model import keras_saved_model from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e2a969f053d3f1ded8aecd6411a62a198df48bb0 --- /dev/null +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py @@ -0,0 +1,108 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +"""Utility functions to save/load keras Model to/from SavedModel.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.keras.models import model_from_json +from tensorflow.python.lib.io import file_io +from tensorflow.python.saved_model import constants +from tensorflow.python.util import compat + + +def save_model(model, saved_model_path): + """Save a `tf.keras.Model` into Tensorflow SavedModel format. + + `save_model` generates such files/folders under the `saved_model_path` folder: + 1) an asset folder containing the json string of the model's + configuration(topology). + 2) a checkpoint containing the model weights. + + Note that subclassed models can not be saved via this function, unless you + provide an implementation for get_config() and from_config(). + Also note that `tf.keras.optimizers.Optimizer` instances can not currently be + saved to checkpoints. Use optimizers from `tf.train`. + + Args: + model: A `tf.keras.Model` to be saved. + saved_model_path: a string specifying the path to the SavedModel directory. + + Raises: + NotImplementedError: If the passed in model is a subclassed model. + """ + if not model._is_graph_network: + raise NotImplementedError + + # save model configuration as a json string under assets folder. + model_json = model.to_json() + assets_destination_dir = os.path.join( + compat.as_bytes(saved_model_path), + compat.as_bytes(constants.ASSETS_DIRECTORY)) + + if not file_io.file_exists(assets_destination_dir): + file_io.recursive_create_dir(assets_destination_dir) + + model_json_filepath = os.path.join( + compat.as_bytes(assets_destination_dir), + compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON)) + file_io.write_string_to_file(model_json_filepath, model_json) + + # save model weights in checkpoint format. + checkpoint_destination_dir = os.path.join( + compat.as_bytes(saved_model_path), + compat.as_bytes(constants.VARIABLES_DIRECTORY)) + + if not file_io.file_exists(checkpoint_destination_dir): + file_io.recursive_create_dir(checkpoint_destination_dir) + + checkpoint_prefix = os.path.join( + compat.as_text(checkpoint_destination_dir), + compat.as_text(constants.VARIABLES_FILENAME)) + model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True) + + +def load_model(saved_model_path): + """Load a keras.Model from SavedModel. + + load_model reinstantiates model state by: + 1) loading model topology from json (this will eventually come + from metagraph). + 2) loading model weights from checkpoint. + + Args: + saved_model_path: a string specifying the path to an existing SavedModel. + + Returns: + a keras.Model instance. + """ + # restore model topology from json string + model_json_filepath = os.path.join( + compat.as_bytes(saved_model_path), + compat.as_bytes(constants.ASSETS_DIRECTORY), + compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON)) + model_json = file_io.read_file_to_string(model_json_filepath) + model = model_from_json(model_json) + + # restore model weights + checkpoint_prefix = os.path.join( + compat.as_text(saved_model_path), + compat.as_text(constants.VARIABLES_DIRECTORY), + compat.as_text(constants.VARIABLES_FILENAME)) + model.load_weights(checkpoint_prefix) + return model diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py new file mode 100644 index 0000000000000000000000000000000000000000..107ae1b07b777570e4124337595ceecd6e33cd0b --- /dev/null +++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py @@ -0,0 +1,201 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +"""Tests for saving/loading function for keras Model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import numpy as np + +from tensorflow.contrib.saved_model.python.saved_model import keras_saved_model +from tensorflow.python import keras +from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import training +from tensorflow.python.platform import test +from tensorflow.python.training import training as training_module + + +class TestModelSavingandLoading(test.TestCase): + + def test_saving_sequential_model(self): + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + + ref_y = model.predict(x) + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + + temp_saved_model = os.path.join(temp_dir, 'saved_model') + keras_saved_model.save_model(model, temp_saved_model) + + loaded_model = keras_saved_model.load_model(temp_saved_model) + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + @test_util.run_in_graph_and_eager_modes + def test_saving_sequential_model_without_compile(self): + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + + x = np.random.random((1, 3)) + ref_y = model.predict(x) + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + + temp_saved_model = os.path.join(temp_dir, 'saved_model') + keras_saved_model.save_model(model, temp_saved_model) + loaded_model = keras_saved_model.load_model(temp_saved_model) + + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + def test_saving_functional_model(self): + with self.test_session(): + inputs = keras.layers.Input(shape=(3,)) + x = keras.layers.Dense(2)(inputs) + output = keras.layers.Dense(3)(x) + + model = keras.models.Model(inputs, output) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + ref_y = model.predict(x) + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + + temp_saved_model = os.path.join(temp_dir, 'saved_model') + keras_saved_model.save_model(model, temp_saved_model) + loaded_model = keras_saved_model.load_model(temp_saved_model) + + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + @test_util.run_in_graph_and_eager_modes + def test_saving_functional_model_without_compile(self): + with self.test_session(): + inputs = keras.layers.Input(shape=(3,)) + x = keras.layers.Dense(2)(inputs) + output = keras.layers.Dense(3)(x) + + model = keras.models.Model(inputs, output) + + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + + ref_y = model.predict(x) + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + + temp_saved_model = os.path.join(temp_dir, 'saved_model') + keras_saved_model.save_model(model, temp_saved_model) + loaded_model = keras_saved_model.load_model(temp_saved_model) + + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + @test_util.run_in_graph_and_eager_modes + def test_saving_with_tf_optimizer(self): + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.Dense(3)) + model.compile( + loss='mse', + optimizer=training_module.RMSPropOptimizer(0.1), + metrics=['acc']) + + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + ref_y = model.predict(x) + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + + temp_saved_model = os.path.join(temp_dir, 'saved_model') + keras_saved_model.save_model(model, temp_saved_model) + loaded_model = keras_saved_model.load_model(temp_saved_model) + loaded_model.compile( + loss='mse', + optimizer=training_module.RMSPropOptimizer(0.1), + metrics=['acc']) + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + # test that new updates are the same with both models + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + + ref_loss = model.train_on_batch(x, y) + loss = loaded_model.train_on_batch(x, y) + self.assertAllClose(ref_loss, loss, atol=1e-05) + + ref_y = model.predict(x) + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + # test saving/loading again + keras_saved_model.save_model(loaded_model, temp_saved_model) + loaded_model = keras_saved_model.load_model(temp_saved_model) + y = loaded_model.predict(x) + self.assertAllClose(ref_y, y, atol=1e-05) + + def test_saving_subclassed_model_raise_error(self): + # For now, saving subclassed model should raise an error. It should be + # avoided later with loading from SavedModel.pb. + + class SubclassedModel(training.Model): + + def __init__(self): + super(SubclassedModel, self).__init__() + self.layer1 = keras.layers.Dense(3) + self.layer2 = keras.layers.Dense(1) + + def call(self, inp): + return self.layer2(self.layer1(inp)) + + model = SubclassedModel() + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + temp_saved_model = os.path.join(temp_dir, 'saved_model') + with self.assertRaises(NotImplementedError): + keras_saved_model.save_model(model, temp_saved_model) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 178328619f087789df040489cd150ba018cc8d14..4073b390fc72cf0f84edd0d2ab56df5ffeb3e2e5 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -132,6 +132,48 @@ class TestGatherTree(test.TestCase): def test_gather_tree_from_array_2d(self): self._test_gather_tree_from_array(depth_ndims=2) + def test_gather_tree_from_array_complex_trajectory(self): + # Max. time = 7, batch = 1, beam = 5. + array = np.expand_dims(np.array( + [[[25, 12, 114, 89, 97]], + [[9, 91, 64, 11, 162]], + [[34, 34, 34, 34, 34]], + [[2, 4, 2, 2, 4]], + [[2, 3, 6, 2, 2]], + [[2, 2, 2, 3, 2]], + [[2, 2, 2, 2, 2]]]), -1) + parent_ids = np.array( + [[[0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0]], + [[0, 1, 2, 3, 4]], + [[0, 0, 1, 2, 1]], + [[0, 1, 1, 2, 3]], + [[0, 1, 3, 1, 2]], + [[0, 1, 2, 3, 4]]]) + expected_array = np.expand_dims(np.array( + [[[25, 25, 25, 25, 25]], + [[9, 9, 91, 9, 9]], + [[34, 34, 34, 34, 34]], + [[2, 4, 2, 4, 4]], + [[2, 3, 6, 3, 6]], + [[2, 2, 2, 3, 2]], + [[2, 2, 2, 2, 2]]]), -1) + sequence_length = [[4, 6, 4, 7, 6]] + + array = ops.convert_to_tensor( + array, dtype=dtypes.float32) + parent_ids = ops.convert_to_tensor( + parent_ids, dtype=dtypes.int32) + expected_array = ops.convert_to_tensor( + expected_array, dtype=dtypes.float32) + + sorted_array = beam_search_decoder.gather_tree_from_array( + array, parent_ids, sequence_length) + + with self.test_session() as sess: + sorted_array, expected_array = sess.run([sorted_array, expected_array]) + self.assertAllEqual(expected_array, sorted_array) + class TestArrayShapeChecks(test.TestCase): diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 184144f64a56358206014a0f75473b4a9b16617a..f17dbb0fe3c13c3a43f043b82772949737dfb2de 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -145,24 +145,20 @@ def gather_tree_from_array(t, parent_ids, sequence_length): array_ops.expand_dims(math_ops.range(beam_width), 0), 0) beam_ids = array_ops.tile(beam_ids, [max_time, batch_size, 1]) - mask = array_ops.sequence_mask( - sequence_length, maxlen=max_time, dtype=dtypes.int32) - mask = array_ops.transpose(mask, perm=[2, 0, 1]) - - # Use beam_width + 1 to mark the end of beam. - masked_beam_ids = (beam_ids * mask) + (1 - mask) * (beam_width + 1) - max_sequence_lengths = math_ops.to_int32( math_ops.reduce_max(sequence_length, axis=1)) sorted_beam_ids = beam_search_ops.gather_tree( - step_ids=masked_beam_ids, + step_ids=beam_ids, parent_ids=parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=beam_width + 1) # For out of range steps, simply copy the same beam. + in_bound_steps = array_ops.transpose( + array_ops.sequence_mask(sequence_length, maxlen=max_time), + perm=[2, 0, 1]) sorted_beam_ids = array_ops.where( - math_ops.cast(mask, dtypes.bool), x=sorted_beam_ids, y=beam_ids) + in_bound_steps, x=sorted_beam_ids, y=beam_ids) # Generate indices for gather_nd. time_ind = array_ops.tile(array_ops.reshape( @@ -250,7 +246,7 @@ class BeamSearchDecoder(decoder.Decoder): ``` tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( encoder_outputs, multiplier=beam_width) - tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch( + tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch( encoder_final_state, multiplier=beam_width) tiled_sequence_length = tf.contrib.seq2seq.tile_batch( sequence_length, multiplier=beam_width) diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index e69725ff8ab1ba4de880c914a6f5fdad5e54566d..f58268eff525a4b592c79acb32207e1a3f62bdc7 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -21,6 +21,7 @@ from __future__ import print_function import abc import six +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -182,19 +183,20 @@ def dynamic_decode(decoder, raise TypeError("Expected decoder to be type Decoder, but saw: %s" % type(decoder)) - def _is_xla_tensor(tensor): - try: - op = tensor.op - except AttributeError: - return False - if control_flow_util.IsInXLAContext(op): - return True - return False - with variable_scope.variable_scope(scope, "decoder") as varscope: - # Properly cache variable values inside the while_loop - if varscope.caching_device is None: - varscope.set_caching_device(lambda op: op.device) + # Determine context types. + ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None + in_while_loop = ( + control_flow_util.GetContainingWhileContext(ctxt) is not None) + # Properly cache variable values inside the while_loop. + # Don't set a caching device when running in a loop, since it is possible + # that train steps could be wrapped in a tf.while_loop. In that scenario + # caching prevents forward computations in loop iterations from re-reading + # the updated weights. + if not context.executing_eagerly() and not in_while_loop: + if varscope.caching_device is None: + varscope.set_caching_device(lambda op: op.device) if maximum_iterations is not None: maximum_iterations = ops.convert_to_tensor( @@ -208,9 +210,6 @@ def dynamic_decode(decoder, decoder.output_dtype, decoder.batch_size) - is_xla = False - if any([_is_xla_tensor(i) for i in nest.flatten(initial_inputs)]): - is_xla = True if is_xla and maximum_iterations is None: raise ValueError("maximum_iterations is required for XLA compilation.") if maximum_iterations is not None: diff --git a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py index 345eb6cfaa67fd4cda6e7e3f01a1243bbf3c9fa1..f4348e80eac54933d67cdf7bd281d6a9c6c10381 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py @@ -53,7 +53,8 @@ def spectrogram_to_mel_matrix(num_mel_bins=20, num_spectrogram_bins=129, audio_sample_rate=8000, lower_edge_hertz=125.0, - upper_edge_hertz=3800.0): + upper_edge_hertz=3800.0, + unused_dtype=None): """Return a matrix that can post-multiply spectrogram rows to make mel. Copied from @@ -132,9 +133,9 @@ class LinearToMelTest(test.TestCase): # lower_edge_hertz, upper_edge_hertz) to test. configs = [ # Defaults. - (20, 129, 8000.0, 125.0, 3800.0), + (20, 129, 8000.0, 125.0, 3800.0, dtypes.float64), # Settings used by Tacotron (https://arxiv.org/abs/1703.10135). - (80, 1025, 24000.0, 80.0, 12000.0) + (80, 1025, 24000.0, 80.0, 12000.0, dtypes.float64) ] with self.test_session(use_gpu=True): for config in configs: @@ -143,7 +144,8 @@ class LinearToMelTest(test.TestCase): self.assertAllClose(mel_matrix_np, mel_matrix.eval(), atol=3e-6) def test_dtypes(self): - for dtype in (dtypes.float16, dtypes.float32, dtypes.float64): + # LinSpace is not supported for tf.float16. + for dtype in (dtypes.bfloat16, dtypes.float32, dtypes.float64): self.assertEqual(dtype, mel_ops.linear_to_mel_weight_matrix(dtype=dtype).dtype) @@ -167,7 +169,8 @@ class LinearToMelTest(test.TestCase): def test_constant_folding(self): """Mel functions should be constant foldable.""" - for dtype in (dtypes.float16, dtypes.float32, dtypes.float64): + # TODO(rjryan): tf.bloat16 cannot be constant folded by Grappler. + for dtype in (dtypes.float32, dtypes.float64): g = ops.Graph() with g.as_default(): mel_matrix = mel_ops.linear_to_mel_weight_matrix(dtype=dtype) diff --git a/tensorflow/contrib/signal/python/ops/mel_ops.py b/tensorflow/contrib/signal/python/ops/mel_ops.py index 1e84006116daa3f28c760037cb9eeafd53eaafb8..062d84aea183ab61501a8b07521adb1a1a17c63c 100644 --- a/tensorflow/contrib/signal/python/ops/mel_ops.py +++ b/tensorflow/contrib/signal/python/ops/mel_ops.py @@ -151,22 +151,21 @@ def linear_to_mel_weight_matrix(num_mel_bins=20, _validate_arguments(num_mel_bins, sample_rate, lower_edge_hertz, upper_edge_hertz, dtype) - # To preserve accuracy, we compute the matrix at float64 precision and then - # cast to `dtype` at the end. This function can be constant folded by graph - # optimization since there are no Tensor inputs. + # This function can be constant folded by graph optimization since there are + # no Tensor inputs. sample_rate = ops.convert_to_tensor( - sample_rate, dtypes.float64, name='sample_rate') + sample_rate, dtype, name='sample_rate') lower_edge_hertz = ops.convert_to_tensor( - lower_edge_hertz, dtypes.float64, name='lower_edge_hertz') + lower_edge_hertz, dtype, name='lower_edge_hertz') upper_edge_hertz = ops.convert_to_tensor( - upper_edge_hertz, dtypes.float64, name='upper_edge_hertz') - zero_float64 = ops.convert_to_tensor(0.0, dtypes.float64) + upper_edge_hertz, dtype, name='upper_edge_hertz') + zero = ops.convert_to_tensor(0.0, dtype) # HTK excludes the spectrogram DC bin. bands_to_zero = 1 nyquist_hertz = sample_rate / 2.0 linear_frequencies = math_ops.linspace( - zero_float64, nyquist_hertz, num_spectrogram_bins)[bands_to_zero:] + zero, nyquist_hertz, num_spectrogram_bins)[bands_to_zero:] spectrogram_bins_mel = array_ops.expand_dims( _hertz_to_mel(linear_frequencies), 1) @@ -193,11 +192,8 @@ def linear_to_mel_weight_matrix(num_mel_bins=20, # Intersect the line segments with each other and zero. mel_weights_matrix = math_ops.maximum( - zero_float64, math_ops.minimum(lower_slopes, upper_slopes)) + zero, math_ops.minimum(lower_slopes, upper_slopes)) # Re-add the zeroed lower bins we sliced out above. - mel_weights_matrix = array_ops.pad( - mel_weights_matrix, [[bands_to_zero, 0], [0, 0]]) - - # Cast to the desired type. - return math_ops.cast(mel_weights_matrix, dtype, name=name) + return array_ops.pad( + mel_weights_matrix, [[bands_to_zero, 0], [0, 0]], name=name) diff --git a/tensorflow/contrib/signal/python/ops/reconstruction_ops.py b/tensorflow/contrib/signal/python/ops/reconstruction_ops.py index 653c030a04c2bbc7e3ee49b9c85a781fb49de8d0..4db8dc2ca090534f2cda66bd55c30dfa389b860a 100644 --- a/tensorflow/contrib/signal/python/ops/reconstruction_ops.py +++ b/tensorflow/contrib/signal/python/ops/reconstruction_ops.py @@ -90,22 +90,28 @@ def overlap_and_add(signal, frame_step, name=None): raise ValueError("frame_step must be an integer. Got %s" % frame_step.dtype) - # If frame_length and frame_step are known at graph construction time, check - # frame_step is less than or equal to frame_length. - frame_step_static = tensor_util.constant_value(frame_step) - if (frame_step_static is not None and signal.shape.ndims is not None and - signal.shape[-1].value is not None and - frame_step_static > signal.shape[-1].value): - raise ValueError( - "frame_step (%d) must be less than or equal to frame_length (%d)" % ( - frame_step_static, signal.shape[-1].value)) - signal_shape = array_ops.shape(signal) # All dimensions that are not part of the overlap-and-add. Can be empty for # rank 2 inputs. outer_dimensions = signal_shape[:-2] + # If frame_length and frame_step are known at graph construction time, check + # frame_step is less than or equal to frame_length. + frame_step_static = tensor_util.constant_value(frame_step) + if (frame_step_static is not None and signal.shape.ndims is not None and + signal.shape[-1].value is not None): + if frame_step_static > signal.shape[-1].value: + raise ValueError( + "frame_step (%d) must be less than or equal to " + "frame_length (%d)" % ( + frame_step_static, signal.shape[-1].value)) + # If frame_length is equal to frame_step, there's no overlap so just + # reshape the tensor. + if frame_step_static == signal.shape[-1].value: + return array_ops.reshape(signal, array_ops.concat( + [outer_dimensions, [-1]], 0)) + signal_rank = array_ops.rank(signal) frames = signal_shape[-2] frame_length = signal_shape[-1] diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py index 3d0308aaf3da3b5b16fd22a2905db36917e8c97b..2c97834523424d0fab56330b4d9355a75427e0ef 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation_test.py +++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py @@ -33,7 +33,6 @@ from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.wrappers import hooks from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics @@ -242,7 +241,7 @@ class SingleEvaluationTest(test.TestCase): checkpoint_path = os.path.join(self.get_temp_dir(), 'this_file_doesnt_exist') log_dir = os.path.join(self.get_temp_dir(), 'error_raised') - with self.assertRaises(errors.NotFoundError): + with self.assertRaises(ValueError): evaluation.evaluate_once('', checkpoint_path, log_dir) def _prepareCheckpoint(self, checkpoint_path): diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 3e41e3d0b48ea06f9cb8c1862e27eacb5ebc4417..4d1807130c57039976dfa57c27bb0d4807e75212 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -20,6 +20,8 @@ import os import tempfile import time +import sqlite3 + import numpy as np import six @@ -275,6 +277,22 @@ class EagerFileTest(test_util.TensorFlowTestCase): class EagerDbTest(summary_test_util.SummaryDbTest): + def testDbURIOpen(self): + tmpdb_path = os.path.join(self.get_temp_dir(), 'tmpDbURITest.sqlite') + tmpdb_uri = six.moves.urllib_parse.urljoin("file:", tmpdb_path) + tmpdb_writer = summary_ops.create_db_writer( + tmpdb_uri, + "experimentA", + "run1", + "user1") + with summary_ops.always_record_summaries(): + with tmpdb_writer.as_default(): + summary_ops.scalar('t1', 2.0) + tmpdb = sqlite3.connect(tmpdb_path) + num = get_one(tmpdb, 'SELECT count(*) FROM Tags WHERE tag_name = "t1"') + self.assertEqual(num, 1) + tmpdb.close() + def testIntegerSummaries(self): step = training_util.create_global_step() writer = self.create_db_writer() diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index 136856c0156c41046f9af61cdd6e3d5f8213309e..164f3e58e6c0b2486d270c457500c8dca0c7e7eb 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -223,7 +223,6 @@ tf_kernel_library( ":model_ops_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", ], alwayslink = 1, ) @@ -319,7 +318,6 @@ tf_kernel_library( ":stats_ops_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", ], alwayslink = 1, ) diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc index 7e25579070eef13682dedfcd3c9e435333f65687..6cb2c881e2428dfcac3187bf7364582e857b9879 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc @@ -51,7 +51,8 @@ std::unique_ptr CreateBinaryDecisionNodeEvaluator( InequalityDecisionNodeEvaluator::InequalityDecisionNodeEvaluator( const decision_trees::InequalityTest& test, int32 left, int32 right) : BinaryDecisionNodeEvaluator(left, right) { - safe_strto32(test.feature_id().id().value(), &feature_num_); + CHECK(safe_strto32(test.feature_id().id().value(), &feature_num_)) + << "Invalid feature ID: [" << test.feature_id().id().value() << "]"; threshold_ = test.threshold().float_value(); include_equals_ = test.type() == decision_trees::InequalityTest::LESS_OR_EQUAL; @@ -72,7 +73,9 @@ ObliqueInequalityDecisionNodeEvaluator::ObliqueInequalityDecisionNodeEvaluator( : BinaryDecisionNodeEvaluator(left, right) { for (int i = 0; i < test.oblique().features_size(); ++i) { int32 val; - safe_strto32(test.oblique().features(i).id().value(), &val); + CHECK(safe_strto32(test.oblique().features(i).id().value(), &val)) + << "Invalid feature ID: [" << test.oblique().features(i).id().value() + << "]"; feature_num_.push_back(val); feature_weights_.push_back(test.oblique().weights(i)); } @@ -97,7 +100,8 @@ int32 ObliqueInequalityDecisionNodeEvaluator::Decide( MatchingValuesDecisionNodeEvaluator::MatchingValuesDecisionNodeEvaluator( const decision_trees::MatchingValuesTest& test, int32 left, int32 right) : BinaryDecisionNodeEvaluator(left, right) { - safe_strto32(test.feature_id().id().value(), &feature_num_); + CHECK(safe_strto32(test.feature_id().id().value(), &feature_num_)) + << "Invalid feature ID: [" << test.feature_id().id().value() << "]"; for (const auto& val : test.value()) { values_.push_back(val.float_value()); } diff --git a/tensorflow/contrib/tensorboard/db/BUILD b/tensorflow/contrib/tensorboard/db/BUILD index 3f6b4cdc9ad10f5089f28af35a8be408918c7f90..6507546ee9f81108add181a9c83064c9860005e2 100644 --- a/tensorflow/contrib/tensorboard/db/BUILD +++ b/tensorflow/contrib/tensorboard/db/BUILD @@ -106,6 +106,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:png_internal", "//tensorflow/core:protos_all_cc", ], ) diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index adda0b758b172f5e80c165e4b28dbdbecef2ba16..fc0d22d112efcccd1a3be6388d36478cf2076ff5 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -3,7 +3,7 @@ # and provide TensorRT operators and converter package. # APIs are meant to change over time. -package(default_visibility = ["//tensorflow:__subpackages__"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 @@ -11,7 +11,6 @@ exports_files(["LICENSE"]) load( "//tensorflow:tensorflow.bzl", - "py_test", "tf_cc_test", "tf_copts", "tf_cuda_library", @@ -20,6 +19,7 @@ load( "tf_gen_op_libs", "tf_gen_op_wrapper_py", ) +load("//tensorflow:tensorflow.bzl", "cuda_py_tests") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") @@ -33,11 +33,13 @@ tf_cuda_cc_test( size = "small", srcs = ["tensorrt_test.cc"], tags = [ - "manual", - "notap", + "no_windows", + "nomac", ], deps = [ + "//tensorflow/core:gpu_init", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor", "//tensorflow/core:test", "//tensorflow/core:test_main", ] + if_tensorrt([ @@ -83,10 +85,12 @@ cc_library( copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + ":test_utils", + ":trt_allocator", + ":trt_conversion", ":trt_logging", ":trt_plugins", ":trt_resources", - ":trt_conversion", ":utils", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib_proto_parsing", @@ -119,7 +123,6 @@ tf_cuda_library( tf_gen_op_wrapper_py( name = "trt_engine_op", - gen_locally = True, deps = [ ":trt_engine_op_op_lib", ":trt_logging", @@ -156,6 +159,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":tf_trt_integration_test_base", ":trt_convert_py", ":trt_ops_py", "//tensorflow/python:errors", @@ -181,11 +185,17 @@ py_library( ], ) +# TODO(aaroey): this wrapper has been causing troubles of double linking, so +# either get rid of it, or split to make it contain minimum dependencies. tf_py_wrap_cc( name = "wrap_conversion", srcs = ["trt_conversion.i"], copts = tf_copts(), + swig_includes = [ + "//tensorflow/python:platform/base.i", + ], deps = [ + ":test_utils", ":trt_conversion", ":trt_engine_op_kernel", "//third_party/python_runtime:headers", @@ -195,17 +205,16 @@ tf_py_wrap_cc( tf_cuda_library( name = "trt_resources", srcs = [ - "resources/trt_allocator.cc", "resources/trt_int8_calibrator.cc", "resources/trt_resource_manager.cc", ], hdrs = [ - "resources/trt_allocator.h", "resources/trt_int8_calibrator.h", "resources/trt_resource_manager.h", "resources/trt_resources.h", ], deps = [ + ":trt_allocator", ":trt_logging", ":utils", "//tensorflow/core:framework_headers_lib", @@ -216,6 +225,34 @@ tf_cuda_library( ]), ) +tf_cuda_library( + name = "trt_allocator", + srcs = ["resources/trt_allocator.cc"], + hdrs = ["resources/trt_allocator.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + +tf_cc_test( + name = "trt_allocator_test", + size = "small", + srcs = ["resources/trt_allocator_test.cc"], + tags = [ + "no_windows", + "nomac", + ], + deps = [ + ":trt_allocator", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + # Library for the node-level conversion portion of TensorRT operation creation tf_cuda_library( name = "trt_conversion", @@ -231,6 +268,8 @@ tf_cuda_library( ], deps = [ ":segment", + ":test_utils", + ":trt_allocator", ":trt_plugins", ":trt_logging", ":trt_resources", @@ -240,7 +279,6 @@ tf_cuda_library( "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", - "//tensorflow/core:gpu_runtime", "//tensorflow/core:framework_lite", "//tensorflow/core:graph", "//tensorflow/core:lib", @@ -275,13 +313,21 @@ tf_cc_test( name = "segment_test", size = "small", srcs = ["segment/segment_test.cc"], + tags = [ + "no_windows", + "nomac", + ], deps = [ ":segment", - "//tensorflow/c:c_api", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + "//tensorflow/core:core_cpu", "//tensorflow/core:lib", + "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) @@ -311,8 +357,9 @@ tf_cuda_cc_test( size = "small", srcs = ["plugin/trt_plugin_factory_test.cc"], tags = [ - "manual", - "notap", + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", ], deps = [ ":trt_plugins", @@ -325,23 +372,61 @@ tf_cuda_cc_test( ]), ) -py_test( +py_library( + name = "tf_trt_integration_test_base", + srcs = ["test/tf_trt_integration_test_base.py"], + deps = [ + ":trt_convert_py", + ":trt_ops_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + ], +) + +cuda_py_tests( name = "tf_trt_integration_test", - srcs = ["test/tf_trt_integration_test.py"], - main = "test/tf_trt_integration_test.py", - srcs_version = "PY2AND3", - tags = [ - "manual", - "notap", + srcs = [ + "test/base_test.py", + # "test/batch_matmul_test.py", + # "test/biasadd_matmul_test.py", + # "test/binary_tensor_weight_broadcast_test.py", # Blocked by trt4 installation + # "test/concatenation_test.py", # Blocked by trt4 installation + "test/const_broadcast_test.py", + "test/multi_connection_neighbor_engine_test.py", + "test/neighboring_engine_test.py", + # "test/unary_test.py", # Blocked by trt4 installation + # "test/vgg_block_nchw_test.py", + # "test/vgg_block_test.py", + "test/memory_alignment_test.py", ], - deps = [ - ":init_py", + additional_deps = [ + ":tf_trt_integration_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", ], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], ) cc_library( name = "utils", + srcs = ["convert/utils.cc"], hdrs = ["convert/utils.h"], + copts = tf_copts(), + deps = [ + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "test_utils", + srcs = ["test/utils.cc"], + hdrs = ["test/utils.h"], + deps = [ + "//tensorflow/core:lib", + "@com_googlesource_code_re2//:re2", + ], ) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 13986127ba879c6787595f1d7b9485352fa0d635..21ec8b0b30c595a1fad01b69bce9b16393742704 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -29,9 +30,7 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/contrib/tensorrt/segment/segment.h" -#include "tensorflow/core/common_runtime/gpu/gpu_id.h" -#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" -#include "tensorflow/core/common_runtime/gpu/process_state.h" +#include "tensorflow/contrib/tensorrt/test/utils.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -86,27 +85,48 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(jie): Segmentation shouldn't associated with op name. // Split it into a registration for each kernel. static const std::set candidate_ops = { - "Identity", - "Snapshot", - "Const", - "Conv2D", - "MaxPool", - "BiasAdd", - "Relu", - "Add", - "Mul", - "Sub", - "Rsqrt", - "Pad", - "Mean", - "AvgPool", - "ConcatV2", - "DepthwiseConv2dNative", - "FusedBatchNorm", - "FusedBatchNormV2", - // TODO(ben,jie): ... + "Identity", + "Snapshot", + "Const", + "Conv2D", + "MaxPool", + "BiasAdd", + "Relu", + "Add", + "Mul", + "Sub", + "Rsqrt", + "Pad", + "Mean", + "AvgPool", + "ConcatV2", + "DepthwiseConv2dNative", + "FusedBatchNorm", + "FusedBatchNormV2", + "Div", + "RealDiv", + "Rsqrt", + "Reciprocal", + "Exp", + "Log", + "Sqrt", + "Abs", + "Neg", +#if NV_TENSORRT_MAJOR > 3 + "MatMul", + "BatchMatMul", + "Softmax", + "Minimum", + "Maximum", + "TopKV2", + "Sum", + "Prod", + "Max", + "Min", +#endif + // TODO(ben,jie): ... }; - // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h) + // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc) return (candidate_ops.count(node->type_string()) || PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); } @@ -152,7 +172,7 @@ tensorflow::Status ConvertCalibGraphToInferGraph( "Need to run graph with calibration data first!"); } if (cres->calibrator_) { - cres->calibrator_->setDone(); + cres->calibrator_->waitAndSetDone(); cres->thr_->join(); const auto& calibration_table = cres->calibrator_->getCalibrationTableAsString(); @@ -168,26 +188,50 @@ tensorflow::Status ConvertCalibGraphToInferGraph( "Can't get TRTCalibrator from resource manager!"); } cres->Unref(); - calib_rm->Cleanup(container_name); + TF_RETURN_IF_ERROR(calib_rm->Cleanup(container_name)); } } return tensorflow::Status::OK(); } -// Entry function from Python. tensorflow::Status ConvertGraphDefToTensorRT( const tensorflow::GraphDef& graph_def, const std::vector& output_names, size_t max_batch_size, size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, int precision_mode, int minimum_segment_size, bool is_dyn_op, int max_cached_engines, std::vector cached_engine_batches) { - // optimization pass + // Create GrapplerItem. tensorflow::grappler::GrapplerItem item; item.fetch = output_names; item.graph = graph_def; - // grappler requires a virtual cluster with a proper GPU device - // in order to calculate flops>0 or fails with FATAL - // We add numbers from a Pascal card here to have flops>0 + + // TODO(aaroey): we should have used single machine cluster like the + // following, but the problem is then wrap_conversion will depend on + // direct_session and cause double linking problems. To fix this we need to + // fix or get rid of the swig dependency. Here we use VirtualCluster + // as a work around, and we need to create a session to initialize the + // underlying device before calling this method. +#if 0 + // Create single machine cluster. Note that this will create a session and + // initialize the gpu devices. + const int num_cpu_cores = + tensorflow::grappler::GetNumAvailableLogicalCPUCores(); + const int num_gpus = tensorflow::grappler::GetNumAvailableGPUs(); + VLOG(2) << "cpu_cores: " << num_cpu_cores; + VLOG(2) << "gpus: " << num_gpus; + const int timeout_s = 60 * 10; + std::unique_ptr cluster( + new tensorflow::grappler::SingleMachine( + timeout_s, num_cpu_cores, num_gpus)); + // These settings are the defaults in tensorflow/python/grappler/cluster.py. + cluster->DisableDetailedStats(true); + cluster->AllowSoftPlacement(true); + cluster->SetNumWarmupSteps(10); + TF_RETURN_IF_ERROR(cluster->Provision()); +#else + // Create virtual cluster. Grappler requires a virtual cluster with a proper + // GPU device in order to calculate flops>0 or fails with FATAL in dbg mode. + // We add numbers from a Pascal card here to have flops>0. tensorflow::DeviceProperties device_properties; device_properties.set_type("GPU"); device_properties.mutable_environment()->insert({"architecture", "6"}); @@ -196,47 +240,43 @@ tensorflow::Status ConvertGraphDefToTensorRT( std::unique_ptr cluster( new tensorflow::grappler::VirtualCluster( {{"/GPU:0", device_properties}})); +#endif - // single machine - int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores(); - int num_gpus = tensorflow::grappler::GetNumAvailableGPUs(); - VLOG(2) << "cpu_cores: " << num_cpu_cores; - VLOG(2) << "gpus: " << num_gpus; + // Create RewriterConfig. tensorflow::RewriterConfig rw_cfg; - // use only const folding and layout for the time being since new optimizers - // break the graph for us + // TODO(aaroey): use only const folding and layout for the time being since + // new optimizers break the graph for trt. rw_cfg.add_optimizers("constfold"); rw_cfg.add_optimizers("layout"); - rw_cfg.set_meta_optimizer_iterations(tensorflow::RewriterConfig::ONE); + auto optimizer = rw_cfg.add_custom_optimizers(); + optimizer->set_name("TensorRTOptimizer"); + auto& parameters = *(optimizer->mutable_parameter_map()); + parameters["minimum_segment_size"].set_i(minimum_segment_size); + parameters["max_batch_size"].set_i(max_batch_size); + parameters["is_dynamic_op"].set_b(is_dyn_op); + parameters["max_workspace_size_bytes"].set_i(max_workspace_size_bytes); + TF_RETURN_IF_ERROR(GetPrecisionModeName( + precision_mode, parameters["precision_mode"].mutable_s())); + parameters["maximum_cached_engines"].set_i(max_cached_engines); + if (!cached_engine_batches.empty()) { + auto list = parameters["cached_engine_batches"].mutable_list(); + for (const int batch : cached_engine_batches) { + list->add_i(batch); + } + } + + // Run optimizer. tensorflow::grappler::MetaOptimizer meta_opt(nullptr, rw_cfg); - tensorflow::GraphDef gdef; - TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, &gdef)); - item.graph = gdef; - - // AJ refactoring shape inference through grappler/GraphProperties. - tensorflow::grappler::GraphProperties static_graph_properties(item); - TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); - // Build full graph - ConversionParams cp; - cp.input_graph_def = &gdef; - cp.output_names = &output_names; - cp.max_batch_size = max_batch_size; - cp.output_graph_def = new_graph_def; - cp.precision_mode = precision_mode; - cp.is_dyn_op = is_dyn_op; - cp.max_cached_engines = max_cached_engines; - cp.cached_engine_batches = cached_engine_batches; - cp.minimum_segment_size = minimum_segment_size; - cp.graph_properties = &static_graph_properties; - cp.max_workspace_size_bytes = max_workspace_size_bytes; + TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, new_graph_def)); + if (VLOG_IS_ON(5)) { std::fstream f; f.open("TRTConversionInput.pb", std::fstream::out | std::fstream::binary | std::fstream::trunc); - f << gdef.SerializeAsString(); + f << new_graph_def->SerializeAsString(); f.close(); } - return ConvertAfterShapes(cp); + return Status::OK(); } // Function to get subsegment information structure. @@ -247,23 +287,23 @@ tensorflow::Status GetEngineInfo( const std::unordered_map& node_map, const std::vector& reverse_topo_order, EngineInfo* info) { - std::vector subgraph_node_ids; + std::vector subgraph_node_ids; // Topologically sorted node ids. + std::set subgraph_node_names = segment_nodes; + std::set added_const_node_ids; // Used to prevent double insertion. std::set segment_devices; - int input_port = 0; - int output_port = 0; // Map from src_node_name+port to the unique port numbers of the TRT op, where // the src_node_name is the name of the source node of the input/output // edge, thus there must not be any duplicates since source nodes of // input/output edges must be in different split of the graph. // TODO(aaroey): consider using node id and port instead. - std::unordered_map created_edges; + // TODO(aaroey): using topo order instead of reverting reverse topo order. + std::unordered_map input_to_engine_port, output_to_engine_port; for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend(); ++it) { const auto& node_name = (*it)->name(); - if (segment_nodes.count(node_name) == 0) continue; - auto node = node_map.at(node_name); + auto node = *it; auto node_device = node->requested_device(); if (!node_device.empty()) { segment_devices.insert(node_device); @@ -275,59 +315,94 @@ tensorflow::Status GetEngineInfo( << " neither have requested device nor assigned device"; } } - int node_id = node->id(); + const int node_id = node->id(); subgraph_node_ids.push_back(node_id); + // Create input connections. for (const auto edge : node->in_edges()) { auto input_node = edge->src(); - if (segment_nodes.count(input_node->name()) == 0) { - // Add constant input node into the segment. We don't care if it has - // other output edges going into other engines or TF nodes. Since we add - // it only to the subsegment node list, not the subsegment itself, it - // won't be removed from the graph. If it doesn't have any edges, TF - // will prune it out. - if (input_node->type_string() == "Const") { - subgraph_node_ids.push_back(input_node->id()); - } else if (!edge->IsControlEdge() && !input_node->IsSource()) { - string s(input_node->name()); - StrAppend(&s, ":", edge->src_output()); - VLOG(1) << "Input edge = " << s; - int port = input_port; - if (created_edges.count(s)) { - port = created_edges.at(s); - } else { - created_edges.insert({s, port}); - input_port++; - } - info->connections.emplace_back(input_node->name(), input_node->id(), - edge->src_output(), node_name, node_id, - edge->dst_input(), true, port); + if (input_node->IsSource() || segment_nodes.count(input_node->name())) { + continue; + } + if (edge->IsControlEdge()) { + // Control input. + info->connections.emplace_back(input_node->name(), input_node->id(), + node_name, node_id, + /*input_edge=*/true); + } else if (input_node->type_string() == "Const") { + // Add constant data input nodes into the segment graphdef (thus also in + // the engine). We don't care if it has other output edges going into + // other engines or TF nodes. Since we add it only to the segment + // graphdef, not the segment itself, it won't be removed from the graph. + // If it doesn't have any edges, TF will prune it out. + // + // Note that the segmenter already ensure that the constant data input + // is valid and suppported by the engine. + if (!added_const_node_ids.insert(input_node->id()).second) { + // Already added before. + continue; } + VLOG(1) << "Adding const node " << input_node->name(); + QCHECK(subgraph_node_names.insert(input_node->name()).second); + // Since we already add (duplicate) the const input node to the segment + // graphdef, it's now not a data dependency any more, but to make the + // dependency correct we still add a control dependency. + info->connections.emplace_back(input_node->name(), input_node->id(), + node_name, node_id, + /*input_edge=*/true); + } else { + // Non-const data input. + int port = Graph::kControlSlot - 1; + // Use the source non-segment node name/port as key. + const string s = StrCat(input_node->name(), ":", edge->src_output()); + VLOG(1) << "Input edge = " << s; + if (input_to_engine_port.count(s)) { + port = input_to_engine_port.at(s); + } else { + port = input_to_engine_port.size(); + input_to_engine_port.insert({s, port}); + } + info->connections.emplace_back( + input_node->name(), input_node->id(), edge->src_output(), node_name, + node_id, edge->dst_input(), /*input_edge=*/true, port); } } + // Create output connections. for (const auto edge : node->out_edges()) { auto output_node = edge->dst(); - if (segment_nodes.count(output_node->name()) == 0 && - !edge->IsControlEdge() && !output_node->IsSink()) { - string s(node_name); - StrAppend(&s, ":", edge->src_output()); + if (output_node->IsSink() || segment_nodes.count(output_node->name())) { + continue; + } + if (edge->IsControlEdge()) { + // Control output. + info->connections.emplace_back(output_node->name(), output_node->id(), + node_name, node_id, + /*input_edge=*/false); + } else { + // Data output. + int port = Graph::kControlSlot - 1; + // Use the source segment node name/port as key. + const string s = StrCat(node_name, ":", edge->src_output()); VLOG(1) << "Output edge = " << s; - int port = output_port; - if (created_edges.count(s)) { - port = created_edges.at(s); + if (output_to_engine_port.count(s)) { + port = output_to_engine_port.at(s); } else { - created_edges.insert({s, port}); - output_port++; + port = output_to_engine_port.size(); + output_to_engine_port.insert({s, port}); } - info->connections.emplace_back(output_node->name(), output_node->id(), - edge->dst_input(), node_name, node_id, - edge->src_output(), false, port); + info->connections.emplace_back( + output_node->name(), output_node->id(), edge->dst_input(), + node_name, node_id, edge->src_output(), /*input_edge=*/false, port); } } - } + } // For each segment node in topological order. + // Construct the const nodes first. + subgraph_node_ids.insert(subgraph_node_ids.begin(), + added_const_node_ids.begin(), + added_const_node_ids.end()); TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef( - g, graph_properties, subgraph_node_ids, &info->connections, - &info->segment_graph_def, &info->engine_name)); + g, graph_properties, subgraph_node_names, subgraph_node_ids, + &info->connections, &info->segment_graph_def, &info->engine_name)); // TODO(sami): This should not happen once segmenter is updated. if (segment_devices.size() == 1) { info->device = *segment_devices.begin(); @@ -337,92 +412,137 @@ tensorflow::Status GetEngineInfo( << "but this shouldn't have happened"; info->device = *segment_devices.begin(); } else { - VLOG(1) << "Segment devices size is 0"; + LOG(ERROR) << "Can't find a device placement for the op!"; } return Status::OK(); } -// Function to insert a TRT node into the graph. The graph is not modified if -// the returned status is not ok. -// 'alloc' is only used for creating static engine. -tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, - const std::vector& infos, int pos, +// Helper function to update edge connection from the removed node to the +// engine node. If an outside node is gone, it must have been absorbed into +// an engine node. Find the engine node. +void UpdateToEngineNode(const std::vector& infos, + const size_t my_engine_id, + const std::vector& engine_nodes, + const bool is_input_edge, const string& node_name, + tensorflow::Node** node, int* port) { + for (size_t t = 0; t < infos.size(); ++t) { + if (t == my_engine_id) { + continue; + } + const auto& info = infos.at(t); + for (const auto& eng_conn : info.connections) { + // If the connection being updated is an input connection, the source of + // the connection must be an output connection of another engine. And vise + // versa. + if (is_input_edge == eng_conn.is_input_edge) continue; + if (eng_conn.inside_node_name == node_name && + eng_conn.inside_port == *port) { + *node = CHECK_NOTNULL(engine_nodes[t]); + QCHECK_EQ(info.engine_name, (**node).name()) + << "Engine name mismatch: " << info.engine_name << " vs " + << (**node).name(); + *port = eng_conn.port_number; + return; + } + } + } + LOG(FATAL) << "Node " << (**node).name() << " not found in any engine."; +} + +// Function to insert a TRT engine node into the graph. +// Create engine nodes in the following way: +// 1. Each invocation of CreateTRTNode creates an engine node for infos[pos] +// 2. When an engine node is created, add it into the graph with necessary +// re-wiring. +// 2.1. If the outside connected node is existing, connect the engine +// node to it. +// 2.2. If the outside connected node is gone, it must have been absorted +// into another engine node (which was processed before the processing +// one). Connect to the pre-existing engine node instead. +// 3. In this way, we ensure the graph is topologically sort-able after each +// invocation of CreateTRTNode(). +tensorflow::Status CreateTRTNode(const std::vector& infos, int pos, + int max_batch_size, tensorflow::Graph* graph, nvinfer1::IGpuAllocator* alloc, - int max_batch_size) { + std::vector* engine_nodes) { const auto& info = infos.at(pos); - std::vector out_shapes; - std::vector input_shapes; - std::vector shapes; + TRT_RETURN_IF_TEST_VALUE(StrCat(info.engine_name, ":CreateTRTNode"), "fail"); + std::vector output_shape_protos; + std::vector input_shape_protos; + std::vector input_shapes; std::vector inputs; + std::vector input_nodes; + std::vector control_input_nodes; + std::unordered_set control_input_names; std::vector out_types; - VLOG(1) << "Processing " << info.engine_name; - // Update the shape and data types of input/output nodes, and find all unique - // inputs. + VLOG(1) << "Processing " << info.engine_name; + // Collect needed info for creating the engine node in the graph for (const auto& conn : info.connections) { - if (!conn.is_input_edge) { - // Set the shapes and data types of output edge. - tensorflow::TensorShapeProto out_shape; - // shape of the output node inside segment - conn.inside_shape.AsProto(&out_shape); - if (out_shapes.size() <= conn.port_number) { - out_shapes.resize(conn.port_number + 1); - out_types.resize(conn.port_number + 1); + // Control edges + if (conn.is_control_edge()) { + // Skip control outputs for now. control output info are not needed for + // node creation and will be processed later. + if (!conn.is_input_edge) continue; + + // Rewrire control input if it's not found in original graph. + tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id); + int port = tensorflow::Graph::kControlSlot; + if (!input_node) { + UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true, + conn.outside_node_name, &input_node, &port); + QCHECK_EQ(Graph::kControlSlot, port); } - out_shapes.at(conn.port_number) = out_shape; - out_types.at(conn.port_number) = conn.connection_type; - continue; - } - - // Set the shapes and data types of input edge. - tensorflow::TensorShapeProto in_shape; - conn.outside_shape.AsProto(&in_shape); - if (input_shapes.size() <= conn.port_number) { - input_shapes.resize(conn.port_number + 1); - shapes.resize(conn.port_number + 1); - } - input_shapes.at(conn.port_number) = in_shape; - shapes.at(conn.port_number) = conn.outside_shape; - - string input_node = conn.outside_node_name; - int input_port = conn.outside_port; - bool found_engine = false; - // Rewire the inputs to other engines if they contain original input node. - // Note that we use the information of the engine here, not the information - // of the created TRT nodes, so we're able to find all the connections to - // any other engines beforehand. - for (size_t t = 0; t < infos.size(); ++t) { - if (t == pos) continue; - auto& engine_info = infos.at(t); - for (const auto& eng_conn : engine_info.connections) { - if (eng_conn.is_input_edge) continue; - if (eng_conn.inside_node_name == input_node) { - input_node = engine_info.engine_name; - if (eng_conn.inside_port == input_port) { - input_port = eng_conn.port_number; - found_engine = true; - break; - } - } + if (!control_input_names.insert(input_node->name()).second) { + continue; } - if (found_engine) break; - } - VLOG(1) << "Engine Input " << input_node << ":" << input_port << " -> " - << info.engine_name << ":" << inputs.size(); - // Skip duplicate inputs. - bool new_input = true; - for (const auto& inp : inputs) { - if (inp.node == input_node && inp.index == input_port) { - new_input = false; - break; + control_input_nodes.push_back(input_node); + VLOG(1) << "Engine Control Input " << input_node->name() << " -> " + << info.engine_name; + } else { + // Data edges + if (!conn.is_input_edge) { + // Set the shapes and data types of output edge. + tensorflow::TensorShapeProto out_shape; + // shape of the output node inside segment + conn.inside_shape.AsProto(&out_shape); + if (output_shape_protos.size() <= conn.port_number) { + output_shape_protos.resize(conn.port_number + 1); + out_types.resize(conn.port_number + 1); + } + output_shape_protos.at(conn.port_number) = out_shape; + out_types.at(conn.port_number) = conn.connection_type; + } else { + // Set the shapes and data types of input edge. + tensorflow::TensorShapeProto in_shape; + conn.outside_shape.AsProto(&in_shape); + if (input_shape_protos.size() <= conn.port_number) { + input_shape_protos.resize(conn.port_number + 1); + input_shapes.resize(conn.port_number + 1); + } + input_shape_protos.at(conn.port_number) = in_shape; + input_shapes.at(conn.port_number) = conn.outside_shape; + + // Rewrire data input if it's not found in original graph. + tensorflow::Node* input_node = graph->FindNodeId(conn.outside_id); + int port = conn.outside_port; + if (!input_node) { + UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true, + conn.outside_node_name, &input_node, &port); + } + if (std::find_if( + std::begin(inputs), std::end(inputs), + [input_node, &port](const NodeDefBuilder::NodeOut& inp) { + return inp.node == input_node->name() && inp.index == port; + }) == std::end(inputs)) { + inputs.emplace_back(input_node->name(), port, conn.connection_type); + input_nodes.push_back(CHECK_NOTNULL(input_node)); + VLOG(1) << "Engine Input " << input_node->name() << ":" << port + << " -> " << info.engine_name << ":" << inputs.size() - 1; + } } } - if (new_input) { - inputs.emplace_back(input_node, input_port, conn.connection_type); - } } - - // Build the engine and get its serialized representation. string segment_string; if (info.engine_type == EngineInfo::EngineType::TRTStatic || info.precision_mode == INT8MODE) { @@ -437,8 +557,8 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, TF_RETURN_IF_ERROR(ConvertGraphDefToEngine( info.segment_graph_def, info.precision_mode == INT8MODE ? FP32MODE : info.precision_mode, - max_batch_size, info.max_workspace_size_bytes, shapes, &trt_logger, - alloc, /*calibrator=*/nullptr, &engine, + max_batch_size, info.max_workspace_size_bytes, input_shapes, + &trt_logger, alloc, /*calibrator=*/nullptr, &engine, /*convert_successfully=*/nullptr)); TrtUniquePtrType engine_data(engine->serialize()); segment_string = @@ -454,21 +574,10 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, // TODO(aaroey): use enum instead, and add a helper method to do the // conversion. string prec_string; - switch (info.precision_mode) { - case FP32MODE: - prec_string = "FP32"; - break; - case FP16MODE: - prec_string = "FP16"; - break; - case INT8MODE: - prec_string = "INT8"; - if (!TRTResourceManager::instance()->getManager("TRTCalibration")) { - LOG(ERROR) << "Failed to construct calibration storage"; - } - break; - default: - return tensorflow::errors::OutOfRange("Unknown precision mode"); + TF_RETURN_IF_ERROR(GetPrecisionModeName(info.precision_mode, &prec_string)); + if (info.precision_mode == INT8MODE && + !TRTResourceManager::instance()->getManager("TRTCalibration")) { + LOG(ERROR) << "Failed to construct calibration storage"; } tensorflow::NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp"); if (!info.device.empty()) node_builder.Device(info.device); @@ -480,14 +589,18 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, VLOG(1) << ins; } node_builder.Input(inputs); + for (const string& c : control_input_names) { + node_builder.ControlInput(c); + } + if (info.engine_type == EngineInfo::EngineType::TRTStatic && info.cached_engine_batches.size()) { LOG(WARNING) << "Cached engine batches are ignored for static engines"; } tensorflow::NodeDef trt_node; tensorflow::Status status = - node_builder.Attr("input_shapes", input_shapes) - .Attr("output_shapes", out_shapes) + node_builder.Attr("input_shapes", input_shape_protos) + .Attr("output_shapes", output_shape_protos) .Attr("static_engine", info.engine_type == EngineInfo::EngineType::TRTStatic) .Attr("segment_funcdef_name", @@ -508,34 +621,55 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, // Up until this point, graph is not modified. If we return !status.ok() from // here, this segment will be skipped + // TODO(aaroey): let it return proper error status for the following logic + // instead of checking fail. tensorflow::Node* engine_node = graph->AddNode(trt_node, &status); + (*engine_nodes)[pos] = engine_node; if (!status.ok()) { LOG(ERROR) << "Adding node failed " << status; return status; } + // Add control input and input edges to the engine node. + for (const auto in : control_input_nodes) { + VLOG(1) << "Connecting control edge from " << in->name() << " to " + << engine_node->name(); + graph->AddControlEdge(in, engine_node); + } + VLOG(1) << "input_nodes size = " << input_nodes.size(); + for (int i = 0; i < input_nodes.size(); ++i) { + Node* n = CHECK_NOTNULL(input_nodes[i]); + const auto& in = inputs[i]; + VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index + << " to " << engine_node->name() << ":" << i; + graph->AddEdge(n, in.index, engine_node, i); + } + // Updates the inputs of output edges destination nodes, and point them to the // engine node. for (auto& conn : info.connections) { - if (conn.is_input_edge) continue; - VLOG(1) << " Updating DBG " << engine_node->name() << " out_port " - << conn.port_number << " out_id " << conn.outside_id - << " name=" << conn.outside_node_name; - auto dst_node = graph->FindNodeId(conn.outside_id); - // dst_node can only be removed if it is an input node of another engine. - // In this case, other engines input edge is updated in nodedef to point to - // this engine. Even though edge doesn't exists in the graph, when it is - // deserialized again, correct edges will be constructed. This is a problem - // of graph->AddNode(). - if (!dst_node) continue; + if (conn.is_input_edge) { + continue; + } + tensorflow::Node* output_node = graph->FindNodeId(conn.outside_id); + int port = conn.outside_port; + if (!output_node) { + UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false, + conn.outside_node_name, &output_node, &port); + } VLOG(1) << "Updating " << engine_node->name() << ":" << conn.port_number - << " to " << dst_node->name() << ":" << conn.outside_port; - auto new_edge = graph->AddEdge(engine_node, conn.port_number, dst_node, - conn.outside_port); - CHECK(new_edge) << "Adding a new edge failed " << engine_node->name() << ":" - << conn.port_number << " -> " << dst_node->name() << ":" - << conn.outside_port; + << " to " << output_node->name() << ":" << port; + if (conn.is_control_edge()) { + QCHECK_EQ(Graph::kControlSlot, port); + graph->AddControlEdge(engine_node, output_node); + } else { + auto new_edge = + graph->AddEdge(engine_node, conn.port_number, output_node, port); + QCHECK(new_edge) << "Adding a new edge failed " << engine_node->name() + << ":" << conn.port_number << " -> " + << output_node->name() << ":" << conn.outside_port; + } } - return status; + return Status::OK(); } // Function to construct a funcdef from the segment and add it to the graph. @@ -596,7 +730,9 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary( edge->src()->output_type(edge->src_output())); VLOG(1) << " input " << nout.node << ":" << nout.index << " dtype=" << tensorflow::DataTypeString(nout.data_type); - node_builder.Input({nout}); + // nvcc complains that Input() is + // ambiguous, so do not use Input({nout}). + node_builder.Input(nout); TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0)) .Attr("index", i) .Finalize(&nd)); @@ -633,77 +769,42 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary( } std::pair GetDeviceAndAllocator( - ConversionParams& params, EngineInfo& engine) { + const ConversionParams& params, const EngineInfo& engine) { int cuda_device_id = -1; - auto check_device_id = [](int tfid) -> int { - tensorflow::TfGpuId tf_gpu_id(tfid); - CudaGpuId cuda_gpu_id; - Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); - if (s.ok()) { - VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device " - << cuda_gpu_id.value(); - return cuda_gpu_id.value(); - } - VLOG(2) << "TF GPU with id " << tfid << " do not exist " << s; - return -1; - }; tensorflow::Allocator* dev_allocator = nullptr; - // we need to us PM here since in python path there is no way to get - // to allocators. - // TODO(sami): when grappler devices become available else path will not be - // necessary - auto pm = tensorflow::ProcessState::singleton(); - if (params.cluster) { // get allocator - tensorflow::Device* device = nullptr; - if (params.cluster->GetDeviceSet()) { - device = params.cluster->GetDeviceSet()->FindDeviceByName(engine.device); + if (params.cluster) { + std::vector devices; + if (!engine.device.empty() && params.cluster->GetDeviceSet()) { + DeviceNameUtils::ParsedName parsed_name; + if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) && + parsed_name.has_id) { + params.cluster->GetDeviceSet()->FindMatchingDevices(parsed_name, + &devices); + } } - if (device) { + if (!devices.empty()) { + if (devices.size() > 1) { + string msg = "Found multiple matching devices using name '"; + StrAppend(&msg, engine.device, "': "); + for (auto d : devices) StrAppend(&msg, d->name(), ", "); + StrAppend(&msg, ". Will get the allocator from first one."); + LOG(WARNING) << msg; + } tensorflow::AllocatorAttributes alloc_attr; - dev_allocator = device->GetAllocator(alloc_attr); - VLOG(1) << "Using allocator " << dev_allocator->Name(); + cuda_device_id = devices[0]->tensorflow_gpu_device_info()->gpu_id; + dev_allocator = devices[0]->GetAllocator(alloc_attr); + VLOG(1) << "Using allocator " << dev_allocator->Name() + << " and cuda_device_id " << cuda_device_id; } else { LOG(WARNING) << "Cluster is set but device '" << engine.device << "' is not found in the cluster"; } - } else { // cluster not found, possibly a python call - VLOG(1) << "Cluster is not set, probably called from python"; - int found_device = 0; - bool try_gpu_ids = true; - // if device is set, try to find the device. Might be a problem for multi - // host case but TensorRT do not support multi host setups yet. - if (!engine.device.empty()) { - DeviceNameUtils::ParsedName parsed_name; - if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name)) { - cuda_device_id = parsed_name.has_id ? parsed_name.id : -1; - } - try_gpu_ids = !parsed_name.has_id; - } - if (try_gpu_ids) { - while (found_device < 100) { - cuda_device_id = check_device_id(found_device); - if (cuda_device_id >= 0) break; - found_device++; - } - } - if (found_device == 100) { - LOG(ERROR) << " Can't find a GPU device to work with. Please " - "instantiate a session to initialize devices"; - return std::make_pair(cuda_device_id, dev_allocator); - } - LOG(WARNING) - << "Can't determine the device, constructing an allocator at device " - << found_device; - tensorflow::GPUOptions gpuoptions; - // this will be a noop if device is already initialized - gpuoptions.set_allow_growth(true); - tensorflow::TfGpuId tf_gpu_id(found_device); - dev_allocator = pm->GetGPUAllocator(gpuoptions, tf_gpu_id, 1); } return std::make_pair(cuda_device_id, dev_allocator); } // Entry function from optimization pass. +// TODO(aaeory): parameter should use pointer type. tensorflow::Status ConvertAfterShapes(ConversionParams& params) { // Convert graphdef to graph. tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), @@ -721,7 +822,8 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { segment_options.minimum_segment_size = params.minimum_segment_size; tensorflow::tensorrt::segment::SegmentNodesVector initial_segments; TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph( - &graph, IsTensorRTCandidate, segment_options, &initial_segments)); + &graph, IsTensorRTCandidate, InputEdgeValidator(*params.graph_properties), + OutputEdgeValidator(), segment_options, &initial_segments)); if (initial_segments.size() > 1) { VLOG(0) << "MULTIPLE tensorrt candidate conversion: " << initial_segments.size(); @@ -789,6 +891,8 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { LOG(ERROR) << "Couldn't get current device: " << cudaGetErrorString(err); } VLOG(1) << "Current cuda device is " << old_cuda_device; + std::vector engine_nodes; + engine_nodes.resize(engine_segments.size()); for (int i = 0; i < engine_segments.size(); ++i) { auto& engine = engine_segments.at(i); // Partition the workspace size by the average of node ratio and segment @@ -801,7 +905,7 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { // The allocator is used to build the engine. The build and the built engine // will be destroyed after we get the serialized engine string, so it's fine // to use unique_ptr here. - std::unique_ptr alloc; + std::unique_ptr alloc; auto device_alloc = GetDeviceAndAllocator(params, engine); int cuda_device_id = 0; if (device_alloc.first >= 0) { @@ -812,19 +916,21 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) { LOG(WARNING) << "Can't identify the cuda device. Running on device 0 "; } cudaSetDevice(cuda_device_id); - auto status = CreateTRTNode(&graph, engine_segments, i, alloc.get(), - params.max_batch_size); + auto status = CreateTRTNode(engine_segments, i, params.max_batch_size, + &graph, alloc.get(), &engine_nodes); // If status is ok, we successfully added the node to the graph and can // remove segment ops. Otherwise graph is not modified. + const string msg = StrCat("Engine ", engine.engine_name, + " creation for segment ", i, ", composed of ", + converted_segments.at(i).first.size(), " nodes"); if (status.ok()) { + LOG(INFO) << msg << " succeeded."; for (auto node_name : converted_segments.at(i).first) { graph.RemoveNode(node_map.at(node_name)); } } else { // Graph is not modified. - LOG(WARNING) << "Engine creation for segment " << i << ", composed of " - << converted_segments.at(i).first.size() << " nodes failed: " - << status << ". Skipping..."; + LOG(WARNING) << msg << " failed: " << status << ". Skipping..."; } } cudaSetDevice(old_cuda_device); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 146b9c7344b0a9c2b3ec87b395e9b1096dbef06c..35fa590254137d62fea868882d5c225848829ca1 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -16,11 +16,13 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" #include +#include #include #include #include #include #include +#include #include #include @@ -49,15 +51,34 @@ limitations under the License. #if GOOGLE_TENSORRT #include "tensorrt/include/NvInfer.h" -// Check if the types are equal. Cast to int first so that failure log message -// would work! -#define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2) +// Check if the types are equal. Cast to int first so that failure log message +// would work! +#define TFTRT_CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2) + +#define TFTRT_INTERNAL_ERROR_AT_NODE(node) \ + do { \ + return tensorflow::errors::Internal( \ + "TFTRT::", __FUNCTION__, "failed to add TRT layer, at: ", node); \ + } while (0) + +#define TFTRT_RETURN_ERROR_IF_FALSE(status, node) \ + do { \ + if (status == false) { \ + TFTRT_INTERNAL_ERROR_AT_NODE(node); \ + } \ + } while (0) + +#define TFTRT_RETURN_ERROR_IF_NULLPTR(ptr, node) \ + do { \ + if (ptr == nullptr) { \ + TFTRT_INTERNAL_ERROR_AT_NODE(node); \ + } \ + } while (0) namespace tensorflow { namespace tensorrt { namespace convert { using ::tensorflow::str_util::Split; - using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; @@ -75,13 +96,163 @@ inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, case tensorflow::DataType::DT_HALF: *trt_dtype = nvinfer1::DataType::kHALF; break; +#if NV_TENSORRT_MAJOR > 3 + case tensorflow::DataType::DT_INT32: + *trt_dtype = nvinfer1::DataType::kINT32; + break; +#endif default: return tensorflow::errors::InvalidArgument( - "Unsupported data type " + tensorflow::DataTypeString(tf_dtype)); + "Unsupported data type ", tensorflow::DataTypeString(tf_dtype)); } return tensorflow::Status::OK(); } +void GetInputProperties(const grappler::GraphProperties& graph_properties, + const Node* outside_node, const int out_port, + PartialTensorShape* shape, + tensorflow::DataType* dtype) { + if (graph_properties.HasOutputProperties(outside_node->name())) { + auto output_params = + graph_properties.GetOutputProperties(outside_node->name()); + auto out_shape = output_params.at(out_port); + *dtype = out_shape.dtype(); + *shape = out_shape.shape(); + } else { + VLOG(0) << "Unknown output shape" << outside_node->name(); + *dtype = outside_node->output_type(out_port); + } +} + +void GetOutputProperties(const grappler::GraphProperties& graph_properties, + const Node* outside_node, const int in_port, + PartialTensorShape* shape, + tensorflow::DataType* dtype) { + if (graph_properties.HasInputProperties(outside_node->name())) { + auto input_params = + graph_properties.GetInputProperties(outside_node->name()); + auto in_shape = input_params.at(in_port); + *dtype = in_shape.dtype(); + *shape = in_shape.shape(); + } else { + *dtype = outside_node->input_type(in_port); + } +} + +tensorflow::Status ValidateInputProperties(const PartialTensorShape& shape, + const tensorflow::DataType dtype, + nvinfer1::DataType* trt_dtype) { + // TODO(aaroey): some of these checks also apply to IsTensorRTCandidate(), so + // put them there instead. + TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype)); + if (shape.dims() < 0) { + return tensorflow::errors::InvalidArgument("Input tensor rank is unknown."); + } + if (shape.dims() > 9) { + return tensorflow::errors::OutOfRange( + "Input tensor rank is greater than 8."); + } + for (int d = 1; d < shape.dims(); ++d) { + if (shape.dim_size(d) < 0) { + return tensorflow::errors::InvalidArgument( + "Input tensor has a unknown non-batch dimemension at dim ", d); + } + } + return Status::OK(); +} + +// Return whether or not the broadcast is feasible; +bool TensorRTGetBroadcastShape(const nvinfer1::Dims& operand_l, + const bool operand_l_is_tensor, + const nvinfer1::Dims& operand_r, + const bool operand_r_is_tensor, + nvinfer1::Dims* operand_l_new_shape, + nvinfer1::Dims* operand_r_new_shape) { + // *************************************************************************** + // TensorRT Elementwise op supports broadcast but requires both tensor to be + // of Identical rank + // + // We consider case of: + // 1. operand_l to be a Tensor & operand_r to be a Const; + // 2. operand_l to be a Tensor & operand_r to be a Tensor; + // note: const op const (constant folding) should fallback to TensorFlow + // + // broadcast scheme: + // T: 1 3 5 (tensor would not have batch dimension) + // W: 1 1 3 1 (weight would have all explicit dimensions) + // i. fill in explicit dimensions + // -> T: -1 1 3 5 (we put a -1 for batch dimension) + // -> W: 1 1 3 1 + // ii. compare broadcast feasibility + // + // We cannot support the following since TensorRT does not allow manipulation + // on batch dimension, we cannot generate output with proper shape + // T: 3 5 1 + // W: 1 1 1 1 3 5 1 + // -> T: 1 1 1 -1 3 5 1 + // -> W: 1 1 1 1 3 5 1 + // *************************************************************************** + const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1; + const size_t element_size = sizeof(operand_l.d[0]); + + // fill in dimensions + int l_s[max_nb_dims]; + std::fill(l_s, l_s + max_nb_dims, 1); + int l_d = operand_l_is_tensor ? operand_l.nbDims + 1 : operand_l.nbDims; + int r_s[max_nb_dims]; + std::fill(r_s, r_s + max_nb_dims, 1); + int r_d = operand_r_is_tensor ? operand_r.nbDims + 1 : operand_r.nbDims; + + int max_d = std::max(l_d, r_d); + std::memcpy(l_s + max_d - operand_l.nbDims, operand_l.d, + operand_l.nbDims * element_size); + std::memcpy(r_s + max_d - operand_r.nbDims, operand_r.d, + operand_r.nbDims * element_size); + + // set -1 for batch dimension, since batch size is not supposed to be + // broadcasted + if (operand_l_is_tensor) { + if (max_d != l_d) { // if broadcast beyond batch dimension, fail + return false; + } + l_s[0] = -1; + } + if (operand_r_is_tensor) { + if (max_d != r_d) { // if broadcast beyond batch dimension, fail + return false; + } + r_s[0] = -1; + } + + // compare broadcast feasibility + for (int i = max_d - 1; i >= 0; i--) { + if ((l_s[i] != r_s[i]) && (l_s[i] != 1) && (r_s[i] != 1)) { + return false; + } + } + + // output new TensorRT Dimension (stripping the batch dimension) + operand_l_new_shape->nbDims = max_d - 1; + std::memcpy(operand_l_new_shape->d, l_s + 1, (max_d - 1) * element_size); + operand_r_new_shape->nbDims = max_d - 1; + std::memcpy(operand_r_new_shape->d, r_s + 1, (max_d - 1) * element_size); + + return true; +} + +inline bool DimsEqual(const nvinfer1::Dims& dim_l, + const nvinfer1::Dims& dim_r) { + if (dim_l.nbDims != dim_r.nbDims) { + return false; + } + for (int i = 0; i < dim_l.nbDims; i++) { + if (dim_l.d[i] != dim_r.d[i]) { + return false; + } + } + return true; +} + inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) { nvinfer1::Dims dims; dims.nbDims = tensor.dims(); @@ -91,7 +262,7 @@ inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) { return dims; } -inline int64_t GetShapeSize(nvinfer1::Dims shape) { +inline int64_t GetShapeSize(const nvinfer1::Dims& shape) { // Returns total number of elements in shape int64_t count = 1; for (int d = 0; d < shape.nbDims; ++d) { @@ -104,7 +275,7 @@ static std::vector> CreateSamePadding( const nvinfer1::DimsHW& stride, const nvinfer1::DimsHW& kernel, const std::vector& input_dims) { std::vector> padding(input_dims.size()); - CHECK_EQ((size_t)stride.nbDims, input_dims.size()); // TODO(jie): N+C? NC+? + CHECK_EQ(stride.nbDims, input_dims.size()); // TODO(jie): N+C? NC+? for (size_t i = 0; i < input_dims.size(); ++i) { // Formula to calculate the padding @@ -134,6 +305,7 @@ string GetCommonNameScope(const string& op_name_a, const string& op_name_b) { return op_name_a.substr(0, last_scope_separator); } +// Class to convert TF weight to TRT weight. class TRT_ShapedWeights { public: TRT_ShapedWeights(tensorflow::DataType type, const void* values, @@ -145,12 +317,14 @@ class TRT_ShapedWeights { explicit TRT_ShapedWeights(tensorflow::DataType type) : shape_(), type_(type), values_(nullptr), empty_weight_flag_(true) {} + // TODO(aaroey): use rvalue reference. TRT_ShapedWeights(const TRT_ShapedWeights& rhs) : shape_(rhs.shape_), type_(rhs.type_), values_(rhs.values_), empty_weight_flag_(rhs.empty_weight_flag_) {} + // TODO(aaroey): use GetShapeSize() instead. int64_t count() const { int64_t c = 1; for (int i = 0; i < shape_.nbDims; i++) c *= shape_.d[i]; @@ -168,6 +342,7 @@ class TRT_ShapedWeights { const void* GetValues() const { return values_; } + // TODO(aaroey): get rid of this method. void SetValues(const void* values) { values_ = values; } size_t size_bytes() const { @@ -178,10 +353,12 @@ class TRT_ShapedWeights { // Default converter operator nvinfer1::Weights() const { return GetWeightsForTRT(); } + // TODO(aaroey): make these private. nvinfer1::Dims shape_; tensorflow::DataType type_; private: + // TODO(aaroey): this should not be const as it's always from TRTWeightStore. const void* values_; bool empty_weight_flag_; }; @@ -192,6 +369,7 @@ class TRT_TensorOrWeights { : tensor_(tensor), weights_(DT_FLOAT), variant_(TRT_NODE_TENSOR) {} explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights) : tensor_(nullptr), weights_(weights), variant_(TRT_NODE_WEIGHTS) {} + // TODO(aaroey): use rvalue reference. TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs) : tensor_(rhs.tensor_), weights_(rhs.weights_), variant_(rhs.variant_) {} ~TRT_TensorOrWeights() {} @@ -200,19 +378,19 @@ class TRT_TensorOrWeights { bool is_weights() const { return variant_ == TRT_NODE_WEIGHTS; } nvinfer1::ITensor* tensor() { - CHECK_EQ(is_tensor(), true); + CHECK(is_tensor()); return tensor_; } const nvinfer1::ITensor* tensor() const { - CHECK_EQ(is_tensor(), true); + CHECK(is_tensor()); return tensor_; } TRT_ShapedWeights& weights() { - CHECK_EQ(is_weights(), true); + CHECK(is_weights()); return weights_; } const TRT_ShapedWeights& weights() const { - CHECK_EQ(is_weights(), true); + CHECK(is_weights()); return weights_; } nvinfer1::Dims shape() const { @@ -236,21 +414,25 @@ class TFAttrs { attrs_.insert({attr.first, &attr.second}); } } - bool count(string key) const { return attrs_.count(key); } - tensorflow::AttrValue const* at(string key) const { + + bool count(const string& key) const { return attrs_.count(key); } + + tensorflow::AttrValue const* at(const string& key) const { if (!attrs_.count(key)) { LOG(FATAL) << "Attribute not found: " << key; } return attrs_.at(key); } + template T get(const string& key) const; + template T get(const string& key, const T& default_value) const { return attrs_.count(key) ? this->get(key) : default_value; } - std::vector GetAllAttrKey() { + std::vector GetAllAttrKeys() const { std::vector attr_list; for (const auto& attr_item : attrs_) { attr_list.emplace_back(attr_item.first); @@ -285,15 +467,6 @@ std::vector TFAttrs::get>(const string& key) const { auto attr = this->at(key)->list().s(); return std::vector(attr.begin(), attr.end()); } -template <> -nvinfer1::Dims TFAttrs::get(const string& key) const { - auto values = this->get>(key); - nvinfer1::Dims dims; - dims.nbDims = values.size(); - std::copy(values.begin(), values.end(), dims.d); - // Note: No dimension type information is included - return dims; -} template <> nvinfer1::DataType TFAttrs::get(const string& key) const { @@ -319,10 +492,11 @@ bool TFAttrs::get(const string& key) const { } // TODO(jie): reorder4 & reorder2 should be merged? +// TODO(aaroey): fix the order of parameters. template -void Reorder4(nvinfer1::DimsNCHW shape, const T* idata, - nvinfer1::DimsNCHW istrides, T* odata, - nvinfer1::DimsNCHW ostrides) { +void Reorder4(const nvinfer1::DimsNCHW& shape, const T* idata, + const nvinfer1::DimsNCHW& istrides, T* odata, + const nvinfer1::DimsNCHW& ostrides) { for (int n = 0; n < shape.n(); ++n) { for (int c = 0; c < shape.c(); ++c) { for (int h = 0; h < shape.h(); ++h) { @@ -337,12 +511,13 @@ void Reorder4(nvinfer1::DimsNCHW shape, const T* idata, } template -void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides, - T* odata, nvinfer1::DimsHW ostrides) { +void Reorder2(const nvinfer1::DimsHW& shape, const T* idata, + const nvinfer1::DimsHW& istrides, T* odata, + const nvinfer1::DimsHW& ostrides) { for (int h = 0; h < shape.h(); ++h) { for (int w = 0; w < shape.w(); ++w) { odata[h * ostrides.h() + w * ostrides.w()] = - idata[h * ostrides.h() + w * ostrides.w()]; + idata[h * istrides.h() + w * istrides.w()]; } } } @@ -350,16 +525,17 @@ void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides, // TODO(jie): fallback to tensorflow!! void ReorderCKtoKC(const TRT_ShapedWeights& iweights, TRT_ShapedWeights* oweights) { - int c = iweights.shape_.d[0]; - int k = iweights.shape_.d[1]; + const int c = iweights.shape_.d[0]; + const int k = iweights.shape_.d[1]; oweights->shape_.d[0] = k; oweights->shape_.d[1] = c; - nvinfer1::DimsHW istrides = {1, k}; - nvinfer1::DimsHW ostrides = {c, 1}; + const nvinfer1::DimsHW istrides = {1, k}; + const nvinfer1::DimsHW ostrides = {c, 1}; switch (iweights.type_) { case tensorflow::DataType::DT_FLOAT: { Reorder2({k, c}, static_cast(iweights.GetValues()), istrides, + // TODO(aaroey): get rid of all the const_cast like this. static_cast(const_cast(oweights->GetValues())), ostrides); break; @@ -382,21 +558,24 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, TRT_ShapedWeights* oweights, int num_groups) { CHECK_EQ(iweights.type_, oweights->type_); CHECK_EQ(iweights.size_bytes(), oweights->size_bytes()); - int r = iweights.shape_.d[0]; - int s = iweights.shape_.d[1]; - // TRT requires GKcRS, while TF depthwise has RSCK - // where c=1, C=G + // K indexes over output channels, C over input channels, and R and S over the + // height and width of the convolution + const int r = iweights.shape_.d[0]; + const int s = iweights.shape_.d[1]; + // TRT requires GKcRS, while TF depthwise has RSCK where c=1, C=G VLOG(2) << "num_groups: " << num_groups; - int c = iweights.shape_.d[2] / num_groups; + const int c = iweights.shape_.d[2] / num_groups; VLOG(2) << "c" << iweights.shape_.d[2] << " then " << c; - int k = iweights.shape_.d[3] * num_groups; + const int k = iweights.shape_.d[3] * num_groups; VLOG(2) << "k" << iweights.shape_.d[3] << " then " << k; + VLOG(2) << "r" << iweights.shape_.d[0] << " then " << r; + VLOG(2) << "s" << iweights.shape_.d[1] << " then " << s; oweights->shape_.d[0] = k / num_groups; oweights->shape_.d[1] = c * num_groups; oweights->shape_.d[2] = r; oweights->shape_.d[3] = s; - nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k}; - nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1}; + const nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k}; + const nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1}; switch (iweights.type_) { case tensorflow::DataType::DT_FLOAT: { Reorder4({k, c, r, s}, static_cast(iweights.GetValues()), @@ -428,11 +607,14 @@ using OpConverter = std::vector*)>; class Converter { + // TODO(aaroey): fix the order of members. std::unordered_map trt_tensors_; std::unordered_map op_registry_; OpConverter plugin_converter_; nvinfer1::INetworkDefinition* trt_network_; std::list> temp_bufs_; + // TODO(aaroey): inline the definition of TRTWeightStore here, and add APIs to + // operate the stored weights instead of operating it directly. TRTWeightStore* weight_store_; bool fp16_; void register_op_converters(); @@ -440,7 +622,7 @@ class Converter { std::vector* inputs) { for (auto const& input_name : node_def.input()) { /************************************************************************* - * TODO(jie) handle case 1) here + * TODO(jie): handle case 1) here. * Normalizes the inputs and extracts associated metadata: * 1) Inputs can contain a colon followed by a suffix of characters. * That suffix may be a single number (e.g. inputName:1) or several @@ -454,6 +636,7 @@ class Converter { if (input_name[0] == '^') continue; string name = input_name; auto first = name.find_first_of(':'); + // TODO(aaroey): why removing the colon but not the zero? A bug? if (first != string::npos && first + 2 == name.size() && name[first + 1] == '0') name.erase(first); @@ -462,12 +645,13 @@ class Converter { if (trt_tensors_.count(name)) { inputs->push_back(trt_tensors_.at(name)); } else { - string str("Node "); - StrAppend(&str, node_def.name(), " should have an input named '", name, + // TODO(aaroey): this should not happen, make it a CHECK. + // TODO(aaroey): use StrCat for pattern like this. + string msg("Node "); + StrAppend(&msg, node_def.name(), " should have an input named '", name, "' but it is not available"); - LOG(WARNING) << "input: " << name << " not available for node at " - << node_def.name(); - return tensorflow::errors::InvalidArgument(str); + LOG(ERROR) << msg; + return tensorflow::errors::InvalidArgument(msg); } } return tensorflow::Status::OK(); @@ -488,6 +672,7 @@ class Converter { weights.SetValues(weight_store_->store_.back().data()); return weights; } + // TODO(aaroey): fix all the namings. bool isFP16() { return fp16_; } TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) { return this->get_temp_weights(weights.type_, weights.shape_); @@ -496,9 +681,10 @@ class Converter { tensorflow::Status convert_node(const tensorflow::NodeDef& node_def) { std::vector inputs; TF_RETURN_IF_ERROR(this->get_inputs(node_def, &inputs)); - string op = node_def.op(); + const string& op = node_def.op(); std::vector outputs; if (PluginFactoryTensorRT::GetInstance()->IsPlugin(op)) { + // TODO(aaroey): plugin_converter_ is not set, fix it. TF_RETURN_IF_ERROR(plugin_converter_(*this, node_def, inputs, &outputs)); } else { if (!op_registry_.count(op)) { @@ -509,7 +695,7 @@ class Converter { TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs)); } for (size_t i = 0; i < outputs.size(); ++i) { - TRT_TensorOrWeights output = outputs.at(i); + TRT_TensorOrWeights& output = outputs[i]; // TODO(jie): tf protobuf seems to be omitting the :0 suffix string output_name = node_def.name(); if (i != 0) output_name = StrCat(output_name, ":", i); @@ -527,26 +713,29 @@ class Converter { nvinfer1::INetworkDefinition* network() { return trt_network_; } - TRT_TensorOrWeights get_tensor(string name) { + TRT_TensorOrWeights get_tensor(const string& name) { if (!trt_tensors_.count(name)) { return TRT_TensorOrWeights(nullptr); } return trt_tensors_.at(name); } - bool insert_input_tensor(string name, nvinfer1::ITensor* tensor) { + bool insert_input_tensor(const string& name, nvinfer1::ITensor* tensor) { return trt_tensors_.insert({name, TRT_TensorOrWeights(tensor)}).second; } nvinfer1::ITensor* TransposeTensor(nvinfer1::ITensor* input_tensor, - std::vector order) { - auto dims = input_tensor->getDimensions(); + const std::vector& order) { + const auto dims = input_tensor->getDimensions(); // TODO(jie): change the return to status and properly exit if (order.size() - 1 != size_t(dims.nbDims)) LOG(ERROR) << "Dimension does not match, fail gracefully"; nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor); + if (layer == nullptr) { + return nullptr; + } nvinfer1::Permutation permutation; for (int32_t i = 0; i < dims.nbDims; ++i) { permutation.order[i] = order[i + 1] - 1; @@ -577,13 +766,14 @@ TRT_ShapedWeights ConvertFP32ToFP16(Converter& ctx, } return weights; } + // **************************************************************************** // Constant folding functions // TODO(jie): once optimizer kicks in, we should have done constant folding // there. -//*****************************************************************************/ +// ***************************************************************************** struct LambdaFactory { - enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB }; + enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB, RECIP }; OP_CATEGORY op; template @@ -595,6 +785,8 @@ struct LambdaFactory { } case OP_CATEGORY::NEG: return [](T t) -> T { return -t; }; + case OP_CATEGORY::RECIP: + return [](T t) -> T { return 1.0 / t; }; default: VLOG(2) << "Not supported op for unary: " << static_cast(op); return nullptr; @@ -628,7 +820,6 @@ struct LambdaFactory { VLOG(2) << "LAMBDA VAL : " << val; return l + val; }; - // Return [val](T l)-> T {return l+val;}; case OP_CATEGORY::SUB: return [val](T l) -> T { VLOG(2) << "LAMBDA VAL : " << val; @@ -688,11 +879,13 @@ std::function LambdaFactory::unary() { } case OP_CATEGORY::NEG: return [](Eigen::half t) -> Eigen::half { return -t; }; + // TODO(aaroey): can we support RECIP? default: VLOG(2) << "Not supported op for unary: " << static_cast(op); return nullptr; } } + tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights, TRT_ShapedWeights* oweights, LambdaFactory unary_op) { @@ -738,6 +931,7 @@ tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l, if (iweights_l.count() != iweights_r.count()) { // We only supports broadcast of RankZero if (iweights_l.count() == 1) { + // TODO(aaroey): Remove loggings like this. VLOG(2) << "I bet it is not working!" << (*inp_l); std::transform(inp_r, inp_r + iweights_r.count(), oup, binary_op.broadcast_l(*inp_l)); @@ -790,117 +984,21 @@ tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l, return tensorflow::Status::OK(); } -tensorflow::Status ConstantFoldUnary( - Converter& ctx, const tensorflow::NodeDef& node_def, - const std::vector& inputs, - std::vector* outputs) { - TRT_ShapedWeights weights_input = inputs.at(0).weights(); - - // Allocate output weights - TRT_ShapedWeights weights_output = ctx.get_temp_weights_like(weights_input); - - // FIXME assume type matches input weights - // Get trt type & shape - // Maybe this part has to be moved into the block of rsqrt later - // Check type consistency - CHECK_EQ(weights_input.type_, - TFAttrs(node_def).get("T")); - - LambdaFactory unary_op; - if (node_def.op() == "Rsqrt") { - // Compute rsqrt - unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT; - auto ret = UnaryCompute(weights_input, &weights_output, unary_op); - // Pass the output - if (ret == tensorflow::Status::OK()) { - outputs->push_back(TRT_TensorOrWeights(weights_output)); - } - return ret; - } else { - return tensorflow::errors::Unimplemented("Binary op not supported: " + - node_def.op()); - } -} - -// TODO(jie,ben) broadcast is needed yet not implemented -// Let's get the simple stuff working first. Maybe we should fall back to TF -// approach for constant folding -tensorflow::Status ConstantFoldBinary( - Converter& ctx, const tensorflow::NodeDef& node_def, - const std::vector& inputs, - std::vector* outputs) { - TRT_ShapedWeights weights_input_l = inputs.at(0).weights(); - TRT_ShapedWeights weights_input_r = inputs.at(1).weights(); - - // Check type consistency - CHECK_EQ(weights_input_l.type_, weights_input_r.type_); - - if (weights_input_l.shape_.nbDims != weights_input_r.shape_.nbDims) - return tensorflow::errors::Unimplemented( - "Binary op implicit broadcast not supported: " + node_def.op()); - - // TODO(jie): constant fold should really fall back to TF. - int num_dims = weights_input_l.shape_.nbDims; - nvinfer1::Dims output_shape; - output_shape.nbDims = num_dims; - VLOG(2) << "nb_dims: " << num_dims - << ", the other: " << weights_input_r.shape_.nbDims; - for (int i = 0; i < num_dims; i++) { - if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) { - output_shape.d[i] = weights_input_l.shape_.d[i]; - } else if (weights_input_l.shape_.d[i] == 1 || - weights_input_r.shape_.d[i] == 1) { - output_shape.d[i] = - std::max(weights_input_l.shape_.d[i], weights_input_r.shape_.d[i]); - } else { - return tensorflow::errors::Unimplemented( - "Binary op with incompatible shape at, " + node_def.op()); - } - VLOG(2) << "left: " << weights_input_l.shape_.d[i] - << "right: " << weights_input_r.shape_.d[i] - << "output: " << output_shape.d[i]; - } - - // FIXME assume type matches input weights - // Get trt type & shape - TFAttrs attrs(node_def); - // Maybe this part has to be moved into the block of rsqrt later - tensorflow::DataType dtype = attrs.get("T"); - - // Allocate output weights - TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape); - - LambdaFactory binary_op; - if (node_def.op() == "Sub") { - binary_op.op = LambdaFactory::OP_CATEGORY::SUB; - } else if (node_def.op() == "Mul") { - binary_op.op = LambdaFactory::OP_CATEGORY::MUL; - } else if (node_def.op() == "Add") { - binary_op.op = LambdaFactory::OP_CATEGORY::ADD; - } else { - return tensorflow::errors::Unimplemented("Binary op not supported: " + - node_def.op()); - } - auto ret = BinaryCompute(weights_input_l, weights_input_r, &weights_output, - binary_op); - - // Pass the output - if (ret == tensorflow::Status::OK()) { - outputs->push_back(TRT_TensorOrWeights(weights_output)); - } - - return ret; -} - // TODO(jie): broadcast is needed yet not implemented. // Only implemented channel wise for the time being tensorflow::Status BinaryTensorOpWeight( Converter& ctx, const tensorflow::NodeDef& node_def, const nvinfer1::ITensor* tensor, TRT_ShapedWeights weights, - std::vector* outputs) { - // FIXME assume type matches input weights - // Get trt type & shape - // Maybe this part has to be moved into the block of rsqrt later + bool swapped_inputs, std::vector* outputs) { + // tensor is the left operand while weights is the right operand; + // when swapped_inputs set to true, those two are swapped. + // TODO(aaroey): use a set. + if (node_def.op() != "Sub" && node_def.op() != "Add" && + node_def.op() != "Mul" && node_def.op() != "Div" && + node_def.op() != "RealDiv") { + return tensorflow::errors::Unimplemented( + "op not supported: " + node_def.op() + ", at: " + node_def.name()); + } // Check type consistency nvinfer1::DataType ttype; @@ -910,6 +1008,12 @@ tensorflow::Status BinaryTensorOpWeight( auto dims_w = weights.shape_; auto dims_t = tensor->getDimensions(); + // TODO(jie): addScale checks for input tensor dimension + if (dims_t.nbDims != 3) { + return tensorflow::errors::InvalidArgument( + "addScale requires tensor with rank 3, " + node_def.name()); + } + // default to element-wise auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; @@ -980,6 +1084,7 @@ tensorflow::Status BinaryTensorOpWeight( permutation[dims_t.nbDims] = 1; tensor = ctx.TransposeTensor(const_cast(tensor), permutation); + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name()); } else { return tensorflow::errors::InvalidArgument( "Transpose cannot be applied, " + node_def.name()); @@ -997,11 +1102,35 @@ tensorflow::Status BinaryTensorOpWeight( // Maybe I should do a switch if (node_def.op() == "Sub") { - TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights); - LambdaFactory unary_op; - unary_op.op = LambdaFactory::OP_CATEGORY::NEG; - TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op)); - shift_weights = neg_weights; + if (swapped_inputs) { + shift_weights = weights; + nvinfer1::IUnaryLayer* layer = + ctx.network()->addUnary(*const_cast(tensor), + nvinfer1::UnaryOperation::kNEG); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + tensor = layer->getOutput(0); + } else { + TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights); + LambdaFactory unary_op; + unary_op.op = LambdaFactory::OP_CATEGORY::NEG; + TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op)); + shift_weights = neg_weights; + } + } else if (node_def.op() == "Div" || node_def.op() == "RealDiv") { + if (swapped_inputs) { + scale_weights = weights; + nvinfer1::IUnaryLayer* layer = + ctx.network()->addUnary(*const_cast(tensor), + nvinfer1::UnaryOperation::kRECIP); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + tensor = layer->getOutput(0); + } else { + TRT_ShapedWeights recip_weights = ctx.get_temp_weights_like(weights); + LambdaFactory unary_op; + unary_op.op = LambdaFactory::OP_CATEGORY::RECIP; + TF_RETURN_IF_ERROR(UnaryCompute(weights, &recip_weights, unary_op)); + scale_weights = recip_weights; + } } else if (node_def.op() == "Mul") { scale_weights = weights; } else if (node_def.op() == "Add") { @@ -1014,11 +1143,13 @@ tensorflow::Status BinaryTensorOpWeight( nvinfer1::IScaleLayer* layer = ctx.network()->addScale( *const_cast(tensor), scale_mode, shift_weights, scale_weights, power_weights); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); // transpose back dimension if (permutation_flag) { output_tensor = ctx.TransposeTensor(output_tensor, permutation); + TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name()); } // Pass the output @@ -1042,20 +1173,31 @@ tensorflow::Status ConvertConv2DHelper( if (data_format == "NHWC") { tensor = ctx.TransposeTensor(const_cast(tensor), {0, 3, 1, 2}); + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name()); h_index = 1; w_index = 2; // TODO(jie): transpose it } // tensor after transpose (NCHW) - auto tensor_dim = tensor->getDimensions(); + const auto tensor_dim = tensor->getDimensions(); int num_groups = group; - if (num_groups == 0) // depthwise convolution - num_groups = tensor_dim.d[0]; + if (num_groups == 0) num_groups = tensor_dim.d[0]; // depthwise convolution VLOG(2) << "groups count: " << num_groups; TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); + + VLOG(2) << "weight shape: " << weights_rsck.shape_.nbDims; + for (int i = 0; i < weights_rsck.shape_.nbDims; i++) { + VLOG(2) << weights_rsck.shape_.d[i]; + } + + if (weights_rsck.shape_.nbDims != 4) { + return tensorflow::errors::Internal( + "Conv2D expects kernel of dimension 4, at: " + node_def.name()); + } + if (ctx.isFP16()) { weights_rsck = ConvertFP32ToFP16(ctx, inputs.at(1).weights()); } @@ -1063,18 +1205,22 @@ tensorflow::Status ConvertConv2DHelper( TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck); ReorderRSCKToKCRS(weights_rsck, &weights, num_groups); TRT_ShapedWeights biases(weights.type_); - int noutput = weights.shape_.d[0] * num_groups; + const int noutput = weights.shape_.d[0] * num_groups; nvinfer1::DimsHW kernel_size; kernel_size.h() = weights.shape_.d[2]; kernel_size.w() = weights.shape_.d[3]; + VLOG(2) << "RSCK: "; + for (int i = 0; i < 4; i++) { + VLOG(2) << " " << weights.shape_.d[i]; + } VLOG(2) << "kernel size: " << kernel_size.h() << ", " << kernel_size.w(); // TODO(jie): stride. (NHWC/NCHW) - auto tf_stride = attrs.get>("strides"); + const auto tf_stride = attrs.get>("strides"); VLOG(2) << "h_INDEX" << h_index << ", w_index " << w_index; VLOG(2) << "stride!!!: " << tf_stride[0] << tf_stride[1] << tf_stride[2] << tf_stride[3]; - nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); + const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); std::vector> padding; // TODO(jie): padding. @@ -1102,6 +1248,7 @@ tensorflow::Status ConvertConv2DHelper( *const_cast(tensor), nvinfer1::DimsHW(padding[0].first, padding[1].first), nvinfer1::DimsHW(padding[0].second, padding[1].second)); + TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name()); padding = {{0, 0}, {0, 0}}; tensor = pad_layer->getOutput(0); auto dim_after = tensor->getDimensions(); @@ -1112,6 +1259,7 @@ tensorflow::Status ConvertConv2DHelper( nvinfer1::IConvolutionLayer* layer = ctx.network()->addConvolution(*const_cast(tensor), noutput, kernel_size, weights, biases); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); layer->setStride(stride); layer->setPadding({padding[0].first, padding[1].first}); @@ -1126,6 +1274,7 @@ tensorflow::Status ConvertConv2DHelper( if (data_format == "NHWC") { // TODO(jie): transpose it back! output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); + TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name()); } else { VLOG(2) << "NCHW !!!!"; } @@ -1147,35 +1296,91 @@ tensorflow::Status ConvertConv2DHelper( node_def.name()); } +// Helper function converts input into tensor with shape specified by dims. +bool PrepareTensorForShape(Converter& ctx, const TRT_TensorOrWeights& input, + const nvinfer1::Dims& dims, + const nvinfer1::ITensor** tensor) { + if (input.is_tensor()) { + if (DimsEqual(input.shape(), dims)) { + *tensor = input.tensor(); + } else { + nvinfer1::IShuffleLayer* layer = ctx.network()->addShuffle( + *const_cast(input.tensor())); + if (layer != nullptr) { + layer->setReshapeDimensions(dims); + *tensor = layer->getOutput(0); + } else { + return false; + } + } + } else { +#if NV_TENSORRT_MAJOR > 3 + nvinfer1::IConstantLayer* layer = + ctx.network()->addConstant(dims, input.weights()); + if (layer != nullptr) { + *tensor = layer->getOutput(0); + } else { + return false; + } +#else + return false; +#endif + } + return true; +} + tensorflow::Status BinaryTensorOpTensor( Converter& ctx, const tensorflow::NodeDef& node_def, - const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r, + const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r, std::vector* outputs) { static const std::unordered_map ops{ {"Add", nvinfer1::ElementWiseOperation::kSUM}, {"Mul", nvinfer1::ElementWiseOperation::kPROD}, {"Sub", nvinfer1::ElementWiseOperation::kSUB}, {"Div", nvinfer1::ElementWiseOperation::kDIV}, + {"RealDiv", nvinfer1::ElementWiseOperation::kDIV}, + {"Minimum", nvinfer1::ElementWiseOperation::kMIN}, + {"Maximum", nvinfer1::ElementWiseOperation::kMAX}, }; - // FIXME assume type matches input weights + const nvinfer1::ITensor* tensor_l; + const nvinfer1::ITensor* tensor_r; + + nvinfer1::Dims dim_l; + nvinfer1::Dims dim_r; + + if (!TensorRTGetBroadcastShape(operand_l.shape(), operand_l.is_tensor(), + operand_r.shape(), operand_r.is_tensor(), + &dim_l, &dim_r)) { + return tensorflow::errors::InvalidArgument( + "Binary op broadcast scheme not supported by TensorRT op: " + + node_def.op() + ", at: " + node_def.name()); + } + + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, operand_l, dim_l, &tensor_l), node_def.name()); + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, operand_r, dim_r, &tensor_r), node_def.name()); + // get trt type & shape TFAttrs attrs(node_def); // maybe this part has to be moved into the block of rsqrt later nvinfer1::DataType dtype = attrs.get("T"); // check type consistency - CHECK_EQ_TYPE(tensor_l->getType(), dtype); - CHECK_EQ_TYPE(tensor_r->getType(), dtype); + TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype); + TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype); auto op_pair = ops.find(node_def.op()); - if (op_pair == ops.end()) + if (op_pair == ops.end()) { return tensorflow::errors::Unimplemented( - "binary op: " + node_def.op() + - " not supported at: " + node_def.name()); + "binary op: ", node_def.op(), " not supported at: ", node_def.name()); + } nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise( + // TODO(aaroey): will tensor_l/tensor_r get modified? *const_cast(tensor_l), *const_cast(tensor_r), op_pair->second); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); @@ -1202,7 +1407,7 @@ tensorflow::Status ConvertPlugin(Converter& ctx, // passing attributes // TODO(jie): support more general attribute TFAttrs attrs(node_def); - auto attr_key_vector = attrs.GetAllAttrKey(); + auto attr_key_vector = attrs.GetAllAttrKeys(); for (auto attr_key : attr_key_vector) { // TODO(jie): support only list of float for toy example here. auto data = attrs.get>(attr_key); @@ -1223,29 +1428,6 @@ tensorflow::Status ConvertPlugin(Converter& ctx, return tensorflow::Status::OK(); } -tensorflow::Status ConvertPlaceholder( - Converter& ctx, const tensorflow::NodeDef& node_def, - const std::vector& inputs, - std::vector* outputs) { - VLOG(2) << "Placeholder should have been replace already"; - return tensorflow::errors::Unimplemented("cannot convert Placeholder op"); - // OK this make sense since we are supposed to replace it with input - TFAttrs attrs(node_def); - nvinfer1::DataType dtype = attrs.get("dtype"); - nvinfer1::Dims dims = attrs.get("shape"); - - dims.nbDims--; - for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1]; - - nvinfer1::ITensor* output = - ctx.network()->addInput(node_def.name().c_str(), dtype, dims); - if (!output) { - return tensorflow::errors::InvalidArgument("Failed to create Input layer"); - } - outputs->push_back(TRT_TensorOrWeights(output)); - return tensorflow::Status::OK(); -} - tensorflow::Status ConvertConv2D(Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector& inputs, @@ -1271,65 +1453,64 @@ tensorflow::Status ConvertPool(Converter& ctx, int h_index = 2; int w_index = 3; - auto data_format = attrs.get("data_format"); + const auto data_format = attrs.get("data_format"); if (data_format == "NHWC") { h_index = 1; w_index = 2; tensor = ctx.TransposeTensor(const_cast(tensor), {0, 3, 1, 2}); - } else { - VLOG(2) << "NCHW !!!!"; + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name()); } + nvinfer1::PoolingType type; - // TODO(jie): support other pooling type - if (node_def.op() == "MaxPool") + if (node_def.op() == "MaxPool") { type = nvinfer1::PoolingType::kMAX; - else if (node_def.op() == "AvgPool") + } else if (node_def.op() == "AvgPool") { type = nvinfer1::PoolingType::kAVERAGE; - else - return tensorflow::errors::Unimplemented("Only supports Max pool"); + } else { + return tensorflow::errors::Unimplemented("Unsupported pool type: ", + node_def.op()); + } - // TODO(jie): NCHW - auto tf_stride = attrs.get>("strides"); - nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); + const auto tf_stride = attrs.get>("strides"); + const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); - auto tf_kernel = attrs.get>("ksize"); - nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]); + const auto tf_kernel = attrs.get>("ksize"); + const nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]); auto tensor_dim = tensor->getDimensions(); std::vector> padding; - // TODO(jie): padding. - if (attrs.get("padding") == "SAME") { + const string padding_type = attrs.get("padding"); + if (padding_type == "SAME") { // This is NCHW tensor with no batch dimension. // 1 -> h // 2 -> w padding = CreateSamePadding( stride, ksize, {static_cast(tensor_dim.d[1]), static_cast(tensor_dim.d[2])}); - } else if (attrs.get("padding") == "VALID") { - // No padding for valid padding here - VLOG(2) << "No padding added for VALID padding in pool" << node_def.name(); + } else if (padding_type == "VALID") { padding = {{0, 0}, {0, 0}}; } else { - return tensorflow::errors::Unimplemented( - "Current MaxPool cannot support padding other than SAME"); + return tensorflow::errors::Unimplemented("Unsupported padding type: ", + padding_type); } if (padding[0].first != padding[0].second || padding[1].first != padding[1].second) { - // TODO(jie): handle asymmetric padding VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second << padding[1].first << padding[1].second; auto pad_layer = ctx.network()->addPadding( *const_cast(tensor), nvinfer1::DimsHW(padding[0].first, padding[1].first), nvinfer1::DimsHW(padding[0].second, padding[1].second)); + TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name()); padding = {{0, 0}, {0, 0}}; tensor = pad_layer->getOutput(0); } nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling( *const_cast(tensor), type, ksize); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); layer->setStride(stride); layer->setPadding({padding[0].first, padding[1].first}); @@ -1337,10 +1518,8 @@ tensorflow::Status ConvertPool(Converter& ctx, nvinfer1::ITensor* output_tensor = layer->getOutput(0); if (data_format == "NHWC") { - // TODO(jie): transpose it back! output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); - } else { - VLOG(2) << "NCHW !!!!"; + TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name()); } outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); @@ -1353,6 +1532,7 @@ tensorflow::Status ConvertActivation( const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); nvinfer1::IActivationLayer* layer = ctx.network()->addActivation( *const_cast(tensor), nvinfer1::ActivationType::kRELU); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); @@ -1363,40 +1543,61 @@ tensorflow::Status ConvertScale(Converter& ctx, const std::vector& inputs, std::vector* outputs) { if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) + !inputs.at(1).is_weights()) { return tensorflow::errors::Unimplemented( - "Only supports tensor op weight for now, at " + node_def.name()); - // Implement tensor binaryOp weight [channel wise] for now; - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + "ConvertScale only supports tensorweight: ", node_def.name()); + } + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); TRT_ShapedWeights weights = inputs.at(1).weights(); if (ctx.isFP16()) { weights = ConvertFP32ToFP16(ctx, inputs.at(1).weights()); } TRT_ShapedWeights empty_weights(weights.type_); - TFAttrs attrs(node_def); - // Transpose NHWC - auto data_format = attrs.get("data_format"); + const auto data_format = attrs.get("data_format"); + int channel_index; + const auto dims = tensor->getDimensions(); if (data_format == "NHWC") { - tensor = ctx.TransposeTensor(const_cast(tensor), - {0, 3, 1, 2}); - // TODO(jie): transpose it + // 1). NHWC is really N+C + channel_index = dims.nbDims - 1; // batch dimension is implicit here! } else { - VLOG(2) << "NCHW !!!!"; + // 2). NCHW is really N+CHW + channel_index = dims.nbDims - 3; // batch dimension is implicit here! } - auto dims = tensor->getDimensions(); - VLOG(2) << "tensor dimensions: " << dims.nbDims; - for (int i = 0; i < dims.nbDims; i++) { - VLOG(2) << "i: " << dims.d[i]; + nvinfer1::Permutation permutation; + for (int32_t i = 0; i < dims.nbDims; ++i) { + permutation.order[i] = i; } - dims = weights.shape_; - VLOG(2) << "tensor dimensions: " << dims.nbDims; - for (int i = 0; i < dims.nbDims; i++) { - VLOG(2) << "i: " << dims.d[i]; + + if (channel_index >= 0) { + permutation.order[0] = channel_index; + permutation.order[channel_index] = 0; + } else { + return tensorflow::errors::Unimplemented( + "TFTRT::BiasAdd cannot apply on batch dimension, at ", node_def.name()); + } + + // TensorRT addScale requires input to be of rank 3, we need to apply + // transpose as well as reshape + if (channel_index != 0 || dims.nbDims != 3) { + nvinfer1::IShuffleLayer* shuffle_layer = + ctx.network()->addShuffle(*const_cast(tensor)); + TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name()); + nvinfer1::Dims reshape_dims; + reshape_dims.nbDims = 3; + reshape_dims.d[0] = 0; // 0 copy from the input + reshape_dims.d[1] = dims.nbDims >= 2 ? 0 : 1; // 0 copy from the input + reshape_dims.d[2] = dims.nbDims >= 3 ? -1 : 1; // -1 infer from the rest + if (channel_index != 0) { + // maybe we do not need this check. concerned about TRT optimization + shuffle_layer->setFirstTranspose(permutation); + } + shuffle_layer->setReshapeDimensions(reshape_dims); + tensor = shuffle_layer->getOutput(0); } nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL; @@ -1407,14 +1608,26 @@ tensorflow::Status ConvertScale(Converter& ctx, nvinfer1::IScaleLayer* layer = ctx.network()->addScale(*const_cast(tensor), mode, weights, empty_weights, empty_weights); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); - if (data_format == "NHWC") { - // TODO(jie): transpose it back! - output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); - } else { - VLOG(2) << "NCHW !!!!"; + + // restore transpose & reshape + if (channel_index != 0 || dims.nbDims != 3) { + nvinfer1::IShuffleLayer* shuffle_layer = ctx.network()->addShuffle( + *const_cast(output_tensor)); + TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name()); + nvinfer1::Dims reshape_dims = dims; + int tmp = reshape_dims.d[channel_index]; + reshape_dims.d[channel_index] = reshape_dims.d[0]; + reshape_dims.d[0] = tmp; + shuffle_layer->setReshapeDimensions(reshape_dims); + if (channel_index != 0) { + shuffle_layer->setSecondTranspose(permutation); + } + output_tensor = shuffle_layer->getOutput(0); } + outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } @@ -1431,11 +1644,13 @@ tensorflow::Status ConvertConst(Converter& ctx, // Create shaped weights as output tensorflow::Tensor tensor; - if (!tensor.FromProto(weights_tensor)) - return tensorflow::errors::Internal("Cannot parse weight tensor proto: " + + if (!tensor.FromProto(weights_tensor)) { + return tensorflow::errors::Internal("Cannot parse weight tensor proto: ", node_def.name()); + } TRT_ShapedWeights weights(dtype); + // TODO(aaroey): we should choose the array using dtype and shape. if (!weights_tensor.float_val().empty()) { VLOG(2) << "SCALAR!!!" << node_def.name(); nvinfer1::Dims scalar_shape; @@ -1443,22 +1658,16 @@ tensorflow::Status ConvertConst(Converter& ctx, VLOG(2) << "dimensions: " << tensor.dims(); VLOG(2) << "size: " << weights_tensor.float_val_size(); scalar_shape = GetTensorShape(tensor); + VLOG(2) << "details: "; for (int i = 0; i < scalar_shape.nbDims; i++) VLOG(2) << scalar_shape.d[i]; - if (GetShapeSize(scalar_shape) != weights_tensor.float_val_size()) { - if (weights_tensor.float_val_size() == 1 || - scalar_shape.d[0] == weights_tensor.float_val_size()) { - scalar_shape.nbDims = 1; - // no dimension provided. flatten it - scalar_shape.d[0] = weights_tensor.float_val_size(); - scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL; - } else { - LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and" - << " kUNIFORM, at: " << node_def.name(); - string err_str("Broadcast method is not supported for '"); - StrAppend(&err_str, node_def.name(), "' of type ", node_def.op()); - return tensorflow::errors::InvalidArgument(err_str); - } + if (GetShapeSize(scalar_shape) != weights_tensor.float_val_size() && + weights_tensor.float_val_size() != 1) { + LOG(ERROR) << "Broadcast on weights only supports kCHANNEL and" + << " kUNIFORM, at: " << node_def.name(); + string err_str("Broadcast method is not supported for '"); + StrAppend(&err_str, node_def.name(), "' of type ", node_def.op()); + return tensorflow::errors::InvalidArgument(err_str); } } else { VLOG(2) << "Dimensions: " << tensor.dims(); @@ -1468,39 +1677,42 @@ tensorflow::Status ConvertConst(Converter& ctx, scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL; for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) { scalar_shape.d[i] = 0; - scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL; } } + // TODO(aaroey): use GetShapeSize(). size_t len_data = tensorflow::DataTypeSize(dtype); for (int i = 0; i < scalar_shape.nbDims; i++) len_data *= scalar_shape.d[i]; ctx.weight_store()->store_.push_back(std::vector(len_data)); void* dst = static_cast(&(ctx.weight_store()->store_.back()[0])); - std::vector tensor_data( - weights_tensor.float_val().begin(), - weights_tensor.float_val() - .end()); // make a local copy first to flatten - memcpy(dst, tensor_data.data(), len_data); // store into weight store + if (weights_tensor.float_val_size() == 1) { + std::fill_n((float*)dst, GetShapeSize(scalar_shape), + *weights_tensor.float_val().begin()); + } else { + // TODO(aaroey): get rid of this copy as RepeatedField is always + // contiguous make a local copy first to flatten doesn't have to be + // contiguous + std::vector tensor_data(weights_tensor.float_val().begin(), + weights_tensor.float_val().end()); + memcpy(dst, tensor_data.data(), len_data); // store into weight store + } + VLOG(2) << "create shape details: "; + for (int i = 0; i < scalar_shape.nbDims; i++) VLOG(2) << scalar_shape.d[i]; weights = TRT_ShapedWeights(dtype, dst, scalar_shape); } else if (!weights_tensor.int_val().empty()) { + // TODO(aaroey): this is very similar to the above code for float, merge + // them. VLOG(2) << "int!!!" << node_def.name(); nvinfer1::Dims scalar_shape; if (tensor.dims() > 0) { VLOG(2) << "dimensions: " << tensor.dims(); scalar_shape = GetTensorShape(tensor); - if (GetShapeSize(scalar_shape) != weights_tensor.int_val_size()) { - if (weights_tensor.int_val_size() == 1 || - scalar_shape.d[0] == weights_tensor.int_val_size()) { - scalar_shape.nbDims = 1; - // no dimension provided. flatten it - scalar_shape.d[0] = weights_tensor.int_val_size(); - scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL; - } else { - LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and" - << " kUNIFORM, at: " << node_def.name(); - string err_str("Broadcast method is not supported for '"); - StrAppend(&err_str, node_def.name(), "' of type ", node_def.op()); - return tensorflow::errors::InvalidArgument(err_str); - } + if (GetShapeSize(scalar_shape) != weights_tensor.int_val_size() && + weights_tensor.int_val_size() != 1) { + LOG(WARNING) << "Broadcast on weights only supports kCHANNEL and" + << " kUNIFORM, at: " << node_def.name(); + string err_str("Broadcast method is not supported for '"); + StrAppend(&err_str, node_def.name(), "' of type ", node_def.op()); + return tensorflow::errors::InvalidArgument(err_str); } } else { VLOG(2) << "dimensions: " << tensor.dims(); @@ -1513,23 +1725,30 @@ tensorflow::Status ConvertConst(Converter& ctx, scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL; } } - // we should not have converted //if (ctx.isFP16()) { + // we should not have converted size_t len_data = tensorflow::DataTypeSize(dtype); for (int i = 0; i < scalar_shape.nbDims; i++) len_data *= scalar_shape.d[i]; size_t len_tensor = weights_tensor.int_val_size() * sizeof(int32); len_data = std::max(len_data, len_tensor); ctx.weight_store()->store_.push_back(std::vector(len_data)); void* dst = static_cast(&(ctx.weight_store()->store_.back()[0])); - std::vector tensor_data( - weights_tensor.int_val().begin(), - weights_tensor.int_val().end()); // make a local copy first to flatten - // doesn't have to be contigous - memcpy(dst, tensor_data.data(), len_tensor); // store into weight store + if (weights_tensor.int_val_size() == 1) { + std::fill_n((int*)dst, GetShapeSize(scalar_shape), + *weights_tensor.int_val().begin()); + } else { + // TODO(aaroey): get rid of this copy as RepeatedField is always + // contiguous make a local copy first to flatten doesn't have to be + // contiguous + std::vector tensor_data(weights_tensor.int_val().begin(), + weights_tensor.int_val().end()); + memcpy(dst, tensor_data.data(), len_tensor); // store into weight store + } weights = TRT_ShapedWeights(dtype, dst, scalar_shape); } else if (!weights_tensor.tensor_content().empty()) { - // obsolete method. - // After optimization path, we do not see weights in this format. - // fp16 conversion technically should be needed here. + // obsolete method. + // After optimization path, we do not see weights in this format. + // TODO(aaroey): why? + // fp16 conversion technically should be needed here. VLOG(2) << "TENSOR!!!" << node_def.name(); const auto& content = weights_tensor.tensor_content(); @@ -1543,8 +1762,8 @@ tensorflow::Status ConvertConst(Converter& ctx, content, static_cast(const_cast(weights.GetValues()))); } } else { - return tensorflow::errors::Unimplemented( - "Not supported constant type, at " + node_def.name()); + return tensorflow::errors::Unimplemented("Not supported constant type, at ", + node_def.name()); } // Pass the output outputs->push_back(TRT_TensorOrWeights(weights)); @@ -1563,96 +1782,144 @@ tensorflow::Status ConvertBinary(Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector& inputs, std::vector* outputs) { - if (inputs.size() != 2) + if (inputs.size() != 2) { return tensorflow::errors::FailedPrecondition( - "Binary ops require two tensor input, at " + node_def.name()); - - if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) - return ConstantFoldBinary(ctx, node_def, inputs, outputs); - - if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) - return BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(), - inputs.at(1).weights(), outputs); + "Binary ops require two tensor input, at ", node_def.name()); + } - if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) - return BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(), - inputs.at(0).weights(), outputs); + // Constant folding should have been done by TensorFlow - if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) - return BinaryTensorOpTensor(ctx, node_def, inputs.at(0).tensor(), - inputs.at(1).tensor(), outputs); + if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) { + return tensorflow::errors::Unimplemented( + "Constant folding is falled back to TensorFlow, binary op received " + "both input as constant at: ", + node_def.name()); + } - return tensorflow::errors::Unknown("Binary op input error, at " + - node_def.name()); + // Try to convert into Scale layer first (for better performance) + // Since scale layer supports restricted broadcast policy and op types, we + // allow failure and try to handle it through Elementwise op + // (BinaryTensorOpTensor) + Status status = tensorflow::Status::OK(); + if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) { + status = BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(), + inputs.at(1).weights(), false, outputs); + } else if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) { + status = BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(), + inputs.at(0).weights(), true, outputs); +#if NV_TENSORRT_MAJOR == 3 + } else { +#else + } + if ((inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) || !status.ok()) { +#endif + status = BinaryTensorOpTensor(ctx, node_def, inputs.at(0), inputs.at(1), + outputs); + } + return status; } tensorflow::Status ConvertUnary(Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector& inputs, std::vector* outputs) { - if (inputs.size() != 1) + static const std::unordered_map ops{ + {"Neg", nvinfer1::UnaryOperation::kNEG}, + {"Exp", nvinfer1::UnaryOperation::kEXP}, + {"Log", nvinfer1::UnaryOperation::kLOG}, + {"Sqrt", nvinfer1::UnaryOperation::kSQRT}, + {"Abs", nvinfer1::UnaryOperation::kABS}, + {"Reciprocal", nvinfer1::UnaryOperation::kRECIP}, + }; + + if (inputs.size() != 1) { return tensorflow::errors::FailedPrecondition( - "Unary ops require single tensor input, at " + node_def.name()); + "Unary ops require single tensor input, at ", node_def.name()); + } - if (inputs.at(0).is_weights()) - return ConstantFoldUnary(ctx, node_def, inputs, outputs); - else if (inputs.at(0).is_tensor()) +#if NV_TENSORRT_MAJOR == 3 + if (inputs.at(0).is_weights()) { return tensorflow::errors::Unimplemented( - "Unary op for tensor not supported, at " + node_def.name()); + "Constant folding for unary op is not supported", node_def.name()); + } +#endif + + // TODO(jie): check type + const nvinfer1::ITensor* tensor; + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, inputs.at(0), inputs.at(0).shape(), &tensor), + node_def.name()); - return tensorflow::errors::Unknown("Binary op input error, at " + - node_def.name()); + nvinfer1::IUnaryLayer* layer; + if (node_def.op() == "Rsqrt") { + layer = ctx.network()->addUnary(*const_cast(tensor), + nvinfer1::UnaryOperation::kSQRT); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + tensor = layer->getOutput(0); + layer = ctx.network()->addUnary(*const_cast(tensor), + nvinfer1::UnaryOperation::kRECIP); + } else if (ops.count(node_def.op()) != 0) { + layer = ctx.network()->addUnary(*const_cast(tensor), + ops.at(node_def.op())); + } else { + return tensorflow::errors::InvalidArgument( + "Binary op: ", node_def.op(), " not supported, at ", node_def.name()); + } + + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); } -tensorflow::Status ConvertReduce(Converter& ctx, - const tensorflow::NodeDef& node_def, - const std::vector& inputs, - std::vector* outputs) { +#if NV_TENSORRT_MAJOR == 3 +tensorflow::Status ConvertReducePool( + Converter& ctx, const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) + !inputs.at(1).is_weights()) { return tensorflow::errors::InvalidArgument( - "Input expects tensor and weights, at" + node_def.name()); + "Input expects tensor and weights, at", node_def.name()); + } // Implement tensor binaryOp weight [channel wise] for now; const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - auto dims = tensor->getDimensions(); + const auto dims = tensor->getDimensions(); // Restore implicit batch dimension - int nb_dims = dims.nbDims + 1; + const int nb_dims = dims.nbDims + 1; TRT_ShapedWeights index_list = inputs.at(1).weights(); - TFAttrs attrs(node_def); - // TODO(jie): handle data type. - // Index type here is done through TF type, so I can leverage their - // EnumToDataType for my cast auto index_type = attrs.get("Tidx"); // Only expect to handle INT32 as attributes for now - if (index_type != tensorflow::DataType::DT_INT32) + if (index_type != tensorflow::DataType::DT_INT32) { return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32"); - auto index_list_data = + } + const auto index_list_data = static_cast(const_cast(index_list.GetValues())); - // Hack warning: have to fall back to pool layer since reduce is not in public - // TRT yet. - if (nb_dims != 4) + if (nb_dims != 4) { return tensorflow::errors::InvalidArgument( - "TRT only support reduce on 4 dimensional tensors, at" + + "TRT only support reduce on 4 dimensional tensors, at", node_def.name()); - if (index_list.count() > 2) + } + if (index_list.count() > 2) { return tensorflow::errors::InvalidArgument( - "TRT cannot support reduce on more than 2 dimensions, at" + + "TRT cannot support reduce on more than 2 dimensions, at", node_def.name()); + } std::set idx_set; // We cannot operate on Channel. permutation flag used to transpose tensor int permuted_index = -1; for (int i = 0; i < index_list.count(); i++) { - if (index_list_data[i] == 0) - return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" + + if (index_list_data[i] == 0) { + return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at", node_def.name()); + } if (index_list_data[i] == 1) permuted_index = 1; - idx_set.emplace(index_list_data[i]); } @@ -1673,6 +1940,7 @@ tensorflow::Status ConvertReduce(Converter& ctx, // Apply permutation before extracting dimension for pool_kernel tensor = ctx.TransposeTensor(const_cast(tensor), permutation_order); + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name()); } // Apply permutation before extracting dimension for pool_kernel @@ -1685,34 +1953,104 @@ tensorflow::Status ConvertReduce(Converter& ctx, nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling(*const_cast(tensor), nvinfer1::PoolingType::kAVERAGE, pool_kernel); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); output_tensor = layer->getOutput(0); } else { - return tensorflow::errors::Unimplemented( - "Op not supported " + node_def.op() + " , at " + node_def.name()); + return tensorflow::errors::Unimplemented("Op not supported ", node_def.op(), + " , at ", node_def.name()); } if (permuted_index != -1) { // Apply permutation before extracting dimension for pool_kernel output_tensor = ctx.TransposeTensor( const_cast(output_tensor), permutation_order); + TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name()); } outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } +#elif NV_TENSORRT_MAJOR > 3 +tensorflow::Status ConvertReduce(Converter& ctx, + const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + if (inputs.size() != 2 || !inputs.at(0).is_tensor() || + !inputs.at(1).is_weights()) { + return tensorflow::errors::InvalidArgument( + "Input expects tensor and weights, at", node_def.name()); + } + + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); + TRT_ShapedWeights index_list = inputs.at(1).weights(); + + TFAttrs attrs(node_def); + auto index_type = attrs.get("Tidx"); + + // Only expect to handle INT32 as attributes for now + if (index_type != tensorflow::DataType::DT_INT32) { + return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32"); + } + + const auto keep_dims = attrs.get("keep_dims"); + auto index_list_data = + static_cast(const_cast(index_list.GetValues())); + + int axes = 0; + if (index_list.count() == 0) { + return tensorflow::errors::InvalidArgument( + "TRT cannot support reduce on all (batch) dimensions, at", + node_def.name()); + } else { + for (int i = 0; i < index_list.count(); i++) { + if (index_list_data[i] == 0) { + return tensorflow::errors::InvalidArgument( + "TRT cannot reduce at batch dimension, at", node_def.name()); + } + axes |= (1 << (index_list_data[i] - 1)); + } + } + + nvinfer1::ReduceOperation reduce_operation; + if (node_def.op() == "Sum") { + reduce_operation = nvinfer1::ReduceOperation::kSUM; + } else if (node_def.op() == "Prod") { + reduce_operation = nvinfer1::ReduceOperation::kPROD; + } else if (node_def.op() == "Max") { + reduce_operation = nvinfer1::ReduceOperation::kMAX; + } else if (node_def.op() == "Min") { + reduce_operation = nvinfer1::ReduceOperation::kMIN; + } else if (node_def.op() == "Mean") { + reduce_operation = nvinfer1::ReduceOperation::kAVG; + } else { + return tensorflow::errors::Unimplemented("Op not supported ", node_def.op(), + " , at ", node_def.name()); + } + + nvinfer1::ILayer* layer = + ctx.network()->addReduce(*const_cast(tensor), + reduce_operation, axes, keep_dims); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + + outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return tensorflow::Status::OK(); +} +#endif tensorflow::Status ConvertPad(Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector& inputs, std::vector* outputs) { + // TODO(aaroey): make a routine for this check and reuse it. if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) + !inputs.at(1).is_weights()) { return tensorflow::errors::InvalidArgument( - "Input expects tensor and weights, at" + node_def.name()); + "Input expects tensor and weights, at", node_def.name()); + } // Implement tensor binaryOp weight [channel wise] for now; const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - auto dims = tensor->getDimensions(); + const auto dims = tensor->getDimensions(); // Restore implicit batch dimension - int nb_dims = dims.nbDims + 1; + const int nb_dims = dims.nbDims + 1; TRT_ShapedWeights pads = inputs.at(1).weights(); @@ -1722,21 +2060,24 @@ tensorflow::Status ConvertPad(Converter& ctx, auto padding_type = attrs.get("Tpaddings"); // TODO(jie): handle data type conversion for TRT? - if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2) + if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2) { return tensorflow::errors::InvalidArgument( - "Pad only supports explicit padding on 4 dimensional tensor, at " + + "Pad only supports explicit padding on 4 dimensional tensor, at ", node_def.name()); + } // Only expect to handle INT32 as attributes for now - if (padding_type != tensorflow::DataType::DT_INT32) + if (padding_type != tensorflow::DataType::DT_INT32) { return tensorflow::errors::Unimplemented( "Tpaddings supports only DT_INT32"); + } auto pad_data = static_cast(const_cast(pads.GetValues())); std::vector pad_index; for (int i = 0; i < nb_dims; i++) { - if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0) + if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0) { pad_index.push_back(i); + } } // No padding at all, we should exit @@ -1746,20 +2087,23 @@ tensorflow::Status ConvertPad(Converter& ctx, } // Only supports padding on less than 2 axis GIE-2579 - if (pad_index.size() > 2) + if (pad_index.size() > 2) { return tensorflow::errors::InvalidArgument( "Padding layer does not support padding on > 2"); + } // Padding on batch dimension is not supported - if (pad_index[0] == 0) + if (pad_index[0] == 0) { return tensorflow::errors::InvalidArgument( "Padding layer does not support padding on batch dimension"); + } // Not doing the legit thing here. ignoring padding on dim 1 and 3; // TODO(jie): implement pad as uff parser - if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3) + if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3) { return tensorflow::errors::Unimplemented( "Padding layer does not support padding on dimension 1 and 3 yet"); + } bool legit_pad = true; nvinfer1::DimsHW pre_padding(0, 0); @@ -1770,6 +2114,7 @@ tensorflow::Status ConvertPad(Converter& ctx, legit_pad = false; tensor = ctx.TransposeTensor(const_cast(tensor), {0, 3, 2, 1}); + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor, node_def.name()); permuted_pad_index[0] = 3; } @@ -1786,11 +2131,14 @@ tensorflow::Status ConvertPad(Converter& ctx, nvinfer1::IPaddingLayer* layer = ctx.network()->addPadding( *const_cast(tensor), pre_padding, post_padding); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); - if (!legit_pad) + if (!legit_pad) { output_tensor = ctx.TransposeTensor( const_cast(output_tensor), {0, 3, 2, 1}); + TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name()); + } outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); @@ -1803,9 +2151,10 @@ tensorflow::Status ConvertConcat(Converter& ctx, // not including the last input (axis) here int input_size = static_cast(inputs.size()) - 1; - if (!inputs.at(0).is_tensor()) + if (!inputs.at(0).is_tensor()) { return tensorflow::errors::InvalidArgument( - "Concat in TRT support only Tensor input, at " + node_def.name()); + "Concat in TRT support only Tensor input, at ", node_def.name()); + } // We are retrieving the axis TRT_ShapedWeights axis = inputs.at(input_size).weights(); @@ -1816,8 +2165,8 @@ tensorflow::Status ConvertConcat(Converter& ctx, // TODO(jie): handle data type // Only expect to handle INT32 as index attributes for now if (index_type != tensorflow::DataType::DT_INT32) - return tensorflow::errors::Unimplemented( - "Tidx supports only DT_INT32, at " + node_def.name()); + return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32, at ", + node_def.name()); int index = *(static_cast(const_cast(axis.GetValues()))); @@ -1825,23 +2174,29 @@ tensorflow::Status ConvertConcat(Converter& ctx, auto dim = inputs.at(0).tensor()->getDimensions(); // dimension check - if (index > dim.nbDims + 1) + if (index > dim.nbDims + 1) { return tensorflow::errors::InvalidArgument( - "Concatenate on axis out of dimension range, at " + node_def.name()); - - if (index == 0) + "Concatenate on axis out of dimension range, at ", node_def.name()); + } + if (index == 0) { return tensorflow::errors::InvalidArgument( - "Concatenate on batch dimension not supported, at " + node_def.name()); + "Concatenate on batch dimension not supported, at ", node_def.name()); + } + if (index < 0) { + index = dim.nbDims + index + 1; + } +#if NV_TENSORRT_MAJOR == 3 // incase we need permutation; std::vector permutation_order(dim.nbDims + 1); for (int i = 0; i < dim.nbDims + 1; i++) permutation_order[i] = i; if (index != 1) { - permutation_order[1] = index - 1; - permutation_order[index - 1] = 1; + permutation_order[1] = index; + permutation_order[index] = 1; } +#endif std::vector inputs_vec; // Shap chack (all input tensor should have same shape) @@ -1849,24 +2204,28 @@ tensorflow::Status ConvertConcat(Converter& ctx, for (int i = 0; i < input_size; i++) { auto tensor_i = inputs.at(i).tensor(); auto dim_i = tensor_i->getDimensions(); - if (dim_i.nbDims != dim.nbDims) + if (dim_i.nbDims != dim.nbDims) { return tensorflow::errors::InvalidArgument( - "Concatenate receives inputs with inconsistent dimensions, at " + + "Concatenate receives inputs with inconsistent dimensions, at ", node_def.name()); - + } for (int j = 0; j < dim.nbDims; j++) { // check dimension consistency on non-concatenate axis - if (j != index - 1 && dim_i.d[j] != dim.d[j]) + if (j != index - 1 && dim_i.d[j] != dim.d[j]) { return tensorflow::errors::InvalidArgument( - "Concatenate receives inputs with inconsistent shape, at" + + "Concatenate receives inputs with inconsistent shape, at", node_def.name()); + } } - // TRT does concatenation only on channel! - if (index != 1) +#if NV_TENSORRT_MAJOR == 3 + // TRT3 does concatenation only on channel! + if (index != 1) { tensor_i = ctx.TransposeTensor(const_cast(tensor_i), permutation_order); - + TFTRT_RETURN_ERROR_IF_NULLPTR(tensor_i, node_def.name()); + } +#endif inputs_vec.push_back(tensor_i); } @@ -1874,11 +2233,18 @@ tensorflow::Status ConvertConcat(Converter& ctx, nvinfer1::IConcatenationLayer* layer = ctx.network()->addConcatenation( const_cast(inputs_vec.data()), inputs_vec.size()); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); +#if NV_TENSORRT_MAJOR > 3 + layer->setAxis(index - 1); +#endif nvinfer1::ITensor* output_tensor = layer->getOutput(0); +#if NV_TENSORRT_MAJOR == 3 if (index != 1) { output_tensor = ctx.TransposeTensor(output_tensor, permutation_order); + TFTRT_RETURN_ERROR_IF_NULLPTR(output_tensor, node_def.name()); } +#endif outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } @@ -1997,112 +2363,243 @@ tensorflow::Status ConvertFusedBatchNorm( combined_offset_weights.GetWeightsForTRT(), combined_scale_weights.GetWeightsForTRT(), dummy_power_weights.GetWeightsForTRT()); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } -tensorflow::Status ConvertMatMul(Converter& ctx, - const tensorflow::NodeDef& node_def, - const std::vector& inputs, - std::vector* outputs) { - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - - // TODO(jie): transpose! - TFAttrs attrs(node_def); +#if NV_TENSORRT_MAJOR > 3 +tensorflow::Status ConvertMatMulHelper( + Converter& ctx, TRT_TensorOrWeights tensor_input, + TRT_ShapedWeights weights_raw, bool transpose_weight, string node_name, + std::vector* outputs) { + nvinfer1::ITensor* output_tensor; + if (!tensor_input.is_tensor()) { + return tensorflow::errors::InvalidArgument("Input 0 expects tensor"); + } + const nvinfer1::ITensor* tensor = tensor_input.tensor(); - TRT_ShapedWeights weights_ck = inputs.at(1).weights(); - TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_ck); - ReorderCKtoKC(weights_ck, &weights); + TRT_ShapedWeights weights(weights_raw.type_); + if (transpose_weight) { + weights = weights_raw; + } else { + TRT_ShapedWeights weights_ck = weights_raw; + weights = ctx.get_temp_weights_like(weights_ck); + ReorderCKtoKC(weights_raw, &weights); + } TRT_ShapedWeights biases(weights.type_); int noutput = weights.shape_.d[0]; + auto input_dim = tensor->getDimensions(); + while (input_dim.nbDims != 3) { + input_dim.d[input_dim.nbDims++] = 1; + } + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, tensor_input, input_dim, &tensor), node_name); + nvinfer1::IFullyConnectedLayer* layer = ctx.network()->addFullyConnected( *const_cast(tensor), noutput, weights, biases); - - nvinfer1::ITensor* output_tensor = layer->getOutput(0); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_name); + output_tensor = layer->getOutput(0); + + const nvinfer1::ITensor* temp_tensor; + auto output_dim = output_tensor->getDimensions(); + output_dim.nbDims = 1; + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, TRT_TensorOrWeights(output_tensor), output_dim, + &temp_tensor), + node_name); + output_tensor = const_cast(temp_tensor); outputs->push_back(TRT_TensorOrWeights(output_tensor)); return tensorflow::Status::OK(); } -tensorflow::Status ConvertReshape( +// inputs are both two dimensional (tensorflow::ops::MatMul) +tensorflow::Status ConvertMatMul(Converter& ctx, + const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + if (!inputs.at(0).is_tensor()) { + return tensorflow::errors::InvalidArgument("Input 0 expects tensor, at" + + node_def.name()); + } + + TFAttrs attrs(node_def); + // TODO(jie): INT32 should be converted? + tensorflow::DataType tf_dtype = attrs.get("T"); + if (tf_dtype != tensorflow::DataType::DT_FLOAT && + tf_dtype != tensorflow::DataType::DT_HALF) { + return tensorflow::errors::Unimplemented( + "data type is not supported, for node " + node_def.name() + " got " + + tensorflow::DataTypeString(tf_dtype)); + } + bool transpose_a = attrs.get("transpose_a"); + bool transpose_b = attrs.get("transpose_b"); + + // FullyConnected: + if (transpose_a) { + return tensorflow::errors::Internal( + "Transpose_a is not supported for TensorRT FullyConnected (op: " + + node_def.op() + "), at: " + node_def.name()); + } + if (inputs.at(1).is_tensor()) { + return tensorflow::errors::Internal( + "Operand 1 must be constant for TensorRT FullyConnected (op: " + + node_def.op() + "), at: " + node_def.name()); + } + return ConvertMatMulHelper(ctx, inputs.at(0), inputs.at(1).weights(), + transpose_b, node_def.name(), outputs); +} + +tensorflow::Status ConvertBatchMatMul( Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector& inputs, std::vector* outputs) { - if (inputs.size() != 2 || !inputs.at(0).is_tensor() || - !inputs.at(1).is_weights()) - return tensorflow::errors::InvalidArgument( - "Input expects tensor and weights, at" + node_def.name()); + TFAttrs attrs(node_def); - // implement tensor binaryOp weight [channel wise] for now; - const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - auto dims = tensor->getDimensions(); - // restore implicit batch dimension + // TODO(jie): INT32 should be converted? + tensorflow::DataType tf_dtype = attrs.get("T"); + if (tf_dtype != tensorflow::DataType::DT_FLOAT && + tf_dtype != tensorflow::DataType::DT_HALF) { + return tensorflow::errors::Unimplemented( + "data type is not supported, for node " + node_def.name() + " got " + + tensorflow::DataTypeString(tf_dtype)); + } - TRT_ShapedWeights shape = inputs.at(1).weights(); + bool transpose_a = attrs.get("adj_x"); + bool transpose_b = attrs.get("adj_y"); - TFAttrs attrs(node_def); + auto dims = inputs.at(0).shape(); + if (dims.nbDims == 1) { // NC * CK is only supported through fully connected + if (transpose_a == false && inputs.at(0).is_tensor() && + inputs.at(1).is_weights()) { + return ConvertMatMulHelper(ctx, inputs.at(0), inputs.at(1).weights(), + transpose_b, node_def.name(), outputs); + } else { + return tensorflow::errors::InvalidArgument( + "Invalid configuration for MatMul, at: " + node_def.name()); + } + } - auto padding_type = attrs.get("Tshape"); + const nvinfer1::ITensor* tensor_l; + const nvinfer1::ITensor* tensor_r; + auto dims_l = inputs.at(0).shape(); + auto dims_r = inputs.at(1).shape(); + if (inputs.at(0).is_weights()) { + if (inputs.at(0).shape().d[0] != 1) { + return tensorflow::errors::InvalidArgument( + "Input 0 as weight assumes broadcast across batch for MatMul, at: " + + node_def.name()); + } else { + for (int i = 0; i < dims_l.nbDims - 1; i++) { + dims_l.d[i] = dims_l.d[i + 1]; + } + dims_l.nbDims--; + } + } + if (inputs.at(1).is_weights()) { + if (inputs.at(1).shape().d[0] != 1) { + return tensorflow::errors::InvalidArgument( + "Input 1 as weight assumes broadcast across batch for MatMul, at: " + + node_def.name()); + } else { + for (int i = 0; i < dims_r.nbDims - 1; i++) { + dims_r.d[i] = dims_r.d[i + 1]; + } + dims_r.nbDims--; + } + } - if (shape.shape_.nbDims != 1) - return tensorflow::errors::InvalidArgument( - "reshape new shape is not 1 dimensional, at " + node_def.name()); + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, inputs.at(0), dims_l, &tensor_l), + node_def.name()); + TFTRT_RETURN_ERROR_IF_FALSE( + PrepareTensorForShape(ctx, inputs.at(1), dims_r, &tensor_r), + node_def.name()); - // Only expect to handle INT32 as attributes for now - if (padding_type != tensorflow::DataType::DT_INT32) - return tensorflow::errors::Unimplemented( - "reshape new shape supports only DT_INT32, at " + node_def.name()); + nvinfer1::IMatrixMultiplyLayer* layer = ctx.network()->addMatrixMultiply( + *const_cast(tensor_l), transpose_a, + *const_cast(tensor_r), transpose_b); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} +#endif - auto shape_data = static_cast(const_cast(shape.GetValues())); +#if NV_TENSORRT_MAJOR > 3 +tensorflow::Status ConvertSoftmax( + Converter& ctx, const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - if (shape_data[0] != -1) + int nbDims = tensor->getDimensions().nbDims; + if (nbDims == 0) { return tensorflow::errors::InvalidArgument( - "reshape new shape first dimension is not -1, at " + node_def.name()); + "TensorRT Softmax cannot apply on batch dimension, at" + + node_def.name()); + } + nvinfer1::ISoftMaxLayer* layer = + ctx.network()->addSoftMax(*const_cast(tensor)); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + // Tensorflow SoftMax assumes applying softmax on the last dimension. + layer->setAxes(1 << (nbDims - 1)); - auto shape_num_dims = shape.shape_.d[0]; - VLOG(2) << "shape dimensions: " << shape_num_dims; - int volume_w = 1; - for (int i = 1; i < shape.shape_.d[0]; i++) volume_w *= shape_data[i]; + nvinfer1::ITensor* output_tensor = layer->getOutput(0); + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return tensorflow::Status::OK(); +} +#endif - int volume_t = 1; - for (int i = 0; i < dims.nbDims; i++) volume_t *= dims.d[i]; +#if NV_TENSORRT_MAJOR > 3 +tensorflow::Status ConvertTopK(Converter& ctx, + const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + const nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - VLOG(2) << "volume: " << volume_t << " volume weights: " << volume_w; - if (volume_w != volume_t) + int nbDims = tensor->getDimensions().nbDims; + if (nbDims == 0) { return tensorflow::errors::InvalidArgument( - "volume does not agree between tensor and new shape, at " + - node_def.name()); + "TensorRT TopK cannot apply on batch dimension, at" + node_def.name()); + } - nvinfer1::IShuffleLayer* layer = - ctx.network()->addShuffle(*const_cast(tensor)); + TRT_ShapedWeights k_w = inputs.at(1).weights(); + int k = *(static_cast(const_cast(k_w.GetValues()))); - nvinfer1::Dims reshape_dims; - VLOG(2) << "new dimension: " << shape_num_dims - 1; - reshape_dims.nbDims = shape_num_dims - 1; - for (int32_t i = 0; i < reshape_dims.nbDims; ++i) { - reshape_dims.d[i] = shape_data[i + 1]; + nvinfer1::TopKOperation op; + uint32_t reducedAxes = 0; + if (node_def.op() == "TopKV2") { + op = nvinfer1::TopKOperation::kMAX; + reducedAxes |= 1 << (nbDims - 1); + } else { + return tensorflow::errors::Unimplemented( + "Operation: " + node_def.op() + + " not implemented, at: " + node_def.name()); } - layer->setReshapeDimensions(reshape_dims); - VLOG(2) << "new dimension: " << shape_num_dims - 1; - nvinfer1::ITensor* output_tensor = layer->getOutput(0); - auto dims_output = output_tensor->getDimensions(); - VLOG(2) << "output tensor dimension:" << dims_output.nbDims; - outputs->push_back(TRT_TensorOrWeights(output_tensor)); + nvinfer1::ITopKLayer* layer = ctx.network()->addTopK( + *const_cast(tensor), op, k, reducedAxes); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + + nvinfer1::ITensor* output_value_tensor = layer->getOutput(0); + nvinfer1::ITensor* output_indices_tensor = layer->getOutput(1); + outputs->push_back(TRT_TensorOrWeights(output_value_tensor)); + outputs->push_back(TRT_TensorOrWeights(output_indices_tensor)); return tensorflow::Status::OK(); } +#endif void Converter::register_op_converters() { // vgg_16 slim implementation - op_registry_["Placeholder"] = ConvertPlaceholder; op_registry_["Conv2D"] = ConvertConv2D; op_registry_["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; op_registry_["Relu"] = ConvertActivation; op_registry_["MaxPool"] = ConvertPool; op_registry_["AvgPool"] = ConvertPool; - // This could be really handled as ConvertBinary op_registry_["BiasAdd"] = ConvertScale; op_registry_["Const"] = ConvertConst; // TODO(ben,jie): this is a temp hack. @@ -2113,17 +2610,39 @@ void Converter::register_op_converters() { op_registry_["Add"] = ConvertBinary; op_registry_["Mul"] = ConvertBinary; op_registry_["Sub"] = ConvertBinary; - op_registry_["Rsqrt"] = ConvertUnary; - op_registry_["Mean"] = ConvertReduce; op_registry_["Pad"] = ConvertPad; - // TODO(ben,jie): Add more ops op_registry_["ConcatV2"] = ConvertConcat; - op_registry_["MatMul"] = ConvertMatMul; - op_registry_["Reshape"] = ConvertReshape; op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm; op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm; + op_registry_["Div"] = ConvertBinary; + op_registry_["RealDiv"] = ConvertBinary; + + op_registry_["Rsqrt"] = ConvertUnary; + op_registry_["Reciprocal"] = ConvertUnary; + op_registry_["Exp"] = ConvertUnary; + op_registry_["Log"] = ConvertUnary; + op_registry_["Sqrt"] = ConvertUnary; + op_registry_["Abs"] = ConvertUnary; + op_registry_["Neg"] = ConvertUnary; +#if NV_TENSORRT_MAJOR == 3 + op_registry_["Mean"] = ConvertReducePool; +#endif +#if NV_TENSORRT_MAJOR > 3 + op_registry_["Sum"] = ConvertReduce; + op_registry_["Prod"] = ConvertReduce; + op_registry_["Max"] = ConvertReduce; + op_registry_["Min"] = ConvertReduce; + op_registry_["Mean"] = ConvertReduce; + op_registry_["Maximum"] = ConvertBinary; + op_registry_["Minimum"] = ConvertBinary; + op_registry_["Softmax"] = ConvertSoftmax; + op_registry_["MatMul"] = ConvertMatMul; + op_registry_["BatchMatMul"] = ConvertBatchMatMul; + op_registry_["TopKV2"] = ConvertTopK; +#endif + plugin_converter_ = ConvertPlugin; } @@ -2172,30 +2691,27 @@ tensorflow::Status ConvertGraphDefToEngine( // Graph nodes are already topologically sorted during construction for (const auto& node_def : gdef.node()) { string node_name = node_def.name(); - VLOG(1) << "Converting op name=" << node_name << ", op=" << node_def.op(); + VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op(); if (tensorflow::str_util::StartsWith(node_name, kInputPHName) && (node_def.op() == "Placeholder")) { nvinfer1::DimsCHW input_dim_pseudo_chw; for (int i = 0; i < 8; i++) input_dim_pseudo_chw.d[i] = 0; - nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); - auto type_status = - ConvertDType(node_def.attr().at("dtype").type(), &dtype); - if (type_status != tensorflow::Status::OK()) { - LOG(WARNING) << "Type conversion failed for " << node_name; - return type_status; - } int32 slot_number = -1; - if (!tensorflow::strings::safe_strto32(node_name.c_str() + 8, - &slot_number)) { - LOG(ERROR) << "Failed to parse slot number from " << node_name - << " +8= " << node_name.c_str() + 8; + if (!tensorflow::strings::safe_strto32( + node_name.c_str() + strlen(kInputPHName), &slot_number)) { + return tensorflow::errors::InvalidArgument( + "Failed to parse slot number from ", node_name); } + nvinfer1::DataType dtype; auto shape = input_shapes.at(slot_number); - if (shape.dims() > 8) { - LOG(ERROR) << "Tensor rank is greater than 8 for " << node_name - << " at input slot " << slot_number; - return tensorflow::errors::OutOfRange( - "Input tensor rank is greater than 8"); + auto status = ValidateInputProperties( + shape, node_def.attr().at("dtype").type(), &dtype); + if (!status.ok()) { + const string error_message = + StrCat("Validation failed for ", node_name, " and input slot ", + slot_number, ": ", status.error_message()); + LOG(WARNING) << error_message; + return Status(status.code(), error_message); } if (VLOG_IS_ON(1)) { string dim_str("dims="); @@ -2226,10 +2742,10 @@ tensorflow::Status ConvertGraphDefToEngine( } else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) && (node_def.op() == "Identity")) { int32 slot_number = -1; - if (!tensorflow::strings::safe_strto32(node_name.c_str() + 9, - &slot_number)) { - LOG(ERROR) << "Failed to parse slot number from " << node_name - << " +9=" << node_name.c_str() + 9; + if (!tensorflow::strings::safe_strto32( + node_name.c_str() + strlen(kOutputPHName), &slot_number)) { + return tensorflow::errors::InvalidArgument( + "Failed to parse slot number from ", node_name); } if (output_tensors.size() <= slot_number) { output_tensors.resize(slot_number + 1); @@ -2273,6 +2789,7 @@ tensorflow::Status ConvertGraphDefToEngine( tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, + const std::set& subgraph_node_names, const std::vector& subgraph_node_ids, // In topological order std::vector* connections, tensorflow::GraphDef* segment_def, string* common_scope) { @@ -2281,6 +2798,7 @@ tensorflow::Status ConvertSegmentToGraphDef( // nodes in the segment graphdef. for (size_t i = 0; i < connections->size(); ++i) { auto& connection = connections->at(i); + if (connection.is_control_edge()) continue; auto outside_node = graph->FindNodeId(connection.outside_id); if (!outside_node) { // This should never happen, unless the original graph is problematic. @@ -2288,38 +2806,20 @@ tensorflow::Status ConvertSegmentToGraphDef( "Cannot find node with id ", connection.outside_id, " in the graph."); } // Updates the shape and data types of input/output connections. - tensorflow::DataType input_type = tensorflow::DT_FLOAT; + tensorflow::DataType dtype; tensorflow::PartialTensorShape partial_shape; if (connection.is_input_edge) { - if (graph_properties.HasOutputProperties(connection.outside_node_name)) { - auto output_params = - graph_properties.GetOutputProperties(connection.outside_node_name); - auto out_shape = output_params.at(connection.outside_port); - input_type = out_shape.dtype(); - std::vector dims; - partial_shape = out_shape.shape(); - connection.outside_shape = partial_shape; - } else { - VLOG(0) << "Unknown output shape" << outside_node->name(); - input_type = graph->FindNodeId(connection.outside_id) - ->output_type(connection.outside_port); - } - connection.connection_type = input_type; - - } else { // output edge - if (graph_properties.HasInputProperties(connection.outside_node_name)) { - auto input_params = - graph_properties.GetInputProperties(connection.outside_node_name); - auto in_shape = input_params.at(connection.outside_port); - input_type = in_shape.dtype(); - partial_shape = in_shape.shape(); - connection.inside_shape = partial_shape; - } else { - input_type = graph->FindNodeId(connection.inside_id) - ->output_type(connection.outside_port); - } - connection.connection_type = input_type; + GetInputProperties(graph_properties, + graph->FindNodeId(connection.outside_id), + connection.outside_port, &partial_shape, &dtype); + connection.outside_shape = partial_shape; + } else { + GetOutputProperties(graph_properties, + graph->FindNodeId(connection.outside_id), + connection.outside_port, &partial_shape, &dtype); + connection.inside_shape = partial_shape; } + connection.connection_type = dtype; // Add dummy input/output nodes to the segment graphdef. if (connection.is_input_edge) { @@ -2335,7 +2835,7 @@ tensorflow::Status ConvertSegmentToGraphDef( auto seg_node = segment_def->add_node(); tensorflow::NodeDefBuilder builder(node_name, "Placeholder"); auto status = builder.Attr("shape", partial_shape) - .Attr("dtype", input_type) + .Attr("dtype", dtype) .Finalize(seg_node); VLOG(1) << "Constructing input " << node_name << " for the edge " << connection.outside_node_name << ":" << connection.outside_port @@ -2353,7 +2853,7 @@ tensorflow::Status ConvertSegmentToGraphDef( marker_nodes.insert(node_name); auto seg_node = segment_def->add_node(); tensorflow::NodeDefBuilder builder(node_name, "Identity"); - auto status = builder.Input(connection.inside_node_name, 0, input_type) + auto status = builder.Input(connection.inside_node_name, 0, dtype) .Finalize(seg_node); VLOG(1) << "Constructing output " << node_name << " for the edge " << connection.inside_node_name << ":" << connection.inside_port @@ -2371,12 +2871,12 @@ tensorflow::Status ConvertSegmentToGraphDef( old_to_new_id_map[node_id] = segment_def->node_size(); auto snode = segment_def->add_node(); snode->CopyFrom(node->def()); - VLOG(1) << "Copying " << snode->name() << " to subgraph"; + VLOG(2) << "Copying " << snode->name() << " to subgraph"; } // Update the inputs of the new input nodes to point to placeholder nodes. for (int i = 0; i < connections->size(); ++i) { auto& connection = connections->at(i); - if (!connection.is_input_edge) continue; + if (connection.is_control_edge() || !connection.is_input_edge) continue; auto snode = segment_def->mutable_node(old_to_new_id_map[connection.inside_id]); const string placeholder_name = @@ -2386,11 +2886,76 @@ tensorflow::Status ConvertSegmentToGraphDef( << placeholder_name; snode->set_input(connection.inside_port, placeholder_name); } + // Remove control inputs that are not inside the segment. + for (int i = 0; i < segment_def->node_size(); ++i) { + auto snode = segment_def->mutable_node(i); + const int input_size = snode->input_size(); + int input_idx = 0; + int actual_input_idx = 0; + while (input_idx < input_size) { + TensorId input = ParseTensorName(snode->input(input_idx)); + if (!subgraph_node_names.count( + string(input.first.data(), input.first.size())) && + !str_util::StartsWith(input.first, kInputPHName)) { + if (input.second == Graph::kControlSlot) { + VLOG(1) << "... removing control inputs " << input.first + << " from subgraph."; + ++input_idx; + continue; + } else { + return tensorflow::errors::InvalidArgument( + "Found non control input outside the segment that is not an " + "engine connection to ", + snode->name(), ": ", input.first); + } + } + if (actual_input_idx != input_idx) { + snode->set_input(actual_input_idx, snode->input(input_idx)); + } + ++input_idx; + ++actual_input_idx; + } + for (int remove = input_size - actual_input_idx; remove > 0; --remove) { + snode->mutable_input()->RemoveLast(); + } + } *common_scope = local_scope; VLOG(0) << "Segment @scope '" << local_scope << "', converted to graph"; return tensorflow::Status::OK(); } +bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const { + if (in_edge->IsControlEdge()) return true; + PartialTensorShape shape; + tensorflow::DataType dtype; + GetInputProperties(graph_properties_, in_edge->src(), in_edge->src_output(), + &shape, &dtype); + nvinfer1::DataType trt_dtype; + Status status = ValidateInputProperties(shape, dtype, &trt_dtype); + if (!status.ok()) { + VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name() + << ": " << status; + return false; + } + if (shape.dims() < 3 && in_edge->src()->type_string() != "Const") { + VLOG(1) << "--> Need to remove input node " << in_edge->dst()->name() + << " which has an input at port " << in_edge->dst_input() + << " with #dim<3 and is not a const: " << shape; + return false; + } + return true; +} + +bool OutputEdgeValidator::operator()(const tensorflow::Edge* out_edge) const { + if (out_edge->IsControlEdge()) return true; + if (out_edge->src()->type_string() == "Const") { + VLOG(1) << "--> Need to remove output node " << out_edge->src()->name() + << " which is a Const."; + return false; + } + return true; +} + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 7684d8d4a23ae22c855d82fc54931151a976eb2f..a60253740fe0b27dcd9c20618d6d05aa7001a1a1 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" #include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" #include "tensorflow/core/framework/graph.pb.h" @@ -35,19 +36,15 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -static const char* kInputPHName = "InputPH_"; -static const char* kOutputPHName = "OutputPH_"; +static const char* kInputPHName = "TensorRTInputPH_"; +static const char* kOutputPHName = "TensorRTOutputPH_"; namespace convert { -// TODO(aaroey): use an enum instead. -const int FP32MODE = 0; -const int FP16MODE = 1; -const int INT8MODE = 2; - struct EngineConnection { + // Constructs a non-control edge. EngineConnection(const string& outside, int out_id, int out_port, - const string& inside, int in_id, int in_port, - bool input_edge, int port) + const string& inside, int in_id, int in_port, + bool input_edge, int port) : outside_node_name(outside), outside_id(out_id), outside_port(out_port), @@ -57,21 +54,35 @@ struct EngineConnection { is_input_edge(input_edge), port_number(port) {} + // Constructs a control edge. + EngineConnection(const string& outside, int out_id, const string& inside, + int in_id, bool input_edge) + : outside_node_name(outside), + outside_id(out_id), + outside_port(Graph::kControlSlot), + inside_node_name(inside), + inside_id(in_id), + inside_port(Graph::kControlSlot), + is_input_edge(input_edge), + port_number(Graph::kControlSlot) {} + + bool is_control_edge() const { return port_number == Graph::kControlSlot; } + const string outside_node_name; const int outside_id; const int outside_port; - tensorflow::PartialTensorShape outside_shape; + tensorflow::PartialTensorShape outside_shape; // Only set for input edge. const string inside_node_name; const int inside_id; const int inside_port; - tensorflow::PartialTensorShape inside_shape; + tensorflow::PartialTensorShape inside_shape; // Only set for output edge. tensorflow::DataType connection_type; - bool is_input_edge; + const bool is_input_edge; - // The port number of the TRT node connecting to this edge. - int port_number; + // The port number of the TRT node connected with this edge. + const int port_number; }; struct EngineInfo { @@ -84,7 +95,9 @@ struct EngineInfo { string device; tensorflow::GraphDef segment_graph_def; - // The segment nodes that are on one side of the edges are topological sorted. + // Non-control input connections inside this vector are sorted in a way such + // that, the segment nodes connecting to them are topological sorted. + // In addition, for non-control connections, there must be no duplicates. std::vector connections; enum class EngineType { TRTStatic = 0, TRTDynamic = 1 }; @@ -100,13 +113,17 @@ struct EngineInfo { // (OutputPH_*). This function needs to be called before TensorRT nodes // inserted in order to correctly get sizes from the original graph. // +// - subgraph_node_names: the node names of the subgraph. // - subgraph_node_ids: the node ids of the subgraph, must be sorted in // topological order. // - segment_def: the output GraphDef, whose non-input/output nodedefs will be // sorted in topological order. +// +// TODO(aaroey): add tests to validate these properties. tensorflow::Status ConvertSegmentToGraphDef( const tensorflow::Graph* graph, const tensorflow::grappler::GraphProperties& graph_properties, + const std::set& subgraph_node_names, const std::vector& subgraph_node_ids, std::vector* connections, tensorflow::GraphDef* segment_def, string* common_scope); @@ -128,6 +145,30 @@ tensorflow::Status ConvertGraphDefToEngine( TrtUniquePtrType* engine, bool* convert_successfully); +// Helper class for the segmenter to determine whether an input edge to the TRT +// segment is valid. +class InputEdgeValidator { + public: + InputEdgeValidator(const grappler::GraphProperties& graph_properties) + : graph_properties_(graph_properties) {} + + // Return true if the specified edge is eligible to be an input edge of the + // TRT segment. + bool operator()(const tensorflow::Edge* in_edge) const; + + private: + const grappler::GraphProperties& graph_properties_; +}; + +// Helper class for the segmenter to determine whether an output edge from the +// TRT segment is valid. +class OutputEdgeValidator { + public: + // Return true if the specified edge is eligible to be an output edge of the + // TRT segment. + bool operator()(const tensorflow::Edge* out_edge) const; +}; + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index ec9dbfa13bfd0a158dcf41cf1fdb7128a2adf641..f33f2cc4d68f5ac10eafeb744f8162bfca0abfab 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -17,9 +17,11 @@ limitations under the License. #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stacktrace.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -188,9 +190,6 @@ tensorflow::Status TRTOptimizationPass::Optimize( tensorflow::grappler::Cluster* cluster, const tensorflow::grappler::GrapplerItem& item, GraphDef* optimized_graph) { VLOG(1) << "Called TRTOptimization Pass " << name_; - if (VLOG_IS_ON(1)) { - PrintDebugInfo(cluster, item); - } // This is a hack to workaround optimizer issue. MetaOptimizer calls // optimization passes on function objects as well, we should not modify // generated funcdefs! This is fragile but we don't have any other option @@ -202,6 +201,10 @@ tensorflow::Status TRTOptimizationPass::Optimize( *optimized_graph = item.graph; return tensorflow::Status::OK(); } + if (VLOG_IS_ON(1)) { + VLOG(2) << CurrentStackTrace(); + PrintDebugInfo(cluster, item); + } int max_dim = -1; if (item.feed.size()) { for (const auto& f : item.feed) { @@ -232,8 +235,25 @@ tensorflow::Status TRTOptimizationPass::Optimize( tensorflow::grappler::GraphProperties static_graph_properties(item); TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); tensorflow::tensorrt::convert::ConversionParams cp; + + std::vector nodes_to_preserve; + for (const auto& n : item.NodesToPreserve()) { + auto tokens = str_util::Split(n, ":"); + string s = tokens.at(0); + for (int i = 1; i < tokens.size() - 1; ++i) { + StrAppend(&s, ":", tokens.at(i)); + } + int dumm_port = -1; + // If the last token is not an integer, it must be part of the name. + // Otherwise it is port number. + if (tokens.size() > 1 && + !strings::safe_strto32(tokens.back(), &dumm_port)) { + StrAppend(&s, ":", tokens.back()); + } + nodes_to_preserve.push_back(s); + } cp.input_graph_def = &item.graph; - cp.output_names = &item.fetch; + cp.output_names = &nodes_to_preserve; cp.max_batch_size = maximum_batch_size_; cp.max_workspace_size_bytes = maximum_workspace_size_; cp.output_graph_def = optimized_graph; diff --git a/tensorflow/contrib/tensorrt/convert/utils.cc b/tensorflow/contrib/tensorrt/convert/utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..e7a1febb8c076891596741fe30721e7acca15a73 --- /dev/null +++ b/tensorflow/contrib/tensorrt/convert/utils.cc @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/convert/utils.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace tensorrt { + +bool IsGoogleTensorRTEnabled() { + // TODO(laigd): consider also checking if tensorrt shared libraries are + // accessible. We can then direct users to this function to make sure they can + // safely write code that uses tensorrt conditionally. E.g. if it does not + // check for for tensorrt, and user mistakenly uses tensorrt, they will just + // crash and burn. +#if GOOGLE_CUDA && GOOGLE_TENSORRT + return true; +#else + return false; +#endif +} + +Status GetPrecisionModeName(const int precision_mode, string* name) { + switch (precision_mode) { + case FP32MODE: + *name = "FP32"; + break; + case FP16MODE: + *name = "FP16"; + break; + case INT8MODE: + *name = "INT8"; + break; + default: + return tensorflow::errors::OutOfRange("Unknown precision mode"); + } + return Status::OK(); +} + +Status GetPrecisionMode(const string& name, int* precision_mode) { + if (name == "FP32") { + *precision_mode = FP32MODE; + } else if (name == "FP16") { + *precision_mode = FP16MODE; + } else if (name == "INT8") { + *precision_mode = INT8MODE; + } else { + return tensorflow::errors::InvalidArgument("Invalid precision mode name: ", + name); + } + return Status::OK(); +} + +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/utils.h b/tensorflow/contrib/tensorrt/convert/utils.h index f601c06701fdbf983b708cf5f5c7d22634bb810b..0592f31462af2b20f3a13fe5119e89c2ba42dd8a 100644 --- a/tensorflow/contrib/tensorrt/convert/utils.h +++ b/tensorflow/contrib/tensorrt/convert/utils.h @@ -18,6 +18,8 @@ limitations under the License. #include +#include "tensorflow/core/lib/core/status.h" + namespace tensorflow { namespace tensorrt { @@ -31,6 +33,17 @@ struct TrtDestroyer { template using TrtUniquePtrType = std::unique_ptr>; +bool IsGoogleTensorRTEnabled(); + +// TODO(aaroey): use an enum instead. +const int FP32MODE = 0; +const int FP16MODE = 1; +const int INT8MODE = 2; + +Status GetPrecisionModeName(const int precision_mode, string* name); + +Status GetPrecisionMode(const string& name, int* precision_mode); + } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index a89cf3ab8bfaecc74fc5890ccb7e7a7147278182..69058c5826822c519a69d50860c06b8ab3ec6578 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -112,7 +112,9 @@ cuda_py_test( ], tags = [ "manual", + "no_windows", "noguitar", + "nomac", "notap", ], ) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc index 988b35f74f3989481f59c52c6320623a26704327..2de79737501a11d9760f9a7d3953cf132e512145 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -65,7 +65,7 @@ class IncPluginTRT : public OpKernel { reinterpret_cast(context->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); IncrementKernel(input_tensor.flat().data(), inc_, output_tensor->flat().data(), input_shape.num_elements(), *stream); diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 75e32559bb055a49ccef2100c208c6277c0c4b60..2b42d81f475189f74a934c3aeed7d7fc34d4eb53 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -15,11 +15,14 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" #include + #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" #include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/contrib/tensorrt/test/utils.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -43,11 +46,11 @@ using ::tensorflow::strings::StrCat; // Helps simultaneous execution of native and TRT engines. class AsyncHelper : public tensorflow::core::RefCounted { public: - AsyncHelper(tensorflow::AsyncOpKernel::DoneCallback done) { done_ = done; } + AsyncHelper(AsyncOpKernel::DoneCallback done) { done_ = done; } ~AsyncHelper() override { done_(); } private: - tensorflow::AsyncOpKernel::DoneCallback done_; + AsyncOpKernel::DoneCallback done_; }; #define TYPECASE(dt, X, Y) \ @@ -120,15 +123,9 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) context->GetAttr("calibration_data", &calibration_data)); OP_REQUIRES_OK(context, context->GetAttr("segment_funcdef_name", &funcdef_name_)); - if (precision_string == "FP32") { - precision_mode_ = convert::FP32MODE; - } else if (precision_string == "FP16") { - precision_mode_ = convert::FP16MODE; - } else if (precision_string == "INT8") { - precision_mode_ = convert::INT8MODE; - } + OP_REQUIRES_OK(context, GetPrecisionMode(precision_string, &precision_mode_)); calibration_mode_ = - (precision_mode_ == convert::INT8MODE && calibration_data.size() == 0); + (precision_mode_ == INT8MODE && calibration_data.size() == 0); if (calibration_data.size()) { calibrator_.reset(new TRTInt8Calibrator(calibration_data)); calibration_data.resize(0); @@ -150,7 +147,7 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) } } -void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx, +void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper) { if (!calibration_mode_) { VLOG(1) << "Executing native engine"; @@ -177,7 +174,7 @@ void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx, helper->Ref(); // Increment count for calculating native graph VLOG(1) << "Executing native segment " << name(); lib->Run(opts, native_func_, inputs, outputs, - [ctx, outputs, helper](const tensorflow::Status& s) { + [this, ctx, outputs, helper](const tensorflow::Status& s) { tensorflow::core::ScopedUnref sc(helper); VLOG(1) << "Native Segment completed"; if (!s.ok()) { @@ -187,11 +184,13 @@ void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx, for (size_t t = 0; t < outputs->size(); ++t) { ctx->set_output(t, outputs->at(t)); } + test::AddTestValue(StrCat(this->name(), ":ExecuteNativeSegment"), + "done"); delete outputs; }); } -void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx, +void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper) { helper->Ref(); tensorflow::core::ScopedUnref sc(helper); @@ -230,13 +229,14 @@ void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx, reinterpret_cast(ctx->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); calib_res->calibrator_->setBatch(input_data, *stream); + test::AddTestValue(StrCat(name(), ":ExecuteCalibration"), "done"); VLOG(2) << "Passed calibration data"; ExecuteNativeSegment(ctx, helper); } -int TRTEngineOp::GetEngineBatch(tensorflow::OpKernelContext* ctx) { +int TRTEngineOp::GetEngineBatch(OpKernelContext* ctx) { int num_batch = ctx->input(0).shape().dim_size(0); int smallest_engine = 0; for (const auto i : cached_engine_batches_) { @@ -252,21 +252,20 @@ int TRTEngineOp::GetEngineBatch(tensorflow::OpKernelContext* ctx) { cached_engine_batches_.push_back(num_batch); VLOG(1) << "Running with batch size " << num_batch; } else { - string s("Engine buffer is full. buffer limit= "); - StrAppend(&s, max_cached_engines_, ", current entries= "); - for (auto i : cached_engine_batches_) StrAppend(&s, i, ", "); - StrAppend(&s, "Requested batch= ", num_batch); - LOG(ERROR) << s; - ctx->SetStatus(tensorflow::errors::ResourceExhausted( - "Requested batch size is not available and engine cache is full")); + string msg = + StrCat("Engine buffer is full. buffer limit=", max_cached_engines_, + ", current entries="); + for (auto i : cached_engine_batches_) StrAppend(&msg, i, ","); + StrAppend(&msg, " requested batch=", num_batch); + LOG(WARNING) << msg; return -1; } } return smallest_engine; } -void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, - tensorflow::AsyncOpKernel::DoneCallback done) { +void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, + AsyncOpKernel::DoneCallback done) { auto helper = new AsyncHelper(done); tensorflow::core::ScopedUnref sc(helper); if (calibration_mode_) { @@ -274,32 +273,54 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, return; } const int smallest_engine = GetEngineBatch(ctx); - if (smallest_engine < 0) return; // GetEngineBatch already set the status. + if (smallest_engine < 0) { + LOG(WARNING) << "Failed to get engine batch, running native segment for " + << name(); + ExecuteNativeSegment(ctx, helper); + return; + } const int num_batch = ctx->input(0).shape().dim_size(0); auto& engine_ctx_pair = GetEngine(smallest_engine, ctx); auto& trt_engine_ptr = engine_ctx_pair.first; if (!trt_engine_ptr) { LOG(WARNING) << "Engine retrieval for batch size " << num_batch - << " failed Running native segment"; + << " failed. Running native segment for " << name(); ExecuteNativeSegment(ctx, helper); return; } + const bool retry = ExecuteTrtEngine(ctx, num_batch, trt_engine_ptr.get(), + engine_ctx_pair.second.get()); + if (retry) { + LOG(WARNING) << "Failed to execute engine, " + << "retrying with native segment for " << name(); + ExecuteNativeSegment(ctx, helper); + return; + } +} +bool TRTEngineOp::ExecuteTrtEngine( + OpKernelContext* ctx, const int num_batch, + nvinfer1::ICudaEngine* trt_engine_ptr, + nvinfer1::IExecutionContext* trt_execution_context_ptr) { + const bool kRetry = true; const int num_binding = ctx->num_inputs() + ctx->num_outputs(); std::vector buffers(num_binding); for (int i = 0; i < ctx->num_inputs(); i++) { - const string inp_name = StrCat(kInputPHName, i); + const string input_name = StrCat(kInputPHName, i); const size_t binding_index = - trt_engine_ptr->getBindingIndex(inp_name.c_str()); + trt_engine_ptr->getBindingIndex(input_name.c_str()); + if (binding_index == -1) { + LOG(ERROR) << "Input node not found, at " << input_name; + return kRetry; + } const Tensor& input_tensor = ctx->input(i); const TensorShape& input_shape = input_tensor.shape(); if (num_batch != input_shape.dim_size(0)) { - LOG(ERROR) << "input data inconsistent batch size"; - ctx->SetStatus(tensorflow::errors::FailedPrecondition( - "Different batch sizes between input tensors")); - return; + LOG(ERROR) << "Input data has inconsistent batch size: " << num_batch + << " vs " << input_shape.dim_size(0); + return kRetry; } auto dtype = trt_engine_ptr->getBindingDataType(binding_index); switch (dtype) { @@ -308,27 +329,26 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, break; case nvinfer1::DataType::kHALF: LOG(ERROR) << "FP16 inputs are not supported yet!"; - ctx->SetStatus(tensorflow::errors::InvalidArgument( - "FP16 inputs are not supported!")); - return; + return kRetry; case nvinfer1::DataType::kINT8: LOG(ERROR) << "INT8 inputs are not supported yet!"; - ctx->SetStatus(tensorflow::errors::InvalidArgument( - "INT8 inputs are not supported!")); - return; + return kRetry; +#if NV_TENSORRT_MAJOR > 3 + case nvinfer1::DataType::kINT32: + buffers[binding_index] = (void*)(input_tensor.flat().data()); + break; +#endif default: LOG(ERROR) << "Unknown TRT data type: " << int(dtype); - ctx->SetStatus(tensorflow::errors::InvalidArgument( - "Unknown ouput TRT data type! ", static_cast(dtype))); - return; + return kRetry; } } for (int i = 0; i < ctx->num_outputs(); i++) { // Create an output tensor const string output_name = StrCat(kOutputPHName, i); - const size_t binding_index = trt_engine_ptr->getBindingIndex( - output_name.c_str()); + const size_t binding_index = + trt_engine_ptr->getBindingIndex(output_name.c_str()); Tensor* output_tensor = nullptr; TensorShape output_shape; @@ -337,20 +357,23 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, std::vector trt_shape(dims.nbDims + 1); trt_shape[0] = num_batch; for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j]; - OP_REQUIRES_OK( - ctx, TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(), - &output_shape)); + auto status = TensorShapeUtils::MakeShape( + trt_shape.data(), trt_shape.size(), &output_shape); + if (!status.ok()) { + LOG(ERROR) << "Failed to get output shape: " << status; + return kRetry; + } } else { - LOG(ERROR) << "output node not found, at " << output_name; - ctx->SetStatus(tensorflow::errors::Internal("output ", output_name, - " couldn't be found!")); - return; + LOG(ERROR) << "Output node not found, at " << output_name; + return kRetry; } auto status = ctx->allocate_output(i, output_shape, &output_tensor); if (!status.ok()) { LOG(ERROR) << "Allocating output failed with " << status; ctx->SetStatus(status); - return; + // Do not retry since we cannot allocate the same output twice. + // TODO(aaroey): ideally we should retry, fix this. + return !kRetry; } auto dtype = trt_engine_ptr->getBindingDataType(binding_index); switch (dtype) { @@ -359,39 +382,39 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, reinterpret_cast(output_tensor->flat().data()); break; case nvinfer1::DataType::kHALF: - LOG(ERROR) << "half size is not supported yet!"; - ctx->SetStatus(tensorflow::errors::InvalidArgument( - "Half outputs are not supported!")); - return; + LOG(WARNING) << "half size is not supported yet!"; + return kRetry; case nvinfer1::DataType::kINT8: - LOG(ERROR) << "int8 is not supported yet!"; - ctx->SetStatus(tensorflow::errors::InvalidArgument( - "INT8 outputs are not supported!")); - return; + LOG(WARNING) << "int8 is not supported yet!"; + return kRetry; +#if NV_TENSORRT_MAJOR > 3 + case nvinfer1::DataType::kINT32: + buffers[binding_index] = + reinterpret_cast(output_tensor->flat().data()); + break; +#endif default: - LOG(ERROR) << "Unknown TRT data type: " << static_cast(dtype); - ctx->SetStatus(tensorflow::errors::InvalidArgument( - "Unsupported output data type! ", int(dtype))); - return; + LOG(WARNING) << "Unknown TRT data type: " << static_cast(dtype); + return kRetry; } } - // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files + // Copied from cuda_kernel_helper since it seems only valid in *.cu.cc files const cudaStream_t* stream = CHECK_NOTNULL( reinterpret_cast(ctx->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); // TODO(jie): trt enqueue does not return error - auto& trt_execution_context_ptr = engine_ctx_pair.second; auto ret = trt_execution_context_ptr->enqueue(num_batch, &buffers[0], *stream, nullptr); if (!ret) { - LOG(ERROR) << "Failed to enqueue batch for TRT engine: " << name(); - ctx->SetStatus(tensorflow::errors::Internal( - "Failed to enqueue batch for TRT engine: ", name())); + LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name(); + return kRetry; } - // sync should be done by TF. + test::AddTestValue(StrCat(name(), ":ExecuteTrtEngine"), "done"); + // Synchronization will be done by TF. + return !kRetry; } TRTEngineOp::~TRTEngineOp() { @@ -411,8 +434,6 @@ nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) { if (!alloc) { LOG(ERROR) << "Can't find device allocator for gpu device " << device->name(); - ctx->SetStatus(tensorflow::errors::Internal( - "Can't get device allocator for device ", device->name())); return nullptr; } allocator_.reset(new TRTDeviceAllocator(alloc)); @@ -420,10 +441,10 @@ nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) { } TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, - OpKernelContext* ctx) { + OpKernelContext* ctx) { static EngineCtxPair null_pair = { - TrtUniquePtrType(nullptr), - TrtUniquePtrType(nullptr)}; + TrtUniquePtrType(nullptr), + TrtUniquePtrType(nullptr)}; // TODO(sami): This method needs to be re-written to use resource manager and // with LRU mechanism option. tensorflow::mutex_lock lock(engine_mutex_); @@ -439,23 +460,25 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, #if NV_TENSORRT_MAJOR > 3 auto allocator = GetAllocator(ctx); if (allocator == nullptr) { - // GetAllocator already set the Status. return null_pair; } infer->setGpuAllocator(allocator); #endif TrtUniquePtrType static_engine( infer->deserializeCudaEngine(serialized_segment_.c_str(), - serialized_segment_.size(), nullptr)); + serialized_segment_.size(), + PluginFactoryTensorRT::GetInstance())); auto raw_static_engine = static_engine.get(); const auto max_batch_size = raw_static_engine->getMaxBatchSize(); engine_map_[max_batch_size] = { - std::move(static_engine), - TrtUniquePtrType( - raw_static_engine->createExecutionContext())}; + std::move(static_engine), + TrtUniquePtrType( + raw_static_engine->createExecutionContext())}; // Runtime is safe to delete after engine creation serialized_segment_.clear(); - if (max_batch_size < batch_size) return null_pair; + if (max_batch_size < batch_size) { + return null_pair; + } return engine_map_.at(max_batch_size); } // static_engine_ @@ -467,7 +490,6 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, #if NV_TENSORRT_MAJOR > 3 allocator = GetAllocator(ctx); if (allocator == nullptr) { - // GetAllocator already set the Status. return null_pair; } #endif @@ -491,9 +513,8 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, // retry in the future. engine_map_[batch_size] = {nullptr, nullptr}; } - LOG(ERROR) << "Engine creation for batch size " << batch_size - << " failed " << status; - ctx->SetStatus(tensorflow::errors::Internal("Engine creation failed!")); + LOG(WARNING) << "Engine creation for batch size " << batch_size + << " failed " << status; return null_pair; } VLOG(1) << "Conversion is done"; @@ -505,7 +526,7 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, } tensorflow::Status TRTEngineOp::AllocateCalibrationResources( - tensorflow::OpKernelContext* ctx, TRTCalibrationResource** cr) { + OpKernelContext* ctx, TRTCalibrationResource** cr) { auto cres = new TRTCalibrationResource(); *cr = cres; // Get the allocator. @@ -569,7 +590,7 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources( // TODO(aaroey): maybe setting the max batch size using the python // calibration wrapper class. auto s = convert::ConvertGraphDefToEngine( - *segment_graph, convert::INT8MODE, cres->calibrator_->getBatchSize(), + *segment_graph, INT8MODE, cres->calibrator_->getBatchSize(), workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(), cres->calibrator_.get(), &cres->engine_, /*convert_successfully=*/nullptr); diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h index 6fe318be6a6bc9f01ce3b52e0430f2090b53002b..8fe06758914261035c90a6fda3f114a63a8ac93a 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h @@ -35,7 +35,7 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -class TRTInt8Calibrator; +struct TRTInt8Calibrator; class TRTCalibrationResource; class AsyncHelper; // TODO(Sami): Remove this file? @@ -60,6 +60,12 @@ class TRTEngineOp : public AsyncOpKernel { // Execute replaced native segment as function Op. void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper); + // Execute the tensorrt engine. Returns whether we need to retry by running + // the native segment. + bool ExecuteTrtEngine(OpKernelContext* ctx, const int num_batch, + nvinfer1::ICudaEngine* trt_engine_ptr, + nvinfer1::IExecutionContext* trt_execution_context_ptr); + // Allocate necessary resources for calibration Status AllocateCalibrationResources(OpKernelContext* ctx, TRTCalibrationResource** cr); @@ -81,7 +87,7 @@ class TRTEngineOp : public AsyncOpKernel { std::vector output_nodes_; // keep device allocator for TRT. - std::unique_ptr allocator_; + std::unique_ptr allocator_; // serialized protobuf segment or trt engine depending on static_engine_ flag. string serialized_segment_; diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc index 383635f428812984915a8c46ad3b92cc7b28a5f7..e0c7b6272379a20e3dacb6cd7c3b39de735d844d 100644 --- a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc @@ -42,8 +42,14 @@ REGISTER_OP("TRTEngineOp") .Attr("precision_mode: {'FP32', 'FP16', 'INT8', 'INT8CALIB'}") .Attr("calibration_data: string = ''") .Input("in_tensor: InT") - .Output("out_tensor: OutT") - .SetShapeFn(shape_inference::TRTEngineOpShapeInference); + .Output("out_tensor: OutT"); +// TODO(jie): TF requires concrete output shape for concrete input shapes. +// This is tricky for batch dimension, since we cannot ensure which input +// would carry the correct batch dimension (for the current stage of the +// implementation, we do require all input tensor to carry the same batch +// size, but this could change in the future). Hence we disable shape +// inference function as a workaround. +// .SetShapeFn(shape_inference::TRTEngineOpShapeInference); } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc index 2bc591484dcaf5b35c39f3d0523dd89dcd152e6a..cccc91226265ed139fb8db0b71c40b868f729562 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -65,9 +65,6 @@ bool PluginFactoryTensorRT::RegisterPlugin( void PluginFactoryTensorRT::DestroyPlugins() { tensorflow::mutex_lock lock(instance_m_); - for (auto& owned_plugin_ptr : owned_plugins_) { - owned_plugin_ptr.release(); - } owned_plugins_.clear(); } diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py index 0b2321b5fc7bcbd53c01d1c97cafcfcb229a83ef..7cdfe2b1a612be2eec473d806d0eb44b611ca68a 100644 --- a/tensorflow/contrib/tensorrt/python/__init__.py +++ b/tensorflow/contrib/tensorrt/python/__init__.py @@ -20,6 +20,11 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +from tensorflow.contrib.tensorrt.python.trt_convert import add_test_value from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph +from tensorflow.contrib.tensorrt.python.trt_convert import clear_test_values from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph +from tensorflow.contrib.tensorrt.python.trt_convert import enable_test_value +from tensorflow.contrib.tensorrt.python.trt_convert import get_test_value +from tensorflow.contrib.tensorrt.python.trt_convert import is_tensorrt_enabled # pylint: enable=unused-import,line-too-long diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 79f512dbcf6bd4d84b98cf69630778734566391c..4116f2fe30aa5c0c9ea139100291abe3b13da94b 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -20,25 +20,26 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long import six as _six +from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert +from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values +from tensorflow.contrib.tensorrt.wrap_conversion import enable_test_value from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version -from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert +from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value +from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled from tensorflow.core.framework import graph_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 -from tensorflow.python.framework import errors from tensorflow.python.framework import errors_impl as _impl -from tensorflow.python.framework import meta_graph +from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.grappler import tf_optimizer from tensorflow.python.platform import tf_logging -from tensorflow.python.util import compat - +from tensorflow.python.training import saver # pylint: enable=unused-import,line-too-long -# TODO(skama): get outputs from session when implemented as c++ -# optimization pass def create_inference_graph(input_graph_def, outputs, max_batch_size=1, @@ -47,7 +48,7 @@ def create_inference_graph(input_graph_def, minimum_segment_size=3, is_dynamic_op=False, maximum_cached_engines=1, - cached_engine_batches=[]): + cached_engine_batches=None): """Python wrapper for the TRT transformation. Args: @@ -86,8 +87,7 @@ def create_inference_graph(input_graph_def, (".".join([str(x) for x in compiled_version]), ".".join([str(x) for x in loaded_version])) + ". Please make sure that correct version of TensorRT " + - "is available in the system and added to ldconfig or LD_LIBRARY_PATH" - ) + "is available in the system and added to ldconfig or LD_LIBRARY_PATH") raise RuntimeError("Incompatible TensorRT library version") for i in zip(loaded_version, compiled_version): if i[0] != i[1]: @@ -120,41 +120,42 @@ def create_inference_graph(input_graph_def, to_bytes = py3bytes to_string = py3string - out_names = [] - for i in outputs: - if isinstance(i, ops.Tensor): - out_names.append(to_bytes(i.name)) - else: - out_names.append(to_bytes(i)) - - input_graph_def_str = input_graph_def.SerializeToString() - - # TODO(sami): Fix this when we can return status from C++ library - # There is a problem with the TF internal library setup that doesn't - # allow us to return a status object from C++. Thus we return a - # pair or strings where first one is encoded status and the second - # one is the transformed graphs protobuf string. - out = trt_convert(input_graph_def_str, out_names, max_batch_size, - max_workspace_size_bytes, mode, minimum_segment_size, - is_dynamic_op, maximum_cached_engines, - cached_engine_batches) - status = to_string(out[0]) - output_graph_def_string = out[1] - del input_graph_def_str # Save some memory - if len(status) < 2: - raise _impl.UnknownError(None, None, status) - if status[:2] != "OK": - msg = status.split(";") - if len(msg) == 1: - raise RuntimeError("Status message is malformed {}".format(status)) - # pylint: disable=protected-access - raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), - int(msg[0])) - # pylint: enable=protected-access - output_graph_def = graph_pb2.GraphDef() - output_graph_def.ParseFromString(output_graph_def_string) - del output_graph_def_string # Save some memory - return output_graph_def + # Create MetaGraphDef + graph = ops.Graph() + with graph.as_default(): + importer.import_graph_def(input_graph_def, name="") + meta_graph = saver.export_meta_graph( + graph_def=graph.as_graph_def(), graph=graph) + if outputs: + output_collection = meta_graph_pb2.CollectionDef() + output_list = output_collection.node_list.value + for i in outputs: + if isinstance(i, ops.Tensor): + output_list.append(to_bytes(i.name)) + else: + output_list.append(to_bytes(i)) + meta_graph.collection_def["train_op"].CopyFrom(output_collection) + + # Create RewriterConfig. + rewriter_cfg = rewriter_config_pb2.RewriterConfig() + rewriter_cfg.optimizers.extend(["constfold", "layout"]) + optimizer = rewriter_cfg.custom_optimizers.add() + optimizer.name = "TensorRTOptimizer" + optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size + optimizer.parameter_map["max_batch_size"].i = max_batch_size + optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op + optimizer.parameter_map[ + "max_workspace_size_bytes"].i = max_workspace_size_bytes + optimizer.parameter_map["precision_mode"].s = to_bytes(precision_mode) + optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines + if cached_engine_batches: + if not isinstance(cached_engine_batches, list): + raise TypeError("cached_engine_batches should be a list.") + optimizer.parameter_map["cached_engine_batches"].list.i.extend( + cached_engine_batches) + + return tf_optimizer.OptimizeGraph( + rewriter_cfg, meta_graph, graph_id=b"tf_graph") def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc index 9f115990c3a3e6e92093e5f0d82b985af1b25482..d8f97bfbbc7adb10a5dda6fbc2f7a660f6cd7742 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc @@ -19,12 +19,42 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT +#include "cuda/include/cuda_runtime_api.h" +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +namespace tensorflow { +namespace tensorrt { + +// std::align is not supported, so this method mimic its behavior. +void* Align(size_t alignment, size_t size, void*& ptr, size_t& space) { + QCHECK_GT(alignment, 0) << "alignment must be greater than 0."; + QCHECK_EQ(0, alignment & (alignment - 1)) << "Alignment must be power of 2."; + QCHECK_GT(size, 0) << "size must be greater than 0."; + QCHECK(ptr) << "ptr must not be nullptr."; + QCHECK_GT(space, 0) << "space must be greater than 0."; + const uintptr_t ptr_val = reinterpret_cast(ptr); + QCHECK_GE(ptr_val + space, ptr_val) << "Provided space overflows."; + if (size > space) return nullptr; + const uintptr_t aligned_ptr_val = ((ptr_val + alignment - 1) & -alignment); + if (aligned_ptr_val > ptr_val + space - size) return nullptr; + ptr = reinterpret_cast(aligned_ptr_val); + const uintptr_t diff = aligned_ptr_val - ptr_val; + space -= diff; + return ptr; +} + +} // namespace tensorrt +} // namespace tensorflow + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT #if NV_TENSORRT_MAJOR > 2 -#include "cuda/include/cuda_runtime_api.h" namespace tensorflow { namespace tensorrt { + void* TRTCudaAllocator::allocate(uint64_t size, uint64_t alignment, uint32_t flags) { assert((alignment & (alignment - 1)) == 0); // zero or a power of 2. @@ -37,10 +67,23 @@ void TRTCudaAllocator::free(void* memory) { cudaFree(memory); } void* TRTDeviceAllocator::allocate(uint64_t size, uint64_t alignment, uint32_t flags) { + // WAR for allocator alignment requirement. Certain cuda API calls require GPU + // memory with alignemtn to cudaDeviceProp::textureAlignment. + // See issue #20856 + alignment = 512; assert((alignment & (alignment - 1)) == 0); // zero or a power of 2. - void* mem = allocator_->AllocateRaw(alignment, size); - VLOG(2) << "Allocated " << size << " bytes with alignment " << alignment - << " @ " << mem; + size_t total_size = size + alignment; + void* mem = allocator_->AllocateRaw(alignment, total_size); + if (!mem) return nullptr; + + void* alloc_mem = mem; + QCHECK(Align(alignment, size, mem, total_size)); + if (mem != alloc_mem) { + QCHECK(mem_map_.insert({mem, alloc_mem}).second); + } + VLOG(2) << "Allocated " << total_size << " bytes memory @" << alloc_mem + << "; aligned to " << size << " bytes @" << mem << " with alignment " + << alignment; return mem; } @@ -51,12 +94,20 @@ TRTDeviceAllocator::TRTDeviceAllocator(tensorflow::Allocator* allocator) void TRTDeviceAllocator::free(void* memory) { VLOG(2) << "Deallocating @ " << memory; - allocator_->DeallocateRaw(memory); + // allocated memory adjusted for alignment, restore the original pointer + if (memory) { + auto alloc_mem = mem_map_.find(memory); + if (alloc_mem != mem_map_.end()) { + memory = alloc_mem->second; + mem_map_.erase(alloc_mem->first); + } + allocator_->DeallocateRaw(memory); + } } } // namespace tensorrt } // namespace tensorflow #endif -#endif -#endif +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.h b/tensorflow/contrib/tensorrt/resources/trt_allocator.h index c5d2cec730f4ae97e4c6bcc19897fd9f321122a7..6f944920835b475fc7d12167dbcefa0111b6fb19 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.h +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.h @@ -16,13 +16,25 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ #define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ -#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include + #include "tensorflow/core/framework/allocator.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT #include "tensorrt/include/NvInfer.h" +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +namespace tensorflow { +namespace tensorrt { +// std::align is not supported, so this function mimic its behavior. +void* Align(size_t alignment, size_t size, void*& ptr, size_t& space); +} // namespace tensorrt +} // namespace tensorflow +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT #if NV_TENSORRT_MAJOR == 3 // Define interface here temporarily until TRT 4.0 is released namespace nvinfer1 { @@ -37,7 +49,14 @@ class IGpuAllocator { namespace tensorflow { namespace tensorrt { -class TRTCudaAllocator : public nvinfer1::IGpuAllocator { +class TRTBaseAllocator : public nvinfer1::IGpuAllocator { + // Base allocator class so we can have a virtual destructor; + public: + // python wrapper seems to be not happy with an pure virtual destructor; + virtual ~TRTBaseAllocator() = default; +}; + +class TRTCudaAllocator : public TRTBaseAllocator { // Allocator implementation that is using cuda allocator instead of device // allocator in case we can't get device allocator from TF. public: @@ -47,10 +66,13 @@ class TRTCudaAllocator : public nvinfer1::IGpuAllocator { void free(void* memory) override; }; -class TRTDeviceAllocator : public nvinfer1::IGpuAllocator { +class TRTDeviceAllocator : public TRTBaseAllocator { // Allocator implementation wrapping TF device allocators. public: TRTDeviceAllocator(tensorflow::Allocator* allocator); + + // TODO(aaroey): base class doesn't have a virtual destructor, work with + // Nvidia to fix it. virtual ~TRTDeviceAllocator() { VLOG(1) << "Destroying allocator attached to " << allocator_->Name(); } @@ -59,6 +81,9 @@ class TRTDeviceAllocator : public nvinfer1::IGpuAllocator { private: tensorflow::Allocator* allocator_; + + // supporting alignment from allocation request requires a map to free; + std::unordered_map mem_map_; }; } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f515ed03f245f11ad461bac07970c5001a56aaad --- /dev/null +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace tensorrt { + +bool RunTest(const size_t alignment, const size_t size, + const intptr_t orig_ptr_val, const size_t orig_space) { + void* const orig_ptr = reinterpret_cast(orig_ptr_val); + void* ptr = orig_ptr; + size_t space = orig_space; + void* result = Align(alignment, size, ptr, space); + if (result == nullptr) { + EXPECT_EQ(orig_ptr, ptr); + EXPECT_EQ(orig_space, space); + return false; + } else { + EXPECT_EQ(result, ptr); + const intptr_t ptr_val = reinterpret_cast(ptr); + EXPECT_EQ(0, ptr_val % alignment); + EXPECT_GE(ptr_val, orig_ptr_val); + EXPECT_GE(space, size); + EXPECT_LE(space, orig_space); + EXPECT_EQ(ptr_val + space, orig_ptr_val + orig_space); + return true; + } +} + +TEST(TRTAllocatorTest, Align) { + for (const size_t space : + {1, 2, 3, 4, 7, 8, 9, 10, 16, 32, 511, 512, 513, 700, 12345}) { + for (size_t alignment = 1; alignment <= space * 4; alignment *= 2) { + for (const intptr_t ptr_val : + {1ul, alignment == 1 ? 1ul : alignment - 1, alignment, alignment + 1, + alignment + (alignment / 2)}) { + if (ptr_val % alignment == 0) { + for (const size_t size : + {1ul, space == 1 ? 1ul : space - 1, space, space + 1}) { + EXPECT_EQ(space >= size, RunTest(alignment, size, ptr_val, space)); + } + } else { + EXPECT_FALSE(RunTest(alignment, space, ptr_val, space)); + const size_t diff = alignment - ptr_val % alignment; + if (space > diff) { + EXPECT_TRUE( + RunTest(alignment, space - diff, ptr_val + diff, space - diff)); + for (const size_t size : + {1ul, space - diff > 1 ? space - diff - 1 : 1ul, space - diff, + space - diff + 1, space - 1}) { + EXPECT_EQ(space - diff >= size, + RunTest(alignment, size, ptr_val, space)); + } + } else { + EXPECT_FALSE(RunTest(alignment, 1, ptr_val, space)); + } + } + } + } + } +} + +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc index 32e81858b95d76a2baebb4804a1326fbbb6144c7..dab1dd9343be7d5b033a3e04bf0b49fbbf37e9e5 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc @@ -36,13 +36,14 @@ TRTInt8Calibrator::TRTInt8Calibrator( : batch_size_(batch_size), done_(false), dev_buffers_(dev_buffers), + // Make sure setBatch() waits until getBatch() is called (the first time). calib_running_(true), batch_is_set_(false), engine_name_(engine_name) {} TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data) : batch_size_(0), - done_(false), + done_(true), calib_running_(false), batch_is_set_(false), calibration_table_(calib_data) {} @@ -50,13 +51,14 @@ TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data) bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, const cudaStream_t stream) { tensorflow::mutex_lock lock(cond_mtx_); - // wait while calibration is running. - while ((calib_running_ || batch_is_set_) && !done_) { - cond_.wait(lock); - } + + // Wait while the queue is full or calibration is running. + while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock); if (done_) return false; CHECK(!calib_running_ && !batch_is_set_); VLOG(1) << "Set Batch Waiting finished"; + + // Sets the batch. for (const auto it : data) { auto devptr = dev_buffers_.find(it.first); if (devptr == dev_buffers_.end()) { @@ -76,8 +78,8 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, } // TODO(Sami, aaorey): Find an alternative way! - cudaStreamSynchronize( - stream); // we have to wait for the stream before returning! + // we have to wait for the stream before returning! + cudaStreamSynchronize(stream); batch_is_set_ = true; cond_.notify_all(); return true; @@ -86,21 +88,21 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, int num_bindings) { tensorflow::mutex_lock lock(cond_mtx_); + // Notify finish of last round of calibration. calib_running_ = false; cond_.notify_all(); - // wait until new batch arrives - while ((!batch_is_set_ && !done_)) { - cond_.wait(lock); - } + + // Wait until new batch arrives + while ((!batch_is_set_ && !done_)) cond_.wait(lock); if (done_) return false; + // Gets the batch for (int i = 0; i < num_bindings; i++) { auto it = dev_buffers_.find(names[i]); if (it == dev_buffers_.end()) { LOG(FATAL) << "Calibration engine asked for unknown tensor name '" << names[i] << "' at position " << i; } - bindings[i] = it->second.first; } batch_is_set_ = false; @@ -108,6 +110,17 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, return true; } +void TRTInt8Calibrator::waitAndSetDone() { + tensorflow::mutex_lock lock(cond_mtx_); + // Wait while the queue is full or calibration is running, so we don't miss + // the last batch. + while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock); + if (!done_) { + done_ = true; + cond_.notify_all(); + } +} + const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) { if (calibration_table_.empty()) return nullptr; length = calibration_table_.size(); diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h index 994312d7c3c93ba04394b7d9542d261c57c5609b..65466c9741989fda5f82fc27d813d026f35fe386 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h @@ -36,10 +36,13 @@ namespace tensorrt { struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { public: + // Construct a calibrator for future calibration. TRTInt8Calibrator( const std::unordered_map>& dev_buffers, int batch_size, string engine_name); + // Construct a finalized calibrator where we don't need to run calibration any + // more, as the calibration data is provided. TRTInt8Calibrator(const string& calibration_data); ~TRTInt8Calibrator(); @@ -52,6 +55,11 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { bool setBatch(const std::unordered_map& data, const cudaStream_t stream); + // Wait until the last batch is consumed by the calibrator and set done. + void waitAndSetDone(); + + // Notify that calibration is done and future batches provided by setBatch() + // will be ignored. void setDone(); // If not null, calibration is skipped. diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h index b7d5ffd6748ba34c6c4ddbfbfbb44edb6bf2aca8..d7d56cb95e033ea55bd3aa385a707e7a7cfc557b 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -64,7 +64,7 @@ class TRTCalibrationResource : public tensorflow::ResourceBase { std::unique_ptr calibrator_; TrtUniquePtrType builder_; TrtUniquePtrType engine_; - std::unique_ptr allocator_; + std::unique_ptr allocator_; tensorflow::tensorrt::Logger logger_; // TODO(sami): Use threadpool threads! std::unique_ptr thr_; diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index cc42913ecadc3e15fbb4a4a322f125579f075da2..b43f1b190f5f8cfe98959dd9f2838e4d45759e5c 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/segment/segment.h" +#include #include #include #include @@ -32,6 +33,7 @@ namespace tensorflow { namespace tensorrt { namespace segment { using ::tensorflow::strings::StrAppend; + // A simple graph representation to mirror tensorflow::Graph. This structure // helps saving memory since segmenter modifies the graph in place, preventing // the need to create a copy of the graph. It is composed of edges and nodes. @@ -215,7 +217,7 @@ namespace { bool CheckCycles(const std::unique_ptr& g, const SimpleNode* src, const std::vector& start) { - // copied from TF ReverseDFS. + // Copied from TF ReverseDFS, which only works for tensorflow::Graph. struct Work { SimpleNode* node; bool leave; // Are we entering or leaving n? @@ -269,6 +271,24 @@ bool CanContractEdge(const SimpleEdge* edge, // 1. Get all nodes incoming to 'dst', excluding 'src' // 2. Reverse DFS from those nodes // 3. If reverse DFS reaches 'src' then we have a cycle + // + // TODO(aaroey): there are several problems with the current approach: + // 1. src->dst->src, this is not detected but it should be; + // 2. src->dst->...(any node sequence that doesn't contain src)...->dst, this + // is detected but it should not be. + // + // Note that it's fine that dst connects back to src indirectly (i.e. through + // a path with length > 1 that consists of intermedia nodes other than src). + // While loops is one example. + // + // The goal is to make sure that the trt subgraph: + // 1. has no loops (i.e. is a DAG), and + // 2. if there is a path in the subgraph from X to Y (X and Y are both nodes + // in the subgraph), then all paths from X to Y are in the subgraph. + // + // To achieve this goal, the correct way seems to be: + // 1. remove any direct edge from src->dst; + // 2. detect if src can reach dst, if so they cannot be merged. std::vector dfs_start_nodes; for (SimpleNode* node : dst->in_nodes()) { if (node != src) { @@ -276,8 +296,8 @@ bool CanContractEdge(const SimpleEdge* edge, } } - bool is_cycle = CheckCycles(graph, src, dfs_start_nodes); - return !is_cycle; + const bool has_cycle = CheckCycles(graph, src, dfs_start_nodes); + return !has_cycle; } } // namespace @@ -342,22 +362,20 @@ void ContractEdge(SimpleEdge* edge, SimpleGraph* graph, } tensorflow::Status SegmentGraph( - const tensorflow::GraphDef& gdef, - const std::function& candidate_fn, - const SegmentOptions& options, SegmentNodesVector* segments) { - // Create a Graph representation of the GraphDef. - tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), - gdef.library()); - tensorflow::Graph graph(flib); - TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( - tensorflow::GraphConstructorOptions(), gdef, &graph)); - return SegmentGraph(&graph, candidate_fn, options, segments); -} - -tensorflow::Status SegmentGraph( - tensorflow::Graph* tf_graph, + const tensorflow::Graph* tf_graph, const std::function& candidate_fn, + const std::function& input_candidate_fn, + const std::function& output_candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments) { + // Steps: + // 1. run the segmentation algorithm to find all the segments, which uses + // candidate_fn to determine the candidates segment nodes; + // 2. for each segments, remove the nodes that are inputs/outputs of the + // segment but are not eligible, using input/output_candidate_fn to + // determine the eligibilities; + // 3. convert the segment into expected return format and return the result. + + // --------------------------------- Step 1 --------------------------------- auto graph = std::unique_ptr(new SimpleGraph(tf_graph)); // Use a union-find to collect the nodes that belong to the same // segment. A node value of nullptr indicates that the node is not a candidate @@ -372,14 +390,19 @@ tensorflow::Status SegmentGraph( node_segments.emplace_back(node); } - // The segmentation algorithm below visits nodes in reverse - // topological order and attempts to merge nodes along output - // edges. That means that subgraphs grow from the output-side of the - // network towards the inputs. In general this is not guaranteed to - // produce a globally optimal segmentation. In the future if we have - // a measure of how beneficial it is to include a given node in a - // TRT subgraph then we can revisit this algorithm to take advantage - // of that information. + // The segmentation algorithm below visits nodes in reverse topological order + // and attempts to merge nodes along output edges. That means that subgraphs + // grow from the output-side of the network towards the inputs. + // + // In general this is not guaranteed to produce a globally optimal + // segmentation. For exaample, consider graph with node {A, B, C, D} and edges + // {A->B, A->C, B->D, C->D), where A, B, D are trt compatible but C is not, so + // in theory we can choose to contract either A, B or B, D but not both, but + // here it always choose to contract B, D. + // + // In the future if we have a measure of how beneficial it is to include a + // given node in a TRT subgraph then we can revisit this algorithm to take + // advantage of that information. std::vector tforder; tensorflow::GetPostOrder(*tf_graph, &tforder); // use postorder implementation from tensorflow and construct mirror in @@ -391,44 +414,39 @@ tensorflow::Status SegmentGraph( } for (const SimpleNode* node : order) { // All output nodes of 'node' have been visited... - VLOG(2) << "Trying node " << node->name() << " id=" << node->id(); - + VLOG(3) << "Trying node " << node->name() << " id=" << node->id(); // 'node' must be a TRT candidate... if (node_segments[node->id()].Value() == nullptr) { - VLOG(2) << "... not a TRT candidate"; + VLOG(3) << "... not a TRT candidate"; continue; } - // Contract output edges to combine 'node' with output // nodes. Iterate since combining two nodes may unblock other // combining. while (true) { std::set contract_edges; for (const SimpleEdge* out_edge : node->out_edges()) { - VLOG(2) << "... out node " << out_edge->dst()->name() << " ( " + VLOG(3) << "... out node " << out_edge->dst()->name() << " ( " << out_edge->dst()->id() << " <- " << node->id() << " )"; if (out_edge->IsControlEdge()) { - VLOG(2) << "... ... Control Edge, Skipping"; + VLOG(3) << "... ... Control Edge, Skipping"; continue; } // Out node must be TRT candidate... if (node_segments[out_edge->dst()->id()].Value() == nullptr) { - VLOG(2) << "... ... not a TRT candidate"; + VLOG(3) << "... ... not a TRT candidate"; continue; } - if (CanContractEdge(out_edge, graph)) { - VLOG(2) << "... ... can contract"; + VLOG(3) << "... ... can contract"; contract_edges.insert(out_edge); } else { - VLOG(2) << "... ... cannot contract, would form cycle"; + VLOG(3) << "... ... cannot contract, would form cycle"; } } - if (contract_edges.empty()) { break; } - // Contract edges and collect the adjacent nodes into the same // segment/subgraph. while (!contract_edges.empty()) { @@ -436,7 +454,7 @@ tensorflow::Status SegmentGraph( const SimpleNode* src = contract_edge->src(); const SimpleNode* dst = contract_edge->dst(); - VLOG(2) << "Merge " << src->name() << " <- " << dst->name() << " (" + VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " (" << src->id() << " <- " << dst->id(); node_segments[src->id()].Merge(&node_segments[dst->id()]); @@ -457,11 +475,22 @@ tensorflow::Status SegmentGraph( // Collect the segments/subgraphs. Each subgraph is represented by a // set of the names of the nodes in that subgraph. - std::unordered_map> sg_map; + + // A map from the segment identifier (currently the name of the root node of + // the segment tree) to the segment nodes set. + std::map> sg_map; + + // A map from the segment identifier (currently the name of the root node of + // the segment tree) to the device names that the nodes in the segment are + // assigned to. + // + // TODO(aaroey): nodes assigned to different devices should not be merged, + // fix this. std::unordered_map> device_maps; + for (auto& u : node_segments) { if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) { - sg_map[u.ParentValue()->name()].insert(u.Value()->name()); + sg_map[u.ParentValue()->name()].insert(u.Value()->tf_node()); auto tf_node = u.Value()->tf_node(); // has_assigned_device_name() is expected to return true // when called from optimization pass. However, since graph @@ -482,25 +511,113 @@ tensorflow::Status SegmentGraph( } } + // --------------------------------- Step 2 --------------------------------- + // Remove ineligible input/output nodes. + for (auto& itr : sg_map) { + std::set& segment_nodes = itr.second; + VLOG(1) << "Segment original size: " << segment_nodes.size(); + while (true) { + std::deque in_nodes_que, out_nodes_que; + // Find an input node that is not eligible and add it to the queue. + // Nodes that has no incoming edges should not be treated as "input", + // as there are really no inputs to them. Similar for output nodes. + for (auto node : segment_nodes) { + bool added = false; + for (const tensorflow::Edge* edge : node->in_edges()) { + if (!edge->IsControlEdge() && !edge->src()->IsSource() && + !segment_nodes.count(edge->src())) { // 'node' is an input node. + if (!input_candidate_fn(edge)) { + in_nodes_que.push_back(node); + added = true; + break; + } + } + } + if (added) continue; // Only adding the node once to either queue. + for (const tensorflow::Edge* edge : node->out_edges()) { + if (!edge->dst()->IsSink() && !edge->IsControlEdge() && + !segment_nodes.count(edge->dst())) { // 'node' is an output node. + if (!output_candidate_fn(edge)) { + out_nodes_que.push_back(node); + break; + } + } + } + } + if (in_nodes_que.empty() && out_nodes_que.empty()) { + // No more ineligible input/output nodes. + break; + } + // Now for each ineligible node, remove all of its inputs or outputs from + // the subgraph. + // + // It can be proven that, if the original subgraph: + // 1. is a DAG, and + // 2. all paths between two nodes in the subgraph are all inside the + // subgraph + // then after doing this operation the resulting subgraph will keep the + // same properties 1 and 2. + // + // For simplicity we use heuristics: for input and const output nodes + // remove all their inputs, and for non-const output nodes remove all + // their outputs. In this way, for common cases the number of removed + // nodes should be minimum. + auto remove_nodes = [&segment_nodes]( + bool is_input_nodes, + std::deque* que) { + // Run a BFS on the queue to find all the input/output nodes. + std::set visited; + std::set logged(que->begin(), que->end()); + while (!que->empty()) { + auto node = que->front(); + que->pop_front(); + if (!visited.insert(node).second) continue; + segment_nodes.erase(node); + for (auto in : (is_input_nodes || node->type_string() == "Const") + ? node->in_nodes() + : node->out_nodes()) { + if (segment_nodes.count(in)) { + que->push_back(in); + if (VLOG_IS_ON(2)) { + if (!logged.count(in)) { + VLOG(2) << "----> Need to remove node " << in->name() + << " because one of its " + << (is_input_nodes ? "output" : "input") + << " nodes in the graph was removed: " + << node->name(); + logged.insert(in); + } + } + } + } + } + }; + remove_nodes(true, &in_nodes_que); + remove_nodes(false, &out_nodes_que); + } + VLOG(1) << "Segment new size: " << segment_nodes.size(); + } + + // --------------------------------- Step 3 --------------------------------- // Convert the segments into the expected return format for (const auto& itr : sg_map) { - const auto& segment_node_names = itr.second; + const std::set& segment_nodes = itr.second; if (VLOG_IS_ON(1)) { - string s; - for (const auto& name : segment_node_names) { - s += " " + name; - } - VLOG(1) << "Segment " << segments->size() << ":" << s; + string s = "parent=" + itr.first + ":"; + for (auto node : segment_nodes) s += " " + node->name(); + VLOG(1) << "Segment " << segments->size() << ": " << s; } // Don't use small segments. - if (static_cast(segment_node_names.size()) < - options.minimum_segment_size) { + if (static_cast(segment_nodes.size()) < options.minimum_segment_size) { VLOG(1) << "Segment " << segments->size() << " has only " - << segment_node_names.size() << " nodes, dropping"; + << segment_nodes.size() << " nodes, dropping"; continue; } + // TODO(sami): Make segmenter placement aware once trtscopes are in place + std::set segment_node_names; + for (auto node : itr.second) segment_node_names.insert(node->name()); const auto& dev_itr = device_maps.find(itr.first); if (dev_itr == device_maps.end() || dev_itr->second.empty()) { VLOG(1) << "No device assigned to segment " << segments->size(); diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h index 81b4bfe49fe375d19f4c7811459f38e25d2edea8..8c44eb782aa37052680d0e06023f29dc65e327c6 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.h +++ b/tensorflow/contrib/tensorrt/segment/segment.h @@ -40,22 +40,6 @@ struct SegmentOptions { std::set exclude_node_list; }; -// Get the subgraphs of a graph that can be handled by TensorRT. -// -// @param gdef The GraphDef describing the network -// @param candidate_fn A function that returns true for a NodeDef if -// that node can be handled by TensorRT. -// @param segments Returns the TensorRT segments/subgraphs. Each entry -// in the vector describes a subgraph by giving a set of the names of -// all the NodeDefs in that subgraph. -// @return the status. -// -// TODO(aaroey): remove this method. -tensorflow::Status SegmentGraph( - const tensorflow::GraphDef& gdef, - const std::function& candidate_fn, - const SegmentOptions& options, SegmentNodesVector* segments); - // Get the subgraphs of a graph that can be handled by TensorRT. // // @param graph tensorflow::Graph of the network @@ -66,8 +50,10 @@ tensorflow::Status SegmentGraph( // all the NodeDefs in that subgraph. // @return the status. tensorflow::Status SegmentGraph( - tensorflow::Graph* tf_graph, + const tensorflow::Graph* tf_graph, const std::function& candidate_fn, + const std::function& input_candidate_fn, + const std::function& output_candidate_fn, const SegmentOptions& options, SegmentNodesVector* segments); } // namespace segment diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc index f5b2d258d70d5577a9d68f2d9f6d6e678ede97ce..5937fa8259a39339e92b150862d195ee1f23f70a 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -14,350 +14,245 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/segment/segment.h" -#include "tensorflow/c/c_api.h" -#include "tensorflow/core/framework/graph.pb.h" + +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session.h" namespace tensorflow { namespace tensorrt { namespace segment { namespace test { +namespace ops = ::tensorflow::ops; class SegmentTest : public ::testing::Test { - public: - bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def); - - TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name); - TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, - TF_Status* s, const char* name); - - std::function MakeCandidateFn( - const std::set& node_names); - protected: - void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name, - TF_Operation** op); - void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, - TF_Status* s, const char* name, TF_Operation** op, bool check); - - SegmentOptions default_options_; -}; - -bool SegmentTest::GetGraphDef(TF_Graph* graph, - tensorflow::GraphDef* graph_def) { - TF_Status* s = TF_NewStatus(); - TF_Buffer* buffer = TF_NewBuffer(); - TF_GraphToGraphDef(graph, buffer, s); - bool ret = TF_GetCode(s) == TF_OK; - EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length); - TF_DeleteBuffer(buffer); - TF_DeleteStatus(s); - return ret; -} + std::function MakeCandidateFn( + const std::set& node_names) { + return [node_names](const tensorflow::Node* node) -> bool { + return node_names.find(node->name()) != node_names.end(); + }; + } -std::function SegmentTest::MakeCandidateFn( - const std::set& node_names) { - return [node_names](const tensorflow::Node* node) -> bool { - return node_names.find(node->name()) != node_names.end(); - }; -} + std::function MakeInputEdgeCandidateFn( + const std::set& node_names) { + return [node_names](const tensorflow::Edge* in_edge) -> bool { + return node_names.find(in_edge->dst()->name()) != node_names.end(); + }; + } -void SegmentTest::PlaceholderHelper(TF_Graph* graph, TF_Status* s, - const char* name, TF_Operation** op) { - TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); - TF_SetAttrType(desc, "dtype", TF_INT32); - *op = TF_FinishOperation(desc, s); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - ASSERT_NE(*op, nullptr); -} + std::function MakeOutputEdgeCandidateFn( + const std::set& node_names) { + return [node_names](const tensorflow::Edge* out_edge) -> bool { + return node_names.find(out_edge->src()->name()) != node_names.end(); + }; + } -TF_Operation* SegmentTest::Placeholder(TF_Graph* graph, TF_Status* s, - const char* name) { - TF_Operation* op; - PlaceholderHelper(graph, s, name, &op); - return op; -} + void RunTest(const tensorflow::Graph* graph, + const std::set& candidates, + const std::set& input_candidates, + const std::set& output_candidates, + const std::vector>& expected_segments) { + SegmentNodesVector segments; + TF_EXPECT_OK(SegmentGraph(graph, MakeCandidateFn(candidates), + MakeInputEdgeCandidateFn(input_candidates), + MakeOutputEdgeCandidateFn(output_candidates), + default_options_, &segments)); + ValidateSegment(segments, expected_segments); + } -void SegmentTest::AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, - TF_Status* s, const char* name, TF_Operation** op, - bool check) { - TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); - TF_Output add_inputs[2] = {{l, 0}, {r, 0}}; - TF_AddInputList(desc, add_inputs, 2); - *op = TF_FinishOperation(desc, s); - if (check) { - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - ASSERT_NE(*op, nullptr); + void ValidateSegment(const SegmentNodesVector& segments, + const std::vector>& expected_segments) { + EXPECT_EQ(expected_segments.size(), segments.size()); + for (int i = 0; i < segments.size(); ++i) { + const auto& segment_node_names = segments[i].first; + const auto& expected = expected_segments[i]; + for (const auto& name : expected) { + EXPECT_TRUE(segment_node_names.count(name)) + << "Segment " << i << " is missing expected node: " << name; + } + if (segment_node_names.size() == expected.size()) continue; + for (const auto& name : segment_node_names) { + EXPECT_TRUE(expected.count(name)) + << "Unexpected node found in segment " << i << ": " << name; + } + } } -} -TF_Operation* SegmentTest::Add(TF_Operation* l, TF_Operation* r, - TF_Graph* graph, TF_Status* s, - const char* name) { - TF_Operation* op; - AddHelper(l, r, graph, s, name, &op, true); - return op; + SegmentOptions default_options_; +}; + +std::set operator-(const std::set& lhs, const string& rhs) { + std::set result = lhs; + CHECK(result.erase(rhs)); + return result; } TEST_F(SegmentTest, Empty) { - TF_Graph* graph = TF_NewGraph(); - - GraphDef graph_def; - ASSERT_TRUE(GetGraphDef(graph, &graph_def)); - - SegmentNodesVector segments; - ASSERT_EQ( - SegmentGraph(graph_def, MakeCandidateFn({}), default_options_, &segments), - tensorflow::Status::OK()); - + Scope s = Scope::NewRootScope(); + tensorflow::Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); // Expect no segments/subgraphs. - EXPECT_TRUE(segments.empty()); - TF_DeleteGraph(graph); + RunTest(&g, {}, {}, {}, {}); } TEST_F(SegmentTest, Simple) { - TF_Status* s = TF_NewStatus(); - TF_Graph* graph = TF_NewGraph(); - // feed - // // || + // // \\ // add0 add1 - // | | / + // | \ / // | add2 - // | / || + // | / \\ // add3 add4 - // | / + // \ / // - // - TF_Operation* feed = Placeholder(graph, s, "feed"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); - - TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add1 = Add(feed, feed, graph, s, "add1"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add2 = Add(add0, add1, graph, s, "add2"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add3 = Add(add0, add2, graph, s, "add3"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("add3"), string(TF_OperationName(add3))); - TF_Operation* add4 = Add(add2, add2, graph, s, "add4"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("add4"), string(TF_OperationName(add4))); - - GraphDef graph_def; - ASSERT_TRUE(GetGraphDef(graph, &graph_def)); - - SegmentNodesVector segments; - ASSERT_EQ( - SegmentGraph(graph_def, - MakeCandidateFn({"add0", "add1", "add2", "add3", "add4"}), - default_options_, &segments), - tensorflow::Status::OK()); - - // Expect all Add operations to be collapsed into a single segment - ASSERT_EQ(segments.size(), 1); - std::vector expected{"add0", "add1", "add2", "add3", "add4"}; - for (const auto& ex : expected) { - EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end()) - << "Missing expected node " << ex; - } - TF_DeleteGraph(graph); - TF_DeleteStatus(s); + Scope s = Scope::NewRootScope(); + auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT); + auto add0 = ops::Add(s.WithOpName("add0"), feed, feed); + auto add1 = ops::Add(s.WithOpName("add1"), feed, feed); + auto add2 = ops::Add(s.WithOpName("add2"), add0, add1); + auto add3 = ops::Add(s.WithOpName("add3"), add0, add2); + auto add4 = ops::Add(s.WithOpName("add4"), add2, add2); + tensorflow::Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); + + // All Add operations are candidates, and we expect all of them to be + // collapsed into a single segment + const std::set all_adds = {"add0", "add1", "add2", "add3", "add4"}; + RunTest(&g, all_adds, all_adds, all_adds, {all_adds}); + + // Make add1 not a candidate, and we expect all other Add operations to be + // collapsed into a single segment + auto without_add1 = all_adds - "add1"; + RunTest(&g, without_add1, without_add1, without_add1, {without_add1}); + + // Make add1 not a candidate and add2 not an input candidate, and we expect + // add0 and add2 are removed from the segment. + auto without_add2 = all_adds - "add2"; + RunTest(&g, without_add1, without_add2, without_add1, {{"add3", "add4"}}); + + // Making add2 not an input candidate itself won't affect anything. + RunTest(&g, all_adds, without_add2, all_adds, {all_adds}); + + // Making add1 not an input candidate. + RunTest(&g, all_adds, without_add1, all_adds, {without_add1}); + + // Making add3 not an output candidate doesn't affect anything, since it's + // output is sink. + auto without_add3 = all_adds - "add3"; + RunTest(&g, all_adds, all_adds, without_add3, {all_adds}); } TEST_F(SegmentTest, AvoidCycle) { - TF_Status* s = TF_NewStatus(); - TF_Graph* graph = TF_NewGraph(); - - // add2 is not a TRT candidate so add0/add3 cannot be formed as a - // subgraph - // // feed - // // || + // // \\ // add0 add1 - // | | / + // | \ / // | add2 - // | / || + // | / \\ // add3 add4 - // | / + // \ / // - // - TF_Operation* feed = Placeholder(graph, s, "feed"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); - - TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add1 = Add(feed, feed, graph, s, "add1"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add2 = Add(add0, add1, graph, s, "add2"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add3 = Add(add0, add2, graph, s, "add3"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("add3"), string(TF_OperationName(add3))); - TF_Operation* add4 = Add(add2, add2, graph, s, "add4"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("add4"), string(TF_OperationName(add4))); - - GraphDef graph_def; - ASSERT_TRUE(GetGraphDef(graph, &graph_def)); - - SegmentNodesVector segments; - ASSERT_EQ( - SegmentGraph(graph_def, MakeCandidateFn({"add0", "add1", "add3", "add4"}), - default_options_, &segments), - tensorflow::Status::OK()); - - // Expect no subgraphs - EXPECT_EQ(segments.size(), 0); - TF_DeleteGraph(graph); - TF_DeleteStatus(s); + Scope s = Scope::NewRootScope(); + auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT); + auto add0 = ops::Add(s.WithOpName("add0"), feed, feed); + auto add1 = ops::Add(s.WithOpName("add1"), feed, feed); + auto add2 = ops::Add(s.WithOpName("add2"), add0, add1); + auto add3 = ops::Add(s.WithOpName("add3"), add0, add2); + auto add4 = ops::Add(s.WithOpName("add4"), add2, add2); + tensorflow::Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); + + // add2 is not a TRT candidate so there should be no segments generated. + const std::set without_add2 = {"add0", "add1", "add3", "add4"}; + RunTest(&g, without_add2, without_add2, without_add2, {}); } TEST_F(SegmentTest, Multiple) { - TF_Status* s = TF_NewStatus(); - TF_Graph* graph = TF_NewGraph(); - - // add5 is not a TRT candidate so two subgraphs should be formed - // - // feed - // // || || - // add0 add1 add7 - // | | / / || - // | add2-----add5 add8 - // | / | | | | - // add3 add4 add6 - // | | / - // - // - TF_Operation* feed = Placeholder(graph, s, "feed"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); - - TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add1 = Add(feed, feed, graph, s, "add1"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add7 = Add(feed, feed, graph, s, "add7"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add2 = Add(add0, add1, graph, s, "add2"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add5 = Add(add2, add7, graph, s, "add5"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add8 = Add(add7, add7, graph, s, "add8"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add3 = Add(add0, add2, graph, s, "add3"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("add3"), string(TF_OperationName(add3))); - TF_Operation* add4 = Add(add2, add5, graph, s, "add4"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("add4"), string(TF_OperationName(add4))); - TF_Operation* add6 = Add(add5, add8, graph, s, "add6"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("add6"), string(TF_OperationName(add6))); - - GraphDef graph_def; - ASSERT_TRUE(GetGraphDef(graph, &graph_def)); - - SegmentNodesVector segments; - ASSERT_EQ(SegmentGraph(graph_def, - MakeCandidateFn({"add0", "add1", "add2", "add3", - "add4", "add6", "add7", "add8"}), - default_options_, &segments), - tensorflow::Status::OK()); - - // Expect two subgraphs - EXPECT_EQ(segments.size(), 2); - - std::vector expected0{"add6", "add8"}; - for (const auto& ex : expected0) { - EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end()) - << "Missing expected node " << ex; - } - - std::vector expected1{"add0", "add1", "add2", "add3"}; - for (const auto& ex : expected1) { - EXPECT_TRUE(segments[1].first.find(ex) != segments[1].first.end()) - << "Missing expected node " << ex; - } - TF_DeleteGraph(graph); - TF_DeleteStatus(s); + // feed + // // || \\ + // add0 add1 add7 + // | \ / / \\ + // | add2 / \\ + // | || \ | || + // | || add5 add8 + // | / \ / \ / + // add3 add4 add6 + // \ | / + // + Scope s = Scope::NewRootScope(); + auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT); + auto add0 = ops::Add(s.WithOpName("add0"), feed, feed); + auto add1 = ops::Add(s.WithOpName("add1"), feed, feed); + auto add7 = ops::Add(s.WithOpName("add7"), feed, feed); + auto add2 = ops::Add(s.WithOpName("add2"), add0, add1); + auto add5 = ops::Add(s.WithOpName("add5"), add2, add7); + auto add8 = ops::Add(s.WithOpName("add8"), add7, add7); + auto add3 = ops::Add(s.WithOpName("add3"), add0, add2); + auto add4 = ops::Add(s.WithOpName("add4"), add2, add5); + auto add6 = ops::Add(s.WithOpName("add6"), add5, add8); + tensorflow::Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); + + const std::set all_adds = {"add0", "add1", "add2", "add3", "add4", + "add5", "add6", "add7", "add8"}; + // Make add5 not a TRT candidate, and we expect two segments. + auto without_add5 = all_adds - "add5"; + RunTest(&g, without_add5, without_add5, without_add5, + {{"add0", "add1", "add2", "add3"}, {"add6", "add8"}}); + + // Make add8 not a candidate and add6 not an input candidate, then all direct + // and indirect inputs of add6 will be removed from the segment. + auto without_add8 = all_adds - "add8"; + auto without_add6 = all_adds - "add6"; + RunTest(&g, without_add8, without_add6, all_adds, {{"add3", "add4"}}); + + // Make add3 not a candidate and add0 not an output candidate, then all + // direct and indirect outputs of add0 will be removed from the segment. + auto without_add3 = all_adds - "add3"; + auto without_add0 = all_adds - "add0"; + RunTest(&g, without_add3, all_adds, without_add0, {{"add1", "add7", "add8"}}); } TEST_F(SegmentTest, BigIfElse) { - TF_Status* s = TF_NewStatus(); - TF_Graph* graph = TF_NewGraph(); - - // add2 is not a TRT candidate - // // feed // || // add0 - // // || + // // \\ // add1 add4 // || || // add2 add5 // || || // add3 add6 - // || // + // \\ // // add7 // || // - // - TF_Operation* feed = Placeholder(graph, s, "feed"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("feed"), string(TF_OperationName(feed))); - - TF_Operation* add0 = Add(feed, feed, graph, s, "add0"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add1 = Add(add0, add0, graph, s, "add1"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add2 = Add(add1, add1, graph, s, "add2"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add3 = Add(add2, add2, graph, s, "add3"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add4 = Add(add0, add0, graph, s, "add4"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add5 = Add(add4, add4, graph, s, "add5"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add6 = Add(add5, add5, graph, s, "add6"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Operation* add7 = Add(add3, add6, graph, s, "add7"); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(string("add7"), string(TF_OperationName(add7))); - - GraphDef graph_def; - ASSERT_TRUE(GetGraphDef(graph, &graph_def)); - - SegmentNodesVector segments; - ASSERT_EQ(SegmentGraph(graph_def, - MakeCandidateFn({"add0", "add1", "add3", "add4", - "add5", "add6", "add7"}), - default_options_, &segments), - tensorflow::Status::OK()); - - // Expect 2 subgraphs - EXPECT_EQ(segments.size(), 2); - - std::vector expected0{"add3", "add4", "add5", "add6", "add7"}; - for (const auto& ex : expected0) { - EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end()) - << "Missing expected node " << ex; - } - - std::vector expected1{"add0", "add1"}; - for (const auto& ex : expected1) { - EXPECT_TRUE(segments[1].first.find(ex) != segments[1].first.end()) - << "Missing expected node " << ex; - } - TF_DeleteGraph(graph); - TF_DeleteStatus(s); + Scope s = Scope::NewRootScope(); + auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT); + auto add0 = ops::Add(s.WithOpName("add0"), feed, feed); + auto add1 = ops::Add(s.WithOpName("add1"), add0, add0); + auto add2 = ops::Add(s.WithOpName("add2"), add1, add1); + auto add3 = ops::Add(s.WithOpName("add3"), add2, add2); + auto add4 = ops::Add(s.WithOpName("add4"), add0, add0); + auto add5 = ops::Add(s.WithOpName("add5"), add4, add4); + auto add6 = ops::Add(s.WithOpName("add6"), add5, add5); + auto add7 = ops::Add(s.WithOpName("add7"), add3, add6); + tensorflow::Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); + + // Make add2 not a TRT candidate, and we expect 2 segments. + const std::set all_adds = {"add0", "add1", "add2", "add3", + "add4", "add5", "add6", "add7"}; + RunTest(&g, all_adds - "add2", all_adds, all_adds, + {{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}}); } } // namespace test diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index 227ac120dde8c986379c687987cd1bd822d559f7..f30dba59ad55317d7ad7730e4dc66c9aba4e6a6b 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -28,36 +28,50 @@ limitations under the License. namespace tensorflow { namespace shape_inference { -tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) { - std::vector shapes; - for (int i = 0; i < context->num_outputs(); ++i) { - context->set_output(i, context->UnknownShape()); +tensorflow::Status TRTEngineOpShapeInference(InferenceContext* c) { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->UnknownShape()); } - auto status = context->GetAttr("input_shapes", &shapes); - // it is ok to not to have shapes - if (!status.ok()) return Status::OK(); - if ((int)shapes.size() != context->num_inputs()) return Status::OK(); - bool different_input = false; - for (int i = 0; i < context->num_inputs(); ++i) { - if (shapes.at(i) != context->input_tensor(i)->shape()) - different_input = true; + + // Check the sanity of the input shapes. + std::vector input_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("input_shapes", &input_shapes)); + if (input_shapes.size() != c->num_inputs()) { + return tensorflow::errors::InvalidArgument( + "The actual number of inputs doesn't match the number of input " + "shapes set in the attr: ", + c->num_inputs(), " vs ", input_shapes.size()); + } + bool input_match = true; + for (int i = 0; i < c->num_inputs(); ++i) { + ShapeHandle handle; + TF_RETURN_IF_ERROR( + c->MakeShapeFromTensorShape(input_shapes.at(i), &handle)); + ShapeHandle merged; + if (!c->Merge(c->input(i), handle, &merged).ok()) { + // Input shape doesn't match what was set in attr, fine. + input_match = false; + } } - if (different_input) return Status::OK(); - shapes.resize(0); - status = context->GetAttr("output_shapes", &shapes); - if (!status.ok()) return Status::OK(); - if ((int)shapes.size() != context->num_outputs()) return Status::OK(); - std::vector shape_handles(shapes.size()); - for (size_t i = 0; i < shapes.size(); ++i) { - status = - context->MakeShapeFromTensorShape(shapes.at(i), &shape_handles.at(i)); - if (!status.ok()) return Status::OK(); + + // Check the sanity of the output shapes. + std::vector output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + if (output_shapes.size() != c->num_outputs()) { + return tensorflow::errors::InvalidArgument( + "The actual number of outputs doesn't match the number of output " + "shapes set in the attr: ", + c->num_outputs(), " vs ", output_shapes.size()); } - for (int i = 0; i < context->num_outputs(); ++i) { - context->set_output(i, shape_handles.at(i)); + for (size_t i = 0; i < output_shapes.size(); ++i) { + ShapeHandle handle; + TF_RETURN_IF_ERROR( + c->MakeShapeFromTensorShape(output_shapes.at(i), &handle)); + if (input_match) c->set_output(i, handle); } return Status::OK(); } + } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/tensorrt_test.cc b/tensorflow/contrib/tensorrt/tensorrt_test.cc index 3712a9a6fe349d949ef2666652b9d750538d5535..769982c6456f76663e50fe3ec59651127e3720ac 100644 --- a/tensorflow/contrib/tensorrt/tensorrt_test.cc +++ b/tensorflow/contrib/tensorrt/tensorrt_test.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/common_runtime/gpu/gpu_init.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/test.h" #if GOOGLE_CUDA @@ -130,6 +132,13 @@ void Execute(nvinfer1::IExecutionContext* context, const float* input, } TEST(TensorrtTest, BasicFunctions) { + // Handle the case where the test is run on machine with no gpu available. + if (CHECK_NOTNULL(GPUMachineManager())->VisibleDeviceCount() <= 0) { + LOG(WARNING) << "No gpu device available, probably not being run on a gpu " + "machine. Skipping..."; + return; + } + // Create the network model. nvinfer1::IHostMemory* model = CreateNetwork(); // Use the model to create an engine and then an execution context. diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea5a6373525a8045d13f70aa9e12d66d4c08f0a --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/base_test.py @@ -0,0 +1,346 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Basic tests for TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.python import trt_convert +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + + +class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing single segment.""" + # TODO(aaroey): test graph with different dtypes. + dtype = dtypes.float32 + input_name = "input" + input_dims = [100, 24, 24, 2] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtype, shape=[None] + input_dims[1:], name=input_name) + with g.device("/GPU:0"): + conv_filter = constant_op.constant( + [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], + name="weights", + dtype=dtype) + conv = nn.conv2d( + input=inp, + filter=conv_filter, + strides=[1, 2, 2, 1], + padding="SAME", + name="conv") + bias = constant_op.constant( + [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtype) + added = nn.bias_add(conv, bias, name="bias_add") + relu = nn.relu(added, "relu") + identity = array_ops.identity(relu, "identity") + pool = nn_ops.max_pool( + identity, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") + array_ops.squeeze(pool, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which + # breaks the connection check, fix it. + # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add", + # "relu", "identity", "max_pool"] + expected_engines=["my_trt_op_0"], + expected_output_dims=(100, 6, 6, 6), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing multiple segment.""" + # TODO(aaroey): test graph with different dtypes. + dtype = dtypes.float32 + input_name = "input" + input_dims = [100, 24, 24, 2] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtype, shape=[None] + input_dims[1:], name=input_name) + with g.device("/GPU:0"): + conv_filter = constant_op.constant( + [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], + name="weights", + dtype=dtype) + conv = nn.conv2d( + input=inp, + filter=conv_filter, + strides=[1, 2, 2, 1], + padding="SAME", + name="conv") + c1 = constant_op.constant( + np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c1") + p = math_ops.mul(conv, c1, name="mul") + c2 = constant_op.constant( + np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c2") + q = math_ops.div(conv, c2, name="div") + + edge = self.trt_incompatible_op(q, name="incompatible") + edge = math_ops.div(edge, edge, name="div1") + r = math_ops.add(edge, edge, name="add") + + p = math_ops.sub(p, edge, name="sub") + q = math_ops.mul(q, edge, name="mul1") + s = math_ops.add(p, q, name="add1") + s = math_ops.sub(s, r, name="sub1") + array_ops.squeeze(s, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which + # breaks the connection check, fix it. + # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1", + # "add", "sub1"]; + # - my_trt_op_1 should have ["weights","conv", "div"] + expected_engines=["my_trt_op_0", "my_trt_op_1"], + expected_output_dims=(100, 12, 12, 6), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): + + def setUp(self): + """Setup method.""" + super(PartiallyConvertedTestA, self).setUp() + # Let it fail to build the second engine. + trt_convert.add_test_value("my_trt_op_1:CreateTRTNode", "fail") + + def GetParams(self): + """Create a graph containing two segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + n = inp + for i in range(2): + c = constant_op.constant(1.0, name="c%d" % i) + n = math_ops.add(n, c, name="add%d" % i) + n = math_ops.mul(n, n, name="mul%d" % i) + edge = self.trt_incompatible_op(n, name="incompatible") + with g.control_dependencies([edge]): + c = constant_op.constant(1.0, name="c2") + n = math_ops.add(n, c, name="add2") + n = math_ops.mul(n, n, name="mul2") + c = constant_op.constant(1.0, name="c3") + n = math_ops.add(n, c, name="add3") + n = math_ops.mul(n, n, name="mul3") + array_ops.squeeze(n, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + # Only the first engine is built. + "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"] + }, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +class PartiallyConvertedTestB(PartiallyConvertedTestA): + + def setUp(self): + """Setup method.""" + super(PartiallyConvertedTestB, self).setUp() + # Let it fail to build the first engine. + trt_convert.clear_test_values("") + trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail") + + def GetParams(self): + """Create a graph containing two segment.""" + return super(PartiallyConvertedTestB, self).GetParams()._replace( + expected_engines={ + # Only the second engine is built. + "my_trt_op_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"] + }) + + +class ConstInputTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing multiple segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + n = inp + c = constant_op.constant(1.0, name="c") + # Adds control dependency from the constant op to a trt incompatible op, + # and adds control dependency from the trt incompatible op to all other + # ops, to make sure the constant op cannot be contracted with any trt + # segment that depends on it. + with g.control_dependencies([c]): + d = self.trt_incompatible_op(n, name="incompatible") + with g.control_dependencies([d]): + n = math_ops.add(n, c, name="add") + n = math_ops.mul(n, n, name="mul") + n = math_ops.add(n, n, name="add1") + n = self.trt_incompatible_op(n, name="incompatible1") + with g.control_dependencies([d]): + n = math_ops.add(n, c, name="add2") + n = math_ops.mul(n, n, name="mul1") + n = math_ops.add(n, n, name="add3") + array_ops.squeeze(n, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + "my_trt_op_0": ["add", "add1", "mul"], + "my_trt_op_1": ["add2", "add3", "mul1"] + }, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing single segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + n = inp + c = constant_op.constant(1.0, name="c") + n = math_ops.add(n, c, name="add") + n = math_ops.mul(n, n, name="mul") + n = math_ops.add(n, n, name="add1") + array_ops.squeeze(n, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={"my_trt_op_0": ["c", "add", "add1", "mul"]}, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing multiple segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + n = inp + c = constant_op.constant(1.0, name="c") + n = math_ops.add(n, c, name="add") + n = math_ops.mul(n, n, name="mul") + n = math_ops.add(n, n, name="add1") + n = self.trt_incompatible_op(n, name="incompatible1") + n = math_ops.add(n, c, name="add2") + n = math_ops.mul(n, n, name="mul1") + n = math_ops.add(n, n, name="add3") + array_ops.squeeze(n, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + "my_trt_op_0": ["add2", "add3", "mul1"], + # Why segment ["add", "add1", "mul"] was assigned segment id 1 + # instead of 0: the parent node of this segment is actually const + # node 'c', but it's removed later since it's const output of the + # segment which is not allowed. + "my_trt_op_1": ["add", "add1", "mul"] + }, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Create a graph containing multiple segment.""" + input_name = "input" + input_dims = [2, 32, 32, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name=input_name) + with g.device("/GPU:0"): + c1 = constant_op.constant(1.0, name="c1") + c2 = constant_op.constant(1.0, name="c2") + d1 = constant_op.constant(1.0, name="d1") + d2 = self.trt_incompatible_op(inp, name="d2") + with g.control_dependencies([d1, d2]): + add = math_ops.add(inp, c1, name="add") + with g.control_dependencies([d1, d2]): + mul = math_ops.mul(add, add, name="mul") + with g.control_dependencies([d1, d2]): + add1 = math_ops.add(mul, mul, name="add1") + edge = self.trt_incompatible_op(add1, name="incompatible") + with g.control_dependencies([d1, d2, add, mul]): + add2 = math_ops.add(edge, c2, name="add2") + with g.control_dependencies([d1, d2, add1, mul]): + mul1 = math_ops.mul(add2, add2, name="mul1") + with g.control_dependencies([d1, d2, add, add1]): + add3 = math_ops.add(mul1, mul1, name="add3") + array_ops.squeeze(add3, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + "my_trt_op_0": ["c1", "add", "add1", "mul"], + "my_trt_op_1": ["c2", "add2", "add3", "mul1"] + }, + expected_output_dims=tuple(input_dims), + allclose_atol=1.e-06, + allclose_rtol=1.e-06) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2e1107e30383926f6428c6551682caf66cd97498 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py @@ -0,0 +1,76 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Testing conversion of BatchMatMul in TF-TRT conversion.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [12, 5, 8, 12] + w1_name = "matmul_w1" + w1_dims = [12, 5, 12, 7] + w2_name = "matmul_w2" + w2_dims = [12, 12, 7] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtype, shape=[None] + input_dims[1:], name=input_name) + w1 = array_ops.placeholder(dtype=dtype, shape=w1_dims, name=w1_name) + w2 = array_ops.placeholder(dtype=dtype, shape=w2_dims, name=w2_name) + with g.device("/GPU:0"): + b = constant_op.constant(np.random.randn(12, 5, 12, 7), dtype=dtype) + c = constant_op.constant(np.random.randn(5, 1, 1), dtype=dtype) + d = constant_op.constant(np.random.randn(5, 1, 1), dtype=dtype) + x1 = math_ops.matmul(inp, b) + x1 = x1 + c + x2 = math_ops.matmul(inp, w1) + x2 = x2 * d + e = gen_array_ops.reshape(inp, [12, 40, 12]) + x3 = math_ops.matmul(e, w2) + f = constant_op.constant(np.random.randn(40, 1), dtype=dtype) + x3 = x3 + f + x3 = gen_array_ops.reshape(x3, [12, 5, 8, 7]) + out = x1 + x2 + x3 + array_ops.squeeze(out, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name, w1_name, w2_name], + input_dims=[input_dims, w1_dims, w2_dims], + expected_engines=["my_trt_op_0"], + expected_output_dims=(12, 5, 8, 7), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8be32f59b48e64412466370950298feafc03b35c --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py @@ -0,0 +1,115 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test + + +class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Testing conversion of BiasAdd MatMul in TF-TRT conversion.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [48, 12] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + + b = constant_op.constant(np.random.randn(12, 4), dtype=dtype) + x1 = math_ops.matmul(x, b) + b = constant_op.constant(np.random.randn(1, 4), dtype=dtype) + x1 = x1 + b + + b = constant_op.constant(np.random.randn(48, 4), dtype=dtype) + x2 = math_ops.matmul(x, b, transpose_a=True) + x2 = gen_array_ops.reshape(x2, [48, 1]) + + b = constant_op.constant(np.random.randn(4, 12), dtype=dtype) + x3 = math_ops.matmul(x, b, transpose_b=True) + + b = constant_op.constant(np.random.randn(16, 48), dtype=dtype) + x4 = math_ops.matmul(x, b, transpose_b=True, transpose_a=True) + x4 = gen_array_ops.reshape(x4, [48, 4]) + + x5 = gen_array_ops.reshape(x, [4, 144]) + b = constant_op.constant(np.random.randn(144, 48), dtype=dtype) + x5 = math_ops.matmul(x5, b) + b = constant_op.constant(np.random.randn(48), dtype=dtype) + x5 = nn.bias_add(x5, b) + x5 = gen_array_ops.reshape(x5, [48, 4]) + + x6 = gen_array_ops.reshape(x, [4, 12, 12]) + b = constant_op.constant(np.random.randn(12), dtype=dtype) + x6 = nn.bias_add(x6, b, data_format="NHWC") + x6 = gen_array_ops.reshape(x6, [48, -1]) + + x7 = gen_array_ops.reshape(x, [4, 12, 3, 4]) + b = constant_op.constant(np.random.randn(4), dtype=dtype) + x7 = nn.bias_add(x7, b, data_format="NHWC") + x7 = gen_array_ops.reshape(x7, [48, -1]) + + x8 = gen_array_ops.reshape(x, [4, 12, 3, 2, 2]) + b = constant_op.constant(np.random.randn(2), dtype=dtype) + x8 = nn.bias_add(x8, b, data_format="NHWC") + x8 = gen_array_ops.reshape(x8, [48, -1]) + + x9 = gen_array_ops.reshape(x, [4, 12, 3, 2, 2]) + b = constant_op.constant(np.random.randn(3), dtype=dtype) + x9 = nn.bias_add(x9, b, data_format="NCHW") + x9 = gen_array_ops.reshape(x9, [48, -1]) + + x10 = gen_array_ops.reshape(x, [4, 12, 3, 4]) + b = constant_op.constant(np.random.randn(12), dtype=dtype) + x10 = nn.bias_add(x10, b, data_format="NCHW") + x10 = gen_array_ops.reshape(x10, [48, -1]) + + x11 = gen_array_ops.reshape(x, [4, 12, 12]) + b = constant_op.constant(np.random.randn(4), dtype=dtype) + x11 = nn.bias_add(x11, b, data_format="NCHW") + x11 = gen_array_ops.reshape(x11, [48, -1]) + + out = array_ops.concat( + [x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11], axis=-1) + out = array_ops.squeeze(out, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines=[ + "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3", + "my_trt_op_4", "my_trt_op_5", "my_trt_op_6" + ], + expected_output_dims=(48, 89), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9316b14da07d5f7e47953504680e14d5d20c17a4 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py @@ -0,0 +1,136 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Tests for scale & elementwise layers in TF-TRT.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [10, 24, 24, 20] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + # scale + a = constant_op.constant(np.random.randn(1), dtype=dtype) + f = x + a + x = math_ops.sigmoid(f) + # scale + a = constant_op.constant(np.random.randn(1), dtype=dtype) + f = a + x + x = math_ops.sigmoid(f) + # scale + a = constant_op.constant(np.random.randn(24, 1, 1), dtype=dtype) + f = x + a + x = math_ops.sigmoid(f) + # scale + a = constant_op.constant(np.random.randn(24, 1, 1), dtype=dtype) + f = a + x + x = math_ops.sigmoid(f) + # scale + a = constant_op.constant(np.random.randn(24, 24, 20), dtype=dtype) + f = a + x + x = math_ops.sigmoid(f) + # scale + a = constant_op.constant(np.random.randn(24, 24, 20), dtype=dtype) + f = x + a + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(20), dtype=dtype) + f = x + a + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(20), dtype=dtype) + f = a + x + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(1, 24, 1, 1), dtype=dtype) + f = a + x + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(1, 24, 1, 1), dtype=dtype) + f = x + a + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(1, 24, 24, 1), dtype=dtype) + f = a + x + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(1, 24, 24, 1), dtype=dtype) + f = x + a + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(1, 24, 24, 20), dtype=dtype) + f = a + x + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(1, 24, 24, 20), dtype=dtype) + f = x + a + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(24, 20), dtype=dtype) + f = a + x + x = math_ops.sigmoid(f) + # elementwise + a = constant_op.constant(np.random.randn(24, 20), dtype=dtype) + f = x + a + x = math_ops.sigmoid(f) + gen_array_ops.reshape(x, [5, -1], name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines=[ + "my_trt_op_0", + "my_trt_op_1", + "my_trt_op_2", + "my_trt_op_3", + "my_trt_op_4", + "my_trt_op_5", + "my_trt_op_6", + "my_trt_op_7", + "my_trt_op_8", + "my_trt_op_9", + "my_trt_op_10", + "my_trt_op_11", + "my_trt_op_12", + "my_trt_op_13", + "my_trt_op_14", + "my_trt_op_15", + ], + expected_output_dims=(5, 23040), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/concatenation_test.py b/tensorflow/contrib/tensorrt/test/concatenation_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1874b9dd45390407d3d36798cae620848df50c8d --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/concatenation_test.py @@ -0,0 +1,83 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.platform import test + + +class ConcatenationTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Testing Concatenation in TF-TRT conversion.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [2, 3, 3, 1] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + # scale + a = constant_op.constant(np.random.randn(3, 1, 1), dtype=dtype) + r1 = x / a + a = constant_op.constant(np.random.randn(3, 1, 1), dtype=dtype) + r2 = a / x + a = constant_op.constant(np.random.randn(1, 3, 1), dtype=dtype) + r3 = a + x + a = constant_op.constant(np.random.randn(1, 3, 1), dtype=dtype) + r4 = x * a + a = constant_op.constant(np.random.randn(3, 1, 1), dtype=dtype) + r5 = x - a + a = constant_op.constant(np.random.randn(3, 1, 1), dtype=dtype) + r6 = a - x + a = constant_op.constant(np.random.randn(3, 1), dtype=dtype) + r7 = x - a + a = constant_op.constant(np.random.randn(3, 1), dtype=dtype) + r8 = a - x + a = constant_op.constant(np.random.randn(3, 1, 1), dtype=dtype) + r9 = gen_math_ops.maximum(x, a) + a = constant_op.constant(np.random.randn(3, 1), dtype=dtype) + r10 = gen_math_ops.minimum(a, x) + a = constant_op.constant(np.random.randn(3), dtype=dtype) + r11 = x * a + a = constant_op.constant(np.random.randn(1), dtype=dtype) + r12 = a * x + concat1 = array_ops.concat([r1, r2, r3, r4, r5, r6], axis=-1) + concat2 = array_ops.concat([r7, r8, r9, r10, r11, r12], axis=3) + x = array_ops.concat([concat1, concat2], axis=-1) + gen_array_ops.reshape(x, [2, -1], name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines=["my_trt_op_0"], + expected_output_dims=(2, 126), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8c59000b70e04cedc84308249865cfcb23ce80a3 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py @@ -0,0 +1,68 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test + + +class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Test for Constant broadcasting in TF-TRT.""" + dtype = dtypes.float32 + input_name = 'input' + input_dims = [5, 12, 12, 2] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + filt1 = constant_op.constant( + 0.3, shape=(3, 3, 2, 1), dtype=dtype, name='filt1') + y1 = nn.conv2d(x, filt1, strides=[1, 1, 1, 1], padding='SAME', name='y1') + z1 = nn.relu(y1, name='z1') + filt2 = constant_op.constant( + np.random.randn(9), shape=(3, 3, 1, 1), dtype=dtype, name='filt2') + y2 = nn.conv2d(z1, filt2, strides=[1, 1, 1, 1], padding='SAME', name='y2') + z2 = nn.relu(y2, name='z') + filt3 = constant_op.constant( + np.random.randn(3, 3, 1, 1), + shape=(3, 3, 1, 1), + dtype=dtype, + name='filt3') + y3 = nn.conv2d(z2, filt3, strides=[1, 1, 1, 1], padding='SAME', name='y3') + nn.relu(y3, name='output') + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines=['my_trt_op_0'], + expected_output_dims=(5, 12, 12, 1), + allclose_atol=1.e-02, + allclose_rtol=1.e-02) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py new file mode 100644 index 0000000000000000000000000000000000000000..66eb6be757d3f4dcc390435486f7ed4f6517f875 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py @@ -0,0 +1,72 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test + + +class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Testing conversion of BatchMatMul in TF-TRT conversion.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [2, 15, 15, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtype, shape=[None] + input_dims[1:], name=input_name) + with g.device("/GPU:0"): + e1 = constant_op.constant( + np.random.randn(1, 1, 3, 5), name="kernel_1", dtype=dtype) + e2 = constant_op.constant( + np.random.randn(1, 1, 5, 10), name="kernel_2", dtype=dtype) + conv = nn.conv2d( + input=inp, + filter=e1, + strides=[1, 1, 1, 1], + padding="VALID", + name="conv") + out = nn.conv2d( + input=conv, + filter=e2, + strides=[1, 1, 1, 1], + padding="VALID", + name="conv_2") + array_ops.squeeze(out, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines=["my_trt_op_0"], + expected_output_dims=(2, 15, 15, 10), + allclose_atol=1.e-02, + allclose_rtol=1.e-02) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fd55b8cd99171fe34424e48a417eb8981b051c17 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py @@ -0,0 +1,87 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test + + +class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Test for multi connection neighboring nodes wiring tests in TF-TRT.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [2, 3, 7, 5] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + e = constant_op.constant( + np.random.normal(.05, .005, [3, 2, 3, 4]), + name="weights", + dtype=dtype) + conv = nn.conv2d( + input=x, + filter=e, + data_format="NCHW", + strides=[1, 1, 1, 1], + padding="VALID", + name="conv") + b = constant_op.constant( + np.random.normal(2.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype) + t = conv + b + + b = constant_op.constant( + np.random.normal(5.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype) + q = conv - b + edge = math_ops.sigmoid(q) + + b = constant_op.constant( + np.random.normal(5.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype) + d = b + conv + edge3 = math_ops.sigmoid(d) + + edge1 = gen_math_ops.tan(conv) + t = t - edge1 + q = q + edge + t = t + q + t = t + d + t = t - edge3 + array_ops.squeeze(t, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines=["my_trt_op_0", "my_trt_op_1"], + expected_output_dims=(2, 4, 5, 4), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py new file mode 100644 index 0000000000000000000000000000000000000000..51c905a50b29c017719d66f9049e9b1bc3a9ec97 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py @@ -0,0 +1,72 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test + + +class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Neighboring node wiring tests in TF-TRT conversion.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [2, 3, 7, 5] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + e = constant_op.constant( + np.random.normal(.3, 0.05, [3, 2, 3, 4]), name="weights", dtype=dtype) + conv = nn.conv2d( + input=x, + filter=e, + data_format="NCHW", + strides=[1, 1, 1, 1], + padding="VALID", + name="conv") + b = constant_op.constant( + np.random.normal(1.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype) + t = math_ops.mul(conv, b, name="mul") + e = self.trt_incompatible_op(conv, name="incompatible") + t = math_ops.sub(t, e, name="sub") + array_ops.squeeze(t, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines={ + "my_trt_op_0": ["bias", "mul", "sub"], + "my_trt_op_1": ["weights", "conv"] + }, + expected_output_dims=(2, 4, 5, 4), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py deleted file mode 100644 index d9c41f90d0ab111b48c37aeaae5f0ce3177646c2..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py +++ /dev/null @@ -1,347 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Script to test TF-TensorRT integration.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from collections import namedtuple -import itertools -import warnings -import numpy as np -import six - -from tensorflow.contrib import tensorrt as trt -from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import importer -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_ops -from tensorflow.python.platform import test - -INPUT_NAME = "input" -OUTPUT_NAME = "output" -INPUT_DIMS = [100, 24, 24, 2] -MODE_FP32 = "FP32" -MODE_FP16 = "FP16" -MODE_INT8 = "INT8" - -if six.PY2: - to_bytes = lambda s: s - to_string = lambda s: s -else: - to_bytes = lambda s: s.encode("utf-8", errors="surrogateescape") - to_string = lambda s: s.decode("utf-8") - - -# TODO(aaroey): test graph with different dtypes. -def GetSingleEngineGraphDef(dtype=dtypes.float32): - """Create a graph containing single segment.""" - g = ops.Graph() - with g.as_default(): - inp = array_ops.placeholder( - dtype=dtype, shape=[None] + INPUT_DIMS[1:], name=INPUT_NAME) - with g.device("/GPU:0"): - conv_filter = constant_op.constant( - [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], - name="weights", - dtype=dtype) - conv = nn.conv2d( - input=inp, - filter=conv_filter, - strides=[1, 2, 2, 1], - padding="SAME", - name="conv") - bias = constant_op.constant( - [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtype) - added = nn.bias_add(conv, bias, name="bias_add") - relu = nn.relu(added, "relu") - identity = array_ops.identity(relu, "identity") - pool = nn_ops.max_pool( - identity, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") - array_ops.squeeze(pool, name=OUTPUT_NAME) - return g.as_graph_def() - - -# TODO(aaroey): test graph with different dtypes. -def GetMultiEngineGraphDef(dtype=dtypes.float32): - """Create a graph containing multiple segment.""" - g = ops.Graph() - with g.as_default(): - inp = array_ops.placeholder( - dtype=dtype, shape=[None] + INPUT_DIMS[1:], name=INPUT_NAME) - with g.device("/GPU:0"): - conv_filter = constant_op.constant( - [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]], - name="weights", - dtype=dtype) - conv = nn.conv2d( - input=inp, - filter=conv_filter, - strides=[1, 2, 2, 1], - padding="SAME", - name="conv") - c1 = constant_op.constant( - np.random.randn(INPUT_DIMS[0], 12, 12, 6), dtype=dtype) - p = conv * c1 - c2 = constant_op.constant( - np.random.randn(INPUT_DIMS[0], 12, 12, 6), dtype=dtype) - q = conv / c2 - - edge = math_ops.sin(q) - edge /= edge - r = edge + edge - - p -= edge - q *= edge - s = p + q - s -= r - array_ops.squeeze(s, name=OUTPUT_NAME) - return g.as_graph_def() - - -TestGraph = namedtuple("TestGraph", - ["gdef", "num_expected_engines", "expected_output_dims"]) - -TEST_GRAPHS = { - "SingleEngineGraph": - TestGraph( - gdef=GetSingleEngineGraphDef(), - num_expected_engines=1, - expected_output_dims=(100, 6, 6, 6)), - "MultiEngineGraph": - TestGraph( - gdef=GetMultiEngineGraphDef(), - num_expected_engines=2, - expected_output_dims=(100, 12, 12, 6)), - # TODO(aaroey): add a large complex graph to test. -} - - -class TfTrtIntegrationTest(test_util.TensorFlowTestCase): - """Class to test Tensorflow-TensorRT integration.""" - - def setUp(self): - """Setup method.""" - super(TfTrtIntegrationTest, self).setUp() - warnings.simplefilter("always") - self._input = np.random.random_sample(INPUT_DIMS) - - def _GetConfigProto(self, - use_optimizer, - precision_mode=None, - is_dynamic_op=None): - if use_optimizer: - rewriter_cfg = rewriter_config_pb2.RewriterConfig() - rewriter_cfg.optimizers.extend(["constfold", "layout"]) - custom_op = rewriter_cfg.custom_optimizers.add() - custom_op.name = "TensorRTOptimizer" - custom_op.parameter_map["minimum_segment_size"].i = 3 - custom_op.parameter_map["max_batch_size"].i = self._input.shape[0] - custom_op.parameter_map["is_dynamic_op"].b = is_dynamic_op - custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25 - custom_op.parameter_map["precision_mode"].s = to_bytes(precision_mode) - graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) - else: - graph_options = config_pb2.GraphOptions() - - gpu_options = config_pb2.GPUOptions() - if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: - gpu_options.per_process_gpu_memory_fraction = 0.50 - - config = config_pb2.ConfigProto( - gpu_options=gpu_options, graph_options=graph_options) - return config - - def _RunGraph(self, graph_key, gdef, input_data, config, num_runs=2): - """Run given graphdef multiple times.""" - g = ops.Graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=gdef, return_elements=[INPUT_NAME, OUTPUT_NAME], name="") - inp = inp.outputs[0] - out = out.outputs[0] - with self.test_session( - graph=g, config=config, use_gpu=True, force_gpu=True) as sess: - val = None - # Defaults to 2 runs to verify result across multiple runs is same. - for _ in range(num_runs): - new_val = sess.run(out, {inp: input_data}) - self.assertEquals(TEST_GRAPHS[graph_key].expected_output_dims, - new_val.shape) - if val is not None: - self.assertAllEqual(new_val, val) - val = new_val - return val - - # Use real data that is representative of the inference dataset - # for calibration. For this test script it is random data. - def _RunCalibration(self, graph_key, gdef, input_data, config): - """Run calibration on given graph.""" - return self._RunGraph(graph_key, gdef, input_data, config, 30) - - def _GetTrtGraph(self, gdef, precision_mode, is_dynamic_op): - """Return trt converted graph.""" - return trt.create_inference_graph( - input_graph_def=gdef, - outputs=[OUTPUT_NAME], - max_batch_size=self._input.shape[0], - max_workspace_size_bytes=1 << 25, - precision_mode=precision_mode, - minimum_segment_size=2, - is_dynamic_op=is_dynamic_op) - - def _VerifyGraphDef(self, - graph_key, - gdef, - precision_mode=None, - is_calibrated=None, - dynamic_engine=None): - num_engines = 0 - for n in gdef.node: - if n.op == "TRTEngineOp": - num_engines += 1 - self.assertNotEqual("", n.attr["serialized_segment"].s) - self.assertNotEqual("", n.attr["segment_funcdef_name"].s) - self.assertEquals(n.attr["precision_mode"].s, precision_mode) - self.assertEquals(n.attr["static_engine"].b, not dynamic_engine) - if precision_mode == MODE_INT8 and is_calibrated: - self.assertNotEqual("", n.attr["calibration_data"].s) - else: - self.assertEquals("", n.attr["calibration_data"].s) - if precision_mode is None: - self.assertEquals(num_engines, 0) - else: - self.assertEquals(num_engines, - TEST_GRAPHS[graph_key].num_expected_engines) - - def _RunTest(self, graph_key, use_optimizer, precision_mode, - dynamic_infer_engine, dynamic_calib_engine): - assert precision_mode in [MODE_FP32, MODE_FP16, MODE_INT8] - input_gdef = TEST_GRAPHS[graph_key].gdef - self._VerifyGraphDef(graph_key, input_gdef) - - # Get reference result without running trt. - config_no_trt = self._GetConfigProto(False) - print("Running original graph w/o trt, config:\n%s" % str(config_no_trt)) - ref_result = self._RunGraph(graph_key, input_gdef, self._input, - config_no_trt) - - # Run calibration if necessary. - if precision_mode == MODE_INT8: - - calib_config = self._GetConfigProto(use_optimizer, precision_mode, - dynamic_calib_engine) - print("Running calibration graph, config:\n%s" % str(calib_config)) - if use_optimizer: - self.assertTrue(False) - # TODO(aaroey): uncomment this and get infer_gdef when this mode is - # supported. - # result = self._RunCalibration(graph_key, input_gdef, self._input, - # calib_config) - else: - calib_gdef = self._GetTrtGraph(input_gdef, precision_mode, - dynamic_calib_engine) - self._VerifyGraphDef(graph_key, calib_gdef, precision_mode, False, - dynamic_calib_engine) - result = self._RunCalibration(graph_key, calib_gdef, self._input, - calib_config) - infer_gdef = trt.calib_graph_to_infer_graph(calib_gdef) - self._VerifyGraphDef(graph_key, infer_gdef, precision_mode, True, - dynamic_calib_engine) - self.assertAllClose(ref_result, result, rtol=1.e-03) - else: - infer_gdef = input_gdef - - # Run inference. - infer_config = self._GetConfigProto(use_optimizer, precision_mode, - dynamic_infer_engine) - print("Running final inference graph, config:\n%s" % str(infer_config)) - if use_optimizer: - result = self._RunGraph(graph_key, infer_gdef, self._input, infer_config) - else: - trt_infer_gdef = self._GetTrtGraph(infer_gdef, precision_mode, - dynamic_infer_engine) - self._VerifyGraphDef(graph_key, trt_infer_gdef, precision_mode, True, - dynamic_infer_engine) - result = self._RunGraph(graph_key, trt_infer_gdef, self._input, - infer_config) - self.assertAllClose(ref_result, result, rtol=1.e-03) - - def testIdempotence(self): - # Test that applying tensorrt optimizer or offline conversion tools multiple - # times to the same graph will result in same graph. - # TODO(aaroey): implement this. - pass - - -def GetTests(): - - def _GetTest(g, u, p, i, c): - - def _Test(self): - print("Running test with parameters: graph_key=%s, use_optimizer=%s, " - "precision_mode=%s, dynamic_infer_engine=%s, " - "dynamic_calib_engine=%s" % (g, u, p, i, c)) - self._RunTest(g, u, p, i, c) - - return _Test - - use_optimizer_options = [False, True] - precision_mode_options = [MODE_FP32, MODE_FP16, MODE_INT8] - dynamic_infer_engine_options = [False, True] - dynamic_calib_engine_options = [False, True] - for (graph_key, use_optimizer, precision_mode, - dynamic_infer_engine, dynamic_calib_engine) in itertools.product( - TEST_GRAPHS, use_optimizer_options, precision_mode_options, - dynamic_infer_engine_options, dynamic_calib_engine_options): - if precision_mode == MODE_INT8: - if not dynamic_calib_engine and dynamic_infer_engine: - # TODO(aaroey): test this case, the conversion from static calibration - # engine to dynamic inference engine should be a noop. - continue - if use_optimizer: - # TODO(aaroey): if use_optimizer is True we need to get the inference - # graphdef using custom python wrapper class, which is not currently - # supported yet. - continue - if not dynamic_calib_engine: - # TODO(aaroey): construction of static calibration engine is not - # supported yet. - continue - if dynamic_calib_engine and not dynamic_infer_engine: - # TODO(aaroey): construction of static inference engine using dynamic - # calibration engine is not supported yet. - continue - else: # In non int8 mode. - if dynamic_calib_engine: - # dynamic_calib_engine doesn't affect non-int8 modes, so just let - # related tests run once on dynamic_calib_engine=False. - continue - yield _GetTest(graph_key, use_optimizer, precision_mode, - dynamic_infer_engine, dynamic_calib_engine) - - -if __name__ == "__main__": - for index, t in enumerate(GetTests()): - setattr(TfTrtIntegrationTest, "testTfTRT_" + str(index), t) - test.main() diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..6f85ada4649563d099c6054e8e17da27954071f7 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py @@ -0,0 +1,461 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple +import itertools +import os +import warnings +import numpy as np +import six + +from tensorflow.contrib.tensorrt.python import trt_convert +# pylint: disable=unused-import +from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +# pylint: enable=unused-import +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.framework import graph_io +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import tf_logging as logging + +TfTrtIntegrationTestParams = namedtuple("TfTrtIntegrationTestParams", [ + "gdef", "input_names", "input_dims", "expected_engines", + "expected_output_dims", "allclose_atol", "allclose_rtol" +]) + +RunParams = namedtuple( + "RunParams", + ["use_optimizer", "precision_mode", "dynamic_engine", "test_name"]) + +PRECISION_MODES = ["FP32", "FP16", "INT8"] + + +def _IsQuantizationMode(mode): + return mode == "INT8" + + +class GraphState(object): + ORIGINAL = 0 + CALIBRATE = 1 + INFERENCE = 2 + + +class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): + """Class to test Tensorflow-TensorRT integration.""" + + @property + def output_name(self): + return "output" + + @property + def trt_incompatible_op(self): + return math_ops.sin + + @property + def precision_modes(self): + return ["FP32", "FP16", "INT8"] + + # str is bytes in py2, but unicode in py3. + def _ToUnicode(self, s): + if six.PY2: + if isinstance(s, unicode): + return s + return s.decode("utf-8") + else: + if isinstance(s, str): + return s + return s.decode("utf-8") + + def _ToBytes(self, s): + if six.PY2: + if isinstance(s, unicode): + return s.encode("utf-8") + return s + else: + if isinstance(s, str): + return s.encode("utf-8") + return s + + def _ToString(self, s): + if six.PY2: + if isinstance(s, unicode): + return s.encode("utf-8") + return s + else: + if isinstance(s, str): + return s + return s.decode("utf-8") + + @classmethod + def setUpClass(cls): + """Setup method for the module.""" + super(TfTrtIntegrationTestBase, cls).setUpClass() + trt_convert.enable_test_value() + + def setUp(self): + """Setup method.""" + super(TfTrtIntegrationTestBase, self).setUp() + warnings.simplefilter("always") + trt_convert.clear_test_values("") + + def GetParams(self): + """Return a TfTrtIntegrationTestParams for test, implemented by subclass.""" + raise NotImplementedError() + + def _PrepareRun(self, params, graph_state): + """Set up necessary testing environment before calling sess.run().""" + # Clear test values added by TRTEngineOp. + trt_convert.clear_test_values("my_trt_op_.*:ExecuteTrtEngine") + trt_convert.clear_test_values("my_trt_op_.*:ExecuteCalibration") + trt_convert.clear_test_values("my_trt_op_.*:ExecuteNativeSegment") + + def _VerifyRun(self, params, graph_state): + """Verify the state after sess.run().""" + for engine_name in params.expected_engines: + if graph_state == GraphState.ORIGINAL: + self._ExpectCalibration(engine_name, "") + self._ExpectNativeSegment(engine_name, "") + self._ExpectTrtEngine(engine_name, "") + elif graph_state == GraphState.CALIBRATE: + self._ExpectCalibration(engine_name, "done") + self._ExpectNativeSegment(engine_name, "done") + self._ExpectTrtEngine(engine_name, "") + elif graph_state == GraphState.INFERENCE: + self._ExpectCalibration(engine_name, "") + self._ExpectNativeSegment(engine_name, "") + self._ExpectTrtEngine(engine_name, "done") + + def _GetConfigProto(self, params, run_params, graph_state): + """Get config proto based on specific settings.""" + if graph_state != GraphState.ORIGINAL and run_params.use_optimizer: + rewriter_cfg = rewriter_config_pb2.RewriterConfig() + rewriter_cfg.optimizers.extend(["constfold", "layout"]) + custom_op = rewriter_cfg.custom_optimizers.add() + custom_op.name = "TensorRTOptimizer" + custom_op.parameter_map["minimum_segment_size"].i = 2 + custom_op.parameter_map["max_batch_size"].i = max( + [dims[0] for dims in params.input_dims]) + custom_op.parameter_map["is_dynamic_op"].b = run_params.dynamic_engine + custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25 + custom_op.parameter_map["precision_mode"].s = self._ToBytes( + run_params.precision_mode) + graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) + else: + graph_options = config_pb2.GraphOptions() + + gpu_options = config_pb2.GPUOptions() + gpu_options.allow_growth = True + if trt_convert.get_linked_tensorrt_version()[0] == 3: + gpu_options.per_process_gpu_memory_fraction = 0.50 + + config = config_pb2.ConfigProto( + gpu_options=gpu_options, graph_options=graph_options) + return config + + def _ExpectTestValue(self, engine_name, method, expected_value): + label = "%s:%s" % (engine_name, method) + actual_value = trt_convert.get_test_value(label) + self.assertEqual( + expected_value, + actual_value, + msg="Unexpected test value with label %s. Actual: %s; expected: %s" % + (label, actual_value, expected_value)) + + def _ExpectCalibration(self, engine_name, value): + self._ExpectTestValue(engine_name, "ExecuteCalibration", value) + + def _ExpectTrtEngine(self, engine_name, value): + self._ExpectTestValue(engine_name, "ExecuteTrtEngine", value) + + def _ExpectNativeSegment(self, engine_name, value): + self._ExpectTestValue(engine_name, "ExecuteNativeSegment", value) + + def _RunGraph(self, params, gdef, input_data, config, graph_state, + num_runs=2): + """Run given graphdef multiple times.""" + assert len(params.input_names) == len(input_data) + g = ops.Graph() + with g.as_default(): + io_ops = importer.import_graph_def( + graph_def=gdef, + return_elements=params.input_names + [self.output_name], + name="") + inp = [i.outputs[0] for i in io_ops[:-1]] + assert len(inp) == len(input_data) + out = io_ops[-1].outputs[0] + with self.test_session( + graph=g, config=config, use_gpu=True, force_gpu=True) as sess: + val = None + # Defaults to 2 runs to verify result across multiple runs is same. + for _ in range(num_runs): + self._PrepareRun(params, graph_state) + new_val = sess.run(out, + {inp[i]: input_data[i] for i in range(len(inp))}) + self.assertEqual(params.expected_output_dims, new_val.shape) + if val is not None: + self.assertAllEqual(val, new_val) + val = new_val + self._VerifyRun(params, graph_state) + return val + + # Use real data that is representative of the inference dataset + # for calibration. For this test script it is random data. + def _RunCalibration(self, params, gdef, input_data, config): + """Run calibration on given graph.""" + return self._RunGraph( + params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5) + + def _GetTrtGraphDef(self, params, run_params, gdef): + """Return trt converted graphdef.""" + return trt_convert.create_inference_graph( + input_graph_def=gdef, + outputs=[self.output_name], + max_batch_size=max([dims[0] for dims in params.input_dims]), + max_workspace_size_bytes=1 << 25, + precision_mode=run_params.precision_mode, + minimum_segment_size=2, + is_dynamic_op=run_params.dynamic_engine) + + def _WriteGraph(self, params, run_params, gdef, graph_state): + if graph_state == GraphState.ORIGINAL: + label = "Original" + elif graph_state == GraphState.CALIBRATE: + label = "CalibEngine" + elif graph_state == GraphState.INFERENCE: + label = "InferEngine" + graph_name = ( + self.__class__.__name__ + "_" + run_params.test_name + "_" + label + + ".pbtxt") + temp_dir = os.getenv("TRT_TEST_TMPDIR", self.get_temp_dir()) + logging.info("Writing graph to %s/%s", temp_dir, graph_name) + graph_io.write_graph(gdef, temp_dir, graph_name) + + def _VerifyConnections(self, params, converted_gdef): + old_to_new_node_map = { + self._ToString(node.name): self._ToString(node.name) + for node in params.gdef.node + } + for engine_name, node_names in params.expected_engines.items(): + for node_name in node_names: + old_to_new_node_map[node_name] = engine_name + name_to_node_map = { + self._ToString(node.name): node for node in params.gdef.node + } + + def _InputName(inp): + inp = self._ToString(inp) + prefix = "" + if inp[0] == "^": + prefix = "^" + inp = inp[1:] + parts = inp.split(":") + if len(parts) > 1 and parts[-1].isdigit(): + inp = inp[:-len(parts[-1]) - 1] + return (prefix, inp) + + expected_input_map = {} + for node in params.gdef.node: + name_str = self._ToString(node.name) + target_node_name = old_to_new_node_map[name_str] + is_engine_op = (target_node_name != name_str) + if target_node_name not in expected_input_map: + expected_input_map[target_node_name] = set() + input_set = expected_input_map[target_node_name] + for inp in node.input: + (prefix, inp_name) = _InputName(inp) + # Add the input only if it's outside the segment (note that it could be + # in a different engine). + if (not is_engine_op or + old_to_new_node_map[inp_name] != target_node_name): + if is_engine_op and name_to_node_map[inp_name].op == "Const": + # Const data input nodes to the segment has been copied to the + # segment graphdef and the engine, and the dependency has been + # converted to control dependendy. + input_set.add("^" + old_to_new_node_map[inp_name]) + else: + input_set.add(prefix + old_to_new_node_map[inp_name]) + + actual_input_map = {} + for node in converted_gdef.node: + name_str = self._ToString(node.name) + actual_input_map[name_str] = set() + input_set = actual_input_map[name_str] + for inp in node.input: + (prefix, node_name) = _InputName(inp) + input_set.add(prefix + node_name) + + self.assertEqual( + expected_input_map, + actual_input_map, + msg="expected:\n%s\nvs actual:\n%s" % (sorted( + expected_input_map.items()), sorted(actual_input_map.items()))) + + def _VerifyGraphDef(self, params, run_params, gdef, graph_state): + self._WriteGraph(params, run_params, gdef, graph_state) + + num_engines = 0 + for node in gdef.node: + if node.op == "TRTEngineOp": + num_engines += 1 + self.assertTrue(node.name in params.expected_engines) + self.assertTrue(len(node.attr["serialized_segment"].s)) + self.assertTrue(len(node.attr["segment_funcdef_name"].s)) + self.assertEqual( + self._ToBytes(run_params.precision_mode), + node.attr["precision_mode"].s) + + is_dynamic_engine = not node.attr["static_engine"].b + self.assertEqual(run_params.dynamic_engine, is_dynamic_engine) + + has_calibration_data = len(node.attr["calibration_data"].s) + if (_IsQuantizationMode(run_params.precision_mode) and + graph_state == GraphState.INFERENCE): + self.assertTrue(has_calibration_data) + else: + self.assertFalse(has_calibration_data) + if graph_state == GraphState.ORIGINAL: + self.assertEqual(0, num_engines) + else: + self.assertEqual(num_engines, len(params.expected_engines)) + if isinstance(params.expected_engines, dict): + self._VerifyConnections(params, gdef) + # TODO(aaroey): consider verifying the corresponding TF function. + + def RunTest(self, params, run_params): + assert run_params.precision_mode in PRECISION_MODES + input_data = [np.random.random_sample(dims) for dims in params.input_dims] + input_gdef = params.gdef + self._VerifyGraphDef(params, run_params, input_gdef, GraphState.ORIGINAL) + + # Get reference result without running trt. + config_no_trt = self._GetConfigProto(params, run_params, + GraphState.ORIGINAL) + logging.info("Running original graph w/o trt, config:\n%s", + str(config_no_trt)) + ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt, + GraphState.ORIGINAL) + + # Run calibration if necessary. + if _IsQuantizationMode(run_params.precision_mode): + + calib_config = self._GetConfigProto(params, run_params, + GraphState.CALIBRATE) + logging.info("Running calibration graph, config:\n%s", str(calib_config)) + if run_params.use_optimizer: + result = self._RunCalibration(params, input_gdef, input_data, + calib_config) + else: + calib_gdef = self._GetTrtGraphDef(params, run_params, input_gdef) + self._VerifyGraphDef(params, run_params, calib_gdef, + GraphState.CALIBRATE) + result = self._RunCalibration(params, calib_gdef, input_data, + calib_config) + infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef) + self._VerifyGraphDef(params, run_params, infer_gdef, GraphState.INFERENCE) + + self.assertAllClose( + ref_result, + result, + atol=params.allclose_atol, + rtol=params.allclose_rtol) + else: + infer_gdef = input_gdef + + # Run inference. + infer_config = self._GetConfigProto(params, run_params, + GraphState.INFERENCE) + logging.info("Running final inference graph, config:\n%s", + str(infer_config)) + if run_params.use_optimizer: + result = self._RunGraph(params, infer_gdef, input_data, infer_config, + GraphState.INFERENCE) + else: + trt_infer_gdef = self._GetTrtGraphDef(params, run_params, infer_gdef) + self._VerifyGraphDef(params, run_params, trt_infer_gdef, + GraphState.INFERENCE) + result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config, + GraphState.INFERENCE) + + self.assertAllClose( + ref_result, + result, + atol=params.allclose_atol, + rtol=params.allclose_rtol) + + def testIdempotence(self): + # Test that applying tensorrt optimizer or offline conversion tools multiple + # times to the same graph will result in same graph. + # + # TODO(aaroey): currently the conversion is not deterministic, this is + # mainly because during tensorflow::ConvertGraphDefToGraph(), the graph uses + # EdgeSet which use a map keyed by Edge*, so the order of input/output edges + # of a node is nondeterministic, thus the order for segmenter to contract + # edges is nondeterministic. Need to evaluate whether we should fix this. + pass + + +def _AddTests(test_class): + """Adds test methods to TfTrtIntegrationTestBase.""" + + def _GetTest(run_params): + """Gets a single test method based on the parameters.""" + + def _Test(self): + params = self.GetParams() + logging.info( + "Running test %s with parameters: use_optimizer=%s, " + "precision_mode=%s, dynamic_engine=%s", + "testTfTrt_" + run_params.test_name, run_params.use_optimizer, + run_params.precision_mode, run_params.dynamic_engine) + self.RunTest(params, run_params) + + return _Test + + use_optimizer_options = [False, True] + dynamic_engine_options = [False, True] + for (use_optimizer, precision_mode, dynamic_engine) in itertools.product( + use_optimizer_options, PRECISION_MODES, dynamic_engine_options): + if _IsQuantizationMode(precision_mode): + if use_optimizer: + # TODO(aaroey): if use_optimizer is True we need to get the inference + # graphdef using custom python wrapper class, which is not currently + # supported yet. + continue + if not dynamic_engine: + # TODO(aaroey): construction of static calibration engine is not + # supported yet. + continue + + conversion = "OptimizerConversion" if use_optimizer else "ToolConversion" + engine_type = ("DynamicEngine" if dynamic_engine else "StaticEngine") + test_name = "%s_%s_%s" % (conversion, precision_mode, engine_type) + run_params = RunParams( + use_optimizer=use_optimizer, + precision_mode=precision_mode, + dynamic_engine=dynamic_engine, + test_name=test_name) + setattr(test_class, "testTfTrt_" + test_name, _GetTest(run_params)) + + +if trt_convert.is_tensorrt_enabled(): + _AddTests(TfTrtIntegrationTestBase) diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py new file mode 100644 index 0000000000000000000000000000000000000000..500057a36d60efa3b7f96f22e27973444ecc277c --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/unary_test.py @@ -0,0 +1,113 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class UnaryTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Test for unary operations in TF-TRT.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [12, 5, 8, 1, 1, 12] + input2_name = "input_2" + input2_dims = [12, 5, 8, 1, 12, 1, 1] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + q = math_ops.abs(x) + q = q + 1.0 + q = gen_math_ops.exp(q) + q = gen_math_ops.log(q) + q = array_ops.squeeze(q, axis=-2) + q = math_ops.abs(q) + q = q + 2.2 + q = gen_math_ops.sqrt(q) + q = gen_math_ops.rsqrt(q) + q = math_ops.negative(q) + q = array_ops.squeeze(q, axis=3) + q = math_ops.abs(q) + q = q + 3.0 + a = gen_math_ops.reciprocal(q) + + x = constant_op.constant(np.random.randn(5, 8, 12), dtype=dtype) + q = math_ops.abs(x) + q = q + 2.0 + q = gen_math_ops.exp(q) + q = gen_math_ops.log(q) + q = math_ops.abs(q) + q = q + 2.1 + q = gen_math_ops.sqrt(q) + q = gen_math_ops.rsqrt(q) + q = math_ops.negative(q) + q = math_ops.abs(q) + q = q + 4.0 + b = gen_math_ops.reciprocal(q) + + # TODO(jie): this one will break, broadcasting on batch. + x = array_ops.placeholder( + dtype=dtype, shape=input2_dims, name=input2_name) + q = math_ops.abs(x) + q = q + 5.0 + q = gen_math_ops.exp(q) + q = array_ops.squeeze(q, axis=[-1, -2, 3]) + q = gen_math_ops.log(q) + q = math_ops.abs(q) + q = q + 5.1 + q = gen_array_ops.reshape(q, [12, 5, 1, 1, 8, 1, 12]) + q = array_ops.squeeze(q, axis=[5, 2, 3]) + q = gen_math_ops.sqrt(q) + q = math_ops.abs(q) + q = q + 5.2 + q = gen_math_ops.rsqrt(q) + q = math_ops.negative(q) + q = math_ops.abs(q) + q = q + 5.3 + c = gen_math_ops.reciprocal(q) + + q = a * b + q = q / c + array_ops.squeeze(q, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name, input2_name], + input_dims=[input_dims, input2_dims], + expected_engines=[ + "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3", + "my_trt_op_4" + ], + expected_output_dims=(12, 5, 8, 12), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/utils.cc b/tensorflow/contrib/tensorrt/test/utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..276308b3a0a6ce864969afb0179c6a3f00d6b70b --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/utils.cc @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/test/utils.h" + +#include +#include + +#include "re2/re2.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace tensorrt { +namespace test { + +// TODO(aaroey): make this class thread-safe. +class TestValueManager { + public: + static TestValueManager* singleton() { + static TestValueManager* manager = new TestValueManager(); + return manager; + } + + void Enable() { + VLOG(1) << "Enabling test value"; + enabled_ = true; + } + + void Add(const string& label, const string& value) { + if (TF_PREDICT_FALSE(enabled_)) { + QCHECK_NE("", value); + VLOG(1) << "Adding test value: " << label << " -> " << value; + values_.insert({label, value}); + } + } + + string Get(const string& label) { + if (TF_PREDICT_FALSE(enabled_)) { + VLOG(1) << "Getting test value by " << label; + auto itr = values_.find(label); + if (itr == values_.end()) return ""; + return itr->second; + } + return ""; + } + + void Clear(const string& pattern) { + if (TF_PREDICT_FALSE(enabled_)) { + VLOG(1) << "Clearing test values"; + if (pattern.empty()) { + values_.clear(); + return; + } + std::vector keys_to_clear; + for (const auto& kv : values_) { + if (RE2::FullMatch(kv.first, pattern)) { + keys_to_clear.push_back(kv.first); + } + } + for (const string& key : keys_to_clear) { + values_.erase(key); + } + } + } + + private: + TestValueManager() : enabled_(false) {} + + bool enabled_; + std::unordered_map values_; +}; + +void EnableTestValue() { TestValueManager::singleton()->Enable(); } + +void ClearTestValues(const string& pattern) { + TestValueManager::singleton()->Clear(pattern); +} + +void AddTestValue(const string& label, const string& value) { + TestValueManager::singleton()->Add(label, value); +} + +string GetTestValue(const string& label) { + return TestValueManager::singleton()->Get(label); +} + +} // namespace test +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/test/utils.h b/tensorflow/contrib/tensorrt/test/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..4bb4120206cfaae70107e55d1818e3af2f02717a --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/utils.h @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace tensorrt { +namespace test { + +// Helper methods to inject values used by testing tools. +void EnableTestValue(); +void ClearTestValues(const string& pattern); +void AddTestValue(const string& label, const string& value); +string GetTestValue(const string& label); + +#define TRT_RETURN_IF_TEST_VALUE(label, value_to_return) \ + do { \ + if (::tensorflow::tensorrt::test::GetTestValue(label) == \ + value_to_return) { \ + return errors::Internal("Injected manually"); \ + } \ + } while (0) + +} // namespace test +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORRT_TEST_UTILS_H_ diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ab4d224db4d88c91c9b06d278b404879d989a834 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py @@ -0,0 +1,82 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_impl +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + + +class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Single vgg layer in NCHW unit tests in TF-TRT.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [5, 2, 8, 8] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + x, _, _ = nn_impl.fused_batch_norm( + x, + np.random.randn(2).astype(np.float32), + np.random.randn(2).astype(np.float32), + mean=np.random.randn(2).astype(np.float32), + variance=np.random.randn(2).astype(np.float32), + data_format="NCHW", + is_training=False) + e = constant_op.constant( + np.random.randn(1, 1, 2, 6), name="weights", dtype=dtype) + conv = nn.conv2d( + input=x, + filter=e, + data_format="NCHW", + strides=[1, 1, 2, 2], + padding="SAME", + name="conv") + b = constant_op.constant(np.random.randn(6), name="bias", dtype=dtype) + t = nn.bias_add(conv, b, data_format="NCHW", name="biasAdd") + relu = nn.relu(t, "relu") + idty = array_ops.identity(relu, "ID") + v = nn_ops.max_pool( + idty, [1, 1, 2, 2], [1, 1, 2, 2], + "VALID", + data_format="NCHW", + name="max_pool") + array_ops.squeeze(v, name="output") + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines=["my_trt_op_0"], + expected_output_dims=(5, 6, 2, 2), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_test.py new file mode 100644 index 0000000000000000000000000000000000000000..56bdf848eadbdde3d5896e415ecd9754ed387eeb --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/vgg_block_test.py @@ -0,0 +1,73 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_impl +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + + +class VGGBlockTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Single vgg layer test in TF-TRT conversion.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [5, 8, 8, 2] + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) + x, _, _ = nn_impl.fused_batch_norm( + x, + np.random.randn(2).astype(np.float32), + np.random.randn(2).astype(np.float32), + mean=np.random.randn(2).astype(np.float32), + variance=np.random.randn(2).astype(np.float32), + is_training=False) + e = constant_op.constant( + np.random.randn(1, 1, 2, 6), name="weights", dtype=dtype) + conv = nn.conv2d( + input=x, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv") + b = constant_op.constant(np.random.randn(6), name="bias", dtype=dtype) + t = nn.bias_add(conv, b, name="biasAdd") + relu = nn.relu(t, "relu") + idty = array_ops.identity(relu, "ID") + v = nn_ops.max_pool( + idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") + array_ops.squeeze(v, name="output") + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + expected_engines=["my_trt_op_0"], + expected_output_dims=(5, 2, 2, 6), + allclose_atol=1.e-03, + allclose_rtol=1.e-03) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i index d6628cd1eb69e46b188de613dee803a2e0dd07d4..6ea15fb8eff13663625420288a37ba002d57fa47 100644 --- a/tensorflow/contrib/tensorrt/trt_conversion.i +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -100,81 +100,23 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/util/stat_summarizer.h" #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/contrib/tensorrt/test/utils.h" %} %ignoreall %unignore tensorflow; -%unignore trt_convert; %unignore calib_convert; %unignore get_linked_tensorrt_version; %unignore get_loaded_tensorrt_version; +%unignore is_tensorrt_enabled; +%unignore enable_test_value; +%unignore clear_test_values; +%unignore add_test_value; +%unignore get_test_value; %{ -std::pair trt_convert( - string graph_def_string, // The serialized GraphDef string. - std::vector output_names, - size_t max_batch_size, - size_t max_workspace_size_bytes, - int precision_mode, - int minimum_segment_size, - bool is_dyn_op, - int max_cached_engines, - std::vector cached_engine_batches - // Unfortunately we can't use TF_Status here since it - // is in c/c_api and brings in a lot of other libraries - // which in turn declare ops. These ops are included - // statically in our library and cause an abort when - // module is loaded due to double registration - // until Tensorflow properly exposes these headers - // we have to work around this by returning a string - // and converting it to exception on python side. - //,TF_Status* out_status) { -) { -#if GOOGLE_CUDA && GOOGLE_TENSORRT - string out_status; - - tensorflow::GraphDef graph_def; - if (!graph_def.ParseFromString(graph_def_string)) { - out_status = "InvalidArgument;Couldn't interpret input as a GraphDef"; - return std::pair{out_status, ""}; - } - - if(precision_mode < 0 || precision_mode > 2){ - out_status = "InvalidArgument;Invalid precision_mode"; - return std::pair{out_status, ""}; - } - if (!output_names.size()) { - out_status = "InvalidArgument;Size of the output_names vector is 0"; - return std::pair{out_status, ""}; - } - tensorflow::GraphDef out_graph; - tensorflow::Status conversion_status = - tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT( - graph_def, output_names, max_batch_size, max_workspace_size_bytes, - &out_graph, precision_mode, minimum_segment_size, - is_dyn_op, max_cached_engines, cached_engine_batches); - if (!conversion_status.ok()) { - auto retCode = (int)conversion_status.code(); - char buff[2000]; - snprintf(buff, 2000, "%d;%s", retCode, - conversion_status.error_message().c_str()); - out_status = buff; - return std::pair{out_status, ""}; - } - string result; - if (!out_graph.SerializeToString(&result)) { - out_status = "InvalidArgument;Couldn't serialize output as a GraphDef"; - return std::pair{out_status, ""}; - } - out_status = "OK;All good!"; - return std::pair{out_status, result}; -#else - // Returns FAILED_PRECONDITION. - return std::pair{"9;TensorRT is not enabled!", ""}; -#endif // GOOGLE_CUDA && GOOGLE_TENSORRT -} - std::pair calib_convert( string graph_def_string, bool is_dyn_op // unfortunately we can't use TF_Status here since it @@ -232,7 +174,8 @@ version_struct get_linked_tensorrt_version() { #endif // GOOGLE_CUDA && GOOGLE_TENSORRT return s; } -version_struct get_loaded_tensorrt_version(){ + +version_struct get_loaded_tensorrt_version() { // Return the version from the loaded library. version_struct s; #if GOOGLE_CUDA && GOOGLE_TENSORRT @@ -244,19 +187,48 @@ version_struct get_loaded_tensorrt_version(){ return s; } -%} +bool is_tensorrt_enabled() { + return tensorflow::tensorrt::IsGoogleTensorRTEnabled(); +} + +void enable_test_value() { + tensorflow::tensorrt::test::EnableTestValue(); +} + +#if PY_MAJOR_VERSION < 3 +#define TRT_PY_TO_CPP_STRING PyString_AsString +#define TRT_CPP_TO_PY_STRING PyString_FromString +#else +#define TRT_PY_TO_CPP_STRING PyUnicode_AsUTF8 +#define TRT_CPP_TO_PY_STRING PyUnicode_FromString +#endif -std::pair calib_convert(string graph_def_string, bool is_dyn_op); +void clear_test_values(PyObject* pattern) { + tensorflow::tensorrt::test::ClearTestValues( + string(TRT_PY_TO_CPP_STRING(pattern))); +} -std::pair trt_convert(string graph_def_string, - std::vector output_names, - size_t max_batch_size, - size_t max_workspace_size_bytes, - int precision_mode, int minimum_segment_size, - bool is_dyn_op, - int max_cached_engines, - std::vector cached_engine_batches); +void add_test_value(PyObject* label, PyObject* value) { + tensorflow::tensorrt::test::AddTestValue( + string(TRT_PY_TO_CPP_STRING(label)), string(TRT_PY_TO_CPP_STRING(value))); +} + +PyObject* get_test_value(PyObject* label) { + string value = tensorflow::tensorrt::test::GetTestValue( + string(TRT_PY_TO_CPP_STRING(label))); + return TRT_CPP_TO_PY_STRING(value.c_str()); +} + +%} + +std::pair calib_convert( + string graph_def_string, bool is_dyn_op); version_struct get_linked_tensorrt_version(); version_struct get_loaded_tensorrt_version(); +bool is_tensorrt_enabled(); +void enable_test_value(); +void clear_test_values(PyObject* pattern); +void add_test_value(PyObject* label, PyObject* value); +PyObject* get_test_value(PyObject* label); %unignoreall diff --git a/tensorflow/contrib/timeseries/__init__.py b/tensorflow/contrib/timeseries/__init__.py index 11db56b1b7a48b401efeece91283eb7084747c14..654a4db098757a969c2d298f7ed490083e63b9da 100644 --- a/tensorflow/contrib/timeseries/__init__.py +++ b/tensorflow/contrib/timeseries/__init__.py @@ -27,6 +27,9 @@ @@TrainEvalFeatures @@FilteringResults + +@@TimeSeriesRegressor +@@OneShotPredictionHead """ from __future__ import absolute_import diff --git a/tensorflow/contrib/timeseries/examples/multivariate.py b/tensorflow/contrib/timeseries/examples/multivariate.py index ed799542fd50cd150f13533c5f33bd67ed09fff6..e81cb18ad7b928a6fd2a748ea6b258c49cf722ae 100644 --- a/tensorflow/contrib/timeseries/examples/multivariate.py +++ b/tensorflow/contrib/timeseries/examples/multivariate.py @@ -80,8 +80,8 @@ def multivariate_train_and_sample( session=session, steps=1)) next_sample = numpy.random.multivariate_normal( # Squeeze out the batch and series length dimensions (both 1). - mean=numpy.squeeze(current_prediction["mean"], axis=[0, 1]), - cov=numpy.squeeze(current_prediction["covariance"], axis=[0, 1])) + mean=numpy.squeeze(current_prediction["mean"], axis=(0, 1)), + cov=numpy.squeeze(current_prediction["covariance"], axis=(0, 1))) # Update model state so that future predictions are conditional on the # value we just sampled. filtering_features = { diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index e4963596d38dbe8aea98fddbc67dbbf761c215c8..0e96c1fbd43ef45e9ff1e090a6d5489ab186484a 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -157,9 +157,11 @@ py_library( py_test( name = "head_test", + size = "large", srcs = [ "head_test.py", ], + shard_count = 4, srcs_version = "PY2AND3", tags = ["no_pip_gpu"], # b/63391119 deps = [ @@ -184,6 +186,7 @@ py_test( "//tensorflow/python/saved_model:loader", "//tensorflow/python/saved_model:tag_constants", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/timeseries/python/timeseries/__init__.py b/tensorflow/contrib/timeseries/python/timeseries/__init__.py index c683dad71de8f8502f08a4e823faa79d60d5604d..8462138339cda8557d9c9ee6e79d4c7a67ad1aa7 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/__init__.py +++ b/tensorflow/contrib/timeseries/python/timeseries/__init__.py @@ -24,5 +24,6 @@ from tensorflow.contrib.timeseries.python.timeseries import saved_model_utils from tensorflow.contrib.timeseries.python.timeseries.ar_model import * from tensorflow.contrib.timeseries.python.timeseries.estimators import * from tensorflow.contrib.timeseries.python.timeseries.feature_keys import * +from tensorflow.contrib.timeseries.python.timeseries.head import * from tensorflow.contrib.timeseries.python.timeseries.input_pipeline import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py index 4ec8d26116159fee3ac00581010d1603ac9e19f3..0ddc4b4144da25206735b0480aa0886374ed43a8 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py @@ -37,6 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.training import training as train from tensorflow.python.util import nest @@ -79,12 +80,137 @@ class TimeSeriesRegressor(estimator_lib.Estimator): model_dir=model_dir, config=config) - # TODO(allenl): A parsing input receiver function, which takes a serialized - # tf.Example containing all features (times, values, any exogenous features) - # and serialized model state (possibly also as a tf.Example). - def build_raw_serving_input_receiver_fn(self, - default_batch_size=None, - default_series_length=None): + def _model_start_state_placeholders( + self, batch_size_tensor, static_batch_size=None): + """Creates placeholders with zeroed start state for the current model.""" + gathered_state = {} + # Models may not know the shape of their state without creating some + # variables/ops. Avoid polluting the default graph by making a new one. We + # use only static metadata from the returned Tensors. + with ops.Graph().as_default(): + self._model.initialize_graph() + # Evaluate the initial state as same-dtype "zero" values. These zero + # constants aren't used, but are necessary for feeding to + # placeholder_with_default for the "cold start" case where state is not + # fed to the model. + def _zeros_like_constant(tensor): + return tensor_util.constant_value(array_ops.zeros_like(tensor)) + start_state = nest.map_structure( + _zeros_like_constant, self._model.get_start_state()) + for prefixed_state_name, state in ts_head_lib.state_to_dictionary( + start_state).items(): + state_shape_with_batch = tensor_shape.TensorShape( + (static_batch_size,)).concatenate(state.shape) + default_state_broadcast = array_ops.tile( + state[None, ...], + multiples=array_ops.concat( + [batch_size_tensor[None], + array_ops.ones(len(state.shape), dtype=dtypes.int32)], + axis=0)) + gathered_state[prefixed_state_name] = array_ops.placeholder_with_default( + input=default_state_broadcast, + name=prefixed_state_name, + shape=state_shape_with_batch) + return gathered_state + + def build_one_shot_parsing_serving_input_receiver_fn( + self, filtering_length, prediction_length, default_batch_size=None, + values_input_dtype=None, truncate_values=False): + """Build an input_receiver_fn for export_savedmodel accepting tf.Examples. + + Only compatible with `OneShotPredictionHead` (see `head`). + + Args: + filtering_length: The number of time steps used as input to the model, for + which values are provided. If more than `filtering_length` values are + provided (via `truncate_values`), only the first `filtering_length` + values are used. + prediction_length: The number of time steps requested as predictions from + the model. Times and all exogenous features must be provided for these + steps. + default_batch_size: If specified, must be a scalar integer. Sets the batch + size in the static shape information of all feature Tensors, which means + only this batch size will be accepted by the exported model. If None + (default), static shape information for batch sizes is omitted. + values_input_dtype: An optional dtype specification for values in the + tf.Example protos (either float32 or int64, since these are the numeric + types supported by tf.Example). After parsing, values are cast to the + model's dtype (float32 or float64). + truncate_values: If True, expects `filtering_length + prediction_length` + values to be provided, but only uses the first `filtering_length`. If + False (default), exactly `filtering_length` values must be provided. + + Returns: + An input_receiver_fn which may be passed to the Estimator's + export_savedmodel. + + Expects features contained in a vector of serialized tf.Examples with + shape [batch size] (dtype `tf.string`), each tf.Example containing + features with the following shapes: + times: [filtering_length + prediction_length] integer + values: [filtering_length, num features] floating point. If + `truncate_values` is True, expects `filtering_length + + prediction_length` values but only uses the first `filtering_length`. + all exogenous features: [filtering_length + prediction_length, ...] + (various dtypes) + """ + if values_input_dtype is None: + values_input_dtype = dtypes.float32 + if truncate_values: + values_proto_length = filtering_length + prediction_length + else: + values_proto_length = filtering_length + + def _serving_input_receiver_fn(): + """A receiver function to be passed to export_savedmodel.""" + times_column = feature_column.numeric_column( + key=feature_keys.TrainEvalFeatures.TIMES, dtype=dtypes.int64) + values_column = feature_column.numeric_column( + key=feature_keys.TrainEvalFeatures.VALUES, dtype=values_input_dtype, + shape=(self._model.num_features,)) + parsed_features_no_sequence = ( + feature_column.make_parse_example_spec( + list(self._model.exogenous_feature_columns) + + [times_column, values_column])) + parsed_features = {} + for key, feature_spec in parsed_features_no_sequence.items(): + if isinstance(feature_spec, parsing_ops.FixedLenFeature): + if key == feature_keys.TrainEvalFeatures.VALUES: + parsed_features[key] = feature_spec._replace( + shape=((values_proto_length,) + + feature_spec.shape)) + else: + parsed_features[key] = feature_spec._replace( + shape=((filtering_length + prediction_length,) + + feature_spec.shape)) + elif feature_spec.dtype == dtypes.string: + parsed_features[key] = parsing_ops.FixedLenFeature( + shape=(filtering_length + prediction_length,), + dtype=dtypes.string) + else: # VarLenFeature + raise ValueError("VarLenFeatures not supported, got %s for key %s" + % (feature_spec, key)) + tfexamples = array_ops.placeholder( + shape=[default_batch_size], dtype=dtypes.string, name="input") + features = parsing_ops.parse_example( + serialized=tfexamples, + features=parsed_features) + features[feature_keys.TrainEvalFeatures.TIMES] = array_ops.squeeze( + features[feature_keys.TrainEvalFeatures.TIMES], axis=-1) + features[feature_keys.TrainEvalFeatures.VALUES] = math_ops.cast( + features[feature_keys.TrainEvalFeatures.VALUES], + dtype=self._model.dtype)[:, :filtering_length] + features.update( + self._model_start_state_placeholders( + batch_size_tensor=array_ops.shape( + features[feature_keys.TrainEvalFeatures.TIMES])[0], + static_batch_size=default_batch_size)) + return export_lib.ServingInputReceiver( + features, {"examples": tfexamples}) + return _serving_input_receiver_fn + + def build_raw_serving_input_receiver_fn( + self, default_batch_size=None, default_series_length=None): """Build an input_receiver_fn for export_savedmodel which accepts arrays. Automatically creates placeholders for exogenous `FeatureColumn`s passed to @@ -149,34 +275,10 @@ class TimeSeriesRegressor(estimator_lib.Estimator): + batch_only_feature_shape[1:]) placeholders[feature_key] = array_ops.placeholder( dtype=value_dtype, name=feature_key, shape=feature_shape) - # Models may not know the shape of their state without creating some - # variables/ops. Avoid polluting the default graph by making a new one. We - # use only static metadata from the returned Tensors. - with ops.Graph().as_default(): - self._model.initialize_graph() - # Evaluate the initial state as same-dtype "zero" values. These zero - # constants aren't used, but are necessary for feeding to - # placeholder_with_default for the "cold start" case where state is not - # fed to the model. - def _zeros_like_constant(tensor): - return tensor_util.constant_value(array_ops.zeros_like(tensor)) - start_state = nest.map_structure( - _zeros_like_constant, self._model.get_start_state()) batch_size_tensor = array_ops.shape(time_placeholder)[0] - for prefixed_state_name, state in ts_head_lib.state_to_dictionary( - start_state).items(): - state_shape_with_batch = tensor_shape.TensorShape( - (default_batch_size,)).concatenate(state.shape) - default_state_broadcast = array_ops.tile( - state[None, ...], - multiples=array_ops.concat( - [batch_size_tensor[None], - array_ops.ones(len(state.shape), dtype=dtypes.int32)], - axis=0)) - placeholders[prefixed_state_name] = array_ops.placeholder_with_default( - input=default_state_broadcast, - name=prefixed_state_name, - shape=state_shape_with_batch) + placeholders.update( + self._model_start_state_placeholders( + batch_size_tensor, static_batch_size=default_batch_size)) return export_lib.ServingInputReceiver(placeholders, placeholders) return _serving_input_receiver_fn @@ -288,7 +390,7 @@ class StateSpaceRegressor(TimeSeriesRegressor): """An Estimator for general state space models.""" def __init__(self, model, state_manager=None, optimizer=None, model_dir=None, - config=None): + config=None, head_type=ts_head_lib.TimeSeriesRegressionHead): """See TimeSeriesRegressor. Uses the ChainingStateManager by default.""" if not isinstance(model, state_space_model.StateSpaceModel): raise ValueError( @@ -301,7 +403,8 @@ class StateSpaceRegressor(TimeSeriesRegressor): state_manager=state_manager, optimizer=optimizer, model_dir=model_dir, - config=config) + config=config, + head_type=head_type) class StructuralEnsembleRegressor(StateSpaceRegressor): @@ -344,7 +447,8 @@ class StructuralEnsembleRegressor(StateSpaceRegressor): anomaly_prior_probability=None, optimizer=None, model_dir=None, - config=None): + config=None, + head_type=ts_head_lib.TimeSeriesRegressionHead): """Initialize the Estimator. Args: @@ -401,6 +505,8 @@ class StructuralEnsembleRegressor(StateSpaceRegressor): from tf.train.Optimizer. Defaults to Adam with step size 0.02. model_dir: See `Estimator`. config: See `Estimator`. + head_type: The kind of head to use for the model (inheriting from + `TimeSeriesRegressionHead`). """ if anomaly_prior_probability is not None: filtering_postprocessor = StateInterpolatingAnomalyDetector( @@ -424,4 +530,5 @@ class StructuralEnsembleRegressor(StateSpaceRegressor): model=model, optimizer=optimizer, model_dir=model_dir, - config=config) + config=config, + head_type=head_type) diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py index f236329fdb038ba5ab432c6b97f44bda7ccfe815..32194e400e6ada594ef2a067bf612826a6e4acd3 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head.py @@ -19,24 +19,22 @@ from __future__ import print_function import re -from tensorflow.python.training import training_util -from tensorflow.contrib.layers.python.layers import optimizers - from tensorflow.contrib.timeseries.python.timeseries import feature_keys - from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.export import export_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.util import nest from tensorflow.python.summary import summary +from tensorflow.python.training import training_util +from tensorflow.python.util import nest class _NoStatePredictOutput(export_lib.PredictOutput): @@ -102,12 +100,9 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce use_resource=True): model_outputs = self.create_loss(features, mode) - train_op = optimizers.optimize_loss( + train_op = self.optimizer.minimize( model_outputs.loss, - global_step=training_util.get_global_step(), - optimizer=self.optimizer, - # Learning rate is set in the Optimizer object - learning_rate=None) + global_step=training_util.get_global_step()) return estimator_lib.EstimatorSpec( loss=model_outputs.loss, mode=mode, @@ -186,7 +181,7 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce return math_ops.cast(value, self.model.dtype) if name == feature_keys.PredictionFeatures.STATE_TUPLE: return value # Correct dtypes are model-dependent - return ops.convert_to_tensor(value) + return sparse_tensor.convert_to_tensor_or_sparse_tensor(value) def _gather_state(self, features): """Returns `features` with state packed, indicates if packing was done.""" @@ -208,6 +203,29 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce flat_sequence=[tensor for _, _, tensor in numbered_state]) return features, True + def _check_predict_features(self, features): + """Raises errors if features are not suitable for prediction.""" + if feature_keys.PredictionFeatures.TIMES not in features: + raise ValueError("Expected a '{}' feature for prediction.".format( + feature_keys.PredictionFeatures.TIMES)) + if feature_keys.PredictionFeatures.STATE_TUPLE not in features: + raise ValueError("Expected a '{}' feature for prediction.".format( + feature_keys.PredictionFeatures.STATE_TUPLE)) + times_feature = features[feature_keys.PredictionFeatures.TIMES] + if not times_feature.get_shape().is_compatible_with([None, None]): + raise ValueError( + ("Expected shape (batch dimension, window size) for feature '{}' " + "(got shape {})").format(feature_keys.PredictionFeatures.TIMES, + times_feature.get_shape())) + _check_feature_shapes_compatible_with( + features=features, + compatible_with_name=feature_keys.PredictionFeatures.TIMES, + compatible_with_value=times_feature, + ignore=set([ + # Model-dependent shapes + feature_keys.PredictionFeatures.STATE_TUPLE + ])) + def create_estimator_spec(self, features, mode, labels=None): """Performs basic error checking and returns an EstimatorSpec.""" with ops.name_scope(self._name, "head"): @@ -236,7 +254,7 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce mode == estimator_lib.ModeKeys.EVAL): _check_train_eval_features(features, self.model) elif mode == estimator_lib.ModeKeys.PREDICT: - _check_predict_features(features) + self._check_predict_features(features) else: raise ValueError("Unknown mode '{}' passed to model_fn.".format(mode)) @@ -273,6 +291,44 @@ class OneShotPredictionHead(TimeSeriesRegressionHead): each time predictions are requested when using this head. """ + def _check_predict_features(self, features): + """Raises errors if features are not suitable for one-shot prediction.""" + if feature_keys.PredictionFeatures.TIMES not in features: + raise ValueError("Expected a '{}' feature for prediction.".format( + feature_keys.PredictionFeatures.TIMES)) + if feature_keys.TrainEvalFeatures.VALUES not in features: + raise ValueError("Expected a '{}' feature for prediction.".format( + feature_keys.TrainEvalFeatures.VALUES)) + if feature_keys.PredictionFeatures.STATE_TUPLE not in features: + raise ValueError("Expected a '{}' feature for prediction.".format( + feature_keys.PredictionFeatures.STATE_TUPLE)) + times_feature = features[feature_keys.PredictionFeatures.TIMES] + if not times_feature.get_shape().is_compatible_with([None, None]): + raise ValueError( + ("Expected shape (batch dimension, window size) for feature '{}' " + "(got shape {})").format(feature_keys.PredictionFeatures.TIMES, + times_feature.get_shape())) + _check_feature_shapes_compatible_with( + features=features, + compatible_with_name=feature_keys.PredictionFeatures.TIMES, + compatible_with_value=times_feature, + ignore=set([ + # Model-dependent shapes + feature_keys.PredictionFeatures.STATE_TUPLE, + # One shot prediction head relies on values being shorter than + # times. Even though we're predicting eventually, we need values for + # the filtering phase. + feature_keys.TrainEvalFeatures.VALUES, + ])) + + def _evaluate_ops(self, features): + """Add ops for evaluation (aka filtering) to the graph.""" + spec = super(OneShotPredictionHead, self)._evaluate_ops(features) + # No state is fed to OneShotPredictionHead, so we don't return it; it being + # a tuple can cause issues for downstream infrastructure. + del spec.eval_metric_ops[feature_keys.State.STATE_TUPLE] + return spec + def _serving_ops(self, features): """Add ops for serving to the graph.""" with variable_scope.variable_scope("model", use_resource=True): @@ -339,29 +395,6 @@ def _check_feature_shapes_compatible_with(features, times_shape=compatible_with_value.get_shape())) -def _check_predict_features(features): - """Raises errors if features are not suitable for prediction.""" - if feature_keys.PredictionFeatures.TIMES not in features: - raise ValueError("Expected a '{}' feature for prediction.".format( - feature_keys.PredictionFeatures.TIMES)) - if feature_keys.PredictionFeatures.STATE_TUPLE not in features: - raise ValueError("Expected a '{}' feature for prediction.".format( - feature_keys.PredictionFeatures.STATE_TUPLE)) - times_feature = features[feature_keys.PredictionFeatures.TIMES] - if not times_feature.get_shape().is_compatible_with([None, None]): - raise ValueError( - ("Expected shape (batch dimension, window size) for feature '{}' " - "(got shape {})").format(feature_keys.PredictionFeatures.TIMES, - times_feature.get_shape())) - _check_feature_shapes_compatible_with( - features=features, - compatible_with_name=feature_keys.PredictionFeatures.TIMES, - compatible_with_value=times_feature, - ignore=set([ - feature_keys.PredictionFeatures.STATE_TUPLE # Model-dependent shapes - ])) - - def _check_train_eval_features(features, model): """Raise errors if features are not suitable for training/evaluation.""" if feature_keys.TrainEvalFeatures.TIMES not in features: diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py index ed8f29c321719e552c25f4d2183fdf4eb282e4b7..bda3b53aca0d0156e542e2bedcadf5caa6b3d2cf 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py @@ -18,17 +18,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools +import os + +from absl.testing import parameterized import numpy import six from tensorflow.contrib.estimator.python.estimator import extenders from tensorflow.contrib.timeseries.examples import lstm as lstm_example +from tensorflow.contrib.timeseries.python.timeseries import ar_model from tensorflow.contrib.timeseries.python.timeseries import estimators as ts_estimators from tensorflow.contrib.timeseries.python.timeseries import feature_keys from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib from tensorflow.contrib.timeseries.python.timeseries import input_pipeline from tensorflow.contrib.timeseries.python.timeseries import model from tensorflow.contrib.timeseries.python.timeseries import state_management +from tensorflow.core.example import example_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.estimator import estimator_lib @@ -317,10 +323,56 @@ class PredictFeatureCheckingTests(test.TestCase): mode=estimator_lib.ModeKeys.PREDICT) -class OneShotTests(test.TestCase): - - def test_one_shot_prediction_head_export(self): - model_dir = self.get_temp_dir() +def _custom_time_series_regressor( + model_dir, head_type, exogenous_feature_columns): + return ts_estimators.TimeSeriesRegressor( + model=lstm_example._LSTMModel( + num_features=5, num_units=128, + exogenous_feature_columns=exogenous_feature_columns), + optimizer=adam.AdamOptimizer(0.001), + config=estimator_lib.RunConfig(tf_random_seed=4), + state_manager=state_management.ChainingStateManager(), + head_type=head_type, + model_dir=model_dir) + + +def _structural_ensemble_regressor( + model_dir, head_type, exogenous_feature_columns): + return ts_estimators.StructuralEnsembleRegressor( + periodicities=None, + num_features=5, + exogenous_feature_columns=exogenous_feature_columns, + head_type=head_type, + model_dir=model_dir) + + +def _ar_lstm_regressor( + model_dir, head_type, exogenous_feature_columns): + return ts_estimators.TimeSeriesRegressor( + model=ar_model.ARModel( + periodicities=10, input_window_size=10, output_window_size=6, + num_features=5, + exogenous_feature_columns=exogenous_feature_columns, + prediction_model_factory=functools.partial( + ar_model.LSTMPredictionModel, + num_units=10)), + head_type=head_type, + model_dir=model_dir) + + +class OneShotTests(parameterized.TestCase): + + @parameterized.named_parameters( + {"testcase_name": "ar_lstm_regressor", + "estimator_factory": _ar_lstm_regressor}, + {"testcase_name": "custom_time_series_regressor", + "estimator_factory": _custom_time_series_regressor}, + {"testcase_name": "structural_ensemble_regressor", + "estimator_factory": _structural_ensemble_regressor}) + def test_one_shot_prediction_head_export(self, estimator_factory): + def _new_temp_dir(): + return os.path.join(test.get_temp_dir(), str(ops.uid())) + model_dir = _new_temp_dir() categorical_column = feature_column.categorical_column_with_hash_bucket( key="categorical_exogenous_feature", hash_bucket_size=16) exogenous_feature_columns = [ @@ -328,15 +380,10 @@ class OneShotTests(test.TestCase): "2d_exogenous_feature", shape=(2,)), feature_column.embedding_column( categorical_column=categorical_column, dimension=10)] - estimator = ts_estimators.TimeSeriesRegressor( - model=lstm_example._LSTMModel( - num_features=5, num_units=128, - exogenous_feature_columns=exogenous_feature_columns), - optimizer=adam.AdamOptimizer(0.001), - config=estimator_lib.RunConfig(tf_random_seed=4), - state_manager=state_management.ChainingStateManager(), - head_type=ts_head_lib.OneShotPredictionHead, - model_dir=model_dir) + estimator = estimator_factory( + model_dir=model_dir, + exogenous_feature_columns=exogenous_feature_columns, + head_type=ts_head_lib.OneShotPredictionHead) train_features = { feature_keys.TrainEvalFeatures.TIMES: numpy.arange( 20, dtype=numpy.int64), @@ -350,8 +397,10 @@ class OneShotTests(test.TestCase): input_pipeline.NumpyReader(train_features), shuffle_seed=2, num_threads=1, batch_size=16, window_size=16) estimator.train(input_fn=train_input_fn, steps=5) + result = estimator.evaluate(input_fn=train_input_fn, steps=1) + self.assertNotIn(feature_keys.State.STATE_TUPLE, result) input_receiver_fn = estimator.build_raw_serving_input_receiver_fn() - export_location = estimator.export_savedmodel(self.get_temp_dir(), + export_location = estimator.export_savedmodel(_new_temp_dir(), input_receiver_fn) graph = ops.Graph() with graph.as_default(): @@ -385,7 +434,42 @@ class OneShotTests(test.TestCase): for output_key, output_value in predict_signature.outputs.items()} output = session.run(fetches, feed_dict=feeds) - self.assertAllEqual((2, 15, 5), output["mean"].shape) + self.assertEqual((2, 15, 5), output["mean"].shape) + # Build a parsing input function, then make a tf.Example for it to parse. + export_location = estimator.export_savedmodel( + _new_temp_dir(), + estimator.build_one_shot_parsing_serving_input_receiver_fn( + filtering_length=20, prediction_length=15)) + graph = ops.Graph() + with graph.as_default(): + with session_lib.Session() as session: + example = example_pb2.Example() + times = example.features.feature[feature_keys.TrainEvalFeatures.TIMES] + values = example.features.feature[feature_keys.TrainEvalFeatures.VALUES] + times.int64_list.value.extend(range(35)) + for i in range(20): + values.float_list.value.extend( + [float(i) * 2. + feature_number + for feature_number in range(5)]) + real_feature = example.features.feature["2d_exogenous_feature"] + categortical_feature = example.features.feature[ + "categorical_exogenous_feature"] + for i in range(35): + real_feature.float_list.value.extend([1, 1]) + categortical_feature.bytes_list.value.append(b"strkey") + # Serialize the tf.Example for feeding to the Session + examples = [example.SerializeToString()] * 2 + signatures = loader.load( + session, [tag_constants.SERVING], export_location) + predict_signature = signatures.signature_def[ + feature_keys.SavedModelLabels.PREDICT] + ((_, input_value),) = predict_signature.inputs.items() + feeds = {graph.as_graph_element(input_value.name): examples} + fetches = {output_key: graph.as_graph_element(output_value.name) + for output_key, output_value + in predict_signature.outputs.items()} + output = session.run(fetches, feed_dict=feeds) + self.assertEqual((2, 15, 5), output["mean"].shape) if __name__ == "__main__": diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 16696793bc2dab977a3dbbfa338e33e5771d0699..1669f6050e7ca92d973a75258a6b57bc62facff2 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -15,8 +15,8 @@ package( default_visibility = [ "//cloud/vmm/testing/tests/tpu:__subpackages__", "//learning/brain:__subpackages__", + "//learning/deepmind:__subpackages__", "//tensorflow:__subpackages__", - "//third_party/cloud_tpu:__subpackages__", ], ) @@ -37,16 +37,17 @@ cc_library( py_library( name = "tpu_estimator", srcs = [ + "python/tpu/error_handling.py", "python/tpu/tpu_config.py", "python/tpu/tpu_context.py", "python/tpu/tpu_estimator.py", - "python/tpu/tpu_system_metadata.py", "python/tpu/util.py", ], srcs_version = "PY2AND3", deps = [ ":tpu_lib", - ":tpu_py", + "//tensorflow/compiler/xla/experimental/xla_sharding", + "//tensorflow/compiler/xla/python_api:xla_shape", "//tensorflow/contrib/training:training_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -133,7 +134,7 @@ py_library( tf_custom_op_py_library( name = "tpu_py", - srcs = glob(["python/ops/*.py"]) + ["__init__.py"], + srcs = glob(["python/ops/*.py"]), dso = [":python/ops/_tpu_ops.so"], kernels = [ ":all_ops", @@ -152,21 +153,62 @@ tf_custom_op_py_library( py_library( name = "tpu", - srcs = ["python/tpu/__init__.py"], + srcs = [ + "__init__.py", + "python/tpu/__init__.py", + ], srcs_version = "PY2AND3", deps = [ + ":keras_support", # split out to avoid cycle with tpu_strategy ":tpu_estimator", ":tpu_lib", ], ) +py_library( + name = "keras_support", + srcs = [ + "python/tpu/keras_support.py", + ], + srcs_version = "PY2AND3", + visibility = [ + "//cloud/vmm/testing/tests/tpu:__subpackages__", + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", + "//third_party/cloud_tpu/models/keras:__subpackages__", + ], + deps = [ + ":tpu_lib", + "//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py", + "//tensorflow/contrib/distribute", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:random_ops", + "//tensorflow/python:session", + "//tensorflow/python:tensor_spec", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/keras:backend", + "//tensorflow/python/keras:engine", + "//tensorflow/python/keras:layers", + "//third_party/py/numpy", + ], +) + py_library( name = "tpu_lib", srcs = [ "python/tpu/__init__.py", "python/tpu/bfloat16.py", "python/tpu/device_assignment.py", - "python/tpu/keras_support.py", "python/tpu/session_support.py", "python/tpu/topology.py", "python/tpu/tpu.py", @@ -174,6 +216,7 @@ py_library( "python/tpu/tpu_function.py", "python/tpu/tpu_optimizer.py", "python/tpu/tpu_sharding.py", + "python/tpu/tpu_system_metadata.py", "python/tpu/training_loop.py", ], srcs_version = "PY2AND3", @@ -307,3 +350,13 @@ tf_py_test( "//tensorflow/python:framework_test_lib", ], ) + +tf_py_test( + name = "topology_test", + size = "small", + srcs = ["python/tpu/topology_test.py"], + additional_deps = [ + ":tpu", + "//tensorflow/python:framework_test_lib", + ], +) diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py index dc9066855990f372c28dc481959117daa4c2da97..537d94b7979af3e4bd3fb7392c8dcc5a210e98af 100644 --- a/tensorflow/contrib/tpu/__init__.py +++ b/tensorflow/contrib/tpu/__init__.py @@ -18,6 +18,10 @@ @@cross_replica_sum @@infeed_dequeue @@infeed_dequeue_tuple +@@infeed_enqueue +@@infeed_enqueue_tuple +@@outfeed_dequeue +@@outfeed_dequeue_tuple @@outfeed_enqueue @@outfeed_enqueue_tuple @@ -42,9 +46,14 @@ @@TPUEstimator @@TPUEstimatorSpec +@@export_estimator_savedmodel @@RunConfig @@InputPipelineConfig @@TPUConfig +@@bfloat16_scope + +@@TPUDistributionStrategy +@@keras_to_tpu_model """ from __future__ import absolute_import @@ -56,11 +65,13 @@ from tensorflow.contrib.tpu.python import profiler from tensorflow.contrib.tpu.python.ops.tpu_ops import * from tensorflow.contrib.tpu.python.tpu.bfloat16 import * from tensorflow.contrib.tpu.python.tpu.device_assignment import * +from tensorflow.contrib.tpu.python.tpu.keras_support import tpu_model as keras_to_tpu_model +from tensorflow.contrib.tpu.python.tpu.keras_support import TPUDistributionStrategy from tensorflow.contrib.tpu.python.tpu.topology import * from tensorflow.contrib.tpu.python.tpu.tpu import * from tensorflow.contrib.tpu.python.tpu.tpu_config import * from tensorflow.contrib.tpu.python.tpu.tpu_estimator import * -from tensorflow.contrib.tpu.python.tpu.tpu_feed import * +from tensorflow.contrib.tpu.python.tpu.tpu_feed import InfeedQueue from tensorflow.contrib.tpu.python.tpu.tpu_optimizer import * from tensorflow.contrib.tpu.python.tpu.training_loop import * # pylint: enable=wildcard-import,unused-import diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index f80f5652af79d410946971573ae160fdd0b85f6d..8e6e9aa0cded630f39bfd699def37e06a8b920e8 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -84,8 +84,6 @@ ProfileRequest PopulateProfileRequest(int duration_ms, request.add_tools("memory_viewer"); request.add_tools("overview_page"); *request.mutable_opts() = opts; - std::cout << "Limiting the number of trace events to " << kMaxEvents - << std::endl; return request; } @@ -99,7 +97,6 @@ bool Profile(const string& service_addr, const string& logdir, int duration_ms, ::grpc::ClientContext context; ::grpc::ChannelArguments channel_args; - // TODO(ioeric): use `SetMaxReceiveMessageSize` instead once it's available. // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their // `ValidateHostPortPair` checks for empty host string case. channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, @@ -166,6 +163,85 @@ bool NewSession(const string& service_addr, return new_session_response.empty_trace(); } +// Starts tracing on a single or multiple TPU hosts and saves the result in the +// given logdir. If no trace was collected, retries tracing for +// num_tracing_attempts. +void StartTracing(const tensorflow::string& service_addr, + const tensorflow::string& logdir, + const tensorflow::string& workers_list, + bool include_dataset_ops, int duration_ms, + int num_tracing_attempts) { + // Use the current timestamp as the run name. + tensorflow::string session_id = GetCurrentTimeStampAsString(); + constexpr char kProfilePluginDirectory[] = "plugins/profile/"; + tensorflow::string repository_root = + io::JoinPath(logdir, kProfilePluginDirectory); + std::vector hostnames = + tensorflow::str_util::Split(workers_list, ","); + + bool empty_trace = false; + int remaining_attempts = num_tracing_attempts; + tensorflow::ProfileOptions opts; + opts.set_include_dataset_ops(include_dataset_ops); + while (true) { + std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. " + << "Remaining attempt(s): " << remaining_attempts-- << std::endl; + if (hostnames.empty()) { + empty_trace = tensorflow::tpu::Profile(service_addr, logdir, duration_ms, + repository_root, session_id, opts); + } else { + tensorflow::string tpu_master = service_addr; + empty_trace = + tensorflow::tpu::NewSession(tpu_master, hostnames, duration_ms, + repository_root, session_id, opts); + } + if (remaining_attempts <= 0 || !empty_trace) break; + std::cout << "No trace event is collected. Automatically retrying." + << std::endl + << std::endl; + } + + if (empty_trace) { + std::cout << "No trace event is collected after " << num_tracing_attempts + << " attempt(s). " + << "Perhaps, you want to try again (with more attempts?)." + << std::endl + << "Tip: increase number of attempts with --num_tracing_attempts." + << std::endl; + } +} + +MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level) { + MonitorRequest request; + request.set_duration_ms(duration_ms); + request.set_monitoring_level(monitoring_level); + return request; +} + +// Repeatedly collects profiles and shows user-friendly metrics for +// 'num_queries' time(s). +void StartMonitoring(const tensorflow::string& service_addr, int duration_ms, + int monitoring_level, int num_queries) { + for (int query = 0; query < num_queries; ++query) { + MonitorRequest request = + PopulateMonitorRequest(duration_ms, monitoring_level); + + ::grpc::ClientContext context; + ::grpc::ChannelArguments channel_args; + channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, + std::numeric_limits::max()); + std::unique_ptr stub = + TPUProfiler::NewStub(::grpc::CreateCustomChannel( + "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), + channel_args)); + MonitorResponse response; + TF_QCHECK_OK(FromGrpcStatus(stub->Monitor(&context, request, &response))); + + std::cout << "Xprof Monitoring Results (Sample " << query + 1 << "):\n\n" + << response.data() << std::flush; + } +} + } // namespace } // namespace tpu } // namespace tensorflow @@ -174,9 +250,11 @@ int main(int argc, char** argv) { tensorflow::string FLAGS_service_addr; tensorflow::string FLAGS_logdir; tensorflow::string FLAGS_workers_list; - int FLAGS_duration_ms = 2000; + int FLAGS_duration_ms = 0; int FLAGS_num_tracing_attempts = 3; bool FLAGS_include_dataset_ops = true; + int FLAGS_monitoring_level = 0; + int FLAGS_num_queries = 100; std::vector flag_list = { tensorflow::Flag("service_addr", &FLAGS_service_addr, "Address of TPU profiler service e.g. localhost:8466"), @@ -186,21 +264,38 @@ int main(int argc, char** argv) { tensorflow::Flag("logdir", &FLAGS_logdir, "Path of TensorBoard log directory e.g. /tmp/tb_log, " "gs://tb_bucket"), - tensorflow::Flag("duration_ms", &FLAGS_duration_ms, - "Duration of tracing in ms. Default is 2000ms."), + tensorflow::Flag( + "duration_ms", &FLAGS_duration_ms, + "Duration of tracing or monitoring in ms. Default is 2000ms for " + "tracing and 1000ms for monitoring."), tensorflow::Flag("num_tracing_attempts", &FLAGS_num_tracing_attempts, "Automatically retry N times when no trace event " "is collected. Default is 3."), tensorflow::Flag("include_dataset_ops", &FLAGS_include_dataset_ops, "Set to false to profile longer TPU device traces."), - }; + tensorflow::Flag("monitoring_level", &FLAGS_monitoring_level, + "Choose a monitoring level between 1 and 2 to monitor " + "your TPU job continuously. Level 2 is more verbose " + "than level 1 and shows more metrics."), + tensorflow::Flag("num_queries", &FLAGS_num_queries, + "This script will run monitoring for num_queries before " + "it stops.")}; std::cout << "Welcome to the Cloud TPU Profiler v" << TPU_PROFILER_VERSION << std::endl; tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); - if (!parse_ok || FLAGS_service_addr.empty() || FLAGS_logdir.empty()) { + if (!parse_ok || FLAGS_service_addr.empty() || + (FLAGS_logdir.empty() && FLAGS_monitoring_level == 0)) { + // Fail if flags are not parsed correctly or service_addr not provided. + // Also, fail if neither logdir is provided (required for tracing) nor + // monitoring level is provided (required for monitoring). + std::cout << usage.c_str() << std::endl; + return 2; + } + if (FLAGS_monitoring_level < 0 || FLAGS_monitoring_level > 2) { + // Invalid monitoring level. std::cout << usage.c_str() << std::endl; return 2; } @@ -213,52 +308,27 @@ int main(int argc, char** argv) { } tensorflow::port::InitMain(argv[0], &argc, &argv); - // Sets the minimum duration_ms and tracing attempts to one. - int duration_ms = std::max(FLAGS_duration_ms, 1); - int remaining_attempts = std::max(FLAGS_num_tracing_attempts, 1); - tensorflow::ProfileOptions opts; - opts.set_include_dataset_ops(FLAGS_include_dataset_ops); - tensorflow::ProfileResponse response; - - // Use the current timestamp as the run name. - tensorflow::string session_id = - tensorflow::tpu::GetCurrentTimeStampAsString(); - constexpr char kProfilePluginDirectory[] = "plugins/profile/"; - tensorflow::string repository_root = - ::tensorflow::io::JoinPath(FLAGS_logdir, kProfilePluginDirectory); - std::vector hostnames = - tensorflow::str_util::Split(FLAGS_workers_list, ","); - - bool empty_trace = false; - while (true) { - std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. " - << "Remaining attempt(s): " << remaining_attempts-- << std::endl; - if (hostnames.empty()) { - empty_trace = tensorflow::tpu::Profile(FLAGS_service_addr, FLAGS_logdir, - duration_ms, repository_root, - session_id, opts); - } else { - tensorflow::string tpu_master = FLAGS_service_addr; - empty_trace = - tensorflow::tpu::NewSession(tpu_master, hostnames, duration_ms, - repository_root, session_id, opts); - } - if (remaining_attempts <= 0 || !empty_trace) break; - std::cout << "No trace event is collected. Automatically retrying." - << std::endl - << std::endl; + // Sets the minimum duration_ms, tracing attempts and num queries. + int duration_ms = std::max(FLAGS_duration_ms, 0); + if (duration_ms == 0) { + // If profiling duration was not set by user or set to a negative value, we + // set it to default values of 2000ms for tracing and 1000ms for monitoring. + duration_ms = FLAGS_monitoring_level == 0 ? 2000 : 1000; } + int num_tracing_attempts = std::max(FLAGS_num_tracing_attempts, 1); + int num_queries = std::max(FLAGS_num_queries, 1); - if (empty_trace) { - std::cout << "No trace event is collected after " - << FLAGS_num_tracing_attempts << " attempt(s). " - << "Perhaps, you want to try again (with more attempts?)." - << std::endl - << "Tip: increase number of attempts with --num_tracing_attempts." + if (FLAGS_monitoring_level != 0) { + std::cout << "Since monitoring level is provided, profile " + << FLAGS_service_addr << " for " << duration_ms + << "ms and show metrics for " << num_queries << " time(s)." << std::endl; - // Don't dump profile data if no trace is collected. - return 0; + tensorflow::tpu::StartMonitoring(FLAGS_service_addr, duration_ms, + FLAGS_monitoring_level, num_queries); + } else { + tensorflow::tpu::StartTracing(FLAGS_service_addr, FLAGS_logdir, + FLAGS_workers_list, FLAGS_include_dataset_ops, + duration_ms, num_tracing_attempts); } - return 0; } diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py index 7a5d01cca42351f6d4d8b41d43756560ce7874d3..438f4428483a86b75ca1feb31d9c43f860fcc287 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py @@ -50,7 +50,8 @@ flags.DEFINE_string( flags.DEFINE_string( 'logdir', None, 'Path of TensorBoard log directory e.g. /tmp/tb_log, ' 'gs://tb_bucket') -flags.DEFINE_integer('duration_ms', 2000, 'Duration of tracing in ms.') +flags.DEFINE_integer('duration_ms', 0, + 'Duration of tracing or monitoring in ms.') flags.DEFINE_integer( 'num_tracing_attempts', 3, 'Automatically retry N times when no trace ' 'event is collected.') @@ -58,6 +59,14 @@ flags.DEFINE_boolean('include_dataset_ops', True, 'Set to false to profile longer TPU ' 'device traces.') +# Monitoring parameters +flags.DEFINE_integer( + 'monitoring_level', 0, 'Choose a monitoring level between ' + '1 and 2 to monitor your TPU job continuously.') +flags.DEFINE_integer( + 'num_queries', 100, + 'This script will run monitoring for num_queries before it stops.') + FLAGS = flags.FLAGS EXECUTABLE = 'data/capture_tpu_profile' JOB_NAME = 'worker' @@ -118,6 +127,8 @@ def main(unused_argv=None): cmd.append('--duration_ms=' + str(FLAGS.duration_ms)) cmd.append('--num_tracing_attempts=' + str(FLAGS.num_tracing_attempts)) cmd.append('--include_dataset_ops=' + str(FLAGS.include_dataset_ops).lower()) + cmd.append('--monitoring_level=' + str(FLAGS.monitoring_level)) + cmd.append('--num_queries=' + str(FLAGS.num_queries)) subprocess.call(cmd) diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto index f0fca63db0bca80cdaa27e491b2a03ae2246c007..da4a95e0450a9d0c20593ca60b69f3ad467d455d 100644 --- a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto +++ b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto @@ -11,6 +11,9 @@ service TPUProfiler { // Starts a profiling session, blocks until it completes, and returns data. rpc Profile(ProfileRequest) returns (ProfileResponse) { } + // Collects profiling data and returns user-friendly metrics. + rpc Monitor(MonitorRequest) returns (MonitorResponse) { + } } message ProfileOptions { @@ -104,3 +107,26 @@ message ProfileResponse { // next-field: 8 } + +message MonitorRequest { + // Duration for which to profile between each update. + uint64 duration_ms = 1; + + // Indicates the level at which we want to monitor. Currently, two levels are + // supported: + // Level 1: An ultra lightweight mode that captures only some utilization + // metrics. + // Level 2: More verbose than level 1. Collects utilization metrics, device + // information, step time information, etc. Do not use this option if the TPU + // host is being very heavily used. + int32 monitoring_level = 2; + + // next-field: 3 +} + +message MonitorResponse { + // Properly formatted string data that can be directly returned back to user. + string data = 1; + + // next-field: 2 +} diff --git a/tensorflow/contrib/tpu/proto/BUILD b/tensorflow/contrib/tpu/proto/BUILD index 26016f47dfb36990fd73267c70619878ac3450e5..598b73b438cb239187a911b2d1425b434c889d8d 100644 --- a/tensorflow/contrib/tpu/proto/BUILD +++ b/tensorflow/contrib/tpu/proto/BUILD @@ -15,6 +15,16 @@ tf_proto_library( "tpu_embedding_config.proto", ], cc_api_version = 2, + protodeps = [":optimization_parameters_proto"], + visibility = ["//visibility:public"], +) + +tf_proto_library( + name = "optimization_parameters_proto", + srcs = [ + "optimization_parameters.proto", + ], + cc_api_version = 2, visibility = ["//visibility:public"], ) diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto new file mode 100644 index 0000000000000000000000000000000000000000..2cc17d6d928370afbb0e3b1e89252f7a687c27d3 --- /dev/null +++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto @@ -0,0 +1,164 @@ +syntax = "proto3"; + +package tensorflow.tpu; + +import "google/protobuf/wrappers.proto"; + +message ClippingLimits { + google.protobuf.FloatValue lower = 1; // -inf if not set + google.protobuf.FloatValue upper = 2; // +inf if not set +} + +// Get the learning rate from a source that can change +// dynamically. +message DynamicLearningRate { +} + +// Source of learning rate to use. +message LearningRate { + oneof learning_rate { + float constant = 1; + DynamicLearningRate dynamic = 2; + } +} + +message AdagradParameters { + float initial_accumulator = 1; +} + +message StochasticGradientDescentParameters { +} + +message FtrlParameters { + float l1 = 1; + float l2 = 2; + float lr_power = 3; + float initial_accum = 4; + float initial_linear = 5; +} + +// The Adam optimizer does not implement hyper-parameter update; use the dynamic +// learning rate feature instead, setting the learning rate to: +// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) +// Here, t is the current timestep. +// https://github.com/tensorflow/tensorflow/blob/ab51450c817674c8ff08a7ae4f8ac50cdc4bed8b/tensorflow/python/training/adam.py#L54 +message AdamParameters { + float beta1 = 3; + float beta2 = 4; + float epsilon = 5; + float initial_m = 6; + float initial_v = 7; +} + +message MomentumParameters { + float momentum = 1; + bool use_nesterov = 2; + float initial_accum = 3; +} + +message RmsPropParameters { + float rho = 1; + float momentum = 2; + float epsilon = 3; + float initial_ms = 4; + float initial_mom = 5; +} + +message CenteredRmsPropParameters { + float rho = 1; + float momentum = 2; + float epsilon = 3; + float initial_ms = 4; + float initial_mom = 5; + float initial_mg = 6; +} + +message MdlAdagradLightParameters { + float l2 = 1; + float lr_power = 2; + float min_servable_mdl_benefit = 3; + float mdl_mix_in_margin = 4; + float mdl_benefit_rampup_coeff = 5; + float mdl_min_weight = 6; + float benefit_revisit_scale = 7; + float max_event_benefit = 8; + float max_total_benefit = 9; + float mdl_hard_limit = 10; + bool hard_limit_min_benefit = 11; + bool mdl_regularize = 12; + float initial_accumulator = 13; + float initial_weight = 14; + float initial_benefit = 15; +} + +message AdadeltaParameters { + float rho = 1; + float epsilon = 2; + float initial_accumulator = 3; + float initial_update = 4; +} + +message ProximalAdagradParameters { + float l1 = 1; + float l2 = 2; + float initial_accumulator = 3; +} + +message OptimizationParameters { + // Learning rate used for updating the embedding layer parameters. + LearningRate learning_rate = 13; + reserved 1; // Old learning rate tag. + + // Limits to which to clip the weight values after the backward pass; not + // present means no limits are applied. + ClippingLimits clipping_limits = 2; + + // Limits to which to clip the backward pass gradient before using it for + // updates; not present means no limits are applied. + ClippingLimits gradient_clipping_limits = 7; + + // Whether to use gradient accumulation (do two passes over the input + // gradients: one to accumulate them into a temporary array and another to + // apply them using the actual optimization algorithm). + bool use_gradient_accumulation = 15; + + // Optimization algorithm parameters; which field is selected determines which + // algorithm to use. + oneof parameters { + AdagradParameters adagrad = 3; + StochasticGradientDescentParameters stochastic_gradient_descent = 4; + FtrlParameters ftrl = 5; + AdamParameters adam = 6; + MomentumParameters momentum = 8; + RmsPropParameters rms_prop = 9; + CenteredRmsPropParameters centered_rms_prop = 10; + MdlAdagradLightParameters mdl_adagrad_light = 11; + AdadeltaParameters adadelta = 12; + ProximalAdagradParameters proximal_adagrad = 14; + } +} + +// Specification of an optimization algorithm's state variables (both the main +// value vector and any extra accumulators, etc.). +message StateVariableSpecification { + // Parameter name for the state variable. + string name = 1; + + // A normal state variable that should be saved and restored in checkpoints + // and used as an input or output to non-debug TensorFlow ops. + message UserDefined { + } + + // A state variable that should be filled with a constant and normally hidden + // from users (used for intermediate gradients being accumulated, for + // example). + message FillWithConstant { + double initial_value = 1; + } + + // Usage type of this state variable. + oneof usage { + UserDefined user_defined = 2; + FillWithConstant fill_with_constant = 3; + } +} diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto b/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto index b0ec968d3a401f1b80ed1bf6fd7a83a69c068fe2..3476cc89534efb7fe05640935d1387d02737f240 100644 --- a/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto +++ b/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package tensorflow.tpu; +import "tensorflow/contrib/tpu/proto/optimization_parameters.proto"; + // The TPUEmbeddingConfiguration contains specification of TPU Embedding lookups // and gradient updates separate from the TF Graph. message TPUEmbeddingConfiguration { @@ -30,15 +32,6 @@ message TPUEmbeddingConfiguration { // The number of training examples per TensorNode. int32 batch_size = 4; - message GradientDescentOptimizer { - float learning_rate = 1; - } - - message AdagradOptimizer { - float learning_rate = 1; - float initial_accumulator = 2; - } - // Each Embedding message TPUEmbeddingTable { // Name of the embedding table. This will be used to name Variables in the @@ -66,10 +59,7 @@ message TPUEmbeddingConfiguration { // separately to the convolutional or recurrent network. int32 num_features = 5; - oneof optimizer { - GradientDescentOptimizer gradient_descent = 6; - AdagradOptimizer adagrad = 7; - } + OptimizationParameters optimization_parameters = 6; } repeated TPUEmbeddingTable table_config = 5; diff --git a/tensorflow/contrib/tpu/python/tpu/device_assignment.py b/tensorflow/contrib/tpu/python/tpu/device_assignment.py index 726b2d248e3086e1882004827076ed3e563d960d..471b1fa46c679dcab70e9bc12d61ada84cba79bb 100644 --- a/tensorflow/contrib/tpu/python/tpu/device_assignment.py +++ b/tensorflow/contrib/tpu/python/tpu/device_assignment.py @@ -175,6 +175,8 @@ class DeviceAssignment(object): """Returns the physical topology coordinates of a logical core.""" if logical_core is None: logical_core = np.array([0, 0, 0], np.int32) + else: + logical_core = np.asarray(logical_core) if any(logical_core < 0) or any(logical_core >= self.computation_shape): raise ValueError("Invalid core {}; computation shape is {}".format( diff --git a/tensorflow/contrib/tpu/python/tpu/error_handling.py b/tensorflow/contrib/tpu/python/tpu/error_handling.py new file mode 100644 index 0000000000000000000000000000000000000000..52e1ea42370d653d1de7c12eee4b456ec7ce921c --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/error_handling.py @@ -0,0 +1,132 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== +"""ErrorRendezvous handler for collecting errors from multiple threads.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import sys +import threading +import time + +import six + +from tensorflow.python.framework import errors +from tensorflow.python.platform import tf_logging as logging + +_UNINTERESTING_ERRORS = (errors.CancelledError,) + + +class ErrorRendezvous(object): + """Resolve errors from multiple threads during TPU execution. + + TPU errors can occur on the infeed or outfeed threads as well as the main + training thread. + + Depending on which thread "wins" and receives the session error first, we may + end up showing users a confusing and non-actionable error message (session + cancelled) instead of a root cause (e.g. a bad filename). + + The rendezvous object provides a location to capture these errors until all + threads terminate. At that point we can choose the most informative error + to report. + """ + + def __init__(self, num_sources): + # string -> (message, traceback) + self._errors = {} + self._num_sources = num_sources + self._session_cancel_timer = None + + def record_error(self, source, exc_info, session=None): + """Report an exception from the given source. + + If a session is passed, a timer will be registered to close it after a few + seconds. This is necessary to ensure the main training loop does not hang + if an infeed/oufeed error occurs. We sleep a few seconds to allow a more + interesting error from another thread to propagate. + + Args: + source: string, source of the error + exc_info: Output from `sys.exc_info` (type, value, traceback) + session: Session to close after delay. + """ + _, value, _ = exc_info + self._errors[source] = exc_info + logging.info('Error recorded from %s: %s', source, value) + + if session is not None and self._session_cancel_timer is None: + + def _cancel_session(): + time.sleep(5) + try: + session.close() + except: # pylint: disable=bare-except + pass + + self._session_cancel_timer = threading.Thread(target=_cancel_session,) + self._session_cancel_timer.daemon = True + self._session_cancel_timer.start() + + def record_done(self, source): + """Mark execution source `source` as done. + + If an error was originally reported from `source` it is left intact. + + Args: + source: `str`, source being recorded + """ + logging.info('%s marked as finished', source) + if source not in self._errors: + self._errors[source] = None + + @contextlib.contextmanager + def catch_errors(self, source, session=None): + """Context manager to report any errors within a block.""" + try: + yield + except Exception: # pylint: disable=broad-except + self.record_error(source, sys.exc_info(), session) + + def raise_errors(self, timeout_sec=0): + """Wait for up to `timeout` seconds for all error sources to finish. + + Preferentially raise "interesting" errors (errors not in the + _UNINTERESTING_ERRORS) set. + + Args: + timeout_sec: Seconds to wait for other error sources. + """ + for _ in range(timeout_sec): + if len(self._errors) == self._num_sources: + break + time.sleep(1) + + kept_errors = [(k, v) for (k, v) in self._errors.items() if v is not None] + + # First check for any interesting errors, then fall back on the session + # cancelled errors etc. + for k, (typ, value, traceback) in kept_errors: + if isinstance(value, _UNINTERESTING_ERRORS): + continue + else: + logging.warn('Reraising captured error') + six.reraise(typ, value, traceback) + + for k, (typ, value, traceback) in kept_errors: + logging.warn('Reraising captured error') + six.reraise(typ, value, traceback) diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 754154438235f4c5e9e8db996acc8d843ab18431..ff893a722f4e77c743edd3b8db77aa90be1e498d 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -45,6 +45,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections import contextlib import re @@ -54,16 +55,19 @@ import time import numpy as np from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver -from tensorflow.contrib.distribute.python import tpu_strategy from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu +from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.contrib.tpu.python.tpu import tpu_optimizer from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as tf_session +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend as K from tensorflow.python.keras import models @@ -71,11 +75,17 @@ from tensorflow.python.keras import optimizers as keras_optimizers from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.layers import embeddings from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging -TPUDistributionStrategy = tpu_strategy.TPUStrategy # pylint: disable=invalid-name + +# Work-around dependency cycle between DistributionStrategy and TPU lib. +def TPUDistributionStrategy(*args, **kw): # pylint: disable=invalid-name + from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top + return tpu_strategy.TPUStrategy(*args, **kw) class TPUEmbedding(embeddings.Embedding): @@ -99,6 +109,45 @@ class TPUEmbedding(embeddings.Embedding): return math_ops.tensordot(inputs, self.embeddings, 1) +class KerasCrossShardOptimizer(keras_optimizers.Optimizer): + """An optimizer that averages gradients across TPU shards.""" + + def __init__(self, opt, name='KerasCrossShardOptimizer'): + """Construct a new cross-shard optimizer. + + Args: + opt: An existing `Optimizer` to encapsulate. + name: Optional name prefix for the operations created when applying + gradients. Defaults to "KerasCrossShardOptimizer". + + Raises: + ValueError: If reduction is not a valid cross-shard reduction. + """ + super(KerasCrossShardOptimizer, self).__init__() + self._name = name + self._opt = opt + + def get_updates(self, loss, params): + logging.info('Get updates: %s', loss) + self._opt.get_gradients = self.get_gradients + return self._opt.get_updates(loss, params) + + def get_gradients(self, loss, params): + num_shards = tpu_function.get_tpu_context().number_of_shards + grads = super(KerasCrossShardOptimizer, self).get_gradients(loss, params) + return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads] + + def set_weights(self, weights): + self._opt.set_weights() + + def get_weights(self): + return self._opt.get_weights() + + @property + def lr(self): + return self._opt.lr + + class TPUModelOp( collections.namedtuple('TPUModelOp', [ 'compile_op', 'execute_op', 'infeed_tensors', 'infeed_op', 'outfeed_op' @@ -113,8 +162,13 @@ def _valid_name(tensor_name): def _replicated_optimizer(opt): """Wrap the optimizer `opt` with CrossShardOptimizer if applicable.""" - return keras_optimizers.TFOptimizer( - optimizer=tpu_optimizer.CrossShardOptimizer(opt.optimizer)) + if tpu_function.get_tpu_context().number_of_shards == 1: + return opt + + if isinstance(opt, keras_optimizers.TFOptimizer): + return tpu_optimizer.CrossShardOptimizer(opt.optimizer) + else: + return KerasCrossShardOptimizer(opt) class TPURewriteContext(object): @@ -154,7 +208,6 @@ class TPURewriteContext(object): caller_obj = caller_frame.f_locals.get('self') if (caller_obj is not None and isinstance(caller_obj, base_layer.Layer) and name is not None): - logging.info('Intercepted name_scope: %s', caller_obj) return variable_scope.variable_scope( name, default_name, values, reuse=variable_scope.AUTO_REUSE) @@ -163,8 +216,51 @@ class TPURewriteContext(object): self._default_placeholder = array_ops.placeholder self._default_name_scope = ops.name_scope self._default_make_variable = base_layer.make_variable + self._default_random_normal = random_ops.random_normal + self._default_qr = gen_linalg_ops.qr array_ops.placeholder = _placeholder + + # Replace random_ops.random_normal with a dummy function because + # `random_normal` isn't yet implemented on the TPU. Because these + # initialized values are overwritten by the CPU values, this is okay. + def random_normal(shape, + mean=0.0, + stddev=1.0, + dtype=dtypes.float32, + seed=None, + name=None): + del mean + del stddev + del seed + return array_ops.zeros(shape, dtype=dtype, name=name) + + random_ops.random_normal = random_normal + + # Replace gen_linalg_ops.qr because QR decomposition is not yet implemented. + # TODO(saeta): Remove qr override once we confirm the qr implementation is + # ok. + # pylint: disable=redefined-builtin + def qr(input, full_matrices=False, name=None): + """Dummy implementation of qr decomposition.""" + del full_matrices # TODO(saeta): Properly handle the full matrix case. + input_shape = input.shape + if len(input_shape) < 2: + raise ValueError('Invalid shape passed to qr: %s' % input_shape) + p = min(input_shape[-1], input_shape[-2]) + if len(input_shape) == 2: + q = array_ops.zeros((p, p), name=name) + r = array_ops.zeros(input_shape, name=name) + return (r, q) + elif len(input_shape) == 3: + n = input_shape[0] + q = array_ops.zeros((n, p, p), name=name) + r = array_ops.zeros(input_shape, name=name) + return (r, q) + else: + raise ValueError('Invalid shape passed to qr: %s' % input_shape) + gen_linalg_ops.qr = qr + ops.name_scope = _name_scope base_layer.make_variable = variable_scope.get_variable logging.info('Overriding default placeholder.') @@ -174,6 +270,334 @@ class TPURewriteContext(object): array_ops.placeholder = self._default_placeholder ops.name_scope = self._default_name_scope base_layer.make_variable = self._default_make_variable + random_ops.random_normal = self._default_random_normal + gen_linalg_ops.qr = self._default_qr + + +class SizedInfeed(collections.namedtuple('SizedInfeed', + ['sharded_infeed_tensors', + 'infeed_ops'])): + """Represents an instantiation of the infeed ops for a concrete input shape. + + sharded_infeed_tensors: A data structure of Tensors used to represent the + placeholder tensors that must be fed when using feed_dicts. + + infeed_ops: the set of ops that will be run to drive infeed for a single step. + """ + pass + + +class TPUInfeedInstance(object): + """TPUInfeedInstance represents the logic to manage feeding in a single step. + + See the comments on the `TPUInfeedManager` for a description for how infeed + is managed. + """ + + @abc.abstractmethod + def make_input_specs(self, input_tensors): + """Constructs the infeed_specs for the given Infeed instance. + + Args: + input_tensors: The inputs to the model. + + Returns: + A list of + """ + pass + + def make_feed_dict(self, tpu_model_op): + """Constructs a feed_dict for this instance, given the tpu_model_op. + + Args: + tpu_model_op: A `TPUModelOp` representing the TPU Model for this + instance's input spec. + + Returns: + A dictionary to use as the feed_dict of a `session.run` call. + """ + pass + + +class TPUInfeedManager(object): + """TPUInfeedManager manages the data infeeding of data to a TPU computation. + + Because there are multiple data sources (e.g. in-memory NumPy arrays, + `tf.data.Dataset`s), we abstract the different logic behind a single + interface: the `TPUInfeedManager`. + + (1) A `TPUFunction` is called with a set of inputs. Based on the inputs, + `TPUFunction` retrieves the corresponding `TPUInfeedManager` (or constructs a + new one if required). + + (2) The `TPUFunction` calls `make_infeed_instance` on the `TPUInfeedManager` + which returns a `TPUInfeedInstance`. + + (3) The `TPUFunction` checks in the shape cache for a pre-compiled instance of + the model based on the returned `input_specs` from `TPUInfeedInstance`. + + (4) [Optional.] If the model has not already been instantiated for the given + input spec, the `TPUFunction` compiles the model for the input spec (using the + `TPUInfeedManager`). + + (5) The `TPUInfeedInstance` constructs the session.run's feed_dict given the + compiled model instance corresponding to its shape. + """ + + @abc.abstractmethod + def make_infeed_instance(self, inputs): + """Given a single step's input, construct a `TPUInfeedInstance`. + + Args: + inputs: The inputs to a given step. + + Returns: + A subclass of `TPUInfeedInstance`. + """ + pass + + @abc.abstractmethod + def build_infeed_from_input_specs(self, input_specs, execution_mode): + """For a given input specification (size, type), construct the infeed ops. + + This is called only once for a given input specification and builds the + graph ops. It does not have a pointer to the actual infeed data. + + Args: + input_specs: TODO(saeta): Document me! + execution_mode: TODO(saeta): Document me! + + Returns: + A `SizedInfeed` instance. + """ + pass + + +class TPUNumpyInfeedManager(TPUInfeedManager): + """TPU Infeed manager for Numpy inputs.""" + + class NumpyInfeedInstance(TPUInfeedInstance): + """Infeed instance for Numpy inputs.""" + + def __init__(self, sharded_inputs): + self._sharded_inputs = sharded_inputs + + def make_input_specs(self, input_tensors): + # Compute an input specification (used to generate infeed enqueue and + # dequeue operations). We use the shape from our input array and the + # dtype from our model. A user may pass in a float64 for a float32 + # input: for model compatibility we still must generate a float32 infeed. + input_specs = [] + # We use the shape and dtype from the first shard to compute the input + # metadata (`input_specs`); all replicas have the same type and shape. + for tensor, ary in zip(input_tensors, self._sharded_inputs[0]): + input_specs.append( + tensor_spec.TensorSpec(ary.shape, tensor.dtype, + _valid_name(tensor.name))) + + return input_specs + + def make_feed_dict(self, tpu_model_op): + infeed_dict = {} + for infeed_tensors, inputs in zip(tpu_model_op.infeed_tensors, + self._sharded_inputs): + for tensor, value in zip(infeed_tensors, inputs): + infeed_dict[tensor] = value + return infeed_dict + + def __init__(self, distribution_strategy): + self._strategy = distribution_strategy + + def _split_tensors(self, inputs): + """Split input data across shards. + + Each input is sliced along the batch axis. + + Args: + inputs: List of Numpy arrays to run on the TPU. + + Returns: + List of lists containing the input to feed to each TPU shard. + """ + if self._strategy.num_towers == 1: + return [inputs] + + batch_size = inputs[0].shape[0] + assert batch_size % self._strategy.num_towers == 0, ( + 'batch_size must be divisible by strategy.num_towers (%s vs %s)' % + (batch_size, self._strategy.num_towers)) + shard_size = batch_size // self._strategy.num_towers + input_list = [] + for index in range(self._strategy.num_towers): + shard_inputs = [ + x[index * shard_size:(index + 1) * shard_size] for x in inputs + ] + input_list.append(shard_inputs) + return input_list + + def make_infeed_instance(self, inputs): + sharded_inputs = self._split_tensors(inputs) + return self.NumpyInfeedInstance(sharded_inputs) + + def build_infeed_from_input_specs(self, input_specs, execution_mode): + infeed_op = [] + shard_infeed_tensors = [] + + for shard_id in range(self._strategy.num_towers): + with ops.device('/device:CPU:0'): + infeed_tensors = [] + with ops.device('/device:TPU:%d' % shard_id): + for spec in input_specs: + # Construct placeholders for each of the inputs. + infeed_tensors.append( + array_ops.placeholder( + dtype=spec.dtype, + shape=spec.shape, + name='infeed-enqueue-%s-%d' % (spec.name, shard_id))) + shard_infeed_tensors.append(infeed_tensors) + + infeed_op.append( + tpu_ops.infeed_enqueue_tuple( + infeed_tensors, [spec.shape for spec in input_specs], + name='infeed-enqueue-%s-%d' % (execution_mode, shard_id), + device_ordinal=shard_id)) + return SizedInfeed(infeed_ops=infeed_op, + sharded_infeed_tensors=shard_infeed_tensors) + + +class TPUDatasetInfeedManager(TPUInfeedManager): + """Manages infeed for a `tf.data.Dataset` into a TPU computation. + """ + + class DatasetInfeedInstance(TPUInfeedInstance): + """An instance of the TPU infeed.""" + + def __init__(self, input_specs): + self._input_specs = input_specs + + def make_input_specs(self, input_tensors): + # TODO(saeta): Do error checking here! + return self._input_specs + + def make_feed_dict(self, tpu_model_op): + # TODO(saeta): Verify tpu_model_op is as expected! + return {} + + def __init__(self, dataset, distribution_strategy, tpu_session): + """Constructs a TPUDatasetInfeedManager. + + Must be called within a `KerasTPUModel.tpu_session` context! + + Args: + dataset: A `tf.data.Dataset` to infeed. + distribution_strategy: The `TPUDistributionStrategy` used to configure the + Keras TPU model. + tpu_session: The `tf.Session` object used for running the TPU model. + """ + self._verify_dataset_shape(dataset) + self._dataset = dataset + self._strategy = distribution_strategy + dummy_x_shape = dataset.output_shapes[0].as_list() + dummy_x_shape[0] *= distribution_strategy.num_towers + dummy_y_shape = dataset.output_shapes[1].as_list() + dummy_y_shape[0] *= distribution_strategy.num_towers + self._iterator = dataset.make_initializable_iterator() + tpu_session.run(self._iterator.initializer) + + self._get_next_ops = [] + ctrl_deps = [] + for i in range(distribution_strategy.num_towers): + with ops.control_dependencies(ctrl_deps): # Ensure deterministic + # TODO(saeta): Ensure correct placement! + get_next_op = self._iterator.get_next() + self._get_next_ops.append(get_next_op) + ctrl_deps.extend(get_next_op) + + # Use dummy numpy inputs for the rest of Keras' shape checking. We + # intercept them when building the model. + self._dummy_x = np.zeros(dummy_x_shape, + dtype=dataset.output_types[0].as_numpy_dtype) + self._dummy_y = np.zeros(dummy_y_shape, + dtype=dataset.output_types[1].as_numpy_dtype) + + input_specs = [] + if isinstance(self._iterator.output_shapes, tuple): + assert isinstance(self._iterator.output_types, tuple) + assert len(self._iterator.output_shapes) == len( + self._iterator.output_types) + for i in range(len(self._iterator.output_shapes)): + spec = tensor_spec.TensorSpec(self._iterator.output_shapes[i], + self._iterator.output_types[i]) + input_specs.append(spec) + elif isinstance(self._iterator.output_shapes, tensor_shape.TensorShape): + spec = tensor_spec.TensorSpec(self._iterator.output_shapes, + self._iterator.output_types) + input_specs.append(spec) + + self._infeed_instance = self.DatasetInfeedInstance(input_specs) + + def _verify_dataset_shape(self, dataset): + """Verifies a dataset is of an appropriate shape for TPUs.""" + if not isinstance(dataset, dataset_ops.Dataset): + raise ValueError('The function passed as the `x` parameter did not ' + 'return a `tf.data.Dataset`.') + if not isinstance(dataset.output_classes, tuple): + raise ValueError('The dataset must return a tuple of tf.Tensors, ' + 'instead it returns: %s' % dataset.output_classes) + if len(dataset.output_classes) != 2: + raise ValueError( + 'The dataset must return a 2-element tuple, got ' + '%s output classes instead.' % (dataset.output_classes,)) + for i, cls in enumerate(dataset.output_classes): + if cls != ops.Tensor: + raise ValueError('The dataset returned a non-Tensor type (%s) at ' + 'index %d.' % (cls, i)) + for i, shape in enumerate(dataset.output_shapes): + if not shape: + raise ValueError('The dataset returns a scalar tensor in ' + 'tuple index %d. Did you forget to batch? ' + '(Output shapes: %s).' % (i, + dataset.output_shapes)) + for j, dim in enumerate(shape): + if dim.value is None: + if j == 0: + hint = (' Hint: did you use `ds.batch(BATCH_SIZE, ' + 'drop_remainder=True)`?') + else: + hint = '' + raise ValueError( + 'The Keras-TPU integration for `tf.data` ' + 'currently requires static shapes. The provided ' + 'dataset only has a partially defined shape. ' + '(Dimension %d of output tensor %d is not statically known ' + 'for output shapes: %s.%s)' % (i, j, dataset.output_shapes, hint)) + + @property + def dummy_x(self): + return self._dummy_x + + @property + def dummy_y(self): + return self._dummy_y + + def make_infeed_instance(self, inputs): + # TODO(saeta): Verify inputs is as expected. + return self._infeed_instance + + def build_infeed_from_input_specs(self, input_specs, execution_mode): + shard_infeed_tensors = self._get_next_ops + assert len(shard_infeed_tensors) == self._strategy.num_towers + infeed_ops = [] + for shard_id in range(self._strategy.num_towers): + with ops.device('/device:CPU:0'): + infeed_ops.append( + tpu_ops.infeed_enqueue_tuple( + shard_infeed_tensors[shard_id], + [spec.shape for spec in input_specs], + name='infeed-enqueue-%s-%d' % (execution_mode, shard_id), + device_ordinal=shard_id)) + return SizedInfeed(infeed_ops=infeed_ops, + sharded_infeed_tensors=shard_infeed_tensors) class TPUFunction(object): @@ -195,7 +619,13 @@ class TPUFunction(object): self._compilation_cache = {} self._cloned_model = None - def _specialize_model(self, input_specs): + # Copy optimizer configuration. This is done prior to `_specialize_model` + # as the configuration may require evaluating variables in the CPU session. + self._optimizer_config = None + if not isinstance(self.model.optimizer, keras_optimizers.TFOptimizer): + self._optimizer_config = self.model.optimizer.get_config() + + def _specialize_model(self, input_specs, infeed_manager): """Specialize `self.model` (a Keras model) for the given input shapes.""" # Re-create our input and output layers inside our subgraph. They will be # attached to the true computation when we clone our model in `tpu_fn`. @@ -221,8 +651,8 @@ class TPUFunction(object): name='infeed-%s' % self.execution_mode) assert len(infeed_tensors) == len(infeed_layers), ( - 'Infeed inputs did not match model: %s vs %s', (infeed_layers, - infeed_tensors)) + 'Infeed inputs did not match model: %s vs %s' % (infeed_layers, + infeed_tensors)) tpu_targets = [] tpu_input_map = {} @@ -236,11 +666,23 @@ class TPUFunction(object): # Clone our CPU model, running within the TPU device context. with TPURewriteContext(tpu_input_map): - self._cloned_model = models.clone_model(self.model) + # TODO(power): Replicate variables. + with ops.device('/device:TPU:0'): + self._cloned_model = models.clone_model(self.model) + + # Create a copy of the optimizer for this graph. + if isinstance(self.model.optimizer, keras_optimizers.TFOptimizer): + cloned_optimizer = keras_optimizers.TFOptimizer( + self.model.optimizer.optimizer) + else: + logging.info('Cloning %s %s', self.model.optimizer.__class__.__name__, + self._optimizer_config) + cloned_optimizer = self.model.optimizer.__class__.from_config( + self._optimizer_config) if is_training or is_test: self._cloned_model.compile( - optimizer=_replicated_optimizer(self.model.optimizer), + optimizer=_replicated_optimizer(cloned_optimizer), loss=self.model.loss, loss_weights=self.model.loss_weights, metrics=self.model.metrics, @@ -299,37 +741,24 @@ class TPUFunction(object): # Generate CPU side operations to enqueue features/labels and dequeue # outputs from the model call. - infeed_op = [] + sized_infeed = infeed_manager.build_infeed_from_input_specs( + input_specs, self.execution_mode) + # Build output ops. outfeed_op = [] - shard_infeed_tensors = [] - for shard_id in range(self._strategy.num_towers): - with ops.device('/device:TPU:%d' % shard_id): - infeed_tensors = [] - for spec in input_specs: - infeed_tensors.append( - array_ops.placeholder( - dtype=spec.dtype, - shape=spec.shape, - name='infeed-enqueue-%s-%d' % (spec.name, shard_id))) - shard_infeed_tensors.append(infeed_tensors) - - infeed_op.append( - tpu_ops.infeed_enqueue_tuple( - infeed_tensors, [spec.shape for spec in input_specs], - name='infeed-enqueue-%s-%d' % (self.execution_mode, shard_id))) - + with ops.device('/device:CPU:0'): outfeed_op.extend( tpu_ops.outfeed_dequeue_tuple( dtypes=[spec.dtype for spec in self._outfeed_spec], shapes=[spec.shape for spec in self._outfeed_spec], - name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id))) + name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id), + device_ordinal=shard_id)) return TPUModelOp( compile_op, execute_op, - infeed_tensors=shard_infeed_tensors, - infeed_op=infeed_op, + infeed_tensors=sized_infeed.sharded_infeed_tensors, + infeed_op=sized_infeed.infeed_ops, outfeed_op=outfeed_op) def _test_model_compiles(self, tpu_model_ops): @@ -348,37 +777,17 @@ class TPUFunction(object): logging.info('Finished compiling. Time elapsed: %s secs', end_time - start_time) - def _split_tensors(self, inputs): - """Split input data across shards. - - Each input is sliced along the batch axis. - - Args: - inputs: List of Numpy arrays to run on the TPU. - - Returns: - List of lists containing the input to feed to each TPU shard. - """ - if self._strategy.num_towers == 1: - return [inputs] - - batch_size = inputs[0].shape[0] - assert batch_size % self._strategy.num_towers == 0, ( - 'batch_size must be divisible by strategy.num_towers (%s vs %s)' % - (batch_size, self._strategy.num_towers) - ) - shard_size = batch_size // self._strategy.num_towers - input_list = [] - for index in range(self._strategy.num_towers): - shard_inputs = [ - x[index * shard_size:(index + 1) * shard_size] for x in inputs - ] - input_list.append(shard_inputs) - return input_list - def __call__(self, inputs): assert isinstance(inputs, list) + infeed_manager = None + for x, mgr in self.model._numpy_to_infeed_manager_list: + if inputs[0] is x: + infeed_manager = mgr + break + if infeed_manager is None: + infeed_manager = TPUNumpyInfeedManager(self.model._strategy) + # Strip sample weight from inputs if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or self.execution_mode == model_fn_lib.ModeKeys.EVAL): @@ -387,21 +796,9 @@ class TPUFunction(object): else: input_tensors = self.model._feed_inputs - shard_inputs = self._split_tensors(inputs) + infeed_instance = infeed_manager.make_infeed_instance(inputs) del inputs # To avoid accident usage. - - # Compute an input specification (used to generate infeed enqueue and - # dequeue operations). We use the shape from our input array and the - # dtype from our model. A user may pass in a float64 for a float32 - # input: for model compatibility we still must generate a float32 infeed. - input_specs = [] - - # We use the shape and dtype from the first shard to compute the input - # metadata (`input_specs`); all replicas have the same type and shape. - for tensor, ary in zip(input_tensors, shard_inputs[0]): - input_specs.append( - tensor_spec.TensorSpec(ary.shape, tensor.dtype, - _valid_name(tensor.name))) + input_specs = infeed_instance.make_input_specs(input_tensors) # XLA requires every operation in the graph has a fixed shape. To # handle varying batch sizes we recompile a new sub-graph for each @@ -412,7 +809,8 @@ class TPUFunction(object): with self.model.tpu_session(): logging.info('New input shapes; (re-)compiling: mode=%s, %s', self.execution_mode, input_specs) - new_tpu_model_ops = self._specialize_model(input_specs) + new_tpu_model_ops = self._specialize_model(input_specs, + infeed_manager) self._compilation_cache[shape_key] = new_tpu_model_ops self._test_model_compiles(new_tpu_model_ops) @@ -420,11 +818,7 @@ class TPUFunction(object): self.model._initialize_weights(self._cloned_model) tpu_model_ops = self._compilation_cache[shape_key] - infeed_dict = {} - for infeed_tensors, inputs in zip(tpu_model_ops.infeed_tensors, - shard_inputs): - for tensor, value in zip(infeed_tensors, inputs): - infeed_dict[tensor] = value + infeed_dict = infeed_instance.make_feed_dict(tpu_model_ops) with self.model.tpu_session() as session: _, _, outfeed_outputs = session.run([ @@ -438,9 +832,8 @@ class TPUFunction(object): outputs_per_replica = len(self._outfeed_spec) for i in range(self._strategy.num_towers): - output_group = outfeed_outputs[ - i * outputs_per_replica:(i+1) * outputs_per_replica - ] + output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) * + outputs_per_replica] for j in range(outputs_per_replica): outputs[j].append(output_group[j]) @@ -459,6 +852,11 @@ class KerasTPUModel(models.Model): name=cpu_model.name, ) + # Create a mapping from numpy arrays to infeed managers. + # Note: uses a list of tuples instead of a map because numpy arrays are + # not hashable. + self._numpy_to_infeed_manager_list = [] + self.predict_function = None self.test_function = None self.train_function = None @@ -470,14 +868,16 @@ class KerasTPUModel(models.Model): self._tpu_weights_initialized = False self._graph = ops.Graph() - cluster_resolver = tpu_cluster_resolver.TPUClusterResolver( + self._cluster_resolver = tpu_cluster_resolver.TPUClusterResolver( tpu_name_or_address) - cluster_spec = cluster_resolver.cluster_spec() + master = self._cluster_resolver.master() + cluster_spec = self._cluster_resolver.cluster_spec() self._session = tf_session.Session( graph=self._graph, - target=cluster_resolver.master(), + target=master, config=config_pb2.ConfigProto(isolate_session_state=True)) + # TODO(saeta): Confirm the lines below work in ClusterSpec propagation env. if cluster_spec: self._session.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) @@ -529,10 +929,91 @@ class KerasTPUModel(models.Model): sample_weight_mode, weighted_metrics, target_tensors, **kwargs) - # Keras optimizers are not compatible with TPU rewrite - if not isinstance(self.optimizer, keras_optimizers.TFOptimizer): + def fit(self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose=1, + callbacks=None, + validation_split=0., + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + **kwargs): + assert not self._numpy_to_infeed_manager_list # Ensure empty. + + infeed_managers = [] # Managers to clean up at the end of the fit call. + if isinstance(x, dataset_ops.Dataset): + # TODO(b/111413240): Support taking a tf.data.Dataset directly. + raise ValueError( + 'Taking a Dataset directly is not yet supported. Please ' + 'wrap your dataset construction code in a function and ' + 'pass that to fit instead. For examples, see: ' + 'https://github.com/tensorflow/tpu/tree/master/models/experimental' + '/keras') + if callable(x): + with self.tpu_session() as sess: + dataset = x() + if steps_per_epoch is None: + raise ValueError('When using tf.data as input to a model, you ' + 'should specify the steps_per_epoch argument.') + if y is not None: + raise ValueError('When using tf.data as input to a model, y must be ' + 'None') + infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess) + # Use dummy numpy inputs for the rest of Keras' shape checking. We + # intercept them when building the model. + x = infeed_manager.dummy_x + y = infeed_manager.dummy_y + infeed_managers.append((x, infeed_manager)) + + if isinstance(validation_data, dataset_ops.Dataset): + # TODO(b/111413240): Support taking a tf.data.Dataset directly. raise ValueError( - 'Optimizer must be a TFOptimizer, got: %s' % self.optimizer) + 'Taking a Dataset directly is not yet supported. Please ' + 'wrap your dataset construction code in a function and ' + 'pass that to fit instead. For examples, see: ' + 'https://github.com/tensorflow/tpu/tree/master/models/experimental' + '/keras') + if callable(validation_data): + with self.tpu_session() as sess: + dataset = validation_data() + if validation_steps is None: + raise ValueError('When using tf.data as validation for a model, you ' + 'should specify the validation_steps argument.') + infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess) + # Use dummy numpy inputs for the rest of Keras' shape checking. We + # intercept them when building the model. + val_x = infeed_manager.dummy_x + val_y = infeed_manager.dummy_y + infeed_managers.append((val_x, infeed_manager)) + validation_data = (val_x, val_y) + + self._numpy_to_infeed_manager_list = infeed_managers + try: + return super(KerasTPUModel, self).fit( + x, + y, + batch_size, + epochs, + verbose, + callbacks, + validation_split, + validation_data, + shuffle, + class_weight, + sample_weight, + initial_epoch, + steps_per_epoch, + validation_steps, + **kwargs) + finally: + self._numpy_to_infeed_manager_list = [] def _make_train_function(self): if not self.train_function: @@ -615,10 +1096,10 @@ class KerasTPUModel(models.Model): K.set_session(default_session) def shutdown(self): - logging.info('Shutting down TPU session.') - with self.tpu_session() as session: - session.run(tpu.shutdown_system()) - + # TODO(b/111364423): Actually shut down the system. + logging.info('Skipping shutting down TPU system.') + # with self.tpu_session() as session: + # session.run(tpu.shutdown_system()) self._session.close() @@ -687,6 +1168,10 @@ def tpu_model(model, tpu_name_or_address=None, strategy=None): Returns: A new `KerasTPUModel` instance. """ + # Force initialization of the CPU model. + model.get_weights() + model.reset_states() + _validate_shapes(model) # TODO(xiejw): Validate TPU model. TPUModel only? # TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset? diff --git a/tensorflow/contrib/tpu/python/tpu/topology.py b/tensorflow/contrib/tpu/python/tpu/topology.py index cda9a63f204ed686b527c95dd5b4fd7786ac60cf..1fb26e701a392d5ef3bc40d5772d4541fa38f773 100644 --- a/tensorflow/contrib/tpu/python/tpu/topology.py +++ b/tensorflow/contrib/tpu/python/tpu/topology.py @@ -55,8 +55,9 @@ class Topology(object): rank 3 numpy int32 array that describes a valid coordinate mapping. """ + self._serialized = serialized + if serialized: - self._serialized = serialized self._parse_topology(serialized) else: self._mesh_shape = np.asarray(mesh_shape, dtype=np.int32) @@ -131,7 +132,7 @@ class Topology(object): proto.mesh_shape[:] = list(self._mesh_shape) proto.num_tasks = self._device_coordinates.shape[0] proto.num_tpu_devices_per_task = self._device_coordinates.shape[1] - proto.device_coordinates = list(self._device_coordinates.flatten()) + proto.device_coordinates.extend(list(self._device_coordinates.flatten())) self._serialized = proto.SerializeToString() return self._serialized diff --git a/tensorflow/contrib/tpu/python/tpu/topology_test.py b/tensorflow/contrib/tpu/python/tpu/topology_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e67fdb263aa48a37f65c3623365ebcf8f98bebd4 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/topology_test.py @@ -0,0 +1,46 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Tests for topology.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tpu.python.tpu import topology + +from tensorflow.python.platform import test + + +class TopologyTest(test.TestCase): + + def testSerialization(self): + """Test if the class is able to generate serialzied string.""" + original_topology = topology.Topology( + mesh_shape=[1, 1, 2], + device_coordinates=[[[0, 0, 0], [0, 0, 1]]], + ) + serialized_str = original_topology.serialized() + new_topology = topology.Topology(serialized=serialized_str) + + # Make sure the topology recovered from serialized str is same as the + # original topology. + self.assertAllEqual( + original_topology.mesh_shape, new_topology.mesh_shape) + self.assertAllEqual( + original_topology.device_coordinates, new_topology.device_coordinates) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 6a64893d9abcd64360554ab00502cdf360b820b6..7fa06d6d560a4b6ffa6d9a3fd0fa208b4c60ee7f 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -151,6 +151,41 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._name = name self._unsupported_ops = [] self._pivot = pivot + self._replicated_vars = {} + + def get_replicated_var_handle(self, var): + """Returns a variable handle for replicated TPU variable 'var'. + + This is an method used by an experimental replicated variable + implementation and is not intended as a public API. + + Args: + var: The replicated TPU variable. + + Returns: + The handle of the TPU replicated input node. + """ + handle = self._replicated_vars.get(var) + if handle is not None: + return handle + + # Builds a TPUReplicatedInput node for the variable, if one does not already + # exist. The TPUReplicatedInput node must belong to the enclosing + # control-flow scope of the TPUReplicateContext. + # TODO(phawkins): consider changing the contract of the TPU encapsulation + # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope + # instead. + + # pylint: disable=protected-access + graph = ops.get_default_graph() + saved_context = graph._get_control_flow_context() + graph._set_control_flow_context(self.outer_context) + handle = tpu_ops.tpu_replicated_input( + [v.handle for v in var._vars], name=var.name + "/handle") + graph._set_control_flow_context(saved_context) + # pylint: enable=protected-access + self._replicated_vars[var] = handle + return handle def report_unsupported_operations(self): if self._unsupported_ops: @@ -279,7 +314,9 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): # Capture the device function stack at the time of first entry # since that is the stack that will be used outside_compilation. graph = ops.get_default_graph() - self._outer_device_function_stack = list(graph._device_function_stack) # pylint: disable=protected-access + # pylint: disable=protected-access + self._outer_device_function_stack = graph._device_function_stack.copy() + # pylint: enable=protected-access super(TPUReplicateContext, self).Enter() def HostComputeCore(self): @@ -598,23 +635,14 @@ def split_compile_and_replicate(computation, with tpu_function.tpu_shard_context( num_replicas), ops.control_dependencies([metadata]): - # For backward compatibility reasons, we tag replicated inputs with the - # _tpu_replicated_input attribute. This does nothing and exists only for - # backward compatibility. - # TODO(phawkins): delete the attr_scope after 6/28/2018. - # pylint: disable=protected-access - with graph._attr_scope({ - "_tpu_replicated_input": attr_value_pb2.AttrValue(b=True) - }): - # Add identity ops so even unused inputs are "consumed" by the - # computation. This is to avoid orphaned TPUReplicatedInput nodes. - # TODO(phawkins): consider instead pruning unused TPUReplicatedInput - # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. - computation_inputs = [ - array_ops.identity(x, name="replicated_input_{}".format(i)) - for i, x in enumerate(computation_inputs) - ] - # pylint: enable=protected-access + # Add identity ops so even unused inputs are "consumed" by the + # computation. This is to avoid orphaned TPUReplicatedInput nodes. + # TODO(phawkins): consider instead pruning unused TPUReplicatedInput + # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. + computation_inputs = [ + array_ops.identity(x, name="replicated_input_{}".format(i)) + for i, x in enumerate(computation_inputs) + ] # If there is an infeed queue, adds the dequeued values to the # computation's inputs. @@ -942,8 +970,15 @@ def rewrite(computation, Args: computation: A Python function that builds a computation to apply to the input. If the function takes n inputs, 'inputs' should be - a list of n tensors. If the function returns m outputs, rewrite - will return a list of m tensors. + a list of n tensors. + + `computation` may return a list of operations and tensors. Tensors must + come before operations in the returned list. The return value of + `rewrite` is a list of tensors corresponding to the tensors from the + from `computation`. + + All `Operation`s returned from `computation` will be executed when + evaluating any of the returned output tensors. inputs: A list of input tensors or `None` (equivalent to an empty list). infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to `computation`. @@ -980,6 +1015,19 @@ _BLACKLISTED_INFERENCE_OPS = set([ ]) +def under_tpu_inference_context(): + """Check if it is currently under `tpu.rewrite_for_inference()`.""" + graph = ops.get_default_graph() + + context = graph._get_control_flow_context() # pylint: disable=protected-access + while context: + if isinstance(context, _TPUInferenceContext): + return True + context = context.outer_context + + return False + + class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext): """A `ControlFlowContext` for nodes inside a TPU inference computation. diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 6d7331e3c79ade9c12c15de79f550cf3973c4e6c..8d05e081a7c6e0327fedae6dc2c3ba45df40d029 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -23,8 +23,6 @@ import collections import json import os -import numpy as np - from tensorflow.contrib.tpu.python.tpu import util as util_lib from tensorflow.core.protobuf import config_pb2 from tensorflow.python.estimator import run_config as run_config_lib @@ -43,17 +41,18 @@ class InputPipelineConfig(object): PER_SHARD_V1 = 1 PER_HOST_V1 = 2 PER_HOST_V2 = 3 + BROADCAST = 4 -# TODO(b/72511246) Provide a simplified api to configure model parallelism. class TPUConfig( collections.namedtuple('TPUConfig', [ 'iterations_per_loop', 'num_shards', - 'computation_shape', + 'num_cores_per_replica', 'per_host_input_for_training', 'tpu_job_name', 'initial_infeed_sleep_secs', + 'input_partition_dims', ])): r"""TPU related configuration required by `TPUEstimator`. @@ -67,22 +66,22 @@ class TPUConfig( case, this number equals the total number of TPU cores. For model-parallelism, the total number of TPU cores equals product(computation_shape) * num_shards. - computation_shape: Defaults to `None`, which disables model parallelism. A - list of size 3 which describes the shape of a model replica's block of - cores. This is required by model-parallelism which enables partitioning - the model to multiple cores. For example, [2, 2, 1] means the model is - partitioned across 4 cores which span two cores in both x and y - coordinates. Please refer to @{tf.contrib.tpu.Topology} for the - geometry of a TPU mesh. + num_cores_per_replica: Defaults to `None`, which disables model parallelism. + An integer which describes the number of TPU cores per model replica. This + is required by model-parallelism which enables partitioning + the model to multiple cores. Currently num_cores_per_replica must be + 1, 2, 4, or 8. per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`, - `input_fn` is invoked per-host rather than per-core. With per-host input - pipeline configuration, `input_fn` is invoked once on each host. With the - per-core input pipeline configuration, it is invoked once for each core. + `input_fn` is invoked once on each host. With the per-core input pipeline + configuration, it is invoked once for each core. With a global batch size `train_batch_size` in `TPUEstimator` constructor, the batch size for each shard is `train_batch_size` // #hosts in the `True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is - `train_batch_size` // #cores. With the per-core input pipeline - configuration, the shard batch size is also `train_batch_size` // #cores. + `train_batch_size` // #cores. In `BROADCAST` mode, `input_fn` is only + invoked once on host 0 and the tensors are broadcasted to all other + replicas. The batch size equals to train_batch_size`. With the per-core + input pipeline configuration, the shard batch size is also + `train_batch_size` // #cores. Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN. tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred within TPUEstimator, however when using ClusterSpec propagation in more @@ -91,6 +90,17 @@ class TPUConfig( initial_infeed_sleep_secs: The number of seconds the infeed thread should wait before enqueueing the first batch. This helps avoid timeouts for models that require a long compilation time. + input_partition_dims: A nested list to describe the partition dims + for all the tensors from input_fn(). The structure of + input_partition_dims must match the structure of `features` and + `labels` from input_fn(). The total number of partitions must match + `num_cores_per_replica`. For example, if input_fn() returns two tensors: + images with shape [N, H, W, C] and labels [N]. + input_partition_dims = [[1, 2, 2, 1], None] will split the images to 4 + pieces and feed into 4 TPU cores. labels tensor are directly broadcasted + to all the TPU cores since the partition dims is `None`. + Current limitations: This feature is only supported with the PER_HOST_V2 + input mode. Raises: ValueError: If `computation_shape` or `computation_shape` are invalid. @@ -99,10 +109,11 @@ class TPUConfig( def __new__(cls, iterations_per_loop=2, num_shards=None, - computation_shape=None, + num_cores_per_replica=None, per_host_input_for_training=True, tpu_job_name=None, - initial_infeed_sleep_secs=None): + initial_infeed_sleep_secs=None, + input_partition_dims=None): # Check iterations_per_loop. util_lib.check_positive_integer(iterations_per_loop, @@ -112,19 +123,26 @@ class TPUConfig( if num_shards is not None: util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards') - # Check computation_shape - if computation_shape is not None and len(computation_shape) != 3: - raise ValueError( - 'computation_shape must be a list with length 3 or None; got {}'. - format(str(computation_shape))) + if input_partition_dims is not None: + if len(input_partition_dims) != 1 and len(input_partition_dims) != 2: + raise ValueError( + 'input_partition_dims must be a list/tuple with one or two' + ' elements.') + + if per_host_input_for_training is not InputPipelineConfig.PER_HOST_V2: + raise ValueError( + 'input_partition_dims is only supported in PER_HOST_V2 mode.') + + if num_cores_per_replica is None: + raise ValueError( + 'input_partition_dims requires setting num_cores_per_replica.') - if computation_shape is not None: - computation_shape_array = np.asarray(computation_shape, dtype=np.int32) - # This prevents any computation being replicated across multiple hosts, so - # that each host feeds the same number of computations. - if any(computation_shape_array < 1) or any(computation_shape_array > 2): - raise ValueError('computation_shape elements can only be 1 or 2; got ' - 'computation_shape={}'.format(computation_shape)) + # Parse computation_shape + if num_cores_per_replica is not None: + if num_cores_per_replica not in [1, 2, 4, 8]: + raise ValueError( + 'num_cores_per_replica must be 1, 2, 4, or 8; got {}'.format( + str(num_cores_per_replica))) # per_host_input_for_training may be True, False, or integer in [1..3]. # Map legacy values (True, False) to numeric values. @@ -144,10 +162,11 @@ class TPUConfig( cls, iterations_per_loop=iterations_per_loop, num_shards=num_shards, - computation_shape=computation_shape, + num_cores_per_replica=num_cores_per_replica, per_host_input_for_training=per_host_input_for_training, tpu_job_name=tpu_job_name, - initial_infeed_sleep_secs=initial_infeed_sleep_secs) + initial_infeed_sleep_secs=initial_infeed_sleep_secs, + input_partition_dims=input_partition_dims) class RunConfig(run_config_lib.RunConfig): @@ -214,6 +233,12 @@ class RunConfig(run_config_lib.RunConfig): self._session_config.cluster_def.CopyFrom( self._cluster_spec.as_cluster_def()) + def _maybe_overwrite_session_config_for_distributed_training(self): + # Overrides the parent class session_config overwrite for between-graph. TPU + # runs with in-graph, which should not have device filter. Doing nothing + # ("pass") basically disables it. + pass + @property def evaluation_master(self): return self._evaluation_master diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py index 37ef3dbe1e66efe18b13ab9153ee346c08b9774a..2326fe97a807e6708a9cdc24fea889b998025a45 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import json from tensorflow.contrib.tpu.python.tpu import tpu_config as tpu_config_lib +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.platform import test @@ -33,6 +34,46 @@ def _set_tf_config_env_variable(tf_config): class TPURunConfigTest(test.TestCase): + def test_no_session_config_set_in_local_case(self): + run_config = tpu_config_lib.RunConfig() + self.assertIsNone(run_config.session_config) + + def test_no_session_config_overwrite_in_local_case(self): + session_config = config_pb2.ConfigProto(allow_soft_placement=True) + run_config = tpu_config_lib.RunConfig(session_config=session_config) + self.assertEqual(session_config, run_config.session_config) + + def test_no_session_config_set_with_cluster_spec(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host3:3'], + run_config_lib.TaskType.WORKER: ['host3:4'] + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + with _set_tf_config_env_variable(tf_config): + run_config = tpu_config_lib.RunConfig() + self.assertIsNone(run_config.session_config) + + def test_no_session_config_overwrite_with_cluster_spec(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host3:3'], + run_config_lib.TaskType.WORKER: ['host3:4'] + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + with _set_tf_config_env_variable(tf_config): + session_config = config_pb2.ConfigProto(allow_soft_placement=True) + run_config = tpu_config_lib.RunConfig(session_config=session_config) + self.assertEqual(session_config, run_config.session_config) + def test_fail_with_invalid_num_shards(self): with self.assertRaisesRegexp(ValueError, 'must be positive'): tpu_config_lib.RunConfig( @@ -43,15 +84,11 @@ class TPURunConfigTest(test.TestCase): tpu_config_lib.RunConfig( tpu_config=tpu_config_lib.TPUConfig(iterations_per_loop=0)) - def test_fail_with_invalid_computation_shape(self): - with self.assertRaisesRegexp(ValueError, - 'computation_shape must be a list with length' - ' 3 or None'): - tpu_config_lib.TPUConfig(computation_shape=[2, 1]) - - with self.assertRaisesRegexp(ValueError, - 'computation_shape elements can only be'): - tpu_config_lib.TPUConfig(computation_shape=[1, 3, 1]) + def test_fail_with_invalid_num_cores_per_replica(self): + with self.assertRaisesRegexp( + ValueError, 'num_cores_per_replica must be 1, 2, 4, or 8;' + ' got 7'): + tpu_config_lib.TPUConfig(num_cores_per_replica=7) class TPURunConfigMasterTest(test.TestCase): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index aec59f3885ca7a2046c24ce5b94917ad6c3693e7..806ae1c4c9918be0bf0af8579c12386c0a18aff0 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -21,8 +21,6 @@ from __future__ import print_function from contextlib import contextmanager import copy -import numpy as np - from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment from tensorflow.contrib.tpu.python.tpu import tpu_config from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib @@ -33,15 +31,26 @@ from tensorflow.python.platform import tf_logging as logging _DEFAULT_JOB_NAME = 'tpu_worker' _DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' _LOCAL_MASTERS = ('', 'local') +_NUM_CORES_TO_COMPUTATION_SHAPE = { + 1: [1, 1, 1], + 2: [1, 1, 2], + 4: [1, 2, 2], + 8: [2, 2, 2] +} class TPUContext(object): """The context of current input_fn invocation.""" - def __init__(self, internal_ctx, input_device=None, invocation_index=None): + def __init__(self, + internal_ctx, + input_device=None, + invocation_index=None, + call_from_input_fn=True): self._internal_ctx = internal_ctx self._input_device = input_device self._invocation_index = invocation_index + self._call_from_input_fn = call_from_input_fn def current_input_fn_deployment(self): """The configuration of the current input_fn invocation. @@ -69,11 +78,21 @@ class TPUContext(object): total invocation count is equal to the number of hosts in the system and num replicas consumed by current invocation is equal to number of cores per host. + + Raises: + RuntimeError: If this method must not be called from input_fn. """ + if not self._call_from_input_fn: + raise RuntimeError('This TPUContext instance must not be called from' + ' model_fn.') + if self._internal_ctx.is_input_sharded_per_core(): total_invocation_count = (self._internal_ctx.num_hosts * self._internal_ctx.num_of_replicas_per_host) replicas_consumed = 1 + elif self._internal_ctx.is_input_broadcast_with_iterators(): + total_invocation_count = 1 + replicas_consumed = self._internal_ctx.num_replicas else: total_invocation_count = self._internal_ctx.num_hosts replicas_consumed = self._internal_ctx.num_of_replicas_per_host @@ -105,6 +124,14 @@ class TPUContext(object): 'num_of_replicas_per_host is not supported for model_parallelism') return self._internal_ctx.num_of_replicas_per_host + @property + def device_assignment(self): + """Returns device_assignment object.""" + if self._call_from_input_fn: + raise RuntimeError('This TPUContext instance must not be called from' + ' input_fn.') + return self._internal_ctx.device_assignment + def device_for_replica(self, replica_id): """Returns the tuple of (CPU device and device ordinal) for replica. @@ -119,24 +146,7 @@ class TPUContext(object): # Note that: For the non-model parallelism, the mapping could be # a random permutation. The order should not matter in most cases # as far as model is replicated to all cores in the system. - - # If the precise replica_id to device mapping is required, please - # set the computation_shape as [1,1,1] in TPUConfig to enable - # the model parallelism. - if self._internal_ctx.model_parallelism_enabled: - return RuntimeError( - 'device_for_replica is not yet implemented for model parallelism. ' - 'b/79689078.') - - master = self._internal_ctx.master_job - job_device = '' if master is None else ('/job:%s' % master) - - num_of_replicas_per_host = self._internal_ctx.num_of_replicas_per_host - host_id = replica_id / num_of_replicas_per_host - ordinal_id = replica_id % num_of_replicas_per_host - - host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id) - return (host_device, ordinal_id) + return self._internal_ctx.device_for_replica(replica_id) class _InternalTPUContext(object): @@ -175,9 +185,14 @@ class _InternalTPUContext(object): self._eval_on_tpu = eval_on_tpu self._model_parallelism_enabled = ( - use_tpu and config.tpu_config.computation_shape) + use_tpu and config.tpu_config.num_cores_per_replica) self._mode = None - + num_cores_per_replica = config.tpu_config.num_cores_per_replica + if num_cores_per_replica: + self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[ + num_cores_per_replica] + else: + self._computation_shape = None self._lazy_tpu_system_metadata_dict = {} # key by master address self._lazy_device_assignment_dict = {} # key by master address self._lazy_validation_dict = {} # key by ModeKeys @@ -217,11 +232,16 @@ class _InternalTPUContext(object): if tpu_system_metadata is not None: return tpu_system_metadata + cluster_def = None + if (self._config.session_config and + self._config.session_config.cluster_def.job): + cluster_def = self._config.session_config.cluster_def + # pylint: disable=protected-access tpu_system_metadata = ( tpu_system_metadata_lib._query_tpu_system_metadata( master, - run_config=self._config, + cluster_def=cluster_def, query_topology=self.model_parallelism_enabled)) self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata @@ -238,11 +258,12 @@ class _InternalTPUContext(object): device_assignment = tpu_device_assignment.device_assignment( tpu_system_metadata.topology, - computation_shape=self._config.tpu_config.computation_shape, + computation_shape=self._computation_shape, num_replicas=self.num_replicas) - logging.info('computation_shape: %s', - str(self._config.tpu_config.computation_shape)) + logging.info('num_cores_per_replica: %s', + str(self._config.tpu_config.num_cores_per_replica)) + logging.info('computation_shape: %s', str(self._computation_shape)) logging.info('num_replicas: %d', self.num_replicas) logging.info('device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) @@ -256,6 +277,10 @@ class _InternalTPUContext(object): def model_parallelism_enabled(self): return self._model_parallelism_enabled + @property + def input_partition_dims(self): + return self._config.tpu_config.input_partition_dims + @property def device_assignment(self): return (self._get_device_assignment() @@ -283,23 +308,20 @@ class _InternalTPUContext(object): num_cores_in_system = self.num_cores if self.model_parallelism_enabled: - computation_shape_array = np.asarray( - self._config.tpu_config.computation_shape, dtype=np.int32) - num_cores_per_replica = np.prod(computation_shape_array) + num_cores_per_replica = self._config.tpu_config.num_cores_per_replica if num_cores_per_replica > num_cores_in_system: raise ValueError( 'The num of cores required by the model parallelism, specified by ' - 'TPUConfig.computation_shape, is larger than the total num of ' - 'TPU cores in the system. computation_shape: {}, num cores ' - 'in the system: {}'.format( - self._config.tpu_config.computation_shape, - num_cores_in_system)) + 'TPUConfig.num_cores_per_replica, is larger than the total num of ' + 'TPU cores in the system. num_cores_per_replica: {}, num cores ' + 'in the system: {}'.format(num_cores_per_replica, + num_cores_in_system)) if num_cores_in_system % num_cores_per_replica != 0: raise RuntimeError( 'The num of cores in the system ({}) is not divisible by the num ' 'of cores ({}) required by the model parallelism, specified by ' - 'TPUConfig.computation_shape. This should never happen!'.format( + 'TPUConfig.num_cores_per_replica. This should never happen!'.format( num_cores_in_system, num_cores_per_replica)) return num_cores_in_system // num_cores_per_replica @@ -327,6 +349,11 @@ class _InternalTPUContext(object): return (self._config.tpu_config.per_host_input_for_training is tpu_config.InputPipelineConfig.PER_HOST_V2) + def is_input_broadcast_with_iterators(self): + """Return true if input_fn should be run in the full_replicae config.""" + return (self._config.tpu_config.per_host_input_for_training is + tpu_config.InputPipelineConfig.BROADCAST) + def is_running_on_cpu(self, is_export_mode=False): """Determines whether the input_fn and model_fn should be invoked on CPU. @@ -391,7 +418,7 @@ class _InternalTPUContext(object): """Returns the shard batch size for `input_fn`.""" global_batch_size = self.global_batch_size - if self.is_running_on_cpu(): + if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()): return global_batch_size # On TPU @@ -406,7 +433,7 @@ class _InternalTPUContext(object): """Returns the shard batch size for `model_fn`.""" global_batch_size = self.global_batch_size - if self.is_running_on_cpu(): + if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()): return global_batch_size # On TPU. always sharded per shard. @@ -463,17 +490,23 @@ class _InternalTPUContext(object): master = self.master_job - def _placement_function(_sentinal=None, core_id=None, host_id=None): # pylint: disable=invalid-name + def _placement_function(_sentinal=None, replica_id=None, host_id=None): # pylint: disable=invalid-name + """Return the host device given replica_id or host_id.""" assert _sentinal is None - if core_id is not None and host_id is not None: + if replica_id is not None and host_id is not None: raise RuntimeError( - 'core_id and host_id can have only one non-None value.') + 'replica_id and host_id can have only one non-None value.') if master is None: return '/replica:0/task:0/device:CPU:0' else: - if core_id is not None: - host_id = core_id / self.num_of_cores_per_host + if replica_id is not None: + if self.model_parallelism_enabled: + return self.device_assignment.host_device( + replica=replica_id, job=master) + else: + host_id = replica_id / self.num_of_cores_per_host + return '/job:%s/task:%d/device:CPU:0' % (master, host_id) return _placement_function @@ -546,7 +579,7 @@ class _InternalTPUContext(object): 'be ({}), got ({}). For non-model-parallelism, num_replicas should ' 'be the total num of TPU cores in the system. For ' 'model-parallelism, the total number of TPU cores should be ' - 'product(computation_shape) * num_replicas. Please set it ' + 'num_cores_per_replica * num_replicas. Please set it ' 'accordingly or leave it as `None`'.format( self._get_master_address(), num_replicas, user_provided_num_replicas)) @@ -554,7 +587,8 @@ class _InternalTPUContext(object): raise ValueError(message) if mode == model_fn_lib.ModeKeys.TRAIN: - if self._train_batch_size % num_replicas != 0: + if (self._train_batch_size % num_replicas != 0 and + not self.is_input_broadcast_with_iterators()): raise ValueError( 'train batch size {} must be divisible by number of replicas {}' .format(self._train_batch_size, num_replicas)) @@ -564,11 +598,12 @@ class _InternalTPUContext(object): raise ValueError( 'eval_batch_size in TPUEstimator constructor cannot be `None`' 'if .evaluate is running on TPU.') - if self._eval_batch_size % num_replicas != 0: + if (self._eval_batch_size % num_replicas != 0 and + not self.is_input_broadcast_with_iterators()): raise ValueError( 'eval batch size {} must be divisible by number of replicas {}' .format(self._eval_batch_size, num_replicas)) - if num_hosts > 1: + if num_hosts > 1 and not self.is_input_broadcast_with_iterators(): raise ValueError( 'TPUEstimator.evaluate should be running on single TPU worker. ' 'got {}.'.format(num_hosts)) @@ -578,11 +613,12 @@ class _InternalTPUContext(object): raise ValueError( 'predict_batch_size in TPUEstimator constructor should not be ' '`None` if .predict is running on TPU.') - if self._predict_batch_size % num_replicas != 0: + if (self._predict_batch_size % num_replicas != 0 and + not self.is_input_broadcast_with_iterators()): raise ValueError( 'predict batch size {} must be divisible by number of replicas {}' .format(self._predict_batch_size, num_replicas)) - if num_hosts > 1: + if num_hosts > 1 and not self.is_input_broadcast_with_iterators(): raise ValueError( 'TPUEstimator.predict should be running on single TPU worker. ' 'got {}.'.format(num_hosts)) @@ -590,6 +626,33 @@ class _InternalTPUContext(object): # Record the state "validated" into lazy dictionary. self._lazy_validation_dict[mode] = True + def device_for_replica(self, replica_id): + """Returns the tuple of (CPU device and device ordinal) for replica. + + This should be used for full replicate for non-model-parallelism. + + Args: + replica_id: Int, the replica index. + + Returns: + A tuple of device spec for CPU device and int device ordinal. + """ + master = self.master_job + + if self.model_parallelism_enabled: + return (self.device_assignment.host_device( + replica=replica_id, job=master), + self.device_assignment.tpu_ordinal(replica=replica_id)) + + job_device = '' if master is None else ('/job:%s' % master) + + num_of_replicas_per_host = self.num_of_replicas_per_host + host_id = replica_id / num_of_replicas_per_host + ordinal_id = replica_id % num_of_replicas_per_host + + host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id) + return (host_device, ordinal_id) + class _OneCoreTPUContext(_InternalTPUContext): """Special _InternalTPUContext for one core usage.""" @@ -625,7 +688,7 @@ def _get_tpu_context(config, train_batch_size, eval_batch_size, """Returns an instance of `_InternalTPUContext`.""" if (config.tpu_config.num_shards == 1 and - config.tpu_config.computation_shape is None): + config.tpu_config.num_cores_per_replica is None): logging.warning( 'Setting TPUConfig.num_shards==1 is an unsupported behavior. ' 'Please fix as soon as possible (leaving num_shards as None.') diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 5210139336990fad70bd4ff2a9788393de6ec20c..c104b2403c69529625bf7a6d921a952150b31b3a 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -22,9 +22,9 @@ import collections import copy import os import signal +import sys import threading import time -import traceback import numpy as np import six @@ -32,6 +32,7 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import error_handling from tensorflow.contrib.tpu.python.tpu import session_support from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_config @@ -81,12 +82,17 @@ _TPU_ESTIMATOR = 'tpu_estimator' _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' _BATCH_SIZE_KEY = 'batch_size' _CTX_KEY = 'context' +_USE_TPU_KEY = 'use_tpu' _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _ONE_GIGABYTE = 1024 * 1024 * 1024 _TPU_ENQUEUE_OPS = '_tpu_enqueue_ops' _TPU_TRAIN_OP = '_tpu_train_op' _REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference' +# Ideally _USE_TPU_KEY should be reserved as well. However there are already +# models that make use of this key, thus it can not be reserved now to prevent +# breakage. In the long run, we would like to mitigate this by migrating models +# off of using _USE_TPU_KEY. _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY] @@ -211,8 +217,8 @@ class _SIGNAL(object): class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`. - See `EstimatorSpec` for `mode`, 'predictions, 'loss', 'train_op', and - 'export_outputs`. + See `EstimatorSpec` for `mode`, `predictions`, `loss`, `train_op`, and + `export_outputs`. For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where `metric_fn` runs on CPU to generate metrics and `tensors` represents the @@ -226,7 +232,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote size is the first dimension. Once all tensors are available at CPU host from all shards, they are concatenated (on CPU) and passed as positional arguments to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is - dict. `metric_fn` takes the `tensors` and returns a dict from metric string + a dict. `metric_fn` takes the `tensors` and returns a dict from metric string name to the result of calling a metric function, namely a `(metric_tensor, update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the `eval_metrics`. @@ -252,7 +258,10 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote eval_metrics=None, export_outputs=None, scaffold_fn=None, - host_call=None): + host_call=None, + training_hooks=None, + evaluation_hooks=None, + prediction_hooks=None): """Creates a validated `TPUEstimatorSpec` instance.""" host_calls = {} if eval_metrics is not None: @@ -260,6 +269,17 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote if host_call is not None: host_calls['host_call'] = host_call _OutfeedHostCall.validate(host_calls) + + training_hooks = list(training_hooks or []) + evaluation_hooks = list(evaluation_hooks or []) + prediction_hooks = list(prediction_hooks or []) + + for hook in training_hooks + evaluation_hooks + prediction_hooks: + if not isinstance(hook, session_run_hook.SessionRunHook): + raise TypeError( + 'All hooks must be SessionRunHook instances, given: {}'.format( + hook)) + return super(TPUEstimatorSpec, cls).__new__( cls, mode=mode, @@ -269,7 +289,10 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote eval_metrics=eval_metrics, export_outputs=export_outputs, scaffold_fn=scaffold_fn, - host_call=host_call) + host_call=host_call, + training_hooks=training_hooks, + evaluation_hooks=evaluation_hooks, + prediction_hooks=prediction_hooks) def as_estimator_spec(self): """Creates an equivalent `EstimatorSpec` used by CPU train/eval.""" @@ -285,6 +308,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote hooks = None if self.host_call is not None: hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] + hooks = list(hooks or []) scaffold = self.scaffold_fn() if self.scaffold_fn else None return model_fn_lib.EstimatorSpec( mode=self.mode, @@ -294,9 +318,9 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote eval_metric_ops=eval_metric_ops, export_outputs=self.export_outputs, scaffold=scaffold, - training_hooks=hooks, - evaluation_hooks=hooks, - prediction_hooks=hooks) + training_hooks=self.training_hooks + hooks, + evaluation_hooks=self.evaluation_hooks + hooks, + prediction_hooks=self.prediction_hooks + hooks) class _OpQueueContext(object): @@ -360,17 +384,17 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): ctx, enqueue_ops, dequeue_ops, - run_infeed_loop_on_coordinator=True): + run_infeed_loop_on_coordinator=True, + rendezvous=None): self._master_job = ctx.master_job self._enqueue_ops = enqueue_ops self._dequeue_ops = dequeue_ops + self._rendezvous = rendezvous self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator self._initial_infeed_sleep_secs = ( ctx.config.tpu_config.initial_infeed_sleep_secs) - self._session_cancel_timer = None - self._feed_error = None self._finished = False @@ -387,62 +411,6 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): for op in summary_writer_init_ops: self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) - def _log_error(self, session, error): - """Log an infeed or outfeed error. - - This logs a short error message immediately, and schedules a timer to - emit the full stack trace and error message after a short period of time. - If the main session has terminated by the time the timer triggers, we - assume the real source of the error was from the main session and avoid - emitting a stack trace for the infeed. - - Args: - session: `tf.Session`, session to be terminated error: exception that - triggered logging. - error: the Exception to log. - """ - logging.warning( - '\n\n' - 'Error occurred during infeed/outfeed. This may be due to a compile ' - 'error in the main session. Waiting for a short time for the main ' - 'session to come back.\n\n%s', error) - - self._feed_error = traceback.format_exc() - - # If we've already encountered a feed error, don't schedule another - # cancellation op. - if self._session_cancel_timer: - return - - def _cancel_session(): - """Close the session to avoid the main thread from hanging. - - If input pipeline triggers any error, the infeed thread dies but the main - thread for TPU computation waits for the infeed enqueue forever. Close the - Session to cancel the main thread Session.run execution. - - We sleep for a few seconds before closing to give some time for the TPU - compilation error, if any, propagating, from TPU to CPU host. Compilation - errors should be reported by the main thread so that the program can be - interrupted and users can take action. Due to a race condition, the - infeed thread might see an error first. Closing the session here - immediately would result in a session cancellation exception in the main - thread, instead of the expected compile error. User code that depends on - having the proper exception type will therefore be confused. - """ - time.sleep(5) - - # If the main session is still running, the infeed/outfeed errors are - # legitimate, and should be logged. - if not self._finished and self._feed_error: - logging.error('Feed error: %s', self._feed_error) - logging.error('Closing session. A RuntimeError should follow.') - session.close() - - self._session_cancel_timer = threading.Thread(target=_cancel_session) - self._session_cancel_timer.daemon = True - self._session_cancel_timer.start() - def _run_infeed(self, queue_ctx, session): logging.info('Starting infeed thread controller.') if self._initial_infeed_sleep_secs: @@ -451,7 +419,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): time.sleep(self._initial_infeed_sleep_secs) logging.info('%s thread starting after sleep', self._name) - try: + with self._rendezvous.catch_errors(source='infeed', session=session): if self._run_infeed_loop_on_coordinator: for count, steps in enumerate(queue_ctx.read_iteration_counts()): for i in xrange(steps): @@ -461,19 +429,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): for _ in queue_ctx.read_iteration_counts(): session.run(self._enqueue_ops) logging.info('Infeed thread finished, shutting down.') - except Exception as e: # pylint: disable=broad-except - self._log_error(session, e) def _run_outfeed(self, queue_ctx, session): logging.info('Starting outfeed thread controller.') - try: + with self._rendezvous.catch_errors(source='outfeed', session=session): for count, steps in enumerate(queue_ctx.read_iteration_counts()): for i in xrange(steps): logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i) session.run(self._dequeue_ops) logging.info('Outfeed thread finished, shutting down.') - except Exception as e: # pylint: disable=broad-except - self._log_error(session, e) def _create_infeed_controller(self, name, target, args): return _OpQueueContext(name=name, target=target, args=args) @@ -492,11 +456,6 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): def before_run(self, run_context): self._feed_error = None - # Wait for the cancellation timer to complete before continuing. - if self._session_cancel_timer: - self._session_cancel_timer.join() - self._session_cancel_timer = None - iterations = run_context.session.run(self._iterations_per_loop_var) logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations) @@ -507,16 +466,14 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): self._outfeed_controller.send_next_batch_signal(iterations) def end(self, session): - if self._session_cancel_timer: - logging.warning('Feed error occurred; waiting for message.') - self._session_cancel_timer.join() - self._finished = True logging.info('Stop infeed thread controller') self._infeed_controller.join() + self._rendezvous.record_done('infeed') logging.info('Stop output thread controller') self._outfeed_controller.join() + self._rendezvous.record_done('outfeed') logging.info('Shutdown TPU system.') session.run(self._finalize_ops) @@ -524,9 +481,10 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook): - def __init__(self, ctx, enqueue_ops, dequeue_ops): + def __init__(self, ctx, enqueue_ops, dequeue_ops, rendezvous=None): super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__( - ctx, enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=False) + ctx, enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=False, + rendezvous=rendezvous) def _create_infeed_controller(self, name, target, args): return _OpSignalOnceQueueContext(name=name, target=target, args=args) @@ -696,8 +654,6 @@ def generate_per_core_enqueue_ops_fn_for_host( infeed_queue = tpu_feed.InfeedQueue( number_of_tuple_elements=len(per_host_sharded_inputs[0])) captured_infeed_queue.capture(infeed_queue) - infeed_queue.set_configuration_from_sharded_input_tensors( - per_host_sharded_inputs) per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) @@ -825,19 +781,105 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( flattened_inputs = ( inputs_structure_recorder.flatten_features_and_labels( features, labels)) - control_deps.extend(flattened_inputs) per_host_sharded_inputs.append(flattened_inputs) + if inputs_structure_recorder.flattened_input_dims: + # pylint: disable=protected-access + infeed_queue = tpu_feed._PartitionedInfeedQueue( + number_of_tuple_elements=len(per_host_sharded_inputs[0]), + host_id=host_id, + input_partition_dims=inputs_structure_recorder.flattened_input_dims, + device_assignment=ctx.device_assignment) + per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( + per_host_sharded_inputs) + else: + infeed_queue = tpu_feed.InfeedQueue( + number_of_tuple_elements=len(per_host_sharded_inputs[0])) + per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( + per_host_sharded_inputs, + tpu_ordinal_function=tpu_ordinal_function_impl) + captured_infeed_queue.capture(infeed_queue) + + return per_host_enqueue_ops + + return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset + + +def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder, + num_hosts): + """Generates infeed enqueue ops for one input_fn on all the hosts.""" + captured_infeed_queue = _CapturedObject() + hooks = [] + device_0 = ctx.tpu_host_placement_function(host_id=0) + with ops.device(device_0): + user_context = tpu_context.TPUContext( + internal_ctx=ctx, input_device=device_0, invocation_index=0) + inputs = _Inputs.from_input_fn(input_fn(user_context)) + + is_dataset = inputs.is_dataset + if ctx.mode == model_fn_lib.ModeKeys.PREDICT: + if not is_dataset: + raise TypeError( + 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' + '`features` and `labels`.') + + inputs = _InputsWithStoppingSignals( + dataset=inputs.dataset, + batch_size=ctx.batch_size_for_input_fn, + add_padding=True) + + if is_dataset: + hooks.append(inputs.dataset_initializer_hook()) + num_replicas_per_host = ctx.num_of_replicas_per_host + + def tpu_ordinal_function_impl(replica_id): + if ctx.device_assignment: + return ctx.device_assignment.tpu_ordinal(replica=replica_id) + else: + return replica_id % num_replicas_per_host + + def device_function_impl(replica_id): + return ctx.tpu_host_placement_function(replica_id=replica_id) + + def enqueue_ops_fn(): + """Generates enqueue ops for all the hosts.""" + broadcasted_inputs = [] + flattened_inputs = None # Cache result from input_fn. + signals = None + for host_id in xrange(num_hosts): + with ops.device(ctx.tpu_host_placement_function(host_id=host_id)): + for _ in xrange(ctx.num_of_replicas_per_host): + # Note: input_fn is only called once at host 0 for the first replica. + # The features and labels returned from that invocation are + # broadcasted to other replicas(including the replicas on other + # hosts). + if flattened_inputs is None: + features, labels = inputs.features_and_labels() # Calls get_next() + signals = inputs.signals() + + inputs_structure_recorder.validate_and_record_structure( + features, labels, signals) + flattened_inputs = ( + inputs_structure_recorder.flatten_features_and_labels( + features, labels, signals)) + broadcasted_inputs.append(flattened_inputs) + infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(per_host_sharded_inputs[0])) + number_of_tuple_elements=len(broadcasted_inputs[0])) captured_infeed_queue.capture(infeed_queue) - infeed_queue.set_configuration_from_sharded_input_tensors( - per_host_sharded_inputs) + enqueue_ops = infeed_queue.generate_enqueue_ops( + broadcasted_inputs, + tpu_ordinal_function=tpu_ordinal_function_impl, + placement_function=device_function_impl) - per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) - return per_host_enqueue_ops + if signals is None: + return enqueue_ops + else: + return { + 'ops': enqueue_ops, + 'signals': signals, + } return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset @@ -875,21 +917,68 @@ class _InputPipeline(object): class InputsStructureRecorder(object): """The recorder to record inputs structure.""" - def __init__(self): + def __init__(self, input_partition_dims=None): # Holds the structure of inputs self._feature_names = [] self._label_names = [] self._has_labels = False self._signals_helper = None + self._flattened_input_dims = None + + if input_partition_dims: + # This should have been validated in TPUConfig. + assert len(input_partition_dims) <= 2, 'must have 1 or 2 elements.' + if len(input_partition_dims) == 2: + self._feature_dims, self._label_dims = input_partition_dims + else: + self._feature_dims = input_partition_dims[0] + self._label_dims = None + + assert self._feature_dims is not None, ('input_partition_dims[0] must ' + 'not be None') + else: + self._feature_dims = None + self._label_dims = None # Internal state. self._initialized = False + @property + def flattened_input_dims(self): + assert self._initialized, 'InputsStructureRecorder is not initialized.' + return self._flattened_input_dims + def has_labels(self): return self._has_labels + def _flatten_input_dims(self, feature_dims, feature_dims_names, label_dims, + label_dims_names, label_names, has_labels): + """Flatten input dims with the same order as flattened input tensors.""" + flattened_input_dims = [] + if feature_dims_names: + # We need a fixed ordering for matching the tensors in features. + flattened_input_dims.extend( + [feature_dims[name] for name in feature_dims_names]) + else: + flattened_input_dims.append(feature_dims) + + if label_dims_names: + # We need a fixed ordering for matching the tensors in labels. + flattened_input_dims.extend( + [label_dims[name] for name in label_dims_names]) + else: + if label_names: + num_tensors_in_label = len(label_names) + else: + num_tensors_in_label = int(has_labels) + # Setting `None` in input_partition_dims[1] will apply `None` to + # all the tensors in labels, regardless of internal structure. + flattened_input_dims.extend([label_dims] * num_tensors_in_label) + + return flattened_input_dims + def validate_and_record_structure(self, features, labels, signals=None): - """Validates and records the structure of features` and `labels`.""" + """Validates and records the structure of `features` and `labels`.""" def _extract_key_names(tensor_or_dict): if tensor_or_dict is None: @@ -917,6 +1006,24 @@ class _InputPipeline(object): self._feature_names = feature_names self._label_names = label_names self._has_labels = has_labels + if self._feature_dims is not None: + feature_dims_names = _extract_key_names(self._feature_dims) + if feature_dims_names != feature_names: + raise ValueError( + 'TPUConfig.input_partition_dims[0] mismatched feature' + ' keys. Expected {}, got {}'.format(feature_names, + feature_dims_names)) + + label_dims_names = _extract_key_names(self._label_dims) + if self._label_dims is not None and label_dims_names != label_names: + raise ValueError( + 'TPUConfig.input_partition_dims[1] mismatched label' + ' keys. Expected {}, got {}'.format(label_names, + label_dims_names)) + + self._flattened_input_dims = self._flatten_input_dims( + self._feature_dims, feature_dims_names, self._label_dims, + label_dims_names, label_names, has_labels) def flatten_features_and_labels(self, features, labels, signals=None): """Flattens the `features` and `labels` to a single tensor list.""" @@ -1011,7 +1118,8 @@ class _InputPipeline(object): Raises: ValueError: If both `sharded_features` and `num_cores` are `None`. """ - self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder() + self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder( + ctx.input_partition_dims) self._sharded_per_core = ctx.is_input_sharded_per_core() self._input_fn = input_fn @@ -1074,6 +1182,24 @@ class _InputPipeline(object): # Infeed_queue_getter must be called after enqueue_ops_fn is called. infeed_queues.append(captured_infeed_queue.get()) + elif self._ctx.is_input_broadcast_with_iterators(): + # Only calls input_fn in host 0. + host_device = tpu_host_placement_fn(host_id=0) + enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = ( + generate_broadcast_enqueue_ops_fn(self._ctx, self._input_fn, + self._inputs_structure_recorder, + num_hosts)) + all_hooks.extend(hooks) + if is_dataset: + run_infeed_loop_on_coordinator = False + wrap_fn = ( + _wrap_computation_in_while_loop + if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else + _wrap_computation_in_while_loop_with_stopping_signals) + enqueue_ops.append(wrap_fn(device=host_device, op_fn=enqueue_ops_fn)) + else: + enqueue_ops.append(enqueue_ops_fn()) + infeed_queues.append(captured_infeed_queue.get()) else: for host_id in range(num_hosts): host_device = tpu_host_placement_fn(host_id=host_id) @@ -1188,6 +1314,7 @@ class _ModelFnWrapper(object): host_call = _OutfeedHostCall(self._ctx) captured_scaffold_fn = _CapturedObject() + captured_training_hooks = _CapturedObject() def train_step(loss): """Training step function for use inside a while loop.""" @@ -1204,6 +1331,8 @@ class _ModelFnWrapper(object): else: captured_scaffold_fn.capture(None) + captured_training_hooks.capture(estimator_spec.training_hooks) + # We must run train_op to update the variables prior to running the # outfeed. with ops.control_dependencies([train_op]): @@ -1215,7 +1344,8 @@ class _ModelFnWrapper(object): with ops.control_dependencies(host_call_outfeed_ops): return array_ops.identity(loss) - return train_step, host_call, captured_scaffold_fn + return (train_step, host_call, captured_scaffold_fn, + captured_training_hooks) def convert_to_single_tpu_eval_step(self, dequeue_fn): """Converts user provided model_fn` as a single eval step on TPU. @@ -1245,6 +1375,7 @@ class _ModelFnWrapper(object): """ host_calls = _OutfeedHostCall(self._ctx) captured_scaffold_fn = _CapturedObject() + captured_eval_hooks = _CapturedObject() def eval_step(total_loss): """Evaluation step function for use inside a while loop.""" @@ -1259,8 +1390,11 @@ class _ModelFnWrapper(object): loss = tpu_estimator_spec.loss captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) + captured_eval_hooks.capture(tpu_estimator_spec.evaluation_hooks) + to_record = {} - to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics + if tpu_estimator_spec.eval_metrics: + to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics if tpu_estimator_spec.host_call is not None: # We assume that evaluate won't update global step, so we don't wrap # this host_call. @@ -1270,7 +1404,7 @@ class _ModelFnWrapper(object): with ops.control_dependencies(host_calls.create_enqueue_op()): return math_ops.add(total_loss, loss) - return eval_step, host_calls, captured_scaffold_fn + return eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks def convert_to_single_tpu_predict_step(self, dequeue_fn): """Converts user provided model_fn` as a single predict step on TPU. @@ -1285,6 +1419,7 @@ class _ModelFnWrapper(object): """ host_calls = _OutfeedHostCall(self._ctx) captured_scaffold_fn = _CapturedObject() + captured_predict_hooks = _CapturedObject() def predict_step(unused_scalar_stopping_signal): """Evaluation step function for use inside a while loop.""" @@ -1305,6 +1440,7 @@ class _ModelFnWrapper(object): self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions) captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) + captured_predict_hooks.capture(tpu_estimator_spec.prediction_hooks) to_record = {} identity_fn = lambda **kwargs: kwargs to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions] @@ -1316,7 +1452,8 @@ class _ModelFnWrapper(object): with ops.control_dependencies(host_calls.create_enqueue_op()): return _StopSignals.as_scalar_stopping_signal(stopping_signals) - return predict_step, host_calls, captured_scaffold_fn + return (predict_step, host_calls, captured_scaffold_fn, + captured_predict_hooks) def _verify_tpu_spec_predictions(self, predictions): """Validates TPUEstimatorSpec.predictions dict.""" @@ -1414,8 +1551,16 @@ class _ModelFnWrapper(object): if batch_size_for_model_fn is not None: _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn) + running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode) + _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu) + + if not running_on_cpu: + user_context = tpu_context.TPUContext( + internal_ctx=self._ctx, call_from_input_fn=False) + _add_item_to_params(params, _CTX_KEY, user_context) + estimator_spec = self._model_fn(features=features, **kwargs) - if (self._ctx.is_running_on_cpu(is_export_mode) and + if (running_on_cpu and isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access # The estimator_spec will be passed to `Estimator` directly, which expects # type `EstimatorSpec`. @@ -1430,11 +1575,9 @@ class _ModelFnWrapper(object): err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.' if estimator_spec.training_chief_hooks: - raise ValueError(err_msg.format('training_chief_hooks')) - if estimator_spec.training_hooks: - raise ValueError(err_msg.format('training_hooks')) - if estimator_spec.evaluation_hooks: - raise ValueError(err_msg.format('evaluation_hooks')) + raise ValueError( + err_msg.format('training_chief_hooks') + 'If you want' + + ' to pass training hooks, please pass via training_hooks.') if estimator_spec.scaffold: logging.warning('EstimatorSpec.Scaffold is ignored by TPU train/eval. ' @@ -1555,7 +1698,7 @@ class _OutfeedHostCall(object): RuntimeError: If outfeed tensor is scalar. """ if not self._names: - return [] + return {} ret = {} # For each i, dequeue_ops[i] is a list containing the tensors from all @@ -1574,11 +1717,13 @@ class _OutfeedHostCall(object): # Outfeed ops execute on each replica's first logical core. Note: we must # constraint it such that we have at most one outfeed dequeue and enqueue # per replica. - tpu_device_placement_fn = self._ctx.tpu_device_placement_function for i in xrange(self._ctx.num_replicas): - with ops.device(tpu_device_placement_fn(i)): + host_device, ordinal_id = self._ctx.device_for_replica(i) + with ops.device(host_device): outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( - dtypes=tensor_dtypes, shapes=tensor_shapes) + dtypes=tensor_dtypes, + shapes=tensor_shapes, + device_ordinal=ordinal_id) for j, item in enumerate(outfeed_tensors): dequeue_ops[j].append(item) @@ -1593,7 +1738,7 @@ class _OutfeedHostCall(object): # place all ops on tpu host if possible. # # TODO(jhseu): Evaluate whether this is right for summaries. - with ops.device(self._ctx.tpu_host_placement_function(core_id=0)): + with ops.device(self._ctx.tpu_host_placement_function(replica_id=0)): for name in self._names: dequeue_ops = dequeue_ops_by_name[name] for i, item in enumerate(dequeue_ops): @@ -1702,6 +1847,9 @@ class InstallSignalHandlerHook(session_run_hook.SessionRunHook): class TPUEstimator(estimator_lib.Estimator): """Estimator with TPU support. + TPUEstimator also supports training on CPU and GPU. You don't need to define + a separate `tf.estimator.Estimator`. + TPUEstimator handles many of the details of running on TPU devices, such as replicating inputs and models for each core, and returning to host periodically to run hooks. @@ -1739,7 +1887,8 @@ class TPUEstimator(estimator_lib.Estimator): Current limitations: -------------------- - 1. TPU evaluation only works on a single host (one TPU worker). + 1. TPU evaluation only works on a single host (one TPU worker) except + BROADCAST mode. 2. `input_fn` for evaluation should **NOT** raise an end-of-input exception (`OutOfRangeError` or `StopIteration`). And all evaluation steps and all @@ -1910,10 +2059,9 @@ class TPUEstimator(estimator_lib.Estimator): """Constructs an `TPUEstimator` instance. Args: - model_fn: Model function as required by `Estimator`. For training, the - returned `EstimatorSpec` cannot have hooks as it is not supported in - `TPUEstimator`. Instead, the user can pass the training hooks as - an argument to `TPUEstimator.train()`. + model_fn: Model function as required by `Estimator` which returns + EstimatorSpec or TPUEstimatorSpec. `training_hooks`, 'evaluation_hooks', + and `prediction_hooks` must not capure any TPU Tensor inside the model_fn. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. If `None`, the model_dir in @@ -1978,7 +2126,7 @@ class TPUEstimator(estimator_lib.Estimator): if (config.tpu_config.per_host_input_for_training is tpu_config.InputPipelineConfig.PER_SHARD_V1 and - config.tpu_config.computation_shape): + config.tpu_config.num_cores_per_replica): raise ValueError( 'Model parallelism only supports per host input for training. ' 'Please adjust TPURunconfig.per_host_input_for_training.') @@ -2025,6 +2173,7 @@ class TPUEstimator(estimator_lib.Estimator): self._export_to_tpu = export_to_tpu self._is_input_fn_invoked = None + self._rendezvous = {} def _add_meta_graph_for_mode(self, builder, @@ -2268,6 +2417,65 @@ class TPUEstimator(estimator_lib.Estimator): """ pass + def train(self, + input_fn, + hooks=None, + steps=None, + max_steps=None, + saving_listeners=None): + rendezvous = error_handling.ErrorRendezvous(num_sources=3) + self._rendezvous[model_fn_lib.ModeKeys.TRAIN] = rendezvous + try: + return super(TPUEstimator, self).train( + input_fn=input_fn, hooks=hooks, steps=steps, max_steps=max_steps, + saving_listeners=saving_listeners + ) + except Exception: # pylint: disable=broad-except + rendezvous.record_error('training_loop', sys.exc_info()) + finally: + rendezvous.record_done('training_loop') + rendezvous.raise_errors() + + def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, + name=None): + rendezvous = error_handling.ErrorRendezvous(num_sources=3) + self._rendezvous[model_fn_lib.ModeKeys.EVAL] = rendezvous + try: + return super(TPUEstimator, self).evaluate( + input_fn, steps=steps, hooks=hooks, checkpoint_path=checkpoint_path, + name=name + ) + except Exception: # pylint: disable=broad-except + rendezvous.record_error('evaluation_loop', sys.exc_info()) + finally: + rendezvous.record_done('evaluation_loop') + rendezvous.raise_errors() + + def predict(self, + input_fn, + predict_keys=None, + hooks=None, + checkpoint_path=None, + yield_single_examples=True): + rendezvous = error_handling.ErrorRendezvous(num_sources=3) + self._rendezvous[model_fn_lib.ModeKeys.PREDICT] = rendezvous + try: + for result in super(TPUEstimator, self).predict( + input_fn=input_fn, + predict_keys=predict_keys, + hooks=hooks, + checkpoint_path=checkpoint_path, + yield_single_examples=yield_single_examples): + yield result + except Exception: # pylint: disable=broad-except + rendezvous.record_error('prediction_loop', sys.exc_info()) + finally: + rendezvous.record_done('prediction_loop') + rendezvous.raise_errors() + + rendezvous.record_done('prediction_loop') + rendezvous.raise_errors() + def _augment_model_fn(self, model_fn, batch_axis): """Returns a new model_fn, which wraps the TPU support.""" @@ -2290,10 +2498,20 @@ class TPUEstimator(estimator_lib.Estimator): # Clear the bit. self._is_input_fn_invoked = None + # examples_hook is added to training_hooks for both CPU and TPU + # execution. + examples_hook = ExamplesPerSecondHook( + ctx.global_batch_size, + output_dir=self.model_dir, + every_n_steps=self._log_every_n_steps) + if ctx.is_running_on_cpu(is_export_mode=is_export_mode): logging.info('Running %s on CPU', mode) - return model_fn_wrapper.call_without_tpu( + estimator_spec = model_fn_wrapper.call_without_tpu( features, labels, is_export_mode=is_export_mode) + estimator_spec = estimator_spec._replace( + training_hooks=estimator_spec.training_hooks + (examples_hook,)) + return estimator_spec assert labels is None, '`labels` passed to `model_fn` must be `None`.' # TPUEstimator._call_input_fn passes `input_fn` as features to here. @@ -2312,7 +2530,7 @@ class TPUEstimator(estimator_lib.Estimator): graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op) if mode == model_fn_lib.ModeKeys.TRAIN: - loss, host_call, scaffold = ( + loss, host_call, scaffold, training_hooks = ( _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) host_ops = host_call.create_tpu_hostcall() if host_ops is None: @@ -2352,7 +2570,9 @@ class TPUEstimator(estimator_lib.Estimator): enqueue_ops, host_ops, run_infeed_loop_on_coordinator=( - run_infeed_loop_on_coordinator)), + run_infeed_loop_on_coordinator), + rendezvous=self._rendezvous[mode], + ), InstallSignalHandlerHook(), training.LoggingTensorHook( { @@ -2361,14 +2581,13 @@ class TPUEstimator(estimator_lib.Estimator): }, every_n_iter=logging_hook_frequency) ]) - examples_hook = ExamplesPerSecondHook( - ctx.global_batch_size, - output_dir=self.model_dir, - every_n_steps=self._log_every_n_steps) examples_hook._set_steps_per_run( # pylint: disable=protected-access self._config.tpu_config.iterations_per_loop) hooks.append(examples_hook) + if training_hooks: + hooks.extend(training_hooks) + chief_hooks = [] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): @@ -2380,6 +2599,7 @@ class TPUEstimator(estimator_lib.Estimator): checkpoint_hook._set_steps_per_run( # pylint: disable=protected-access self._config.tpu_config.iterations_per_loop) chief_hooks.append(checkpoint_hook) + summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) with ops.control_dependencies([loss]): update_ops = _sync_variables_ops() @@ -2399,7 +2619,7 @@ class TPUEstimator(estimator_lib.Estimator): scaffold=scaffold) if mode == model_fn_lib.ModeKeys.EVAL: - total_loss, host_calls, scaffold = _eval_on_tpu_system( + total_loss, host_calls, scaffold, eval_hooks = _eval_on_tpu_system( ctx, model_fn_wrapper, dequeue_fn) iterations_per_loop_var = _create_or_get_iterations_per_loop() mean_loss = math_ops.div(total_loss, @@ -2424,7 +2644,8 @@ class TPUEstimator(estimator_lib.Estimator): host_call_ret = host_calls.create_tpu_hostcall() eval_metric_ops = {} eval_update_ops = [] - for k, v in host_call_ret['eval_metrics'].items(): + + for k, v in host_call_ret.get('eval_metrics', {}).items(): eval_metric_ops[k] = (v[0], dummy_update_op) eval_update_ops.append(v[1]) @@ -2438,9 +2659,13 @@ class TPUEstimator(estimator_lib.Estimator): enqueue_ops, eval_update_ops + host_ops, run_infeed_loop_on_coordinator=( - run_infeed_loop_on_coordinator)), + run_infeed_loop_on_coordinator), + rendezvous=self._rendezvous[mode]), ] + input_hooks + if eval_hooks: + hooks.extend(eval_hooks) + return model_fn_lib.EstimatorSpec( mode, loss=mean_loss, @@ -2451,8 +2676,9 @@ class TPUEstimator(estimator_lib.Estimator): # Predict assert mode == model_fn_lib.ModeKeys.PREDICT - dummy_predict_op, host_calls, scaffold = _predict_on_tpu_system( - ctx, model_fn_wrapper, dequeue_fn) + (dummy_predict_op, host_calls, + scaffold, prediction_hooks) = _predict_on_tpu_system( + ctx, model_fn_wrapper, dequeue_fn) with ops.control_dependencies([dummy_predict_op]): internal_ops_to_run = _sync_variables_ops() with ops.control_dependencies(internal_ops_to_run): @@ -2504,10 +2730,13 @@ class TPUEstimator(estimator_lib.Estimator): hooks = [ _StoppingPredictHook(scalar_stopping_signal), - TPUInfeedOutfeedSessionHookForPrediction(ctx, enqueue_ops, - host_ops), + TPUInfeedOutfeedSessionHookForPrediction( + ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode]), ] + input_hooks + if prediction_hooks: + hooks.extend(prediction_hooks) + return model_fn_lib.EstimatorSpec( mode, prediction_hooks=hooks, @@ -2591,8 +2820,8 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" iterations_per_loop_var = _create_or_get_iterations_per_loop() - single_tpu_eval_step, host_calls, captured_scaffold_fn = ( - model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn)) + (single_tpu_eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks + ) = model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn) def multi_tpu_eval_steps_on_single_shard(): return training_loop.repeat( @@ -2607,15 +2836,16 @@ def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): device_assignment=ctx.device_assignment) scaffold = _get_scaffold(captured_scaffold_fn) - return loss, host_calls, scaffold + return loss, host_calls, scaffold, captured_eval_hooks.get() def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" iterations_per_loop_var = _create_or_get_iterations_per_loop() - single_tpu_train_step, host_call, captured_scaffold_fn = ( - model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn)) + (single_tpu_train_step, host_call, captured_scaffold_fn, + captured_training_hooks) = ( + model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn)) def multi_tpu_train_steps_on_single_shard(): return training_loop.repeat( @@ -2630,15 +2860,16 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): device_assignment=ctx.device_assignment) scaffold = _get_scaffold(captured_scaffold_fn) - return loss, host_call, scaffold + return loss, host_call, scaffold, captured_training_hooks.get() def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" num_cores = ctx.num_cores - single_tpu_predict_step, host_calls, captured_scaffold_fn = ( - model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)) + (single_tpu_predict_step, host_calls, captured_scaffold_fn, + captured_predict_hooks + ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn) def multi_tpu_predict_steps_on_single_shard(): @@ -2655,10 +2886,11 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): multi_tpu_predict_steps_on_single_shard, inputs=[], num_shards=num_cores, - outputs_from_all_shards=False) + outputs_from_all_shards=False, + device_assignment=ctx.device_assignment) scaffold = _get_scaffold(captured_scaffold_fn) - return dummy_predict_op, host_calls, scaffold + return dummy_predict_op, host_calls, scaffold, captured_predict_hooks.get() def _wrap_computation_in_while_loop(device, op_fn): @@ -3155,3 +3387,47 @@ def _add_item_to_params(params, key, value): else: # Now params is Python dict. params[key] = value + + +def export_estimator_savedmodel(estimator, + export_dir_base, + serving_input_receiver_fn, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False): + """Export `Estimator` trained model for TPU inference. + + Args: + estimator: `Estimator` with which model has been trained. + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + serving_input_receiver_fn: A function that takes no argument and + returns a `ServingInputReceiver` or `TensorServingInputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. + + Returns: + The string path to the exported directory. + """ + # `TPUEstimator` requires `tpu_config.RunConfig`, so we cannot use + # `estimator.config`. + config = tpu_config.RunConfig(model_dir=estimator.model_dir) + est = TPUEstimator( + estimator._model_fn, # pylint: disable=protected-access + config=config, + params=estimator.params, + use_tpu=True, + train_batch_size=2048, # Does not matter. + eval_batch_size=2048, # Does not matter. + ) + return est.export_savedmodel(export_dir_base, serving_input_receiver_fn, + assets_extra, + as_text, + checkpoint_path, + strip_default_attrs) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py index 604e6600c81a4136a1f10e79a725a887a96f4d86..d9c77a3ea1bbc456f058f36d78eec1f0843ddc79 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py @@ -20,8 +20,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import itertools + +import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding +from tensorflow.compiler.xla.python_api import xla_shape from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_sharding @@ -30,6 +35,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.util import nest class InfeedQueue(object): @@ -461,7 +467,10 @@ class InfeedQueue(object): name=full_name, device_ordinal=tpu_ordinal) - def generate_enqueue_ops(self, sharded_inputs, tpu_ordinal_function=None): + def generate_enqueue_ops(self, + sharded_inputs, + tpu_ordinal_function=None, + placement_function=None): """Generates the host-side Ops to enqueue the shards of a tuple. sharded_inputs is a list, one for each shard, of lists of @@ -483,6 +492,9 @@ class InfeedQueue(object): shard index as input and returns the ordinal of the TPU device the shard's infeed should be placed on. tpu_ordinal_function must be set if the inputs are placed on CPU devices. + placement_function: if not None, a function that takes the shard index as + input and returns the host device where the enqueue op should be placed + on. Returns: A list of host-side Ops, one for each shard, that when executed together @@ -508,8 +520,12 @@ class InfeedQueue(object): tpu_ordinal_function = lambda index: -1 name_prefix = "%s/enqueue" % self._name return [ - self._generate_enqueue_op(shard, name_prefix, index, - tpu_ordinal=tpu_ordinal_function(index)) + self._generate_enqueue_op( + shard, + name_prefix, + index, + tpu_ordinal=tpu_ordinal_function(index), + device=placement_function(index) if placement_function else None) for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) ] @@ -630,3 +646,264 @@ class InfeedQueue(object): tpu_ordinal=tpu_ordinal_function(index)) for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) ] + + +class _PartitionedInfeedQueue(InfeedQueue): + """A helper object to build a device infeed queue with input partition. + + Args: + number_of_tuple_elements: the number of Tensors fed atomically through the + queue, must be present unless it can be inferred from other arguments. + device_assignment: A TPU `DeviceAssignment` which is used to place all the + partitions to different TPU infeed queues. + host_id: The id of the host machine. + input_partition_dims: A nested list/tuple of integers. Each inner + list/tuple describes how to partition the corresponding input tensor. + tuple_types: If not None, a list of types of the elements of the queue. + tuple_shapes: If not None, a list of shapes of the elements of the queue. + name: The name of the queue. + """ + + def __init__(self, + number_of_tuple_elements, + device_assignment, + host_id, + input_partition_dims=None, + tuple_types=None, + tuple_shapes=None, + name=None): + super(_PartitionedInfeedQueue, self).__init__( + number_of_tuple_elements=number_of_tuple_elements, + tuple_types=tuple_types, + tuple_shapes=None, + shard_dimensions=None, + name="PartitionedInfeedQueue" if name is None else name) + self._input_partition_dims = input_partition_dims + self._host_id = host_id + self._device_assignment = device_assignment + + def generate_dequeue_op(self, tpu_device=0): + """Generate TPU dequeue ops. + + Args: + tpu_device: The TPU device ordinal where the infeed instruction should be + placed. + + Returns: + A list of Outputs corresponding to a partition of infeed dequeued + into XLA, suitable for use within a replicated block. + + Raises: + ValueError: if the types or shapes of the tuple elements have not been + set; or if a dequeue op has already been generated. + """ + self.freeze() + if self._generated_dequeue_op: + raise ValueError("Can't generate two dequeue Ops from the same queue") + self._generated_dequeue_op = True + full_name = "%s/dequeue" % self._name + sharded_shapes = [ + policy.get_sharded_shape(shape) + for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) + ] + with ops.device(tpu.core(tpu_device)): + values = tpu_ops.infeed_dequeue_tuple( + dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) + return self._tag_sharding_attribute_for_dequeued_tensors( + values, self._input_partition_dims) + + def generate_enqueue_ops(self, per_host_sharded_inputs): + """Generates the host-side Ops to enqueue the partitioned inputs. + + per_host_sharded_inputs is a list, one for each replica, of lists of + Tensors. sharded_inputs[i] is the tuple of Tensors to use to feed + replica i. + sharded_inputs[i][j] is partitioned by self._input_partition_dims[j]. + + For example, if sharded_inputs[i][j] is a 2-D Tensor: + [[A, B, C, D], + [E ,F, G, H]] + self._input_partition_dims[j] is [2, 4]. + + sharded_inputs[i][j] will be partitioned and flattened into: + [A, B, C, D, E, F, G, H] and fed into the logical core ids: + [0, 1, 2, 3, 4, 5, 6, 7] respectively. + + Args: + per_host_sharded_inputs: a list of lists of Tensors. The length of the + outer list determines the number of shards. Each inner list indicates + the types and shapes of the tuples in the corresponding shard. + + Returns: + A list of host-side Ops, one for each shard, that when executed together + will enqueue a full-size element of infeed. + + Raises: + ValueError: if the queue configuration has previously been frozen and the + shapes of the elements of sharded_inputs are not compatible with the + frozen configuration; or if the shapes of the elements of sharded_inputs + don't form a consistent unsharded tuple; or if the elements of a tuple + have different device constraints; or if the partition dims are invalid. + TypeError: if the queue configuration has previously been frozen and the + types of the elements of sharded_inputs are not compatible with the + frozen configuration; or if the types of the elements of sharded_inputs + don't form a consistent unsharded tuple. + """ + self.set_configuration_from_sharded_input_tensors(per_host_sharded_inputs) + number_of_replicas_per_host = len(per_host_sharded_inputs) + number_of_tuple_elements = len(per_host_sharded_inputs[0]) + + assert len(self._input_partition_dims) == number_of_tuple_elements + per_host_enqueue_ops = [] + + for replica_index in range(number_of_replicas_per_host): + flattened_inputs = per_host_sharded_inputs[replica_index] + inputs_part_dims_flat = nest.flatten_up_to(flattened_inputs, + self._input_partition_dims) + inputs_parted_iters = [ + iter(self._partition_or_replicate_on_host(x, dims)) for x, dims in + zip(per_host_sharded_inputs[replica_index], inputs_part_dims_flat) + ] + + for core_index in xrange(self._device_assignment.num_cores_per_replica): + # Places different partitions to different logic cores. + logical_core = self._get_logical_core(core_index) + replica_id = self._device_assignment.lookup_replicas( + self._host_id, logical_core)[replica_index] + ordinal = self._device_assignment.tpu_ordinal( + replica=replica_id, logical_core=logical_core) + infeed_inputs = [] + for it in inputs_parted_iters: + input_for_device = next(it, None) + if input_for_device is not None: + infeed_inputs.append(input_for_device) + + if infeed_inputs: + per_host_enqueue_ops.append( + tpu_ops.infeed_enqueue_tuple( + inputs=infeed_inputs, + shapes=[x.shape for x in infeed_inputs], + name="enqueue/replica_{0}/input_{1}".format( + replica_index, core_index), + device_ordinal=ordinal)) + return per_host_enqueue_ops + + def _check_input_partition_dims(self, tensor, dims): + """Checks that input partition dims are valid for the `Tensor`. + + Args: + tensor: Input tensor for partitioning. + dims: A list of integer describes how to partition the input tensor. + + Raises: + ValueError: If the tensor can't be partitioned by dims or the + num_cores_per_replica doesn't match the number of + partitions(dims.prod()). + """ + if dims is None: + return + + dims = np.array(dims) + + if (dims < 1).any(): + raise ValueError("All input partition dims must be >= 1.") + + # No partitioning, so don't perform further checks. + if dims.prod() == 1: + return + + if dims.prod() != self._device_assignment.num_cores_per_replica: + raise ValueError( + "The product of each input parition dim should equal to " + "num_cores_per_replica. (dim = {}, num_cores_per_replica " + "= {})".format(dims, self._device_assignment.num_cores_per_replica)) + if dims.shape[0] != tensor.shape.ndims: + raise ValueError( + "Input partition dims must have the same number of dimensions " + "as the `Tensor` to be partitioned. (tensor shape = {}, input " + "partition dims = {}).".format(tensor.shape.as_list(), dims)) + + tensor.shape.assert_is_fully_defined() + if (np.array(tensor.shape.as_list()) % dims != 0).any(): + raise ValueError( + "All input partition dims must divide exactly into the `Tensor` " + "shape (tensor shape = {}, input partition dims = {}).".format( + tensor.shape.as_list(), dims)) + + def _partition_or_replicate_on_host(self, tensor, dims): + """Partitions or replicates the input tensor. + + The ops inside this function are placed on the host side. + + Args: + tensor: The input tensor which will be partioned or replicated. + dims: A list of integer describes how to partition the input tensor. + Returns: + An iterator of `Tensor`s or a list of partioned tensors. + """ + self._check_input_partition_dims(tensor, dims) + if dims is None: + return itertools.repeat(tensor) + else: + output = [tensor] + for axis, dim in enumerate(dims): + if dim > 1: + output = [array_ops.split(x, dim, axis=axis) for x in output] + output = nest.flatten(output) + return output + + def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims): + """Tags appropriate XLA sharding attribute to the dequeued tensor. + + Args: + tensor: The dequeued tensor on TPU. + dims: A list of integer describes how the tensor is partitioned. + + Returns: + The same tensor with the xla_sharding attribute. + """ + if dims is None: + return xla_sharding.replicate(tensor) + elif np.prod(dims) == 1: + return xla_sharding.assign_device(tensor, 0) + else: + tile_shape = np.array(tensor.shape.as_list()) // dims + tile_assignment = np.arange(np.prod(dims)).reshape(dims) + return xla_sharding.tile( + tensor=tensor, + tile_shape=xla_shape.CreateShapeFromDtypeAndTuple( + dtype=np.dtype(tensor.dtype.as_numpy_dtype), + shape_tuple=tile_shape), + tile_assignment=tile_assignment) + + def _tag_sharding_attribute_for_dequeued_tensors(self, dequeues, dims): + """Tags appropriate XLA sharding attribute to the dequeued tensors. + + Args: + dequeues: A list of dequeued tensors on TPU. + dims: A list of integer describes how the tensor is partitioned. + + Returns: + The same dequeues with appropriate xla_sharding attribute. + """ + nest.assert_shallow_structure(dequeues, dims) + return nest.map_structure_up_to( + dequeues, self._tag_sharding_attribute_for_dequeued_tensor, dequeues, + dims) + + def _get_logical_core(self, core_index): + """Maps the core index to the 3D coordinate within replica. + + The lowest dimension number in computation_shape is the slowest varying + dimension (most major). + + Args: + core_index: An integer represents the core index within replcia. + + Returns: + A tuple with three integers which represents the 3D coordinate. + """ + computation_shape = self._device_assignment.computation_shape + return (core_index // (computation_shape[1] * computation_shape[2]), + core_index % (computation_shape[1] * computation_shape[2]) // + computation_shape[2], core_index % computation_shape[2]) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py index 15f99d7eebddd46f9f6902b68f01e42359a72cbe..53d33f40777a1c6d93f19c30b2ef5902d63ad2fd 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py @@ -23,6 +23,7 @@ import collections from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_function +from tensorflow.python.framework import ops from tensorflow.python.ops.losses import losses from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import optimizer @@ -153,8 +154,9 @@ class CrossShardOptimizer(optimizer.Optimizer): if grad is None: summed_grads_and_vars.append((grad, var)) else: - summed_grads_and_vars.append((tpu_ops.cross_replica_sum( - grad, self._group_assignment), var)) + with ops.colocate_with(grad): + summed_grads_and_vars.append((tpu_ops.cross_replica_sum( + grad, self._group_assignment), var)) return self._opt.apply_gradients(summed_grads_and_vars, global_step, name) def get_slot(self, *args, **kwargs): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py index 894f21d0635ca47d3da1c0d2c3f5c37bac690920..ec682e5829c4df536a043334b74200f0b6259df3 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py @@ -45,7 +45,7 @@ _TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [ ]) -def _query_tpu_system_metadata(master_address, run_config, +def _query_tpu_system_metadata(master_address, cluster_def=None, query_topology=False): """Automatically detects the TPU system metadata in the system.""" tpu_core_count = 0 @@ -61,7 +61,8 @@ def _query_tpu_system_metadata(master_address, run_config, with session_lib.Session( master_address, config=get_session_config_with_timeout( - _PINGING_MASTER_TIMEOUT_IN_MS, run_config)) as sess: + _PINGING_MASTER_TIMEOUT_IN_MS, + cluster_def)) as sess: devices = sess.list_devices() for device in devices: match = _TPU_DEVICE_REG.match(device.name) @@ -105,7 +106,7 @@ def _query_tpu_system_metadata(master_address, run_config, 'TPU worker has some problems. Available devices: {}'.format( master_address, devices)) - topology = _obtain_topology(master_address, run_config) + topology = _obtain_topology(master_address, cluster_def) metadata = _TPUSystemMetadata( num_cores=tpu_core_count, @@ -127,14 +128,15 @@ def _query_tpu_system_metadata(master_address, run_config, return metadata -def _obtain_topology(master_address, run_config): +def _obtain_topology(master_address, cluster_def): + """Obtains TPU fabric topology.""" try: logging.info('Initializing TPU system (master: %s) to fetch topology ' 'for model parallelism. This might take a while.', master_address) with ops.Graph().as_default(): session_config = get_session_config_with_timeout( - _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, run_config) + _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def) with session_lib.Session( master_address, config=session_config) as sess: topology = sess.run(tpu.initialize_system()) @@ -146,11 +148,8 @@ def _obtain_topology(master_address, run_config): master_address)) -def get_session_config_with_timeout(timeout_in_secs, run_config): - cluster_def = None - if run_config.session_config and run_config.session_config.cluster_def.job: - cluster_def = run_config.session_config.cluster_def - +def get_session_config_with_timeout(timeout_in_secs, cluster_def): + """Returns a session given a timeout and a cluster configuration.""" config = config_pb2.ConfigProto( operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def) return config diff --git a/tensorflow/contrib/training/python/training/evaluation.py b/tensorflow/contrib/training/python/training/evaluation.py index f7fd66d33fc0c329db7daaf87373385156d84217..01bac891da7ddf8523e6cc8c99decf4a61aa2741 100644 --- a/tensorflow/contrib/training/python/training/evaluation.py +++ b/tensorflow/contrib/training/python/training/evaluation.py @@ -142,9 +142,9 @@ from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import evaluation from tensorflow.python.training import monitored_session -from tensorflow.python.training import saver as tf_saver from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util @@ -189,7 +189,7 @@ def wait_for_new_checkpoint(checkpoint_dir, logging.info('Waiting for new checkpoint at %s', checkpoint_dir) stop_time = time.time() + timeout if timeout is not None else None while True: - checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir) + checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir) if checkpoint_path is None or checkpoint_path == last_checkpoint: if stop_time is not None and time.time() + seconds_to_sleep > stop_time: return None diff --git a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay.py b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay.py new file mode 100644 index 0000000000000000000000000000000000000000..ed0f398e30a7f3c0b1b9378f8fc5d5bfbea1536a --- /dev/null +++ b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay.py @@ -0,0 +1,187 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""SGDR learning rate decay function.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops, control_flow_ops + + +def sgdr_decay(learning_rate, global_step, initial_period_steps, + t_mul=2.0, m_mul=1.0, name=None): + """Implements Stochastic Gradient Descent with Warm Restarts (SGDR). + + As described in "SGDR: Stochastic Gradient Descent + with Warm Restarts" by Ilya Loshchilov & Frank Hutter, Proceedings of + ICLR'2017, available at https://arxiv.org/pdf/1608.03983.pdf + + The learning rate decreases according to cosine annealing: + + ```python + learning_rate * 0.5 * (1 + cos(x_val * pi)) # for x_val defined in [0, 1] + ``` + + Thus, at the beginning (when the restart index i = 0), + the learning rate decreases for `initial_period_steps` steps from the initial + learning rate `learning_rate` (when `x_val=0`, we get `cos(0)=1`) to + 0 (when `x_val=1`, we get `cos(pi)=-1`). + + The decrease within the i-th period takes `t_i` steps, + where `t_0` = `initial_period_steps` is the user-defined number of batch + iterations (not epochs as in the paper) to be performed before the first + restart is launched. + + Then, we perform the first restart (i=1) by setting the learning rate to + `learning_rate*(m_mul^i)`, where `m_mul in [0,1]` (set to 1 by default). + The i-th restart runs for `t_i=t_0*(t_mul^i)` steps, i.e., every new + restart runs `t_mul` times longer than the previous one. + + Importantly, when one has no access to a validation set, SGDR suggests + to report the best expected / recommended solution in the following way: + When we are within our initial run (i=0), every new solution represents + SGDR's recommended solution. Instead, when i>0, the recommended solution is + the one obtained at the end of each restart. + + Note that the minimum learning rate is set to 0 for simplicity, + you can adjust the code to deal with any positive minimum learning rate + as defined in the paper. + + `initial_period_steps` is the duration of the first period measured in terms + of number of minibatch updates. If one wants to use epochs, one should compute + the number of updates required for an epoch. + + For example, assume the following parameters and intention: + Minibatch size: 100 + Training dataset size: 10000 + If the user wants the first decay period to span across 5 epochs, then + `initial_period_steps` = 5 * 10000/100 = 500 + + Train for 10000 batch iterations with the initial learning rate set to + 0.1, then restart to run 2 times longer, i.e, for 20000 batch iterations + and with the initial learning rate 0.05, then restart again and again, + doubling the runtime of each new period and with two times smaller + initial learning rate. + + To accomplish the above, one would write: + + ```python + ... + global_step = tf.Variable(0, trainable=False) + starter_learning_rate = 0.1 + learning_rate = sgdr_decay(starter_learning_rate, global_step, + initial_period_steps=10000, t_mul=2, m_mul=0.5) + # Passing global_step to minimize() will increment it at each step. + learning_step = ( + tf.train.GradientDescentOptimizer(learning_rate) + .minimize(...my loss..., global_step=global_step) + ) + + # Step | 0 | 1000 | 5000 | 9000 | 9999 | 10000 | 11000 | + # LR | 0.1 | 0.097 | 0.05 | 0.002 | 0.00 | 0.05 | 0.0496 | + + # Step | 20000 | 29000 | 29999 | 30000 | + # LR | 0.025 | 0.0003 | 0.00 | 0.025 | + ``` + + Args: + learning_rate: A scalar `float32` or `float64` `Tensor` or a + Python number. The initial learning rate. + global_step: A scalar `int32` or `int64` `Tensor` or a Python number. + Global step to use for the decay computation. Must not be negative. + initial_period_steps: Duration of the first period measured as the number + of minibatch updates, if one wants to use epochs, one should compute + the number of updates required for an epoch. + t_mul: A scalar `float32` or `float64` `Tensor` or a Python number. + Must be positive. + Used to derive the number of iterations in the i-th period: + `initial_period_steps * (t_mul^i)`. Defaults to 2.0. + m_mul: A scalar `float32` or `float64` `Tensor` or a Python number. + Must be positive. + Used to derive the initial learning rate of the i-th period: + `learning_rate * (m_mul^i)`. Defaults to 1.0 + + Returns: + A scalar `Tensor` of the same type as `learning_rate`. + The learning rate for a provided global_step. + Raises: + ValueError: if `global_step` is not supplied. + """ + + if global_step is None: + raise ValueError("global_step is required for sgdr_decay.") + with ops.name_scope(name, "SGDRDecay", + [learning_rate, global_step, + initial_period_steps, t_mul, m_mul]) as name: + learning_rate = ops.convert_to_tensor(learning_rate, + name="initial_learning_rate") + dtype = learning_rate.dtype + global_step = math_ops.cast(global_step, dtype) + t_0 = math_ops.cast(initial_period_steps, dtype) + t_mul = math_ops.cast(t_mul, dtype) + m_mul = math_ops.cast(m_mul, dtype) + + c_one = math_ops.cast(constant_op.constant(1.0), dtype) + c_half = math_ops.cast(constant_op.constant(0.5), dtype) + c_pi = math_ops.cast(constant_op.constant(math.pi), dtype) + + # Find normalized value of the current step + x_val = math_ops.div(global_step, t_0) + + def compute_step(x_val, geometric=False): + if geometric: + # Consider geometric series where t_mul != 1 + # 1 + t_mul + t_mul^2 ... = (1 - t_mul^i_restart) / (1 - t_mul) + + # First find how many restarts were performed for a given x_val + # Find maximal integer i_restart value for which this equation holds + # x_val >= (1 - t_mul^i_restart) / (1 - t_mul) + # x_val * (1 - t_mul) <= (1 - t_mul^i_restart) + # t_mul^i_restart <= (1 - x_val * (1 - t_mul)) + + # tensorflow allows only log with base e + # i_restart <= log(1 - x_val * (1 - t_mul) / log(t_mul) + # Find how many restarts were performed + + i_restart = math_ops.floor( + math_ops.log(c_one - x_val * (c_one - t_mul)) / math_ops.log(t_mul)) + # Compute the sum of all restarts before the current one + sum_r = (c_one - t_mul ** i_restart) / (c_one - t_mul) + # Compute our position within the current restart + x_val = (x_val - sum_r) / t_mul ** i_restart + + else: + # Find how many restarts were performed + i_restart = math_ops.floor(x_val) + # Compute our position within the current restart + x_val = x_val - i_restart + return i_restart, x_val + + i_restart, x_val = control_flow_ops.cond( + math_ops.equal(t_mul, c_one), + lambda: compute_step(x_val, geometric=False), + lambda: compute_step(x_val, geometric=True)) + + # If m_mul < 1, then the initial learning rate of every new restart will be + # smaller, i.e., by a factor of m_mul ** i_restart at i_restart-th restart + m_fac = learning_rate * (m_mul ** i_restart) + + return math_ops.multiply(c_half * m_fac, + (math_ops.cos(x_val * c_pi) + c_one), name=name) diff --git a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4a46e9a49ef203384e36698f81d6cbe3a3881ef8 --- /dev/null +++ b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py @@ -0,0 +1,145 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Functional test for sgdr learning rate decay.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from sgdr_learning_rate_decay import sgdr_decay +from tensorflow.python.platform import googletest +from tensorflow.python.framework import test_util +from tensorflow.python.framework import dtypes +from tensorflow import placeholder + + +class SGDRDecayTest(test_util.TensorFlowTestCase): + """Unit tests for SGDR learning rate decay.""" + + def get_original_values(self, lr, t_e, mult_factor, iter_per_epoch, epochs): + """Get an array with learning rate values from the consecutive steps using + the original implementation + (https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py).""" + t0 = math.pi / 2.0 + tt = 0 + te_next = t_e + + lr_values = [] + sh_lr = lr + for epoch in range(epochs): + for _ in range(iter_per_epoch): + # In the original approach training function is executed here + lr_values.append(sh_lr) + dt = 2.0 * math.pi / float(2.0 * t_e) + tt = tt + float(dt) / iter_per_epoch + if tt >= math.pi: + tt = tt - math.pi + cur_t = t0 + tt + new_lr = lr * (1.0 + math.sin(cur_t)) / 2.0 # lr_min = 0, lr_max = lr + sh_lr = new_lr + if (epoch + 1) == te_next: # time to restart + sh_lr = lr + tt = 0 # by setting to 0 we set lr to lr_max, see above + t_e = t_e * mult_factor # change the period of restarts + te_next = te_next + t_e # note the next restart's epoch + + return lr_values + + def get_sgdr_values(self, lr, initial_period_steps, t_mul, iters): + """Get an array with learning rate values from the consecutive steps + using current tensorflow implementation.""" + with self.test_session(): + step = placeholder(dtypes.int32) + + decay = sgdr_decay(lr, step, initial_period_steps, t_mul) + lr_values = [] + for i in range(iters): + lr_values.append(decay.eval(feed_dict={step: i})) + + return lr_values + + def testCompareToOriginal(self): + """Compare values generated by tensorflow implementation to the values + generated by the original implementation + (https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py).""" + with self.test_session(): + lr = 10.0 + init_steps = 2 + t_mul = 3 + iters = 10 + epochs = 50 + + org_lr = self.get_original_values(lr, init_steps, t_mul, iters, epochs) + sgdr_lr = self.get_sgdr_values(lr, init_steps*iters, t_mul, iters*epochs) + + for org, sgdr in zip(org_lr, sgdr_lr): + self.assertAllClose(org, sgdr) + + def testMDecay(self): + """Test m_mul argument. Check values for learning rate at the beginning + of the first, second, third and fourth period. """ + with self.test_session(): + step = placeholder(dtypes.int32) + + lr = 0.1 + t_e = 10 + t_mul = 3 + m_mul = 0.9 + + decay = sgdr_decay(lr, step, t_e, t_mul, m_mul) + + test_step = 0 + self.assertAllClose(decay.eval(feed_dict={step: test_step}), + lr) + + test_step = t_e + self.assertAllClose(decay.eval(feed_dict={step: test_step}), + lr * m_mul) + + test_step = t_e + t_e*t_mul + self.assertAllClose(decay.eval(feed_dict={step: test_step}), + lr * m_mul**2) + + test_step = t_e + t_e*t_mul + t_e * (t_mul**2) + self.assertAllClose(decay.eval(feed_dict={step: test_step}), + lr * (m_mul**3)) + + def testCos(self): + """Check learning rate values at the beginning, in the middle + and at the end of the period.""" + with self.test_session(): + step = placeholder(dtypes.int32) + lr = 0.2 + t_e = 1000 + t_mul = 1 + + decay = sgdr_decay(lr, step, t_e, t_mul) + + test_step = 0 + self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr) + + test_step = t_e//2 + self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr/2) + + test_step = t_e + self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr) + + test_step = t_e*3//2 + self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr/2) + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/contrib/training/python/training/training_test.py b/tensorflow/contrib/training/python/training/training_test.py index 4877c010fad2c567d26b9674d2904274c0895f55..94cf7788b2bd3bc3fe87eefd599ce88de03042af 100644 --- a/tensorflow/contrib/training/python/training/training_test.py +++ b/tensorflow/contrib/training/python/training/training_test.py @@ -36,6 +36,7 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import checkpoint_management from tensorflow.python.training import gradient_descent from tensorflow.python.training import monitored_session from tensorflow.python.training import saver as saver_lib @@ -421,7 +422,7 @@ class TrainTest(test.TestCase): train_op = self.create_train_op() model_variables = variables_lib2.global_variables() - model_path = saver_lib.latest_checkpoint(logdir1) + model_path = checkpoint_management.latest_checkpoint(logdir1) assign_fn = variables_lib.assign_from_checkpoint_fn( model_path, model_variables) diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index 86350a08e57e5050f18d019fe80d70f6381c1f7d..f7c979e86320d59ad033e2b8d7fcdff89ce0d133 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -24,8 +24,8 @@ limitations under the License. #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/process_util.h" #if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" #include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#include "tensorflow/core/common_runtime/gpu/process_state.h" #endif #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" @@ -1084,7 +1084,7 @@ void RdmaTensorResponse::RecvHandler(Rendezvous::ParsedKey parsed, // The tensor must be copied from GPU to CPU, because either: // 1. The tensor is located on a non GDR compatible GPU. // 2. The tensor's meta-data has changed. - Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); + Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0); copy = Tensor(alloc, in.dtype(), in.shape()); CountCopies(rm_.name_, (void*)DMAHelper::base(&in), (void*)DMAHelper::base(©), in.TotalBytes(), true); @@ -1541,7 +1541,7 @@ bool RdmaTensorRequest::AllocateTensors() { if (mr_ == nullptr) { // Can't RDMA directly to result. Use a proxy. proxy_tensor_ = - new Tensor(ProcessState::singleton()->GetCUDAHostAllocator(0), + new Tensor(GPUProcessState::singleton()->GetCUDAHostAllocator(0), result_tensor_->dtype(), result_tensor_->shape()); rdma_addr_ = DMAHelper::base(proxy_tensor_); mr_ = diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc index 369bd986df5313955bc22d6e5c6d38815908ada3..3cb5e61facf860f2740935f66bf548096296280f 100644 --- a/tensorflow/contrib/verbs/rdma_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_mgr.cc @@ -21,8 +21,10 @@ limitations under the License. #include "tensorflow/contrib/verbs/grpc_verbs_client.h" #include "tensorflow/contrib/verbs/verbs_service.pb.h" #include "tensorflow/core/common_runtime/bfc_allocator.h" +#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" #include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#include "tensorflow/core/common_runtime/gpu/process_state.h" +#include "tensorflow/core/common_runtime/pool_allocator.h" +#include "tensorflow/core/common_runtime/process_state.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" #include "tensorflow/core/framework/allocator_registry.h" @@ -254,37 +256,34 @@ void MRDeleter(ibv_mr* mr) { } } -// TODO(byronyi): remove this class duplicated from the one in -// common/runtime/gpu/pool_allocator.h when it is available in common_runtime -class BasicCPUAllocator : public SubAllocator { - public: - ~BasicCPUAllocator() override {} - - void* Alloc(size_t alignment, size_t num_bytes) override { - return port::AlignedMalloc(num_bytes, alignment); - } - void Free(void* ptr, size_t) override { port::AlignedFree(ptr); } -}; - // TODO(byronyi): remove this class and its registration when the default -// cpu_allocator() returns visitable allocator +// cpu_allocator() returns visitable allocator, or cpu_allocator() is no +// longer in use. class BFCRdmaAllocator : public BFCAllocator { public: BFCRdmaAllocator() - : BFCAllocator(new BasicCPUAllocator(), 1LL << 36, true, "cpu_rdma_bfc") { + : BFCAllocator(new BasicCPUAllocator(port::kNUMANoAffinity), 1LL << 36, + true, "cpu_rdma_bfc") {} +}; +class BFCRdmaAllocatorFactory : public AllocatorFactory { + public: + Allocator* CreateAllocator() { return new BFCRdmaAllocator; } + + SubAllocator* CreateSubAllocator(int numa_node) { + return new BasicCPUAllocator(numa_node); } }; -REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocator); +REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocatorFactory); void RdmaMgr::InitAllocators() { RdmaMemoryMgr::Singleton().pd_ = rdma_adapter_->pd_; Allocator* allocators[] = { #if GOOGLE_CUDA - ProcessState::singleton()->GetCUDAHostAllocator(0), - ProcessState::singleton()->GetCPUAllocator(0), + GPUProcessState::singleton()->GetCUDAHostAllocator(0), #endif // GOOGLE_CUDA + ProcessState::singleton()->GetCPUAllocator(0), cpu_allocator(), }; @@ -323,7 +322,8 @@ void RdmaMgr::InitAllocators() { std::bind(&RdmaMemoryMgr::InsertMemoryRegion, &RdmaMemoryMgr::Singleton(), _1, _2, std::string(buf)); - ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor); + GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id, + cuda_alloc_visitor); LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id; } #endif // GOOGLE_CUDA diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index c1efc9c0c6f3f24f37f7166382f7348fd51b61bb..1423c7fbcb227c3a00d74fb62c8e2b547d93c41c 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -150,7 +150,6 @@ load( "//third_party/mkl:build_defs.bzl", "if_mkl", ) -load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library") exports_files(["ops/ops.pbtxt"]) @@ -234,7 +233,6 @@ tf_proto_library( srcs = [], cc_api_version = 2, default_header = True, - j2objc_api_version = 1, java_api_version = 2, js_api_version = 2, protodeps = [ @@ -335,6 +333,7 @@ filegroup( "platform/init_main.h", "platform/mem.h", "platform/mutex.h", + "platform/numa.h", "platform/thread_annotations.h", ], visibility = ["//visibility:private"], @@ -663,6 +662,7 @@ cc_library( "lib/random/random_distributions.h", "lib/random/simple_philox.h", "lib/strings/numbers.h", + "lib/strings/proto_serialization.h", "lib/strings/str_util.h", "lib/strings/strcat.h", "lib/strings/stringprintf.h", @@ -847,6 +847,7 @@ tf_cuda_library( "util/sparse/sparse_tensor.h", "util/stat_summarizer.h", "util/stat_summarizer_options.h", + "util/status_util.h", "util/stream_executor_util.h", "util/strided_slice_op.h", "util/tensor_format.h", @@ -883,6 +884,16 @@ cc_library( copts = tf_copts(), ) +tf_cc_test( + name = "stats_calculator_test", + srcs = ["util/stats_calculator_test.cc"], + deps = [ + ":stats_calculator_portable", + ":test", + ":test_main", + ], +) + cc_library( name = "overflow", hdrs = ["util/overflow.h"], @@ -1263,6 +1274,7 @@ cc_library( "//tensorflow/core/kernels:fake_quant_ops", "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:functional_ops", + "//tensorflow/core/kernels:grappler", "//tensorflow/core/kernels:histogram_op", "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", @@ -1644,6 +1656,7 @@ cc_library( copts = tf_copts(android_optimization_level_override = None) + [ "-DSUPPORT_SELECTIVE_REGISTRATION", ], + linkopts = if_android(["-lz"]), tags = [ "manual", "notap", @@ -1667,6 +1680,7 @@ cc_library( copts = tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_android() + [ "-DSUPPORT_SELECTIVE_REGISTRATION", ], + linkopts = if_android(["-lz"]), tags = [ "manual", "notap", @@ -1952,8 +1966,10 @@ LIB_INTERNAL_PRIVATE_HEADERS = ["framework/resource_handle.h"] + glob( "**/*test*", "lib/gif/**/*", "lib/jpeg/**/*", + "lib/png/**/*", "platform/gif.h", "platform/jpeg.h", + "platform/png.h", "platform/**/cuda.h", "platform/**/stream_executor.h", ], @@ -2048,6 +2064,7 @@ cc_library( "lib/hash/crc32c_accelerate.cc", "lib/gif/**/*", "lib/jpeg/**/*", + "lib/png/**/*", "platform/**/env_time.cc", "platform/**/cuda_libdevice_path.cc", "platform/**/device_tracer.cc", @@ -2143,6 +2160,39 @@ cc_library( ], ) +cc_library( + name = "png_internal", + srcs = ["lib/png/png_io.cc"], + hdrs = [ + "lib/bfloat16/bfloat16.h", + "lib/core/casts.h", + "lib/core/stringpiece.h", + "lib/png/png_io.h", + "platform/byte_order.h", + "platform/cpu_info.h", + "platform/default/integral_types.h", + "platform/default/logging.h", + "platform/logging.h", + "platform/macros.h", + "platform/platform.h", + "platform/png.h", + "platform/types.h", + ], + copts = tf_copts(), + linkopts = select({ + "//tensorflow:freebsd": [], + "//tensorflow:windows": [], + "//tensorflow:windows_msvc": [], + "//conditions:default": ["-ldl"], + }), + deps = [ + ":lib", + ":lib_internal", + "//tensorflow/core/platform/default/build_config:png", + "@zlib_archive//:zlib", + ], +) + cc_library( name = "tflite_portable_logging", srcs = [], @@ -2188,6 +2238,7 @@ cc_library( linkopts = ["-ldl"], deps = [ "//tensorflow/core/platform/default/build_config:jpeg", + "//tensorflow/core/platform/default/build_config:logging", ], ) @@ -2216,6 +2267,7 @@ cc_library( linkopts = ["-ldl"], deps = [ "//tensorflow/core/platform/default/build_config:gif", + "//tensorflow/core/platform/default/build_config:logging", ], ) @@ -2242,6 +2294,7 @@ cc_library( copts = tf_copts(), linkopts = ["-ldl"], deps = [ + "//tensorflow/core/platform/default/build_config:logging", "@png_archive//:png", ], ) @@ -2251,7 +2304,6 @@ tf_proto_library( srcs = ERROR_CODES_PROTO_SRCS, cc_api_version = 2, default_header = True, - j2objc_api_version = 1, java_api_version = 2, js_api_version = 2, provide_cc_alias = True, @@ -2273,7 +2325,6 @@ tf_proto_library( srcs = COMMON_PROTO_SRCS + ADDITIONAL_CORE_PROTO_SRCS, cc_api_version = 2, default_header = True, - j2objc_api_version = 1, java_api_version = 2, js_api_version = 2, protodeps = [ @@ -2430,6 +2481,7 @@ tf_cuda_library( "framework/resource_handle.cc", "util/memmapped_file_system.*", "util/memmapped_file_system_writer.*", + "util/stats_calculator.*", "util/version_info.cc", ], ) + select({ @@ -2456,6 +2508,7 @@ tf_cuda_library( ":protos_all_proto_text", ":error_codes_proto_text", ":protos_all_cc", + ":stats_calculator_portable", ":version_lib", "//tensorflow/core/platform/default/build_config:platformlib", "//tensorflow/core/kernels:bounds_check", @@ -2671,6 +2724,8 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/step_stats_collector.h", "common_runtime/threadpool_device.h", "common_runtime/visitable_allocator.h", + "common_runtime/process_state.h", + "common_runtime/pool_allocator.h", "graph/gradients.h", "graph/quantize_training.h", ] + if_mkl(["graph/mkl_graph_util.h"]) @@ -2709,7 +2764,9 @@ tf_cuda_library( "common_runtime/optimization_registry.cc", "common_runtime/parallel_concat_optimizer.cc", "common_runtime/placer.cc", + "common_runtime/pool_allocator.cc", "common_runtime/process_function_library_runtime.cc", + "common_runtime/process_state.cc", "common_runtime/process_util.cc", "common_runtime/renamed_device.cc", "common_runtime/rendezvous_mgr.cc", @@ -2870,6 +2927,14 @@ tf_cuda_library( ] + tf_additional_device_tracer_deps(), ) +cc_library( + name = "session_ref", + srcs = ["common_runtime/session_ref.cc"], + hdrs = ["common_runtime/session_ref.h"], + copts = tf_copts(), + deps = [":core_cpu_base"], +) + cc_library( name = "gpu_id", hdrs = [ @@ -2896,6 +2961,7 @@ cc_library( ) GPU_RUNTIME_HEADERS = [ + "common_runtime/gpu/cuda_host_allocator.h", "common_runtime/gpu/gpu_bfc_allocator.h", "common_runtime/gpu/gpu_cudamalloc_allocator.h", "common_runtime/gpu/gpu_debug_allocator.h", @@ -2905,10 +2971,9 @@ GPU_RUNTIME_HEADERS = [ "common_runtime/gpu/gpu_id_utils.h", "common_runtime/gpu/gpu_init.h", "common_runtime/gpu/gpu_managed_allocator.h", + "common_runtime/gpu/gpu_process_state.h", "common_runtime/gpu/gpu_stream_util.h", "common_runtime/gpu/gpu_util.h", - "common_runtime/gpu/pool_allocator.h", - "common_runtime/gpu/process_state.h", "common_runtime/gpu_device_context.h", ] @@ -2921,11 +2986,10 @@ tf_cuda_library( "common_runtime/gpu/gpu_device.cc", "common_runtime/gpu/gpu_device_factory.cc", "common_runtime/gpu/gpu_managed_allocator.cc", + "common_runtime/gpu/gpu_process_state.cc", "common_runtime/gpu/gpu_stream_util.cc", "common_runtime/gpu/gpu_util.cc", "common_runtime/gpu/gpu_util_platform_specific.cc", - "common_runtime/gpu/pool_allocator.cc", - "common_runtime/gpu/process_state.cc", ], hdrs = GPU_RUNTIME_HEADERS, copts = tf_copts(), @@ -3172,6 +3236,7 @@ tf_cc_tests( "platform/fingerprint_test.cc", "platform/integral_types_test.cc", "platform/logging_test.cc", + "platform/mutex_test.cc", "platform/net_test.cc", "platform/port_test.cc", "platform/profile_utils/cpu_utils_test.cc", @@ -3186,6 +3251,7 @@ tf_cc_tests( ":test", ":test_main", "//third_party/eigen3", + "@zlib_archive//:zlib", ], ) @@ -3235,6 +3301,28 @@ tf_cc_test( ], ) +tf_cc_test( + name = "platform_numa_test", + size = "small", + srcs = ["platform/numa_test.cc"], + tags = [ + # This test will not pass unless it has access to all NUMA nodes + # on the executing machine. + "manual", + "notap", + ], + deps = [ + ":framework", + ":lib", + ":lib_internal", + ":lib_test_internal", + ":protos_all_cc", + ":test", + ":test_main", + "//third_party/eigen3", + ], +) + tf_cc_test( name = "platform_setround_test", size = "small", @@ -3406,6 +3494,7 @@ tf_cc_tests( "framework/tensor_shape_test.cc", "framework/tensor_slice_test.cc", "framework/tensor_test.cc", + "framework/tensor_testutil_test.cc", "framework/tensor_util_test.cc", "framework/tracking_allocator_test.cc", "framework/types_test.cc", @@ -3599,6 +3688,7 @@ tf_cc_test_mkl( deps = [ ":core", ":core_cpu", + ":core_cpu_internal", ":framework", ":framework_internal", ":test", @@ -3672,7 +3762,6 @@ tf_cc_tests_gpu( "common_runtime/gpu/gpu_bfc_allocator_test.cc", "common_runtime/gpu/gpu_device_test.cc", "common_runtime/gpu/gpu_id_manager_test.cc", - "common_runtime/gpu/gpu_event_mgr_test.cc", "common_runtime/gpu/pool_allocator_test.cc", ], linkstatic = tf_kernel_tests_linkstatic(), @@ -3696,6 +3785,23 @@ tf_cc_tests_gpu( ], ) +tf_cc_test_gpu( + name = "gpu_event_mgr_test", + srcs = ["common_runtime/gpu/gpu_event_mgr_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + ], +) + tf_cuda_cc_test( name = "gpu_device_unified_memory_test", size = "small", diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index 477a0b670e49f8aa4ee8c250d4957886eb865ed5..ae03a61ae66ec8d0119d91eefe8c64e61348e9b4 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -149,6 +149,33 @@ void TestAllApiDefAttributeNamesAreValid( } } } + +void TestDeprecatedAttributesSetCorrectly( + const std::unordered_map& api_defs_map) { + for (const auto& name_and_api_def : api_defs_map) { + int num_deprecated_endpoints = 0; + const auto& api_def = name_and_api_def.second; + for (const auto& endpoint : api_def.endpoint()) { + if (endpoint.deprecated()) { + ++num_deprecated_endpoints; + } + } + + const auto& name = name_and_api_def.first; + ASSERT_TRUE(api_def.deprecation_message().empty() || + num_deprecated_endpoints == 0) + << "Endpoints are set to 'deprecated' for deprecated op " << name + << ". If an op is deprecated (i.e. deprecation_message is set), " + << "all the endpoints are deprecated implicitly and 'deprecated' " + << "field should not be set."; + if (num_deprecated_endpoints > 0) { + ASSERT_NE(num_deprecated_endpoints, api_def.endpoint_size()) + << "All " << name << " endpoints are deprecated. Please, set " + << "deprecation_message in api_def_" << name << ".pbtxt instead. " + << "to indicate that the op is deprecated."; + } + } +} } // namespace class BaseApiTest : public ::testing::Test { @@ -171,7 +198,7 @@ TEST_F(BaseApiTest, AllOpsAreInApiDef) { if (excluded_ops->find(op.name()) != excluded_ops->end()) { continue; } - ASSERT_TRUE(api_defs_map_.find(op.name()) != api_defs_map_.end()) + EXPECT_TRUE(api_defs_map_.find(op.name()) != api_defs_map_.end()) << op.name() << " op does not have api_def_*.pbtxt file. " << "Please add api_def_" << op.name() << ".pbtxt file " << "under tensorflow/core/api_def/base_api/ directory."; @@ -236,6 +263,11 @@ TEST_F(BaseApiTest, AllApiDefAttributeNamesAreValid) { TestAllApiDefAttributeNamesAreValid(ops_, api_defs_map_); } +// Checks that deprecation is set correctly. +TEST_F(BaseApiTest, DeprecationSetCorrectly) { + TestDeprecatedAttributesSetCorrectly(api_defs_map_); +} + class PythonApiTest : public ::testing::Test { protected: PythonApiTest() { @@ -272,4 +304,9 @@ TEST_F(PythonApiTest, AllApiDefAttributeNamesAreValid) { TestAllApiDefAttributeNamesAreValid(ops_, api_defs_map_); } +// Checks that deprecation is set correctly. +TEST_F(PythonApiTest, DeprecationSetCorrectly) { + TestDeprecatedAttributesSetCorrectly(api_defs_map_); +} + } // namespace tensorflow diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCenterBias.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCenterBias.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b58b974eb4e43b49d6630449de1a0a6c37a15859 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCenterBias.pbtxt @@ -0,0 +1,41 @@ +op { + graph_op_name: "BoostedTreesCenterBias" + visibility: HIDDEN + in_arg { + name: "tree_ensemble_handle" + description: <